# Batch Activation Patching Experiments

This notebook allows you to experiment with batch activation extraction and aggregation. Instead of using activations from a single text, you can:

1. **Extract activations from multiple positive texts**
2. **Aggregate them using different methods** (mean, median, max, etc.)
3. **Use the aggregated activation for patching**
4. **Compare different aggregation strategies**

This approach may provide more robust and generalizable activation patterns for patching.

In [2]:
import sys
import os
import torch
import json
import random
import numpy as np
from IPython.display import display, HTML

# Add TransformerLens to path
sys.path.append('/home/koalacrown/Desktop/Code/Projects/turnaround/turn_point/third_party/TransformerLens')

# Import our activation patcher
sys.path.append('/home/koalacrown/Desktop/Code/Projects/turnaround/turn_point/manual_activation_patching')
from activation_patcher import ActivationPatcher

print("Imports successful!")

# Set random seeds for reproducibility
# random.seed(42)
# np.random.seed(42)
# torch.manual_seed(42)

Imports successful!


## Choose Model and Load Dataset

In [None]:
# Choose your model here - change this and re-run to experiment with different models
MODEL_NAME = "gpt2-small"  # Change to: gpt2-medium, EleutherAI/gpt-neo-125m, etc.

# Initialize the activation patcher
patcher = ActivationPatcher(MODEL_NAME)

# Load the positive patterns dataset
dataset_path = "/home/koalacrown/Desktop/Code/Projects/turnaround/turn_point/data/final/positive_patterns.jsonl"
patterns = patcher.load_dataset(dataset_path)

print(f"Loaded {len(patterns)} positive thought patterns")
print(f"Model info: {patcher.get_model_info()}")

# Ensure model starts in clean state
patcher.reset_hooks()
print("✅ Model initialized and hooks reset to clean state")

## Explore Available Models

In [3]:
# List all supported models
ActivationPatcher.list_supported_models()

Supported Models:

GPT2:
  - gpt2-small (small)
  - gpt2-medium (medium)
  - gpt2-large (large)
  - gpt2-xl (xl)

GPTJ:
  - EleutherAI/gpt-j-6b (6b)

GPT-NEO:
  - EleutherAI/gpt-neo-125m (125m)
  - EleutherAI/gpt-neo-1.3b (1.3b)
  - EleutherAI/gpt-neo-2.7b (2.7b)

OPT:
  - facebook/opt-125m (125m)
  - facebook/opt-1.3b (1.3b)
  - facebook/opt-2.7b (2.7b)
  - facebook/opt-6.7b (6.7b)

PYTHIA:
  - EleutherAI/pythia-70m (70m)
  - EleutherAI/pythia-160m (160m)
  - EleutherAI/pythia-410m (410m)
  - EleutherAI/pythia-1b (1b)
  - EleutherAI/pythia-1.4b (1.4b)
  - EleutherAI/pythia-2.8b (2.8b)

GEMMA:
  - google/gemma-2b (2b)
  - google/gemma-7b (7b)


## Prepare Batch Data

In [None]:
# Select patterns by cognitive type or randomly
def filter_patterns_by_type(patterns, pattern_type=None, max_count=20):
    """Filter patterns by cognitive type or return random selection."""
    if pattern_type:
        filtered = [p for p in patterns if pattern_type.lower() in p.get('cognitive_pattern_type', '').lower()]
        print(f"Found {len(filtered)} patterns matching '{pattern_type}'")
    else:
        filtered = patterns.copy()
        random.shuffle(filtered)
        print(f"Using random selection from {len(filtered)} patterns")
    
    return filtered[:max_count]

# Available cognitive pattern types in the dataset:
pattern_types = set([p.get('cognitive_pattern_type', '') for p in patterns])
print("Available cognitive pattern types:")
for ptype in sorted(pattern_types):
    if ptype:
        count = len([p for p in patterns if p.get('cognitive_pattern_type', '') == ptype])
        print(f"  - {ptype} ({count} examples)")

In [None]:
# Configuration for batch experiments
BATCH_SIZE = 15  # Number of texts to use for batch activation extraction
PATTERN_TYPE = None  # Set to specific type like "rumination" or None for random

# Select patterns for batch
selected_patterns = filter_patterns_by_type(patterns, PATTERN_TYPE, BATCH_SIZE)

