In [1]:
# !git clone https://github.com/SD-interp/mech_interp_toolkit.git
# %cd mech_interp_toolkit
# !pip install -e .

# Test: gradient_based_attribution module

This notebook tests gradient-based attribution methods from `mech_interp_toolkit.gradient_based_attribution`.

In [2]:
import sys
import os
sys.path.append(os.path.abspath("../src"))

import torch
from mech_interp_toolkit.utils import load_model_tokenizer_config, get_default_device, get_all_layer_components
from mech_interp_toolkit.activation_dict import ActivationDict
from mech_interp_toolkit.gradient_based_attribution import (
    edge_attribution_patching,
    simple_integrated_gradients,
    eap_integrated_gradients,
)

## Setup: Load model

In [3]:
model_name = "Qwen/Qwen3-0.6B"
device = get_default_device()

print(f"Loading model {model_name} on {device}...")
model, tokenizer, config = load_model_tokenizer_config(model_name, device=device)
print("Model loaded successfully")
print(f"Number of layers: {config.num_hidden_layers}")
print(f"Hidden size: {config.hidden_size}")

Loading model Qwen/Qwen3-0.6B on cuda...
Model loaded successfully
Number of layers: 28
Hidden size: 1024


In [4]:
# Prepare clean and corrupted inputs for EAP
clean_prompts = ["The capital of France is"]
corrupted_prompts = ["The capital of Germany is"]

clean_inputs = tokenizer(clean_prompts, thinking=False)
corrupted_inputs = tokenizer(corrupted_prompts, thinking=False)

print(f"Clean input shape: {clean_inputs['input_ids'].shape}")
print(f"Corrupted input shape: {corrupted_inputs['input_ids'].shape}")

Clean input shape: torch.Size([1, 33])
Corrupted input shape: torch.Size([1, 33])


## Test: edge_attribution_patching()

In [5]:
# Define a simple metric function
def metric_fn(logits):
    return logits[:, -1, :].max(dim=-1).values.sum()

print("Running Edge Attribution Patching (EAP)...")
eap_scores = edge_attribution_patching(
    model=model,
    clean_inputs=clean_inputs,
    corrupted_inputs=corrupted_inputs,
    compute_grad_at="clean",
    metric_fn=metric_fn,
    position=-1,
)

print(f"\nEAP scores computed for {len(eap_scores)} components")
print(f"EAP score type: {type(eap_scores)}")
print(f"\nSample EAP scores:")
for key, score in list(eap_scores.items())[:6]:
    print(f"  {key}: {score}")

Running Edge Attribution Patching (EAP)...


                    All positions will be captured by default. Use output.extract_positions() instead


(0, 'attn') >> True
(0, 'mlp') >> True
(1, 'attn') >> True
(1, 'mlp') >> True
(2, 'attn') >> True
(2, 'mlp') >> True
(3, 'attn') >> True
(3, 'mlp') >> True
(4, 'attn') >> True
(4, 'mlp') >> True
(5, 'attn') >> True
(5, 'mlp') >> True
(6, 'attn') >> True
(6, 'mlp') >> True
(7, 'attn') >> True
(7, 'mlp') >> True
(8, 'attn') >> True
(8, 'mlp') >> True
(9, 'attn') >> True
(9, 'mlp') >> True
(10, 'attn') >> True
(10, 'mlp') >> True
(11, 'attn') >> True
(11, 'mlp') >> True
(12, 'attn') >> True
(12, 'mlp') >> True
(13, 'attn') >> True
(13, 'mlp') >> True
(14, 'attn') >> True
(14, 'mlp') >> True
(15, 'attn') >> True
(15, 'mlp') >> True
(16, 'attn') >> True
(16, 'mlp') >> True
(17, 'attn') >> True
(17, 'mlp') >> True
(18, 'attn') >> True
(18, 'mlp') >> True
(19, 'attn') >> True
(19, 'mlp') >> True
(20, 'attn') >> True
(20, 'mlp') >> True
(21, 'attn') >> True
(21, 'mlp') >> True
(22, 'attn') >> True
(22, 'mlp') >> True
(23, 'attn') >> True
(23, 'mlp') >> True
(24, 'attn') >> True
(24, 'mlp') >> 

