In [None]:
import warnings
warnings.filterwarnings('ignore')

import os

labels = ["advertisement",
          "budget",
          "email",
          "file_folder",
          "form",
          "handwritten",
          "invoice",
          "letter",
          "memo",
          "news_article",
          "presentation",
          "questionnaire",
          "resume",
          "scientific_publication",
          "scientific_report",
          "specification"]
labels.sort()
id2label = {v: k for v, k in enumerate(labels)}
label2id = {k: v for v, k in enumerate(labels)}
id2label

In [None]:
import pandas as pd
from glob import glob
import random
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data.dataloader import default_collate
import torch
from transformers import AdamW
from tqdm.notebook import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

files = sorted(glob('processed_data/part*'))
print("total num of files", len(files))
random.shuffle(files)

N = 510
CHUNK_SIZE = 50
CHUNKS_NUM = int(len(files) / CHUNK_SIZE) + 1
print("num of chunks", CHUNKS_NUM)

def get_chunk(index, files_):
    i = CHUNK_SIZE * index
    df_tmp = pd.read_parquet(files_[i])
    for f in files_[i + 1:i + CHUNK_SIZE]:
        df_tmp = pd.concat([df_tmp,pd.read_parquet(f)])
    return df_tmp

def pad_or_cut(arr, n, padd_val = 0):
    if len(arr) >= n:
        return arr[:n]
    else:
        return np.pad(arr, (0,(n - len(arr))), 'constant',constant_values = padd_val)

def preprocess_df(df, label_column = 'label'):
    df['labels'] = df[label_column].apply(lambda x: label2id[x])
    df['image_path'] = df['path']
    df = df.drop(['path'], axis=1)
    df_ = df.reset_index()
    df_ = df_.drop([label_column], axis=1)
    df_ = df_[df_['input_ids'].apply(lambda x: len(x) > 0)]
    df_ = df_[df_['bbox'].apply(lambda x: len(x) > 0)]
    df_ = df_[df_['attention_mask'].apply(lambda x: len(x) > 0)]
    df_['input_ids'] = df_['input_ids'].apply(lambda x: pad_or_cut(np.array(x), N))
    df_['attention_mask'] = df_['attention_mask'].apply(lambda x: pad_or_cut(np.array(x), N))
    df_['bbox'] = df_['bbox'].apply(lambda x: np.reshape(pad_or_cut(x, N*4), (-1,4)))
    return df_

class PandasDataset(Dataset):
    def __init__(self, dataframe, target_key, device):
        self.dataframe = dataframe
        self.target_key = target_key
        self.device = device

    def __len__(self):
        return len(self.dataframe[self.target_key])

    def __getitem__(self, index):
        return {k: torch.tensor(self.dataframe[k][index]).to(self.device) for k in self.dataframe}

dataloader = torch.utils.data.DataLoader(PandasDataset(df_[
    ['image_path', 'input_ids', 'attention_mask', 'bbox', 'labels']
    ], 'labels', device), batch_size=2)

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, LayoutLMv3FeatureExtractor, LayoutLMv3Processor

tokenizer = AutoTokenizer.from_pretrained("SCUT-DLVCLab/lilt-roberta-en-base")
feature_extractor = LayoutLMv3FeatureExtractor()
processor = LayoutLMv3Processor(feature_extractor, tokenizer, apply_ocr = True)

In [None]:
from sklearn.model_selection import train_test_split

chunk_idxs = range(0, CHUNKS_NUM)
trainingSet_idxs, testSet_idxs = train_test_split(chunk_idxs, test_size=0.1)

In [None]:
from transformers import AutoModelForSequenceClassification
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = AutoModelForSequenceClassification.from_pretrained("SCUT-DLVCLab/lilt-roberta-en-base",
                                                          num_labels=len(labels))
model.to(device)

In [None]:
model_save_path = 'fine_tuned_model'
!mkdir $model_save_path

optimizer = AdamW(model.parameters(), lr=2e-5)

global_step = 0
num_train_epochs = 5

#put the model in training mode
model.train()
n = 0.0
tot_n = len(trainingSet_idxs)

for epoch in range(num_train_epochs):
    print("Epoch:", epoch)
    for idx in trainingSet_idxs:
        n = n + 1
        print("Epoch:", epoch, (n / tot_n))
        df = get_chunk(idx, files)
        dataloader = torch.utils.data.DataLoader(PandasDataset(preprocess_df(df, 'act_label')[
            ['input_ids', 'attention_mask', 'bbox', 'labels']    
            ], 'labels', device), batch_size=8)
        running_loss = 0.0
        correct = 0
        try:
            for batch in tqdm(dataloader):
                # forward pass
                outputs = model(**batch)
                loss = outputs.loss
                running_loss += loss.item()
                predictions = outputs.logits.argmax(-1)
                correct += (predictions == batch['labels']).float().sum()
                # backward pass to get the gradients 
                loss.backward()
                # update
                optimizer.step()
                optimizer.zero_grad()
                global_step += 1
        except:
            print("Exception")
    model.save_pretrained(model_save_path)

In [None]:
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained(model_save_path)
model.to(device)

In [None]:
import gc
gc.collect()
from tqdm.notebook import tqdm

correct = 0
tot = 0
N = 510
for idx in testSet_idxs:
    df = get_chunk(idx, files)
    tot = tot + len(df)
    dataloader = torch.utils.data.DataLoader(PandasDataset(preprocess_df(df, "act_label")[
        ['input_ids', 'attention_mask', 'bbox', 'labels']
        ], 'labels', device), batch_size=1)
    try:
        for batch in tqdm(dataloader):
            outputs = model(**batch)
            predictions = outputs.logits.argmax(-1)
            correct += (predictions == batch['labels']).float().sum()
    except:
        accuracy = 100 * correct / tot
        print("Testing accuracy:", accuracy.item()) 
accuracy = 100 * correct / tot
print("Testing accuracy:", accuracy.item()) 