# Extract texts
batch_texts = [p['positive_thought_pattern'] for p in selected_patterns]

print(f"\nSelected {len(batch_texts)} texts for batch processing:")
for i, text in enumerate(batch_texts[:5]):  # Show first 5
    print(f"{i+1}. {text[:100]}...")
if len(batch_texts) > 5:
    print(f"... and {len(batch_texts) - 5} more")

In [None]:
# 🔄 RESET MODEL HOOKS - Essential for batch experiments
# Run this before starting new batch experiments to avoid interference

patcher.reset_hooks()

print("🔄 Model hooks reset - Ready for clean batch experiments!")
print("💡 Tip: Run this cell between different batch configurations")

## 🔄 Reset Model State

Important: Reset hooks before running batch experiments to ensure clean state:

## Experiment 1: Different Aggregation Methods

In [None]:
# Test different aggregation methods with NEW MULTI-LAYER FUNCTIONALITY
aggregation_methods = ["mean", "median", "max", "random"]
corrupted_prompt = "I feel completely overwhelmed and stuck, unable to"
target_words = ["positive", "solutions", "growth", "hope", "progress"]

print(f"Testing aggregation methods with:")
print(f"- Batch size: {len(batch_texts)}")
print(f"- Corrupted prompt: {corrupted_prompt}")
print(f"- Target words: {target_words}")
print("\n" + "="*100)

aggregation_results = {}

for method in aggregation_methods:
    print(f"\n--- Testing {method.upper()} aggregation ---")
    
    try:
        predicted_token, generated_text = patcher.batch_patch_and_generate(
            clean_texts=batch_texts,
            corrupted_text=corrupted_prompt,
            capture_layer_idx=-1,     # Updated parameter name
            patch_layer_idx=-1,       # Explicit patch layer specification
            aggregation=method,
            target_words=target_words,
            num_placeholder_tokens=5,
            max_new_tokens=60
        )
        
        aggregation_results[method] = {
            'success': True,
            'predicted_token': predicted_token,
            'generated_text': generated_text
        }
        
        print(f"\n✓ SUCCESS with {method} aggregation:")
        print(f"Generated: {generated_text}")
        
    except Exception as e:
        print(f"✗ Error with {method} aggregation: {e}")
        aggregation_results[method] = {
            'success': False,
            'error': str(e),
            'generated_text': None
        }
    
    print("-" * 80)

# Memory cleanup
if torch.cuda.is_available():
    torch.cuda.empty_cache()

## Experiment 2: Batch Size Comparison

In [None]:
# Test different batch sizes
batch_sizes = [1, 3, 5, 10, 15]
corrupted_prompt = "My thoughts are spiraling and I can't seem to"
aggregation_method = "mean"

print(f"Testing batch sizes with {aggregation_method} aggregation")
print(f"Corrupted prompt: {corrupted_prompt}")
print("\n" + "="*100)

batch_size_results = {}

for batch_size in batch_sizes:
    if batch_size > len(batch_texts):
        print(f"Skipping batch size {batch_size} (not enough texts available)")
        continue
        
    print(f"\n--- Testing batch size: {batch_size} ---")
    
    subset_texts = batch_texts[:batch_size]
    
    try:
        predicted_token, generated_text = patcher.batch_patch_and_generate(
            clean_texts=subset_texts,
            corrupted_text=corrupted_prompt,
            capture_layer_idx=-1,     # Updated parameter name
            patch_layer_idx=-1,       # Updated parameter name
            aggregation=aggregation_method,
            target_words=None,
            num_placeholder_tokens=5,
            max_new_tokens=55
        )
        
        batch_size_results[batch_size] = {
            'success': True,
            'generated_text': generated_text,
            'predicted_token': predicted_token
        }
        
        print(f"\n✓ SUCCESS with batch size {batch_size}:")
        print(f"Generated: {generated_text}")
        
    except Exception as e:
        print(f"✗ Error with batch size {batch_size}: {e}")
        batch_size_results[batch_size] = {
            'success': False,
            'error': str(e),
            'generated_text': None
        }
    
    print("-" * 80)

if torch.cuda.is_available():
    torch.cuda.empty_cache()

## Experiment 3: Custom Batch Configuration

