In [1]:
# !git clone https://github.com/SD-interp/mech_interp_toolkit.git
# %cd mech_interp_toolkit
# !pip install -e .

# Test: gradient_based_attribution module

This notebook tests gradient-based attribution methods from `mech_interp_toolkit.gradient_based_attribution`.

Functions tested:
- `get_activations()` - Extract activations with gradient support
- `get_embeddings()` - Extract input embeddings
- `simple_integrated_gradients()` - Vanilla integrated gradients w.r.t embeddings
- `edge_attribution_patching()` - Simple gradient x activation method
- `eap_integrated_gradients()` - Integrated gradients for edge attributions

In [2]:
import sys
import os
sys.path.append(os.path.abspath("../src"))

import torch
from mech_interp_toolkit.utils import load_model_tokenizer_config, get_default_device, get_all_layer_components
from mech_interp_toolkit.gradient_based_attribution import (
    get_activations,
    get_embeddings,
    simple_integrated_gradients,
    edge_attribution_patching,
    eap_integrated_gradients,
)

## Setup: Load model

In [3]:
model_name = "Qwen/Qwen3-0.6B"
device = get_default_device()

print(f"Loading model {model_name} on {device}...")
model, tokenizer, config = load_model_tokenizer_config(model_name, device=device)
print("Model loaded successfully")
print(f"Number of layers: {config.num_hidden_layers}")
print(f"Hidden size: {config.hidden_size}")

Loading model Qwen/Qwen3-0.6B on cuda...


config.json:   0%|          | 0.00/726 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.50G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Model loaded successfully
Number of layers: 28
Hidden size: 1024


In [4]:
# Prepare clean and corrupted inputs for attribution methods
clean_prompts = ["The capital of France is"]
corrupted_prompts = ["The capital of Germany is"]

clean_inputs = tokenizer(clean_prompts, thinking=False)
corrupted_inputs = tokenizer(corrupted_prompts, thinking=False)

print(f"Clean input shape: {clean_inputs['input_ids'].shape}")
print(f"Corrupted input shape: {corrupted_inputs['input_ids'].shape}")

Clean input shape: torch.Size([1, 33])
Corrupted input shape: torch.Size([1, 33])


## Test: get_activations()

In [5]:
# Test extracting activations from specific layer components
layer_components = [(0, "attn"), (0, "mlp"), (5, "attn"), (5, "mlp")]

print("Extracting activations...")
activations = get_activations(model, clean_inputs, layer_components)

print(f"\nExtracted activations:")
for key, val in activations.items():
    print(f"  {key}: {val.shape}")

# Verify all requested components were extracted
for lc in layer_components:
    assert lc in activations, f"Missing {lc}"
print("\nPASSED: get_activations()")

Extracting activations...

Extracted activations:
  (0, 'attn'): torch.Size([1, 33, 1024])
  (0, 'mlp'): torch.Size([1, 33, 1024])
  (5, 'attn'): torch.Size([1, 33, 1024])
  (5, 'mlp'): torch.Size([1, 33, 1024])

PASSED: get_activations()


## Test: get_embeddings()

In [6]:
# Test extracting input embeddings
print("Extracting embeddings...")
embeddings = get_embeddings(model, clean_inputs)

print(f"Embeddings keys: {list(embeddings.keys())}")
print(f"Embeddings shape: {embeddings[(0, 'layer_in')].shape}")

# Should have shape (batch, seq_len, hidden_size)
batch_size, seq_len = clean_inputs["input_ids"].shape
assert embeddings[(0, "layer_in")].shape == (batch_size, seq_len, config.hidden_size)
print("\nPASSED: get_embeddings()")

Extracting embeddings...
Embeddings keys: [(0, 'layer_in')]
Embeddings shape: torch.Size([1, 33, 1024])

PASSED: get_embeddings()


## Test: edge_attribution_patching()

In [7]:
# Define a simple metric function
def metric_fn(logits):
    return logits[:, :].mean()

