# SAE Feature Intervention Experiment

This notebook implements feature interventions using TransformerLens and SAE features identified in the basketball/baseball analysis. We modify activations at layer 16 to suppress basketball features and amplify baseball features, then evaluate the impact on token predictions.

## 0. Fix Dependencies (if needed)

If you encounter a scipy/numpy compatibility error when importing TransformerLens, run the cell below to fix it.

In [1]:
# Run this cell if you get scipy/numpy import errors
# This will upgrade scipy and numpy to compatible versions

import subprocess
import sys

print("Upgrading scipy and numpy to fix compatibility issues...")
try:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "scipy>=1.11.0", "numpy>=1.24.0"], 
                         stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    print("✓ Dependencies upgraded successfully!")
    print("\n" + "="*80)
    print("IMPORTANT: You must restart the kernel for changes to take effect!")
    print("="*80)
    print("\nSteps:")
    print("1. Go to: Kernel → Restart Kernel (or press Ctrl+M, then 0, then 0)")
    print("2. After restart, run all cells from the beginning")
    print("="*80)
except subprocess.CalledProcessError as e:
    print(f"Error upgrading dependencies: {e}")
    print("\nTry running this command manually in your terminal:")
    print("  pip install --upgrade 'scipy>=1.11.0' 'numpy>=1.24.0'")

Upgrading scipy and numpy to fix compatibility issues...


✓ Dependencies upgraded successfully!

IMPORTANT: You must restart the kernel for changes to take effect!

Steps:
1. Go to: Kernel → Restart Kernel (or press Ctrl+M, then 0, then 0)
2. After restart, run all cells from the beginning


## 1. Setup and Configuration

In [2]:
%load_ext autoreload
%autoreload 2

import os
import sys
import torch
import numpy as np
from typing import Set, Dict, List, Tuple
from dataclasses import dataclass
from IPython.display import HTML, display
import pandas as pd

# Workaround for scipy/numpy compatibility issue
# This patches the scipy issue before importing transformer_lens
import warnings
warnings.filterwarnings('ignore', category=UserWarning)

# Try to patch the scipy issue if it exists
try:
    import scipy.special._multiufuncs
    # The issue is in MultiUFunc initialization - try to bypass it
    # by ensuring numpy is properly initialized first
    _ = np.array([1, 2, 3])  # Force numpy initialization
except Exception as scipy_init_err:
    # If scipy fails to initialize, we'll handle it below
    pass

# Import TransformerLens with error handling
try:
    from transformer_lens import HookedTransformer
    print("✓ Successfully imported HookedTransformer")
except (ValueError, ImportError) as e:
    error_msg = str(e)
    if "numpy.ufunc" in error_msg or "scipy" in error_msg.lower():
        print("="*80)
        print("DEPENDENCY ISSUE DETECTED: scipy/numpy version incompatibility")
        print("="*80)
        print("\nTo fix this, run one of the following commands in your terminal:")
        print("\n  Option 1 (pip):")
        print("    pip install --upgrade 'scipy>=1.11.0' 'numpy>=1.24.0'")
        print("\n  Option 2 (conda):")
        print("    conda update scipy numpy")
        print("\n  Option 3 (reinstall compatible versions):")
        print("    pip uninstall scipy numpy -y")
        print("    pip install 'scipy>=1.11.0' 'numpy>=1.24.0'")
        print("\nAfter fixing, restart your kernel and run this cell again.")
        print("="*80)
        raise ImportError(
            "TransformerLens import failed due to scipy/numpy incompatibility. "
            "Please fix dependencies as shown above and restart the kernel."
        ) from e
    else:
        raise

from globals import LLAMA_3_1_8B
from crisp import LayerFeatures
from sae import TopkSae
from utils import load_cached_features
from data import prepare_text
from datasets import load_dataset

# Set GPU device
os.environ["CUDA_VISIBLE_DEVICES"] = "2"  # Adjust as needed

# Configuration from basketball_baseball_analysis.ipynb
MODEL_CARD = LLAMA_3_1_8B
LAYER_TO_ANALYZE = 16
SAE_SAVE_PATH = "llama_sae_cache"