In [None]:
# Custom configuration - modify these variables for your experiments
CUSTOM_BATCH_TEXTS = [
    "I'm learning to acknowledge my feelings and take things one step at a time.",
    "I choose to focus on solutions rather than dwelling on problems.",
    "I'm practicing self-compassion and recognizing my growth.",
    "I can handle challenges by breaking them into manageable pieces.",
    "I'm building resilience and finding healthy ways to cope.",
    "I trust in my ability to navigate difficult situations.",
    "I'm grateful for the support I have and the progress I've made.",
    "I choose hope and believe that positive change is possible."
]

CUSTOM_CORRUPTED_TEXT = "I don't know how to deal with this situation and feel"
CUSTOM_AGGREGATION = "mean"
CUSTOM_TARGET_WORDS = ["acknowledge", "focus", "solutions", "growth", "hope"]
CUSTOM_CAPTURE_LAYER = -1      # Layer to capture activations from
CUSTOM_PATCH_LAYER = -1        # Layer to patch activations into
CUSTOM_PLACEHOLDERS = 5
CUSTOM_MAX_TOKENS = 70

print("CUSTOM BATCH EXPERIMENT WITH MULTI-LAYER FUNCTIONALITY:")
print("="*100)
print(f"Batch texts ({len(CUSTOM_BATCH_TEXTS)} total):")
for i, text in enumerate(CUSTOM_BATCH_TEXTS[:3]):
    print(f"  {i+1}. {text}")
if len(CUSTOM_BATCH_TEXTS) > 3:
    print(f"  ... and {len(CUSTOM_BATCH_TEXTS) - 3} more")

print(f"\nCorrupted text: {CUSTOM_CORRUPTED_TEXT}")
print(f"Aggregation: {CUSTOM_AGGREGATION}")
print(f"Target words: {CUSTOM_TARGET_WORDS}")
print(f"Capture layer: {CUSTOM_CAPTURE_LAYER}")
print(f"Patch layer: {CUSTOM_PATCH_LAYER}")
print(f"Placeholders: {CUSTOM_PLACEHOLDERS}")
print(f"Max tokens: {CUSTOM_MAX_TOKENS}")

try:
    predicted_token, generated_text = patcher.batch_patch_and_generate(
        clean_texts=CUSTOM_BATCH_TEXTS,
        corrupted_text=CUSTOM_CORRUPTED_TEXT,
        capture_layer_idx=CUSTOM_CAPTURE_LAYER,  # Updated parameter name
        patch_layer_idx=CUSTOM_PATCH_LAYER,      # Updated parameter name
        aggregation=CUSTOM_AGGREGATION,
        target_words=CUSTOM_TARGET_WORDS,
        num_placeholder_tokens=CUSTOM_PLACEHOLDERS,
        max_new_tokens=CUSTOM_MAX_TOKENS
    )
    
    print("\n" + "="*50)
    print("CUSTOM EXPERIMENT RESULT:")
    print("="*50)
    print(generated_text)
    print("="*50)
    
except Exception as e:
    print(f"\n✗ Error in custom experiment: {e}")

print("\n" + "="*100)
print("ADVANCED MULTI-LAYER BATCH EXAMPLES:")
print("="*100)

# Example 1: Cross-layer batch patching
print("\nExample 1: Capture from early layers [0,1,2], patch to late layer [-1]")
try:
    predicted_token, generated_text = patcher.batch_patch_and_generate(
        clean_texts=CUSTOM_BATCH_TEXTS[:5],  # Use fewer texts for demo
        corrupted_text=CUSTOM_CORRUPTED_TEXT,
        capture_layer_idx=[0, 1, 2],  # Multiple capture layers
        patch_layer_idx=-1,           # Single patch layer
        aggregation="mean",
        target_words=CUSTOM_TARGET_WORDS[:3],
        num_placeholder_tokens=3,
        max_new_tokens=50
    )
    print(f"Result: {generated_text}")
except Exception as e:
    print(f"Error: {e}")

