<a href="https://colab.research.google.com/github/adamserag1/Interpretability-for-VRDU-models/blob/main/finetuning/BROS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# BROS

## Imports etc.

In [None]:
# UNCOMMNET FOR USE IN COLAB
!git clone https://github.com/adamserag1/Interpretability-for-VRDU-models.git

In [None]:
!pip install transformers torch datasets evaluate seqeval

In [None]:
import torch
from PIL import Image,ImageDraw, ImageFont
from datasets import load_dataset
import pandas as pd
import evaluate
from transformers import BrosProcessor, BrosSpadeEEForTokenClassification, AutoTokenizer

## Data

In [None]:
funsd = load_dataset("nielsr/funsd", trust_remote_code=True)
label_list = funsd["train"].features["ner_tags"].feature.names
id2label = {v:k for v,k in enumerate(label_list)}
label2id = {k:v for v,k in enumerate(label_list)}

print(funsd["train"]["words"][0])
print(funsd["train"]["ner_tags"][0])
id_0_ner_tags = [id2label[id] for id in funsd["train"]["ner_tags"][0]]
print(id_0_ner_tags)

In [None]:
processor = BrosProcessor.from_pretrained("naver-clova-ocr/bros-base-uncased")
tokenizer = AutoTokenizer.from_pretrained("naver-clova-ocr/bros-base-uncased",do_lower_case=True)

In [None]:
sample = funsd["train"][66]
image = Image.open(sample["image_path"]).convert("RGB")
print(sample['words'])

In [None]:
def normalize_bbox(bbox, width, height):
    return [
        int(1000 * (bbox[0] / width)),
        int(1000 * (bbox[1] / height)),
        int(1000 * (bbox[2] / width)),
        int(1000 * (bbox[3] / height)),
    ]

def tokenize_words(batch):
  encodings = tokenizer(
    batch["words"],
    is_split_into_words=True,
    truncation=True,
    padding="max_length",
    max_length=512,
    return_tensors="pt"
  )

  batch_normalized_bboxes, encoded_labels = [], []
  for idx, (bboxes, img_path, labels) in enumerate(zip(batch["bboxes"], batch["image_path"], batch["ner_tags"])):
    width, height = Image.open(img_path).size
    normalized_bboxes = [normalize_bbox(bbox, width, height) for bbox in bboxes]

    # Align boxes to sub words
    aligned_boxes, aligned_labels = [], []
    for word_id in encodings.word_ids(batch_index=idx):
      if word_id is None:
        aligned_boxes.append([0, 0, 0, 0])
        aligned_labels.append(-100)
      else:
        aligned_boxes.append(normalized_bboxes[word_id])
        aligned_labels.append(labels[word_id])

    batch_normalized_bboxes.append(aligned_boxes)
    encoded_labels.append(aligned_labels)

  encodings['bbox'] = batch_normalized_bboxes
  encodings['labels'] = encoded_labels

  return encodings

In [None]:
train_dataset = funsd["train"].map(tokenize_words, batched=True, remove_columns=funsd["train"].column_names)
val_dataset = funsd["test"].map(tokenize_words, batched=True, remove_columns=funsd["train"].column_names)

train_dataset.set_format("torch")
val_dataset.set_format("torch")

## Finetuning

In [None]:
from transformers import AutoModelForTokenClassification, BrosForTokenClassification
model = BrosForTokenClassification.from_pretrained("naver-clova-ocr/bros-base-uncased", num_labels=len(label_list), id2label=id2label, label2id=label2id)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
from transformers import TrainingArguments, Trainer, DefaultDataCollator
training_args = TrainingArguments(
    output_dir="./bros-funsd-finetuned",
    eval_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=100,
    weight_decay=0.01,
    save_strategy="epoch",
    load_best_model_at_end=True,
    push_to_hub=False,
    metric_for_best_model="eval_f1",
)

# Data collator
data_collator = DefaultDataCollator(return_tensors="pt")

