In [1]:
# !git clone https://github.com/SD-interp/mech_interp_toolkit.git
# %cd mech_interp_toolkit
# !pip install -e .

# Test: activations module

This notebook tests `UnifiedAccessAndPatching` and `create_z_patch_dict` from `mech_interp_toolkit.activations`.

In [2]:
import sys
import os
sys.path.append(os.path.abspath("../src"))

import torch
from mech_interp_toolkit.utils import load_model_tokenizer_config, get_default_device
from mech_interp_toolkit.activations import UnifiedAccessAndPatching, create_z_patch_dict
from mech_interp_toolkit.activation_dict import ActivationDict

## Setup: Load model

In [3]:
model_name = "Qwen/Qwen3-0.6B"
device = get_default_device()

print(f"Loading model {model_name} on {device}...")
model, tokenizer, config = load_model_tokenizer_config(model_name, device=device)
print("Model loaded successfully")
print(f"Number of layers: {config.num_hidden_layers}")
print(f"Hidden size: {config.hidden_size}")

Loading model Qwen/Qwen3-0.6B on cuda...


Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


config.json:   0%|          | 0.00/726 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.50G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Model loaded successfully
Number of layers: 28
Hidden size: 1024


In [4]:
# Prepare test inputs
prompts = [
    "The capital of France is",
    "The capital of Germany is"
]
inputs = tokenizer(prompts, thinking=False)
print(f"Input IDs shape: {inputs['input_ids'].shape}")
print(f"Attention mask shape: {inputs['attention_mask'].shape}")

Input IDs shape: torch.Size([2, 33])
Attention mask shape: torch.Size([2, 33])


## Test: Basic activation extraction

In [5]:
n_layers = config.num_hidden_layers

# Extract activations from multiple components
spec_dict = {
    "activations": {
        "positions": -1,  # Last position only
        "locations": [
            (0, "layer_in"),
            (0, "attn"),
            (0, "mlp"),
            (n_layers - 1, "layer_out")
        ],
    }
}

print("Extracting activations...")
with UnifiedAccessAndPatching(model, inputs, spec_dict) as uap:
    activations, logits = uap.unified_access_and_patching()

print(f"Logits shape: {logits.shape}")
print(f"\nExtracted activations:")
for key, val in activations.items():
    print(f"  {key}: {val.shape}")

assert (0, "layer_in") in activations, "Should have layer_in"
assert (0, "attn") in activations, "Should have attn"
assert (0, "mlp") in activations, "Should have mlp"
assert (n_layers - 1, "layer_out") in activations, "Should have layer_out"
print("PASSED: Basic activation extraction")

Extracting activations...
Logits shape: torch.Size([2, 1, 151936])

Extracted activations:
  (0, 'layer_in'): torch.Size([2, 1, 1024])
  (0, 'attn'): torch.Size([2, 1, 1024])
  (0, 'mlp'): torch.Size([2, 1, 1024])
  (27, 'layer_out'): torch.Size([2, 1, 1024])
PASSED: Basic activation extraction


## Test: Extract multiple positions

In [6]:
# Extract all positions
spec_dict = {
    "activations": {
        "positions": slice(None),  # All positions
        "locations": [(0, "layer_in")],
    }
}

print("Extracting all positions...")
with UnifiedAccessAndPatching(model, inputs, spec_dict) as uap:
    activations, logits = uap.unified_access_and_patching()

seq_len = inputs["input_ids"].shape[1]
print(f"Input sequence length: {seq_len}")
print(f"Extracted activation shape: {activations[(0, 'layer_in')].shape}")

assert activations[(0, "layer_in")].shape[1] == seq_len, "Should have all positions"
print("PASSED: Multiple positions extraction")

Extracting all positions...
Input sequence length: 33
Extracted activation shape: torch.Size([2, 33, 1024])
PASSED: Multiple positions extraction


In [7]:
# Extract specific positions
spec_dict = {
    "activations": {
        "positions": [-3, -2, -1],  # Last 3 positions
        "locations": [(0, "layer_in")],
    }
}

print("Extracting last 3 positions...")
with UnifiedAccessAndPatching(model, inputs, spec_dict) as uap:
    activations, logits = uap.unified_access_and_patching()

print(f"Extracted activation shape: {activations[(0, 'layer_in')].shape}")
assert activations[(0, "layer_in")].shape[1] == 3, "Should have 3 positions"
print("PASSED: Specific positions extraction")

Extracting last 3 positions...
Extracted activation shape: torch.Size([2, 3, 1024])
PASSED: Specific positions extraction


## Test: Extract 'z' activations (attention head outputs)