# Example 2: Multi-layer broadcast patching
print("\nExample 2: Capture from layer 5, patch to multiple layers [8,9,10]")
try:
    predicted_token, generated_text = patcher.batch_patch_and_generate(
        clean_texts=CUSTOM_BATCH_TEXTS[:5],
        corrupted_text=CUSTOM_CORRUPTED_TEXT,
        capture_layer_idx=5,           # Single capture layer
        patch_layer_idx=[8, 9, 10],    # Multiple patch layers
        aggregation="median",
        target_words=CUSTOM_TARGET_WORDS[:3],
        num_placeholder_tokens=3,
        max_new_tokens=50
    )
    print(f"Result: {generated_text}")
except Exception as e:
    print(f"Error: {e}")

# Example 3: Range-based layer selection
print("\nExample 3: Capture from range(3,6), patch to range(7,10)")
try:
    predicted_token, generated_text = patcher.batch_patch_and_generate(
        clean_texts=CUSTOM_BATCH_TEXTS[:5],
        corrupted_text=CUSTOM_CORRUPTED_TEXT,
        capture_layer_idx=list(range(3, 6)),  # Layers 3,4,5
        patch_layer_idx=list(range(7, 10)),   # Layers 7,8,9
        aggregation="max",
        target_words=CUSTOM_TARGET_WORDS[:3],
        num_placeholder_tokens=3,
        max_new_tokens=50
    )
    print(f"Result: {generated_text}")
except Exception as e:
    print(f"Error: {e}")

## Experiment 4: Baseline Comparison

In [None]:
# Comparison experiment
comparison_prompt = "I'm overwhelmed by everything and don't see a way to"
batch_for_comparison = batch_texts[:10]
single_text_for_comparison = batch_texts[0]

print("BASELINE COMPARISON:")
print("="*100)
print(f"Test prompt: {comparison_prompt}")
print(f"Batch size: {len(batch_for_comparison)}")

comparison_results = {}

# 1. No patching (baseline)
print("\n--- 1. NO PATCHING (Baseline) ---")
try:
    tokens = patcher.model.to_tokens(comparison_prompt)
    generated_tokens = patcher.model.generate(
        tokens,
        max_new_tokens=60,
        temperature=0.7,
        do_sample=True
    )
    baseline_text = patcher.model.to_string(generated_tokens[0])
    comparison_results['no_patching'] = baseline_text
    print(f"Generated: {baseline_text}")
except Exception as e:
    print(f"Error: {e}")
    comparison_results['no_patching'] = f"Error: {e}"

# 2. Single text patching
print("\n--- 2. SINGLE TEXT PATCHING ---")
try:
    target_words = patcher._extract_key_words(single_text_for_comparison)
    predicted_token, single_patch_text = patcher.patch_and_generate(
        clean_text=single_text_for_comparison,
        corrupted_text=comparison_prompt,
        target_words=target_words,
        max_new_tokens=60
    )
    comparison_results['single_patching'] = single_patch_text
    print(f"Generated: {single_patch_text}")
except Exception as e:
    print(f"Error: {e}")
    comparison_results['single_patching'] = f"Error: {e}"

# 3. Batch patching
print("\n--- 3. BATCH PATCHING ---")
try:
    predicted_token, batch_patch_text = patcher.batch_patch_and_generate(
        clean_texts=batch_for_comparison,
        corrupted_text=comparison_prompt,
        aggregation="mean",
        max_new_tokens=60
    )
    comparison_results['batch_patching'] = batch_patch_text
    print(f"Generated: {batch_patch_text}")
except Exception as e:
    print(f"Error: {e}")
    comparison_results['batch_patching'] = f"Error: {e}"

print("\n" + "="*100)
print("COMPARISON SUMMARY:")
print("="*100)
for method, result in comparison_results.items():
    print(f"\n{method.upper().replace('_', ' ')}:")
    print(result)
    print("-" * 60)

## Results Summary

In [None]:
# 🛠️ BATCH EXPERIMENT UTILITIES

# 1. Quick reset for batch experiments
def batch_reset():
    patcher.reset_hooks()
    print("🔄 Batch experiment reset - Model ready!")

# 2. Memory management for large batches
def batch_memory_cleanup():
    import torch
    import gc
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()
    patcher.reset_hooks()  # Also reset hooks
    print("🧹 Memory cleared and hooks reset for batch experiments")

