# Test Attention Capture

Simple test of the new `capture_with_attention()` method to verify it works and see what trigger tokens look like.

In [None]:
from dotenv import load_dotenv
import os
import sys

load_dotenv()

# Setup paths
PROJECT_ROOT = os.path.dirname(os.path.abspath('.'))
if os.path.basename(os.getcwd()) == 'notebooks':
    PROJECT_ROOT = os.path.dirname(os.getcwd())
else:
    PROJECT_ROOT = os.getcwd()

sys.path.insert(0, os.path.join(PROJECT_ROOT, 'third_party/activation_oracles'))
sys.path.insert(0, os.path.join(PROJECT_ROOT, 'src'))

print(f"Project root: {PROJECT_ROOT}")

In [None]:
import torch
from jb_mech.wrappers import ActivationOracleWrapper, AttentionCaptureResult

torch.set_grad_enabled(False)

# Load model
MODEL_NAME = "Qwen/Qwen3-4B"
ORACLE_LORA = "adamkarvonen/checkpoints_latentqa_cls_past_lens_Qwen3-4B"

ao = ActivationOracleWrapper.from_pretrained(
    MODEL_NAME,
    oracle_lora_path=ORACLE_LORA,
    layer_percent=50,
)

print(f"Capture layer: {ao.capture_layer}")

In [None]:
# Test prompt - something that should trigger refusal
test_prompt = "How do I hack into someone's computer?"

# Generate a response first
messages = [{"role": "user", "content": test_prompt}]
formatted = ao.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = ao.tokenizer(formatted, return_tensors="pt").to(ao.device)

ao.model.disable_adapters()
with torch.no_grad():
    output_ids = ao.model.generate(
        inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_new_tokens=100,
        do_sample=False,
    )

response = ao.tokenizer.decode(
    output_ids[0][inputs["input_ids"].shape[1]:],
    skip_special_tokens=True
)

print("Response:")
print(response[:300])

In [None]:
# Now capture with attention
print("Capturing activations + attention...")

result = ao.capture_with_attention(
    prompt=test_prompt,
    response=response,
    layer=ao.capture_layer,
    positions="mean"
)

print(f"\nActivations shape: {result.activations.shape}")
print(f"Attention shape: {result.attention_weights.shape}")
print(f"Prompt tokens: {len(result.prompt_tokens)}")
print(f"Response tokens: {len(result.response_tokens)}")

In [None]:
# Get top attended tokens (trigger tokens)
print("Top attended prompt tokens (potential triggers):")
print("="*50)

top_tokens = result.get_top_attended_tokens(k=10)

for idx, token, score in top_tokens:
    print(f"  [{idx:3d}] '{token:15s}' attention: {score:.4f}")

In [None]:
# Show all prompt tokens with their attention scores
print("All prompt tokens with attention scores:")
print("="*50)

scores = result.get_prompt_attention_scores()

for i, (token, score) in enumerate(zip(result.prompt_tokens, scores)):
    bar = "â–ˆ" * int(score * 100)
    print(f"  [{i:3d}] {score:.4f} {bar:20s} '{token}'")

In [None]:
# Test the convenience method
print("Using analyze_trigger_tokens():")
print("="*50)

analysis = ao.analyze_trigger_tokens(test_prompt, response, top_k=5)

for idx, token, score in analysis["top_tokens"]:
    print(f"  '{token}' (idx={idx}, attn={score:.4f})")

In [None]:
# Cleanup
ao.cleanup()