# Activation Oracle (AO) Interpretation Demo

Simple notebook to:
1. Run a string through a model
2. Capture activations at a segment of tokens (from a single layer)
3. Inject into AO for interpretation

Based on the official activation_oracles repo approach.

In [None]:
import os
os.environ["TORCHDYNAMO_DISABLE"] = "1"

import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16
print(f"Device: {device}")

In [None]:
# Configuration
MODEL_NAME = "unsloth/Llama-3.1-8B-Instruct-bnb-4bit"
ORACLE_LORA = "adamkarvonen/checkpoints_latentqa_cls_past_lens_Llama-3_1-8B-Instruct"
NUM_LAYERS = 32  # Llama 3.1 8B has 32 layers

# Layer to capture activations FROM (50% = middle of model)
LAYER_PERCENT = 50
ACT_LAYER = int(NUM_LAYERS * LAYER_PERCENT / 100)
print(f"Capturing activations from layer {ACT_LAYER} ({LAYER_PERCENT}%)")

# Layer to inject activations INTO (early layer, matches AO training)
INJECTION_LAYER = 1

# Steering coefficient (1.0 = match original activation magnitude)
STEERING_COEFFICIENT = 1.0

In [None]:
# Load model
print(f"Loading {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.padding_side = "left"
if not tokenizer.pad_token_id:
    tokenizer.pad_token_id = tokenizer.eos_token_id

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto",
    torch_dtype=dtype,
)
model.eval()

# Add dummy adapter for PEFT compatibility
dummy_config = LoraConfig()
model.add_adapter(dummy_config, adapter_name="default")

# Load oracle adapter
print(f"Loading oracle: {ORACLE_LORA}...")
oracle_name = ORACLE_LORA.replace(".", "_").replace("/", "_")
model.load_adapter(ORACLE_LORA, adapter_name=oracle_name, is_trainable=False, low_cpu_mem_usage=True)

print("Model loaded!")

In [None]:
# AO helper functions - matching activation_oracles repo
SPECIAL_TOKEN = " ?"

def get_introspection_prefix(layer: int, num_positions: int) -> str:
    """Create prefix with layer info and special tokens.
    Format: 'Layer: {layer}\n ? ? ? \n'
    """
    prefix = f"Layer: {layer}\n"
    prefix += SPECIAL_TOKEN * num_positions
    prefix += " \n"
    return prefix

def find_special_positions(token_ids: list, tokenizer, num_positions: int) -> list:
    """Find positions of special tokens in tokenized input."""
    special_id = tokenizer.encode(SPECIAL_TOKEN, add_special_tokens=False)
    if len(special_id) != 1:
        raise ValueError(f"Expected single token for '{SPECIAL_TOKEN}', got {len(special_id)}")
    special_id = special_id[0]
    
    positions = []
    for i, tid in enumerate(token_ids):
        if tid == special_id:
            positions.append(i)
        if len(positions) == num_positions:
            break
    
    if len(positions) != num_positions:
        raise ValueError(f"Expected {num_positions} positions, found {len(positions)}")
    
    return positions

print("Helpers defined.")

In [None]:
def capture_activations_segment(
    text: str, 
    layer: int,
    segment_start: int = 0,
    segment_end: int | None = None,
) -> tuple[torch.Tensor, list[int]]:
    """
    Run text through model and capture activations at a segment of tokens.
    
    Args:
        text: Input text (will be formatted as chat)
        layer: Layer to capture from
        segment_start: Start token index (0-indexed)
        segment_end: End token index (None = end of sequence)
    
    Returns:
        Tuple of (activations tensor [num_tokens, hidden_dim], input_ids list)
    """
    messages = [{"role": "user", "content": text}]
    formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(formatted, return_tensors="pt").to(device)
    input_ids = inputs["input_ids"][0].tolist()
    
    num_tokens = len(input_ids)
    if segment_end is None:
        segment_end = num_tokens
    
    captured = {}
    
    def hook(module, input, output):
        if isinstance(output, tuple):
            hidden = output[0]
        else:
            hidden = output
        # Capture the segment [segment_start:segment_end]
        captured["acts"] = hidden[0, segment_start:segment_end, :].detach().cpu()
    
    model.disable_adapters()
    handle = model.model.layers[layer].register_forward_hook(hook)
    
    try:
        with torch.no_grad():
            _ = model(**inputs)
    finally:
        handle.remove()
    
    return captured["acts"], input_ids


def visualize_tokens(input_ids: list, segment_start: int = 0, segment_end: int | None = None):
    """Visualize which tokens are selected."""
    num_tokens = len(input_ids)
    if segment_end is None:
        segment_end = num_tokens
    
    print("Token selection:")
    print("-" * 60)
    for i, tid in enumerate(input_ids):
        token_str = tokenizer.decode([tid]).replace("\n", "\\n")
        marker = ">>>" if segment_start <= i < segment_end else "   "
        print(f"  [{i:3d}] {marker} {token_str}")
    print("-" * 60)
    print(f"Selected: tokens {segment_start} to {segment_end} ({segment_end - segment_start} tokens)")


