# Test: direct_logit_attribution module

This notebook tests Direct Logit Attribution (DLA) functions from `mech_interp_toolkit.direct_logit_attribution`.

In [None]:
import sys
import os
sys.path.append(os.path.abspath("../src"))

import torch
from mech_interp_toolkit.utils import load_model_tokenizer_config, get_default_device
from mech_interp_toolkit.direct_logit_attribution import (
    get_pre_rms_logit_diff_direction,
    run_componentwise_dla,
    run_headwise_dla_for_layer,
)

## Setup: Load model

In [None]:
model_name = "Qwen/Qwen3-0.6B"
device = get_default_device()

print(f"Loading model {model_name} on {device}...")
model, tokenizer, config = load_model_tokenizer_config(model_name, device=device)
print("Model loaded successfully")
print(f"Number of layers: {config.num_hidden_layers}")
print(f"Hidden size: {config.hidden_size}")
print(f"Number of attention heads: {config.num_attention_heads}")

## Test: get_pre_rms_logit_diff_direction()

In [None]:
# Test with valid single-token pairs
token_pair = ["A", "B"]

print(f"Computing logit diff direction for tokens: {token_pair}")
direction = get_pre_rms_logit_diff_direction(token_pair, tokenizer, model)

print(f"Direction shape: {direction.shape}")
print(f"Direction norm: {direction.norm().item():.4f}")
print(f"Direction device: {direction.device}")

assert direction.shape == (config.hidden_size,), f"Expected shape ({config.hidden_size},), got {direction.shape}"
assert direction.norm().item() > 0, "Direction should not be zero"
print("PASSED: Basic logit diff direction")

In [None]:
# Test with different token pairs
token_pairs_to_test = [
    ["Yes", "No"],
    ["true", "false"],
    ["1", "0"],
]

for pair in token_pairs_to_test:
    try:
        direction = get_pre_rms_logit_diff_direction(pair, tokenizer, model)
        print(f"Tokens {pair}: direction norm = {direction.norm().item():.4f}")
    except ValueError as e:
        print(f"Tokens {pair}: Skipped - {e}")

print("PASSED: Multiple token pairs")

In [None]:
# Test error handling: wrong number of tokens
try:
    direction = get_pre_rms_logit_diff_direction(["A"], tokenizer, model)
    assert False, "Should have raised ValueError"
except ValueError as e:
    print(f"Correctly raised error for single token: {e}")

try:
    direction = get_pre_rms_logit_diff_direction(["A", "B", "C"], tokenizer, model)
    assert False, "Should have raised ValueError"
except ValueError as e:
    print(f"Correctly raised error for three tokens: {e}")

print("PASSED: Error handling for wrong token count")

## Test: run_componentwise_dla()

In [None]:
# Prepare inputs
prompts = [
    "The answer is definitely",
    "I think the result is"
]
inputs = tokenizer(prompts, thinking=False)
print(f"Input shape: {inputs['input_ids'].shape}")

# Get direction
direction = get_pre_rms_logit_diff_direction(["A", "B"], tokenizer, model)
direction = direction.to(device)

In [None]:
print("Running component-wise DLA...")
dla_scores = run_componentwise_dla(model, inputs, direction)

print(f"\nDLA scores computed for {len(dla_scores)} components:")
for key, score in list(dla_scores.items())[:10]:  # Show first 10
    print(f"  {key}: {score}")
if len(dla_scores) > 10:
    print(f"  ... and {len(dla_scores) - 10} more")

# Check that we have scores for attn and mlp at each layer
n_layers = config.num_hidden_layers
expected_keys = [(i, "attn") for i in range(n_layers)] + [(i, "mlp") for i in range(n_layers)]

for key in expected_keys:
    assert key in dla_scores, f"Missing DLA score for {key}"

print(f"\nTotal components with DLA scores: {len(dla_scores)}")
print("PASSED: Component-wise DLA")

