# Test: activation_dict module

This notebook tests `FrozenError`, `FreezableDict`, `ArithmeticOperation`, and `ActivationDict` classes.

In [None]:
import sys
import os
sys.path.append(os.path.abspath("../src"))

import torch
from transformers import AutoConfig
from mech_interp_toolkit.activation_dict import (
    FrozenError,
    FreezableDict,
    ArithmeticOperation,
    ActivationDict,
)

## Setup: Load model config

In [None]:
model_name = "Qwen/Qwen3-0.6B"
config = AutoConfig.from_pretrained(model_name)
print(f"Loaded config for {model_name}")
print(f"Hidden size: {config.hidden_size}")
print(f"Num attention heads: {config.num_attention_heads}")
print(f"Num hidden layers: {config.num_hidden_layers}")

## Test: FrozenError exception

In [None]:
try:
    raise FrozenError("Test error message")
except FrozenError as e:
    print(f"Caught FrozenError: {e}")
    assert "Test error message" in str(e)
print("PASSED: FrozenError exception")

## Test: ActivationDict initialization

In [None]:
# Test basic initialization
act_dict = ActivationDict(config, positions=slice(None))

print(f"num_heads: {act_dict.num_heads}")
print(f"num_layers: {act_dict.num_layers}")
print(f"head_dim: {act_dict.head_dim}")
print(f"model_dim: {act_dict.model_dim}")
print(f"fused_heads: {act_dict.fused_heads}")
print(f"value_type: {act_dict.value_type}")

assert act_dict.num_heads == config.num_attention_heads
assert act_dict.num_layers == config.num_hidden_layers
assert act_dict.head_dim == config.hidden_size // config.num_attention_heads
assert act_dict.model_dim == config.hidden_size
assert act_dict.fused_heads == True
assert act_dict.value_type == "activation"
print("PASSED: ActivationDict initialization")

In [None]:
# Test initialization with different value_type
grad_dict = ActivationDict(config, positions=-1, value_type="gradient")
assert grad_dict.value_type == "gradient"

score_dict = ActivationDict(config, positions=[0, 1, 2], value_type="scores")
assert score_dict.value_type == "scores"
print("PASSED: Different value_type initialization")

## Test: Freeze/Unfreeze functionality

In [None]:
act_dict = ActivationDict(config, positions=slice(None))

# Add data before freezing
act_dict[(0, "attn")] = torch.randn(2, 10, config.hidden_size)
print(f"Added (0, 'attn'): {act_dict[(0, 'attn')].shape}")

# Freeze the dictionary
act_dict.freeze()
assert act_dict._frozen == True, "Should be frozen"
print("Dictionary frozen")

# Try to modify while frozen (should raise FrozenError)
try:
    act_dict[(0, "mlp")] = torch.randn(2, 10, config.hidden_size)
    assert False, "Should have raised FrozenError"
except FrozenError:
    print("Correctly raised FrozenError on setitem")

# Try other modifying operations
try:
    del act_dict[(0, "attn")]
    assert False, "Should have raised FrozenError"
except FrozenError:
    print("Correctly raised FrozenError on delitem")

try:
    act_dict.clear()
    assert False, "Should have raised FrozenError"
except FrozenError:
    print("Correctly raised FrozenError on clear")

# Unfreeze and modify
act_dict.unfreeze()
assert act_dict._frozen == False, "Should be unfrozen"
act_dict[(0, "mlp")] = torch.randn(2, 10, config.hidden_size)
print(f"After unfreeze, added (0, 'mlp'): {act_dict[(0, 'mlp')].shape}")
print("PASSED: Freeze/Unfreeze functionality")

## Test: clone()

In [None]:
original = ActivationDict(config, positions=slice(None))
original[(0, "attn")] = torch.randn(2, 5, config.hidden_size)

cloned = original.clone()

# Verify clone has same data
assert (0, "attn") in cloned, "Cloned should have same keys"
assert torch.equal(original[(0, "attn")], cloned[(0, "attn")]), "Data should be equal"

# Verify clone is independent (deep copy)
cloned[(0, "attn")] = torch.zeros_like(cloned[(0, "attn")])
assert not torch.equal(original[(0, "attn")], cloned[(0, "attn")]), "Modifying clone should not affect original"

print("Original sum:", original[(0, "attn")].sum().item())
print("Cloned sum (after zeroing):", cloned[(0, "attn")].sum().item())
print("PASSED: clone()")

## Test: Arithmetic operations

In [None]:
# Create two ActivationDicts with same keys
act1 = ActivationDict(config, positions=slice(None))
act1[(0, "attn")] = torch.ones(2, 5, config.hidden_size) * 3
act1[(0, "mlp")] = torch.ones(2, 5, config.hidden_size) * 5

