In [None]:
!pip install datasets



In [None]:
import numpy as np
import torch
import pickle
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import logging
logging.basicConfig(level=logging.ERROR)
from transformers import AdamW
from tqdm import tqdm
from datasets import load_metric
import numpy as np
import pprint
import copy
from torch.nn import TransformerEncoder, TransformerEncoderLayer

#Dataset Class

In [None]:
class CustomDataset(Dataset):
    def __init__(self, doc_info, padding_len=100):
        """
        Args:
            annotations (List[List]): List of lists containing the word-level annotations (words, labels, boxes).
            image_dir (string): Directory with all the document images.
            processor (LayoutLMv2Processor): Processor to prepare the text + image.
        """
        self.doc_info = doc_info
        self.padding_len = padding_len

    def __len__(self):
        return len(self.doc_info)

    def __getitem__(self, idx):
        doc_info = self.doc_info[list(self.doc_info.keys())[idx]]
        target_id = []
        visual_feat = doc_info["visual_list"]
        bert_cls = doc_info["bert_cls"]
        for i in range(len(doc_info["form"])):
            target_id.append(doc_info["form"][i]["label"])
        target_id = [label_dict[l] for l in target_id]

        visual_feat = copy.deepcopy(visual_feat)
        bert_cls = copy.deepcopy(bert_cls)

        object_mask = []
        if len(visual_feat) >= self.padding_len:
            visual_feat = visual_feat[:self.padding_len]
            bert_cls = bert_cls[:self.padding_len]
            object_mask = [1]*self.padding_len

            target_id = target_id[:self.padding_len]
        else:
            size = len(visual_feat)
            visual_feat.extend([[0.0]*2048]*(self.padding_len-len(visual_feat)))
            bert_cls.extend([[0.0]*768]*(self.padding_len-len(bert_cls)))
            object_mask = [1]*size+[0.0]*(self.padding_len-size)
            target_id.extend([-100]*(self.padding_len-len(target_id)))

        return {
            'visual_feat': torch.tensor(visual_feat, dtype=torch.float),
            'bert_cls': torch.tensor(bert_cls, dtype=torch.float),
            'target': torch.tensor(target_id, dtype=torch.float),
            'object_mask':torch.tensor(object_mask, dtype=torch.float),
        }

#Model Class
It is a standard transformer-based architecture implemented using TransformerEncoderLayers from PyTorch
Token have a size equal to the sum of textual and visual embeddings, i.e., 768 + 2048 = 2816

In [None]:
class New_model(torch.nn.Module):
    def __init__(self):
        super(New_model, self).__init__()

        self.encoder_layer = TransformerEncoderLayer(d_model=2816, nhead=16)
        self.encoder = TransformerEncoder(self.encoder_layer, num_layers=6)
        self.pre_classifier = torch.nn.Linear(2816, 2816)
        self.dropout = torch.nn.Dropout(0.1)
        self.classifier = torch.nn.Linear(2816, 4)

    def forward(self, textual_embed, visual_embed, attention_mask):

        embed = torch.cat((textual_embed, visual_embed), dim=2)
        attention_mask = attention_mask.transpose(0,1)

        output_1 = self.encoder(embed, src_key_padding_mask=attention_mask)
        output = self.pre_classifier(output_1)
        output = torch.nn.Tanh()(output)
        output = self.dropout(output)
        output = self.classifier(output)
        return output

Compute metric function definition

