# Summary Statistics for interpretability techniques

## GITHUB

In [None]:
!git clone https://github.com/adamserag1/Interpretability-for-VRDU-models.git

In [None]:
!git pull https://github.com/adamserag1/Interpretability-for-VRDU-models.git

In [None]:
%cd /content/Interpretability-for-VRDU-models

In [None]:
!pip install -r requirements.txt

In [None]:
!pip install -U datasets

## Libraries

In [None]:
#code
from datasets import load_from_disk
from transformers import LayoutLMv3ForSequenceClassification, AutoProcessor, BrosModel, AutoTokenizer, BrosPreTrainedModel, AutoConfig
import sys
import importlib
def reload_modules():
    for module in list(sys.modules.keys()):
        if module.startswith('vrdu_utils') or module.startswith('Explain') or module.startswith('lime') or module.startswith('Eval'):
            print(f"Reloading module: {module}")
            importlib.reload(sys.modules[module])

reload_modules()

from vrdu_utils.encoders import *
from Explain.lime import *
from vrdu_utils.utils import *
import torch
from Eval.eval_suite import *
from Eval.fidelity import *
from Explain.shap import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import warnings
from transformers import logging as hf_logging

warnings.filterwarnings(
    "ignore",
    category=FutureWarning,
    module="transformers.modeling_utils",   # the module that emits the msg
)
hf_logging.set_verbosity_error()

## Data + Model Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')
!cp /content/drive/MyDrive/THESIS/explanations.pkl /content/Interpretability-for-VRDU-models/

In [None]:
!cp -r /content/drive/MyDrive/THESIS/rvl_cdip_financial_subset /content
rvl = load_from_disk('/content/rvl_cdip_financial_subset')
dataset_split = rvl.train_test_split(test_size=0.2, seed=42)
val = dataset_split['test']
val_ds = DocSampleDataset(val)