In [6]:
# Analyze EAP scores
n_layers = config.num_hidden_layers

print("\nEAP scores summary by layer:")
print(f"{'Layer':<8} {'Attn':>12} {'MLP':>12}")
print("-" * 32)

for layer in range(n_layers):
    attn_key = (layer, "attn")
    mlp_key = (layer, "mlp")

    attn_score = eap_scores.get(attn_key, torch.tensor(0.0)).item() if attn_key in eap_scores else 0.0
    mlp_score = eap_scores.get(mlp_key, torch.tensor(0.0)).item() if mlp_key in eap_scores else 0.0

    print(f"{layer:<8} {attn_score:>12.6f} {mlp_score:>12.6f}")

print("PASSED: Edge Attribution Patching")


EAP scores summary by layer:
Layer            Attn          MLP
--------------------------------
0            0.103439     0.102436
1            0.020723     0.034075
2           -0.022823    -0.032630
3            0.016342     0.007022
4           -0.014269    -0.012199
5            0.033902     0.028130
6           -0.005734    -0.002962
7           -0.001622     0.007245
8            0.011429    -0.044762
9           -0.005803    -0.013749
10          -0.004995     0.040908
11          -0.086160     0.022783
12          -0.005039    -0.021428
13          -0.013774     0.034420
14          -0.020320    -0.038764
15           0.009910    -0.008491
16          -0.000997    -0.026708
17           0.038213     0.061616
18          -0.006833    -0.009788
19           0.007158    -0.010926
20          -0.011855    -0.019960
21           0.008507    -0.054610
22          -0.000294     0.004383
23          -0.010230    -0.017337
24           0.004418    -0.006846
25          -0.001784     0

In [7]:
# Test with compute_grad_at="corrupted"
print("\nRunning EAP with compute_grad_at='corrupted'...")
eap_scores_corrupted = edge_attribution_patching(
    model=model,
    clean_inputs=clean_inputs,
    corrupted_inputs=corrupted_inputs,
    compute_grad_at="corrupted",
    metric_fn=metric_fn,
    position=-1,
)

print(f"EAP scores (corrupted): {len(eap_scores_corrupted)} components")
print("Sample scores:")
for key in list(eap_scores_corrupted.keys())[:3]:
    print(f"  {key}: {eap_scores_corrupted[key]}")

print("PASSED: EAP with corrupted gradients")


