In [None]:
!pip install datasets

In [None]:
import numpy as np
import torch
import pickle
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import LxmertTokenizer, LxmertModel
import logging
logging.basicConfig(level=logging.ERROR)
import math
from transformers import AdamW
from tqdm import tqdm
from datasets import load_metric
import numpy as np
import pprint
import copy
from PIL import Image

LXMERT utility functions: 1D-positional encoding and Bounding Box normaliation

In [None]:
def positionalencoding1d(d_model, feature_list):
    """
    :param d_model: dimension of the model
    :param feature_list: length of positions
    :return: length*d_model position matrix
    """
    if d_model % 2 != 0:
        raise ValueError("Cannot use sin/cos positional encoding with "
                         "odd dim (got dim={:d})".format(d_model))
    pe = torch.zeros(1, d_model)
    feats = torch.tensor(feature_list)
    div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *
                         -(math.log(10000.0) / d_model)))
    pe[:, 0::2] = torch.sin(feats.float() * div_term)
    pe[:, 1::2] = torch.cos(feats.float() * div_term)
    pe = np.array(pe.tolist())
    return pe

def normalize_bbox(bbox, width, height):
    x1 = bbox[0]
    y1 = bbox[1]
    x2 = bbox[2]
    y2 = bbox[3]
    return [x1/width, y1/height,abs(x2-x1)/width,abs(y2-y1)/height]

#Dataset Class

In [None]:
class CustomDataset(Dataset):

    def __init__(self, doc_info, tokenizer, positional_encoding, padding_len=100, split="train"):
        """
        Args:
            annotations (List[List]): List of lists containing the word-level annotations (words, labels, boxes).
            image_dir (string): Directory with all the document images.
            processor (LayoutLMv2Processor): Processor to prepare the text + image.
        """
        self.tokenizer = tokenizer
        self.doc_info = doc_info
        self.positional_encoding = positional_encoding
        self.padding_len = padding_len
        self.split = split

    def __len__(self):
        return len(self.doc_info)

    def __getitem__(self, idx):
        key = list(self.doc_info.keys())[idx]
        doc_info = self.doc_info[key]
        image_path = "data/FUNSD/dataset/"+self.split+"ing_data/images/"+key+".png"
        image = Image.open(image_path)
        width, height = image.size

        texts = []
        target_id = []
        boxes = []
        for i in range(len(doc_info["form"])):
            texts.append(doc_info["form"][i]["text"])
            target_id.append(doc_info["form"][i]["label"])
            boxes.append(copy.deepcopy(doc_info["form"][i]["box"]))
        target_id = [label_dict[l] for l in target_id]
        text = " ".join(texts)
        # Get Fine-Grained Level Information
        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=512,
            pad_to_max_length=True,
            return_token_type_ids=True
        )
        ids = inputs['input_ids']
        mask = inputs['attention_mask']
        token_type_ids = inputs["token_type_ids"]

        object_mask = []
        visual_feat = copy.deepcopy(doc_info["visual_list"])
        norm_bbox = [normalize_bbox(bbox, width, height) for bbox in boxes]

        if len(visual_feat) >= self.padding_len:
            visual_feat = visual_feat[:self.padding_len]
            norm_bbox = norm_bbox[:self.padding_len]
            object_mask = [1]*self.padding_len
            target_id = target_id[:self.padding_len]
        else:
            size = len(visual_feat)
            visual_feat.extend([[0.0]*2048]*(self.padding_len-len(visual_feat)))
            norm_bbox.extend([[0.0]*4]*(self.padding_len-len(norm_bbox)))
            object_mask = [1]*size+[0.0]*(self.padding_len-size)
            target_id.extend([-100]*(self.padding_len-len(target_id)))

        return {
            'ids': torch.tensor(ids, dtype=torch.long), #Key information input ids
            'mask': torch.tensor(mask, dtype=torch.float), # Key information masks
            'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long), # Key information token type ids
            'visual_feat': torch.tensor(visual_feat, dtype=torch.float),
            'target': torch.tensor(target_id, dtype=torch.float),
            'positional_encoding': torch.tensor(self.positional_encoding, dtype = torch.float),
            'object_mask':torch.tensor(object_mask, dtype=torch.float),
            'norm_bbox':torch.tensor(norm_bbox, dtype=torch.float)
        }

#Model Class

In [None]:
class New_model(torch.nn.Module):
    def __init__(self):
        super(New_model, self).__init__()
        self.l1 = LxmertModel.from_pretrained("unc-nlp/lxmert-base-uncased",output_hidden_states = True)
        self.pre_classifier = torch.nn.Linear(768, 768)
        self.dropout = torch.nn.Dropout(0.1)
        self.classifier = torch.nn.Linear(768, 4)

    def forward(self, ids, mask, token_type_ids, visual_feat,attention_mask,norm_bbox):
        visual_attention_mask = attention_mask.to(device, dtype = torch.float)
        output_1 = self.l1(input_ids=ids,
                           attention_mask=mask,
                           token_type_ids=token_type_ids,
                           visual_feats =visual_feat,
                           visual_pos = norm_bbox,
                           visual_attention_mask=visual_attention_mask)
        visual_feat = output_1.vision_hidden_states[-1]
        output = self.pre_classifier(visual_feat)
        output = torch.nn.Tanh()(output)
        output = self.dropout(output)
        output = self.classifier(output)
        return output

