In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys

current_path = os.path.dirname(os.path.abspath('__file__'))
root_path = os.path.dirname(current_path)
sys.path.append(root_path)

from IPython.display import display
import numpy as np
from PIL import Image
import tqdm as tqdm


# lib
from src.dataset import DocumentDataset
from src.utils import draw_boxes, apply_ocr

# nn
import torch
from torch.utils.data import DataLoader

from transformers import LayoutLMTokenizer, LayoutLMForSequenceClassification
from transformers import logging


  from .autonotebook import tqdm as notebook_tqdm


## Config

In [2]:
data_path = os.path.join(root_path, 'data')

# load the dataset
dataset = DocumentDataset(data_path=data_path)

logging.set_verbosity_warning()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Define and load the model

In [3]:
model = LayoutLMForSequenceClassification.from_pretrained("microsoft/layoutlm-base-uncased", num_labels=len(dataset.labels))
model.to(device)

Some weights of the model checkpoint at microsoft/layoutlm-base-uncased were not used when initializing LayoutLMForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing LayoutLMForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LayoutLMForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LayoutLMForSequenceClassification were not initialized from the model checkpoint 

LayoutLMForSequenceClassification(
  (layoutlm): LayoutLMModel(
    (embeddings): LayoutLMEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (x_position_embeddings): Embedding(1024, 768)
      (y_position_embeddings): Embedding(1024, 768)
      (h_position_embeddings): Embedding(1024, 768)
      (w_position_embeddings): Embedding(1024, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): LayoutLMEncoder(
      (layer): ModuleList(
        (0-11): 12 x LayoutLMLayer(
          (attention): LayoutLMAttention(
            (self): LayoutLMSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True

## Define the train, val and test dataloaders

In [4]:
train_split_ratio = 0.8
val_split_ratio = 0.1
test_split_ration = 0.1

train_size = int(train_split_ratio * len(dataset))
val_size = int(val_split_ratio * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)

## Train the model

### Model Config

In [5]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
epochs = 5

In [6]:
for epoch in range(epochs):
    print(f'Epoch {epoch + 1}/{epochs}')
    print('-' * 10)

    # train mode
    model.train()

    running_loss_train = 0.0
    correct_train = 0
    for step, batch in enumerate(tqdm.tqdm(train_dataloader)):
        
        # convert the boxes to tensor
        bboxes = torch.stack([torch.stack(bbox) for bbox in batch['bbox']]).to(device)
        bboxes = bboxes.reshape(bboxes.shape[2], bboxes.shape[0], bboxes.shape[1])
        #print(f"bboxes: {bboxes.shape}")

        # input ids
        input_ids = torch.stack(batch['input_ids']).to(device)
        input_ids = input_ids.reshape(input_ids.shape[1], input_ids.shape[0])
        #print(f"input_ids: {input_ids.shape}")        

        # attention mask
        attention_mask = torch.stack(batch['attention_mask']).to(device)
        attention_mask = attention_mask.reshape(attention_mask.shape[1], attention_mask.shape[0])
        #print(f"attention_mask: {attention_mask.shape}")

        # token type ids
        token_type_ids = torch.stack(batch['token_type_ids']).to(device)
        token_type_ids = token_type_ids.reshape(token_type_ids.shape[1], token_type_ids.shape[0])
        #print(f"token_type_ids: {token_type_ids.shape}")

        # labels
        labels = batch['label'].to(device)

        # forward
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            bbox=bboxes,
            token_type_ids=token_type_ids,
            labels=labels
        )

        # update loss and accuracy
        loss = outputs.loss
        running_loss_train += loss.item()
        predictions = outputs.logits.argmax(dim=-1)
        correct_train += (predictions == labels).sum().item()

        # backward
        loss.backward()
        optimizer.step()

        # zero the parameter gradients
        optimizer.zero_grad()

    avg_train_loss = running_loss_train / len(train_dataloader)
    avg_train_acc = correct_train / len(train_dataloader)
    print(f'Average train loss: {avg_train_loss}')
    print(f'Average train accuracy: {avg_train_acc}')

    # validation mode
    model.eval()

    running_loss_val = 0.0
    correct_val = 0
    for step, batch in enumerate(tqdm.tqdm(val_dataloader)):
            
        # convert the boxes to tensor
        bboxes = torch.stack([torch.stack(bbox) for bbox in batch['bbox']]).to(device)
        bboxes = bboxes.reshape(bboxes.shape[2], bboxes.shape[0], bboxes.shape[1])
        #print(f"bboxes: {bboxes.shape}")

        # input ids
        input_ids = torch.stack(batch['input_ids']).to(device)
        input_ids = input_ids.reshape(input_ids.shape[1], input_ids.shape[0])
        #print(f"input_ids: {input_ids.shape}")        

        # attention mask
        attention_mask = torch.stack(batch['attention_mask']).to(device)
        attention_mask = attention_mask.reshape(attention_mask.shape[1], attention_mask.shape[0])
        #print(f"attention_mask: {attention_mask.shape}")

        # token type ids
        token_type_ids = torch.stack(batch['token_type_ids']).to(device)
        token_type_ids = token_type_ids.reshape(token_type_ids.shape[1], token_type_ids.shape[0])
        #print(f"token_type_ids: {token_type_ids.shape}")

        # labels
        labels = batch['label'].to(device)

        # forward
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            bbox=bboxes,
            token_type_ids=token_type_ids,
            labels=labels
        )

        # update loss and accuracy
        loss = outputs.loss
        running_loss_val += loss.item()
        predictions = outputs.logits.argmax(dim=-1)
        correct_val += (predictions == labels).sum().item()

    avg_val_loss = running_loss_val / len(val_dataloader)
    avg_val_acc = correct_val / len(val_dataloader)
    print(f'Average val loss: {avg_val_loss}')
    print(f'Average val accuracy: {avg_val_acc}')




Epoch 1/5
----------


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [24:13<00:00,  1.38it/s]


Average train loss: 0.533275840759743
Average train accuracy: 0.84


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [02:57<00:00,  1.41it/s]


Average val loss: 0.408884329829365
Average val accuracy: 0.848
Epoch 2/5
----------


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [24:57<00:00,  1.34it/s]


Average train loss: 0.23069822494010442
Average train accuracy: 0.927


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:02<00:00,  1.37it/s]


Average val loss: 0.34862392212636767
Average val accuracy: 0.888
Epoch 3/5
----------


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [25:50<00:00,  1.29it/s]


Average train loss: 0.13496689462405628
Average train accuracy: 0.96


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:07<00:00,  1.33it/s]


Average val loss: 0.3545875901784748
Average val accuracy: 0.908
Epoch 4/5
----------


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [25:39<00:00,  1.30it/s]


Average train loss: 0.08260223338217475
Average train accuracy: 0.9735


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:03<00:00,  1.36it/s]


Average val loss: 0.4173163246400654
Average val accuracy: 0.88
Epoch 5/5
----------


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [24:43<00:00,  1.35it/s]


Average train loss: 0.0609215815901116
Average train accuracy: 0.9825


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [02:56<00:00,  1.42it/s]

Average val loss: 0.39413693897938357
Average val accuracy: 0.9





## Save the model and tokenizer

In [10]:
model_path = os.path.join(root_path, 'models', 'layoutlm-model')
model.save_pretrained(model_path)

## Load the model from disk
- Comment if you are doing training and testing together

In [11]:
loaded_model = LayoutLMForSequenceClassification.from_pretrained(model_path)

## Test

In [13]:
# test mode
model.eval()

running_loss_test = 0.0
correct_test = 0
y_true = []
y_pred = []
for step, batch in enumerate(tqdm.tqdm(test_dataloader)):
            
    # convert the boxes to tensor
    bboxes = torch.stack([torch.stack(bbox) for bbox in batch['bbox']]).to(device)
    bboxes = bboxes.reshape(bboxes.shape[2], bboxes.shape[0], bboxes.shape[1])
    #print(f"bboxes: {bboxes.shape}")

    # input ids
    input_ids = torch.stack(batch['input_ids']).to(device)
    input_ids = input_ids.reshape(input_ids.shape[1], input_ids.shape[0])
    #print(f"input_ids: {input_ids.shape}")        

    # attention mask
    attention_mask = torch.stack(batch['attention_mask']).to(device)
    attention_mask = attention_mask.reshape(attention_mask.shape[1], attention_mask.shape[0])
    #print(f"attention_mask: {attention_mask.shape}")

    # token type ids
    token_type_ids = torch.stack(batch['token_type_ids']).to(device)
    token_type_ids = token_type_ids.reshape(token_type_ids.shape[1], token_type_ids.shape[0])
    #print(f"token_type_ids: {token_type_ids.shape}")

    # labels
    labels = batch['label'].to(device)

    # forward
    outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        bbox=bboxes,
        token_type_ids=token_type_ids,
        labels=labels
    )

    # update loss and accuracy
    loss = outputs.loss
    running_loss_test += loss.item()
    predictions = outputs.logits.argmax(dim=-1)
    correct_test += (predictions == labels).sum().item()

    # update y_true and y_pred
    y_true.extend(labels.tolist())
    y_pred.extend(predictions.tolist())


