In [1]:
# !git clone https://github.com/SD-interp/mech_interp_toolkit.git
# %cd mech_interp_toolkit
# !pip install -e .

# Test: direct_logit_attribution module

This notebook tests Direct Logit Attribution (DLA) functions from `mech_interp_toolkit.direct_logit_attribution`.

In [2]:
import sys
import os
sys.path.append(os.path.abspath("../src"))

import torch
from mech_interp_toolkit.utils import load_model_tokenizer_config, get_default_device
from mech_interp_toolkit.direct_logit_attribution import (
    get_pre_rms_logit_diff_direction,
    run_componentwise_dla,
    run_headwise_dla_for_layer,
)

## Setup: Load model

In [3]:
model_name = "Qwen/Qwen3-0.6B"
device = get_default_device()

print(f"Loading model {model_name} on {device}...")
model, tokenizer, config = load_model_tokenizer_config(model_name, device=device)
print("Model loaded successfully")
print(f"Number of layers: {config.num_hidden_layers}")
print(f"Hidden size: {config.hidden_size}")
print(f"Number of attention heads: {config.num_attention_heads}")

Loading model Qwen/Qwen3-0.6B on cuda...


Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


config.json:   0%|          | 0.00/726 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.50G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Model loaded successfully
Number of layers: 28
Hidden size: 1024
Number of attention heads: 16


## Test: get_pre_rms_logit_diff_direction()

In [4]:
# Test with valid single-token pairs
token_pair = ["A", "B"]

print(f"Computing logit diff direction for tokens: {token_pair}")
direction = get_pre_rms_logit_diff_direction(token_pair, tokenizer, model)

print(f"Direction shape: {direction.shape}")
print(f"Direction norm: {direction.norm().item():.4f}")
print(f"Direction device: {direction.device}")

assert direction.shape == (config.hidden_size,), f"Expected shape ({config.hidden_size},), got {direction.shape}"
assert direction.norm().item() > 0, "Direction should not be zero"
print("PASSED: Basic logit diff direction")

Computing logit diff direction for tokens: ['A', 'B']
Direction shape: torch.Size([1024])
Direction norm: 3.4329
Direction device: cuda:0
PASSED: Basic logit diff direction


In [5]:
# Test with different token pairs
token_pairs_to_test = [
    ["Yes", "No"],
    ["true", "false"],
    ["1", "0"],
]

for pair in token_pairs_to_test:
    try:
        direction = get_pre_rms_logit_diff_direction(pair, tokenizer, model)
        print(f"Tokens {pair}: direction norm = {direction.norm().item():.4f}")
    except ValueError as e:
        print(f"Tokens {pair}: Skipped - {e}")

print("PASSED: Multiple token pairs")

Tokens ['Yes', 'No']: direction norm = 3.8051
Tokens ['true', 'false']: direction norm = 2.9281
Tokens ['1', '0']: direction norm = 2.7252
PASSED: Multiple token pairs


In [6]:
# Test error handling: wrong number of tokens
try:
    direction = get_pre_rms_logit_diff_direction(["A"], tokenizer, model)
    assert False, "Should have raised ValueError"
except ValueError as e:
    print(f"Correctly raised error for single token: {e}")

try:
    direction = get_pre_rms_logit_diff_direction(["A", "B", "C"], tokenizer, model)
    assert False, "Should have raised ValueError"
except ValueError as e:
    print(f"Correctly raised error for three tokens: {e}")

print("PASSED: Error handling for wrong token count")

Correctly raised error for single token: Provide exactly two target tokens.
Correctly raised error for three tokens: Provide exactly two target tokens.
PASSED: Error handling for wrong token count


## Test: run_componentwise_dla()

In [7]:
# Prepare inputs
prompts = [
    "The answer is definitely",
    "I think the result is"
]
inputs = tokenizer(prompts, thinking=False)
print(f"Input shape: {inputs['input_ids'].shape}")

# Get direction
direction = get_pre_rms_logit_diff_direction(["A", "B"], tokenizer, model)
direction = direction.to(device)

Input shape: torch.Size([2, 33])


In [8]:
print("Running component-wise DLA...")
dla_scores = run_componentwise_dla(model, inputs, direction)

