# Manual Activation Patching Experiments

In [None]:
import sys
import os
import torch
import json
import random
from IPython.display import display, HTML

# Add TransformerLens to path
sys.path.append('/Users/ivanculo/Desktop/Projects/turn_point/third_party/TransformerLens')

# Import our activation patcher
sys.path.append('/Users/ivanculo/Desktop/Projects/turn_point/manual_activation_patching')
from activation_patcher import ActivationPatcher
from interpretation_templates import INTERPRETATION_TEMPLATES

# Data loading utilities
def load_cognitive_patterns(dataset_path="/Users/ivanculo/Desktop/Projects/turn_point/data/final/positive_patterns.jsonl"):
    """Load the cognitive patterns dataset with all text variants (positive, negative, transition)."""
    patterns = []
    pattern_types = {}
    
    with open(dataset_path, 'r') as f:
        for line in f:
            pattern = json.loads(line.strip())
            patterns.append(pattern)
            
            # Group by cognitive pattern type
            pattern_type = pattern['cognitive_pattern_type']
            if pattern_type not in pattern_types:
                pattern_types[pattern_type] = []
            pattern_types[pattern_type].append(pattern)
    
    return patterns, pattern_types

def get_pattern_by_index(patterns, index):
    """Get a pattern by index with bounds checking."""
    if 0 <= index < len(patterns):
        return patterns[index]
    else:
        raise IndexError(f"Index {index} out of range. Dataset has {len(patterns)} patterns.")

def get_pattern_by_type(pattern_types, pattern_type):
    """Get patterns by cognitive pattern type."""
    if pattern_type in pattern_types:
        return pattern_types[pattern_type]
    else:
        available_types = list(pattern_types.keys())
        raise KeyError(f"Pattern type '{pattern_type}' not found. Available types: {available_types}")

def get_pattern_text(pattern, text_type="positive"):
    """
    Get specific text variant from a pattern.
    
    Args:
        pattern: The pattern dictionary
        text_type: "positive", "negative", or "transition"
    
    Returns:
        The requested text string
    """
    text_map = {
        "positive": "positive_thought_pattern",
        "negative": "reference_negative_example", 
        "transition": "reference_transformed_example"
    }
    
    if text_type not in text_map:
        raise ValueError(f"text_type must be one of: {list(text_map.keys())}")
    
    field_name = text_map[text_type]
    if field_name not in pattern:
        raise KeyError(f"Pattern missing field: {field_name}")
    
    return pattern[field_name]

def get_template(template_name):
    """Get an interpretation template by name."""
    if template_name in INTERPRETATION_TEMPLATES:
        return INTERPRETATION_TEMPLATES[template_name]
    else:
        available_templates = list(INTERPRETATION_TEMPLATES.keys())
        raise KeyError(f"Template '{template_name}' not found. Available templates: {available_templates}")

def show_pattern_info(pattern):
    """Display detailed information about a pattern."""
    print(f"🧠 Pattern: {pattern['cognitive_pattern_name']}")
    print(f"🔄 Type: {pattern['cognitive_pattern_type']}")
    print(f"📝 Description: {pattern['pattern_description']}")
    print(f"❓ Source Question: {pattern['source_question']}")
    print(f"\n✅ Positive Text: {pattern['positive_thought_pattern']}")
    print(f"\n❌ Negative Text: {pattern['reference_negative_example']}")
    print(f"\n🔄 Transition Text: {pattern['reference_transformed_example']}")

print("Imports and utilities loaded successfully!")

## Initialize the Model and Load Dataset

In [None]:
# Choose your model here - change this and re-run to experiment with different models
MODEL_NAME = "google/gemma-2-2b-it"  # Change to: gpt2-medium, EleutherAI/gpt-neo-125m, etc.

# Set the BOS token for your model (change this based on your model)
BOS_TOKEN = "<bos>"  # Default for Gemma models