In [None]:
# Analyze DLA scores
attn_scores = [dla_scores[(i, "attn")].mean().item() for i in range(n_layers)]
mlp_scores = [dla_scores[(i, "mlp")].mean().item() for i in range(n_layers)]

print("Mean DLA scores by layer:")
print(f"{'Layer':<8} {'Attn':>12} {'MLP':>12}")
print("-" * 32)
for i in range(n_layers):
    print(f"{i:<8} {attn_scores[i]:>12.4f} {mlp_scores[i]:>12.4f}")

print(f"\nTotal Attn contribution: {sum(attn_scores):.4f}")
print(f"Total MLP contribution: {sum(mlp_scores):.4f}")

## Test: run_headwise_dla_for_layer()

In [None]:
# Test for a specific layer
layer = 10

print(f"Running head-wise DLA for layer {layer}...")
head_dla = run_headwise_dla_for_layer(model, inputs, direction, layer)

print(f"Head DLA shape: {head_dla.shape}")
print(f"Expected: (batch_size, num_heads) = ({inputs['input_ids'].shape[0]}, {config.num_attention_heads})")

assert head_dla.shape == (inputs['input_ids'].shape[0], config.num_attention_heads), \
    f"Shape mismatch: got {head_dla.shape}"

print(f"\nHead DLA scores for layer {layer}:")
for head_idx in range(config.num_attention_heads):
    mean_score = head_dla[:, head_idx].mean().item()
    print(f"  Head {head_idx}: {mean_score:.6f}")

print("PASSED: Head-wise DLA for layer")

In [None]:
# Test head-wise DLA for multiple layers
print("\nHead-wise DLA summary across layers:")
print(f"{'Layer':<8} {'Max Head':>10} {'Max Score':>12} {'Min Score':>12}")
print("-" * 44)

for layer in [0, 5, 10, 15, n_layers - 1]:
    if layer >= n_layers:
        continue
    head_dla = run_headwise_dla_for_layer(model, inputs, direction, layer)
    mean_scores = head_dla.mean(dim=0)  # Average across batch
    max_head = mean_scores.argmax().item()
    max_score = mean_scores.max().item()
    min_score = mean_scores.min().item()
    print(f"{layer:<8} {max_head:>10} {max_score:>12.6f} {min_score:>12.6f}")

print("PASSED: Head-wise DLA across layers")

## Test: DLA with different inputs

In [None]:
# Test with single input
single_input = tokenizer(["What is 2 + 2?"], thinking=False)

dla_scores_single = run_componentwise_dla(model, single_input, direction)
print(f"Single input DLA: {len(dla_scores_single)} components")

# Check score shapes for single batch
sample_score = dla_scores_single[(0, "attn")]
print(f"Sample score shape: {sample_score.shape}")
print("PASSED: DLA with single input")

In [None]:
# Test with custom direction (random)
random_direction = torch.randn(config.hidden_size, device=device)
random_direction = random_direction / random_direction.norm()  # Normalize

dla_scores_random = run_componentwise_dla(model, inputs, random_direction)

print(f"DLA with random direction: {len(dla_scores_random)} components")
print(f"Sample attn score (layer 0): {dla_scores_random[(0, 'attn')]}")
print("PASSED: DLA with custom direction")

## Test: Consistency check

In [None]:
# Run DLA twice with same inputs - should get same results
dla_scores_1 = run_componentwise_dla(model, inputs, direction)
dla_scores_2 = run_componentwise_dla(model, inputs, direction)

all_close = True
for key in dla_scores_1.keys():
    if not torch.allclose(dla_scores_1[key], dla_scores_2[key], rtol=1e-4):
        all_close = False
        print(f"Mismatch at {key}")
        break

assert all_close, "DLA results should be consistent"
print("PASSED: DLA consistency check")

## Summary

In [None]:
print("="*50)
print("All direct_logit_attribution module tests PASSED!")
print("="*50)