# Test: misc module

This notebook tests the `get_attention_pattern` function from `mech_interp_toolkit.misc`.

In [None]:
import sys
import os
sys.path.append(os.path.abspath("../src"))

import torch
from mech_interp_toolkit.utils import load_model_tokenizer_config, get_default_device
from mech_interp_toolkit.misc import get_attention_pattern

## Setup: Load model with eager attention

In [None]:
model_name = "Qwen/Qwen3-0.6B"
device = get_default_device()

# Load with eager attention to get attention patterns
print(f"Loading model {model_name} with eager attention on {device}...")
model, tokenizer, config = load_model_tokenizer_config(
    model_name, 
    device=device,
    attn_type="eager"  # Required for attention pattern extraction
)
print("Model loaded successfully")
print(f"Number of layers: {config.num_hidden_layers}")
print(f"Number of attention heads: {config.num_attention_heads}")
print(f"Attention implementation: {config._attn_implementation}")

In [None]:
# Prepare test inputs
prompts = ["The quick brown fox jumps over the lazy dog."]
inputs = tokenizer(prompts, thinking=False)

print(f"Input IDs shape: {inputs['input_ids'].shape}")
seq_len = inputs['input_ids'].shape[1]
print(f"Sequence length: {seq_len}")

## Test: Basic attention pattern extraction

In [None]:
# Extract attention patterns for specific layers and heads
layers = [0, 5]
head_indices = [[0, 1], [2, 3]]  # Heads for each layer

print(f"Extracting attention patterns for layers {layers}...")
attn_patterns = get_attention_pattern(
    model=model,
    inputs=inputs,
    layers=layers,
    head_indices=head_indices,
    query_position=-1,  # Last position
)

print(f"\nExtracted patterns for {len(attn_patterns)} layers")
for layer, pattern in attn_patterns.items():
    print(f"  Layer {layer}: shape = {pattern.shape}")

# Check shapes
# Expected: (batch, selected_heads, seq_len) for query_position=-1
assert 0 in attn_patterns, "Should have layer 0"
assert 5 in attn_patterns, "Should have layer 5"

print("PASSED: Basic attention pattern extraction")

In [None]:
# Analyze attention pattern
pattern_0 = attn_patterns[0]
print(f"Layer 0 attention pattern shape: {pattern_0.shape}")
print(f"Pattern sum per head (should be close to 1.0):")

for head_idx in range(pattern_0.shape[1]):
    head_sum = pattern_0[0, head_idx, :].sum().item()
    print(f"  Head {head_indices[0][head_idx]}: sum = {head_sum:.4f}")

# Attention should sum to approximately 1.0
assert abs(pattern_0[0, 0, :].sum().item() - 1.0) < 0.1, "Attention should sum to ~1.0"
print("PASSED: Attention pattern sums to ~1.0")

## Test: Single layer, single head

In [None]:
# Extract single head attention
attn_single = get_attention_pattern(
    model=model,
    inputs=inputs,
    layers=[10],
    head_indices=[[5]],  # Single head
    query_position=-1,
)

print(f"Single head pattern shape: {attn_single[10].shape}")
assert attn_single[10].shape[1] == 1, "Should have single head"
print("PASSED: Single head extraction")

## Test: Multiple heads per layer

In [None]:
# Extract all heads for a single layer
n_heads = config.num_attention_heads
all_heads = list(range(n_heads))

print(f"Extracting all {n_heads} heads for layer 0...")
attn_all = get_attention_pattern(
    model=model,
    inputs=inputs,
    layers=[0],
    head_indices=[all_heads],
    query_position=-1,
)

print(f"All heads pattern shape: {attn_all[0].shape}")
assert attn_all[0].shape[1] == n_heads, f"Should have {n_heads} heads"
print("PASSED: All heads extraction")

## Test: Different query positions

In [None]:
# Test with different query positions
layers = [0]
head_indices = [[0, 1, 2]]

print("Testing different query positions:")
for query_pos in [-1, 0, seq_len // 2]:
    attn = get_attention_pattern(
        model=model,
        inputs=inputs,
        layers=layers,
        head_indices=head_indices,
        query_position=query_pos,
    )
    print(f"  query_position={query_pos}: shape = {attn[0].shape}")

print("PASSED: Different query positions")

## Test: Error handling - mismatched layers and heads

In [None]:
# Test error when layers and head_indices don't match
try:
    attn = get_attention_pattern(
        model=model,
        inputs=inputs,
        layers=[0, 1, 2],  # 3 layers
        head_indices=[[0], [1]],  # Only 2 head lists
        query_position=-1,
    )
    assert False, "Should have raised ValueError"
except ValueError as e:
    print(f"Correctly raised error: {e}")

print("PASSED: Mismatch error handling")

## Test: Attention pattern visualization data

In [None]:
# Get attention data that could be used for visualization
layers = [0, 10, config.num_hidden_layers - 1]
head_indices = [[0], [0], [0]]  # Head 0 for each layer

print("Extracting attention patterns across layers...")
attn_viz = get_attention_pattern(
    model=model,
    inputs=inputs,
    layers=layers,
    head_indices=head_indices,
    query_position=-1,
)

print("\nAttention weights from last token (query) to all tokens (keys):")
for layer in layers:
    pattern = attn_viz[layer][0, 0, :].cpu().numpy()  # First batch, first head
    top_5_idx = pattern.argsort()[-5:][::-1]  # Top 5 attended positions
    print(f"\nLayer {layer}:")
    print(f"  Pattern shape: {attn_viz[layer].shape}")
    print(f"  Top 5 positions: {top_5_idx}")
    print(f"  Top 5 weights: {pattern[top_5_idx]}")

print("\nPASSED: Attention pattern analysis")

## Test: With batch of inputs

In [None]:
# Test with multiple prompts
batch_prompts = [
    "Hello world!",
    "The cat sat on the mat.",
    "Machine learning is fascinating."
]
batch_inputs = tokenizer(batch_prompts, thinking=False)

print(f"Batch input shape: {batch_inputs['input_ids'].shape}")

attn_batch = get_attention_pattern(
    model=model,
    inputs=batch_inputs,
    layers=[5],
    head_indices=[[0, 1]],
    query_position=-1,
)

print(f"Batch attention pattern shape: {attn_batch[5].shape}")
# Should be (batch_size, num_heads, seq_len)
assert attn_batch[5].shape[0] == 3, "Should have batch size 3"
print("PASSED: Batch attention extraction")

## Test: Warning for non-eager attention (informational)

In [None]:
# Load model with non-eager attention to see warning
import warnings

print("Loading model with SDPA attention...")
model_sdpa, tok_sdpa, cfg_sdpa = load_model_tokenizer_config(
    model_name,
    device=device,
    attn_type="sdpa"
)
print(f"Attention implementation: {cfg_sdpa._attn_implementation}")

# This should emit a warning
print("\nAttempting to get attention patterns (may emit warning)...")
with warnings.catch_warnings(record=True) as w:
    warnings.simplefilter("always")
    try:
        attn = get_attention_pattern(
            model=model_sdpa,
            inputs=inputs,
            layers=[0],
            head_indices=[[0]],
            query_position=-1,
        )
        if len(w) > 0:
            print(f"Warning emitted: {w[-1].message}")
    except Exception as e:
        print(f"Error (expected for non-eager): {e}")

print("PASSED: Non-eager attention handling")

## Summary

In [None]:
print("="*50)
print("All misc module tests PASSED!")
print("="*50)