In [None]:
from torch import nn
class BrosForDocumentClassification(BrosPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.bros = BrosModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

    def forward(
        self,
        input_ids=None,
        bbox=None,
        attention_mask=None,
        token_type_ids=None,
        labels=None,
        **kwargs
    ):
        outputs = self.bros(
            input_ids=input_ids,
            bbox=bbox,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )


        cls_output = outputs.last_hidden_state[:, 0, :]

        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return {
            "loss": loss,
            "logits": logits,
        }

In [None]:
label2id = {"form": 0, "file folder": 1, "budget": 2, "invoice": 3, "questionnaire": 4}
id2label = {0: "form", 1: "file folder", 2: "budget", 3: "invoice", 4: "questionnaire"}

In [None]:

LLMV3 = LayoutLMv3ForSequenceClassification.from_pretrained("adamadam111/layoutlmv3-docclass-finetuned-frz",
                                                            num_labels=5,
                                                            id2label={0: "form", 1: "file folder", 2: "budget", 3: "invoice", 4: "questionnaire"},
                                                            label2id={"form": 0, "file folder": 1, "budget": 2, "invoice": 3, "questionnaire": 4})
LLMV3_proc = AutoProcessor.from_pretrained("adamadam111/layoutlmv3-docclass-finetuned-frz", apply_ocr=False)

LLMV3.to(device)

bros_config = AutoConfig.from_pretrained(
    "adamadam111/bros-docclass-finetuned-frz",
    num_labels=5,
    id2label={0: "form", 1: "file folder", 2: "budget", 3: "invoice", 4: "questionnaire"},
    label2id={"form": 0, "file folder": 1, "budget": 2, "invoice": 3, "questionnaire": 4}
)

BROS = BrosForDocumentClassification.from_pretrained(
    "adamadam111/bros-docclass-finetuned-frz",
    config=bros_config
)
BROS_t = AutoTokenizer.from_pretrained("adamadam111/bros-docclass-finetuned-frz",do_lower_case=True)
LLMV3.to(device)
BROS.to(device)

In [None]:
LLMV3_encode = make_layoutlmv3_encoder(LLMV3_proc)
BROS_encode = make_bros_encoder(BROS_t)

@torch.no_grad()
def pred(model, encode_fn, samples):
  enc = encode_fn(samples, device)
  try:
    logits = model(**encode_fn(samples)).logits
  except:
    logits_loss_dict = model(**encode_fn(samples, device))
    logits = logits_loss_dict["logits"]
  return torch.softmax(logits, dim=1).cpu().numpy()



## Finding Samples

In [None]:
pred(LLMV3, LLMV3_encode,[val_ds[5][0]]).argmax()

In [None]:
samples = {
    "form": [],
    "file folder": [],
    "budget": [],
    "invoice": [],
    "questionnaire": []
}

classes = {
    'form': [(s,id) for s, id in val_ds if s.label == label2id['form']],
    'invoice': [(s,id) for s, id in val_ds if s.label == label2id['invoice']],
    'budget': [(s,id) for s, id in val_ds if s.label == label2id['budget']],
    'file folder': [(s,id) for s, id in val_ds if s.label == label2id['file folder']],
    'questionnaire': [(s,id) for s, id in val_ds if s.label == label2id['questionnaire']]
}
wrong_ids = {
    "form": [],
    "invoice": [],
    "budget": [],
    "file folder": [],
    "questionnaire": []
}

In [None]:
wrong_ids = {
    "form": [],
    "invoice": [],
    "budget": [],
    "file folder": [],
    "questionnaire": []
}

In [None]:
# 5 Samples from each class
# for example in forms:

def find_incorrect(key):
  for doc, id in classes[key]:
    bros_p = pred(BROS, BROS_encode, [doc])
    llmv3_p = pred(LLMV3, LLMV3_encode, [doc])
    if bros_p.argmax() != label2id[key] or llmv3_p.argmax() != label2id[key]:
      wrong_ids[key].append(id)

find_incorrect('form')
find_incorrect('invoice')
find_incorrect('budget')
find_incorrect('file folder')
find_incorrect('questionnaire')

In [None]:
wrong_ids['invoice']
# print(len(wrong_ids['invoice']))
# print(len(classes['invoice']))

In [None]:
forms = [classes['form'][5], classes['form'][10], classes['form'][56], classes['form'][150], classes['form'][77]]
invoices = [classes['invoice'][3], classes['invoice'][12], classes['invoice'][29], classes['invoice'][40], classes['invoice'][22]]
budgets = [classes['budget'][1], classes['budget'][3], classes['budget'][13], classes['budget'][33], classes['budget'][21]]
file_folders = [classes['file folder'][32], classes['file folder'][2], classes['file folder'][11], classes['file folder'][94], classes['file folder'][29]]
questionnaires = [classes['questionnaire'][5], classes['questionnaire'][1], classes['questionnaire'][35], classes['questionnaire'][36], classes['questionnaire'][70]]

def check_incorrect(key, docs):
  for idx, (doc, id) in enumerate(docs):
    if id in wrong_ids[key]:
      print(idx)
      return False
  return True

assert check_incorrect('form', forms)
assert check_incorrect('invoice', invoices)
assert check_incorrect('budget', budgets)
assert check_incorrect('file folder', file_folders)
assert check_incorrect('questionnaire', questionnaires)

In [None]:
final_samples = {
    "form": forms,
    "file folder": file_folders,
    "budget": budgets,
    "invoice": invoices,
    "questionnaire": questionnaires
}

In [None]:
for idx, sample in enumerate(final_samples['file folder']):
  if len(sample[0].words) == 0:
    print(idx)

## Obtain Explanations

In [None]:
from tqdm import tqdm
def obtain_explainers(key, label):
  text_explainers = {
      'BROS lime' : LimeTextExplainer(BROS, BROS_encode, mask_token=BROS_t.mask_token, kernel_width_factor = 0.75, labels=[label]),
      'LLMV3 lime' : LimeTextExplainer(LLMV3, LLMV3_encode, mask_token = LLMV3_proc.tokenizer.mask_token, kernel_width_factor = 0.75, labels = [label]),
      'BROS shap' : SHAPTextExplainer(BROS, BROS_encode, BROS_t, mask_token=BROS_t.mask_token, device=device),
      'LLMV3 shap' : SHAPTextExplainer(LLMV3, LLMV3_encode, LLMV3_proc.tokenizer,mask_token=LLMV3_proc.tokenizer.mask_token, device=device)
  }

  layout_explainers = {
      'BROS lime' : LimeLayoutExplainer(BROS, BROS_encode, mask_token=BROS_t.mask_token, kernel_width_factor = 0.75, labels=[label]),
      'LLMV3 lime' : LimeLayoutExplainer(LLMV3, LLMV3_encode, mask_token = LLMV3_proc.tokenizer.mask_token, kernel_width_factor = 0.75, labels = [label]),
      'BROS shap' : SHAPLayoutExplainer(BROS, BROS_encode, device=device),
      'LLMV3 shap' : SHAPLayoutExplainer(LLMV3, LLMV3_encode, device=device)
  }

  vision_explainers = {
      'LLMV3 lime' : LimeVisionExplainer(LLMV3, LLMV3_encode, label = [label], device=device),
      'LLMV3 shap' : SHAPVisionExplainer(LLMV3, LLMV3_encode, device=device, class_idx=4, mask_value='blur(64,64)')
  }

  return text_explainers, layout_explainers, vision_explainers

def obtain_explanations(text, layout, vision, key):
  explanations = {}
  print(text)
  for idx, sample in tqdm(enumerate(final_samples[key]), desc=f"Obtaining explanations for {key}s"):
    explanations.update({f'{idx} - [T] - {key}' : explainer.explain(sample[0], nsamples=2000) for key, explainer in text.items() if 'shap' in key})
    explanations.update({f'{idx} - [L] - {key}' : explainer.explain(sample[0], nsamples=2000) for key, explainer in layout.items() if 'shap' in key})
    explanations.update({f'{idx} - [V] - {key}' : explainer.explain(sample[0], nsamples=1000) for key, explainer in vision.items() if 'shap' in key})
    explanations.update({f'{idx} - [T] - {key}' : explainer.explain(sample[0], num_samples=2000) for key, explainer in text.items() if 'lime' in key})
    explanations.update({f'{idx} - [L] - {key}' : explainer.explain(sample[0], num_samples=2000) for key, explainer in layout.items() if 'lime' in key})
    explanations.update({f'{idx} - [V] - {key}' : explainer.explain(sample[0], num_samples=1000) for key, explainer in vision.items() if 'lime' in key})

  return explanations



In [None]:
print(class_explanations['form_explanations']['0 - [T] - BROS shap'])

In [None]:
# with (open("explanations.pkl", "rb")) as openfile:
#     while True:
#         try:
#             class_explanations.update(pickle.load(openfile))
#         except EOFError:
#             break
len(class_explanations)

In [None]:
import pickle
explainers = {key: obtain_explainers(key, label2id[key]) for key, _ in final_samples.items()}

class_explanations = {}
class_explanations.update({'form_explanations' : obtain_explanations(explainers['form'][0], explainers['form'][1], explainers['form'][2], 'form')})
class_explanations.update({'file_folder_explanations' : obtain_explanations(explainers['file folder'][0], explainers['file folder'][1], explainers['file folder'][2], 'file folder')})
class_explanations.update({'budget_explanations' : obtain_explanations(explainers['budget'][0], explainers['budget'][1], explainers['budget'][2], 'budget')})
class_explanations.update({'invoice_explanations' : obtain_explanations(explainers['invoice'][0], explainers['invoice'][1], explainers['invoice'][2], 'invoice')})
class_explanations.update({'questionnaire_explanations' : obtain_explanations(explainers['questionnaire'][0], explainers['questionnaire'][1], explainers['questionnaire'][2], 'questionnaire')})

# OPEN PICLE FILE
# STORE IN CLASS_EXPLANATIONS
# UPDATE WITH INVOICE EXPLANATIONS
# DO REST OF POSTER


with open('explanations_final_final.pkl', 'wb') as fp:
    pickle.dump(class_explanations, fp)
    print('Dictionary saved successfully to file')

!cp explanations_final_final.pkl /content/drive/MyDrive/THESIS/

In [None]:
!cp explanations_final.pkl /content/drive/MyDrive/THESIS/

## Investigation

In [None]:
raw_explanations = {}
with (open("explanations_final_final.pkl", "rb")) as openfile:
    while True:
        try:
            raw_explanations.update(pickle.load(openfile))
        except EOFError:
            break


In [None]:
len(raw_explanations)

In [None]:
len(raw_explanations['file_folder_explanations'])

In [None]:
import numpy as np
from collections import defaultdict

TAG2MOD = {'[T]': 'text', '[L]': 'layout', '[V]': 'vision'}

def _sum_abs(expl, modality, cls_idx):
    if hasattr(expl, 'values'):
        vals = np.asarray(expl.values)
    elif isinstance(expl, dict) and 'values' in expl:   # old pickles
        vals = np.asarray(expl['values'])
    else:                                               # fallback
        vals = np.asarray(expl)

    if modality in ('text', 'layout'):
        if vals.ndim >= 2:
            vals = np.take(vals, cls_idx, axis=-1)
    return np.abs(vals).sum()

def compute_mm_shap_nested(class_explanations, label2id):
    bucket = defaultdict(lambda: defaultdict(dict))  # key → modality → expl

    #  regroup by (model-method, class, sample-idx)
    for cls_wrap, expls in class_explanations.items():
        cls = cls_wrap.replace('_explanations', '')
        # handle file folder
        if cls == 'file_folder':
          cls = 'file folder'
        cls_idx = label2id[cls]
        for k, expl in expls.items():
            parts = [p.strip() for p in k.split('-')]
            idx       = int(parts[0])
            tag       = parts[1]                     # '[T]' / '[L]' / '[V]'
            modelmet  = '-'.join(parts[2:])          # 'BROS shap', …
            if 'shap' not in modelmet.lower():
                continue
            modality  = TAG2MOD[tag]
            bucket[(modelmet, cls, idx, cls_idx)][modality] = expl

    mm_by_sample, per_cls_tmp = {}, defaultdict(list)
    for (modelmet, cls, idx, cls_idx), mods in bucket.items():
        phi = {m: _sum_abs(e, m, cls_idx) for m, e in mods.items()}
        total = sum(phi.values()) or 1.0
        share = {m: phi.get(m, 0.0) / total for m in ['text', 'layout', 'vision']}
        mm_by_sample[(modelmet, cls, idx)] = share
        per_cls_tmp[(modelmet, cls)].append(list(share.values()))

    mm_mean = defaultdict(dict)
    for (modelmet, cls), rows in per_cls_tmp.items():
        arr = np.asarray(rows)
        mm_mean[modelmet][cls] = dict(zip(['text', 'layout', 'vision'], arr.mean(0)))

    return mm_by_sample, mm_mean

In [None]:
mm_sample, mm_class = compute_mm_shap_nested(raw_explanations, label2id)

LLMV3_mm_shap = {
  'form' : mm_class['LLMV3 shap']['form'],
  'invoice' : mm_class['LLMV3 shap']['invoice'],
  'file folder' : mm_class['LLMV3 shap']['file folder'],
  'budget' : mm_class['LLMV3 shap']['budget'],
  'questionnaire' : mm_class['LLMV3 shap']['questionnaire']
}

BROS_mm_shap = {
  'form' : mm_class['BROS shap']['form'],
  'invoice' : mm_class['BROS shap']['invoice'],
  'file folder' : mm_class['BROS shap']['file folder'],
  'budget' : mm_class['BROS shap']['budget'],
  'questionnaire' : mm_class['BROS shap']['questionnaire']
}
print(LLMV3_mm_shap)
print(BROS_mm_shap)
import pandas as pd

dfllmv3 = pd.DataFrame(LLMV3_mm_shap)
dfbros = pd.DataFrame(BROS_mm_shap)


In [None]:
dfbros

In [None]:
MODELS = {
    "BROS":   (BROS,   BROS_encode),
    "LLMV3":  (LLMV3,  LLMV3_encode),
}

MASK_TOKEN = {
    "BROS":  BROS_t.mask_token,
    "LLMV3": LLMV3_proc.tokenizer.mask_token,
}

SLIC_KW = dict(n_segments=200, compactness=20.0, sigma=1.0, start_label=1)

def _rank_positive_feats(expl, expl_type, modality, cls_idx):
    if "shap" in expl_type.lower():
        vals = expl.values
        if modality in ("text", "layout"):
            while vals.ndim > 2: vals = vals[0]      # squeeZE GS
            if vals.ndim == 2: vals = vals[:, cls_idx]
        vals = vals.flatten()
        idx = np.argsort(-vals)
        return [i for i in idx if vals[i] > 0]

    # LIME
    m = expl.as_map()[cls_idx]
    pos = [(i, w) for i, w in m if w > 0]
    return [i for i, _ in sorted(pos, key=lambda x: -x[1])]

def _perturb(sample, modality, feat_ids, model_tag, segments=None):
    if modality == "text":
        words = [w if i not in feat_ids else MASK_TOKEN[model_tag]
                 for i, w in enumerate(sample.words)]
        return DocSample(sample.image, words, sample.bboxes,
                         ner_tags=sample.ner_tags, label=sample.label)

    if modality == "layout":
        w, h = sample.image.size
        boxes = [b if i not in feat_ids else [0, 0, w, h] # 0,0, width, height instead of 0,0,0,0
                 for i, b in enumerate(sample.bboxes)]
        return DocSample(sample.image, sample.words, boxes,
                         ner_tags=sample.ner_tags, label=sample.label)

    if modality == "vision":
        if segments is None:
            img_np  = np.asarray(sample.image)
            segments = slic(img_np, **SLIC_KW)
        from PIL import ImageFilter
        blur = sample.image.filter(ImageFilter.GaussianBlur(64))
        base, blur_np = np.array(sample.image), np.array(blur)
        mask = np.isin(segments, feat_ids)
        out = base.copy()
        out[mask] = blur_np[mask]
        return DocSample(Image.fromarray(out.astype(np.uint8)),
                         sample.words, sample.bboxes,
                         ner_tags=sample.ner_tags, label=sample.label)

    raise ValueError(f"Unknown modality: {modality}")



@torch.no_grad()
def _predict_prob(samples, model, encode_fn, cls_idx, device):
    enc = encode_fn(samples, device)
    logits = model(**enc).logits
    probs  = torch.softmax(logits, dim=-1)
    return probs[:, cls_idx].cpu().numpy()           # (N,)



def plot_aopc_curves(target_cls: str,
                     expl_type: str,
                     modality: str,
                     K: int = 20,
                     device: torch.device | str = "cuda"):

    tag = "[T]" if modality=="text" else "[L]" if modality=="layout" else "[V]"

    cls_wrap = f"{target_cls}_explanations"
    if cls_wrap == 'file folder_explanations':
      cls_wrap = 'file_folder_explanations'
    expls = raw_explanations[cls_wrap]

    model_tag = expl_type.split()[0]                 # e.g. " / "LLMV3"
    model, encode_fn = MODELS[model_tag]
    cls_idx = label2id[target_cls]

    x = np.arange(0, K+1)
    aopc_acc = []

    plt.figure(figsize=(6,4))

    for k_str, expl in expls.items():
        if tag not in k_str or expl_type not in k_str:
            continue
        s_idx   = int(k_str.split("-")[0].strip())
        sample  = final_samples[target_cls][s_idx]

        ranked = _rank_positive_feats(expl, expl_type, modality, cls_idx)
        if not ranked:
            continue

        seg = None
        if modality == "vision":
            img_np = np.asarray(sample[0].image)
            ch_ax  = -1 if img_np.ndim == 3 else None
            seg    = slic(img_np, channel_axis=ch_ax, **SLIC_KW)

        probs = []
        print([sample[0]])
        p0 = _predict_prob([sample[0]], model, encode_fn, cls_idx, device)[0]
        probs.append(p0)

        masked = []
        for k in range(1, K+1):
            masked.append(ranked[k-1] if k-1 < len(ranked) else None)
            pert = _perturb(sample[0], modality, masked, model_tag, seg)
            pk   = _predict_prob([pert], model, encode_fn, cls_idx, device)[0]
            probs.append(pk)

        plt.plot(x, probs, lw=1)
        aopc_acc.append(np.trapz(p0 - np.array(probs[1:]), dx=1))

    plt.title(f"AOPC – {target_cls} · {expl_type} · {modality}")
    plt.xlabel("k  (top positive features masked)")
    plt.ylabel(f"P(class={target_cls})")
    plt.grid(True)
    plt.show()

    if aopc_acc:
        print(f"Mean AOPC over {len(aopc_acc)} samples: ",
              np.mean(aopc_acc))

In [None]:
plot_aopc_curves(target_cls="invoice",
                 expl_type="LLMV3 shap",
                 modality="text",
                 K=20,
                 device=device)