In [None]:
"""

Integrated Gradients note book for VRDU models

"""

In [7]:
!find . -name "__pycache__" -exec rm -rf {} +

In [10]:
!pip uninstall datasets seqeval evaluate transformers torch captum y

Found existing installation: datasets 3.6.0
Uninstalling datasets-3.6.0:
  Would remove:
    /usr/local/bin/datasets-cli
    /usr/local/lib/python3.11/dist-packages/datasets-3.6.0.dist-info/*
    /usr/local/lib/python3.11/dist-packages/datasets/*
Proceed (Y/n)? y
  Successfully uninstalled datasets-3.6.0
Found existing installation: seqeval 1.2.2
Uninstalling seqeval-1.2.2:
  Would remove:
    /usr/local/lib/python3.11/dist-packages/seqeval-1.2.2.dist-info/*
    /usr/local/lib/python3.11/dist-packages/seqeval/*
Proceed (Y/n)? y
  Successfully uninstalled seqeval-1.2.2
Found existing installation: evaluate 0.4.3
Uninstalling evaluate-0.4.3:
  Would remove:
    /usr/local/bin/evaluate-cli
    /usr/local/lib/python3.11/dist-packages/evaluate-0.4.3.dist-info/*
    /usr/local/lib/python3.11/dist-packages/evaluate/*
Proceed (Y/n)? y
  Successfully uninstalled evaluate-0.4.3
Found existing installation: transformers 4.51.3
Uninstalling transformers-4.51.3:
  Would remove:
    /usr/local/bin/t

In [2]:
!pip install --upgrade datasets seqeval evaluate transformers torch captum



In [3]:
from transformers import AutoProcessor, LayoutLMv3ForTokenClassification, set_seed
from PIL import Image,ImageDraw, ImageFont
from datasets import load_dataset
import torch
import pandas as pd
import evaluate
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


set_seed(0)


RuntimeError: Failed to import transformers.models.auto.processing_auto because of the following error (look up to see its traceback):
operator torchvision::nms does not exist

python3: can't open file '/content/test_imports.py': [Errno 2] No such file or directory


In [None]:
funsd = load_dataset("nielsr/funsd")
labels = funsd["train"].features["ner_tags"].feature.names
id2label = {v: k for v, k in enumerate(labels)}
label2id = {v: k for k, v in enumerate(labels)}
print(id2label)
print(label2id)

In [None]:
model = LayoutLMv3ForTokenClassification.from_pretrained("adamadam111/layoutlmv3-finetuned-funsd").to(device).eval()
processor = AutoProcessor.from_pretrained("adamadam111/layoutlmv3-finetuned-funsd")

In [None]:
for name, module in model.named_modules():
    if 'embed' in name:
        print(name)

In [None]:
def unnormalize_box(bbox, width, height):
     return [
         width * (bbox[0] / 1000),
         height * (bbox[1] / 1000),
         width * (bbox[2] / 1000),
         height * (bbox[3] / 1000),
     ]

In [None]:
sample = funsd["test"][0]

enc = processor(
    Image.open(sample["image_path"]).convert("RGB"),
    sample["words"],
    boxes=sample["bboxes"],
    word_labels=sample["ner_tags"],
    padding="max_length",
    return_tensors = 'pt'
).to(device)

with torch.no_grad():
  out = model(**enc)

pred_ids = out.logits.argmax(-1).squeeze()
print(len(pred_ids))
question_label_id = model.config.label2id["B-QUESTION"]

In [None]:
question_token_idx = (pred_ids == question_label_id).nonzero(as_tuple=True)[0]
print(question_token_idx)

In [None]:
def token_logit_forward(input_ids, bbox, pixel_values, attention_mask, token_index):
  logits = model(input_ids = input_ids,
                 bbox=bbox,
                 pixel_values=pixel_values,
                 attention_mask=attention_mask).logits

  return logits[0, token_index, question_label_id]

In [None]:
from captum.attr import LayerIntegratedGradients
lig = LayerIntegratedGradients(token_logit_forward, model.layoutlmv3.embeddings)

In [None]:
print(question_token_idx)

In [None]:
from tqdm import tqdm
attributions = {}

pad_ids = torch.full_like(enc.input_ids, processor.tokenizer.pad_token_id)

for idx in tqdm(question_token_idx.tolist()):
  attrs, delta = lig.attribute(
      inputs = enc.input_ids,
      baselines = pad_ids,
      additional_forward_args = (enc.bbox, enc.pixel_values, enc.attention_mask, idx),
      n_steps = 50,
      return_convergence_delta = True)


  attributions[idx] = attrs.sum(-1).squeeze().detatch().cpu()

In [None]:
tokens_labeled_question = [(t, idx, highest_logit[idx]) for idx, (t, l) in enumerate(zip(input_ids, pred_labels)) if l == "B-QUESTION"] # (token_id, idx)
print(len(tokens_labeled_question))

### Extract embeddings from model

In [None]:
embedding_list = []

def get_input_embeddings_hook(module, input_, output):
    embedding_list.append(output.detach())

hook = model.layoutlmv3.embeddings.register_forward_hook(get_input_embeddings_hook)

# Single forward pass to pupulate embedding list
foo = model(
    input_ids=encode_sample.input_ids,
    attention_mask=encode_sample.attention_mask,
    bbox=encode_sample.bbox,
    pixel_values=encode_sample.pixel_values,
)

hook.remove()

### Run IG over tokens labeled 'B-QUESTION'

In [None]:
from captum.attr import IntegratedGradients

label = "B-QUESTION"
question_label_id = label2id[label]
input_embeddings = embedding_list[0].requires_grad_()

position_ids = torch.arange(input_embeddings.size(1), dtype=torch.long).unsqueeze(0).to(input_embeddings.device)
bbox = encode_sample["bbox"].to(input_embeddings.device)
image_size = torch.tensor([[224, 224]]).to(input_embeddings.device)

for token_id, idx, scalar_logit in tokens_labeled_question:
  def model_forward_fn(embeds):
    out = model.layoutlmv3.encoder(
        embeds,                            # hidden_states
        attention_mask=encode_sample["attention_mask"],
        bbox=bbox,
        position_ids=position_ids,
        )
    sequence_output = out.last_hidden_state
    logits = model.layoutlmv3.classifier(sequence_output)
    return out.logits.squeeze()[idx][question_label_id]

  # Compute attributions
  ig = IntegratedGradients(model_forward_fn)
  attributions = ig.attribute(
    inputs=input_embeddings,
    n_steps=50,
    target=question_label_id
  )