In [None]:
# !git clone https://github.com/SD-interp/mech_interp_toolkit.git
# %cd mech_interp_toolkit
# !pip install -e .

# Test: gradient_based_attribution module

This notebook tests gradient-based attribution methods from `mech_interp_toolkit.gradient_based_attribution`.

Functions tested:
- `get_activations()` - Extract activations with gradient support
- `get_embeddings()` - Extract input embeddings
- `simple_integrated_gradients()` - Vanilla integrated gradients w.r.t embeddings
- `edge_attribution_patching()` - Simple gradient x activation method
- `eap_integrated_gradients()` - Integrated gradients for edge attributions

In [None]:
import sys
import os
sys.path.append(os.path.abspath("../src"))

import torch
from mech_interp_toolkit.utils import load_model_tokenizer_config, get_default_device, get_all_layer_components
from mech_interp_toolkit.gradient_based_attribution import (
    get_activations,
    get_embeddings,
    simple_integrated_gradients,
    edge_attribution_patching,
    eap_integrated_gradients,
)

## Setup: Load model

In [None]:
model_name = "Qwen/Qwen3-0.6B"
device = get_default_device()

print(f"Loading model {model_name} on {device}...")
model, tokenizer, config = load_model_tokenizer_config(model_name, device=device)
print("Model loaded successfully")
print(f"Number of layers: {config.num_hidden_layers}")
print(f"Hidden size: {config.hidden_size}")

In [None]:
# Prepare clean and corrupted inputs for attribution methods
clean_prompts = ["The capital of France is"]
corrupted_prompts = ["The capital of Germany is"]

clean_inputs = tokenizer(clean_prompts, thinking=False)
corrupted_inputs = tokenizer(corrupted_prompts, thinking=False)

print(f"Clean input shape: {clean_inputs['input_ids'].shape}")
print(f"Corrupted input shape: {corrupted_inputs['input_ids'].shape}")

## Test: get_activations()

In [None]:
# Test extracting activations from specific layer components
layer_components = [(0, "attn"), (0, "mlp"), (5, "attn"), (5, "mlp")]

print("Extracting activations...")
activations = get_activations(model, clean_inputs, layer_components)

print(f"\nExtracted activations:")
for key, val in activations.items():
    print(f"  {key}: {val.shape}")

# Verify all requested components were extracted
for lc in layer_components:
    assert lc in activations, f"Missing {lc}"
print("\nPASSED: get_activations()")

## Test: get_embeddings()

In [None]:
# Test extracting input embeddings
print("Extracting embeddings...")
embeddings = get_embeddings(model, clean_inputs)

print(f"Embeddings keys: {list(embeddings.keys())}")
print(f"Embeddings shape: {embeddings[(0, 'layer_in')].shape}")

# Should have shape (batch, seq_len, hidden_size)
batch_size, seq_len = clean_inputs["input_ids"].shape
assert embeddings[(0, "layer_in")].shape == (batch_size, seq_len, config.hidden_size)
print("\nPASSED: get_embeddings()")

## Test: edge_attribution_patching()

In [None]:
# Define a simple metric function
def metric_fn(logits):
    return logits[:, -1, :].max(dim=-1).values.sum()

print("Running Edge Attribution Patching (EAP)...")
eap_scores = edge_attribution_patching(
    model=model,
    clean_inputs=clean_inputs,
    corrupted_inputs=corrupted_inputs,
    compute_grad_at="clean",
    metric_fn=metric_fn,
)

print(f"\nEAP scores computed for {len(eap_scores)} components")
print(f"EAP score type: {type(eap_scores).__name__}")
print(f"\nSample EAP scores (first 6 components):")
for key, score in list(eap_scores.items())[:6]:
    print(f"  {key}: {score}")

print("\nPASSED: edge_attribution_patching()")

In [None]:
# Test with compute_grad_at="corrupted"
print("Running EAP with compute_grad_at='corrupted'...")
eap_scores_corrupted = edge_attribution_patching(
    model=model,
    clean_inputs=clean_inputs,
    corrupted_inputs=corrupted_inputs,
    compute_grad_at="corrupted",
    metric_fn=metric_fn,
)

print(f"EAP scores (corrupted): {len(eap_scores_corrupted)} components")
print("Sample scores:")
for key in list(eap_scores_corrupted.keys())[:3]:
    print(f"  {key}: {eap_scores_corrupted[key]}")

print("\nPASSED: EAP with corrupted gradients")

In [None]:
# Analyze EAP scores by layer
n_layers = config.num_hidden_layers

print("\nEAP scores summary by layer (first 10 layers):")
print(f"{'Layer':<8} {'Attn':>12} {'MLP':>12}")
print("-" * 32)

for layer in range(min(10, n_layers)):
    attn_key = (layer, "attn")
    mlp_key = (layer, "mlp")

    attn_score = eap_scores.get(attn_key, torch.tensor(0.0)).item() if attn_key in eap_scores else 0.0
    mlp_score = eap_scores.get(mlp_key, torch.tensor(0.0)).item() if mlp_key in eap_scores else 0.0

    print(f"{layer:<8} {attn_score:>12.6f} {mlp_score:>12.6f}")

## Test: simple_integrated_gradients()