# Initialize the activation patcher
patcher = ActivationPatcher(MODEL_NAME)

# Load the cognitive patterns dataset with new utilities
patterns, pattern_types = load_cognitive_patterns()

print(f"Loaded {len(patterns)} cognitive patterns")
print(f"Available pattern types: {len(pattern_types)}")
for pattern_type, examples in pattern_types.items():
    print(f"  - {pattern_type}: {len(examples)} examples")

print(f"\nModel info: {patcher.get_model_info()}")
print(f"Using BOS token: '{BOS_TOKEN}'")
print(f"\nAvailable interpretation templates: {list(INTERPRETATION_TEMPLATES.keys())}")
print(f"\nAvailable text types: positive, negative, transition")

## Explore the Dataset

Available pattern types:
  1. Instrumental suicidal reasoning
  2. Intrusive suicidal fixation
  3. Fragmented perceptual reasoning
  4. Over-elaborative recounting
  5. Autobiographical integration
  6. Cognitive disorganization
  7. Negative self-evaluative loop
  8. Existential rumination
  9. Entrapment cognition
  10. Cognitive depletion pattern
  11. Hyper-attuned interoception
  12. Learned helplessness loop
  13. Internal dialectical processing

## Experiment 1: Basic Activation Patching

Let's start with a simple experiment using one positive pattern:

In [None]:
# EXPERIMENT 1: Basic Activation Patching with Template

# Configuration - Easy to modify these parameters
PATTERN_INDEX = 2  # Change this to use different patterns (0-49)
CLEAN_TEXT_TYPE = "positive"  # "positive", "negative", or "transition"
TEMPLATE_NAME = "cognitive_pattern"  # Template to use for corrupted text

# Load the selected pattern and template
selected_pattern = get_pattern_by_index(patterns, PATTERN_INDEX)
template = get_template(TEMPLATE_NAME)

# Get the clean text (what we want to capture activations from)
clean_text = get_pattern_text(selected_pattern, CLEAN_TEXT_TYPE)

print(f"📋 EXPERIMENT 1 SETUP:")
print(f"✅ Pattern {PATTERN_INDEX}: {selected_pattern['cognitive_pattern_name']}")
print(f"🧠 Pattern Type: {selected_pattern['cognitive_pattern_type']}")
print(f"📝 Clean Text Type: {CLEAN_TEXT_TYPE}")
print(f"📄 Clean Text: {clean_text}")
print(f"\n🎯 Using Template: '{TEMPLATE_NAME}'")

# Extract the template components  
template_bos = template[0]  # "<bos>"
zero_positions = template[1:6]  # The 5 zeros (0, 0, 0, 0, 0) 
continuation_text = template[6]  # The text part

print(f"📄 Template: {template}")
print(f"\n🔧 Template Structure:")
print(f"  BOS placeholder: {template_bos}")
print(f"  Patch positions: {zero_positions} (count: {len(zero_positions)})")
print(f"  Continuation: {continuation_text}")

# The corrupted text is the continuation part (BOS will be added automatically)
corrupted_text = continuation_text

# Extract key words for patching
target_words = patcher._extract_key_words(clean_text)

print(f"\n⚡ Running activation patching...")
print(f"  Clean text: {clean_text[:100]}...")
print(f"  Corrupted text: {corrupted_text}")
print(f"  Target words: {target_words}")
print(f"  Number of patch positions: {len(zero_positions)}")

# Perform the patching and generation
predicted_token, generated_text = patcher.patch_and_generate(
    clean_text=clean_text,
    corrupted_text=corrupted_text,
    target_words=target_words,
    num_placeholder_tokens=len(zero_positions),
    capture_layer_idx=-1,  # Last layer for capture
    patch_layer_idx=-1,    # Last layer for patch
    max_new_tokens=60,
    bos_token=BOS_TOKEN
)

