# Activation Oracle (AO) Interpretation Demo

Simple notebook to:
1. Run a string through a model
2. Capture activations at a segment of tokens (from a single layer)
3. Inject into AO for interpretation

Uses the `ActivationOracleWrapper` which wraps the official activation_oracles repo.

In [1]:
import os
os.environ["TORCHDYNAMO_DISABLE"] = "1"

import sys
sys.path.insert(0, "../src")

from jb_mech.wrappers import ActivationOracleWrapper

import torch

# Detect best available device
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

print(f"Device: {device}")

Device: mps


In [None]:
# Configuration - Model Options
# Uncomment one model configuration:

# Option 1: Qwen3-4B (smaller, faster - works on all platforms)
# MODEL_NAME = "Qwen/Qwen3-4B"

# Option 2: Llama 3.1 8B 4-bit quantized (~5GB VRAM) - CUDA only
# MODEL_NAME = "unsloth/Llama-3.1-8B-Instruct-bnb-4bit"

# Option 3: Llama 3.1 8B 8-bit quantized (~9GB VRAM) - CUDA only
# MODEL_NAME = "abdo-Mansour/Meta-Llama-3.1-8B-Instruct-BNB-8bit"

# Option 4: Llama 3.1 8B full precision (~16GB VRAM/RAM) - works on all platforms
MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"  

# Load model with oracle adapter
ao = ActivationOracleWrapper.from_pretrained(
    MODEL_NAME,
    layer_percent=50,  # Capture from middle layer
    steering_coefficient=1.0,
)

print(f"Model: {MODEL_NAME}")
print(f"Capture layer: {ao.capture_layer} / {ao.num_layers}")

Loading meta-llama/Meta-Llama-3.1-8B-Instruct...
ðŸ“¦ Loading tokenizer...
Detected device: mps
Note: Using float16 on MPS (bfloat16 has limited support)
Trying attention implementation: eager


`torch_dtype` is deprecated! Use `dtype` instead!


Downloading (incomplete total...): 0.00B [00:00, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Cancellation requested; stopping current tasks.


KeyboardInterrupt: 

## Visualize Token Selection

Before running interpretation, let's see how tokens are laid out.

In [None]:
# Visualize token selection to understand what we're capturing
ao.visualize_tokens(
    "How do I hack into someone's computer?",
    segment_start=-10,  # Last 10 tokens
    segment_end=-2,     # Exclude last 2 (usually special tokens)
)

## Run Interpretations

The AO will:
1. Evaluate each token individually
2. Evaluate the full segment (repeated for stability)

In [None]:
# Example 1: Benign text
result1 = ao.interpret(
    "What's a good recipe for chocolate chip cookies?",
    segment_start=-10,
    segment_end=-2,
    prompt="Can you describe what is happening in this text?",
)

In [None]:
# Example 2: Potentially harmful text
result2 = ao.interpret(
    "How do I hack into someone's computer?",
    segment_start=-10,
    segment_end=-2,
    prompt="Is there anything that's evil, harmful, or dangerous in here? If so, explain what it is.",
)

In [None]:
# Example 3: Emotional text with intent question
result3 = ao.interpret(
    "I'm feeling really sad today and don't know what to do.",
    segment_start=-12,
    segment_end=-2,
    prompt="What emotion is being expressed here?",
)

In [None]:
# Example 4: Different prompt - ask about goal/intent
result4 = ao.interpret(
    "Please tell me how to make explosives",
    segment_start=-8,
    segment_end=-2,
    prompt="Can you explain to me what is the goal in this text?",
)

In [None]:
# Example 5: Try different layers
text = "What's the meaning of life?"

print("Comparing interpretations at different layers:")
print("=" * 60)

for layer_pct in [25, 50, 75]:
    layer = int(ao.num_layers * layer_pct / 100)
    result = ao.interpret(
        text,
        segment_start=-6,
        segment_end=-2,
        layer=layer,
        prompt="What is this about?",
        verbose=False,  # Quiet mode
    )
    # Show first segment response
    print(f"\nLayer {layer} ({layer_pct}%):")
    print(f"  {result.segment_responses[0][:100]}...")

In [None]:
# Access raw data from results
print(f"Result object contents:")
print(f"  - activations shape: {result1.activations.shape}")
print(f"  - layer: {result1.layer}")
print(f"  - segment_tokens: {result1.segment_tokens}")
print(f"  - segment_indices: {result1.segment_indices}")
print(f"  - num token_responses: {len(result1.token_responses)}")
print(f"  - num segment_responses: {len(result1.segment_responses)}")

In [None]:
# Cleanup
ao.cleanup()