In [None]:
metric = evaluate.load("seqeval")
import numpy as np
def compute_metrics(p):
  predictions, labels = p
  predictions = np.argmax(predictions, axis=-1)

  true_preds = [
      [id2label[p] for (p, l) in zip(prediction, label) if l != -100]
      for prediction, label in zip(predictions, labels)
  ]
  true_labels = [
      [id2label[l] for (p, l) in zip(prediction, label) if l != -100]
      for prediction, label in zip(predictions, labels)
  ]
  results = metric.compute(predictions=true_preds, references=true_labels)

  return {
      "precision": results["overall_precision"],
      "recall": results["overall_recall"],
      "f1": results["overall_f1"],
      "accuracy": results["overall_accuracy"],
  }

In [None]:
from transformers import EarlyStoppingCallback
early_stop = EarlyStoppingCallback(
    early_stopping_patience   = 5,
    early_stopping_threshold  = 0.0,
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    processing_class=tokenizer,
    compute_metrics=compute_metrics,
)

trainer.train()

In [None]:
'''
Model available on huggingface hub: https://huggingface.co/adamadam111/bros-finetuned-funsd
'''

## Inference

Upload a test form as 'test.jpg' for inference

In [None]:
!pip install pytesseract

In [None]:
import pytesseract, cv2

tokenizer = AutoTokenizer.from_pretrained("adamadam111/bros-funsd-finetuned",do_lower_case=True)
model = BrosForTokenClassification.from_pretrained("adamadam111/bros-funsd-finetuned")

image_path = "test.jpg"
img = cv2.imread(image_path)
ocr_data = pytesseract.image_to_data(img, output_type=pytesseract.Output.DICT)

In [None]:
words, bboxes = [], []
for i in range(len(ocr_data["text"])):
  word = ocr_data["text"][i].strip()
  if word == "":
    continue
  x, y, w, h = (ocr_data[k][i] for k in ["left","top","width","height"])
  words.append(word)
  bboxes.append([x, y, x+w, y+h])


In [None]:
height, width = img.shape[:2]
boxes_1000 = [normalize_bbox(bbox, width, height) for bbox in bboxes]


In [None]:
enc = tokenizer(words,
          is_split_into_words=True,
          return_tensors="pt",
          padding="max_length",
          truncation=True,
          max_length=512)

aligned_boxes = []
word_ids = enc.word_ids()
for wi in word_ids:
  if wi is None:
    aligned_boxes.append([0,0,0,0])
  else:
    aligned_boxes.append(boxes_1000[wi])

enc["bbox"] = torch.tensor([aligned_boxes])


In [None]:
def unnormalize_box(box_1000, w, h):
  return [
    int(box_1000[0] * w / 1000),
    int(box_1000[1] * h / 1000),
    int(box_1000[2] * w / 1000),
    int(box_1000[3] * h / 1000),
  ]

def iob_to_label(tag):
  print(tag)
  core = tag[2:] if tag.startswith(("B-","I-")) else tag
  return core.lower() if core else "other"

label2color = {
  "question": "blue",
  "answer":   "green",
  "header":   "orange",
  "other":    "violet",
}

In [None]:
outputs   = model(**enc)
pred_ids  = outputs.logits.argmax(dim=-1).squeeze().tolist()
word_ids  = enc.word_ids()

true_preds, true_boxes = [], []
prev_word = None
for pid, wid, box in zip(pred_ids, word_ids, aligned_boxes):
  if wid is None or wid == prev_word:
    continue
  true_preds.append(id2label[pid])
  true_boxes.append(unnormalize_box(box, width, height))
  prev_word = wid

true_preds = [id2label[p] for p, l in zip(pred_ids, word_ids) if l != -100]
true_boxes = [unnormalize_box(b, width, height) for b, l in zip(aligned_boxes, word_ids) if l != -100]

img_pil = Image.fromarray(img)
draw = ImageDraw.Draw(img_pil)
font = ImageFont.load_default()

for pred, box in zip(true_preds, true_boxes):
  label = iob_to_label(pred)
  color = label2color.get(label, "red")
  draw.rectangle(box, outline=color, width=2)
  draw.text((box[0] + 3, box[1] - 10), text=label, fill=color, font=font)

img_pil
#img_pil.save('/content/test_ann.jpg')