# Test: gradient_based_attribution module

This notebook tests gradient-based attribution methods from `mech_interp_toolkit.gradient_based_attribution`.

In [None]:
import sys
import os
sys.path.append(os.path.abspath("../src"))

import torch
from mech_interp_toolkit.utils import load_model_tokenizer_config, get_default_device, get_all_layer_components
from mech_interp_toolkit.activation_dict import ActivationDict
from mech_interp_toolkit.gradient_based_attribution import (
    edge_attribution_patching,
    simple_integrated_gradients,
    eap_integrated_gradients,
)

## Setup: Load model

In [None]:
model_name = "Qwen/Qwen3-0.6B"
device = get_default_device()

print(f"Loading model {model_name} on {device}...")
model, tokenizer, config = load_model_tokenizer_config(model_name, device=device)
print("Model loaded successfully")
print(f"Number of layers: {config.num_hidden_layers}")
print(f"Hidden size: {config.hidden_size}")

In [None]:
# Prepare clean and corrupted inputs for EAP
clean_prompts = ["The capital of France is"]
corrupted_prompts = ["The capital of Germany is"]

clean_inputs = tokenizer(clean_prompts, thinking=False)
corrupted_inputs = tokenizer(corrupted_prompts, thinking=False)

print(f"Clean input shape: {clean_inputs['input_ids'].shape}")
print(f"Corrupted input shape: {corrupted_inputs['input_ids'].shape}")

## Test: edge_attribution_patching()

In [None]:
# Define a simple metric function
def metric_fn(logits):
    return logits[:, -1, :].max(dim=-1).values.sum()

print("Running Edge Attribution Patching (EAP)...")
eap_scores = edge_attribution_patching(
    model=model,
    clean_inputs=clean_inputs,
    corrupted_inputs=corrupted_inputs,
    compute_grad_at="clean",
    metric_fn=metric_fn,
    position=-1,
)

print(f"\nEAP scores computed for {len(eap_scores)} components")
print(f"EAP score type: {type(eap_scores)}")
print(f"\nSample EAP scores:")
for key, score in list(eap_scores.items())[:6]:
    print(f"  {key}: {score}")

In [None]:
# Analyze EAP scores
n_layers = config.num_hidden_layers

print("\nEAP scores summary by layer:")
print(f"{'Layer':<8} {'Attn':>12} {'MLP':>12}")
print("-" * 32)

for layer in range(n_layers):
    attn_key = (layer, "attn")
    mlp_key = (layer, "mlp")
    
    attn_score = eap_scores.get(attn_key, torch.tensor(0.0)).item() if attn_key in eap_scores else 0.0
    mlp_score = eap_scores.get(mlp_key, torch.tensor(0.0)).item() if mlp_key in eap_scores else 0.0
    
    print(f"{layer:<8} {attn_score:>12.6f} {mlp_score:>12.6f}")

print("PASSED: Edge Attribution Patching")

In [None]:
# Test with compute_grad_at="corrupted"
print("\nRunning EAP with compute_grad_at='corrupted'...")
eap_scores_corrupted = edge_attribution_patching(
    model=model,
    clean_inputs=clean_inputs,
    corrupted_inputs=corrupted_inputs,
    compute_grad_at="corrupted",
    metric_fn=metric_fn,
    position=-1,
)

print(f"EAP scores (corrupted): {len(eap_scores_corrupted)} components")
print("Sample scores:")
for key in list(eap_scores_corrupted.keys())[:3]:
    print(f"  {key}: {eap_scores_corrupted[key]}")

print("PASSED: EAP with corrupted gradients")

## Test: simple_integrated_gradients()

In [None]:
# Create baseline embeddings (zeros)
inputs = tokenizer(["The quick brown fox"], thinking=False)
seq_len = inputs["input_ids"].shape[1]

baseline_embeddings = ActivationDict(config, positions=slice(None))
baseline_embeddings[(0, "layer_in")] = torch.zeros(
    1, seq_len, config.hidden_size, 
    device=device, 
    dtype=model.dtype
)

print(f"Input sequence length: {seq_len}")
print(f"Baseline embeddings shape: {baseline_embeddings[(0, 'layer_in')].shape}")

In [None]:
print("Running Simple Integrated Gradients...")
ig_attributions = simple_integrated_gradients(
    model=model,
    inputs=inputs,
    baseline_embeddings=baseline_embeddings,
    metric_fn=metric_fn,
    steps=10,  # Using fewer steps for faster testing
)

print(f"\nIG attributions computed")
print(f"Keys: {list(ig_attributions.keys())}")

for key, val in ig_attributions.items():
    print(f"  {key}: shape={val.shape}")

print("PASSED: Simple Integrated Gradients")

