In [None]:
!pip install datasets

In [None]:
import numpy as np
import torch
import pickle
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import  VisualBertModel, BertTokenizer
import logging
logging.basicConfig(level=logging.ERROR)
from transformers import AdamW
from tqdm import tqdm
from datasets import load_metric
import numpy as np
import pprint
import copy

#Dataset Class

In [None]:
class CustomDataset(Dataset):
    def __init__(self, doc_info, tokenizer, padding_len=100):
        """
        Args:
            annotations (List[List]): List of lists containing the word-level annotations (words, labels, boxes).
            image_dir (string): Directory with all the document images.
            processor (LayoutLMv2Processor): Processor to prepare the text + image.
        """
        self.tokenizer = tokenizer
        self.doc_info = doc_info
        self.padding_len = padding_len

    def __len__(self):
        return len(self.doc_info)

    def __getitem__(self, idx):
        doc_info = self.doc_info[list(self.doc_info.keys())[idx]]
        texts = []
        target_id = []
        visual_feat = doc_info["visual_list"]
        for i in range(len(doc_info["form"])):
            texts.append(doc_info["form"][i]["text"])
            target_id.append(doc_info["form"][i]["label"])
        target_id = [label_dict[l] for l in target_id]
        text = " ".join(texts)
        # Get Fine-Grained Level Information
        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=512,
            pad_to_max_length=True,
            return_token_type_ids=True
        )
        ids = inputs['input_ids']
        mask = inputs['attention_mask']
        token_type_ids = inputs["token_type_ids"]
        visual_feat = copy.deepcopy(visual_feat)

        object_mask = []
        if len(visual_feat) >= self.padding_len:
            visual_feat = visual_feat[:self.padding_len]
            object_mask = [1]*self.padding_len
            target_id = target_id[:self.padding_len]
        else:
            size = len(visual_feat)
            visual_feat.extend([[0.0]*2048]*(self.padding_len-len(visual_feat)))
            object_mask = [1]*size+[0.0]*(self.padding_len-size)
            target_id.extend([-100]*(self.padding_len-len(target_id)))

        return {
            'ids': torch.tensor(ids, dtype=torch.long), #Key information input ids
            'mask': torch.tensor(mask, dtype=torch.float), # Key information masks
            'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long), # Key information token type ids
            'visual_feat': torch.tensor(visual_feat, dtype=torch.float),
            'target': torch.tensor(target_id, dtype=torch.float),
            'object_mask':torch.tensor(object_mask, dtype=torch.float),
        }

#Model Class

It uses a pretrained VisualBERT model, considering all the output embeddings corresponding to visual tokens (from token 512 onwards) to perform the final classification.

In [None]:
class New_model(torch.nn.Module):
    def __init__(self):
        super(New_model, self).__init__()

        self.l1 = VisualBertModel.from_pretrained("uclanlp/visualbert-vqa-coco-pre",output_hidden_states=True)
        self.pre_classifier = torch.nn.Linear(768, 768)
        self.dropout = torch.nn.Dropout(0.1)
        self.classifier = torch.nn.Linear(768, 4)

    def forward(self, ids, mask, token_type_ids, visual_feat, attention_mask):
        visual_token_type_ids = torch.ones(visual_feat.shape[:-1], dtype=torch.long).to(device, dtype = torch.long)
        visual_attention_mask = attention_mask.to(device, dtype = torch.float)
        output_1 = self.l1(input_ids=ids, attention_mask=mask, token_type_ids=token_type_ids,
                                        visual_embeds=visual_feat, visual_token_type_ids=visual_token_type_ids,visual_attention_mask=visual_attention_mask)
        hidden_state = output_1.hidden_states[-1]
        visual_feat = hidden_state[:, 512:,:]
        output = self.pre_classifier(visual_feat)
        output = torch.nn.Tanh()(output)
        output = self.dropout(output)
        output = self.classifier(output)
        return output

Compute metric function definition