# Intervention configuration (configurable)
INTERVENTION_FACTOR_BASKETBALL = -5.0  # Complete suppression (0.0) or partial (e.g., 0.5)
INTERVENTION_FACTOR_BASEBALL = 5.0    # Amplification factor (2.0 = 2x amplification)

# Feature threshold (same as used in topk_filtered)
FEATURE_THRESHOLD = 3.0

print(f"Using model: {MODEL_CARD}")
print(f"Operating on layer: {LAYER_TO_ANALYZE}")
print(f"Basketball intervention factor: {INTERVENTION_FACTOR_BASKETBALL}")
print(f"Baseball intervention factor: {INTERVENTION_FACTOR_BASEBALL}")

  from .autonotebook import tqdm as notebook_tqdm


✓ Successfully imported HookedTransformer
Using model: meta-llama/Llama-3.1-8B
Operating on layer: 16
Basketball intervention factor: -5.0
Baseball intervention factor: 5.0


## 2. Load Cached Features

In [3]:
# ============================================================================
# PARAMETER: Choose feature source
# ============================================================================
# Options:
#   "wiki" - Load from basketball vs wiki and baseball vs wiki comparisons
#   "baseball" - Load from basketball vs baseball comparison
# ============================================================================
FEATURE_SOURCE = "baseball"  # Change to "baseball" to use basketball vs baseball comparison

# Create data config classes (same as in basketball_baseball_analysis.ipynb)
@dataclass
class BasketballBaseballDataConfig:
    max_length: int = 1000
    min_length: int = 100
    n_examples: int = 250
    data_type: str = "basketball_baseball"
    forget_type: str = "basketball"
    retain_type: str = "wiki"
    
    def to_dict(self):
        return {
            "max_length": self.max_length,
            "min_length": self.min_length,
            "n_examples": self.n_examples,
            "data_type": self.data_type,
            "forget_type": self.forget_type,
            "retain_type": self.retain_type
        }

if FEATURE_SOURCE == "wiki":
    # Load basketball vs wiki features
    print("Loading features from Basketball vs Wiki and Baseball vs Wiki comparisons...")
    data_config_basketball_wiki = BasketballBaseballDataConfig(
        n_examples=250,
        data_type="basketball_wiki",
        forget_type="basketball",
        retain_type="wiki"
    )

    basketball_wiki_features = load_cached_features(
        LAYER_TO_ANALYZE,
        data_config_basketball_wiki,
        model_name=MODEL_CARD
    )

    if basketball_wiki_features is None:
        raise ValueError("Basketball vs wiki features not found. Please run basketball_baseball_analysis.ipynb first.")

    basketball_layer_features = LayerFeatures(list(basketball_wiki_features))
    print(f"Loaded {len(basketball_layer_features)} features for Basketball vs Wikipedia")

    # Load baseball vs wiki features
    data_config_baseball_wiki = BasketballBaseballDataConfig(
        n_examples=250,
        data_type="baseball_wiki",
        forget_type="baseball",
        retain_type="wiki"
    )

    baseball_wiki_features = load_cached_features(
        LAYER_TO_ANALYZE,
        data_config_baseball_wiki,
        model_name=MODEL_CARD
    )

    if baseball_wiki_features is None:
        raise ValueError("Baseball vs wiki features not found. Please run basketball_baseball_analysis.ipynb first.")

    baseball_layer_features = LayerFeatures(list(baseball_wiki_features))
    print(f"Loaded {len(baseball_layer_features)} features for Baseball vs Wikipedia")

    # Extract significant feature indices
    # Basketball features: high target_acts_relative (basketball vs wiki)
    # Baseball features: high benign_acts_relative (baseball vs wiki)
    basketball_feature_indices: Set[int] = set()
    for feature in basketball_layer_features.features.values():
        if feature.target_acts_relative >= FEATURE_THRESHOLD:
            basketball_feature_indices.add(feature.index)

    baseball_feature_indices: Set[int] = set()
    for feature in baseball_layer_features.features.values():
        if feature.benign_acts_relative >= FEATURE_THRESHOLD:
            baseball_feature_indices.add(feature.index)