avg_test_loss = running_loss_test / len(test_dataloader)
avg_test_acc = correct_test / len(test_dataloader)
print(f'Average test loss: {avg_test_loss}')
print(f'Average test accuracy: {avg_test_acc}')

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:01<00:00,  1.38it/s]

Average test loss: 0.2526314825930167
Average test accuracy: 0.944





## Classification report

In [18]:
from sklearn.metrics import classification_report, confusion_matrix

In [19]:
classification_report(y_true, y_pred, output_dict=True)

{'0': {'precision': 0.9807692307692307,
  'recall': 0.9272727272727272,
  'f1-score': 0.9532710280373831,
  'support': 55},
 '1': {'precision': 0.859375,
  'recall': 0.9482758620689655,
  'f1-score': 0.9016393442622951,
  'support': 58},
 '2': {'precision': 0.9148936170212766,
  'recall': 0.9772727272727273,
  'f1-score': 0.945054945054945,
  'support': 44},
 '3': {'precision': 1.0,
  'recall': 0.9574468085106383,
  'f1-score': 0.9782608695652174,
  'support': 47},
 '4': {'precision': 1.0,
  'recall': 0.9130434782608695,
  'f1-score': 0.9545454545454545,
  'support': 46},
 'accuracy': 0.944,
 'macro avg': {'precision': 0.9510075695581015,
  'recall': 0.9446623206771856,
  'f1-score': 0.946554328293059,
  'support': 250},
 'weighted avg': {'precision': 0.9481655073649754,
  'recall': 0.944,
  'f1-score': 0.9447790314813715,
  'support': 250}}

In [20]:
confusion_matrix(y_true, y_pred)

array([[51,  4,  0,  0,  0],
       [ 1, 55,  2,  0,  0],
       [ 0,  1, 43,  0,  0],
       [ 0,  2,  0, 45,  0],
       [ 0,  2,  2,  0, 42]])

## Summary
- Train Accuracy: 0.982
- Validation Accuracy: 0.9
- Test Accuracy: 0.944
- Average Precision: 0.951
- Average Recall: 0.944
- Average F1-Score: 0.946