# Interpretability techniques for single samples from RVL-CDIP subset - LayoutLMVV3

## GITHUB

In [None]:
!git clone https://github.com/adamserag1/Interpretability-for-VRDU-models.git

In [None]:
!git pull https://github.com/adamserag1/Interpretability-for-VRDU-models.git

In [None]:
%cd /content/Interpretability-for-VRDU-models

In [None]:
!pip install -r requirements.txt

In [None]:
!pip install -U datasets

## Libraries

In [None]:
#code
from datasets import load_from_disk
from transformers import LayoutLMv3ForSequenceClassification, AutoProcessor, BrosModel, AutoTokenizer, BrosPreTrainedModel, AutoConfig
import sys
import importlib
def reload_modules():
    for module in list(sys.modules.keys()):
        if module.startswith('vrdu_utils') or module.startswith('Explain') or module.startswith('lime') or module.startswith('Eval'):
            print(f"Reloading module: {module}")
            importlib.reload(sys.modules[module])

reload_modules()

from vrdu_utils.encoders import *
from Explain.lime import *
from vrdu_utils.utils import *
import torch
from Eval.eval_suite import *
from Eval.evaluation import *
# from Eval.fidelity import *
from Explain.shap import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import warnings
from transformers import logging as hf_logging

warnings.filterwarnings(
    "ignore",
    category=FutureWarning,
    module="transformers.modeling_utils",   # the module that emits the msg
)
hf_logging.set_verbosity_error()


## Data + Model Setup

In [None]:
from google.colab import drive
drive.mount("/content/drive")
!cp -r /content/drive/MyDrive/THESIS/rvl_cdip_financial_subset /content

In [None]:
rvl = load_from_disk('/content/rvl_cdip_financial_subset')
dataset_split = rvl.train_test_split(test_size=0.2, seed=42)
val = dataset_split["test"]
val_ds = DocSampleDataset(val)

AGREE = val_ds[846]
CLASH = val_ds[282]

In [None]:
AGREE[0].label

In [None]:
CLASH[0].image

### BROSforDocumentClassifcation classifier head