elif FEATURE_SOURCE == "baseball":
    # Load basketball vs baseball features
    print("Loading features from Basketball vs Baseball comparison...")
    data_config_basketball_baseball = BasketballBaseballDataConfig(
        n_examples=250,
        data_type="basketball_baseball",
        forget_type="basketball",
        retain_type="baseball"
    )

    basketball_baseball_features = load_cached_features(
        LAYER_TO_ANALYZE,
        data_config_basketball_baseball,
        model_name=MODEL_CARD
    )

    if basketball_baseball_features is None:
        raise ValueError("Basketball vs baseball features not found. Please run basketball_baseball_analysis.ipynb first.")

    basketball_baseball_layer_features = LayerFeatures(list(basketball_baseball_features))
    print(f"Loaded {len(basketball_baseball_layer_features)} features for Basketball vs Baseball")

    # Extract significant feature indices from basketball vs baseball comparison
    # Basketball features: high target_acts_relative (basketball is target)
    # Baseball features: high benign_acts_relative (baseball is benign/retain)
    basketball_feature_indices: Set[int] = set()
    for feature in basketball_baseball_layer_features.features.values():
        if feature.target_acts_relative >= FEATURE_THRESHOLD:
            basketball_feature_indices.add(feature.index)

    baseball_feature_indices: Set[int] = set()
    for feature in basketball_baseball_layer_features.features.values():
        if feature.benign_acts_relative >= FEATURE_THRESHOLD:
            baseball_feature_indices.add(feature.index)

else:
    raise ValueError(f"Invalid FEATURE_SOURCE: {FEATURE_SOURCE}. Must be 'wiki' or 'baseball'.")

print(f"\nFound {len(basketball_feature_indices)} significant basketball features (target_acts_relative >= {FEATURE_THRESHOLD})")
print(f"Found {len(baseball_feature_indices)} significant baseball features (benign_acts_relative >= {FEATURE_THRESHOLD})")
print(f"Overlap: {len(basketball_feature_indices & baseball_feature_indices)} features")
print(f"\nFeature source: {FEATURE_SOURCE}")

Loading features from Basketball vs Baseball comparison...
Loaded 19122 features for Basketball vs Baseball

Found 2000 significant basketball features (target_acts_relative >= 3.0)
Found 1819 significant baseball features (benign_acts_relative >= 3.0)
Overlap: 0 features

Feature source: baseball


## 3. Load SAE and Model

In [4]:
# Load SAE for layer 16
print("Loading SAE for layer 16...")
layer_path = os.path.join(SAE_SAVE_PATH, f"layer_{LAYER_TO_ANALYZE}")
if not os.path.exists(layer_path):
    raise FileNotFoundError(f"SAE not found at {layer_path}. Please run basketball_baseball_analysis.ipynb first.")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sae = TopkSae.load_from_disk(layer_path, device=device, decoder=True)
print(f"SAE loaded. Input dim: {sae.d_in}, SAE dim: {sae.d_sae}, Top-k: {sae.cfg.k}")

# Load model using TransformerLens
print("Loading model with TransformerLens...")
model = HookedTransformer.from_pretrained(
    MODEL_CARD,
    device=device,
    dtype=torch.bfloat16,
    trust_remote_code=True
)
print(f"Model loaded. Hidden size: {model.cfg.d_model}")

# Verify dimensions match
assert sae.d_in == model.cfg.d_model, f"SAE input dim ({sae.d_in}) doesn't match model hidden size ({model.cfg.d_model})"
print("✓ Dimensions verified")

Loading SAE for layer 16...


SAE loaded. Input dim: 4096, SAE dim: 32768, Top-k: 50
Loading model with TransformerLens...