print("\n" + "="*80)
print("🎊 EXPERIMENT 1 RESULTS:")
print("="*80)
print(f"Model: {patcher.model_name}")
print(f"Pattern: {selected_pattern['cognitive_pattern_name']}")
print(f"Clean text type: {CLEAN_TEXT_TYPE}")
print(f"Template: {TEMPLATE_NAME}")
print(f"\n📊 Generated Text:")
print(f"{generated_text}")
print("="*80)

In [None]:
# Extract key words for patching
target_words = patcher._extract_key_words(clean_text)
print(f"Target words for patching: {target_words}")

# Define the template structure for corrupted text input
# This template provides the structure: BOS token + placeholder tokens + continuation text
template_key = "basic_activation_template"
template = (
    "<bos>",  # BOS token (will be replaced by actual BOS_TOKEN)
    0, 0, 0, 0, 0,  # 5 placeholder tokens (represented as zeros)
    "I need to shift my perspective and find a constructive way to"  # Continuation text
)

# Create the corrupted prompt from the template
# The template structure is: ("<bos>", 0, 0, 0, 0, 0, "\n\nThis neural activation represents...")
template_bos_token = template[0]  # "<bos>" from the template
zero_tokens = template[1:6]  # The 5 zeros
continuation_text = template[6]  # The text part

print(f"Template structure:")
print(f"  Template BOS token: {template_bos_token}")
print(f"  Zero tokens: {zero_tokens} (count: {len(zero_tokens)})")
print(f"  Continuation: {continuation_text}")

# Use the continuation text as the corrupted_text
# The BOS_TOKEN variable will be automatically prepended
corrupted_text = continuation_text

print(f"\nUsing BOS token: '{BOS_TOKEN}'")
print(f"Corrupted text for patching: '{corrupted_text}'")

# Perform the patching and generation with NEW MULTI-LAYER FUNCTIONALITY
predicted_token, generated_text = patcher.patch_and_generate(
    clean_text=clean_text,
    corrupted_text=corrupted_text,
    target_words=target_words,
    num_placeholder_tokens=5,  # This matches the 5 zero tokens in the template
    capture_layer_idx=-1,  # Last layer for capture
    patch_layer_idx=-1,    # Last layer for patch (same as capture)
    max_new_tokens=60,
    bos_token=BOS_TOKEN  # Use the manually set BOS token
)

print("\n" + "="*100)
print("EXPERIMENT 1 RESULTS:")
print("="*100)
print(f"Model: {patcher.model_name}")
print(f"BOS token used: '{BOS_TOKEN}'")
print(f"Template used: {template_key}")
print(f"Clean pattern: {sample_pattern['cognitive_pattern_name']}")
print(f"Clean text (first 100 chars): {clean_text[:100]}...")
print(f"Target words extracted: {target_words}")
print(f"\nGenerated Text:\n{generated_text}")
print("="*100)

#### 🔄 Reset Model State

Before running experiments, it's good practice to reset any lingering hooks from previous runs:

In [None]:
# 🔄 RESET MODEL HOOKS - Run this cell to reset the model to clean state
# This is especially important when switching between different experiments

patcher.reset_hooks()

print("Model is now ready for clean experiments!")
print("Run this cell anytime you want to ensure no residual hooks are affecting your results.")

## Experiment 2: Different Layers Comparison

Let's see how patching at different layers affects the output:

In [None]:
# Test patching at different layers with NEW MULTI-LAYER FUNCTIONALITY
layers_to_test = [0, 3, 6, 9, -1]  # Early, middle, late, and final layers
sample_pattern = patterns[5]  # Use a different pattern

clean_text = sample_pattern['positive_thought_pattern']
corrupted_text = "I can't stop worrying about everything and feel completely"
target_words = patcher._extract_key_words(clean_text)

print(f"Pattern: {sample_pattern['cognitive_pattern_name']}")
print(f"Target words: {target_words}")
print(f"Corrupted prompt: {corrupted_text}")
print("\n" + "="*100)

