# Test: utils module

This notebook tests all utility functions in `mech_interp_toolkit.utils`.

In [None]:
import sys
import os
sys.path.append(os.path.abspath("../src"))

import torch
import numpy as np
import random

from mech_interp_toolkit.utils import (
    set_global_seed,
    load_model_tokenizer_config,
    get_position_ids,
    input_dict_to_tuple,
    get_logit_difference,
    regularize_position,
    get_num_layers,
    get_all_layer_components,
    get_default_device,
    build_dataloader,
)

## Test: get_default_device()

In [None]:
device = get_default_device()
print(f"Default device: {device}")
assert device in ["cuda", "cpu"], "Device must be 'cuda' or 'cpu'"
print("PASSED: get_default_device()")

## Test: set_global_seed()

In [None]:
# Test reproducibility
set_global_seed(42)
r1 = random.random()
n1 = np.random.rand()
t1 = torch.rand(1).item()

set_global_seed(42)
r2 = random.random()
n2 = np.random.rand()
t2 = torch.rand(1).item()

assert r1 == r2, "Random seed not reproducible"
assert n1 == n2, "Numpy seed not reproducible"
assert t1 == t2, "Torch seed not reproducible"
print(f"Random: {r1} == {r2}")
print(f"Numpy: {n1} == {n2}")
print(f"Torch: {t1} == {t2}")
print("PASSED: set_global_seed()")

## Test: load_model_tokenizer_config()

In [None]:
model_name = "Qwen/Qwen3-0.6B"
device = get_default_device()

print(f"Loading model {model_name} on {device}...")
model, tokenizer, config = load_model_tokenizer_config(model_name, device=device)
print(f"Model loaded successfully")
print(f"Config type: {type(config).__name__}")
print(f"Number of layers: {config.num_hidden_layers}")
print(f"Hidden size: {config.hidden_size}")
print(f"Number of attention heads: {config.num_attention_heads}")
print("PASSED: load_model_tokenizer_config()")

## Test: get_position_ids()

In [None]:
# Test with padded attention mask (left padding)
attention_mask = torch.tensor([
    [0, 0, 1, 1, 1],  # 2 padding tokens
    [0, 1, 1, 1, 1],  # 1 padding token
    [1, 1, 1, 1, 1],  # no padding
])

position_ids = get_position_ids(attention_mask)
print(f"Attention mask:\n{attention_mask}")
print(f"Position IDs:\n{position_ids}")

# Expected: position_ids for non-padded tokens should be 0, 1, 2, ...
expected = torch.tensor([
    [1, 1, 0, 1, 2],
    [1, 0, 1, 2, 3],
    [0, 1, 2, 3, 4],
])
assert torch.equal(position_ids, expected), f"Position IDs mismatch. Got:\n{position_ids}\nExpected:\n{expected}"
print("PASSED: get_position_ids()")

## Test: input_dict_to_tuple()

In [None]:
# Create sample input dict
input_dict = {
    "input_ids": torch.tensor([[1, 2, 3, 4]]),
    "attention_mask": torch.tensor([[1, 1, 1, 1]]),
}

input_ids, attention_mask, position_ids = input_dict_to_tuple(input_dict, device="cpu")

print(f"Input IDs shape: {input_ids.shape}")
print(f"Attention mask shape: {attention_mask.shape}")
print(f"Position IDs shape: {position_ids.shape}")
print(f"Position IDs: {position_ids}")

assert input_ids.shape == (1, 4), "Input IDs shape mismatch"
assert attention_mask.shape == (1, 4), "Attention mask shape mismatch"
assert position_ids.shape == (1, 4), "Position IDs shape mismatch"
print("PASSED: input_dict_to_tuple()")

## Test: regularize_position()

In [None]:
# Test int input
result = regularize_position(5)
assert result == [5], f"Expected [5], got {result}"
print(f"regularize_position(5) = {result}")

# Test None input
result = regularize_position(None)
assert result == slice(None), f"Expected slice(None), got {result}"
print(f"regularize_position(None) = {result}")

# Test slice input
result = regularize_position(slice(1, 5))
assert result == slice(1, 5), f"Expected slice(1, 5), got {result}"
print(f"regularize_position(slice(1, 5)) = {result}")

# Test list input
result = regularize_position([1, 2, 3])
assert result == [1, 2, 3], f"Expected [1, 2, 3], got {result}"
print(f"regularize_position([1, 2, 3]) = {result}")

print("PASSED: regularize_position()")

## Test: get_num_layers()

In [None]:
n_layers = get_num_layers(model)
print(f"Number of layers: {n_layers}")
assert n_layers == config.num_hidden_layers, "Layer count mismatch"
assert isinstance(n_layers, int), "Layer count should be an integer"
print("PASSED: get_num_layers()")

## Test: get_all_layer_components()

In [None]:
layer_components = get_all_layer_components(model)
print(f"Number of layer components: {len(layer_components)}")
print(f"First 6 components: {layer_components[:6]}")

# Should have 2 components (attn, mlp) per layer
expected_count = n_layers * 2
assert len(layer_components) == expected_count, f"Expected {expected_count} components, got {len(layer_components)}"

# Check ordering: (0, attn), (0, mlp), (1, attn), (1, mlp), ...
assert layer_components[0] == (0, "attn"), "First component should be (0, 'attn')"
assert layer_components[1] == (0, "mlp"), "Second component should be (0, 'mlp')"
print("PASSED: get_all_layer_components()")

## Test: get_logit_difference()

In [None]:
# Get some logits from the model
prompts = ["The answer is"]
inputs = tokenizer(prompts, thinking=False)

with torch.no_grad():
    outputs = model.model(inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device))
    logits = outputs.logits[:, -1, :]  # Last position logits

print(f"Logits shape: {logits.shape}")

# Test logit difference between two tokens
tokens = ["A", "B"]
logit_diff = get_logit_difference(logits, tokenizer, tokens)
print(f"Logit difference (A - B): {logit_diff.item():.4f}")

assert logit_diff.shape == (1,), f"Expected shape (1,), got {logit_diff.shape}"
print("PASSED: get_logit_difference()")

## Test: build_dataloader()

In [None]:
# Test with tensor dataset
tensor_data = torch.randn(100, 10)
dataloader = build_dataloader(tensor_data, batch_size=16)

batch = next(iter(dataloader))
print(f"Batch shape: {batch.shape}")
assert batch.shape == (16, 10), f"Expected shape (16, 10), got {batch.shape}"

# Test with list dataset
list_data = list(range(50))
dataloader = build_dataloader(list_data, batch_size=10)

batch = next(iter(dataloader))
print(f"List batch: {batch}")
assert len(batch) == 10, f"Expected batch size 10, got {len(batch)}"

print("PASSED: build_dataloader()")

## Summary

In [None]:
print("="*50)
print("All utils module tests PASSED!")
print("="*50)