# Summary Statistics for interpretability techniques

## GITHUB

In [1]:
!git clone https://github.com/adamserag1/Interpretability-for-VRDU-models.git

Cloning into 'Interpretability-for-VRDU-models'...
remote: Enumerating objects: 1607, done.[K
remote: Counting objects: 100% (135/135), done.[K
remote: Compressing objects: 100% (93/93), done.[K
remote: Total 1607 (delta 77), reused 88 (delta 41), pack-reused 1472 (from 1)[K
Receiving objects: 100% (1607/1607), 26.39 MiB | 27.41 MiB/s, done.
Resolving deltas: 100% (1016/1016), done.


In [None]:
!git pull https://github.com/adamserag1/Interpretability-for-VRDU-models.git

In [2]:
%cd /content/Interpretability-for-VRDU-models

/content/Interpretability-for-VRDU-models


In [3]:
!pip install -r requirements.txt

Collecting jupyter (from -r requirements.txt (line 2))
  Downloading jupyter-1.1.1-py2.py3-none-any.whl.metadata (2.0 kB)
Collecting seqeval (from -r requirements.txt (line 8))
  Downloading seqeval-1.2.2.tar.gz (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting evaluate (from -r requirements.txt (line 10))
  Downloading evaluate-0.4.5-py3-none-any.whl.metadata (9.5 kB)
Collecting lime (from -r requirements.txt (line 12))
  Downloading lime-0.2.0.1.tar.gz (275 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m275.7/275.7 kB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting jupyterlab (from jupyter->-r requirements.txt (line 2))
  Downloading jupyterlab-4.4.5-py3-none-any.whl.metadata (16 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->-r requirements

## Libraries

In [4]:
#code
from datasets import load_from_disk
from transformers import LayoutLMv3ForSequenceClassification, AutoProcessor, BrosModel, AutoTokenizer, BrosPreTrainedModel, AutoConfig
import sys
import importlib
def reload_modules():
    for module in list(sys.modules.keys()):
        if module.startswith('vrdu_utils') or module.startswith('Classification_Explain') or module.startswith('lime') or module.startswith('Eval'):
            print(f"Reloading module: {module}")
            importlib.reload(sys.modules[module])

reload_modules()

from vrdu_utils.encoders import *
from Classification_Explain.lime import *
from vrdu_utils.utils import *
import torch
from Eval.eval_suite import *
from Eval.fidelity import *
from Classification_Explain.shap import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import warnings
from transformers import logging as hf_logging

warnings.filterwarnings(
    "ignore",
    category=FutureWarning,
    module="transformers.modeling_utils",   # the module that emits the msg
)
hf_logging.set_verbosity_error()

## Data + Model Setup

In [5]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [6]:
!cp -r /content/drive/MyDrive/THESIS/rvl_cdip_financial_subset /content
rvl = load_from_disk('/content/rvl_cdip_financial_subset')
dataset_split = rvl.train_test_split(test_size=0.2, seed=42)
val = dataset_split['test']
val_ds = DocSampleDataset(val)

In [7]:
from torch import nn
class BrosForDocumentClassification(BrosPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.bros = BrosModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

    def forward(
        self,
        input_ids=None,
        bbox=None,
        attention_mask=None,
        token_type_ids=None,
        labels=None,
        **kwargs
    ):
        outputs = self.bros(
            input_ids=input_ids,
            bbox=bbox,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )

        # Use the [CLS] token's representation (first token)
        cls_output = outputs.last_hidden_state[:, 0, :]  # shape: (batch_sizef, hidden_size)

        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return {
            "loss": loss,
            "logits": logits,
        }

In [198]:
label2id = {"form": 0, "file folder": 1, "budget": 2, "invoice": 3, "questionnaire": 4}
id2label = {0: "form", 1: "file folder", 2: "budget", 3: "invoice", 4: "questionnaire"}

In [199]:

LLMV3 = LayoutLMv3ForSequenceClassification.from_pretrained("adamadam111/layoutlmv3-docclass-finetuned-frz",
                                                            num_labels=5,
                                                            id2label={0: "form", 1: "file folder", 2: "budget", 3: "invoice", 4: "questionnaire"},
                                                            label2id={"form": 0, "file folder": 1, "budget": 2, "invoice": 3, "questionnaire": 4})
LLMV3_proc = AutoProcessor.from_pretrained("adamadam111/layoutlmv3-docclass-finetuned-frz", apply_ocr=False)

LLMV3.to(device)

bros_config = AutoConfig.from_pretrained(
    "adamadam111/bros-docclass-finetuned-frz",
    num_labels=5,
    id2label={0: "form", 1: "file folder", 2: "budget", 3: "invoice", 4: "questionnaire"},
    label2id={"form": 0, "file folder": 1, "budget": 2, "invoice": 3, "questionnaire": 4}
)

BROS = BrosForDocumentClassification.from_pretrained(
    "adamadam111/bros-docclass-finetuned-frz",
    config=bros_config
)
BROS_t = AutoTokenizer.from_pretrained("adamadam111/bros-docclass-finetuned-frz",do_lower_case=True)
LLMV3.to(device)
BROS.to(device)

BrosForDocumentClassification(
  (bros): BrosModel(
    (embeddings): BrosTextEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (bbox_embeddings): BrosBboxEmbeddings(
      (bbox_sinusoid_emb): BrosPositionalEmbedding2D(
        (x_pos_emb): BrosPositionalEmbedding1D()
        (y_pos_emb): BrosPositionalEmbedding1D()
      )
      (bbox_projection): Linear(in_features=192, out_features=64, bias=False)
    )
    (encoder): BrosEncoder(
      (layer): ModuleList(
        (0-11): 12 x BrosLayer(
          (attention): BrosAttention(
            (self): BrosSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): 

In [200]:
LLMV3_encode = make_layoutlmv3_encoder(LLMV3_proc)
BROS_encode = make_bros_encoder(BROS_t)

@torch.no_grad()
def pred(model, encode_fn, samples):
  enc = encode_fn(samples, device)
  try:
    logits = model(**encode_fn(samples)).logits
  except:
    logits_loss_dict = model(**encode_fn(samples, device))
    logits = logits_loss_dict["logits"]
  return torch.softmax(logits, dim=1).cpu().numpy()



## Finding Samples

In [202]:
pred(LLMV3, LLMV3_encode,[val_ds[5][0]]).argmax()

np.int64(3)

In [203]:
samples = {
    "form": [],
    "file folder": [],
    "budget": [],
    "invoice": [],
    "questionnaire": []
}

classes = {
    'form': [(s,id) for s, id in val_ds if s.label == label2id['form']],
    'invoice': [(s,id) for s, id in val_ds if s.label == label2id['invoice']],
    'budget': [(s,id) for s, id in val_ds if s.label == label2id['budget']],
    'file folder': [(s,id) for s, id in val_ds if s.label == label2id['file folder']],
    'questionnaire': [(s,id) for s, id in val_ds if s.label == label2id['questionnaire']]
}
wrong_ids = {
    "form": [],
    "invoice": [],
    "budget": [],
    "file folder": [],
    "questionnaire": []
}

In [211]:
wrong_ids = {
    "form": [],
    "invoice": [],
    "budget": [],
    "file folder": [],
    "questionnaire": []
}

In [212]:
# 5 Samples from each class
# for example in forms:

def find_incorrect(key):
  for doc, id in classes[key]:
    bros_p = pred(BROS, BROS_encode, [doc])
    llmv3_p = pred(LLMV3, LLMV3_encode, [doc])
    if bros_p.argmax() != label2id[key] or llmv3_p.argmax() != label2id[key]:
      wrong_ids[key].append(id)

find_incorrect('form')
find_incorrect('invoice')
find_incorrect('budget')
find_incorrect('file folder')
find_incorrect('questionnaire')

In [215]:
wrong_ids['invoice']
# print(len(wrong_ids['invoice']))
# print(len(classes['invoice']))

[8,
 50,
 90,
 91,
 104,
 113,
 121,
 123,
 128,
 152,
 154,
 178,
 211,
 231,
 243,
 257,
 309,
 315,
 330,
 336,
 337,
 377,
 407,
 451,
 468,
 474,
 479,
 481,
 496,
 542,
 555,
 620,
 624,
 754,
 757,
 761,
 767,
 786,
 804,
 821,
 843,
 850,
 887,
 938,
 993,
 995]

In [262]:
t = LimeTextExplainer(LLMV3, LLMV3_encode, mask_token = LLMV3_proc.tokenizer.mask_token, kernel_width_factor = 0.75, labels = [3])
temp = t.explain(invoices[2][0], num_samples=500)

Begging EXPLAINER
Begging EXPLAIN_INSTANCE
MADE PREDICT


[LIME] - Text: 100%|██████████| 32/32 [00:14<00:00,  2.24it/s]


In [271]:
forms = [classes['form'][5], classes['form'][10], classes['form'][56], classes['form'][150], classes['form'][77]]
invoices = [classes['invoice'][3], classes['invoice'][12], classes['invoice'][29], classes['invoice'][40], classes['invoice'][22]]
budgets = [classes['budget'][1], classes['budget'][3], classes['budget'][13], classes['budget'][33], classes['budget'][21]]
file_folders = [classes['file folder'][32], classes['file folder'][2], classes['file folder'][10], classes['file folder'][95], classes['file folder'][29]]
questionnaires = [classes['questionnaire'][5], classes['questionnaire'][1], classes['questionnaire'][35], classes['questionnaire'][36], classes['questionnaire'][70]]

def check_incorrect(key, docs):
  for idx, (doc, id) in enumerate(docs):
    if id in wrong_ids[key]:
      print(idx)
      return False
  return True

assert check_incorrect('form', forms)
assert check_incorrect('invoice', invoices)
assert check_incorrect('budget', budgets)
assert check_incorrect('file folder', file_folders)
assert check_incorrect('questionnaire', questionnaires)

In [229]:
final_samples = {
    "form": forms,
    "file folder": invoices,
    "budget": budgets,
    "invoice": file_folders,
    "questionnaire": questionnaires
}

## Obtain Explanations

In [254]:
from tqdm import tqdm
def obtain_explainers(key, label):
  text_explainers = {
      'BROS lime' : LimeTextExplainer(BROS, BROS_encode, mask_token=BROS_t.mask_token, kernel_width_factor = 0.75, labels=[label]),
      'LLMV3 lime' : LimeTextExplainer(LLMV3, LLMV3_encode, mask_token = LLMV3_proc.tokenizer.mask_token, kernel_width_factor = 0.75, labels = [label]),
      'BROS shap' : SHAPTextExplainer(BROS, BROS_encode, BROS_t, mask_token=BROS_t.mask_token, device=device),
      'LLMV3 shap' : SHAPTextExplainer(LLMV3, LLMV3_encode, LLMV3_proc.tokenizer,mask_token=LLMV3_proc.tokenizer.mask_token, device=device)
  }

  layout_explainers = {
      'BROS lime' : LimeLayoutExplainer(BROS, BROS_encode, mask_token=BROS_t.mask_token, kernel_width_factor = 0.75, labels=[label]),
      'LLMV3 lime' : LimeLayoutExplainer(LLMV3, LLMV3_encode, mask_token = LLMV3_proc.tokenizer.mask_token, kernel_width_factor = 0.75, labels = [label]),
      'BROS shap' : SHAPLayoutExplainer(BROS, BROS_encode, device=device),
      'LLMV3 shap' : SHAPLayoutExplainer(LLMV3, LLMV3_encode, device=device)
  }

  vision_explainers = {
      'LLMV3 lime' : LimeVisionExplainer(LLMV3, LLMV3_encode, label = [label], device=device),
      'LLMV3 shap' : SHAPVisionExplainer(LLMV3, LLMV3_encode, device=device, class_idx=4, mask_value='blur(64,64)')
  }

  return text_explainers, layout_explainers, vision_explainers

def obtain_explanations(text, layout, vision, key):
  explanations = {}
  print(text)
  for sample in tqdm(final_samples[key], desc=f"Obtaining explanations for {key}s"):
    explanations.update({f'[T] - {key}' : explainer.explain(sample[0], nsamples=2000) for key, explainer in text.items() if 'shap' in key})
    explanations.update({f'[L] - {key}' : explainer.explain(sample[0], nsamples=2000) for key, explainer in layout.items() if 'shap' in key})
    explanations.update({f'[V] - {key}' : explainer.explain(sample[0], nsamples=1000) for key, explainer in vision.items() if 'shap' in key})
    explanations.update({f'[T] - {key}' : explainer.explain(sample[0], num_samples=2000) for key, explainer in text.items() if 'lime' in key})
    explanations.update({f'[L] - {key}' : explainer.explain(sample[0], num_samples=2000) for key, explainer in layout.items() if 'lime' in key})
    explanations.update({f'[V] - {key}' : explainer.explain(sample[0], num_samples=1000) for key, explainer in vision.items() if 'lime' in key})

  return explanations



In [263]:
print(len(class_explanations))

3


In [265]:
import pickle
explainers = {key: obtain_explainers(key, label2id[key]) for key, _ in final_samples.items()}
# class_explanations = {}
# class_explanations.update({'form_explanations' : obtain_explanations(explainers['form'][0], explainers['form'][1], explainers['form'][2], 'form')})
# class_explanations.update({'file_folder_explanations' : obtain_explanations(explainers['file folder'][0], explainers['file folder'][1], explainers['file folder'][2], 'file folder')})
# class_explanations.update({'budget_explanations' : obtain_explanations(explainers['budget'][0], explainers['budget'][1], explainers['budget'][2], 'budget')})
# class_explanations.update({'invoice_explanations' : obtain_explanations(explainers['invoice'][0], explainers['invoice'][1], explainers['invoice'][2], 'invoice')})
# class_explanations.update({'questionnaire_explanations' : obtain_explanations(explainers['questionnaire'][0], explainers['questionnaire'][1], explainers['questionnaire'][2], 'questionnaire')})


with open('explanations.pkl', 'wb') as fp:
    pickle.dump(class_explanations, fp)
    print('Dictionary saved successfully to file')

Dictionary saved successfully to file


In [266]:
!cp explanations.pkl /content/drive/MyDrive/THESIS/

In [252]:
import pickle
temp = {'hello': print('hello')}
with open('person_data.pkl', 'wb') as fp:
    pickle.dump(temp, fp)
    print('dictionary saved successfully to file')

hello
dictionary saved successfully to file


In [None]:
form_explanations = obtain_explanations(explainers['file folder'][0], explainers['file folder'][1], explainers['file folder'][2], 'file folder')
form_explanations = obtain_explanations(explainers['budget'][0], explainers['budget'][1], explainers['budget'][2], 'budget')
form_explanations = obtain_explanations(explainers['invoice'][0], explainers['invoice'][1], explainers['invoice'][2], 'form')
form_explanations = obtain_explanations(explainers['questionnaire'][0], explainers['questionnaire'][1], explainers['questionnaire'][2], 'questionnaire')

In [268]:
ttt = {}
with (open("explanations.pkl", "rb")) as openfile:
    while True:
        try:
            ttt.update(pickle.load(openfile))
        except EOFError:
            break

In [269]:
print(ttt)

{'form_explanations': {'[T] - BROS shap': .values =
array([[[ 7.68571482e-01, -8.26928900e-03, -4.64106961e-02,
         -1.44490416e+00, -3.99228383e-01],
        [ 4.24439805e-02, -9.66768942e-02, -2.84873952e-02,
          7.93326172e-03, -7.49220008e-02],
        [-2.79477880e-01, -1.86295618e-02,  7.60602580e-01,
          8.73843588e-01,  7.27536678e-02],
        [ 4.37782292e-02, -1.03070590e-01, -8.14246915e-02,
         -9.96775993e-02, -8.69527811e-02],
        [ 4.48228389e-03, -2.01650479e-02,  2.19438640e-01,
          1.13336741e-01,  1.08702996e-01],
        [ 2.70457686e-02, -1.69764707e-01, -4.45195328e-02,
         -7.91337231e-02, -2.19386930e-01],
        [ 7.32130268e-02,  2.93779820e-02, -1.60636143e-02,
         -1.20847327e-02, -1.70598724e-01],
        [ 2.11023405e-02, -7.70054514e-02, -4.43540595e-02,
         -7.10025870e-03, -5.02473658e-03],
        [ 8.97032158e-03, -2.29093161e-04,  9.80796380e-02,
          2.48255307e-02,  1.29238047e-01],
        [ 7.