Running EAP with compute_grad_at='corrupted'...
(0, 'attn') >> True
(0, 'mlp') >> True
(1, 'attn') >> True
(1, 'mlp') >> True
(2, 'attn') >> True
(2, 'mlp') >> True
(3, 'attn') >> True
(3, 'mlp') >> True
(4, 'attn') >> True
(4, 'mlp') >> True
(5, 'attn') >> True
(5, 'mlp') >> True
(6, 'attn') >> True
(6, 'mlp') >> True
(7, 'attn') >> True
(7, 'mlp') >> True
(8, 'attn') >> True
(8, 'mlp') >> True
(9, 'attn') >> True
(9, 'mlp') >> True
(10, 'attn') >> True
(10, 'mlp') >> True
(11, 'attn') >> True
(11, 'mlp') >> True
(12, 'attn') >> True
(12, 'mlp') >> True
(13, 'attn') >> True
(13, 'mlp') >> True
(14, 'attn') >> True
(14, 'mlp') >> True
(15, 'attn') >> True
(15, 'mlp') >> True
(16, 'attn') >> True
(16, 'mlp') >> True
(17, 'attn') >> True
(17, 'mlp') >> True
(18, 'attn') >> True
(18, 'mlp') >> True
(19, 'attn') >> True
(19, 'mlp') >> True
(20, 'attn') >> True
(20, 'mlp') >> True
(21, 'attn') >> True
(21, 'mlp') >> True
(22, 'attn') >> True
(22, 'mlp') >> True
(23, 'attn') >> True
(23, 'm

## Test: simple_integrated_gradients()

In [8]:
# Create baseline embeddings (zeros)
inputs = tokenizer(["The quick brown fox"], thinking=False)
seq_len = inputs["input_ids"].shape[1]

baseline_embeddings = ActivationDict(config, positions=slice(None))
baseline_embeddings[(0, "layer_in")] = torch.zeros(
    1, seq_len, config.hidden_size,
    device=device,
    dtype=model.dtype
)

print(f"Input sequence length: {seq_len}")
print(f"Baseline embeddings shape: {baseline_embeddings[(0, 'layer_in')].shape}")

Input sequence length: 32
Baseline embeddings shape: torch.Size([1, 32, 1024])


In [9]:
print("Running Simple Integrated Gradients...")
ig_attributions = simple_integrated_gradients(
    model=model,
    inputs=inputs,
    baseline_embeddings=baseline_embeddings,
    metric_fn=metric_fn,
    steps=10,  # Using fewer steps for faster testing
)

print(f"\nIG attributions computed")
print(f"Keys: {list(ig_attributions.keys())}")

for key, val in ig_attributions.items():
    print(f"  {key}: shape={val.shape}")

print("PASSED: Simple Integrated Gradients")

Running Simple Integrated Gradients...
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True

IG attributions computed
Keys: [(0, 'layer_in')]
  (0, 'layer_in'): shape=torch.Size([1, 32])
PASSED: Simple Integrated Gradients


In [10]:
# Analyze IG attributions per position
ig_values = ig_attributions[(0, "layer_in")]
print(f"\nIG attribution shape: {ig_values.shape}")
print(f"\nIG attribution per position:")

for pos in range(ig_values.shape[1]):
    print(f"  Position {pos}: {ig_values[0, pos].item():.6f}")


IG attribution shape: torch.Size([1, 32])

IG attribution per position:
  Position 0: nan
  Position 1: nan
  Position 2: nan
  Position 3: nan
  Position 4: nan
  Position 5: nan
  Position 6: nan
  Position 7: nan
  Position 8: nan
  Position 9: nan
  Position 10: nan
  Position 11: nan
  Position 12: nan
  Position 13: nan
  Position 14: nan
  Position 15: nan
  Position 16: nan
  Position 17: nan
  Position 18: nan
  Position 19: nan
  Position 20: nan
  Position 21: nan
  Position 22: nan
  Position 23: nan
  Position 24: nan
  Position 25: nan
  Position 26: nan
  Position 27: nan
  Position 28: nan
  Position 29: nan
  Position 30: nan
  Position 31: nan


In [11]:
# Test with different step counts
print("\nTesting IG with different step counts:")
for steps in [5, 10, 20]:
    ig = simple_integrated_gradients(
        model=model,
        inputs=inputs,
        baseline_embeddings=baseline_embeddings,
        metric_fn=metric_fn,
        steps=steps,
    )
    total_attribution = ig[(0, "layer_in")].sum().item()
    print(f"  Steps={steps}: total attribution = {total_attribution:.6f}")

print("PASSED: IG with different step counts")


Testing IG with different step counts:
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
  Steps=5: total attribution = nan
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
  Steps=10: total attribution = nan
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
(0, 'layer_in') >> True
  Steps=20: total attribution = nan
PASSED: IG wi

## Test: eap_integrated_gradients()

In [12]:
# Get baseline embeddings for EAP-IG
inputs = tokenizer(["The capital of France is"], thinking=False)
seq_len = inputs["input_ids"].shape[1]

baseline_embeddings = ActivationDict(config, positions=slice(None))
baseline_embeddings[(0, "layer_in")] = torch.zeros(
    1, seq_len, config.hidden_size,
    device=device,
    dtype=model.dtype
)

print(f"Input shape: {inputs['input_ids'].shape}")

Input shape: torch.Size([1, 33])


In [13]:
# Run EAP-IG with specific layer components
layer_components = [(0, "attn"), (0, "mlp"), (5, "attn"), (5, "mlp")]

print("Running EAP Integrated Gradients...")
eap_ig = eap_integrated_gradients(
    model=model,
    inputs=inputs,
    baseline_embeddings=baseline_embeddings,
    layer_components=layer_components,
    metric_fn=metric_fn,
    position=-1,
    intermediate_points=5,
)

print(f"\nEAP-IG computed for {len(eap_ig)} components")
for key, val in eap_ig.items():
    print(f"  {key}: {val}")

print("PASSED: EAP Integrated Gradients")

Running EAP Integrated Gradients...


NNsightException: 

Traceback (most recent call last):
  File "/content/mech_interp_toolkit/src/mech_interp_toolkit/activations.py", line 206, in unified_access_and_patching
    comp[:] = self.patch_fn(comp, self.patching_dict[(layer, component)], patch_pos)
  File "/content/mech_interp_toolkit/src/mech_interp_toolkit/activations.py", line 182, in patch_fn
    original[:, position, :] = new_value

RuntimeError: shape mismatch: value tensor of shape [33, 1024] cannot be broadcast to indexing result of shape [1, 1, 1024]

In [None]:
# Run EAP-IG with all layer components (default)
print("\nRunning EAP-IG with all layer components...")
eap_ig_full = eap_integrated_gradients(
    model=model,
    inputs=inputs,
    baseline_embeddings=baseline_embeddings,
    layer_components=None,  # Uses all components
    metric_fn=metric_fn,
    position=-1,
    intermediate_points=3,
)

print(f"EAP-IG (full): {len(eap_ig_full)} components")

# Print summary by layer
print("\nEAP-IG summary by layer:")
print(f"{'Layer':<8} {'Attn':>12} {'MLP':>12}")
print("-" * 32)

for layer in range(min(10, config.num_hidden_layers)):
    attn_val = eap_ig_full.get((layer, "attn"), torch.tensor(0.0))
    mlp_val = eap_ig_full.get((layer, "mlp"), torch.tensor(0.0))
    attn_score = attn_val.item() if isinstance(attn_val, torch.Tensor) else attn_val
    mlp_score = mlp_val.item() if isinstance(mlp_val, torch.Tensor) else mlp_val
    print(f"{layer:<8} {attn_score:>12.6f} {mlp_score:>12.6f}")

print("PASSED: EAP-IG with all components")

## Test: Custom metric functions

In [None]:
# Test different metric functions
def metric_sum(logits):
    return logits.sum()

def metric_mean(logits):
    return logits.mean()

def metric_max_prob(logits):
    probs = torch.softmax(logits, dim=-1)
    return probs.max()

metrics = [
    ("sum", metric_sum),
    ("mean", metric_mean),
    ("max_prob", metric_max_prob),
]

print("Testing EAP with different metrics:")
for name, metric in metrics:
    eap = edge_attribution_patching(
        model=model,
        clean_inputs=clean_inputs,
        corrupted_inputs=corrupted_inputs,
        metric_fn=metric,
        position=-1,
    )
    # Get total attribution
    total = sum(v.sum().item() for v in eap.values())
    print(f"  {name}: total attribution = {total:.6f}")

print("PASSED: Custom metric functions")

## Test: Different positions

In [None]:
# Test EAP with different position specifications
print("Testing EAP with different positions:")

# Last position
eap_last = edge_attribution_patching(
    model=model,
    clean_inputs=clean_inputs,
    corrupted_inputs=corrupted_inputs,
    metric_fn=metric_fn,
    position=-1,
)
print(f"  position=-1: {len(eap_last)} components")

# All positions
eap_all = edge_attribution_patching(
    model=model,
    clean_inputs=clean_inputs,
    corrupted_inputs=corrupted_inputs,
    metric_fn=metric_fn,
    position=slice(None),
)
print(f"  position=slice(None): {len(eap_all)} components")

print("PASSED: Different positions")

## Summary

In [None]:
print("="*50)
print("All gradient_based_attribution module tests PASSED!")
print("="*50)