<a href="https://colab.research.google.com/github/Ashu-00/NLP-Implementations/blob/main/XAI/Integrated_Gradients_from_Scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
!pip install transformers



## IG FUNCTION IMPLEMENTATION

In [None]:
from tqdm import tqdm

def ig_attribution_scores(model, predicted_token_id, input_embeds, steps = 50):

    # Freeze parameters



    original_requires_grad = {param: param.requires_grad for param in model.parameters()}
    for param in model.parameters():
        param.requires_grad_(False)



    ig = torch.zeros_like(input_embeds)
    baseline_embeds = torch.zeros_like(input_embeds)

    alphas = torch.linspace(0,1,steps)
    for alpha in tqdm(alphas):
      interpolated_embeds = baseline_embeds + alpha * (input_embeds - baseline_embeds)
      interpolated_embeds.requires_grad_(True)
      interpolated_embeds.retain_grad()

      logits = model(inputs_embeds = interpolated_embeds).logits

      target_logit = logits[0, -1, predicted_token_id]

      grads = torch.autograd.grad(
            outputs=target_logit,
            inputs=interpolated_embeds,
            retain_graph=False  # Critical: Free the graph immediately
        )[0]


      ig+= grads


    # Restore gradients
    for param in model.parameters():
        param.requires_grad_(original_requires_grad[param])


    average_grads = ig / steps
    ig = (input_embeds - baseline_embeds) * average_grads


    attributions = ig.sum(dim=2).squeeze(0).detach().numpy()
    return attributions




## GPT2 TEST

In [None]:
text = " I have two dogs named Jude, and a cat named mono. So I call my dog as"

In [None]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch

model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

input_ids = (tokenizer.encode(text, return_tensors='pt'))
input_embeds = model.transformer.wte(input_ids)

outputs = model(input_ids,)
predicted_token_id = torch.argmax(outputs.logits[0, -1, :]).item()


In [None]:
print(tokenizer.decode(predicted_token_id))

 Jude


In [None]:
attr = ig_attribution_scores(model, predicted_token_id, input_embeds)

100%|██████████| 50/50 [00:16<00:00,  2.98it/s]


In [None]:
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
for token, attribution in zip(tokens, attr):
    print(f"{token}: {attribution:.4f}")

ĠI: 1.6706
Ġhave: 45.3497
Ġtwo: 33.4860
Ġdogs: 11.9327
Ġnamed: 8.9556
ĠJude: 9.4381
,: 3.7672
Ġand: 2.1482
Ġa: 4.0684
Ġcat: 5.1920
Ġnamed: 6.7285
Ġmono: 5.2100
.: 3.6659
ĠSo: -1.7635
ĠI: -2.9428
Ġcall: -14.8203
Ġmy: -4.3303
Ġdog: -0.1566
Ġas: -15.9648


## DistilGPT -2

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
model = AutoModelForCausalLM.from_pretrained("distilgpt2")

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/762 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/353M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [None]:
input_ids = (tokenizer.encode(text, return_tensors='pt'))
input_embeds = model.transformer.wte(input_ids)

outputs = model(input_ids,)
predicted_token_id = torch.argmax(outputs.logits[0, -1, :]).item()

In [None]:
print(tokenizer.decode(predicted_token_id))

 Jude


In [None]:
attr = ig_attribution_scores(model, predicted_token_id, input_embeds)
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
for token, attribution in zip(tokens, attr):
    print(f"{token}: {attribution:.4f}")

100%|██████████| 50/50 [00:13<00:00,  3.66it/s]

ĠI: -0.0999
Ġhave: -4.6640
Ġtwo: -1.7825
Ġdogs: 0.7575
Ġnamed: -1.3957
ĠJude: 8.2272
,: -2.8709
Ġand: 0.6567
Ġa: 0.3707
Ġcat: 1.3484
Ġnamed: 0.1629
Ġmono: -3.0655
.: 1.1122
ĠSo: 4.1171
ĠI: -0.5236
Ġcall: 2.5838
Ġmy: 1.7970
Ġdog: 1.1569
Ġas: 4.4283