In [None]:
from torch import nn
class BrosForDocumentClassification(BrosPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.bros = BrosModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

    def forward(
        self,
        input_ids=None,
        bbox=None,
        attention_mask=None,
        token_type_ids=None,
        labels=None,
        **kwargs
    ):
        outputs = self.bros(
            input_ids=input_ids,
            bbox=bbox,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )

        cls_output = outputs.last_hidden_state[:, 0, :]

        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return {
            "loss": loss,
            "logits": logits,
        }

### Model config

In [None]:
LLMV3 = LayoutLMv3ForSequenceClassification.from_pretrained("adamadam111/layoutlmv3-docclass-finetuned-frz",
                                                            num_labels=5,
                                                            id2label={0: "form", 1: "invoice", 2: "budget", 3: "file folder", 4: "questionnaire"},
                                                            label2id={"form": 0, "invoice": 1, "budget": 2, "file folder": 3, "questionnaire": 4})
LLMV3_proc = AutoProcessor.from_pretrained("adamadam111/layoutlmv3-docclass-finetuned-frz", apply_ocr=False)

LLMV3.to(device)

bros_config = AutoConfig.from_pretrained(
    "adamadam111/bros-docclass-finetuned-frz",
    num_labels=5,
    id2label={0: "form", 1: "invoice", 2: "budget", 3: "file folder", 4: "questionnaire"},
    label2id={"form": 0, "invoice": 1, "budget": 2, "file folder": 3, "questionnaire": 4}
)

BROS = BrosForDocumentClassification.from_pretrained(
    "adamadam111/bros-docclass-finetuned-frz",
    config=bros_config
)
BROS_t = AutoTokenizer.from_pretrained("adamadam111/bros-docclass-finetuned-frz",do_lower_case=True)


## Interpreting the 'CLASH' sample

In [None]:
LLMV3_encode = make_layoutlmv3_encoder(LLMV3_proc)
BROS_encode = make_bros_encoder(BROS_t)

pred_fn_llmv3 = FidelityEvaluator(LLMV3, LLMV3_encode, mask_token=LLMV3_proc.tokenizer.mask_token)._get_prediction_function(4)
pred_fn_llmv3 = FidelityEvaluator(BROS, BROS_encode, mask_token=BROS_t.mask_token)._get_prediction_function(0)


### Text Modality


In [None]:
print(CLASH[0].words)

In [None]:
!git pull https://github.com/adamserag1/Interpretability-for-VRDU-models.git

In [None]:
from Explain.lime import *
from Explain.shap import *

In [None]:

text_explainers = {
    'BROS lime' : LimeTextExplainer(BROS, BROS_encode, mask_token=BROS_t.mask_token, kernel_width_factor = 0.75, labels=[4,0]),
    'LLMV3 lime' : LimeTextExplainer(LLMV3, LLMV3_encode, mask_token = LLMV3_proc.tokenizer.mask_token, kernel_width_factor = 0.75, labels = [4,0]),
    'BROS shap' : SHAPTextExplainer(BROS, BROS_encode, BROS_t, mask_token=BROS_t.mask_token, device=device),
    'LLMV3 shap' : SHAPTextExplainer(LLMV3, LLMV3_encode, LLMV3_proc.tokenizer,mask_token=LLMV3_proc.tokenizer.mask_token, device=device)
}

In [None]:
# RUN FOR
shap_text_explainations = {key : explainer.explain(CLASH[0], nsamples = 4000) for key, explainer in text_explainers.items() if 'shap' in key}
lime_text_explainations = {key : explainer.explain(CLASH[0], num_samples = 4000) for key, explainer in text_explainers.items() if 'lime' in key}

In [None]:
bros_pred = text_explainers['BROS lime']._predict([CLASH[0]])
llmv3_pred = text_explainers['LLMV3 lime']._predict([CLASH[0]]).flatten()
print(llmv3_pred.argmax())
print(bros_pred, llmv3_pred)
CLASH[0].label

In [None]:
# Pertaining to label 4 (correct, Questionnaire) (idx 0 shap)
text_explanations_hms = {}
text_explanations_weights = {}
for key, explanation in shap_text_explainations.items():
  weights = {tok : float(val) for tok, val in zip(CLASH[0].words, shap_text_explainations[key].values[:,:,4].flatten())}
  text_explanations_weights.update({key : weights})
  text_explanations_hms.update({key : draw_lime_token_heatmap(image = CLASH[0].image, words = CLASH[0].words, boxes = CLASH[0].bboxes, weights=weights, alpha=0.5) })

for key, explanation in lime_text_explainations.items():
  weights = dict(lime_text_explainations[key].as_list(label=4))
  clean_weights = {key.replace('=1', ''): value for key, value in weights.items()}
  text_explanations_weights.update({key : clean_weights})
  text_explanations_hms.update({key : draw_lime_token_heatmap(image = CLASH[0].image, words = CLASH[0].words, boxes = CLASH[0].bboxes, weights=clean_weights, alpha=0.5) })

# LLMV3 Fidelity
# pred_fn_llmv3_s = FidelityEvaluator(LLMV3, LLMV3_encode, mask_token=LLMV3_proc.tokenizer.mask_token)._get_prediction_function(4)
# pred_fn_llmv3_l = FidelityEvaluator(LLMV3, LLMV3_encode, mask_token='|~|')._get_prediction_function(4)
# comp_s = calculate_comprehensiveness(pred_fn_llmv3_s, CLASH[0], text_explanations_weights['LLMV3 shap'], mask_token='|~|',  top_k=5)
# comp_l = calculate_comprehensiveness(pred_fn_llmv3_l, CLASH[0], text_explanations_weights['LLMV3 lime'], mask_token=LLMV3_proc.tokenizer.mask_token,  top_k=5)
# suf_s = calculate_sufficiency(pred_fn_llmv3_s, CLASH[0], text_explanations_weights['LLMV3 shap'], mask_token='|~|',  top_k=5)
# suf_l = calculate_sufficiency(pred_fn_llmv3_l, CLASH[0], text_explanations_weights['LLMV3 lime'], mask_token=LLMV3_proc.tokenizer.mask_token,  top_k=5)

# display_image_grid([text_explanations_hms['LLMV3 lime'], text_explanations_hms['LLMV3 shap']],
#                    [f'LLMV3 - Lime w.r.t text\nComprehensiveness = {comp_l:.3g}, Sufficiency = {suf_l:.3g}', f'LLMV3 - SHAP w.r.t text\nComprehensiveness = {comp_s:.3g}, Sufficiency = {suf_s:.3g}'],
#                     (1,2),
#                    main_title='Layout LMV3 Correct Prediction w.r.t Text\n "Questionnaire"')

# text_explanations_hms['LLMV3 shap']
text_explanations_hms['LLMV3 lime']

In [None]:
text_explanations_hms = {}
text_explanations_weights = {}
for key, explanation in shap_text_explainations.items():
  weights = {tok : float(val) for tok, val in zip(CLASH[0].words, shap_text_explainations[key].values[:,:,0].flatten())}
  text_explanations_weights.update({key : weights})
  text_explanations_hms.update({key : draw_lime_token_heatmap(image = CLASH[0].image, words = CLASH[0].words, boxes = CLASH[0].bboxes, weights=weights, alpha=0.5) })

for key, explanation in lime_text_explainations.items():
  weights = dict(lime_text_explainations[key].as_list(label=0))
  clean_weights = {key.replace('=1', ''): value for key, value in weights.items()}
  text_explanations_weights.update({key : clean_weights})
  text_explanations_hms.update({key : draw_lime_token_heatmap(image = CLASH[0].image, words = CLASH[0].words, boxes = CLASH[0].bboxes, weights=clean_weights, alpha=0.5) })

# LLMV3 Fidelity
pred_fn_bros_s = FidelityEvaluator(BROS, BROS_encode, mask_token=BROS_t.mask_token)._get_prediction_function(0)
pred_fn_bros_l = FidelityEvaluator(BROS, BROS_encode, mask_token='|~|')._get_prediction_function(0)
comp_s = calculate_comprehensiveness(pred_fn_bros_s, CLASH[0], text_explanations_weights['BROS shap'], mask_token='|~|',  top_k=5)
comp_l = calculate_comprehensiveness(pred_fn_bros_l, CLASH[0], text_explanations_weights['BROS lime'], mask_token=BROS_t.mask_token, top_k=5)
suf_s = calculate_sufficiency(pred_fn_bros_s, CLASH[0], text_explanations_weights['BROS shap'], mask_token='|~|', top_k=5)
suf_l = calculate_sufficiency(pred_fn_bros_l, CLASH[0], text_explanations_weights['BROS lime'], mask_token=BROS_t.mask_token, top_k=5)

# display_image_grid([text_explanations_hms['BROS lime'], text_explanations_hms['BROS shap']],
#                    [f'BROS - Lime\nComprehensiveness = {comp_l:.3g}, Sufficiency = {suf_l:.3g}', f'BROS - SHAP\nComprehensiveness = {comp_s:.3g}, Sufficiency = {suf_s:.3g}'],
#                     (1,2),
#                    main_title='BROS inorrect Prediction w.r.t Text\n "Form"')

# text_explanations_hms['BROS lime']
text_explanations_hms['BROS shap']

In [None]:
text_explanations_hms = {}
text_explanations_weights = {}
for key, explanation in shap_text_explainations.items():
  weights = {tok : float(val) for tok, val in zip(CLASH[0].words, shap_text_explainations[key].values[:, 0])}
  text_explanations_weights.update({key : weights})
  text_explanations_hms.update({key : draw_lime_token_heatmap(image = CLASH[0].image, words = CLASH[0].words, boxes = CLASH[0].bboxes, weights=weights, alpha=0.5) })

for key, explanation in lime_text_explainations.items():
  weights = dict(lime_text_explainations[key].as_list(label=0))
  clean_weights = {key.replace('=1', ''): value for key, value in weights.items()}
  text_explanations_weights.update({key : clean_weights})
  text_explanations_hms.update({key : draw_lime_token_heatmap(image = CLASH[0].image, words = CLASH[0].words, boxes = CLASH[0].bboxes, weights=clean_weights, alpha=0.5) })

In [None]:
## Pertaining to label 0 (incorrect, Form)

In [None]:
 print(shap_text_explainations['BROS shap'].values[:,:,4].flatten())

In [None]:
print(lime_text_explainations['BROS lime'])

In [None]:
print( [dict(lime_text_explainations['BROS lime'].as_list(label=0))])

In [None]:
CLASH[0].label

In [None]:
# FIDELTI
from Eval.evaluation import evaluate_sample as eval
# from Eval.eval_suite_updated import calculate_sufficiency as calculate_sufficiency
def predict_fn_bros(sample):
    enc = BROS_encode([sample], device)
    pred = BROS(**enc)
    logits = pred['logits']
    probs  = torch.softmax(logits, -1)[0]
    return probs[0].item()

@torch.no_grad()
def predict_fn_llmv3(sample):
    enc = LLMV3_encode([sample], device)
    logits =  LLMV3(**enc).logits
    probs  = torch.softmax(logits, -1)[0]
    return probs[4].item()


# class_ids_dict = {
#     "BROS lime"  : 0,   # probability of class-0 for BROS
#     "LLMV3 lime" : 4,   # probability of class-4 for LayoutLMv3
#     "BROS shap"  : 0,
#     "LLMV3 shap" : 4,
# }

# models_dict = {
#     'BROS lime' : BROS,
#     'LLMV3 lime' : LLMV3,
#     'BROS shap' : BROS,
#     'LLMV3 shap' : LLMV3
# }

# encoders_dict = {
#     'BROS lime' : BROS_encode,
#     'LLMV3 lime' : LLMV3_encode,
#     'BROS shap' : BROS_encode,
#     'LLMV3 shap' : LLMV3_encode
# }

# mask_tokens_dict = {
#     'BROS lime' : BROS_t.mask_token,
#     'LLMV3 lime' : LLMV3_proc.tokenizer.mask_token,
#     'BROS shap' : BROS_t.mask_token,
#     'LLMV3 shap' : LLMV3_proc.tokenizer.mask_token
# }
# text_explanations_dict = {
#     'BROS lime' : [dict(lime_text_explainations['BROS lime'].as_list(label=0))],
#     'LLMV3 lime' : [dict(lime_text_explainations['LLMV3 lime'].as_list(label=4))],
#     'BROS shap' : [{tok : float(val) for tok, val in zip(CLASH[0].words, shap_text_explainations['BROS shap'].values[:,:,0].flatten())}],
#     'LLMV3 shap' : [{tok : float(val) for tok, val in zip(CLASH[0].words, shap_text_explainations['LLMV3 shap'].values[:,:,4].flatten())}]
# }

# aopc = compute_aopc_curves([CLASH[0]], text_explanations_dict, modality="text", models_dict=models_dict, encoders_dict=encoders_dict, mask_tokens_dict=mask_tokens_dict, class_ids_dict = class_ids_dict, max_k=50)
# plot_aopc_curves(aopc, modality="text")
# text_explanations_dict = shap_text_explainations
# print(lime_text_explainations)
# text_explanations_dict.update({k : d.as_list(label=0) for k, d in lime_text_explainations.items()})
# print(text_explanations_dict)
# print(len(shap_text_explainations))
shap_text_fid = {}
lime_text_fid = {}
for key,exp in shap_text_explainations.items():
  if 'LLMV3' in key:
    weights = {tok : float(val) for tok, val in zip(CLASH[0].words, exp.values[:,:,4].flatten())}
    shap_text_fid.update({key : eval(CLASH[0], weights, 'text', LLMV3, LLMV3_encode, top_k=5, device=device, mask_token=LLMV3_proc.tokenizer.mask_token, target_class_id=4)})

  if 'BROS' in key:
    weights = {tok : float(val) for tok, val in zip(CLASH[0].words, exp.values[:,:,0].flatten())}
    shap_text_fid.update({key : eval(CLASH[0], weights, 'text', BROS, BROS_encode, top_k=5, device=device, mask_token=BROS_t.mask_token,  target_class_id=0)})


for key,exp in lime_text_explainations.items():
  if 'LLMV3' in key:
    weights = dict(exp.as_list(label=0))
    clean_weights = {key.replace('=1', ''): value for key, value in weights.items()}
    shap_text_fid.update({key : eval(CLASH[0], clean_weights, 'text', LLMV3, LLMV3_encode, top_k=5, device=device, mask_token=LLMV3_proc.tokenizer.mask_token,  target_class_id=4)})

  if 'BROS' in key:
    weights = dict(exp.as_list(label=0))
    clean_weights = {key.replace('=1', ''): value for key, value in weights.items()}
    shap_text_fid.update({key : eval(CLASH[0], clean_weights, 'text', BROS, BROS_encode, top_k=5, device=device, mask_token=BROS_t.mask_token,  target_class_id=0)})

print(lime_text_fid)
print(shap_text_fid)

## Layout

In [None]:
layout_explainers = {
    'BROS lime' : LimeLayoutExplainer(BROS, BROS_encode, mask_token=BROS_t.mask_token, kernel_width_factor = 0.75, labels=[4,0]),
    'LLMV3 lime' : LimeLayoutExplainer(LLMV3, LLMV3_encode, mask_token = LLMV3_proc.tokenizer.mask_token, kernel_width_factor = 0.75, labels = [4,0]),
    'BROS shap' : SHAPLayoutExplainer(BROS, BROS_encode, device=device),
    'LLMV3 shap' : SHAPLayoutExplainer(LLMV3, LLMV3_encode, device=device)
}

In [None]:
shap_layout_explainations = {key : explainer.explain(CLASH[0], nsamples = 4000) for key, explainer in layout_explainers.items() if 'shap' in key}
lime_layout_explainations = {key : explainer.explain(CLASH[0], num_samples = 4000) for key, explainer in layout_explainers.items() if 'lime' in key}

In [None]:
# Pertaining to label 4 (correct, Questionnaire) (idx 0 shap)
layout_explanations_hms = {}
layout_explanations_weights = {}
for key, explanation in shap_layout_explainations.items():
  weights = {tok : float(val) for tok, val in zip(CLASH[0].words, shap_layout_explainations[key].values[:,:,4].flatten())}
  layout_explanations_weights.update({key : weights})
  layout_explanations_hms.update({key : draw_lime_token_heatmap(image = CLASH[0].image, words = CLASH[0].words, boxes = CLASH[0].bboxes, weights=weights, alpha=0.5) })

for key, explanation in lime_layout_explainations.items():
  weights = dict(lime_layout_explainations[key].as_list(label=4))
  clean_weights = {key.replace('=1', ''): value for key, value in weights.items()}
  layout_explanations_weights.update({key : clean_weights})
  layout_explanations_hms.update({key : draw_lime_token_heatmap(image = CLASH[0].image, words = CLASH[0].words, boxes = CLASH[0].bboxes, weights=clean_weights, alpha=0.5) })

# LLMV3 Fidelity
pred_fn_llmv3_s = FidelityEvaluator(LLMV3, LLMV3_encode, mask_token=LLMV3_proc.tokenizer.mask_token)._get_prediction_function(4)
pred_fn_llmv3_l = FidelityEvaluator(LLMV3, LLMV3_encode, mask_token='|~|')._get_prediction_function(4)
comp_s = calculate_comprehensiveness(pred_fn_llmv3_s, CLASH[0], layout_explanations_weights['LLMV3 shap'], mask_token=LLMV3_proc.tokenizer.mask_token,  top_k=10, modality='layout')
comp_l = calculate_comprehensiveness(pred_fn_llmv3_l, CLASH[0], layout_explanations_weights['LLMV3 lime'], mask_token=LLMV3_proc.tokenizer.mask_token,  top_k=10,modality='layout')
suf_s = calculate_sufficiency(pred_fn_llmv3_s, CLASH[0], layout_explanations_weights['LLMV3 shap'], LLMV3_proc.tokenizer.mask_token,  top_k=5,modality='layout')
suf_l = calculate_sufficiency(pred_fn_llmv3_l, CLASH[0], layout_explanations_weights['LLMV3 lime'], mask_token=LLMV3_proc.tokenizer.mask_token,  top_k=5,modality='layout')

# display_image_grid([layout_explanations_hms['LLMV3 lime'], layout_explanations_hms['LLMV3 shap']],
#                    [f'LLMV3 - Lime w.r.t layout\nComprehensiveness = {comp_l:.3g}, Sufficiency = {suf_l:.3g}', f'LLMV3 - SHAP w.r.t layout\nComprehensiveness = {comp_s:.3g}, Sufficiency = {suf_s:.3g}'],
#                     (1,2),
#                    main_title='Layout LMV3 Correct Prediction w.r.t Layout\n "Questionnaire"')

# layout_explanations_hms['LLMV3 lime']
layout_explanations_hms['LLMV3 shap']

In [None]:
# Pertaining to label 4 (correct, Questionnaire) (idx 0 shap)
layout_explanations_hms = {}
layout_explanations_weights = {}
for key, explanation in shap_layout_explainations.items():
  weights = {tok : float(val) for tok, val in zip(CLASH[0].words, shap_layout_explainations[key].values[:,:,0].flatten())}
  layout_explanations_weights.update({key : weights})
  layout_explanations_hms.update({key : draw_lime_token_heatmap(image = CLASH[0].image, words = CLASH[0].words, boxes = CLASH[0].bboxes, weights=weights, alpha=0.5) })

for key, explanation in lime_layout_explainations.items():
  weights = dict(lime_layout_explainations[key].as_list(label=0))
  clean_weights = {key.replace('=1', ''): value for key, value in weights.items()}
  layout_explanations_weights.update({key : clean_weights})
  layout_explanations_hms.update({key : draw_lime_token_heatmap(image = CLASH[0].image, words = CLASH[0].words, boxes = CLASH[0].bboxes, weights=clean_weights, alpha=0.5) })

# LLMV3 Fidelity
pred_fn_bros_s = FidelityEvaluator(BROS, BROS_encode, mask_token=BROS_t.mask_token)._get_prediction_function(0)
pred_fn_bros_l = FidelityEvaluator(BROS, BROS_encode, mask_token=BROS_t.mask_token)._get_prediction_function(0)
comp_s = calculate_comprehensiveness(pred_fn_bros_s, CLASH[0], layout_explanations_weights['BROS shap'], mask_token=BROS_t.mask_token,  top_k=10, modality='layout')
comp_l = calculate_comprehensiveness(pred_fn_bros_l, CLASH[0], layout_explanations_weights['BROS lime'], mask_token=BROS_t.mask_token,  top_k=5,modality='layout')
suf_s = calculate_sufficiency(pred_fn_bros_s, CLASH[0], layout_explanations_weights['BROS shap'], mask_token=BROS_t.mask_token,  top_k=5,modality='layout')
suf_l = calculate_sufficiency(pred_fn_bros_l, CLASH[0], layout_explanations_weights['BROS lime'], mask_token=BROS_t.mask_token,  top_k=5,modality='layout')

# display_image_grid([layout_explanations_hms['BROS lime'], layout_explanations_hms['BROS shap']],
#                    [f'LLMV3 - Lime w.r.t layout\nComprehensiveness = {comp_l:.3g}, Sufficiency = {suf_l:.3g}', f'LLMV3 - SHAP w.r.t layout\nComprehensiveness = {comp_s:.3g}, Sufficiency = {suf_s:.3g}'],
#                     (1,2),
#                    main_title='Layout LMV3 Correct Prediction w.r.t Layout\n "Questionnaire"')

layout_explanations_hms['BROS shap']
# layout_explanations_hms['BROS shap']

In [None]:
print(shap_layout_explainations.keys())
print(lime_layout_explainations.keys())


In [None]:
shap_layout_fid = {}
lime_layout_fid = {}

for key, exp in shap_layout_explainations.items():
    if "LLMV3" in key:                                  # class-4 explanation
        scores   = exp.values[:, :, 4].flatten()        # (#tokens,)
        weights  = {tuple(box): float(s)
                    for box, s in zip(CLASH[0].bboxes, scores)}
        shap_layout_fid[key] = eval(CLASH[0], weights, "layout",
                                    LLMV3, LLMV3_encode,
                                    top_k=5, device=device,
                                    mask_token=LLMV3_proc.tokenizer.mask_token,
                                    target_class_id=4)

    if "BROS" in key:                                   # class-0 explanation
        scores   = exp.values[:, :, 0].flatten()
        weights  = {tuple(box): float(s)
                    for box, s in zip(CLASH[0].bboxes, scores)}
        shap_layout_fid[key] = eval(CLASH[0], weights, "layout",
                                    BROS, BROS_encode,
                                    top_k=5, device=device,
                                    mask_token=BROS_t.mask_token,
                                    target_class_id=0)


for key, exp in lime_layout_explainations.items():
    if "LLMV3" in key:
        token_w = dict(exp.as_map()[4])                 # class-4 map
        bbox_w  = {tuple(CLASH[0].bboxes[i]): w
                   for i, w in token_w.items()
                   if i < len(CLASH[0].bboxes)}
        lime_layout_fid[key] = eval(CLASH[0], bbox_w, "layout",
                                     LLMV3, LLMV3_encode,
                                     top_k=5, device=device,
                                     mask_token=LLMV3_proc.tokenizer.mask_token,
                                     target_class_id=4)

    if "BROS" in key:
        token_w = dict(exp.as_map()[0])
        bbox_w  = {tuple(CLASH[0].bboxes[i]): w
                   for i, w in token_w.items()
                   if i < len(CLASH[0].bboxes)}
        lime_layout_fid[key] = eval(CLASH[0], bbox_w, "layout",
                                     BROS, BROS_encode,
                                     top_k=5, device=device,
                                     mask_token=BROS_t.mask_token,
                                     target_class_id=0)
print(lime_layout_fid)
print(shap_layout_fid)

## Vision

In [None]:
vision_explainers = {
    'LLMV3 lime' : LimeVisionExplainer(LLMV3, LLMV3_encode, label = [4], device=device),
    'LLMV3 shap' : SHAPVisionExplainer(LLMV3, LLMV3_encode, device=device, class_idx=4, mask_value='blur(64,64)')
}

In [None]:
print(vision_explanations['LLMV3 shap'].values.shape)

In [None]:
vision_explanations = {key : explainer.explain(CLASH[0], num_samples = 1000) for key, explainer in vision_explainers.items() if 'lime' in key}
vision_explanations.update({key : explainer.explain(CLASH[0], nsamples = 1000) for key, explainer in vision_explainers.items() if 'shap' in key})
# temp = vision_explainers['LLMV3 shap'].explain(CLASH[0], nsamples = 1000, max_batch = 32)

In [None]:
print(vision_explanations['LLMV3 shap'])

In [None]:
import numpy as np
from Eval.evaluation import evaluate_sample
from skimage.segmentation import slic
from PIL import Image

# Lime
lime_segs = vision_explanations['LLMV3 lime'].segments
seg_weights = {int(sid): float(w) for sid, w in vision_explanations['LLMV3 lime'].local_exp[4]}

metrics_lime = evaluate_sample(CLASH[0], seg_weights, 'vision',
                               LLMV3, LLMV3_encode,
                               top_k = 10,
                               device = device,
                               target_class_id = 4,
                               segments = lime_segs)

print("lime:", metrics_lime)

# SHAP

shap_vals = vision_explanations['LLMV3 shap'].values   # (H,W,3) float
abs_map   = np.abs(shap_vals).sum(-1)                    # (H,W)

segments_S = lime_segs

weights_S = {int(sid): float(abs_map[segments_S == sid].mean())
             for sid in np.unique(segments_S)}

metrics_shap = evaluate_sample(CLASH[0], weights_S, "vision",
                               LLMV3, LLMV3_encode,
                               top_k           = 10,
                               device          = device,
                               target_class_id = 4,
                               segments        = segments_S)

print("SHAP – vision:", metrics_shap)


In [None]:
def shap_overlay(exp, alpha=0.6, cmap="bwr"):
    vals  = exp.values          # (H,W) or (H,W,3)
    img   = exp.data.astype(np.uint8)   # (H,W,3)

    if vals.ndim == 3:                   # SHAP gave per-channel values
        vals = vals.mean(2)              # collapse to one heat-map

    vmax  = np.abs(vals).max() + 1e-12
    norm  = (vals + vmax) / (2 * vmax)   # 0…1
    rgb   = plt.get_cmap(cmap)(norm)[..., :3] * 255  # (H,W,3) uint8
    blend = (img*(1-alpha) + rgb*alpha).astype(np.uint8)
    return Image.fromarray(blend)

shap_overlay(vision_explanations['LLMV3 shap'], alpha=0.6)


In [None]:
import numpy as np, matplotlib.pyplot as plt, matplotlib.cm as cm
from PIL import Image
from skimage.segmentation import slic

def lime_heatmap(expl, class_idx=4, alpha_max=0.6, cmap="bwr"):
    img     = expl.image.astype(float)
    segs    = expl.segments
    weights = dict(expl.local_exp[class_idx])
    vmax    = max(abs(w) for w in weights.values()) + 1e-12

    overlay = img.copy()
    for seg_id, w in weights.items():
        mask   = segs == seg_id
        color  = np.array(cm.get_cmap(cmap)((w+vmax)/(2*vmax))[:3]) * 255
        alpha  = alpha_max * abs(w) / vmax
        overlay[mask] = overlay[mask]*(1-alpha) + color*alpha

    return Image.fromarray(overlay.astype(np.uint8))

expl  = vision_explanations['LLMV3 lime']
hm    = lime_heatmap(expl, class_idx=4, alpha_max=0.7)
hm

In [None]:
pred_f_l = vision_explainers['LLMV3 lime']._batched_predict
pred_f_s = vision_explainers['LLMV3 shap']._batched_predict
lime_drop, im1   = lime_comprehensiveness(pred_f_l, CLASH[0], vision_explanations['LLMV3 lime'],4, 2 )
lime_keep, im2   = lime_sufficiency(pred_f_l, CLASH[0], vision_explanations['LLMV3 lime'],4,2)
print(lime_drop, lime_keep)
# SHAP
shap_drop, im3   = shap_comprehensiveness(pred_f_l, CLASH[0], temp, 1)
shap_keep, im4   = shap_sufficiency(pred_f_l, CLASH[0], temp,50)
print(shap_drop, shap_keep)
display_image_grid([im1,im2,im3,im4], titles=['ld', 'lk', 'sd', 'sk'], grid_size=(2,2),main_title='fidComp')

In [None]:
# Trouble importing from evaluation so done in notebook

def _predict_prob(predict_fn, sample):
    logits = predict_fn([sample])[0]
    if logits.ndim == 0:
        return 1 / (1 + np.exp(-logits))
    cls = np.argmax(logits)
    probs = 1 / (1 + np.exp(-logits))
    return probs[cls]


def _mask_image(img, mask, hide=(255, 255, 255)):
    out = img.copy()
    out[mask] = hide
    return out


def _doc_from_img(img_np, t):
    return DocSample(
        image=Image.fromarray(img_np),
        words=t.words,
        bboxes=t.bboxes,
        ner_tags=t.ner_tags,
        label=t.label,
    )


def _topk_lime_segments(weights: dict, k: int):
    # keep only positive weights
    pos = {sid: w for sid, w in weights.items() if w > 0}
    if not pos:
        return []
    return sorted(pos, key=pos.get, reverse=True)[:k]


def lime_comprehensiveness(predict, sample, exp, class_idx, k=5):
    w = dict(exp.local_exp[class_idx])
    top = _topk_lime_segments(w, k)

    seg = exp.segments
    mask = np.isin(seg, top)

    img = np.asarray(sample.image.convert("RGB"))
    masked = _mask_image(img, mask)
    pert = _doc_from_img(masked, sample)

    d = _predict_prob(predict, sample) - _predict_prob(predict, pert)
    return d, Image.fromarray(masked)


def lime_sufficiency(predict, sample, exp, class_idx, k=5):
    w = dict(exp.local_exp[class_idx])
    top = _topk_lime_segments(w, k)

    seg = exp.segments
    mask = ~np.isin(seg, top)

    img = np.asarray(sample.image.convert("RGB"))
    masked = _mask_image(img, mask)
    pert = _doc_from_img(masked, sample)

    d = _predict_prob(predict, sample) - _predict_prob(predict, pert)
    return d, Image.fromarray(masked)


def _topk_shap_segments(vals, img_np, k,
                        n_segments=150, compactness=10, sigma=1):
    segs = slic(img_np, n_segments=n_segments,
                compactness=compactness, sigma=sigma, start_label=0)
    means = {sid: vals[segs == sid].mean() for sid in np.unique(segs)}
    pos = {sid: m for sid, m in means.items() if m > 0}
    if not pos:
        return np.zeros_like(vals, bool)
    top = sorted(pos, key=pos.get, reverse=True)[:k]
    return np.isin(segs, top)


def shap_comprehensiveness(predict, sample, exp, k=10, slic_kw=None):
    vals = exp.values.mean(2) if exp.values.ndim == 3 else exp.values
    img = np.asarray(sample.image.convert("RGB"))
    mask = _topk_shap_segments(vals, img, k, **(slic_kw or {}))

    masked = _mask_image(img, mask)
    pert = _doc_from_img(masked, sample)

    d = _predict_prob(predict, sample) - _predict_prob(predict, pert)
    return d, Image.fromarray(masked)


def shap_sufficiency(predict, sample, exp, k=10, slic_kw=None):
    vals = exp.values.mean(2) if exp.values.ndim == 3 else exp.values
    img = np.asarray(sample.image.convert("RGB"))
    keep = _topk_shap_segments(vals, img, k, **(slic_kw or {}))
    mask = ~keep

    masked = _mask_image(img, mask)
    pert = _doc_from_img(masked, sample)

    d = _predict_prob(predict, sample) - _predict_prob(predict, pert)
    return d, Image.fromarray(masked)

## DUMP

### Vision

In [None]:
vision_explainer = LimeVisionExplainer(
    LLMV3,
    LLMV3_encode,
    batch_size=4,
    label = 4
)

vision_vals = vision_explainer.explain(CLASH[0], num_samples=1000, num_features=200)

In [None]:
print("Explanation score (R²):", vision_vals.score)

In [None]:
segment_weights = dict(vision_vals.local_exp[2])

# Create weight map
segments = vision_vals.segments
weight_map = np.zeros_like(segments, dtype=float)
for seg_id, weight in segment_weights.items():
    weight_map[segments == seg_id] = weight

# Define overlay function
def lime_weight_to_overlay(image_np, weight_map, alpha=0.4):
    max_weight = np.max(np.abs(weight_map))
    if max_weight == 0:
        max_weight = 1
    normalized_weights = weight_map / max_weight
    r = np.where(normalized_weights > 0, 255,
                 np.where(normalized_weights < 0, 255 * (1 + normalized_weights), 255))
    g = np.where(normalized_weights > 0, 255 * (1 - normalized_weights),
                 np.where(normalized_weights < 0, 255 * (1 + normalized_weights), 255))
    b = np.where(normalized_weights > 0, 255 * (1 - normalized_weights),
                 np.where(normalized_weights < 0, 255, 255))
    color_map = np.stack([r, g, b], axis=-1).astype(np.uint8)
    overlay = ((1 - alpha) * image_np + alpha * color_map).clip(0, 255).astype(np.uint8)
    return overlay

# Apply overlay
img_np = vision_vals.image
heat_np = lime_weight_to_overlay(img_np, weight_map, alpha=0.4)

# Display
Image.fromarray(heat_np)

In [None]:
from PIL import Image

In [None]:
img_np, mask = vision_vals.get_image_and_mask(
    label = 2,
    positive_only=False,   # include negative weights too
    num_features=30,
    hide_rest=False
)

heat_np = lime_mask_to_overlay(img_np, mask, alpha=0.40)
Image.fromarray(heat_np)

In [None]:
print("Explanation score (R²):", vision_vals.score)
for word, weight in vision_vals:
  print(f"{word:10s} -> {weight:+.10f}")

In [None]:
text_explainer = LimeTextExplainer(
    LLMV3,
    LLMV3_encode,
    mask_token = LLMV3_proc.tokenizer.mask_token,
    batch_size = 2,
    kernel_width_factor = 0.75,
    labels = [4]
)

In [None]:
text_vals = text_explainer.explain(AGREE[0], align_boxes=True, num_samples=2000, num_features=30)

In [None]:
print("Explanation score (R²):", text_vals.score)
for word, weight in text_vals.as_list(label=4):
  print(f"{word:10s} -> {weight:+.10f}")

In [None]:
from vrdu_utils.utils import draw_lime_token_heatmap
import re
weights = {}
for token, w in text_vals.as_list(label=4):
    clean = re.sub(r"=\d+$", "", token)     # drop '=number' suffix
    weights[clean] = weights.get(clean, 0.0) + w

draw_lime_token_heatmap(image = AGREE[0].image, words = AGREE[0].words, boxes = AGREE[0].bboxes, weights = weights, alpha = 0.25)

In [None]:
fe_text = FidelityEvaluator(
    model = LLMV3,
    encode_fn = LLMV3_encode,
    device = device,
    mask_token = LLMV3_proc.tokenizer.mask_token
)

In [None]:
len(AGREE[0].words)

In [None]:
top20 = sorted(weights, key=weights.get, reverse=True)[: int(len(weights)*0.2)]

print("LIME feature strings:", top20[:10])
print("First 20 words in sample:", AGREE[0].words[:20])

# quick overlap check
print("Overlap size:",
      len({f.split('=')[0] for f in top20}.intersection(AGREE[0].words)))

In [None]:
orig = fe_text._get_prediction_function(AGREE[0].label)(AGREE[0])
pert = fe_text._get_prediction_function(AGREE[0].label)(
            DocSample(image=AGREE[0].image, words=["[UNK]"]*len(AGREE[0].words),
                      bboxes=AGREE[0].bboxes, ner_tags=AGREE[0].ner_tags, label=AGREE[0].label))
print(f"p_orig  = {orig:.6f}")
print(f"p_allUNK= {pert:.6f}")

In [None]:
scores = fe_text.evaluate(
    sample          = AGREE[0],
    explanation     = weights,
    top_k_fraction  = 0.2,          # use 20 % of the most important tokens
)
print(scores)