print("Running Edge Attribution Patching (EAP)...")
eap_scores = edge_attribution_patching(
    model=model,
    clean_inputs=clean_inputs,
    corrupted_inputs=corrupted_inputs,
    compute_grad_at="clean",
    metric_fn=metric_fn,
)

print(f"\nEAP scores computed for {len(eap_scores)} components")
print(f"EAP score type: {type(eap_scores).__name__}")
print(f"\nSample EAP scores (first 6 components):")
for key, score in list(eap_scores.items())[:6]:
    print(f"  {key}: {score}")

print("\nPASSED: edge_attribution_patching()")

Running Edge Attribution Patching (EAP)...

EAP scores computed for 56 components
EAP score type: ActivationDict

Sample EAP scores (first 6 components):
  (0, 'attn'): tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00, -3.8734e-01,  2.6416e-04,  2.1704e-04,
         -1.0432e-03,  1.3060e-03,  2.1108e-03, -1.0146e-03, -3.3073e-03,
         -2.3746e-03,  3.2487e-03,  1.1408e-03]], device='cuda:0',
       grad_fn=<SumBackward1>)
  (0, 'mlp'): tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.00

In [8]:
# Test with compute_grad_at="corrupted"
print("Running EAP with compute_grad_at='corrupted'...")
eap_scores_corrupted = edge_attribution_patching(
    model=model,
    clean_inputs=clean_inputs,
    corrupted_inputs=corrupted_inputs,
    compute_grad_at="corrupted",
    metric_fn=metric_fn,
)

print(f"EAP scores (corrupted): {len(eap_scores_corrupted)} components")
print("Sample scores:")
for key in list(eap_scores_corrupted.keys())[:3]:
    print(f"  {key}: {eap_scores_corrupted[key]}")

print("\nPASSED: EAP with corrupted gradients")

Running EAP with compute_grad_at='corrupted'...
EAP scores (corrupted): 56 components
Sample scores:
  (0, 'attn'): tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.1497, -0.0025,
         -0.0007, -0.0005,  0.0014,  0.0019, -0.0012, -0.0039, -0.0025,  0.0043,
          0.0013]], device='cuda:0', grad_fn=<SumBackward1>)
  (0, 'mlp'): tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  3.9296e-02, -2.3588e-03, -3.4855e-05,
         -1.6136e-03,  1.0848e-03, -5.0639e-04, -1.3447e-03, -3.2448e-03,
          1.1902e-03,  3.2278e-03, -1.885

In [9]:
# Analyze EAP scores by layer
n_layers = config.num_hidden_layers

print("\nEAP scores summary by layer (first 10 layers):")
print(f"{'Layer':<8} {'Attn':>12} {'MLP':>12}")
print("-" * 32)

for layer in range(min(10, n_layers)):
    attn_key = (layer, "attn")
    mlp_key = (layer, "mlp")

    attn_score = eap_scores.get(attn_key, torch.tensor(0.0)).mean().item() if attn_key in eap_scores else 0.0
    mlp_score = eap_scores.get(mlp_key, torch.tensor(0.0)).mean().item() if mlp_key in eap_scores else 0.0

    print(f"{layer:<8} {attn_score:>12.6f} {mlp_score:>12.6f}")


EAP scores summary by layer (first 10 layers):
Layer            Attn          MLP
--------------------------------
0           -0.011721    -0.003888
1            0.000119    -0.000652
2           -0.000105    -0.000404
3            0.000104    -0.002395
4           -0.000703    -0.000713
5            0.002193     0.003179
6           -0.000534     0.006681
7            0.001657    -0.008160
8            0.000075     0.002121
9            0.001505    -0.001130


## Test: simple_integrated_gradients()

In [10]:
# Get input embeddings to determine shape for baseline
inputs = tokenizer(["The quick brown fox"], thinking=False)
embeddings = get_embeddings(model, inputs)
input_embeddings = embeddings[(0, "layer_in")]

# Create baseline embeddings (zeros) - must be a torch.Tensor, same shape as input embeddings
baseline_embeddings = torch.zeros_like(input_embeddings)

print(f"Input sequence length: {inputs['input_ids'].shape[1]}")
print(f"Input embeddings shape: {input_embeddings.shape}")
print(f"Baseline embeddings shape: {baseline_embeddings.shape}")

