In [1]:
import torch
# import math
import random
import numpy as np
import torch.nn.functional as F
# import inspect
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForQuestionAnswering
# import transformers
# import shap
torch.manual_seed(0)  # optional: reproducibility
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)

In [2]:
# model_name: str = "google/gemma-2-2b"
device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
# print(device)
# tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
# model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
# model.eval()

In [3]:
# mask_token = tokenizer.convert_ids_to_tokens(4)         
# mask_token_id = 4   
# print(mask_token, mask_token_id)
# mask_token_tensor = torch.tensor([[mask_token_id]], device=device)
# mask_embedding = model.get_input_embeddings()(mask_token_tensor.clone().contiguous())
# print(mask_embedding.shape)

In [4]:
# print_vocab cell
# Uses existing `tokenizer` in the notebook

# vocab = tokenizer.get_vocab()  # dict: token -> id
# print("Vocab size:", len(vocab))

# # sort by id and print first N entries
# sorted_items = sorted(vocab.items(), key=lambda kv: kv[1])
# N = 200
# print(f"\nFirst {N} tokens (id, token):")
# for tok, idx in sorted_items[:N]:
#     print(f"{idx:6d}\t{tok}")

# write full vocab to a file for inspection (useful since vocab is large)
# with open("vocab.txt", "w", encoding="utf-8") as f:
#     for tok, idx in sorted_items:
#         f.write(f"{idx}\t{tok}\n")
# print("\nFull vocab written to 'vocab.txt'")

In [5]:
def egrad_integral_causal_lm(
    text: str,
    a: float = 0,
    b: float = 1.0,
    steps: int = 20,
    model_name: str = "google/gemma-2-2b",
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
    mask_token_id: int = 4,  # your mask token ID
):
    """
    Integrated gradients for next-token prediction (causal LM).
    
    Interpolates between a masked embedding baseline and the actual input embeddings,
    computing gradients of the target (most probable next token) logit w.r.t. interpolation coefficient.
    
    Returns:
        dict with keys: "tokens", "acc_attributions", "attributions_steps", "db_scores", "ts"
    """
    # Load model & tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
    # model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
    model = AutoModelForQuestionAnswering.from_pretrained(model_name).to(device)
    model.eval()

    # Tokenize
    enc = tokenizer(text, return_tensors="pt", truncation=True)
    input_ids = enc["input_ids"].to(device)           # (1, L)
    attention_mask = enc["attention_mask"].to(device) # (1, L)
    
    # Get base embeddings X: (1, L, d)
    embed = model.get_input_embeddings()
    with torch.no_grad():
        X = embed(input_ids)  # (1, L, d)
    L, d = X.shape[1], X.shape[2]
    
    # Baseline: mask embedding repeated for all positions
    mask_token_tensor = torch.tensor([[mask_token_id]], device=device)
    mask_embedding = embed(mask_token_tensor.clone().contiguous())  # (1,1,d)
    X_RefMask = mask_embedding.repeat(1, L, 1)  # (1, L, d)
    
    # Run forward pass to find target token (most probable next token)
    with torch.no_grad():
        out = model(inputs_embeds=X, attention_mask=attention_mask)
        logits = out.logits[0, -1, :]  # (vocab_size,) at last position
        probs = F.softmax(logits, dim=-1)
        target_id = torch.argmax(probs)
        target_prob = probs[target_id].item()
        token_for_target = tokenizer.convert_ids_to_tokens([target_id.item()])[0]
    
    print(f"Target token: {token_for_target} (id={target_id.item()}, prob={target_prob:.4f})")
    
    # Integration grid
    t_vals = torch.linspace(a, b, steps, device=device, dtype=X.dtype)
    
    # Accumulators
    attr = np.zeros(L)
    attrs = []
    db_scores = []
    ts = [] 
    
    prev_label_score = None
    prev_lg_score = None
    sum_dlg = 0.0
    
    # Padding mask: don't attribute to BOS tokens
    padding_mask = torch.ones((L, 1), device=device, dtype=X.dtype)
    padding_mask[0] = 0   # skip first token
    
    for i, t in enumerate(t_vals):
        ts.append(t.item())
        
        # Interpolation coefficient per token: t * ones(L)
        ones_L = torch.ones(L, device=device, dtype=X.dtype)
        interpolate_v = t * ones_L  # (L,)
        interpolated_o = interpolate_v.view(L, 1)  # (L,1)
        
        # Add a dummy gradient variable
        ex = torch.zeros((L, 1), device=device, dtype=X.dtype).requires_grad_(True)
        interpolated_o = interpolated_o + ex  # to enable grad
        interpolated = interpolated_o.tile((1, d))  # (L, d)
        
        # Force first token to remain unchanged (coefficient=1)
        # interpolated_o[0,0] = 1.0 # hình như phải thêm cái này
        interpolated[0, :] = 1.0
        
        # Interpolated embeddings
        X_inter = X * interpolated + X_RefMask * (1 - interpolated)  # (1, L, d)
        
        # Add epsilon for attribution (only on non-padding positions)
        eps = torch.zeros((1, L, 1), device=device, dtype=X.dtype).requires_grad_(True)
        X_inter = X_inter + eps * padding_mask.unsqueeze(0)
        
        # Forward pass
        out = model(inputs_embeds=X_inter, attention_mask=attention_mask)
        logits_step = out.logits[0, -1, :]  # (vocab_size,)
        probs_step = F.softmax(logits_step, dim=-1)
        
        logit_score = logits_step[target_id]
        label_score = probs_step[target_id] # gọi tạm là label score nhưng vẫn là prob cho target output logit
        
        if i == 0:
            prev_label_score = label_score
            prev_lg_score = logit_score
        
        dlogit = logit_score - prev_lg_score # diff logit trc softmax
        dlb = label_score - prev_label_score # diff prob sau softmax
        prev_label_score = label_score
        prev_lg_score = logit_score
        sum_dlg += dlb.item()
        
        # Gradient of target logit w.r.t. interpolation coefficient
        (grad_eps,) = torch.autograd.grad(logit_score, interpolated_o, retain_graph=False, create_graph=False)
        grad_eps_n = grad_eps.squeeze()  # (L,)
        
        # Normalize gradient
        grad_eps_n = grad_eps_n / (torch.sum(grad_eps_n) + 1e-10)
        
        # Attribution: grad * change in logit 
        attri = grad_eps_n * dlogit # scale theo dlogit
        attri = attri.detach().cpu().numpy()
        
        attrs.append(attri)
        attr += attri
        
        db_scores.append((
            label_score.detach().cpu().item(),
            logit_score.detach().cpu().item(),
            dlb.detach().cpu().item(),
            t.item()
        ))
    
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    print(f"Sum delta prob: {sum_dlg:.6f}")
    
    return {
        "tokens": tokens,
        "acc_attributions": attr,
        "attributions_steps": attrs,
        "db_scores": db_scores,
        "ts": ts
    }