`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 106.52it/s]


Loaded pretrained model meta-llama/Llama-3.1-8B into HookedTransformer
Model loaded. Hidden size: 4096
✓ Dimensions verified


## 4. Implement Intervention Hook Function

In [5]:
def create_intervention_hook(
    sae: TopkSae,
    basketball_indices: Set[int],
    baseball_indices: Set[int],
    basketball_factor: float,
    baseball_factor: float
):
    """
    Create a hook function that modifies hidden states by intervening on SAE features.
    
    Args:
        sae: The SAE model
        basketball_indices: Set of basketball feature indices to suppress
        baseball_indices: Set of baseball feature indices to amplify
        basketball_factor: Multiplier for basketball features (0.0 = suppress, 1.0 = no change)
        baseball_factor: Multiplier for baseball features (1.0 = no change, >1.0 = amplify)
    
    Returns:
        Hook function that can be used with TransformerLens
    """
    def intervention_hook(hidden_states, hook):
        """
        Hook function that modifies hidden states at layer 16.
        
        Args:
            hidden_states: Tensor of shape (batch, seq_len, hidden_dim)
            hook: TransformerLens hook object
        
        Returns:
            Modified hidden states
        """
        # Convert sets to tensors for efficient vectorized operations (inside hook for device access)
        basketball_tensor = torch.tensor(list(basketball_indices), device=hidden_states.device, dtype=torch.long)
        baseball_tensor = torch.tensor(list(baseball_indices), device=hidden_states.device, dtype=torch.long)
        batch_size, seq_len, hidden_dim = hidden_states.shape
        
        # Reshape for SAE processing: (batch * seq_len, hidden_dim)
        hidden_flat = hidden_states.reshape(-1, hidden_dim)
        
        # Encode through SAE to get feature activations
        # Get pre-activations and select top-k
        pre_acts = sae.encode_pre_relu(hidden_flat)  # (batch * seq_len, d_sae)
        acts = torch.relu(pre_acts)  # (batch * seq_len, d_sae)
        
        # Get top-k activations and indices
        topk_acts, topk_indices = acts.topk(sae.cfg.k, dim=-1, sorted=False)
        # topk_acts: (batch * seq_len, k)
        # topk_indices: (batch * seq_len, k)
        
        # Create intervention mask using vectorized operations
        # Shape: (batch * seq_len, k)
        batch_seq_len = topk_indices.shape[0]
        
        # Create masks for basketball and baseball features
        # Expand indices for comparison: (batch * seq_len, k, 1) vs (1, len(basketball_indices))
        topk_indices_expanded = topk_indices.unsqueeze(-1)  # (batch * seq_len, k, 1)
        basketball_mask = (topk_indices_expanded == basketball_tensor.unsqueeze(0).unsqueeze(0)).any(dim=-1)  # (batch * seq_len, k)
        baseball_mask = (topk_indices_expanded == baseball_tensor.unsqueeze(0).unsqueeze(0)).any(dim=-1)  # (batch * seq_len, k)
        
        # Apply intervention factors
        intervention_factors = torch.ones_like(topk_acts)
        intervention_factors[basketball_mask] = basketball_factor
        intervention_factors[baseball_mask] = baseball_factor
        
        # Apply interventions
        modified_topk_acts = topk_acts * intervention_factors
        
        # Create sparse activation tensor and scatter modified activations
        sparse_acts = torch.zeros_like(acts)
        sparse_acts.scatter_(-1, topk_indices, modified_topk_acts)
        
        # Decode modified activations back to hidden state space
        modified_hidden_flat = sae.decode(sparse_acts)  # (batch * seq_len, hidden_dim)
        
        # Reshape back to original shape
        modified_hidden = modified_hidden_flat.reshape(batch_size, seq_len, hidden_dim)
        
        return modified_hidden
    
    return intervention_hook

# Create the intervention hook
intervention_hook = create_intervention_hook(
    sae=sae,
    basketball_indices=basketball_feature_indices,
    baseball_indices=baseball_feature_indices,
    basketball_factor=INTERVENTION_FACTOR_BASKETBALL,
    baseball_factor=INTERVENTION_FACTOR_BASEBALL
)

print("Intervention hook created successfully")

Intervention hook created successfully


## 5. Create Intervention Utilities

In [6]:
def get_top_tokens(logits: torch.Tensor, tokenizer, k: int = 5) -> List[Tuple[int, str, float]]:
    """
    Get top k tokens from logits.
    
    Args:
        logits: Tensor of shape (batch, seq_len, vocab_size) or (vocab_size,)
        tokenizer: Tokenizer to decode tokens
        k: Number of top tokens to return
    
    Returns:
        List of (token_id, token_text, probability) tuples
    """
    # Handle different input shapes
    if logits.dim() == 3:
        # Take the last token's logits
        logits = logits[0, -1, :]
    elif logits.dim() == 2:
        # Take the last token's logits
        logits = logits[0, -1, :]
    
    # Get probabilities
    probs = torch.softmax(logits, dim=-1)
    
    # Get top k
    top_probs, top_indices = probs.topk(k, dim=-1)
    
    # Decode tokens and return with IDs
    results = []
    for prob, idx in zip(top_probs, top_indices):
        token_id = idx.item()
        token_text = tokenizer.decode([token_id])
        results.append((token_id, token_text, prob.item()))
    
    return results

def run_model_with_intervention(
    model: HookedTransformer,
    prompt: str,
    intervention_hook,
    layer: int
) -> torch.Tensor:
    """
    Run model with intervention hook.
    
    Args:
        model: HookedTransformer model
        prompt: Input prompt
        intervention_hook: Hook function to apply
        layer: Layer to hook into
    
    Returns:
        Logits tensor
    """
    hook_name = f"blocks.{layer}.hook_resid_post"
    
    with model.hooks([(hook_name, intervention_hook)]):
        logits = model(prompt, return_type="logits")
    
    return logits

def run_model_without_intervention(
    model: HookedTransformer,
    prompt: str
) -> torch.Tensor:
    """
    Run model without intervention.
    
    Args:
        model: HookedTransformer model
        prompt: Input prompt
    
    Returns:
        Logits tensor
    """
    logits = model(prompt, return_type="logits")
    return logits

def compare_predictions(
    prompt: str,
    model: HookedTransformer,
    intervention_hook,
    layer: int,
    k: int = 5
) -> Dict:
    """
    Compare predictions with and without intervention.
    
    Args:
        prompt: Input prompt
        model: HookedTransformer model
        intervention_hook: Hook function
        layer: Layer to hook into
        k: Number of top tokens to return
    
    Returns:
        Dictionary with original and modified predictions
    """
    # Run without intervention
    logits_original = run_model_without_intervention(model, prompt)
    top_original = get_top_tokens(logits_original, model.tokenizer, k)
    
    # Run with intervention
    logits_modified = run_model_with_intervention(model, prompt, intervention_hook, layer)
    top_modified = get_top_tokens(logits_modified, model.tokenizer, k)
    
    return {
        "prompt": prompt,
        "original": top_original,
        "modified": top_modified,
        "logits_original": logits_original,
        "logits_modified": logits_modified
    }

print("Utility functions created")

Utility functions created


In [11]:
# Prompt 1: Michael Jordan prompt
prompt1 = "Michael Jordan plays the sports of"

# Prompt 2: Load a sample from Wikipedia retain examples
print("Loading Wikipedia sample for prompt 2...")
wiki_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
wiki_texts = prepare_text(wiki_data['text'], max_len=1000)

# Take a sample that's not related to sports
# Filter for texts that don't contain basketball/baseball keywords
sports_keywords = ["basketball", "baseball", "sport", "game", "player", "team", "coach"]
wiki_samples = [text for text in wiki_texts[:100] 
                if not any(keyword in text.lower() for keyword in sports_keywords)]

if len(wiki_samples) > 0:
    prompt2 = wiki_samples[0][:200]  # Take first 200 chars
    # Ensure it ends at a word boundary
    if len(prompt2) < len(wiki_samples[0]):
        last_space = prompt2.rfind(' ')
        if last_space > 0:
            prompt2 = prompt2[:last_space]
else:
    # Fallback to a generic prompt
    prompt2 = "The history of science involves many important discoveries and developments."

print(f"Prompt 1: {prompt1}")
print(f"\nPrompt 2: {prompt2}")

Loading Wikipedia sample for prompt 2...
Prompt 1: Michael Jordan plays the sports of

Prompt 2: As the Nameless officially do not exist , the upper echelons of the Gallian Army exploit the concept of plausible deniability in order to send them on missions that would otherwise make Gallia lose


## 7. Run Experiments

In [None]:
# Run experiment for Prompt 1
print("="*80)
print("EXPERIMENT 1: Michael Jordan Prompt")
print("="*80)
k = 10
results1 = compare_predictions(
    prompt1,
    model,
    intervention_hook,
    LAYER_TO_ANALYZE,
    k=k
)

print(f"\nPrompt: {results1['prompt']}")
print("\nTop 5 tokens WITHOUT intervention:")
for i, token_data in enumerate(results1['original'], 1):
    # Handle both old format (2 values) and new format (3 values)
    if len(token_data) == 3:
        token_id, token_text, prob = token_data
    else:
        # Old format - decode token ID from text if possible
        token_text, prob = token_data
        # Try to get token ID by encoding the text
        try:
            token_id = model.tokenizer.encode(token_text, add_special_tokens=False)[0]
        except:
            token_id = "N/A"
    print(f"  {i}. ID: {token_id:<6} Text: {repr(token_text):<30} (prob: {prob:.4f})")

print(f"\nTop {k} tokens WITH intervention:")
for i, token_data in enumerate(results1['modified'], 1):
    # Handle both old format (2 values) and new format (3 values)
    if len(token_data) == 3:
        token_id, token_text, prob = token_data
    else:
        # Old format - decode token ID from text if possible
        token_text, prob = token_data
        # Try to get token ID by encoding the text
        try:
            token_id = model.tokenizer.encode(token_text, add_special_tokens=False)[0]
        except:
            token_id = "N/A"
    print(f"  {i}. ID: {token_id:<6} Text: {repr(token_text):<30} (prob: {prob:.4f})")

EXPERIMENT 1: Michael Jordan Prompt

Prompt: Michael Jordan plays the sports of

Top 5 tokens WITHOUT intervention:
  1. ID: 19794  Text: ' basketball'                  (prob: 0.4590)
  2. ID: 19665  Text: ' golf'                        (prob: 0.1689)
  3. ID: 813    Text: ' his'                         (prob: 0.1157)
  4. ID: 2324   Text: ' life'                        (prob: 0.0189)
  5. ID: 279    Text: ' the'                         (prob: 0.0178)
  6. ID: 20075  Text: ' baseball'                    (prob: 0.0178)
  7. ID: 32515  Text: ' tennis'                      (prob: 0.0167)
  8. ID: 47589  Text: ' Basketball'                  (prob: 0.0122)
  9. ID: 9141   Text: ' football'                    (prob: 0.0051)
  10. ID: 28131  Text: ' Golf'                        (prob: 0.0051)


NameError: name 'k' is not defined

In [13]:
# Run experiment for Prompt 2
print("="*80)
print("EXPERIMENT 2: General Wikipedia Prompt")
print("="*80)
results2 = compare_predictions(
    prompt2,
    model,
    intervention_hook,
    LAYER_TO_ANALYZE,
    k=5
)

print(f"\nPrompt: {results2['prompt']}")
print("\nTop 5 tokens WITHOUT intervention:")
for i, token_data in enumerate(results2['original'], 1):
    # Handle both old format (2 values) and new format (3 values)
    if len(token_data) == 3:
        token_id, token_text, prob = token_data
    else:
        # Old format - decode token ID from text if possible
        token_text, prob = token_data
        # Try to get token ID by encoding the text
        try:
            token_id = model.tokenizer.encode(token_text, add_special_tokens=False)[0]
        except:
            token_id = "N/A"
    print(f"  {i}. ID: {token_id:<6} Text: {repr(token_text):<30} (prob: {prob:.4f})")

print("\nTop 5 tokens WITH intervention:")
for i, token_data in enumerate(results2['modified'], 1):
    # Handle both old format (2 values) and new format (3 values)
    if len(token_data) == 3:
        token_id, token_text, prob = token_data
    else:
        # Old format - decode token ID from text if possible
        token_text, prob = token_data
        # Try to get token ID by encoding the text
        try:
            token_id = model.tokenizer.encode(token_text, add_special_tokens=False)[0]
        except:
            token_id = "N/A"
    print(f"  {i}. ID: {token_id:<6} Text: {repr(token_text):<30} (prob: {prob:.4f})")

EXPERIMENT 2: General Wikipedia Prompt



Prompt: As the Nameless officially do not exist , the upper echelons of the Gallian Army exploit the concept of plausible deniability in order to send them on missions that would otherwise make Gallia lose

Top 5 tokens WITHOUT intervention:
  1. ID: 3663   Text: ' face'                        (prob: 0.6641)
  2. ID: 1202   Text: ' its'                         (prob: 0.0698)
  3. ID: 279    Text: ' the'                         (prob: 0.0330)
  4. ID: 38769  Text: ' credibility'                 (prob: 0.0200)
  5. ID: 6625   Text: ' international'               (prob: 0.0177)

Top 5 tokens WITH intervention:
  1. ID: 1202   Text: ' its'                         (prob: 0.6133)
  2. ID: 304    Text: ' in'                          (prob: 0.0732)
  3. ID: 4033   Text: ' official'                    (prob: 0.0270)
  4. ID: 477    Text: ' or'                          (prob: 0.0270)
  5. ID: 389    Text: ' on'                          (prob: 0.0270)


## 8. Analysis and Visualization

In [14]:
# Create comparison tables
def create_comparison_table(results: Dict, title: str, tokenizer):
    """Create a formatted comparison table."""
    # Normalize data format - handle both old (2-tuple) and new (3-tuple) formats
    def normalize_token_data(token_data):
        if len(token_data) == 3:
            return token_data  # (token_id, token_text, prob)
        else:
            # Old format: (token_text, prob) - need to get token_id
            token_text, prob = token_data
            try:
                token_id = tokenizer.encode(token_text, add_special_tokens=False)[0]
            except:
                token_id = "N/A"
            return (token_id, token_text, prob)
    
    original_normalized = [normalize_token_data(t) for t in results['original']]
    modified_normalized = [normalize_token_data(t) for t in results['modified']]
    
    df_data = {
        "Rank": list(range(1, len(original_normalized) + 1)),
        "Original ID": [t[0] for t in original_normalized],
        "Original Text": [repr(t[1]) for t in original_normalized],
        "Original Prob": [f"{t[2]:.4f}" for t in original_normalized],
        "Modified ID": [t[0] for t in modified_normalized],
        "Modified Text": [repr(t[1]) for t in modified_normalized],
        "Modified Prob": [f"{t[2]:.4f}" for t in modified_normalized],
    }
    
    df = pd.DataFrame(df_data)
    
    print(f"\n{title}")
    print("="*120)
    display(df)
    
    # Calculate changes
    print("\nChanges:")
    original_tokens = {(t[0], repr(t[1])): (i, t[2]) for i, t in enumerate(original_normalized)}
    modified_tokens = {(t[0], repr(t[1])): (i, t[2]) for i, t in enumerate(modified_normalized)}
    
    # Find tokens that changed rank (by token ID)
    for (token_id, token_text), (orig_rank, orig_prob) in original_tokens.items():
        if (token_id, token_text) in modified_tokens:
            mod_rank, mod_prob = modified_tokens[(token_id, token_text)]
            if orig_rank != mod_rank:
                rank_change = mod_rank - orig_rank
                prob_change = mod_prob - orig_prob
                print(f"  ID {token_id} ({token_text}): Rank {orig_rank+1} → {mod_rank+1} ({rank_change:+d}), Prob {orig_prob:.4f} → {mod_prob:.4f} ({prob_change:+.4f})")
        else:
            print(f"  ID {token_id} ({token_text}): Dropped from top 5 (was rank {orig_rank+1})")
    
    # Find new tokens
    for (token_id, token_text), (mod_rank, mod_prob) in modified_tokens.items():
        if (token_id, token_text) not in original_tokens:
            print(f"  ID {token_id} ({token_text}): New entry at rank {mod_rank+1} (prob: {mod_prob:.4f})")

# Create tables for both experiments
create_comparison_table(results1, "COMPARISON TABLE: Michael Jordan Prompt", model.tokenizer)
create_comparison_table(results2, "COMPARISON TABLE: General Wikipedia Prompt", model.tokenizer)


COMPARISON TABLE: Michael Jordan Prompt


Unnamed: 0,Rank,Original ID,Original Text,Original Prob,Modified ID,Modified Text,Modified Prob
0,1,19794,' basketball',0.459,10034,' sports',0.1689
1,2,19665,' golf',0.1689,9522,'...\n',0.1025
2,3,813,' his',0.1157,9141,' football',0.0903
3,4,2324,' life',0.0189,19794,' basketball',0.0703
4,5,279,' the',0.0178,10775,' sport',0.0549
5,6,20075,' baseball',0.0178,198,'\n',0.0276
6,7,32515,' tennis',0.0167,1131,'...',0.0228
7,8,47589,' Basketball',0.0122,1847,' game',0.0201
8,9,9141,' football',0.0051,1972,' real',0.0201
9,10,28131,' Golf',0.0051,279,' the',0.0178



Changes:
  ID 19794 (' basketball'): Rank 1 → 4 (+3), Prob 0.4590 → 0.0703 (-0.3887)
  ID 19665 (' golf'): Dropped from top 5 (was rank 2)
  ID 813 (' his'): Dropped from top 5 (was rank 3)
  ID 2324 (' life'): Dropped from top 5 (was rank 4)
  ID 279 (' the'): Rank 5 → 10 (+5), Prob 0.0178 → 0.0178 (+0.0000)
  ID 20075 (' baseball'): Dropped from top 5 (was rank 6)
  ID 32515 (' tennis'): Dropped from top 5 (was rank 7)
  ID 47589 (' Basketball'): Dropped from top 5 (was rank 8)
  ID 9141 (' football'): Rank 9 → 3 (-6), Prob 0.0051 → 0.0903 (+0.0852)
  ID 28131 (' Golf'): Dropped from top 5 (was rank 10)
  ID 10034 (' sports'): New entry at rank 1 (prob: 0.1689)
  ID 9522 ('...\n'): New entry at rank 2 (prob: 0.1025)
  ID 10775 (' sport'): New entry at rank 5 (prob: 0.0549)
  ID 198 ('\n'): New entry at rank 6 (prob: 0.0276)
  ID 1131 ('...'): New entry at rank 7 (prob: 0.0228)
  ID 1847 (' game'): New entry at rank 8 (prob: 0.0201)
  ID 1972 (' real'): New entry at rank 9 (prob: 0.0

Unnamed: 0,Rank,Original ID,Original Text,Original Prob,Modified ID,Modified Text,Modified Prob
0,1,3663,' face',0.6641,1202,' its',0.6133
1,2,1202,' its',0.0698,304,' in',0.0732
2,3,279,' the',0.033,4033,' official',0.027
3,4,38769,' credibility',0.02,477,' or',0.027
4,5,6625,' international',0.0177,389,' on',0.027



Changes:
  ID 3663 (' face'): Dropped from top 5 (was rank 1)
  ID 1202 (' its'): Rank 2 → 1 (-1), Prob 0.0698 → 0.6133 (+0.5435)
  ID 279 (' the'): Dropped from top 5 (was rank 3)
  ID 38769 (' credibility'): Dropped from top 5 (was rank 4)
  ID 6625 (' international'): Dropped from top 5 (was rank 5)
  ID 304 (' in'): New entry at rank 2 (prob: 0.0732)
  ID 4033 (' official'): New entry at rank 3 (prob: 0.0270)
  ID 477 (' or'): New entry at rank 4 (prob: 0.0270)
  ID 389 (' on'): New entry at rank 5 (prob: 0.0270)