Input sequence length: 32
Input embeddings shape: torch.Size([1, 32, 1024])
Baseline embeddings shape: torch.Size([1, 32, 1024])


In [11]:
print("Running Simple Integrated Gradients...")
with torch.enable_grad():
    ig_attributions = simple_integrated_gradients(
        model=model,
        inputs=inputs,
        baseline_embeddings=baseline_embeddings,
        metric_fn=metric_fn,
        steps=10,  # Using fewer steps for faster testing
    )

print(f"\nIG attributions type: {type(ig_attributions).__name__}")
print(f"IG attributions shape: {ig_attributions.shape}")

# Should return tensor of shape (batch, seq_len) - attributions summed over hidden dim
assert ig_attributions.ndim == 2, "IG should return 2D tensor (batch, seq_len)"
print("\nPASSED: simple_integrated_gradients()")

Running Simple Integrated Gradients...

IG attributions type: Tensor
IG attributions shape: torch.Size([1, 32])

PASSED: simple_integrated_gradients()


In [12]:
# Analyze IG attributions per position
print(f"\nIG attribution per position:")
for pos in range(ig_attributions.shape[1]):
    val = ig_attributions[0, pos].item()
    print(f"  Position {pos}: {val:.6f}")


IG attribution per position:
  Position 0: 0.005977
  Position 1: 0.009443
  Position 2: 0.000083
  Position 3: -0.024873
  Position 4: 0.003160
  Position 5: 0.020614
  Position 6: 0.023986
  Position 7: -0.009533
  Position 8: -0.026629
  Position 9: 0.037477
  Position 10: 0.021760
  Position 11: -0.007442
  Position 12: 0.062427
  Position 13: 0.004779
  Position 14: -0.010886
  Position 15: 0.003113
  Position 16: 0.056903
  Position 17: -0.019279
  Position 18: -0.026440
  Position 19: -0.032712
  Position 20: 0.039270
  Position 21: -0.009380
  Position 22: 0.044395
  Position 23: 0.142104
  Position 24: 0.059125
  Position 25: -0.025080
  Position 26: 0.002880
  Position 27: -0.028371
  Position 28: 0.150504
  Position 29: -0.055223
  Position 30: -0.270713
  Position 31: 0.004531


## Test: eap_integrated_gradients()

In [13]:
# Prepare inputs for EAP-IG (requires clean and corrupted inputs with same shape)
print(f"Clean inputs shape: {clean_inputs['input_ids'].shape}")
print(f"Corrupted inputs shape: {corrupted_inputs['input_ids'].shape}")

Clean inputs shape: torch.Size([1, 33])
Corrupted inputs shape: torch.Size([1, 33])


In [14]:
print("Running EAP Integrated Gradients...")
with torch.enable_grad():
    eap_ig = eap_integrated_gradients(
        model=model,
        clean_dict=clean_inputs,
        corrupted_dict=corrupted_inputs,
        metric_fn=metric_fn,
        intermediate_points=5,
    )

print(f"\nEAP-IG computed for {len(eap_ig)} components")
print(f"EAP-IG type: {type(eap_ig).__name__}")

print("\nSample EAP-IG scores (first 6 components):")
for key, val in list(eap_ig.items())[:6]:
    print(f"  {key}: {val}")

print("\nPASSED: eap_integrated_gradients()")

Running EAP Integrated Gradients...

EAP-IG computed for 56 components
EAP-IG type: ActivationDict

Sample EAP-IG scores (first 6 components):
  (0, 'attn'): tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.0707, -0.0024,
         -0.0003, -0.0004,  0.0014,  0.0020, -0.0010, -0.0036, -0.0024,  0.0041,
          0.0012]], device='cuda:0', grad_fn=<SumBackward1>)
  (0, 'mlp'): tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.0488, -0.0024,
          0.0003, -0.0017,  0.0011, -0.0002, -0.0012, -0.0029,  0.0012,  0.0033,
         -0.0004]], device='cuda:0', grad_fn=<SumBackward1>)
  (1, 'attn'): tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e