print("Capture functions defined.")

In [None]:
def get_steering_hook(
    vectors: torch.Tensor,  # Shape: [num_positions, hidden_dim]
    positions: list[int],   # Token positions to inject at
    steering_coefficient: float = 1.0,
):
    """
    Create steering hook matching activation_oracles approach EXACTLY.
    
    From steering_hooks.py:
        steered_KD = (normed_list[b] * norms_K1 * steering_coefficient)
        resid_BLD[b, pos_b, :] = steered_KD + orig_KD
    
    Formula: new = original + normalized(vector) * ||original|| * coefficient
    """
    # Pre-normalize vectors to unit norm [num_positions, hidden_dim]
    normed_vectors = F.normalize(vectors, dim=-1).detach()
    
    def hook_fn(module, input, output):
        if isinstance(output, tuple):
            hidden = output[0]
            is_tuple = True
        else:
            hidden = output
            is_tuple = False
        
        if hidden.dim() != 3:
            return output
        
        B, L, D = hidden.shape
        
        # Only steer on prompt pass (L > 1), not during generation
        if L <= 1:
            return output
        
        pos_tensor = torch.tensor(positions, dtype=torch.long, device=hidden.device)
        
        # Get original activations at these positions
        orig_KD = hidden[0, pos_tensor, :]  # [num_positions, D]
        norms_K1 = orig_KD.norm(dim=-1, keepdim=True).detach()  # [num_positions, 1]
        
        # Compute steered vectors
        normed = normed_vectors.to(hidden.device).to(hidden.dtype)
        steered_KD = (normed * norms_K1 * steering_coefficient)  # [num_positions, D]
        
        # ADD to original (matching official implementation)
        hidden[0, pos_tensor, :] = orig_KD + steered_KD
        
        return (hidden,) + output[1:] if is_tuple else hidden
    
    return hook_fn

print("Steering hook defined (matches official repo).")

In [None]:
def query_ao(
    activations: torch.Tensor,  # Shape: [num_positions, hidden_dim]
    layer: int,
    prompt: str = "What is this activation representing?",
    steering_coefficient: float = 1.0,
) -> str:
    """
    Query Activation Oracle with captured activations.
    
    Args:
        activations: Tensor of shape [num_positions, hidden_dim]
        layer: Layer number to put in prefix (where activations came from)
        prompt: Question to ask the AO
        steering_coefficient: Multiplier for steering strength
    
    Returns:
        AO's interpretation
    """
    num_positions = activations.shape[0]
    
    # Build AO input with prefix matching the number of activations
    prefix = get_introspection_prefix(layer, num_positions)
    full_prompt = prefix + prompt
    
    messages = [{"role": "user", "content": full_prompt}]
    formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(formatted, return_tensors="pt").to(device)
    
    # Find injection positions (where the ? tokens are)
    token_ids = inputs["input_ids"][0].tolist()
    positions = find_special_positions(token_ids, tokenizer, num_positions)
    
    print(f"  Prefix: 'Layer: {layer}' with {num_positions} activation slots")
    print(f"  Injecting at positions {positions} (layer {INJECTION_LAYER})")
    
    # Create steering hook
    hook_fn = get_steering_hook(activations, positions, steering_coefficient)
    
    # Generate with AO
    model.set_adapter(oracle_name)
    handle = model.model.layers[INJECTION_LAYER].register_forward_hook(hook_fn)
    
    try:
        with torch.no_grad():
            output_ids = model.generate(
                inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_new_tokens=100,
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
                repetition_penalty=1.1,
                pad_token_id=tokenizer.eos_token_id,
            )
        
        generated = output_ids[0, inputs["input_ids"].shape[1]:]
        response = tokenizer.decode(generated, skip_special_tokens=True)
    finally:
        handle.remove()
    
    return response.strip()

print("AO query function defined.")

