# Test: linear_probes module

This notebook tests the `LinearProbe` class from `mech_interp_toolkit.linear_probes`.

In [None]:
import sys
import os
sys.path.append(os.path.abspath("../src"))

import torch
import numpy as np
from transformers import AutoConfig
from mech_interp_toolkit.linear_probes import LinearProbe
from mech_interp_toolkit.activation_dict import ActivationDict
from mech_interp_toolkit.utils import set_global_seed

## Setup

In [None]:
# Set seed for reproducibility
set_global_seed(42)

# Load config
model_name = "Qwen/Qwen3-0.6B"
config = AutoConfig.from_pretrained(model_name)
print(f"Loaded config for {model_name}")
print(f"Hidden size: {config.hidden_size}")

## Test: LinearProbe initialization

In [None]:
# Test classification initialization
probe_cls = LinearProbe(target_type="classification")
print(f"Classification probe model: {type(probe_cls.linear_model).__name__}")
assert probe_cls.target_type == "classification"
assert probe_cls.broadcast_target == True
assert probe_cls.test_split == 0.2
print("PASSED: Classification initialization")

In [None]:
# Test regression initialization
probe_reg = LinearProbe(target_type="regression", test_split=0.3)
print(f"Regression probe model: {type(probe_reg.linear_model).__name__}")
assert probe_reg.target_type == "regression"
assert probe_reg.test_split == 0.3
print("PASSED: Regression initialization")

In [None]:
# Test invalid target_type
try:
    probe = LinearProbe(target_type="invalid")
    assert False, "Should have raised ValueError"
except ValueError as e:
    print(f"Correctly raised error: {e}")
print("PASSED: Invalid target_type error")

In [None]:
# Test invalid test_split
try:
    probe = LinearProbe(target_type="classification", test_split=0.0)
    assert False, "Should have raised ValueError"
except ValueError as e:
    print(f"Correctly raised error for test_split=0.0: {e}")

try:
    probe = LinearProbe(target_type="classification", test_split=1.0)
    assert False, "Should have raised ValueError"
except ValueError as e:
    print(f"Correctly raised error for test_split=1.0: {e}")

print("PASSED: Invalid test_split error")

## Test: Classification probe (2D inputs)

In [None]:
set_global_seed(42)

# Create synthetic classification data
n_samples = 200
d_model = config.hidden_size

# Create 2D inputs (batch, d_model) - single position
X_data = torch.randn(n_samples, d_model)
# Create linearly separable targets
true_weights = torch.randn(d_model)
y_data = (X_data @ true_weights > 0).long().numpy()

# Wrap in ActivationDict
X_act = ActivationDict(config, positions=slice(None))
X_act[(0, "mlp")] = X_data

print(f"X shape: {X_data.shape}")
print(f"y shape: {y_data.shape}")
print(f"y distribution: {np.bincount(y_data)}")

In [None]:
# Train classification probe
probe = LinearProbe(target_type="classification", test_split=0.2)
probe.fit(X_act, y_data)

print(f"\nWeight shape: {probe.weight.shape}")
print(f"Bias shape: {probe.bias.shape}")

assert probe.weight is not None, "Weight should be set after fitting"
assert probe.bias is not None, "Bias should be set after fitting"
print("PASSED: Classification probe training (2D)")

## Test: Classification probe (3D inputs)

In [None]:
set_global_seed(42)

# Create 3D inputs (batch, positions, d_model)
n_samples = 100
n_positions = 5
d_model = config.hidden_size

X_data_3d = torch.randn(n_samples, n_positions, d_model)
# Single target per sample (broadcast across positions)
y_data = np.random.randint(0, 2, n_samples)

X_act_3d = ActivationDict(config, positions=slice(None))
X_act_3d[(0, "mlp")] = X_data_3d

print(f"X 3D shape: {X_data_3d.shape}")
print(f"y shape: {y_data.shape}")

In [None]:
# Train with broadcast_target=True (default)
probe_broadcast = LinearProbe(target_type="classification", broadcast_target=True)
probe_broadcast.fit(X_act_3d, y_data)

print(f"Weight shape: {probe_broadcast.weight.shape}")
print("PASSED: Classification probe with broadcast_target=True")

In [None]:
# Train with broadcast_target=False (token-level targets)
# Need targets for each position
y_data_tokens = np.random.randint(0, 2, (n_samples, n_positions))

probe_no_broadcast = LinearProbe(target_type="classification", broadcast_target=False)
probe_no_broadcast.fit(X_act_3d, y_data_tokens)

print(f"Weight shape: {probe_no_broadcast.weight.shape}")
print("PASSED: Classification probe with broadcast_target=False")

## Test: Regression probe

In [None]:
set_global_seed(42)

# Create regression data
n_samples = 200
d_model = config.hidden_size

X_data = torch.randn(n_samples, d_model)
true_weights = torch.randn(d_model)
y_data = (X_data @ true_weights + 0.1 * torch.randn(n_samples)).numpy()