layer_results = {}

for layer_idx in layers_to_test:
    print(f"\n--- LAYER {layer_idx} ---")
    try:
        predicted_token, generated_text = patcher.patch_and_generate(
            clean_text=clean_text,
            corrupted_text=corrupted_text,
            target_words=target_words,
            num_placeholder_tokens=5,
            capture_layer_idx=layer_idx,  # Updated parameter name
            patch_layer_idx=layer_idx,    # Patch to same layer
            max_new_tokens=50
        )
        layer_results[layer_idx] = generated_text
        print(f"Generated: {generated_text}")
    except Exception as e:
        print(f"Error at layer {layer_idx}: {e}")
        layer_results[layer_idx] = f"Error: {e}"
    print("-" * 80)

## Experiment 3: Multiple Patterns Comparison

Let's compare how different cognitive patterns affect the generation:

In [None]:
# Test with multiple different cognitive patterns
test_patterns = patterns[:5]  # Use first 5 patterns
corrupted_text = "I feel trapped and don't see a way forward because"

print(f"Testing with corrupted prompt: '{corrupted_text}'")
print("\n" + "="*100)

pattern_results = []

for i, pattern in enumerate(test_patterns):
    print(f"\n--- PATTERN {i+1}: {pattern['cognitive_pattern_name']} ---")
    
    clean_text = pattern['positive_thought_pattern']
    target_words = patcher._extract_key_words(clean_text)
    
    print(f"Clean text (first 100 chars): {clean_text[:100]}...")
    print(f"Target words: {target_words}")
    
    try:
        predicted_token, generated_text = patcher.patch_and_generate(
            clean_text=clean_text,
            corrupted_text=corrupted_text,
            target_words=target_words,
            num_placeholder_tokens=5,
            capture_layer_idx=-1,  # Updated parameter name
            patch_layer_idx=-1,    # Updated parameter name
            max_new_tokens=65
        )
        
        pattern_results.append({
            'pattern_name': pattern['cognitive_pattern_name'],
            'generated_text': generated_text
        })
        
        print(f"\nGENERATED TEXT:")
        print(generated_text)
        
    except Exception as e:
        print(f"Error with pattern {i+1}: {e}")
        pattern_results.append({
            'pattern_name': pattern['cognitive_pattern_name'],
            'generated_text': f"Error: {e}"
        })
    
    print("-" * 100)

## Experiment 4: Different Number of Patch Positions

Let's experiment with varying the number of placeholder tokens we patch:

In [None]:
# Test with different numbers of placeholder tokens
placeholder_counts = [1, 3, 5, 7]
sample_pattern = patterns[10]  # Use another pattern

clean_text = sample_pattern['positive_thought_pattern']
corrupted_text = "My mind keeps racing with negative thoughts and I feel"
target_words = patcher._extract_key_words(clean_text)

print(f"Pattern: {sample_pattern['cognitive_pattern_name']}")
print(f"Target words: {target_words}")
print(f"Corrupted prompt: {corrupted_text}")
print("\n" + "="*100)

placeholder_results = {}

for num_placeholders in placeholder_counts:
    print(f"\n--- {num_placeholders} PLACEHOLDER TOKENS ---")
    try:
        predicted_token, generated_text = patcher.patch_and_generate(
            clean_text=clean_text,
            corrupted_text=corrupted_text,
            target_words=target_words,
            num_placeholder_tokens=num_placeholders,
            capture_layer_idx=-1,  # Updated parameter name
            patch_layer_idx=-1,    # Updated parameter name
            max_new_tokens=55
        )
        placeholder_results[num_placeholders] = generated_text
        print(f"Generated: {generated_text}")
    except Exception as e:
        print(f"Error with {num_placeholders} placeholders: {e}")
        placeholder_results[num_placeholders] = f"Error: {e}"
    print("-" * 80)

## Experiment 5: Baseline Comparison (No Patching)