In [None]:
def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)
    # Remove ignored index (special tokens)
    true_predictions = [
        [label_list[int(p)] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[int(l)] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    results = metric.compute(predictions=true_predictions, references=true_labels)
    if return_entity_level_metrics:
        # Unpack nested dictionaries
        final_results = {}
        for key, value in results.items():
            if isinstance(value, dict):
                for n, v in value.items():
                    final_results[f"{key}_{n}"] = v
            else:
                final_results[key] = value
        return final_results
    else:
        return results


#Main

Load textual and visual precomputed features

In [None]:
with open("data/FUNSD/dataset/training_data/all_annotations_visual.pickle", 'rb') as f:
    train_data_visual = pickle.load(f)

with open("data/FUNSD/dataset/training_data/all_annotations_textual.pickle", 'rb') as f:
    train_data_textual = pickle.load(f)

with open("data/FUNSD/dataset/testing_data/all_annotations_visual.pickle", 'rb') as f:
    test_data_visual = pickle.load(f)

with open("data/FUNSD/dataset/testing_data/all_annotations_textual.pickle", 'rb') as f:
    test_data_textual = pickle.load(f)


In [None]:
label_dict = {}
label_id = 0
for doc in train_data_textual:
    for i in range(len(train_data_textual[doc]["form"])):
        if train_data_textual[doc]["form"][i]["label"] not in label_dict:
            label_dict[train_data_textual[doc]["form"][i]["label"]] = label_id
            label_id += 1

for doc in train_data_textual:
    train_data_textual[doc]["visual_list"] = train_data_visual[doc]["visual_list"]

for doc in test_data_textual:
    test_data_textual[doc]["visual_list"] = test_data_visual[doc]["visual_list"]

Define model, loss and optimizer and create dataset objects

In [None]:
model = New_model()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

metric = load_metric("seqeval")
return_entity_level_metrics = True
label_list = list(label_dict.keys())

train_dataset = CustomDataset(doc_info=train_data_textual)
test_dataset = CustomDataset(doc_info=test_data_textual)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False)

optimizer = AdamW(model.parameters(), lr=2e-5)
loss_function = torch.nn.CrossEntropyLoss()
model.train()

Train and Validation functions

In [None]:
def train(num_train_epochs):
    for _ in range(num_train_epochs):
        total_loss = 0
        for data in tqdm(train_dataloader):
            # get the inputs;
            labels = data['target'].to(device, dtype = torch.long)
            visual_feats = data['visual_feat'].to(device, dtype = torch.float)
            bert_cls = data['bert_cls'].to(device, dtype = torch.float)
            object_mask = data['object_mask'].to(device, dtype = torch.float)

            optimizer.zero_grad()
            outputs = model(bert_cls,visual_feats,object_mask)
            # Change the number of categories
            loss = loss_function(outputs.view(-1, 4), labels.view(-1)) #
            loss.backward()
            total_loss += loss.item()
            optimizer.step()
        print("Train Loss:", total_loss/len(train_dataloader))


In [None]:
def eval(test_dataloader):
    preds_val = None
    out_label_ids = None
    model.eval()
    total_loss = 0
    for data in tqdm(test_dataloader):
        with torch.no_grad():
            # get the inputs;
            labels = data['target'].to(device, dtype = torch.long)
            visual_feats = data['visual_feat'].to(device, dtype = torch.float)
            bert_cls = data['bert_cls'].to(device, dtype = torch.float)
            object_mask = data['object_mask'].to(device, dtype = torch.float)

            outputs = model(bert_cls,visual_feats,object_mask)
            loss = loss_function(outputs.view(-1, 4), labels.view(-1))
            total_loss += loss.item()
            if preds_val is None:
                preds_val = outputs.detach().cpu().numpy()
                out_label_ids = data["target"].detach().cpu().numpy()
            else:
                preds_val = np.append(preds_val, outputs.detach().cpu().numpy(), axis=0)
                out_label_ids = np.append(out_label_ids, data["target"].detach().cpu().numpy(), axis=0)
    print("Val Loss:", total_loss/len(test_dataloader))
    pprint.pprint(compute_metrics((preds_val, out_label_ids)))
    return compute_metrics((preds_val, out_label_ids))

##Train Loop

Run training for 30 epochs, validating the model every epoch. Everytime a new best model is found, the checkpoint is saved. At the end of trainig loop best performing epoch and associated evaluation results are printed

In [None]:
current_f1 = 0
best_epoch = 0
optimizer = AdamW(model.parameters(), lr=1e-5)

for epoch in range(30):
    print("Epoch:", epoch+1)
    train(1)
    val_result = eval(test_dataloader)
    if val_result['overall_f1'] > current_f1:
        current_f1 = val_result['overall_f1']
        best_epoch = epoch
        best_val_result = val_result
        torch.save(model, 'results/vanilla-transformer_funsd.pth')

print("Best Epoch:", best_epoch+1)
pprint.pprint(best_val_result)