Compute metric function

In [None]:
def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)
    # Remove ignored index (special tokens)
    true_predictions = [
        [label_list[int(p)] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[int(l)] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    results = metric.compute(predictions=true_predictions, references=true_labels)
    if return_entity_level_metrics:
        # Unpack nested dictionaries
        final_results = {}
        for key, value in results.items():
            if isinstance(value, dict):
                for n, v in value.items():
                    final_results[f"{key}_{n}"] = v
            else:
                final_results[key] = value
        return final_results
    else:
        return results

#Main

Load visual precomputed features. Precomputed textual features are not needed because ViasualBERT uses its own tokenizer.

In [None]:
with open("data/FUNSD/dataset/training_data/all_annotations_visual.pickle", 'rb') as f:
    train_data = pickle.load(f)

with open("data/FUNSD/dataset/testing_data/all_annotations_visual.pickle", 'rb') as f:
    test_data = pickle.load(f)


Define model, loss and optimizer and create dataset objects

In [None]:
label_dict = {}
label_id = 0
for doc in train_data:
    for i in range(len(train_data[doc]["form"])):
        if train_data[doc]["form"][i]["label"] not in label_dict:
            label_dict[train_data[doc]["form"][i]["label"]] = label_id
            label_id += 1

tokenizer = LxmertTokenizer.from_pretrained('unc-nlp/lxmert-base-uncased')
model = New_model()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

metric = load_metric("seqeval")
return_entity_level_metrics = True
label_list = list(label_dict.keys())

positional_encoding = []
for i in range(100):
    positional_encoding.append(positionalencoding1d(768,i)[0])
positional_encoding = np.array(positional_encoding)

train_dataset = CustomDataset(doc_info=train_data, tokenizer=tokenizer, positional_encoding=positional_encoding, split="train")
test_dataset = CustomDataset(doc_info=test_data, tokenizer=tokenizer, positional_encoding=positional_encoding, split="test")
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False)

optimizer = AdamW(model.parameters(), lr=2e-5)
loss_function = torch.nn.CrossEntropyLoss()
model.train()

Train and Validation functions


In [None]:
def train(num_train_epochs):
    for _ in range(num_train_epochs):
        total_loss = 0
        for data in tqdm(train_dataloader):
            # get the inputs;
            ids = data['ids'].to(device, dtype = torch.long)
            mask = data['mask'].to(device, dtype = torch.float)
            token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)
            labels = data['target'].to(device, dtype = torch.long)
            visual_feats = data['visual_feat'].to(device, dtype = torch.float)
            norm_bbox = data['norm_bbox'].to(device, dtype = torch.float)
            object_mask = data['object_mask'].to(device, dtype = torch.float)

            optimizer.zero_grad()
            outputs = model(ids, mask, token_type_ids,visual_feats,object_mask,norm_bbox)
            # Change the number of categories
            loss = loss_function(outputs.view(-1, 4), labels.view(-1)) #
            loss.backward()
            total_loss += loss.item()
            optimizer.step()
        print("Train Loss:", total_loss/len(train_dataloader))




In [None]:
def eval(test_dataloader):
    preds_val = None
    out_label_ids = None
    model.eval()
    total_loss = 0
    for data in tqdm(test_dataloader):
        with torch.no_grad():
            # get the inputs;
            ids = data['ids'].to(device, dtype = torch.long)
            mask = data['mask'].to(device, dtype = torch.float)
            token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)
            visual_feats = data['visual_feat'].to(device, dtype = torch.float)
            norm_bbox = data['norm_bbox'].to(device, dtype = torch.float)
            object_mask = data['object_mask'].to(device, dtype = torch.float)
            labels = data['target'].to(device, dtype = torch.long)

            optimizer.zero_grad()
            outputs = model(ids, mask, token_type_ids, visual_feats, object_mask,norm_bbox)
            loss = loss_function(outputs.view(-1, 4), labels.view(-1))
            total_loss += loss.item()

            if preds_val is None:
                preds_val = outputs.detach().cpu().numpy()
                out_label_ids = data["target"].detach().cpu().numpy()
            else:
                preds_val = np.append(preds_val, outputs.detach().cpu().numpy(), axis=0)
                out_label_ids = np.append(out_label_ids, data["target"].detach().cpu().numpy(), axis=0)

    print("Val Loss:", total_loss/len(test_dataloader))
    pprint.pprint(compute_metrics((preds_val, out_label_ids)))
    return compute_metrics((preds_val, out_label_ids))

##Train Loop

Run training for 30 epochs, validating the model every epoch. Everytime a new best model is found, the checkpoint is saved. At the end of trainig loop best performing epoch and associated evaluation results are printed

In [None]:
current_f1 = 0
best_epoch = 0
optimizer = AdamW(model.parameters(), lr=1e-5)

for epoch in range(30):
    print("Epoch:", epoch+1)
    train(1)
    val_result = eval(test_dataloader)
    if val_result['overall_f1'] > current_f1:
        current_f1 = val_result['overall_f1']
        best_epoch = epoch
        best_val_result = val_result
        torch.save(model, 'results/lxmert_funsd.pth')

print("Best Epoch:", best_epoch+1)
pprint.pprint(best_val_result)