# Interpretability Techniques for single samples from FUNSD

## Github

In [None]:
!git clone https://github.com/adamserag1/Interpretability-for-VRDU-models.git

In [None]:
!git pull https://github.com/adamserag1/Interpretability-for-VRDU-models.git

In [None]:
%cd /content/Interpretability-for-VRDU-models

In [None]:
!pip install -r requirements.txt

In [None]:
!pip install -U datasets

## Libraries

In [None]:
#code
from datasets import load_from_disk
from transformers import LayoutLMv3ForTokenClassification, AutoProcessor, BrosForTokenClassification, AutoTokenizer, AutoConfig
import sys
import importlib
def reload_modules():
    for module in list(sys.modules.keys()):
        if module.startswith('vrdu_utils') or module.startswith('Classification_Explain') or module.startswith('lime') or module.startswith('Eval'):
            print(f"Reloading module: {module}")
            importlib.reload(sys.modules[module])

reload_modules()

from vrdu_utils.encoders import *
from Classification_Explain.lime import *
from vrdu_utils.utils import *
import torch
from Eval.eval_suite import *
from Eval.fidelity import *
from Classification_Explain.shap import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import warnings
from transformers import logging as hf_logging

warnings.filterwarnings(
    "ignore",
    category=FutureWarning,
    module="transformers.modeling_utils",
)
hf_logging.set_verbosity_error()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Data & Model Setup

In [None]:
FUNSD = load_dataset("nielsr/funsd")
label_list = FUNSD["train"].features["ner_tags"].feature.names
id2label = {v:k for v,k in enumerate(label_list)}
label2id = {k:v for v,k in enumerate(label_list)}
funsd_test = FUNSD["test"]
test_ds = DocSampleDataset(funsd_test)
print(label_list)

In [None]:
LLMV3 = LayoutLMv3ForTokenClassification.from_pretrained('adamadam111/layoutlmv3-finetuned-funsd', num_labels = len(label_list), id2label = id2label, label2id = label2id)
LLMV3_proc = AutoProcessor.from_pretrained('adamadam111/layoutlmv3-finetuned-funsd', apply_ocr=False)


BROS = BrosForTokenClassification.from_pretrained('adamadam111/bros-funsd-finetuned', num_labels = len(label_list), id2label = id2label, label2id = label2id)
BROS_t = AutoTokenizer.from_pretrained('adamadam111/bros-funsd-finetuned')

BROS.to(device)
LLMV3.to(device)

In [None]:
LLMV3_encode = make_layoutlmv3_encoder(LLMV3_proc, ner=True)
BROS_encode = make_bros_encoder(BROS_t, ner=True)

## Finding Samples

In [None]:

def display_sample_with_gt(sample, id2label):
    TAG_COLOURS = {
        "O": "gray",
        "B-HEADER": "navy",
        "I-HEADER": "skyblue",
        "B-QUESTION": "green",
        "I-QUESTION": "lightgreen",
        "B-ANSWER": "red",
        "I-ANSWER": "pink",
    }

    img = sample["image"].convert("RGB")
    draw = ImageDraw.Draw(img)
    w, h = img.size

    for bbox, tag_id in zip(sample["bboxes"], sample["ner_tags"]):
        tag_name = id2label[tag_id]
        color = TAG_COLOURS.get(tag_name, "black")
        x1, y1, x2, y2 = unnormalize_box(bbox, w, h)
        draw.rectangle([x1, y1, x2, y2], outline=color, width=2)

    return img

display_sample_with_gt(expample, id2label)

In [None]:
TAG_COLOURS = {
    "O":          "gray",
    "B-HEADER":   "navy",
    "I-HEADER":   "skyblue",
    "B-QUESTION": "green",
    "I-QUESTION": "lightgreen",
    "B-ANSWER":   "red",
    "I-ANSWER":   "pink",
}


def draw_page_with_tags(img_path, bboxes, tag_names, tag_colours=TAG_COLOURS, width=2):

    img = img_path.convert("RGB")
    draw = ImageDraw.Draw(img)
    w, h = img.size
    bboxes = [unnormalize_box(bbox, w, h) for bbox in bboxes]
    for box, tag in zip(bboxes, tag_names):
        colour = tag_colours.get(tag, "black")
        draw.rectangle(box, outline=colour, width=width)

    return img