Let's generate text without any patching to see the baseline behavior:

In [None]:
# Generate baseline text without patching
test_prompts = [
    "I feel completely overwhelmed and don't know how to",
    "My thoughts are spiraling out of control and I feel like",
    "Everything seems hopeless and I can't figure out why",
    "I'm stuck in negative thinking patterns and can't seem to",
    "The anxiety is taking over and I don't think I can"
]

print("BASELINE GENERATIONS (No Patching):")
print("="*100)

baseline_results = []

for i, prompt in enumerate(test_prompts):
    print(f"\n--- Prompt {i+1}: {prompt} ---")
    
    # Tokenize and generate without any hooks
    tokens = patcher.model.to_tokens(prompt)
    generated_tokens = patcher.model.generate(
        tokens,
        max_new_tokens=60,
        temperature=0.7,
        do_sample=True
    )
    generated_text = patcher.model.to_string(generated_tokens[0])
    
    baseline_results.append({
        'prompt': prompt,
        'generated_text': generated_text
    })
    
    print(f"Generated: {generated_text}")
    print("-" * 80)

## Experiment 6: Custom Pattern Testing

Interactive cell for testing your own patterns:

In [None]:
# Interactive experiment - modify these variables to test your own patterns

# You can change these to experiment with different combinations
CUSTOM_CLEAN_TEXT = "I'm taking a moment to acknowledge my feelings and remind myself that challenges are temporary. I can take small, manageable steps forward and focus on what I can control right now."

CUSTOM_CORRUPTED_TEXT = "I don't know what to do anymore and feel completely lost"

CUSTOM_TARGET_WORDS = ["acknowledge", "feelings", "challenges", "temporary", "control"]  # Or leave empty to auto-extract

CUSTOM_NUM_PLACEHOLDERS = 5
CUSTOM_CAPTURE_LAYER = -1  # Layer to capture activations from
CUSTOM_PATCH_LAYER = 3    # Layer to patch activations into
CUSTOM_MAX_TOKENS = 70

print("CUSTOM EXPERIMENT WITH MULTI-LAYER FUNCTIONALITY:")
print("="*100)
print(f"Clean text: {CUSTOM_CLEAN_TEXT}")
print(f"Corrupted text: {CUSTOM_CORRUPTED_TEXT}")

# Auto-extract target words if not provided
if not CUSTOM_TARGET_WORDS:
    CUSTOM_TARGET_WORDS = patcher._extract_key_words(CUSTOM_CLEAN_TEXT)

print(f"Target words: {CUSTOM_TARGET_WORDS}")
print(f"Settings: {CUSTOM_NUM_PLACEHOLDERS} placeholders, capture layer {CUSTOM_CAPTURE_LAYER}, patch layer {CUSTOM_PATCH_LAYER}, {CUSTOM_MAX_TOKENS} max tokens")

try:
    predicted_token, generated_text = patcher.patch_and_generate(
        clean_text=CUSTOM_CLEAN_TEXT,
        corrupted_text=CUSTOM_CORRUPTED_TEXT,
        target_words=CUSTOM_TARGET_WORDS,
        num_placeholder_tokens=CUSTOM_NUM_PLACEHOLDERS,
        capture_layer_idx=CUSTOM_CAPTURE_LAYER,  # Updated parameter name
        patch_layer_idx=CUSTOM_PATCH_LAYER,      # Updated parameter name
        max_new_tokens=CUSTOM_MAX_TOKENS
    )
    
    print("\nGENERATED TEXT:")
    print(generated_text)
    
except Exception as e:
    print(f"Error in custom experiment: {e}")

print("\n" + "="*100)
print("ADVANCED MULTI-LAYER EXAMPLES:")
print("="*100)

