# Mamba Causal Tracing Visualization

This notebook demonstrates causal tracing on Mamba State Space Models and visualizes the results as heatmaps (similar to the ROME paper).

The heatmap shows which (layer, position) pairs are critical for factual recall:
- **X-axis**: Token positions in the prompt
- **Y-axis**: Model layers (0 = earliest, higher = later)
- **Color**: Probability of correct answer when that state is restored
  - Blue/Dark: Low probability (restoration doesn't help)
  - Yellow/Bright: High probability (restoration recovers the answer)

In [None]:
# Imports
import sys
from pathlib import Path

# Add project root to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, HTML

from mamba_causal_analysis.mamba_models import load_mamba_model
from mamba_causal_analysis.mamba_causal_trace import calculate_hidden_flow

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 6)

print("✓ Imports successful")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 1. Load Mamba Model

In [None]:
# Load Mamba-130m (smallest model for quick experimentation)
model_name = "state-spaces/mamba-130m"
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Loading {model_name}...")
mt = load_mamba_model(model_name, device=device)
print(f"✓ Model loaded with {mt.num_layers} layers")
print(f"  Device: {mt.device}")

## 2. Define Test Prompt

Choose a factual prompt where we know the answer. We'll trace which parts of the model are responsible for recalling this fact.

In [None]:
# Example factual prompts (choose one or add your own)
test_cases = [
    {"prompt": "The Eiffel Tower is located in", "subject": "Eiffel Tower"},
    {"prompt": "The Space Needle is located in downtown", "subject": "Space Needle"},
    {"prompt": "The mother tongue of Angela Merkel is", "subject": "Angela Merkel"},
    {"prompt": "Apple Inc. was founded by Steve", "subject": "Apple Inc"},
]

# Select test case
test = test_cases[0]  # Change index to try different prompts
prompt = test["prompt"]
subject = test["subject"]

print(f"Prompt: \"{prompt}\"")
print(f"Subject: \"{subject}\"")
print()

# Test what the model predicts
tokens = mt.tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
    outputs = mt.model(tokens["input_ids"])
    if hasattr(outputs, 'logits'):
        logits = outputs.logits
    else:
        logits = outputs
    
    probs = torch.softmax(logits[0, -1, :], dim=0)
    top_tokens = torch.topk(probs, 5)
    
print("Model's top 5 predictions:")
for prob, token_id in zip(top_tokens.values, top_tokens.indices):
    token_str = mt.tokenizer.decode([token_id.item()])
    print(f"  {token_str:20s} {prob.item():.4f}")

## 3. Run Causal Tracing

This cell will:
1. Run the clean input and save all hidden states
2. For each (layer, position) pair:
   - Corrupt the input with noise
   - Restore the clean state at that location
   - Measure how much this helps recover the correct answer
3. Return a matrix of scores for visualization

**Note**: This takes a few minutes (tracing all layer × position combinations).

In [None]:
# Run causal tracing
samples = 10  # Number of noise samples to average (10 is fast, 100 is more accurate)
noise_level = 3.0  # Standard deviations of noise (3.0 is standard from ROME)

print(f"Running causal tracing with {samples} noise samples...")
print(f"Tracing {mt.num_layers} layers × sequence length positions")
print(f"Estimated time: ~{mt.num_layers * samples * 0.1:.1f} seconds")
print()

result = calculate_hidden_flow(
    mt,
    mt.tokenizer,
    prompt=prompt,
    subject=subject,
    samples=samples,
    noise_level=noise_level,
)

print("\n✓ Causal tracing complete!")
print(f"  Clean probability: {result['high_score']:.4f}")
print(f"  Corrupted probability: {result['low_score']:.4f}")
print(f"  Effect size: {result['high_score'] - result['low_score']:.4f}")
print(f"  Target token: {result['target_token']}")

## 4. Visualize Heatmap

The heatmap shows:
- **Bright (yellow) regions**: Restoring states here recovers the answer
- **Dark (blue) regions**: Restoring states here doesn't help

This reveals which layers and token positions encode the factual knowledge.

