<a href="https://colab.research.google.com/github/adamserag1/Interpretability-for-VRDU-models/blob/main/notebooks/LayoutLM3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Layout LMv3

## Imports

In [None]:
# UNCOMMNET FOR USE IN COLAB
!git clone https://github.com/adamserag1/Interpretability-for-VRDU-models.git

In [None]:
!pip install datasets seqeval evaluate transformers torch

In [None]:
from transformers import AutoProcessor, LayoutLMv3ForTokenClassification, set_seed
from PIL import Image,ImageDraw, ImageFont
from datasets import load_dataset
import torch
import pandas as pd
import evaluate
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

set_seed(0)

## Setup

### Data

In [None]:
funsd = load_dataset("nielsr/funsd", trust_remote_code=True)
funsd_train = funsd["train"]
labels = funsd_train.features["ner_tags"].feature.names
id2label = {v: k for v, k in enumerate(labels)}
label2id = {v: k for k, v in enumerate(labels)}
print(id2label)
print(label2id)

In [None]:
processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)

In [None]:
def encode(example):
  images = [Image.open(path).convert("RGB") for path in example["image_path"]]
  words = example["words"]
  boxes = example["bboxes"]
  labels = example["ner_tags"]
  encoding = processor(images,
                       words,
                       boxes=boxes,
                       word_labels=labels,
                       padding="max_length",
                       truncation=True,
                       return_tensors="pt")
  return encoding

In [None]:
# 80:20 Train : Validate
split = funsd["train"].train_test_split(test_size=0.2)
train_dataset = split["train"]
val_dataset = split["test"]

train_dataset = train_dataset.map(encode, batched=True, remove_columns=funsd["train"].column_names)
val_dataset = val_dataset.map(encode, batched=True, remove_columns=funsd["train"].column_names)

### Finetuning

In [None]:
metric = evaluate.load("seqeval")
import numpy as np
def compute_metrics(p):
  predictions, labels = p
  predictions = np.argmax(predictions, axis=2)

  true_preds = [
      [id2label[p] for (p, l) in zip(prediction, label) if l != -100]
      for prediction, label in zip(predictions, labels)
  ]
  true_labels = [
      [id2label[l] for (p, l) in zip(prediction, label) if l != -100]
      for prediction, label in zip(predictions, labels)
  ]
  results = metric.compute(predictions=true_preds, references=true_labels)

  return {
      "precision": results["overall_precision"],
      "recall": results["overall_recall"],
      "f1": results["overall_f1"],
      "accuracy": results["overall_accuracy"],
  }

In [None]:
from huggingface_hub import login
login()

In [None]:
from transformers import TrainingArguments, Trainer
output_dir = "./layoutlmv3-finetuned-funsd"
model = LayoutLMv3ForTokenClassification.from_pretrained("microsoft/layoutlmv3-base", num_labels=len(labels), id2label=id2label, label2id=label2id)
training_args = TrainingArguments(
    output_dir=output_dir,
    max_steps=1000,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    learning_rate=1e-5,
    eval_strategy="steps",
    eval_steps=100,
    load_best_model_at_end=True,
    metric_for_best_model="f1"
)
## Chat gpt i want to save to hub etc.
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=processor,
    compute_metrics=compute_metrics,
)

trainer.train()

In [None]:
trainer.push_to_hub("adamadam111/layoutlmv3bfinetuned-funsd-01")
processor.push_to_hub("adamadam111/layoutlmv3b-finetuned-funsd-01")

In [None]:
predictions = logits.argmax(-1).squeeze().tolist()
token_boxes = encoding.bbox.squeeze().tolist()

predictions = [p for p, l in zip(predictions, labels) if l != 100]
labels = [l for l in labels if l != 100]

width, height = image.size

def denormalize_box(bbox, width, height):
  return [
      width * (bbox[0] / 1000),
      height * (bbox[1] / 1000),
      width * (bbox[2] / 1000),
      height * (bbox[3] / 1000),
  ]

true_preds = [id2label[pred] for pred in predictions]
true_boxes = [denormalize_box(box, width, height) for box in token_boxes]

print(true_preds)
print(len(example["ner_tags"]))
print(len(logits.argmax(-1).squeeze().tolist()))
count = 0
correct = 0
for idx, pred in enumerate(predictions):
  if pred == example["ner_tags"][idx]:
    correct += 1
  count += 1

accuracy = correct/count
print(accuracy)


In [None]:
draw = ImageDraw.Draw(image)

font = ImageFont.load_default()

def iob_to_label(label):
    label = label[2:]
    if not label:
        return 'other'
    return label

label2color = {'question':'blue', 'answer':'green', 'header':'orange', 'other':'violet'}

for prediction, box in zip(true_preds, true_boxes):
    predicted_label = iob_to_label(prediction).lower()
    draw.rectangle(box, outline=label2color[predicted_label])
    draw.text(
        (box[0]+10, box[1]-10), text=predicted_label, fill=label2color[predicted_label], font=font
    )

image