X_act = ActivationDict(config, positions=slice(None))
X_act[(0, "mlp")] = X_data

print(f"X shape: {X_data.shape}")
print(f"y shape: {y_data.shape}")
print(f"y range: [{y_data.min():.2f}, {y_data.max():.2f}]")

In [None]:
# Train regression probe
probe_reg = LinearProbe(target_type="regression", test_split=0.2)
probe_reg.fit(X_act, y_data)

print(f"\nWeight shape: {probe_reg.weight.shape}")
print(f"Bias: {probe_reg.bias}")
print("PASSED: Regression probe training")

## Test: Prediction

In [None]:
set_global_seed(42)

# Create and train a classification probe
n_train = 200
n_test = 50

X_train = torch.randn(n_train, config.hidden_size)
y_train = np.random.randint(0, 2, n_train)

X_train_act = ActivationDict(config, positions=slice(None))
X_train_act[(0, "mlp")] = X_train

probe = LinearProbe(target_type="classification")
probe.fit(X_train_act, y_train)

In [None]:
# Predict on new data
X_test = torch.randn(n_test, config.hidden_size)
y_test = np.random.randint(0, 2, n_test)

X_test_act = ActivationDict(config, positions=slice(None))
X_test_act[(0, "mlp")] = X_test

# Predict without targets
print("Prediction without targets:")
preds = probe.predict(X_test_act)
print(f"Predictions shape: {preds.shape}")
print(f"Unique predictions: {np.unique(preds)}")

In [None]:
# Predict with targets (shows metrics)
print("\nPrediction with targets:")
preds_with_metrics = probe.predict(X_test_act, target=y_test, label="Test")
print(f"Predictions shape: {preds_with_metrics.shape}")
print("PASSED: Prediction")

In [None]:
# Test predict on unfitted probe
unfitted_probe = LinearProbe(target_type="classification")

try:
    unfitted_probe.predict(X_test_act)
    assert False, "Should have raised ValueError"
except ValueError as e:
    print(f"Correctly raised error: {e}")

print("PASSED: Unfitted probe error")

## Test: Error handling for multiple components

In [None]:
# LinearProbe only supports single component in ActivationDict
X_multi = ActivationDict(config, positions=slice(None))
X_multi[(0, "attn")] = torch.randn(100, config.hidden_size)
X_multi[(0, "mlp")] = torch.randn(100, config.hidden_size)

y = np.random.randint(0, 2, 100)

probe = LinearProbe(target_type="classification")

try:
    probe.fit(X_multi, y)
    assert False, "Should have raised ValueError"
except ValueError as e:
    print(f"Correctly raised error: {e}")

print("PASSED: Multiple components error")

## Test: Using tensors as targets

In [None]:
set_global_seed(42)

X = torch.randn(100, config.hidden_size)
y_tensor = torch.randint(0, 2, (100,))

X_act = ActivationDict(config, positions=slice(None))
X_act[(0, "mlp")] = X

probe = LinearProbe(target_type="classification")
probe.fit(X_act, y_tensor)  # Pass tensor instead of numpy array

print(f"Weight shape: {probe.weight.shape}")
print("PASSED: Tensor targets")

## Test: With real activations

In [None]:
from mech_interp_toolkit.utils import load_model_tokenizer_config, get_default_device
from mech_interp_toolkit.activations import UnifiedAccessAndPatching

# Load model
device = get_default_device()
model, tok, cfg = load_model_tokenizer_config(model_name, device=device)
print(f"Model loaded on {device}")

In [None]:
# Generate activations for "positive" and "negative" prompts
positive_prompts = [
    "This is great!",
    "I love this!",
    "Amazing work!",
    "Wonderful!",
    "Excellent job!",
] * 10  # Repeat for more samples

negative_prompts = [
    "This is terrible!",
    "I hate this!",
    "Awful work!",
    "Horrible!",
    "Bad job!",
] * 10

all_prompts = positive_prompts + negative_prompts
labels = np.array([1] * len(positive_prompts) + [0] * len(negative_prompts))

print(f"Total samples: {len(all_prompts)}")
print(f"Labels: {np.bincount(labels)}")

In [None]:
# Extract activations
inputs = tok(all_prompts, thinking=False)

spec_dict = {
    "activations": {
        "positions": -1,  # Last position
        "locations": [(10, "mlp")],  # Middle layer
    }
}

with UnifiedAccessAndPatching(model, inputs, spec_dict) as uap:
    activations, _ = uap.unified_access_and_patching()

print(f"Activations shape: {activations[(10, 'mlp')].shape}")

In [None]:
# Train linear probe on real activations
probe = LinearProbe(target_type="classification", test_split=0.2)
probe.fit(activations, labels)

print(f"\nWeight shape: {probe.weight.shape}")
print("PASSED: Training on real activations")

## Summary

In [None]:
print("="*50)
print("All linear_probes module tests PASSED!")
print("="*50)