In [9]:
spec_dict = {
    "activations": {
        "positions": -1,
        "locations": [(0, "z"), (5, "z")],
    }
}

print("Extracting z activations...")
with UnifiedAccessAndPatching(model, inputs, spec_dict) as uap:
    activations, logits = uap.unified_access_and_patching()

print(f"Layer 0 z shape: {activations[(0, 'z')].shape}")
print(f"Layer 5 z shape: {activations[(5, 'z')].shape}")

# z activations should have shape (batch, pos, hidden_size) when fused
print("PASSED: z activation extraction")

Extracting z activations...
Layer 0 z shape: torch.Size([2, 1, 2048])
Layer 5 z shape: torch.Size([2, 1, 2048])
PASSED: z activation extraction


## Test: Activation patching

In [10]:
# First, get clean activations and logits
spec_dict_clean = {
    "activations": {
        "positions": -1,
        "locations": [(n_layers - 1, "layer_out")],
    }
}

with UnifiedAccessAndPatching(model, inputs, spec_dict_clean) as uap:
    clean_acts, clean_logits = uap.unified_access_and_patching()

print(f"Clean logits shape: {clean_logits.shape}")

Clean logits shape: torch.Size([2, 1, 151936])


In [11]:
# Create patch: zero-ablate MLP at layer 5
layer_to_patch = 5
batch_size = inputs["input_ids"].shape[0]
seq_len = inputs["input_ids"].shape[1]

patch_data = ActivationDict(config, positions=slice(None))
patch_tensor = torch.zeros((batch_size, seq_len, config.hidden_size), device=device, dtype=model.dtype)
patch_data[(layer_to_patch, "mlp")] = patch_tensor

patch_spec = {
    "patching": patch_data,
    "activations": {
        "positions": -1,
        "locations": [(n_layers - 1, "layer_out")],
    }
}

print(f"Patching layer {layer_to_patch} MLP with zeros...")
with UnifiedAccessAndPatching(model, inputs, patch_spec) as uap:
    patched_acts, patched_logits = uap.unified_access_and_patching()

print(f"Patched logits shape: {patched_logits.shape}")

# Logits should be different after patching
diff = (clean_logits - patched_logits).abs().sum()
print(f"Logit difference (clean vs patched): {diff.item():.4f}")
assert diff.item() > 0, "Patched logits should differ from clean"
print("PASSED: Activation patching")

Patching layer 5 MLP with zeros...
Patched logits shape: torch.Size([2, 1, 151936])
Logit difference (clean vs patched): 618010.5000
PASSED: Activation patching


## Test: Gradient computation

In [12]:
# Define a simple metric function
def metric_fn(logits):
    return logits.max(dim=-1).values.sum()

spec_dict_grad = {
    "activations": {
        "positions": slice(None),  # Need all positions for gradients
        "locations": [(0, "layer_in"), (5, "attn")],
        "gradients": {
            "metric_fn": metric_fn,
            "compute_metric_at": (-1, "logits")
        }
    }
}

print("Computing gradients...")
with UnifiedAccessAndPatching(model, inputs, spec_dict_grad) as uap:
    acts_with_grad, logits = uap.unified_access_and_patching()

print(f"\nActivations with gradient info:")
for key, val in acts_with_grad.items():
    print(f"  {key}: shape={val.shape}, has_grad={val.grad is not None}")
    if val.grad is not None:
        print(f"    grad shape: {val.grad.shape}, grad norm: {val.grad.norm().item():.6f}")

# Extract gradients
grads = acts_with_grad.get_grads()
print(f"\nExtracted gradients:")
for key, val in grads.items():
    if val is not None:
        print(f"  {key}: {val.shape}")

print("PASSED: Gradient computation")

Computing gradients...

Activations with gradient info:
  (0, 'layer_in'): shape=torch.Size([2, 33, 1024]), has_grad=True
    grad shape: torch.Size([2, 33, 1024]), grad norm: 169.458817
  (5, 'attn'): shape=torch.Size([2, 33, 1024]), has_grad=True
    grad shape: torch.Size([2, 33, 1024]), grad norm: 5.725571

Extracted gradients:
  (0, 'layer_in'): torch.Size([2, 33, 1024])
  (5, 'attn'): torch.Size([2, 33, 1024])
PASSED: Gradient computation


## Test: stop_at_layer

In [13]:
# This is an advanced feature for early stopping
spec_dict = {
    "stop_at_layer": 5,
    "activations": {
        "positions": -1,
        "locations": [(0, "layer_in"), (3, "attn")],
    }
}

print("Extracting with stop_at_layer=5...")
with UnifiedAccessAndPatching(model, inputs, spec_dict) as uap:
    activations, logits = uap.unified_access_and_patching()