print(f"\nDLA scores computed for {len(dla_scores)} components:")
for key, score in list(dla_scores.items())[:10]:  # Show first 10
    print(f"  {key}: {score}")
if len(dla_scores) > 10:
    print(f"  ... and {len(dla_scores) - 10} more")

# Check that we have scores for attn and mlp at each layer
n_layers = config.num_hidden_layers
expected_keys = [(i, "attn") for i in range(n_layers)] + [(i, "mlp") for i in range(n_layers)]

for key in expected_keys:
    assert key in dla_scores, f"Missing DLA score for {key}"

print(f"\nTotal components with DLA scores: {len(dla_scores)}")
print("PASSED: Component-wise DLA")

Running component-wise DLA...

DLA scores computed for 56 components:
  (0, 'attn'): tensor([-7.2912e-05, -1.4762e-04], device='cuda:0', grad_fn=<DivBackward0>)
  (0, 'mlp'): tensor([0.0003, 0.0003], device='cuda:0', grad_fn=<DivBackward0>)
  (1, 'attn'): tensor([0.0013, 0.0014], device='cuda:0', grad_fn=<DivBackward0>)
  (1, 'mlp'): tensor([-0.0007, -0.0008], device='cuda:0', grad_fn=<DivBackward0>)
  (2, 'attn'): tensor([0.0003, 0.0003], device='cuda:0', grad_fn=<DivBackward0>)
  (2, 'mlp'): tensor([0.0019, 0.0019], device='cuda:0', grad_fn=<DivBackward0>)
  (3, 'attn'): tensor([-0.0006, -0.0004], device='cuda:0', grad_fn=<DivBackward0>)
  (3, 'mlp'): tensor([0.0006, 0.0003], device='cuda:0', grad_fn=<DivBackward0>)
  (4, 'attn'): tensor([-0.0007, -0.0007], device='cuda:0', grad_fn=<DivBackward0>)
  (4, 'mlp'): tensor([-0.0021, -0.0020], device='cuda:0', grad_fn=<DivBackward0>)
  ... and 46 more

Total components with DLA scores: 56
PASSED: Component-wise DLA


In [9]:
# Analyze DLA scores
attn_scores = [dla_scores[(i, "attn")].mean().item() for i in range(n_layers)]
mlp_scores = [dla_scores[(i, "mlp")].mean().item() for i in range(n_layers)]

print("Mean DLA scores by layer:")
print(f"{'Layer':<8} {'Attn':>12} {'MLP':>12}")
print("-" * 32)
for i in range(n_layers):
    print(f"{i:<8} {attn_scores[i]:>12.4f} {mlp_scores[i]:>12.4f}")

print(f"\nTotal Attn contribution: {sum(attn_scores):.4f}")
print(f"Total MLP contribution: {sum(mlp_scores):.4f}")

Mean DLA scores by layer:
Layer            Attn          MLP
--------------------------------
0             -0.0001       0.0003
1              0.0013      -0.0008
2              0.0003       0.0019
3             -0.0005       0.0004
4             -0.0007      -0.0021
5             -0.0009      -0.0013
6              0.0011       0.0027
7             -0.0003      -0.0026
8             -0.0002      -0.0025
9             -0.0003       0.0005
10             0.0003       0.0044
11             0.0009      -0.0025
12            -0.0012       0.0019
13            -0.0040      -0.0014
14             0.0007       0.0018
15            -0.0002       0.0074
16             0.0016       0.0016
17             0.0114      -0.0085
18             0.0027      -0.0029
19             0.0012      -0.0009
20             0.0118       0.0240
21             0.0178       0.0122
22             0.0052       0.0192
23             0.0079      -0.0152
24             0.0095      -0.0022
25             0.0486      -0.0

## Test: run_headwise_dla_for_layer()

In [10]:
# Test for a specific layer
layer = 10

print(f"Running head-wise DLA for layer {layer}...")
head_dla = run_headwise_dla_for_layer(model, inputs, direction, layer)

print(f"Head DLA shape: {head_dla.shape}")
print(f"Expected: (batch_size, num_heads) = ({inputs['input_ids'].shape[0]}, {config.num_attention_heads})")