In [None]:
# Get input embeddings to determine shape for baseline
inputs = tokenizer(["The quick brown fox"], thinking=False)
embeddings = get_embeddings(model, inputs)
input_embeddings = embeddings[(0, "layer_in")]

# Create baseline embeddings (zeros) - must be a torch.Tensor, same shape as input embeddings
baseline_embeddings = torch.zeros_like(input_embeddings)

print(f"Input sequence length: {inputs['input_ids'].shape[1]}")
print(f"Input embeddings shape: {input_embeddings.shape}")
print(f"Baseline embeddings shape: {baseline_embeddings.shape}")

In [None]:
print("Running Simple Integrated Gradients...")
with torch.enable_grad():
    ig_attributions = simple_integrated_gradients(
        model=model,
        inputs=inputs,
        baseline_embeddings=baseline_embeddings,
        metric_fn=metric_fn,
        steps=10,  # Using fewer steps for faster testing
    )

print(f"\nIG attributions type: {type(ig_attributions).__name__}")
print(f"IG attributions shape: {ig_attributions.shape}")

# Should return tensor of shape (batch, seq_len) - attributions summed over hidden dim
assert ig_attributions.ndim == 2, "IG should return 2D tensor (batch, seq_len)"
print("\nPASSED: simple_integrated_gradients()")

In [None]:
# Analyze IG attributions per position
print(f"\nIG attribution per position:")
for pos in range(ig_attributions.shape[1]):
    val = ig_attributions[0, pos].item()
    print(f"  Position {pos}: {val:.6f}")

## Test: eap_integrated_gradients()

In [None]:
# Prepare inputs for EAP-IG (requires clean and corrupted inputs with same shape)
print(f"Clean inputs shape: {clean_inputs['input_ids'].shape}")
print(f"Corrupted inputs shape: {corrupted_inputs['input_ids'].shape}")

In [None]:
print("Running EAP Integrated Gradients...")
with torch.enable_grad():
    eap_ig = eap_integrated_gradients(
        model=model,
        clean_dict=clean_inputs,
        corrupted_dict=corrupted_inputs,
        metric_fn=metric_fn,
        intermediate_points=5,
    )

print(f"\nEAP-IG computed for {len(eap_ig)} components")
print(f"EAP-IG type: {type(eap_ig).__name__}")

print("\nSample EAP-IG scores (first 6 components):")
for key, val in list(eap_ig.items())[:6]:
    print(f"  {key}: {val}")

print("\nPASSED: eap_integrated_gradients()")

In [None]:
# Analyze EAP-IG scores by layer
print("\nEAP-IG summary by layer (first 10 layers):")
print(f"{'Layer':<8} {'Attn':>12} {'MLP':>12}")
print("-" * 32)

for layer in range(min(10, n_layers)):
    attn_val = eap_ig.get((layer, "attn"), torch.tensor(0.0))
    mlp_val = eap_ig.get((layer, "mlp"), torch.tensor(0.0))
    attn_score = attn_val.item() if isinstance(attn_val, torch.Tensor) else attn_val
    mlp_score = mlp_val.item() if isinstance(mlp_val, torch.Tensor) else mlp_val
    print(f"{layer:<8} {attn_score:>12.6f} {mlp_score:>12.6f}")

## Test: Custom metric functions

In [None]:
# Test different metric functions with edge_attribution_patching
def metric_sum(logits):
    return logits.sum()

def metric_mean(logits):
    return logits.mean()

def metric_max_prob(logits):
    probs = torch.softmax(logits, dim=-1)
    return probs.max()

metrics = [
    ("sum", metric_sum),
    ("mean", metric_mean),
    ("max_prob", metric_max_prob),
]

print("Testing EAP with different metrics:")
for name, metric in metrics:
    eap = edge_attribution_patching(
        model=model,
        clean_inputs=clean_inputs,
        corrupted_inputs=corrupted_inputs,
        metric_fn=metric,
    )
    # Get total attribution
    total = sum(v.sum().item() for v in eap.values())
    print(f"  {name}: total attribution = {total:.6f}")

print("\nPASSED: Custom metric functions")

## Test: Error handling

In [None]:
# Test that gradient-based methods require gradients enabled
print("Testing error handling for disabled gradients...")

try:
    with torch.no_grad():
        _ = simple_integrated_gradients(
            model=model,
            inputs=inputs,
            baseline_embeddings=baseline_embeddings,
            metric_fn=metric_fn,
            steps=5,
        )
    print("ERROR: Should have raised RuntimeError")
except RuntimeError as e:
    print(f"Correctly raised RuntimeError: {e}")

print("\nPASSED: Error handling")

In [None]:
# Test shape mismatch validation for simple_integrated_gradients
print("Testing shape mismatch validation...")

# Create baseline with wrong shape
wrong_baseline = torch.zeros(1, 5, config.hidden_size, device=device, dtype=model.dtype)

try:
    with torch.enable_grad():
        _ = simple_integrated_gradients(
            model=model,
            inputs=inputs,
            baseline_embeddings=wrong_baseline,
            metric_fn=metric_fn,
            steps=5,
        )
    print("ERROR: Should have raised ValueError")
except ValueError as e:
    print(f"Correctly raised ValueError: {e}")

print("\nPASSED: Shape mismatch validation")

## Summary

In [None]:
print("="*50)
print("All gradient_based_attribution module tests PASSED!")
print("="*50)