In [None]:
# Analyze IG attributions per position
ig_values = ig_attributions[(0, "layer_in")]
print(f"\nIG attribution shape: {ig_values.shape}")
print(f"\nIG attribution per position:")

for pos in range(ig_values.shape[1]):
    print(f"  Position {pos}: {ig_values[0, pos].item():.6f}")

In [None]:
# Test with different step counts
print("\nTesting IG with different step counts:")
for steps in [5, 10, 20]:
    ig = simple_integrated_gradients(
        model=model,
        inputs=inputs,
        baseline_embeddings=baseline_embeddings,
        metric_fn=metric_fn,
        steps=steps,
    )
    total_attribution = ig[(0, "layer_in")].sum().item()
    print(f"  Steps={steps}: total attribution = {total_attribution:.6f}")

print("PASSED: IG with different step counts")

## Test: eap_integrated_gradients()

In [None]:
# Get baseline embeddings for EAP-IG
inputs = tokenizer(["The capital of France is"], thinking=False)
seq_len = inputs["input_ids"].shape[1]

baseline_embeddings = ActivationDict(config, positions=slice(None))
baseline_embeddings[(0, "layer_in")] = torch.zeros(
    1, seq_len, config.hidden_size,
    device=device,
    dtype=model.dtype
)

print(f"Input shape: {inputs['input_ids'].shape}")

In [None]:
# Run EAP-IG with specific layer components
layer_components = [(0, "attn"), (0, "mlp"), (5, "attn"), (5, "mlp")]

print("Running EAP Integrated Gradients...")
eap_ig = eap_integrated_gradients(
    model=model,
    inputs=inputs,
    baseline_embeddings=baseline_embeddings,
    layer_components=layer_components,
    metric_fn=metric_fn,
    position=-1,
    intermediate_points=5,
)

print(f"\nEAP-IG computed for {len(eap_ig)} components")
for key, val in eap_ig.items():
    print(f"  {key}: {val}")

print("PASSED: EAP Integrated Gradients")

In [None]:
# Run EAP-IG with all layer components (default)
print("\nRunning EAP-IG with all layer components...")
eap_ig_full = eap_integrated_gradients(
    model=model,
    inputs=inputs,
    baseline_embeddings=baseline_embeddings,
    layer_components=None,  # Uses all components
    metric_fn=metric_fn,
    position=-1,
    intermediate_points=3,
)

print(f"EAP-IG (full): {len(eap_ig_full)} components")

# Print summary by layer
print("\nEAP-IG summary by layer:")
print(f"{'Layer':<8} {'Attn':>12} {'MLP':>12}")
print("-" * 32)

for layer in range(min(10, config.num_hidden_layers)):
    attn_val = eap_ig_full.get((layer, "attn"), torch.tensor(0.0))
    mlp_val = eap_ig_full.get((layer, "mlp"), torch.tensor(0.0))
    attn_score = attn_val.item() if isinstance(attn_val, torch.Tensor) else attn_val
    mlp_score = mlp_val.item() if isinstance(mlp_val, torch.Tensor) else mlp_val
    print(f"{layer:<8} {attn_score:>12.6f} {mlp_score:>12.6f}")

print("PASSED: EAP-IG with all components")

## Test: Custom metric functions

In [None]:
# Test different metric functions
def metric_sum(logits):
    return logits.sum()

def metric_mean(logits):
    return logits.mean()

def metric_max_prob(logits):
    probs = torch.softmax(logits, dim=-1)
    return probs.max()

metrics = [
    ("sum", metric_sum),
    ("mean", metric_mean),
    ("max_prob", metric_max_prob),
]

print("Testing EAP with different metrics:")
for name, metric in metrics:
    eap = edge_attribution_patching(
        model=model,
        clean_inputs=clean_inputs,
        corrupted_inputs=corrupted_inputs,
        metric_fn=metric,
        position=-1,
    )
    # Get total attribution
    total = sum(v.sum().item() for v in eap.values())
    print(f"  {name}: total attribution = {total:.6f}")

print("PASSED: Custom metric functions")

## Test: Different positions

In [None]:
# Test EAP with different position specifications
print("Testing EAP with different positions:")

# Last position
eap_last = edge_attribution_patching(
    model=model,
    clean_inputs=clean_inputs,
    corrupted_inputs=corrupted_inputs,
    metric_fn=metric_fn,
    position=-1,
)
print(f"  position=-1: {len(eap_last)} components")

# All positions
eap_all = edge_attribution_patching(
    model=model,
    clean_inputs=clean_inputs,
    corrupted_inputs=corrupted_inputs,
    metric_fn=metric_fn,
    position=slice(None),
)
print(f"  position=slice(None): {len(eap_all)} components")

print("PASSED: Different positions")

## Summary

In [None]:
print("="*50)
print("All gradient_based_attribution module tests PASSED!")
print("="*50)