In [None]:
!git clone https://github.com/adamserag1/Interpretability-for-VRDU-models.git

In [None]:
!git pull https://github.com/adamserag1/Interpretability-for-VRDU-models.git

#config

In [None]:
!pwd

In [None]:
%cd /content/Interpretability-for-VRDU-models

In [None]:
!pip install -r requirements.txt

In [None]:
!pip install -U datasets

#code

In [None]:
from datasets import load_from_disk
from transformers import LayoutLMv3ForSequenceClassification, AutoProcessor
import sys
import importlib
def reload_modules():
    for module in list(sys.modules.keys()):
        if module.startswith('vrdu_utils') or module.startswith('Explain'):
            print(f"Reloading module: {module}")
            importlib.reload(sys.modules[module])

reload_modules()

from vrdu_utils.encoders import *
from Explain.lime import *
from vrdu_utils.utils import *
import torch
from torch.utils.data import DataLoader
from transformers import LayoutLMv3ForSequenceClassification, AutoProcessor, BrosPreTrainedModel, BrosModel, AutoConfig, AutoTokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


import warnings
from transformers import logging as hf_logging

warnings.filterwarnings(
    "ignore",
    category=FutureWarning,
    module="transformers.modeling_utils",
)
hf_logging.set_verbosity_error()

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!cp -r /content/drive/MyDrive/THESIS/rvl_cdip_financial_subset /content

In [None]:
from torch import nn
class BrosForDocumentClassification(BrosPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.bros = BrosModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

    def forward(
        self,
        input_ids=None,
        bbox=None,
        attention_mask=None,
        token_type_ids=None,
        labels=None,
        **kwargs
    ):
        outputs = self.bros(
            input_ids=input_ids,
            bbox=bbox,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )


        cls_output = outputs.last_hidden_state[:, 0, :]  # shape: (batch_size, hidden_size)

        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return {
            "loss": loss,
            "logits": logits,
        }

In [None]:
bros_config = AutoConfig.from_pretrained(
    "adamadam111/bros-docclass-finetuned-frz",
    num_labels=5,
    id2label={0: "form", 1: "invoice", 2: "budget", 3: "file folder", 4: "questionnaire"},
    label2id={"form": 0, "invoice": 1, "budget": 2, "file folder": 3, "questionnaire": 4}
)

BROS = BrosForDocumentClassification.from_pretrained(
    "adamadam111/bros-docclass-finetuned-frz",
    config=bros_config
)
BROS_t = AutoTokenizer.from_pretrained("adamadam111/bros-docclass-finetuned-frz",do_lower_case=True)

LLMV3 = LayoutLMv3ForSequenceClassification.from_pretrained("adamadam111/layoutlmv3-docclass-finetuned-frz",
                                                            num_labels=5,
                                                            id2label={0: "form", 1: "invoice", 2: "budget", 3: "file folder", 4: "questionnaire"},
                                                            label2id={"form": 0, "invoice": 1, "budget": 2, "file folder": 3, "questionnaire": 4})
LLMV3_p =AutoProcessor.from_pretrained("adamadam111/layoutlmv3-docclass-finetuned-frz", apply_ocr=False)



In [None]:
LLMV3_encode = make_layoutlmv3_encoder(LLMV3_p, ner=False)
BROS_encode = make_bros_encoder(BROS_t, ner=False)

In [None]:
ds = load_from_disk("/content/rvl_cdip_financial_subset")
split = ds.train_test_split(test_size=0.2, shuffle=True, seed=42)
val = split["test"]

In [None]:
def make_dl(ds, bs=8):
  return DataLoader(ds, batch_size=bs, collate_fn=lambda x: x)

#val = val.filter(lambda example, idx: idx != 159, with_indices=True)
val = val.filter(lambda example, idx: idx != 715, with_indices=True)
val_ds = DocSampleDataset(val)
dl = make_dl(val_ds)
encoders = {
    "lmv3" : LLMV3_encode,
    "bros" : BROS_encode
}

In [None]:
val_ds.pop(159)

In [None]:
from tqdm import tqdm

models = {
    "lmv3" : LLMV3.eval().to("cuda"),
    "bros" : BROS.eval().to("cuda")
}


records = []

@torch.no_grad()
def sweep(dl, ds_name):
  for batch in tqdm(dl, desc="finding samples"):
    samples = [item[0] for item in batch]
    indices = [item[1] for item in batch]

    for mname, model in models.items():
      enc = make_layoutlmv3_encoder(LLMV3_p) if mname=="lmv3" else make_bros_encoder(BROS_t)
      outputs = model(**enc(samples, device))
      logits = outputs['logits'] if isinstance(outputs, dict) else outputs.logits # (B, n_cls)
      probs  = logits.softmax(-1).cpu()
      for i, (samp, p) in enumerate(zip(samples, probs)):
        records.append(dict(
            ds = ds_name,
            file_id = indices[i],
            true_label= samp.label,
            model= mname,
            pred= p.argmax().item(),
            conf= p.max().item(),
        ))

sweep(dl, "rvl_cdip_financial_subset")

In [None]:
import pandas as pd
df = pd.DataFrame(records)
wide = (df.pivot(index=["ds","file_id","true_label"],
                 columns="model", values=["pred","conf"])
          .reset_index())
wide.columns = ["ds","file_id","true",
                "pred_bros","pred_lmv3",
                "conf_bros","conf_lmv3"]
agree = wide.query("pred_bros==true and pred_lmv3==true")
disagree = wide.query("pred_bros!=pred_lmv3")

sample_agree = (agree
                .assign(margin=lambda d: d[["conf_bros","conf_lmv3"]].min(1))
                .sort_values("margin", ascending=False)
                .iloc[0])

sample_disagree = (disagree
                   .assign(avg_conf=lambda d: d[["conf_bros","conf_lmv3"]].mean(1))
                   .sort_values("avg_conf", ascending=False)
                   .iloc[0])
print(sample_agree)
print(sample_disagree)

In [None]:
val_ds[715][0].image

In [None]:
def predict_one(model, encode_fn, sample):
    model.eval()
    with torch.no_grad():
        outputs = model(**encode_fn([sample], device))         # (n_cls,)
        logits= outputs['logits'] if isinstance(outputs, dict) else outputs.logits # (B, n_cls)
        prob = logits.softmax(-1)
    cls_id = prob.argmax().item()
    conf = prob.max().item()
    return cls_id, conf
ID2NAME = {0:"form", 1:"invoice", 2:"budget", 3:"file folder", 4:"questionnaire"}

def quick_check(sample, name="page"):
    for tag in models:
        model= models[tag]
        enc = encoders[tag]
        cls, con = predict_one(model, enc, sample)
        print(f"{name:>8} | {tag:5} → {ID2NAME[cls]:12s}  (p = {conf:.4f})")

In [None]:
val_docsam = DocSampleDataset(val)
agree = val_docsam[108]
disagree = val_docsam[100]

In [None]:
quick_check(agree[0],  "agree")
quick_check(disagree[0],  "clash")