# Pun Attribution with Inseq

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Nix07/neural-mechanics-web/blob/main/labs/week7/pun_attribution_inseq.ipynb)

This notebook demonstrates **attribution methods** for understanding which input tokens contribute to a model's predictions, using the [inseq](https://inseq.org/) library. We'll apply these techniques to **puns** to discover which words make the model recognize (or generate) the punchline.

**Key Idea:** Attribution methods assign importance scores to input tokens, helping us understand *why* a model made a particular prediction. For puns, we want to know: which words in the setup make the model predict the pun word?

## Methods Covered
- **Integrated Gradients**: Path-based attribution that satisfies theoretical axioms
- **Input x Gradient**: Simple gradient-based attribution
- **Attention-based attribution**: Using attention weights as proxies for importance

## References
- [Inseq Documentation](https://inseq.org/)
- [Integrated Gradients paper](https://arxiv.org/abs/1703.01365) (Sundararajan et al., 2017)
- [Attention is not Explanation](https://arxiv.org/abs/1902.10186) (Jain & Wallace, 2019)
- [Attention is not not Explanation](https://arxiv.org/abs/1908.04626) (Wiegreffe & Pinter, 2019)

## Setup

Install inseq and dependencies:

In [None]:
!pip install -q inseq transformers torch

In [None]:
import inseq
import torch
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, HTML

# Check available attribution methods
print("Available attribution methods:")
print(inseq.list_feature_attribution_methods())

## Part 1: Loading a Model with Inseq

Inseq wraps HuggingFace models to enable attribution. We'll use GPT-2 for this tutorial since it runs efficiently on CPU/Colab.

In [None]:
# Load GPT-2 with Integrated Gradients attribution
model = inseq.load_model("gpt2", "integrated_gradients")

print(f"Model: {model.model_name}")
print(f"Attribution method: {model.attribution_method}")

## Part 2: Basic Attribution on a Pun

Let's see which words in a pun setup contribute to predicting the punchline word.

In [None]:
# Our classic electrician pun
pun_prompt = "Why do electricians make good swimmers? Because they know the"
pun_completion = " current"

# Run attribution
# We attribute the generation of the completion to the input tokens
out = model.attribute(
    input_texts=pun_prompt,
    generated_texts=pun_prompt + pun_completion,
    n_steps=50,  # Number of integration steps
    internal_batch_size=10
)

# Visualize the attribution
out.show()

In [None]:
# Get the attribution scores as a numpy array
attr_scores = out.sequence_attributions[0].source_attributions
tokens = out.sequence_attributions[0].source

print("Attribution scores for predicting 'current':")
print("=" * 50)
for token, score in zip(tokens, attr_scores[-1]):  # -1 for last generated token
    print(f"{repr(token):15} {score:.4f}")

## Part 3: Comparing Attribution Methods

Different attribution methods can give different results. Let's compare Integrated Gradients, Input x Gradient, and Attention.

In [None]:
def get_attributions(prompt, completion, method="integrated_gradients"):
    """Get attribution scores for a prompt-completion pair."""
    model = inseq.load_model("gpt2", method)
    
    kwargs = {}
    if method == "integrated_gradients":
        kwargs = {"n_steps": 50, "internal_batch_size": 10}
    
    out = model.attribute(
        input_texts=prompt,
        generated_texts=prompt + completion,
        **kwargs
    )
    
    seq_attr = out.sequence_attributions[0]
    tokens = seq_attr.source
    # Get attribution for the completion token(s)
    scores = seq_attr.source_attributions[-1]  # For last generated token
    
    return tokens, scores

# Compare methods on our pun
methods = ["integrated_gradients", "input_x_gradient", "attention"]
results = {}

for method in methods:
    print(f"Running {method}...")
    tokens, scores = get_attributions(pun_prompt, pun_completion, method)
    results[method] = (tokens, scores)
    print(f"  Done!")

In [None]:
# Visualize comparison
fig, axes = plt.subplots(len(methods), 1, figsize=(14, 3*len(methods)))

for ax, method in zip(axes, methods):
    tokens, scores = results[method]
    scores_np = np.array([s.item() if hasattr(s, 'item') else float(s) for s in scores])
    
    colors = ['red' if s < 0 else 'blue' for s in scores_np]
    ax.bar(range(len(tokens)), np.abs(scores_np), color=colors, alpha=0.7)
    ax.set_xticks(range(len(tokens)))
    ax.set_xticklabels(tokens, rotation=45, ha='right', fontsize=8)
    ax.set_ylabel('|Attribution|')
    ax.set_title(f'{method}')
    ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)

plt.suptitle(f'Attribution Comparison: "{pun_prompt}{pun_completion}"', fontsize=12)
plt.tight_layout()
plt.show()

## Part 4: Which Words Make Puns Punny?

Let's analyze multiple puns to find patterns in what words drive pun recognition.

In [None]:
puns = [
    ("Why do electricians make good swimmers? Because they know the", " current"),
    ("Why did the banker break up with his girlfriend? He lost", " interest"),
    ("Why do cows wear bells? Because their horns don't", " work"),
    ("What do you call a fish without eyes? A", " f"),
    ("I used to hate facial hair, but then it grew on", " me"),
]

# Load model once
ig_model = inseq.load_model("gpt2", "integrated_gradients")

def analyze_pun(model, prompt, completion):
    """Analyze which tokens contribute most to the pun completion."""
    out = model.attribute(
        input_texts=prompt,
        generated_texts=prompt + completion,
        n_steps=30,
        internal_batch_size=10
    )
    
    seq_attr = out.sequence_attributions[0]
    tokens = seq_attr.source
    scores = seq_attr.source_attributions[-1]
    scores_np = np.array([s.item() if hasattr(s, 'item') else float(s) for s in scores])
    
    # Find top contributing tokens
    top_indices = np.argsort(np.abs(scores_np))[-5:][::-1]
    
    return tokens, scores_np, top_indices

In [None]:
print("Top contributing tokens for each pun:")
print("=" * 60)

for prompt, completion in puns:
    tokens, scores, top_idx = analyze_pun(ig_model, prompt, completion)
    
    print(f"\nPun: '{prompt}{completion}'")
    print(f"Punchline: '{completion}'")
    print("Top contributing tokens:")
    for i, idx in enumerate(top_idx):
        print(f"  {i+1}. {repr(tokens[idx]):15} (score: {scores[idx]:.4f})")

## Part 5: Pun vs Literal Context Attribution

How does attribution differ when the same word appears in a pun vs literal context?

In [None]:
# Same completion word, different contexts
pun_context = "Why do electricians make good swimmers? Because they know the"
literal_context = "The electrician measured the electrical"
target_word = " current"

# Get attributions for both
pun_tokens, pun_scores, _ = analyze_pun(ig_model, pun_context, target_word)
lit_tokens, lit_scores, _ = analyze_pun(ig_model, literal_context, target_word)

# Visualize side by side
fig, axes = plt.subplots(2, 1, figsize=(14, 6))

# Pun context
colors = ['red' if s < 0 else 'blue' for s in pun_scores]
axes[0].bar(range(len(pun_tokens)), np.abs(pun_scores), color=colors, alpha=0.7)
axes[0].set_xticks(range(len(pun_tokens)))
axes[0].set_xticklabels(pun_tokens, rotation=45, ha='right', fontsize=9)
axes[0].set_ylabel('|Attribution|')
axes[0].set_title(f'PUN: "{pun_context}{target_word}"')

# Literal context
colors = ['red' if s < 0 else 'blue' for s in lit_scores]
axes[1].bar(range(len(lit_tokens)), np.abs(lit_scores), color=colors, alpha=0.7)
axes[1].set_xticks(range(len(lit_tokens)))
axes[1].set_xticklabels(lit_tokens, rotation=45, ha='right', fontsize=9)
axes[1].set_ylabel('|Attribution|')
axes[1].set_title(f'LITERAL: "{literal_context}{target_word}"')

plt.tight_layout()
plt.show()

In [None]:
# Quantify the difference
print("Comparison of top contributors:")
print("\nPUN context - Top 5 tokens:")
top_pun = np.argsort(np.abs(pun_scores))[-5:][::-1]
for idx in top_pun:
    print(f"  {repr(pun_tokens[idx]):15} {pun_scores[idx]:.4f}")

print("\nLITERAL context - Top 5 tokens:")
top_lit = np.argsort(np.abs(lit_scores))[-5:][::-1]
for idx in top_lit:
    print(f"  {repr(lit_tokens[idx]):15} {lit_scores[idx]:.4f}")

## Part 6: Validating Attribution with Ablation

Attribution scores claim certain tokens are important. Let's validate by removing high-attribution tokens and measuring the change in prediction.

In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# Load raw model for ablation experiments
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
raw_model = GPT2LMHeadModel.from_pretrained("gpt2")
raw_model.eval()

def get_token_probability(model, tokenizer, prompt, target_token):
    """Get the probability of target_token given prompt."""
    inputs = tokenizer(prompt, return_tensors="pt")
    target_ids = tokenizer.encode(target_token, add_special_tokens=False)
    
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits[0, -1]  # Last position
        probs = torch.softmax(logits, dim=-1)
        return probs[target_ids[0]].item()

# Original probability
orig_prob = get_token_probability(raw_model, tokenizer, pun_prompt, pun_completion)
print(f"Original P('{pun_completion}'): {orig_prob:.4f}")

In [None]:
def ablate_token(prompt, token_idx, replacement=""):
    """Replace a token at the given position."""
    tokens = tokenizer.encode(prompt)
    # Decode each token, replace the one at token_idx
    token_strs = [tokenizer.decode([t]) for t in tokens]
    token_strs[token_idx] = replacement
    return "".join(token_strs)

# Get attribution scores
tokens, scores, top_indices = analyze_pun(ig_model, pun_prompt, pun_completion)

print("Ablation study: removing high-attribution tokens")
print("=" * 60)
print(f"Original prompt: '{pun_prompt}'")
print(f"Original P('{pun_completion}'): {orig_prob:.4f}")
print()

# Ablate top contributing tokens
for idx in top_indices[:5]:
    token = tokens[idx]
    ablated_prompt = ablate_token(pun_prompt, idx, "[MASK]")
    
    # We can't easily get probability with masked token in GPT-2
    # Instead, let's just remove the token
    tokens_list = tokenizer.encode(pun_prompt)
    tokens_removed = tokens_list[:idx] + tokens_list[idx+1:]
    reduced_prompt = tokenizer.decode(tokens_removed)
    
    new_prob = get_token_probability(raw_model, tokenizer, reduced_prompt, pun_completion)
    delta = new_prob - orig_prob
    
    print(f"Remove {repr(token):12} (attr: {scores[idx]:+.4f}): P = {new_prob:.4f} (delta: {delta:+.4f})")

In [None]:
# Compare: ablate LOW attribution tokens
print("\nAblating LOW-attribution tokens (should have less effect):")
print("=" * 60)

low_indices = np.argsort(np.abs(scores))[:5]  # Lowest attribution

for idx in low_indices:
    token = tokens[idx]
    tokens_list = tokenizer.encode(pun_prompt)
    tokens_removed = tokens_list[:idx] + tokens_list[idx+1:]
    reduced_prompt = tokenizer.decode(tokens_removed)
    
    new_prob = get_token_probability(raw_model, tokenizer, reduced_prompt, pun_completion)
    delta = new_prob - orig_prob
    
    print(f"Remove {repr(token):12} (attr: {scores[idx]:+.4f}): P = {new_prob:.4f} (delta: {delta:+.4f})")

## Exercise 1: Attribution Across Layers

Inseq can compute layer-wise attributions. Do different layers attribute importance to different tokens?

In [None]:
# TODO: Use inseq's layer attribution capabilities to:
# 1. Get attribution scores at each layer
# 2. Create a heatmap showing attribution by layer and token
# 3. Identify which layers focus on semantically relevant tokens

# Hint: Check inseq.list_aggregators() for layer aggregation options
print("Available aggregators:")
print(inseq.list_aggregators())

## Exercise 2: Attention Pattern Analysis

Compare attention-based attribution with gradient-based methods. Are they correlated?

In [None]:
# TODO: 
# 1. Get attention-based attribution scores
# 2. Get integrated gradients scores
# 3. Compute correlation between them
# 4. Discuss: When do they agree? When do they disagree?

# Load attention model
attention_model = inseq.load_model("gpt2", "attention")

# Your code here...

## Exercise 3: Contrastive Attribution

What makes "current" the right answer vs a wrong answer? Use contrastive attribution.

In [None]:
# TODO: Compare attribution for correct vs incorrect completions
# 
# For the prompt "Why do electricians make good swimmers? Because they know the"
# Compare:
# - Correct: " current" (pun answer)
# - Wrong: " water" (literal but wrong)
# - Wrong: " answer" (generic)
#
# Which tokens distinguish the correct pun answer from incorrect ones?

completions_to_compare = [" current", " water", " answer", " secret"]

# Your code here...

## Exercise 4: Critical Evaluation of Attribution

Attribution methods have limitations. Design experiments to test their reliability.

In [None]:
# TODO: Test attribution reliability
#
# 1. Sanity check: Do random baselines give different results?
# 2. Consistency: Run attribution multiple times - are results stable?
# 3. Sensitivity: Small prompt changes shouldn't drastically change attribution
#
# Refer to Adebayo et al., "Sanity Checks for Saliency Maps" for inspiration

# Your code here...

## Summary

In this notebook, we learned:

1. **Attribution methods** assign importance scores to input tokens, explaining model predictions

2. **Integrated Gradients** provides theoretically grounded attribution that satisfies key axioms

3. **Different methods can disagree** - attention, gradient, and path-based methods may highlight different tokens

4. **Validation is crucial** - ablation studies help verify that high-attribution tokens actually affect predictions

5. **For puns**, attribution reveals which setup words prime the model for the punchline

### Key Questions

- Do attribution methods capture *why* the model predicts a pun, or just *what* it attends to?
- How reliable are these explanations? (See Jain & Wallace, 2019)
- Can we use attribution to find where pun understanding happens in the model?

### Limitations to Consider

- Attribution shows correlation, not causation
- Results depend on implementation details (baseline, integration steps)
- May not capture distributed representations well
- Should be validated with causal methods (activation patching, ablation)