# Example 1: Capture from multiple layers, patch to single layer
print("\nExample 1: Capture from early layers [0, 1, 2], patch to last layer")
try:
    predicted_token, generated_text = patcher.patch_and_generate(
        clean_text=CUSTOM_CLEAN_TEXT,
        corrupted_text=CUSTOM_CORRUPTED_TEXT,
        target_words=CUSTOM_TARGET_WORDS[:3],  # Use fewer words for this test
        num_placeholder_tokens=3,
        capture_layer_idx=[0, 1, 2],  # Multiple capture layers
        patch_layer_idx=-1,           # Single patch layer
        max_new_tokens=50
    )
    print(f"Result: {generated_text}")
except Exception as e:
    print(f"Error: {e}")

# Example 2: Capture from one layer, patch to multiple layers
print("\nExample 2: Capture from last layer, patch to multiple middle layers [5, 7, 9]")
try:
    predicted_token, generated_text = patcher.patch_and_generate(
        clean_text=CUSTOM_CLEAN_TEXT,
        corrupted_text=CUSTOM_CORRUPTED_TEXT,
        target_words=CUSTOM_TARGET_WORDS[:3],
        num_placeholder_tokens=3,
        capture_layer_idx=-1,         # Single capture layer
        patch_layer_idx=[5, 7, 9],    # Multiple patch layers
        max_new_tokens=50
    )
    print(f"Result: {generated_text}")
except Exception as e:
    print(f"Error: {e}")

# Example 3: Use range for layers
print("\nExample 3: Capture from range(3, 6), patch to range(8, 11)")
try:
    predicted_token, generated_text = patcher.patch_and_generate(
        clean_text=CUSTOM_CLEAN_TEXT,
        corrupted_text=CUSTOM_CORRUPTED_TEXT,
        target_words=CUSTOM_TARGET_WORDS[:3],
        num_placeholder_tokens=3,
        capture_layer_idx=list(range(3, 6)),  # Layers 3, 4, 5
        patch_layer_idx=list(range(8, 11)),   # Layers 8, 9, 10
        max_new_tokens=50
    )
    print(f"Result: {generated_text}")
except Exception as e:
    print(f"Error: {e}")

# Example 4: All layers (use with caution - computationally intensive)
print("\nExample 4: Capture from 'all' layers, patch to last 3 layers")
try:
    predicted_token, generated_text = patcher.patch_and_generate(
        clean_text=CUSTOM_CLEAN_TEXT,
        corrupted_text=CUSTOM_CORRUPTED_TEXT,
        target_words=CUSTOM_TARGET_WORDS[:2],  # Use even fewer words
        num_placeholder_tokens=2,
        capture_layer_idx='all',      # All layers for capture
        patch_layer_idx=[-3, -2, -1], # Last 3 layers for patch
        max_new_tokens=40
    )
    print(f"Result: {generated_text}")
except Exception as e:
    print(f"Error: {e}")

## 🛠️ Troubleshooting & Utilities

In [None]:
# 🛠️ TROUBLESHOOTING UTILITIES

# 1. Reset hooks if experiments behave unexpectedly
def quick_reset():
    patcher.reset_hooks()
    print("🔄 Hooks reset - Model is clean!")



# 2. Check model state
def check_model_info():
    print("📊 MODEL STATUS:")
    print(f"  Model: {patcher.model_name}")
    print(f"  Device: {patcher.model.cfg.device}")
    print(f"  Layers: {patcher.model.cfg.n_layers}")
    print(f"  D_model: {patcher.model.cfg.d_model}")
    print(f"  Vocab size: {patcher.model.cfg.d_vocab}")
    
# 3. Clear memory if needed
def clear_memory():
    import torch
    import gc
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()
    print("🧹 Memory cleared!")

# Quick access functions
print("Available utilities:")
print("- quick_reset() - Reset model hooks")
print("- check_model_info() - Display model details") 
print("- clear_memory() - Clear GPU/system memory")
print("- patcher.reset_hooks() - Direct reset call")

# Uncomment any line below to run:
# quick_reset()
# check_model_info()
# clear_memory()