def _get_word_preds(sample, model, encode_fn, id2label, device="cuda"):

    model.eval()


    ds = DocSample(
        image    = sample["image"],
        words    = sample["words"],
        bboxes   = sample["bboxes"],
        ner_tags = sample["ner_tags"],
    )


    enc, _ = encode_fn([ds], device)
    enc.to(device)
    with torch.no_grad():
        logits = model(**enc).logits[0]

    pred_ids = logits.argmax(-1).tolist()
    word_ids = enc.word_ids(batch_index=0)

    preds, prev_wid = [], None
    for tok_idx, wid in enumerate(word_ids):
        if wid is None or wid == prev_wid:
            continue
        preds.append(id2label[pred_ids[tok_idx]])
        prev_wid = wid
        if len(preds) == len(sample["ner_tags"]):
            break

    return preds
def visualise_error_boxes(
    sample,
    model_a, encode_a,
    model_b, encode_b,
    id2label,
    device="cpu",
):

    tags_a = _get_word_preds(sample, model_a, encode_a, id2label, device)
    tags_b = _get_word_preds(sample, model_b, encode_b, id2label, device)


    gt_tags = [id2label[i] for i in sample["ner_tags"]]

    err_a = draw_page_errors(sample["image"], sample["bboxes"],
                             tags_a, gt_tags)
    err_b = draw_page_errors(sample["image"], sample["bboxes"],
                             tags_b, gt_tags)

    return err_a, err_b

def draw_page_errors(img,
                     bboxes,
                     pred_tags,
                     gt_tags,
                     tag_colours=TAG_COLOURS,
                     width=3):
    canvas = img.convert("RGB").copy()
    draw   = ImageDraw.Draw(canvas)
    W, H   = canvas.size
    pix_boxes = [unnormalize_box(b, W, H) for b in bboxes]

    for box, p_tag, g_tag in zip(pix_boxes, pred_tags, gt_tags):
        if p_tag != g_tag:                       # model got it wrong
            colour = tag_colours.get(p_tag, "black")
            draw.rectangle(box, outline=colour, width=width)

    return canvas


def visualise_two_models(
    sample,
    model_a, encode_a,
    model_b, encode_b,
    id2label,
    device="cpu",
):
    # predictions
    tags_a = _get_word_preds(sample, model_a, encode_a, id2label, device)
    tags_b = _get_word_preds(sample, model_b, encode_b, id2label, device)

    # coloured pages
    img_a  = draw_page_with_tags(sample["image"], sample["bboxes"], tags_a)
    img_b  = draw_page_with_tags(sample["image"], sample["bboxes"], tags_b)

    return img_a, img_b


page = FUNSD["test"][29]                        # or your DocSample instance

img_llmv3, img_bros = visualise_two_models(
    page,
    LLMV3, LLMV3_encode,
    BROS , BROS_encode,
    id2label,
    device=device,
)


err_llmv3, err_bros = visualise_error_boxes(
    page, LLMV3, LLMV3_encode, BROS, BROS_encode, id2label, device=device
)

display(img_llmv3)
display(err_llmv3)
display(img_bros)
display(err_bros)

# TAG_COLOURS = {
#     "O":          "gray",
#     "B-HEADER":   "navy",
#     "I-HEADER":   "skyblue",
#     "B-QUESTION": "green",
#     "I-QUESTION": "lightgreen",
#     "B-ANSWER":   "red",
#     "I-ANSWER":   "pink",
# }


In [None]:
expample = FUNSD['test'][29]
example = DocSample(
        image    = expample["image"],
        words    = expample["words"],
        bboxes   = expample["bboxes"],
        ner_tags = expample["ner_tags"],
    )
target_word = 'OBJECTIVE'
target_idx = 0

for i,t in enumerate(example.words):
  if t == target_word:
    target_idx = i

def target_token_fn(enc, *, batch_index: int = 0):
  try:
    word_ids = enc.word_ids(batch_index=batch_index)
  except:
    word_ids = enc
  for tok_idx, wid in enumerate(word_ids):
      if wid == target_idx:
          return tok_idx

target_label = label2id['B-HEADER']