In [6]:
# For table step score
# LABEL_ID = 0
torch.manual_seed(11)
random.seed(11)
text = """
Context: Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.
Question: To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?
Answer:
"""
#text = "The movie was unhilariously funny, I mean it was bad"
#text = "The movie was visually stunning, but the plot was predictable"
#text = "it is not a mass-market entertainment but an uncompromising attempt by one artist to think about another."
#text = "it 's also heavy-handed and devotes too much time to bigoted views"
# text = "while the mystery surrounding the nature of the boat 's malediction remains intriguing enough to sustain mild interest , the picture refuses to offer much accompanying sustenance in the way of characterization , humor or plain old popcorn fun"
#text = "The food tasted awful and the place was dirty"
#text = "It is summer, but the weather is bad"
#text = "The movie was an emotional masterpiece — the storytelling was powerful, the cinematography was breathtaking, and the music added so much depth to every scene"
#text = "i can go from feeling so hopeless to so damned hopeful just from being around someone who cares and is awake"
res = egrad_integral_causal_lm(text, device=device, model_name = "Sadat07/bert-SQuAD", a=0, b=1, steps=101)
tokens = res["tokens"]
atts = res["acc_attributions"]
attr_steps = res["attributions_steps"]
db_scores = res["db_scores"]
ts = res["ts"]


min_idx = int(np.argmin(atts))
max_idx = int(np.argmax(atts))
db = {  "min_attr": float(atts[min_idx]),
        "min_token": tokens[min_idx],
        "max_attr": float(atts[max_idx]),
        "max_token": tokens[max_idx],
        "sum_attr": float(atts.sum())}
print(db)

print(f"Tokens & attributions for target logit:") 
for tok in res["tokens"]:
    print(f"{tok:10s}", end=" ")
for tok in ["Prob_score", "Logit_score", "Delta_score", "t"]:
    print(f"{tok:10s}", end=" ")   
print()
for i, attri in enumerate(attr_steps):
    for val in attri.tolist():
        print(f"{val:+2.6f}", end="  ")
    db_score = db_scores[i]
    for v in db_score:
        print(f"{v:+2.6f}", end="  ")
    print()
for tok, val in zip(res["tokens"], atts.tolist()): 
    print(f"{tok:>12s} : {val:+.6f}")

model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

AttributeError: 'QuestionAnsweringModelOutput' object has no attribute 'logits'