In [None]:
"""BROS"""

In [1]:
!pip install transformers torch datasets evaluate seqeval

Collecting evaluate
  Downloading evaluate-0.4.4-py3-none-any.whl.metadata (9.5 kB)
Collecting seqeval
  Downloading seqeval-1.2.2.tar.gz (43 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/43.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70

In [27]:
import torch
import torch.nn as nn
from transformers import BrosPreTrainedModel, BrosModel, AutoConfig, AutoTokenizer
from PIL import Image,ImageDraw, ImageFont
from datasets import load_dataset, load_from_disk
import pandas as pd
import evaluate
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import numpy as np

In [3]:
# mount drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
!cp -r /content/drive/MyDrive/THESIS/rvl_cdip_financial_subset /content

In [5]:
class BrosForDocumentClassification(BrosPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.bros = BrosModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

    def forward(
        self,
        input_ids=None,
        bbox=None,
        attention_mask=None,
        token_type_ids=None,
        labels=None,
        **kwargs
    ):
        outputs = self.bros(
            input_ids=input_ids,
            bbox=bbox,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )

        # Use the [CLS] token's representation (first token)
        cls_output = outputs.last_hidden_state[:, 0, :]  # shape: (batch_size, hidden_size)

        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return {
            "loss": loss,
            "logits": logits,
        }

In [6]:
config = AutoConfig.from_pretrained(
    "naver-clova-ocr/bros-base-uncased",
    num_labels=5,
    id2label={0: "form", 1: "invoice", 2: "budget", 3: "file folder", 4: "questionnaire"},
    label2id={"form": 0, "invoice": 1, "budget": 2, "file folder": 3, "questionnaire": 4}
)

model = BrosForDocumentClassification.from_pretrained(
    "naver-clova-ocr/bros-base-uncased",
    config=config
)
tokenizer = AutoTokenizer.from_pretrained("naver-clova-ocr/bros-base-uncased",do_lower_case=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

config.json:   0%|          | 0.00/535 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

Some weights of BrosForDocumentClassification were not initialized from the model checkpoint at naver-clova-ocr/bros-base-uncased and are newly initialized: ['bros.bbox_embeddings.bbox_projection.weight', 'bros.bbox_embeddings.bbox_sinusoid_emb.x_pos_emb.inv_freq', 'bros.bbox_embeddings.bbox_sinusoid_emb.y_pos_emb.inv_freq', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


vocab.txt: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

BrosForDocumentClassification(
  (bros): BrosModel(
    (embeddings): BrosTextEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (bbox_embeddings): BrosBboxEmbeddings(
      (bbox_sinusoid_emb): BrosPositionalEmbedding2D(
        (x_pos_emb): BrosPositionalEmbedding1D()
        (y_pos_emb): BrosPositionalEmbedding1D()
      )
      (bbox_projection): Linear(in_features=192, out_features=64, bias=False)
    )
    (encoder): BrosEncoder(
      (layer): ModuleList(
        (0-11): 12 x BrosLayer(
          (attention): BrosAttention(
            (self): BrosSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): 

In [38]:
def normalize_bbox(bbox, width, height):
    return [
        int(1000 * (bbox[0] / width)),
        int(1000 * (bbox[1] / height)),
        int(1000 * (bbox[2] / width)),
        int(1000 * (bbox[3] / height)),
    ]
def encode(batch):
    encodings = tokenizer(
        batch["words"],
        is_split_into_words=True,
        truncation=True,
        padding="max_length",
        max_length=512,
        return_tensors="pt"
    )

    batch_aligned_bboxes = []

    for idx, (bboxes, image) in enumerate(zip(batch["bboxes"], batch["image"])):
        # Get image size for normalization
        if isinstance(image, Image.Image):
            width, height = image.size
        else:
            image = Image.open(image)
            width, height = image.size

        # Normalize bounding boxes
        normalized_bboxes = [normalize_bbox(bbox, width, height) for bbox in bboxes]

        # Align bboxes to subword tokens
        aligned_bboxes = []
        word_ids = encodings.word_ids(batch_index=idx)

        for word_id in word_ids:
            if word_id is None:
                aligned_bboxes.append([0, 0, 0, 0])
            else:
                aligned_bboxes.append(normalized_bboxes[word_id])

        batch_aligned_bboxes.append(aligned_bboxes)

    encodings["bbox"] = batch_aligned_bboxes
    encodings["labels"] = batch["label"]  # document-level label per example

    return encodings

def tokenize_words(batch):
  encodings = tokenizer(
    batch["words"],
    is_split_into_words=True,
    truncation=True,
    padding="max_length",
    max_length=512,
    return_tensors="pt"
  )

  batch_normalized_bboxes, encoded_labels = [], []
  for idx, (bboxes, img_path, labels) in enumerate(zip(batch["bboxes"], batch["image"], batch["label"])):
    width, height = Image.open(img_path).size
    normalized_bboxes = [normalize_bbox(bbox, width, height) for bbox in bboxes]

    # Align boxes to sub words
    aligned_boxes, aligned_labels = [], []
    for word_id in encodings.word_ids(batch_index=idx):
      if word_id is None:
        aligned_boxes.append([0, 0, 0, 0])
        aligned_labels.append(-100)
      else:
        aligned_boxes.append(normalized_bboxes[word_id])
        aligned_labels.append(labels[word_id])

    batch_normalized_bboxes.append(aligned_boxes)
    encoded_labels.append(aligned_labels)

  encodings['bbox'] = batch_normalized_bboxes
  encodings['labels'] = encoded_labels

  return encodings

In [39]:
rvl = load_from_disk("/content/rvl_cdip_financial_subset")
dataset_split = rvl.train_test_split(test_size=0.2, seed=42)

train = dataset_split["train"]
val = dataset_split["test"]
print(len(train))
print(len(val))
train_dataset = train.map(encode, batched=True, remove_columns=train.column_names)
val_dataset = val.map(encode, batched=True, remove_columns=val.column_names)

4000
1000


Map:   0%|          | 0/4000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [28]:
def compute_metrics(p):
    preds = np.argmax(p.predictions, axis=1)
    labels = p.label_ids

    acc = accuracy_score(labels, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, preds, average='weighted', zero_division=0
    )

    return {
        "accuracy": acc,
        "precision": precision,
        "recall": recall,
        "f1": f1
    }

In [40]:
from transformers import TrainingArguments, Trainer, DefaultDataCollator
training_args = TrainingArguments(
    output_dir="./bros-docclass-finetuned",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=10,                 # Enough to converge on 5-class task
    weight_decay=0.01,
    # warmup_ratio=0.1,
    logging_dir="./logs",
    logging_strategy="steps",
    logging_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="eval_accuracy",
    save_total_limit=2,
    seed=42,
    report_to="none"                     #
)
# Data collator
data_collator = DefaultDataCollator(return_tensors="pt")

In [41]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=DefaultDataCollator(),
    compute_metrics=compute_metrics,
)
trainer.train()

  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.5654,0.637008,0.813,0.817836,0.813,0.811303
2,0.5545,0.571972,0.838,0.845242,0.838,0.836524
3,0.38,0.763508,0.82,0.83685,0.82,0.818305
4,0.3074,0.673216,0.845,0.849043,0.845,0.844261
5,0.2474,0.817511,0.835,0.840681,0.835,0.834766
6,0.2001,0.823353,0.85,0.851455,0.85,0.849621
7,0.1411,0.855983,0.85,0.852066,0.85,0.849144
8,0.1287,0.915286,0.846,0.846765,0.846,0.845527
9,0.1255,0.909841,0.849,0.850455,0.849,0.848423
10,0.0825,0.911784,0.844,0.845101,0.844,0.843436




TrainOutput(global_step=5000, training_loss=0.2918694201469421, metrics={'train_runtime': 1715.7985, 'train_samples_per_second': 23.313, 'train_steps_per_second': 2.914, 'total_flos': 1.0526235648e+16, 'train_loss': 0.2918694201469421, 'epoch': 10.0})