## Text

In [None]:
text_explainers = {
    'LLMV3 lime' : LimeTextNer(LLMV3, LLMV3_encode, mask_token=LLMV3_proc.tokenizer.mask_token, device = device, kernel_width_factor=0.75, target_token_fn=target_token_fn, labels = [target_label]),
    'LLMV3 shap' : SHAPTextNer(LLMV3, LLMV3_encode, tokenizer=LLMV3_proc.tokenizer, mask_token=LLMV3_proc.tokenizer.mask_token, device = device, target_token_fn=target_token_fn),
    'BROS lime' : LimeTextNer(BROS, BROS_encode, mask_token=BROS_t.mask_token, device = device, kernel_width_factor=0.75, target_token_fn=target_token_fn, labels=[3]), # 3 as falsley predicted as B-QUESTION
    'BROS shap' : SHAPTextNer(BROS, BROS_encode, tokenizer=BROS_t ,mask_token=BROS_t.mask_token, device = device, target_token_fn=target_token_fn),
}


In [None]:
print(example)

In [None]:
text_explanations = {key : text_explainers[key].explain(example) for key in text_explainers}

In [None]:
text_explanations_hms = {}
text_explanations_weights = {}
w, h = example.image.size
bboxes = [unnormalize_box(bbox, w, h) for bbox in example.bboxes]

def get_bros_weights_and_hm(key, explanation, global_weights_dict):
  if 'lime' in key:
    weights = dict(explanation.as_list(label=3))
    weights = {key.replace('=1', ''): value for key, value in weights.items()}
    global_weights_dict.update({key : weights})
  else:
    weights = {tok : float(val) for tok, val in zip(example.words, explanation.values[:,:,3].flatten())}
    global_weights_dict.update({key : weights})
  hm = draw_lime_token_heatmap(image = example.image, words = example.words, boxes = bboxes, weights=weights, alpha=0.5)
  return hm

def get_llmv3_weights_and_hm(key, explanation, global_weights_dict):
  if 'lime' in key:
    weights = dict(explanation.as_list(label=target_label))
    weights = {key.replace('=1', ''): value for key, value in weights.items()}
    global_weights_dict.update({key : weights})
  else:
    weights = {tok : float(val) for tok, val in zip(example.words, explanation.values[:,:,target_label].flatten())}
    global_weights_dict.update({key : weights})
  hm = draw_lime_token_heatmap(image = example.image, words = example.words, boxes = bboxes, weights=weights, alpha=0.7)
  return hm

llmve_hms = [get_llmv3_weights_and_hm(key, text_explanations[key], text_explanations_weights) for key in text_explanations if 'LLMV3' in key]
bros_hms = [get_bros_weights_and_hm(key, text_explanations[key], text_explanations_weights) for key in text_explanations if 'BROS' in key]
# display_image_grid(llmve_hms,
#                    [f'', f''],
#                     (1,2),
#                    main_title='LLMV3"')
# display_image_grid(bros_hms,
#                    [f'',f''],
#                    (1,2),
#                    main_title='BROS'
#                    )

# llmve_hms[0]
# llmve_hms[1]
# bros_hms[0]
bros_hms[1]

In [None]:
ner_fidelity = FidelityEvaluator(
    model          = LLMV3,
    encode_fn      = LLMV3_encode,
    mask_token     = LLMV3_proc.tokenizer.mask_token,
    target_token_fn= target_token_fn,    # NEW
    target_label_id= target_label,    # NEW
    device         = device,
)
pred_fn = ner_fidelity._get_prediction_function(3)
calculate_sufficiency(predict_fn=pred_fn, sample=example,explanation=text_explanations_weights['LLMV3 shap'] , mask_token = LLMV3_proc.tokenizer.mask_token, top_k=1)

## Layout

In [None]:
layout_explainers = {
    'LLMV3 lime' : LimeLayoutNer(LLMV3, LLMV3_encode, mask_token=LLMV3_proc.tokenizer.mask_token, device = device, kernel_width_factor=0.75, target_token_fn=target_token_fn, labels = [target_label]),
    'LLMV3 shap' : SHAPLayoutNer(LLMV3, LLMV3_encode, device = device, target_token_fn=target_token_fn),
    'BROS lime' : LimeLayoutNer(BROS, BROS_encode, mask_token=BROS_t.mask_token, device = device, kernel_width_factor=0.75, target_token_fn=target_token_fn, labels=[3]),
    'BROS shap' : SHAPLayoutNer(BROS, BROS_encode , device = device, target_token_fn=target_token_fn),
}