In [None]:
def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)
    # Remove ignored index (special tokens)
    true_predictions = [
        [label_list[int(p)] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[int(l)] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    results = metric.compute(predictions=true_predictions, references=true_labels)
    if return_entity_level_metrics:
        # Unpack nested dictionaries
        final_results = {}
        for key, value in results.items():
            if isinstance(value, dict):
                for n, v in value.items():
                    final_results[f"{key}_{n}"] = v
            else:
                final_results[key] = value
        return final_results
    else:
        return results

#Main

Load visual precomputed features. Precomputed textual features are not needed because ViasualBERT uses its own tokenizer.

In [None]:
with open("data/FUNSD/dataset/training_data/all_annotations_visual.pickle", 'rb') as f:
    train_data = pickle.load(f)

with open("data/FUNSD/dataset/testing_data/all_annotations_visual.pickle", 'rb') as f:
    test_data = pickle.load(f)

Define model, loss and optimizer and create dataset objects

In [None]:
label_dict = {}
label_id = 0
for doc in train_data:
    for i in range(len(train_data[doc]["form"])):
        if train_data[doc]["form"][i]["label"] not in label_dict:
            label_dict[train_data[doc]["form"][i]["label"]] = label_id
            label_id += 1

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation=True, do_lower_case=True) #Cased
model = New_model()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

metric = load_metric("seqeval")
return_entity_level_metrics = True
label_list = list(label_dict.keys())

train_dataset = CustomDataset(doc_info=train_data, tokenizer=tokenizer)
test_dataset = CustomDataset(doc_info=test_data, tokenizer=tokenizer)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False)

optimizer = AdamW(model.parameters(), lr=2e-5)
loss_function = torch.nn.CrossEntropyLoss()
model.train()


Train and Validation functions


In [None]:
def train(num_train_epochs):
    for _ in range(num_train_epochs):
        total_loss = 0
        for data in tqdm(train_dataloader):
            # get the inputs;
            ids = data['ids'].to(device, dtype = torch.long)
            mask = data['mask'].to(device, dtype = torch.float)
            token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)
            labels = data['target'].to(device, dtype = torch.long)
            visual_feats = data['visual_feat'].to(device, dtype = torch.float)
            object_mask = data['object_mask'].to(device, dtype = torch.float)

            optimizer.zero_grad()
            outputs = model(ids, mask, token_type_ids,visual_feats,object_mask)
            # Change the number of categories
            loss = loss_function(outputs.view(-1, 4), labels.view(-1)) #
            loss.backward()
            total_loss += loss.item()
            optimizer.step()
        print("Train Loss:", total_loss/len(train_dataloader))

In [None]:
def eval(test_dataloader):
    preds_val = None
    out_label_ids = None
    model.eval()
    total_loss = 0
    for data in tqdm(test_dataloader):
        with torch.no_grad():
            # get the inputs;
            ids = data['ids'].to(device, dtype = torch.long)
            mask = data['mask'].to(device, dtype = torch.float)
            token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)
            visual_feats = data['visual_feat'].to(device, dtype = torch.float)
            object_mask = data['object_mask'].to(device, dtype = torch.float)
            labels = data['target'].to(device, dtype = torch.long)

            optimizer.zero_grad()
            outputs = model(ids, mask, token_type_ids, visual_feats, object_mask)
            loss = loss_function(outputs.view(-1, 4), labels.view(-1))
            total_loss += loss.item()
            if preds_val is None:
                preds_val = outputs.detach().cpu().numpy()
                out_label_ids = data["target"].detach().cpu().numpy()
            else:
                preds_val = np.append(preds_val, outputs.detach().cpu().numpy(), axis=0)
                out_label_ids = np.append(out_label_ids, data["target"].detach().cpu().numpy(), axis=0)
    print("Val Loss:", total_loss/len(test_dataloader))
    pprint.pprint(compute_metrics((preds_val, out_label_ids)))
    return compute_metrics((preds_val, out_label_ids))



##Train Loop

Run training for 20 epochs, validating the model every epoch. Everytime a new best model is found, the checkpoint is saved. At the end of trainig loop best performing epoch and associated evaluation results are printed

In [None]:
current_f1 = 0
best_epoch = 0
optimizer = AdamW(model.parameters(), lr=1e-5)

for epoch in range(20):
    print("Epoch:", epoch+1)
    train(1)
    val_result = eval(test_dataloader)
    if val_result['overall_f1'] > current_f1:
        current_f1 = val_result['overall_f1']
        best_epoch = epoch
        best_val_result = val_result
        torch.save(model, 'results/visualbert_funsd.pth')

print("Best Epoch:", best_epoch+1)
pprint.pprint(best_val_result)