act2 = ActivationDict(config, positions=slice(None))
act2[(0, "attn")] = torch.ones(2, 5, config.hidden_size) * 2
act2[(0, "mlp")] = torch.ones(2, 5, config.hidden_size) * 3

In [None]:
# Test addition
result = act1 + act2
assert torch.allclose(result[(0, "attn")], torch.ones(2, 5, config.hidden_size) * 5)
assert torch.allclose(result[(0, "mlp")], torch.ones(2, 5, config.hidden_size) * 8)
print(f"Addition: 3 + 2 = {result[(0, 'attn')][0, 0, 0].item()}")
print("PASSED: Addition")

In [None]:
# Test subtraction
result = act1 - act2
assert torch.allclose(result[(0, "attn")], torch.ones(2, 5, config.hidden_size) * 1)
assert torch.allclose(result[(0, "mlp")], torch.ones(2, 5, config.hidden_size) * 2)
print(f"Subtraction: 3 - 2 = {result[(0, 'attn')][0, 0, 0].item()}")
print("PASSED: Subtraction")

In [None]:
# Test multiplication with ActivationDict
result = act1 * act2
assert torch.allclose(result[(0, "attn")], torch.ones(2, 5, config.hidden_size) * 6)
assert torch.allclose(result[(0, "mlp")], torch.ones(2, 5, config.hidden_size) * 15)
print(f"Multiplication (ActivationDict): 3 * 2 = {result[(0, 'attn')][0, 0, 0].item()}")
print("PASSED: Multiplication (ActivationDict)")

In [None]:
# Test multiplication with scalar
result = act1 * 2.0
assert torch.allclose(result[(0, "attn")], torch.ones(2, 5, config.hidden_size) * 6)
print(f"Multiplication (scalar): 3 * 2.0 = {result[(0, 'attn')][0, 0, 0].item()}")

# Test rmul
result = 2.0 * act1
assert torch.allclose(result[(0, "attn")], torch.ones(2, 5, config.hidden_size) * 6)
print(f"Reverse multiplication: 2.0 * 3 = {result[(0, 'attn')][0, 0, 0].item()}")
print("PASSED: Scalar multiplication")

In [None]:
# Test division
result = act1 / act2
assert torch.allclose(result[(0, "attn")], torch.ones(2, 5, config.hidden_size) * 1.5)
print(f"Division (ActivationDict): 3 / 2 = {result[(0, 'attn')][0, 0, 0].item()}")

result = act1 / 2.0
assert torch.allclose(result[(0, "attn")], torch.ones(2, 5, config.hidden_size) * 1.5)
print(f"Division (scalar): 3 / 2.0 = {result[(0, 'attn')][0, 0, 0].item()}")
print("PASSED: Division")

## Test: split_heads() and merge_heads()

In [None]:
act_dict = ActivationDict(config, positions=slice(None))
batch_size, seq_len = 2, 10
# z activations with fused heads: (batch, pos, n_heads * d_head)
act_dict[(0, "z")] = torch.randn(batch_size, seq_len, config.hidden_size)
print(f"Original z shape (fused): {act_dict[(0, 'z')].shape}")
assert act_dict.fused_heads == True

# Split heads
act_dict.split_heads()
print(f"After split_heads z shape: {act_dict[(0, 'z')].shape}")
expected_shape = (batch_size, seq_len, act_dict.num_heads, act_dict.head_dim)
assert act_dict[(0, "z")].shape == expected_shape, f"Expected {expected_shape}, got {act_dict[(0, 'z')].shape}"
assert act_dict.fused_heads == False
print("PASSED: split_heads()")

In [None]:
# Merge heads back
act_dict.merge_heads()
print(f"After merge_heads z shape: {act_dict[(0, 'z')].shape}")
expected_shape = (batch_size, seq_len, config.hidden_size)
assert act_dict[(0, "z")].shape == expected_shape, f"Expected {expected_shape}, got {act_dict[(0, 'z')].shape}"
assert act_dict.fused_heads == True
print("PASSED: merge_heads()")

## Test: apply()

In [None]:
act_dict = ActivationDict(config, positions=slice(None))
act_dict[(0, "attn")] = torch.ones(2, 5, config.hidden_size) * 3
act_dict[(0, "mlp")] = torch.ones(2, 5, config.hidden_size) * 5

# Apply sum along last dimension
summed = act_dict.apply(torch.sum, dim=-1)
print(f"Original shape: {act_dict[(0, 'attn')].shape}")
print(f"After apply(sum, dim=-1) shape: {summed[(0, 'attn')].shape}")