In [None]:
layout_explanations = {key : layout_explainers[key].explain(example) for key in layout_explainers}

In [None]:
layout_explanations_hms = {}
layout_explanations_weights = {}
w, h = example.image.size
bboxes = [unnormalize_box(bbox, w, h) for bbox in example.bboxes]
layout_llmv3_hms = [get_llmv3_weights_and_hm(key, layout_explanations[key], layout_explanations_weights) for key in layout_explanations if 'LLMV3' in key]
layout_bros_hms = [get_bros_weights_and_hm(key, layout_explanations[key], layout_explanations_weights) for key in layout_explanations if 'BROS' in key]

# display_image_grid(layout_llmv3_hms,
#                    [f'', f''],
#                     (1,2),
#                    main_title='LLMV3"')
# display_image_grid(layout_bros_hms,
#                    [f'',f''],
#                    (1,2),
#                    main_title='BROS'
#                    )
# layout_llmv3_hms[0]
# layout_llmv3_hms[1]
# layout_bros_hms[0]
layout_bros_hms[1]



# for key, explanation in layout_explanations.items():
#   if 'lime' in key:
#     weights = dict(layout_explanations[key].as_list(label=3))
#     weights = {key.replace('=1', ''): value for key, value in weights.items()}
#     layout_explanations_weights.update({key : weights})
#     layout_explanations_hms.update({key : draw_lime_token_heatmap(image = example.image, words = example.words, boxes = bboxes, weights=weights, alpha=0.5) })
#   if 'shap' in key:
#     weights = {tok : float(val) for tok, val in zip(example.words, layout_explanations[key].values[:,:,3].flatten())}
#     layout_explanations_weights.update({key : weights})
#     layout_explanations_hms.update({key : draw_lime_token_heatmap(image = example.image, words = example.words, boxes = bboxes, weights=weights, alpha=0.5) })


# display_image_grid([layout_explanations_hms['LLMV3 lime'], layout_explanations_hms['LLMV3 shap']],
#                    [f'', f''],
#                     (1,2),
#                    main_title='Layout LMV3"')

# display_image_grid([layout_explanations_hms['BROS lime'], layout_explanations_hms['BROS shap']],
#                    [f'', f''],
#                     (1,2),
#                    main_title='BROS"')


## Vision

In [None]:
vision_explainers = {
    'LLMV3 lime' : LimeVisionNer(LLMV3, LLMV3_encode, device=device, target_token_fn=target_token_fn, label = [target_label]),
    'LLMV3 shap' : SHAPVisionNer(LLMV3, LLMV3_encode, device = device, class_idx=1,mask_value='blur(64,64)',target_token_fn=target_token_fn),
}


In [None]:
vision_explanations = {key: vision_explainers[key].explain(example, nsamples=1000) for key in vision_explainers if 'shap' in key}
vision_explanations.update({key: vision_explainers[key].explain(example, num_samples=1000) for key in vision_explainers if 'lime' in key})

In [None]:
def shap_overlay(exp, alpha=0.6, cmap="bwr"):
    vals  = exp.values
    img   = exp.data.astype(np.uint8)

    if vals.ndim == 3:
        vals = vals.mean(2)

    vmax  = np.abs(vals).max() + 1e-12
    norm  = (vals + vmax) / (2 * vmax)
    rgb   = plt.get_cmap(cmap)(norm)[..., :3] * 255
    blend = (img*(1-alpha) + rgb*alpha).astype(np.uint8)
    return Image.fromarray(blend)

shap_overlay(vision_explanations['LLMV3 shap'], alpha=0.6)

