# Test: Hook-Based Activation Extraction + SAE

Tests:
1. Model/SAE loading via `load_sae_model`
2. Layer index parsing from SAE hook name
3. Hook captures correct activations (verified against `output_hidden_states`)
4. SAE reconstruction quality (FVU) and L0 sparsity
5. `FeatureExtractor` end-to-end

In [None]:
import torch
import numpy as np

torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

## 1. Load Model & SAE via `load_sae_model`

In [None]:
from interpret_personas.extraction.sae_loader import load_sae_model

model, sae, tokenizer = load_sae_model(
    model_name="google/gemma-3-27b-it",
    sae_release="gemma-scope-2-27b-it-res",
    sae_id="layer_40_width_65k_l0_medium",
)

print(f"Model type: {type(model).__name__}")
print(f"SAE hook: {sae.cfg.metadata.hook_name}")
print(f"SAE dimensions: d_in={sae.cfg.d_in}, d_sae={sae.cfg.d_sae}")

## 2. Layer Index Parsing

Verify `_parse_layer_index` extracts the correct layer from the SAE's hook name.

In [None]:
from interpret_personas.extraction.feature_extractor import FeatureExtractor, _gather_residual_activations

LAYER = FeatureExtractor._parse_layer_index(sae.cfg.metadata.hook_name)
print(f"Hook name: {sae.cfg.metadata.hook_name}")
print(f"Parsed layer: {LAYER}")
assert LAYER == 40, f"Expected layer 40, got {LAYER}"

## 3. Hook-Based Activation Capture

Test `_gather_residual_activations` on a simple input.

In [None]:
test_text = "The quick brown fox jumps over the lazy dog."
inputs = tokenizer.encode(test_text, return_tensors="pt").to(device)

target_act = _gather_residual_activations(model, LAYER, inputs)
print(f"Input shape: {inputs.shape}")
print(f"Activation shape: {target_act.shape}")
print(f"Activation dtype: {target_act.dtype}")
print(f"Activation device: {target_act.device}")

## 4. Verify Hook vs `output_hidden_states`

Sanity check: hook output should match `output_hidden_states` at the same layer.

In [None]:
with torch.no_grad():
    outputs = model(inputs, output_hidden_states=True)

# hidden_states[0] = embeddings, hidden_states[L+1] = output of layer L
ohs_act = outputs.hidden_states[LAYER + 1]

max_diff = (target_act - ohs_act).abs().max().item()
print(f"Max absolute difference (hook vs output_hidden_states): {max_diff}")
assert max_diff < 1e-5, f"Mismatch! max_diff={max_diff}"

## 5. SAE Encode / Decode + FVU

Fraction of Variance Unexplained measures reconstruction quality.
Lower is better (< 10% is typical for a well-trained SAE).

In [None]:
sae_acts = sae.encode(target_act.to(sae.dtype).to(sae.device))
recon = sae.decode(sae_acts)

# Skip BOS token (position 0) for FVU â€” SAE may not have been trained on it
reconstruction_mse = torch.mean((recon[:, 1:] - target_act[:, 1:].to(sae.dtype)) ** 2)
target_variance = target_act[:, 1:].to(sae.dtype).var()

fvu = reconstruction_mse / target_variance
print(f"Fraction of variance unexplained: {fvu:.2%}")

## 6. L0 Sparsity

Number of active features per token. Should be sparse (typical: 50-200 active out of 65k).

In [None]:
l0_per_token = (sae_acts > 0).sum(-1)[0]
print(f"L0 per token: {l0_per_token.tolist()}")
print(f"Average L0 (excluding BOS): {l0_per_token[1:].float().mean():.2f}")
print(f"SAE feature shape: {sae_acts.shape}")

## 7. Test `FeatureExtractor` End-to-End

Verify the full class produces correct output shapes and sensible values.

In [None]:
conversation = [
    {"role": "user", "content": "What is the capital of France?"},
    {"role": "assistant", "content": "The capital of France is Paris, a city known for the Eiffel Tower and its rich cultural heritage."},
]

extractor = FeatureExtractor(model=model, sae=sae, tokenizer=tokenizer)
print(f"Target layer: {extractor._target_layer}")
assert extractor._target_layer == LAYER

In [None]:
result = extractor.extract_from_conversation(conversation, token_selection="response_only")

print(f"Mean feature shape: {result['mean'].shape}")
print(f"Max feature shape: {result['max'].shape}")
print(f"Mean vector L2 norm: {np.linalg.norm(result['mean']):.4f}")
print(f"Max vector L2 norm: {np.linalg.norm(result['max']):.4f}")
print(f"Nonzero mean features: {(result['mean'] > 0).sum()}")
print(f"Nonzero max features: {(result['max'] > 0).sum()}")

assert result["mean"].shape == (sae.cfg.d_sae,)
assert result["max"].shape == (sae.cfg.d_sae,)
assert result["mean"].max() > 0, "Mean features should have positive activations"


In [None]:
result_all = extractor.extract_from_conversation(conversation, token_selection="all")

print(f"response_only mean L2: {np.linalg.norm(result['mean']):.4f}")
print(f"all tokens mean L2:    {np.linalg.norm(result_all['mean']):.4f}")
print(f"\nVectors differ: {not np.allclose(result['mean'], result_all['mean'])}")

In [None]:
batch_results = extractor.extract_batch(
    [
        conversation,
        [
            {"role": "user", "content": "Explain photosynthesis briefly."},
            {"role": "assistant", "content": "Photosynthesis converts sunlight, water, and CO2 into glucose and oxygen."},
        ],
    ],
    token_selection="response_only",
)

print(f"Batch size: {len(batch_results)}")
for i, r in enumerate(batch_results):
    print(f"  Conv {i}: mean shape={r['mean'].shape}, max shape={r['max'].shape}")