expected_sum = 3 * config.hidden_size
assert torch.allclose(summed[(0, "attn")], torch.ones(2, 5) * expected_sum)
print(f"Sum value: {summed[(0, 'attn')][0, 0].item()} (expected {expected_sum})")
print("PASSED: apply()")

## Test: cuda() and cpu()

In [None]:
act_dict = ActivationDict(config, positions=slice(None))
act_dict[(0, "attn")] = torch.randn(2, 5, config.hidden_size)

# Test cpu()
act_dict.cpu()
assert act_dict[(0, "attn")].device.type == "cpu", "Should be on CPU"
print(f"Device after cpu(): {act_dict[(0, 'attn')].device}")

# Test cuda() if available
if torch.cuda.is_available():
    act_dict.cuda()
    assert act_dict[(0, "attn")].device.type == "cuda", "Should be on CUDA"
    print(f"Device after cuda(): {act_dict[(0, 'attn')].device}")
    act_dict.cpu()  # Move back to CPU
else:
    print("CUDA not available, skipping cuda() test")

print("PASSED: cuda() and cpu()")

## Test: zeros_like()

In [None]:
act_dict = ActivationDict(config, positions=slice(None))
act_dict[(0, "attn")] = torch.randn(2, 5, config.hidden_size)
act_dict[(0, "mlp")] = torch.randn(2, 5, config.hidden_size)

# Create zeros for all keys
zeros = act_dict.zeros_like()
assert torch.all(zeros[(0, "attn")] == 0), "Should be all zeros"
assert zeros[(0, "attn")].shape == act_dict[(0, "attn")].shape, "Shape should match"
print(f"zeros_like() sum: {zeros[(0, 'attn')].sum().item()}")

# Create zeros for specific keys
zeros_partial = act_dict.zeros_like(keys=[(0, "attn")])
assert (0, "attn") in zeros_partial, "Should have requested key"
assert (0, "mlp") not in zeros_partial, "Should not have unrequested key"
print("PASSED: zeros_like()")

## Test: reorganize()

In [None]:
act_dict = ActivationDict(config, positions=slice(None))
# Add in random order
act_dict[(1, "mlp")] = torch.randn(2, 5, config.hidden_size)
act_dict[(0, "z")] = torch.randn(2, 5, config.hidden_size)
act_dict[(0, "layer_in")] = torch.randn(2, 5, config.hidden_size)
act_dict[(0, "mlp")] = torch.randn(2, 5, config.hidden_size)
act_dict[(0, "attn")] = torch.randn(2, 5, config.hidden_size)

print(f"Original order: {list(act_dict.keys())}")

reorganized = act_dict.reorganize()
print(f"Reorganized order: {list(reorganized.keys())}")

# Expected order: (0, layer_in), (0, z), (0, attn), (0, mlp), (1, mlp)
keys = list(reorganized.keys())
assert keys[0] == (0, "layer_in"), "First should be (0, 'layer_in')"
assert keys[1] == (0, "z"), "Second should be (0, 'z')"
assert keys[2] == (0, "attn"), "Third should be (0, 'attn')"
assert keys[3] == (0, "mlp"), "Fourth should be (0, 'mlp')"
assert keys[4] == (1, "mlp"), "Fifth should be (1, 'mlp')"
print("PASSED: reorganize()")

## Test: extract_positions()

In [None]:
act_dict = ActivationDict(config, positions=[1, 3])  # positions to extract
act_dict[(0, "attn")] = torch.randn(2, 10, config.hidden_size)  # Full sequence

extracted = act_dict.extract_positions()
print(f"Original shape: {act_dict[(0, 'attn')].shape}")
print(f"Extracted shape: {extracted[(0, 'attn')].shape}")

# Should extract positions [1, 3] -> shape (batch, 2, hidden_size)
assert extracted[(0, "attn")].shape == (2, 2, config.hidden_size)
print("PASSED: extract_positions()")

## Test: get_grads()

In [None]:
act_dict = ActivationDict(config, positions=slice(None))
tensor = torch.randn(2, 5, config.hidden_size, requires_grad=True)
act_dict[(0, "attn")] = tensor

# Compute some gradients
loss = tensor.sum()
loss.backward()

print(f"Gradient exists: {tensor.grad is not None}")
print(f"Gradient shape: {tensor.grad.shape}")

# Get gradients
grads = act_dict.get_grads()
assert grads.value_type == "gradient", "Should have gradient value_type"
assert grads[(0, "attn")] is not None, "Should have gradient"
print(f"Retrieved gradient shape: {grads[(0, 'attn')].shape}")
print("PASSED: get_grads()")

## Summary

In [None]:
print("="*50)
print("All activation_dict module tests PASSED!")
print("="*50)