# 3. Check batch experiment status
def batch_status():
    print("📊 BATCH EXPERIMENT STATUS:")
    print(f"  Model: {patcher.model_name}")
    print(f"  Total patterns available: {len(patterns) if 'patterns' in globals() else 'Not loaded'}")
    print(f"  Current batch size: {len(batch_texts) if 'batch_texts' in globals() else 'Not configured'}")
    
# 4. Quick test batch configuration
def test_small_batch():
    """Run a quick test with minimal batch to check everything works"""
    test_texts = [
        "I choose to focus on positive solutions.",
        "I can handle this step by step.",
        "I'm grateful for my progress so far."
    ]
    
    try:
        patcher.reset_hooks()
        predicted_token, generated_text = patcher.batch_patch_and_generate(
            clean_texts=test_texts,
            corrupted_text="I feel stuck and don't know",
            capture_layer_idx=-1,
            patch_layer_idx=-1,
            aggregation="mean",
            num_placeholder_tokens=3,
            max_new_tokens=30
        )
        print("✅ Batch system working correctly!")
        print(f"Test result: {generated_text}")
        return True
    except Exception as e:
        print(f"❌ Batch system error: {e}")
        return False

# Available batch utilities
print("Available batch utilities:")
print("- batch_reset() - Reset hooks for batch experiments")
print("- batch_memory_cleanup() - Clear memory + reset hooks")
print("- batch_status() - Check experiment configuration") 
print("- test_small_batch() - Run quick functionality test")
print("- patcher.reset_hooks() - Direct reset call")

# Uncomment to run:
# batch_reset()
# batch_status()
# test_small_batch()

## 🛠️ Batch Experiment Utilities

## Multi-Layer Batch Patching

The batch activation patcher now supports the same multi-layer functionality as the single-text patcher:

### New Parameters:
- **`capture_layer_idx`**: Layer to capture activations from (replaces `layer_idx`)
- **`patch_layer_idx`**: Layer to patch activations into (defaults to `capture_layer_idx`)

### Current Limitations:
- Batch mode currently supports single patch layer (first layer if multiple specified)
- Full multi-layer patch broadcasting planned for future versions

### Batch + Multi-Layer Benefits:
- **Robust activation extraction**: Aggregate patterns from multiple texts
- **Cross-layer analysis**: Capture from one layer, patch to another
- **Reduced noise**: Statistical aggregation reduces individual text variations
- **Better generalization**: Combined patterns are more representative

### Usage Examples:
- `capture_layer_idx=0, patch_layer_idx=-1`: Early capture → Late patch
- `capture_layer_idx=[0,1,2], patch_layer_idx=5`: Multi-capture → Single patch
- `capture_layer_idx='all', patch_layer_idx=-1`: All layers capture → Late patch

In [None]:
print("BATCH ACTIVATION PATCHING EXPERIMENT SUMMARY")
print("="*100)

print(f"\nModel Used: {patcher.model_name}")
print(f"Model Info: {patcher.get_model_info()}")

print(f"\nDataset: {len(patterns)} total patterns")
print(f"Batch Size Tested: {len(batch_texts)} texts")

experiments = [
    ("Aggregation Methods", globals().get('aggregation_results', {})),
    ("Batch Sizes", globals().get('batch_size_results', {})),
    ("Baseline Comparison", globals().get('comparison_results', {}))
]

for exp_name, results in experiments:
    print(f"\n{exp_name}:")
    if results:
        success_count = sum(1 for r in results.values() if isinstance(r, dict) and r.get('success', False))
        total_count = len(results)
        print(f"  - Completed: {success_count}/{total_count} configurations")
        
        if success_count > 0:
            successful_configs = [k for k, v in results.items() if isinstance(v, dict) and v.get('success', False)]
            if not successful_configs:
                successful_configs = [k for k in results.keys() if not str(results[k]).startswith('Error')]
            print(f"  - Successful configurations: {successful_configs}")
    else:
        print(f"  - Not run in this session")

print("\nKey Findings:")
print("- Batch activation patching allows for more robust activation patterns")
print("- Different aggregation methods can produce varied results")
print("- Batch size affects the quality and consistency of patching")
print("- Comparison with baseline shows the effect of activation patching")

print("\nNext Steps:")
print("- Try different models by changing MODEL_NAME and re-running")
print("- Experiment with different cognitive pattern types")
print("- Test with longer generation sequences")
print("- Try different layers for patching")