In [16]:
# Analyze EAP-IG scores by layer
print("\nEAP-IG summary by layer (first 10 layers):")
print(f"{'Layer':<8} {'Attn':>12} {'MLP':>12}")
print("-" * 32)

for layer in range(min(10, n_layers)):
    attn_val = eap_ig.get((layer, "attn"), torch.tensor(0.0)).mean()
    mlp_val = eap_ig.get((layer, "mlp"), torch.tensor(0.0)).mean()
    attn_score = attn_val.item() if isinstance(attn_val, torch.Tensor) else attn_val
    mlp_score = mlp_val.item() if isinstance(mlp_val, torch.Tensor) else mlp_val
    print(f"{layer:<8} {attn_score:>12.6f} {mlp_score:>12.6f}")


EAP-IG summary by layer (first 10 layers):
Layer            Attn          MLP
--------------------------------
0           -0.002189    -0.001570
1            0.000835    -0.000276
2           -0.000260    -0.000185
3           -0.000135    -0.003200
4           -0.000312    -0.001413
5            0.003100     0.003061
6           -0.001290     0.004233
7            0.000779    -0.005711
8           -0.000494     0.003605
9            0.002237    -0.001782


## Test: Custom metric functions

In [17]:
# Test different metric functions with edge_attribution_patching
def metric_sum(logits):
    return logits.sum()

def metric_mean(logits):
    return logits.mean()

def metric_max_prob(logits):
    probs = torch.softmax(logits, dim=-1)
    return probs.max()

metrics = [
    ("sum", metric_sum),
    ("mean", metric_mean),
    ("max_prob", metric_max_prob),
]

print("Testing EAP with different metrics:")
for name, metric in metrics:
    eap = edge_attribution_patching(
        model=model,
        clean_inputs=clean_inputs,
        corrupted_inputs=corrupted_inputs,
        metric_fn=metric,
    )
    # Get total attribution
    total = sum(v.sum().item() for v in eap.values())
    print(f"  {name}: total attribution = {total:.6f}")

print("\nPASSED: Custom metric functions")

Testing EAP with different metrics:
  sum: total attribution = -552618.310852
  mean: total attribution = -3.637181
  max_prob: total attribution = 0.007299

PASSED: Custom metric functions


## Test: Error handling

In [18]:
# Test that gradient-based methods require gradients enabled
print("Testing error handling for disabled gradients...")

try:
    with torch.no_grad():
        _ = simple_integrated_gradients(
            model=model,
            inputs=inputs,
            baseline_embeddings=baseline_embeddings,
            metric_fn=metric_fn,
            steps=5,
        )
    print("ERROR: Should have raised RuntimeError")
except RuntimeError as e:
    print(f"Correctly raised RuntimeError: {e}")

print("\nPASSED: Error handling")

Testing error handling for disabled gradients...
Correctly raised RuntimeError: Integrated Gradients requires gradient computation. Run with torch.enable_grad()

PASSED: Error handling


In [19]:
# Test shape mismatch validation for simple_integrated_gradients
print("Testing shape mismatch validation...")

# Create baseline with wrong shape
wrong_baseline = torch.zeros(1, 5, config.hidden_size, device=device, dtype=model.dtype)

try:
    with torch.enable_grad():
        _ = simple_integrated_gradients(
            model=model,
            inputs=inputs,
            baseline_embeddings=wrong_baseline,
            metric_fn=metric_fn,
            steps=5,
        )
    print("ERROR: Should have raised ValueError")
except ValueError as e:
    print(f"Correctly raised ValueError: {e}")

print("\nPASSED: Shape mismatch validation")

Testing shape mismatch validation...
Correctly raised ValueError: Baseline and input embeddings must have identical shape. Got baseline: torch.Size([1, 5, 1024]), input: torch.Size([1, 32, 1024])

PASSED: Shape mismatch validation


## Summary

In [20]:
print("="*50)
print("All gradient_based_attribution module tests PASSED!")
print("="*50)

All gradient_based_attribution module tests PASSED!