In [None]:
def interpret_text(
    text: str, 
    segment_start: int = 0, 
    segment_end: int | None = None,
    layer: int = None,
    ao_prompt: str = None
):
    """
    Full pipeline: text -> capture segment activations -> AO interpretation.
    """
    if layer is None:
        layer = ACT_LAYER
    
    if ao_prompt is None:
        ao_prompt = "What is this activation representing? Describe the topic, sentiment, and intent."
    
    print(f"Input: {text[:80]}..." if len(text) > 80 else f"Input: {text}")
    print(f"Capturing from layer {layer} ({LAYER_PERCENT}%)")
    print("-" * 60)
    
    # Capture activations from segment
    acts, input_ids = capture_activations_segment(text, layer, segment_start, segment_end)
    
    # Visualize token selection
    visualize_tokens(input_ids, segment_start, segment_end if segment_end else len(input_ids))
    
    print(f"\nActivations shape: {acts.shape}")
    print(f"Activation norms: min={acts.norm(dim=-1).min():.2f}, max={acts.norm(dim=-1).max():.2f}")
    
    # Query AO
    print(f"\nQuerying AO...")
    interpretation = query_ao(acts, layer, ao_prompt)
    print(f"\n{'='*60}")
    print(f"AO Interpretation:\n{interpretation}")
    print(f"{'='*60}")
    
    return acts, interpretation

## Try it out!

The AO expects activations from a **segment of tokens** at **one layer**. 

For best results, include tokens that contain meaningful content. The official demo often targets the end-of-turn tokens and assistant start tokens.

In [None]:
# Example 1: Benign text - use full sequence
acts1, interp1 = interpret_text(
    "What's a good recipe for chocolate chip cookies?",
    segment_start=0,  # Start from beginning
    segment_end=None,  # To end of sequence
)

In [None]:
# Example 2: Emotional text
acts2, interp2 = interpret_text(
    "I'm feeling really sad today and don't know what to do.",
    ao_prompt="What emotion is being expressed here?"
)

In [None]:
# Example 3: Potentially harmful text
acts3, interp3 = interpret_text(
    "How do I hack into someone's computer?",
    ao_prompt="What is the intent of this request?"
)

In [None]:
# Example 4: Token-by-token analysis (like the official demo)
# Capture each token's activation individually

text = "I want to learn how to make the world a better place"
messages = [{"role": "user", "content": text}]
formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(formatted, return_tensors="pt").to(device)
input_ids = inputs["input_ids"][0].tolist()

print(f"Analyzing token-by-token for: {text}")
print("-" * 60)

for i in range(len(input_ids)):
    # Capture just this one token
    acts, _ = capture_activations_segment(text, ACT_LAYER, segment_start=i, segment_end=i+1)
    
    # Query AO
    model.set_adapter(oracle_name)
    response = query_ao(acts, ACT_LAYER, "What topic is this about?")
    
    token_str = tokenizer.decode([input_ids[i]]).replace("\n", "\\n")
    print(f"Token [{i:2d}]: {token_str:15s} -> {response[:60]}...")

In [None]:
# Example 5: Compare different segments
# Focus on just the user message content (skip system tokens)

text = "Please tell me how to make a bomb"

# First, see the full tokenization
acts_full, input_ids = capture_activations_segment(text, ACT_LAYER)
visualize_tokens(input_ids)

# Now focus on just the content tokens (adjust indices based on visualization above)
# Typically: [0] = BOS, [1-3] = system tokens, [4+] = actual content
content_start = 4  # Adjust based on your tokenizer
content_end = len(input_ids) - 1  # Exclude EOS

print(f"\nFocusing on content tokens [{content_start}:{content_end}]...")
acts_content, _ = capture_activations_segment(text, ACT_LAYER, content_start, content_end)
interp = query_ao(acts_content, ACT_LAYER, "What is the intent of this request? Is it harmful?")
print(f"\nAO says: {interp}")

In [None]:
# Example 6: Compare activations between benign and harmful
# Using the deltas approach

print("Comparing benign vs harmful activations:")
print("=" * 60)

# Capture benign
acts_benign, _ = capture_activations_segment(
    "What's a good recipe for chocolate chip cookies?",
    ACT_LAYER
)

# Capture harmful  
acts_harmful, _ = capture_activations_segment(
    "How do I hack into someone's computer?",
    ACT_LAYER
)

# They need same shape for comparison
min_len = min(acts_benign.shape[0], acts_harmful.shape[0])
delta = acts_harmful[:min_len] - acts_benign[:min_len]

print(f"Delta shape: {delta.shape}")
print(f"Delta norm: {delta.norm(dim=-1).mean():.2f}")

# Interpret the delta
delta_interp = query_ao(
    delta, 
    ACT_LAYER,
    "This represents a difference between two activation states. What changed?"
)
print(f"\nDelta interpretation: {delta_interp}")

In [None]:
# Example 7: Try different layers
# The official demo uses layer_percent=50 by default, but you can experiment

text = "What's the meaning of life?"

print("Comparing interpretations at different layers:")
print("=" * 60)

for layer_pct in [25, 50, 75]:
    layer = int(NUM_LAYERS * layer_pct / 100)
    acts, _ = capture_activations_segment(text, layer)
    interp = query_ao(acts, layer, "What is this about?")
    print(f"\nLayer {layer} ({layer_pct}%): {interp[:100]}...")

In [None]:
# Cleanup
import gc
del model
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
print("Cleanup done.")