In [None]:
def lime_heatmap(expl, class_idx=1, alpha_max=0.6, cmap="bwr"):
    img     = expl.image.astype(float)
    segs    = expl.segments
    weights = dict(expl.local_exp[class_idx])
    vmax    = max(abs(w) for w in weights.values()) + 1e-12

    overlay = img.copy()
    for seg_id, w in weights.items():
        mask   = segs == seg_id
        color  = np.array(cm.get_cmap(cmap)((w+vmax)/(2*vmax))[:3]) * 255
        alpha  = alpha_max * abs(w) / vmax
        overlay[mask] = overlay[mask]*(1-alpha) + color*alpha

    return Image.fromarray(overlay.astype(np.uint8))
lime_heatmap(vision_explanations['LLMV3 lime'],alpha_max=0.6)

## Eval

In [None]:
from skimage.segmentation import slic
from Eval.evaluation import evaluate_sample

SLIC_KW = dict(n_segments=200, compactness=20.0, sigma=1.0, start_label=1)
all_expls = {
    # TEXT
    "[T] LLMV3 shap": text_explanations["LLMV3 shap"],
    "[T] LLMV3 lime": text_explanations["LLMV3 lime"],
    "[T] BROS  shap": text_explanations["BROS shap"],
    "[T] BROS  lime": text_explanations["BROS lime"],

    # LAYOUT
    "[L] LLMV3 shap": layout_explanations["LLMV3 shap"],
    "[L] LLMV3 lime": layout_explanations["LLMV3 lime"],
    "[L] BROS  shap": layout_explanations["BROS shap"],
    "[L] BROS  lime": layout_explanations["BROS lime"],

    # VISION
    "[V] LLMV3 shap": vision_explanations["LLMV3 shap"],
    "[V] LLMV3 lime": vision_explanations["LLMV3 lime"],
}

def extract_weights(expl, modality: str, cls_idx: int,
                    sample: DocSample | None = None,
                    slic_kw: dict = SLIC_KW):

    if hasattr(expl, "local_exp"):
        if modality == "vision":
            w = {int(seg): float(val)
                 for seg, val in expl.local_exp[cls_idx]}
            return w, expl.segments
        else:
            m = expl.as_map()[cls_idx]
            return {int(fid): float(w) for fid, w in m}, None


    if hasattr(expl, "values"):
        vals = expl.values

        if modality in ("text", "layout"):
            while vals.ndim > 2:
                vals = vals[0]
            if vals.ndim == 2:
                vals = vals[:, cls_idx]
            return {i: float(v) for i, v in enumerate(vals)}, None

        if modality == "vision":
            if sample is None:
                raise ValueError("sample must be supplied for SHAP-vision.")
            if vals.ndim == 4:
                vals = vals[0]
            if vals.ndim == 3:
                vals = vals.mean(-1)

            img_np = np.asarray(sample.image)
            seg = slic(img_np,
                       channel_axis=-1 if img_np.ndim == 3 else None,
                       **slic_kw)
            w = {int(sid): float(vals[seg == sid].mean())
                 for sid in np.unique(seg)}
            return w, seg

    raise RuntimeError("Unrecognised explanation object type.")



def get_target_class_id(model_tag: str,
                        sample: DocSample,
                        llmv3_fallback: int) -> int:

    if model_tag.upper() == "LLMV3":
        return getattr(sample, "label", None) or llmv3_fallback
    return 3



MODEL_HANDLES = {
    "BROS":  (BROS , BROS_encode ),
    "LLMV3": (LLMV3, LLMV3_encode),
}

fidelity_scores = {}

for tag, expl in all_expls.items():

    modality  = "text"   if tag.startswith("[T]") else \
               "layout" if tag.startswith("[L]") else "vision"
    model_tag = "LLMV3" if "LLMV3" in tag else "BROS"
    model, encode_fn = MODEL_HANDLES[model_tag]


    sample = example
    cls_id = get_target_class_id(model_tag, sample, llmv3_fallback=target_label)


    weights, seg = extract_weights(expl, modality, cls_id, sample)


    metrics = evaluate_sample(
        sample,
        weights,
        modality,
        model, encode_fn,
        top_k            = 10,
        device           = device,
        target_token_fn  = target_token_fn,
        target_label_id  = cls_id,
        segments         = seg,
    )

    fidelity_scores[tag] = metrics
    print(f"{tag:28s}  "
          f"comp = {metrics['comprehensiveness']:.4f}   "
          f"suff = {metrics['sufficiency']:.4f}")