assert head_dla.shape == (inputs['input_ids'].shape[0], config.num_attention_heads), \
    f"Shape mismatch: got {head_dla.shape}"

print(f"\nHead DLA scores for layer {layer}:")
for head_idx in range(config.num_attention_heads):
    mean_score = head_dla[:, head_idx].mean().item()
    print(f"  Head {head_idx}: {mean_score:.6f}")

print("PASSED: Head-wise DLA for layer")

Running head-wise DLA for layer 10...
Head DLA shape: torch.Size([2, 16])
Expected: (batch_size, num_heads) = (2, 16)

Head DLA scores for layer 10:
  Head 0: -0.000624
  Head 1: 0.000400
  Head 2: 0.001769
  Head 3: -0.000134
  Head 4: -0.000593
  Head 5: 0.000091
  Head 6: -0.001611
  Head 7: 0.000448
  Head 8: 0.000344
  Head 9: 0.000171
  Head 10: -0.000002
  Head 11: 0.000705
  Head 12: -0.000004
  Head 13: -0.000231
  Head 14: -0.000295
  Head 15: -0.000162
PASSED: Head-wise DLA for layer




In [11]:
# Test head-wise DLA for multiple layers
print("\nHead-wise DLA summary across layers:")
print(f"{'Layer':<8} {'Max Head':>10} {'Max Score':>12} {'Min Score':>12}")
print("-" * 44)

for layer in [0, 5, 10, 15, n_layers - 1]:
    if layer >= n_layers:
        continue
    head_dla = run_headwise_dla_for_layer(model, inputs, direction, layer)
    mean_scores = head_dla.mean(dim=0)  # Average across batch
    max_head = mean_scores.argmax().item()
    max_score = mean_scores.max().item()
    min_score = mean_scores.min().item()
    print(f"{layer:<8} {max_head:>10} {max_score:>12.6f} {min_score:>12.6f}")

print("PASSED: Head-wise DLA across layers")


Head-wise DLA summary across layers:
Layer      Max Head    Max Score    Min Score
--------------------------------------------
0                 9     0.000698    -0.000397
5                13     0.000222    -0.000609
10                2     0.001769    -0.001611
15                3     0.001319    -0.001703
27               15     0.020299    -0.002251
PASSED: Head-wise DLA across layers


## Test: DLA with different inputs

In [12]:
# Test with single input
single_input = tokenizer(["What is 2 + 2?"], thinking=False)

dla_scores_single = run_componentwise_dla(model, single_input, direction)
print(f"Single input DLA: {len(dla_scores_single)} components")

# Check score shapes for single batch
sample_score = dla_scores_single[(0, "attn")]
print(f"Sample score shape: {sample_score.shape}")
print("PASSED: DLA with single input")

Single input DLA: 56 components
Sample score shape: torch.Size([1])
PASSED: DLA with single input


In [13]:
# Test with custom direction (random)
random_direction = torch.randn(config.hidden_size, device=device)
random_direction = random_direction / random_direction.norm()  # Normalize

dla_scores_random = run_componentwise_dla(model, inputs, random_direction)

print(f"DLA with random direction: {len(dla_scores_random)} components")
print(f"Sample attn score (layer 0): {dla_scores_random[(0, 'attn')]}")
print("PASSED: DLA with custom direction")

DLA with random direction: 56 components
Sample attn score (layer 0): tensor([-0.0003, -0.0003], device='cuda:0')
PASSED: DLA with custom direction


## Test: Consistency check

In [14]:
# Run DLA twice with same inputs - should get same results
dla_scores_1 = run_componentwise_dla(model, inputs, direction)
dla_scores_2 = run_componentwise_dla(model, inputs, direction)

all_close = True
for key in dla_scores_1.keys():
    if not torch.allclose(dla_scores_1[key], dla_scores_2[key], rtol=1e-4):
        all_close = False
        print(f"Mismatch at {key}")
        break

assert all_close, "DLA results should be consistent"
print("PASSED: DLA consistency check")

PASSED: DLA consistency check


## Summary

In [15]:
print("="*50)
print("All direct_logit_attribution module tests PASSED!")
print("="*50)

All direct_logit_attribution module tests PASSED!