print(f"Extracted keys: {list(activations.keys())}")
for key, val in activations.items():
    print(f"  {key}: {val.shape}")
print("PASSED: stop_at_layer")

Extracting with stop_at_layer=5...
Extracted keys: [(0, 'layer_in'), (3, 'attn')]
  (0, 'layer_in'): torch.Size([2, 1, 1024])
  (3, 'attn'): torch.Size([2, 1, 1024])
PASSED: stop_at_layer


## Test: create_z_patch_dict()

In [14]:
# Get z activations from two different prompts
prompt1 = tokenizer(["Paris is the capital of France"], thinking=False)
prompt2 = tokenizer(["Berlin is the capital of Germany"], thinking=False)

spec_dict = {
    "activations": {
        "positions": slice(None),
        "locations": [(5, "z")],
    }
}

with UnifiedAccessAndPatching(model, prompt1, spec_dict) as uap:
    acts1, _ = uap.unified_access_and_patching()

with UnifiedAccessAndPatching(model, prompt2, spec_dict) as uap:
    acts2, _ = uap.unified_access_and_patching()

print(f"Acts1 z shape (fused): {acts1[(5, 'z')].shape}")
print(f"Acts2 z shape (fused): {acts2[(5, 'z')].shape}")

Acts1 z shape (fused): torch.Size([1, 34, 2048])
Acts2 z shape (fused): torch.Size([1, 34, 2048])


In [None]:
# Split heads for patching
acts1.split_heads()
acts2.split_heads()

print(f"Acts1 z shape (split): {acts1[(5, 'z')].shape}")
print(f"Acts2 z shape (split): {acts2[(5, 'z')].shape}")

z_patch = acts1[(5, 'z')].clone()[:, -1, :, :]

# Create patch dict for specific layer-head pairs
layer_head_pairs = [(5, 0), (5, 3)]  # Patch heads 0 and 3 at layer 5

patch_dict = create_z_patch_dict(
    original_acts=acts1,
    new_acts=z_patch,
    layer_head=layer_head_pairs,
    position=-1  # Patch last position only
)

print(f"\nPatch dict keys: {list(patch_dict.keys())}")
print(f"Patch dict fused_heads: {patch_dict.fused_heads}")
for key, val in patch_dict.items():
    print(f"  {key}: {val.shape}")

print("PASSED: create_z_patch_dict()")

## Test: Context manager cleanup

In [16]:
import gc

spec_dict = {
    "activations": {
        "positions": -1,
        "locations": [(0, "layer_in")],
    }
}

# Run in context manager
with UnifiedAccessAndPatching(model, inputs, spec_dict) as uap:
    activations, logits = uap.unified_access_and_patching()
    print(f"Inside context: got {len(activations)} activations")

# After context exits, references should be cleaned up
gc.collect()
print("Context manager exited and cleanup completed")
print("PASSED: Context manager cleanup")

Inside context: got 1 activations
Context manager exited and cleanup completed
PASSED: Context manager cleanup


## Test: Static method patch_fn

In [19]:
# Test the static patch function
original = torch.ones(2, 5, 10)
new_value = torch.zeros(2, 1, 10)
position = -1

result = UnifiedAccessAndPatching.patch_fn(original, new_value, [position])

print(f"Original last position sum: {original[:, -1, :].sum().item()}")
print(f"Result last position sum: {result[:, -1, :].sum().item()}")
print(f"Result other positions sum: {result[:, :-1, :].sum().item()}")

# Last position should be zeros, others should be ones
assert result[:, -1, :].sum().item() == 0, "Last position should be patched to zeros"
assert result[:, :-1, :].sum().item() == (2 * 4 * 10), "Other positions should be unchanged"
print("PASSED: Static patch_fn")

Original last position sum: 20.0
Result last position sum: 0.0
Result other positions sum: 80.0
PASSED: Static patch_fn


## Test: metric_fn_example

In [21]:
# Test the example metric function
acts = torch.randn(2, 5, 10)  # batch=2, pos=5, hidden=10
result = UnifiedAccessAndPatching.metric_fn_example(acts)

# Should sum the last position across batch and hidden dimensions
expected = acts[:, -1, :].sum()
print(f"Metric result: {result.item():.4f}")
print(f"Expected (acts[:, -1, :].sum()): {expected.item():.4f}")

assert torch.isclose(result, expected), "metric_fn_example should sum last position"
print("PASSED: metric_fn_example")

Metric result: 7.8077
Expected (acts[:, -1, :].sum()): 7.8077
PASSED: metric_fn_example


## Summary

In [22]:
print("="*50)
print("All activations module tests PASSED!")
print("="*50)

All activations module tests PASSED!