In [None]:
def plot_causal_trace_heatmap(result, title=None):
    """
    Plot causal tracing heatmap like in the ROME paper.
    """
    scores = result['scores']
    tokens = result['input_tokens']
    low_score = result['low_score']
    high_score = result['high_score']
    subj_start, subj_end = result['subject_range']
    
    # Normalize scores to [0, 1] range
    normalized_scores = (scores - low_score) / (high_score - low_score)
    normalized_scores = np.clip(normalized_scores, 0, 1)
    
    # Create figure
    fig, ax = plt.subplots(figsize=(14, 8))
    
    # Plot heatmap
    im = ax.imshow(
        normalized_scores,
        aspect='auto',
        cmap='RdYlBu_r',  # Red-Yellow-Blue reversed (blue=low, red=high)
        vmin=0,
        vmax=1,
        interpolation='nearest'
    )
    
    # Colorbar
    cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label('Normalized Probability', rotation=270, labelpad=20)
    
    # Axes labels
    ax.set_xlabel('Token Position', fontsize=12)
    ax.set_ylabel('Layer', fontsize=12)
    
    # Set token labels on x-axis
    ax.set_xticks(range(len(tokens)))
    ax.set_xticklabels(tokens, rotation=45, ha='right')
    
    # Highlight subject tokens
    ax.axvline(subj_start - 0.5, color='black', linestyle='--', alpha=0.5, linewidth=1)
    ax.axvline(subj_end + 0.5, color='black', linestyle='--', alpha=0.5, linewidth=1)
    
    # Add text box with subject
    ax.text(
        (subj_start + subj_end) / 2,
        -1,
        'Subject',
        ha='center',
        va='top',
        fontsize=10,
        bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)
    )
    
    # Title
    if title is None:
        title = f"Causal Tracing: {result.get('prompt', prompt)}"
    ax.set_title(title, fontsize=14, pad=20)
    
    # Grid
    ax.set_yticks(range(0, scores.shape[0], 4))
    ax.grid(True, alpha=0.3, linewidth=0.5)
    
    plt.tight_layout()
    return fig

# Plot the heatmap
fig = plot_causal_trace_heatmap(result)
plt.show()

# Find and print most important positions
scores = result['scores']
max_layer, max_pos = np.unravel_index(scores.argmax(), scores.shape)
print(f"\nMost important restoration:")
print(f"  Layer {max_layer}, Position {max_pos}")
print(f"  Token: '{result['input_tokens'][max_pos]}'")
print(f"  Score: {scores[max_layer, max_pos]:.4f}")

## 5. Analyze Results

Let's examine which layers are most important for factual recall.

In [None]:
# Average effect by layer
scores = result['scores']
layer_avg = scores.mean(axis=1)

plt.figure(figsize=(10, 5))
plt.plot(layer_avg, marker='o')
plt.xlabel('Layer')
plt.ylabel('Average Restoration Effect')
plt.title('Average Effect of Restoring Each Layer')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("Top 5 most important layers (by average effect):")
top_layers = np.argsort(layer_avg)[-5:][::-1]
for layer in top_layers:
    print(f"  Layer {layer:2d}: {layer_avg[layer]:.4f}")

## 6. Subject Token Analysis

How important are the subject tokens specifically?

In [None]:
subj_start, subj_end = result['subject_range']
subject_scores = scores[:, subj_start:subj_end+1]

print(f"Subject tokens: {result['input_tokens'][subj_start:subj_end+1]}")
print(f"Subject token range: positions {subj_start} to {subj_end}")
print()

# Average effect across subject positions by layer
subject_layer_avg = subject_scores.mean(axis=1)

plt.figure(figsize=(10, 5))
plt.plot(subject_layer_avg, marker='o', label='Subject tokens')
plt.plot(layer_avg, marker='s', alpha=0.5, label='All tokens')
plt.xlabel('Layer')
plt.ylabel('Average Restoration Effect')
plt.title('Effect of Restoring Subject vs All Tokens')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"Overall average effect (all tokens): {layer_avg.mean():.4f}")
print(f"Average effect (subject tokens): {subject_layer_avg.mean():.4f}")
print(f"Ratio (subject / all): {subject_layer_avg.mean() / layer_avg.mean():.2f}x")

## 7. Save Results (Optional)

Save the causal tracing results for later analysis or comparison.

In [None]:
# Uncomment to save
# output_path = f"../results/{prompt.replace(' ', '_')[:30]}_trace.npz"
# Path(output_path).parent.mkdir(parents=True, exist_ok=True)
# np.savez(
#     output_path,
#     scores=result['scores'],
#     low_score=result['low_score'],
#     high_score=result['high_score'],
#     input_tokens=result['input_tokens'],
#     subject_range=result['subject_range'],
#     target_token=result['target_token'],
#     prompt=prompt,
#     subject=subject,
# )
# print(f"Results saved to {output_path}")

## Next Steps

Try different prompts and analyze the patterns:
- Do certain layers consistently encode factual knowledge?
- How does the pattern differ for different types of facts?
- Compare with GPT-2 results from the ROME paper

For Phase 3, we'll dive deeper into tracing the internal SSM state (h_t) and selection parameters (B, C, Δt).