# ⏱️ Temporal Steering with GPT-2

This notebook demonstrates **Contrastive Activation Addition (CAA)** to steer GPT-2's temporal scope from immediate/short-term thinking to long-term/strategic thinking.

## What is Temporal Steering?

We extract "steering vectors" by comparing activations from:
- **Immediate prompts**: "Develop a 1 week plan to..."
- **Long-term prompts**: "Develop a 20 year plan to..."

These vectors capture the difference in how the model represents immediate vs. long-term thinking, and we can add them during generation to shift the model's temporal perspective.

## Approach

1. Extract activations from prompt pairs with different temporal horizons
2. Compute contrastive vectors: `steering_vector = long_term_activations - immediate_activations`
3. Apply steering during generation by adding vectors to hidden states
4. Observe how responses shift between tactical and strategic thinking

---

**Based on:** [Contrastive Activation Addition](https://github.com/steering-vectors/steering-vectors)


## 1. Setup & Installation

In [None]:
!pip install -q transformers torch numpy matplotlib ipywidgets

In [None]:
import torch
import numpy as np
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from typing import Dict, List, Tuple
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, HTML, clear_output

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("✓ Imports successful")

## 2. Prepare Prompt Pairs

We create pairs of prompts that differ only in their temporal horizon.

In [None]:
# Sample prompt pairs for steering vector extraction
PROMPT_PAIRS = [
    {
        "task": "establishing a new data center with procurement of servers, storage, and network infrastructure",
        "immediate": "Develop a 1 week plan to establishing a new data center with procurement of servers, storage, and network infrastructure.",
        "long_term": "Develop a 20 years plan to establishing a new data center with procurement of servers, storage, and network infrastructure."
    },
    {
        "task": "develop a comprehensive marketing strategy including market research and target audiences",
        "immediate": "Develop a 1 month plan to develop a comprehensive marketing strategy including market research and target audiences.",
        "long_term": "Develop a 10 years plan to develop a comprehensive marketing strategy including market research and target audiences."
    },
    {
        "task": "improve team productivity and collaboration across the organization",
        "immediate": "Develop a 1 week plan to improve team productivity and collaboration across the organization.",
        "long_term": "Develop a 10 years plan to improve team productivity and collaboration across the organization."
    },
    {
        "task": "address climate change and reduce carbon emissions",
        "immediate": "Develop a 1 month plan to address climate change and reduce carbon emissions.",
        "long_term": "Develop a 50 years plan to address climate change and reduce carbon emissions."
    },
    {
        "task": "expand operations to new markets and locations",
        "immediate": "Develop a 1 month plan to expand operations to new markets and locations.",
        "long_term": "Develop a 20 years plan to expand operations to new markets and locations."
    },
    {
        "task": "improve public health outcomes and healthcare access",
        "immediate": "Develop a 1 week plan to improve public health outcomes and healthcare access.",
        "long_term": "Develop a 30 years plan to improve public health outcomes and healthcare access."
    },
    {
        "task": "modernize education system and workforce development",
        "immediate": "Develop a 1 month plan to modernize education system and workforce development.",
        "long_term": "Develop a 25 years plan to modernize education system and workforce development."
    },
    {
        "task": "implement digital transformation across business operations",
        "immediate": "Develop a 1 week plan to implement digital transformation across business operations.",
        "long_term": "Develop a 15 years plan to implement digital transformation across business operations."
    },
]

print(f"Prepared {len(PROMPT_PAIRS)} prompt pairs")
print("\nExample pair:")
print(f"Immediate: {PROMPT_PAIRS[0]['immediate']}")
print(f"Long-term: {PROMPT_PAIRS[0]['long_term']}")

## 3. Load Model

We'll use GPT-2 (124M parameters) which runs efficiently on CPU/GPU.

In [None]:
# Load GPT-2
print("Loading GPT-2...")
model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
model.eval()

# Move to GPU if available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)

print(f"✓ Model loaded on {device}")
print(f"  Model: GPT-2 (124M parameters)")
print(f"  Layers: {len(model.transformer.h)}")

## 4. Extract Steering Vectors

Extract activations from each layer and compute contrastive vectors.

In [None]:
def extract_activations(model, tokenizer, prompt: str, device: str) -> Dict[int, torch.Tensor]:
    """
    Extract activations from all layers for a given prompt.
    
    Returns: dict mapping layer_idx -> activations (seq_len, hidden_dim)
    """
    activations = {}
    
    def hook_fn(layer_num):
        def hook(module, input, output):
            # output[0] is the hidden states
            activations[layer_num] = output[0].detach()
        return hook
    
    # Register hooks for all layers
    hooks = []
    for i, layer in enumerate(model.transformer.h):
        hook = layer.register_forward_hook(hook_fn(i))
        hooks.append(hook)
    
    # Forward pass
    inputs = tokenizer(prompt, return_tensors='pt').to(device)
    with torch.no_grad():
        model(**inputs)
    
    # Remove hooks
    for hook in hooks:
        hook.remove()
    
    return activations


def compute_steering_vectors(model, tokenizer, prompt_pairs: List[Dict], device: str) -> Dict[int, np.ndarray]:
    """
    Compute steering vectors from prompt pairs.
    
    Returns: dict mapping layer_idx -> steering_vector (hidden_dim,)
    """
    n_layers = len(model.transformer.h)
    layer_contrasts = {layer: [] for layer in range(n_layers)}
    
    print(f"Extracting activations from {len(prompt_pairs)} prompt pairs...")
    
    for pair in tqdm(prompt_pairs):
        # Extract activations
        immediate_acts = extract_activations(model, tokenizer, pair['immediate'], device)
        long_term_acts = extract_activations(model, tokenizer, pair['long_term'], device)
        
        # Compute contrastive vectors at each layer
        for layer in range(n_layers):
            # Take final token position
            imm_vec = immediate_acts[layer][0, -1, :].cpu().numpy()
            long_vec = long_term_acts[layer][0, -1, :].cpu().numpy()
            
            # Contrastive vector: long_term - immediate
            contrast = long_vec - imm_vec
            layer_contrasts[layer].append(contrast)
    
    # Average across all pairs
    steering_vectors = {}
    for layer in range(n_layers):
        contrasts = np.stack(layer_contrasts[layer])
        steering_vectors[layer] = contrasts.mean(axis=0)
    
    return steering_vectors

print("✓ Extraction functions defined")

In [None]:
# Extract steering vectors
steering_vectors = compute_steering_vectors(model, tokenizer, PROMPT_PAIRS, device)

print(f"\n✓ Extracted steering vectors for {len(steering_vectors)} layers")
print(f"  Vector dimension: {len(steering_vectors[0])}")

## 5. Analyze Steering Vectors

Visualize the strength of steering vectors across layers.

In [None]:
# Compute norms for each layer
layer_norms = {layer: np.linalg.norm(vec) for layer, vec in steering_vectors.items()}

# Plot
plt.figure(figsize=(12, 5))

# Bar plot of norms
plt.subplot(1, 2, 1)
layers = list(layer_norms.keys())
norms = list(layer_norms.values())
plt.bar(layers, norms, color='steelblue', alpha=0.7)
plt.xlabel('Layer', fontsize=12)
plt.ylabel('Steering Vector Norm', fontsize=12)
plt.title('Steering Vector Strength by Layer', fontsize=14, fontweight='bold')
plt.grid(axis='y', alpha=0.3)

# Cumulative view
plt.subplot(1, 2, 2)
plt.plot(layers, norms, marker='o', linewidth=2, markersize=8, color='coral')
plt.xlabel('Layer', fontsize=12)
plt.ylabel('Steering Vector Norm', fontsize=12)
plt.title('Steering Effect Across Layers', fontsize=14, fontweight='bold')
plt.grid(alpha=0.3)

plt.tight_layout()
plt.show()

# Print top layers
sorted_layers = sorted(layer_norms.items(), key=lambda x: x[1], reverse=True)
print("\nTop 5 layers with strongest steering effect:")
for layer, norm in sorted_layers[:5]:
    print(f"  Layer {layer:2d}: {norm:.3f}")

## 6. Apply Temporal Steering

Now we can steer the model during generation by adding our steering vectors to the activations.

In [None]:
class TemporalSteering:
    """Apply temporal steering vectors during generation."""
    
    def __init__(self, model, tokenizer, steering_vectors, target_layers=None):
        self.model = model
        self.tokenizer = tokenizer
        self.steering_vectors = steering_vectors
        
        # Default to middle-to-late layers
        if target_layers is None:
            n_layers = len(model.transformer.h)
            start = max(0, n_layers - 8)
            self.target_layers = list(range(start, n_layers))
        else:
            self.target_layers = target_layers
    
    def generate(self, prompt: str, steering_strength: float = 0.0, 
                 max_length: int = 100, temperature: float = 0.7, **kwargs):
        """
        Generate text with temporal steering.
        
        Args:
            prompt: Input text
            steering_strength: -1.0 (immediate) to +1.0 (long-term)
            max_length: Maximum tokens
            temperature: Sampling temperature
        """
        inputs = self.tokenizer(prompt, return_tensors='pt').to(self.model.device)
        input_ids = inputs['input_ids']
        
        # Register steering hooks
        hooks = []
        
        def make_hook(layer_idx, strength):
            def hook(module, input, output):
                hidden_states = output[0]
                
                if layer_idx in self.steering_vectors:
                    steering_vec = torch.tensor(
                        self.steering_vectors[layer_idx],
                        dtype=hidden_states.dtype,
                        device=hidden_states.device
                    )
                    hidden_states = hidden_states + strength * steering_vec
                
                return (hidden_states,) + output[1:]
            return hook
        
        # Register hooks
        for layer_idx in self.target_layers:
            hook = self.model.transformer.h[layer_idx].register_forward_hook(
                make_hook(layer_idx, steering_strength)
            )
            hooks.append(hook)
        
        # Generate
        with torch.no_grad():
            output_ids = self.model.generate(
                input_ids,
                max_length=max_length,
                temperature=temperature,
                do_sample=True,
                top_p=0.9,
                pad_token_id=self.tokenizer.eos_token_id,
                **kwargs
            )
        
        # Remove hooks
        for hook in hooks:
            hook.remove()
        
        # Decode
        generated = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
        return generated

# Initialize steering system
steering_system = TemporalSteering(model, tokenizer, steering_vectors)
print(f"✓ Steering system ready (layers {steering_system.target_layers[0]}-{steering_system.target_layers[-1]})")

## 7. Interactive Demo

Try steering GPT-2's temporal scope with the interactive controls below!

In [None]:
# Create interactive widgets
prompt_input = widgets.Textarea(
    value='What should policymakers prioritize to address climate change?',
    placeholder='Enter your prompt...',
    description='Prompt:',
    layout=widgets.Layout(width='100%', height='80px')
)

steering_slider = widgets.FloatSlider(
    value=0.0,
    min=-1.0,
    max=1.0,
    step=0.1,
    description='Temporal Steering:',
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
    layout=widgets.Layout(width='80%')
)

temp_slider = widgets.FloatSlider(
    value=0.7,
    min=0.1,
    max=1.5,
    step=0.1,
    description='Temperature:',
    continuous_update=False,
    readout_format='.1f',
    layout=widgets.Layout(width='50%')
)

max_length_slider = widgets.IntSlider(
    value=100,
    min=50,
    max=200,
    step=10,
    description='Max Length:',
    continuous_update=False,
    layout=widgets.Layout(width='50%')
)

generate_button = widgets.Button(
    description='Generate',
    button_style='primary',
    icon='play',
    layout=widgets.Layout(width='150px', height='40px')
)

output_area = widgets.Output()

# Example prompts
example_buttons = []
examples = [
    "What should policymakers prioritize to address climate change?",
    "Develop a plan to improve team productivity and collaboration.",
    "How should we approach solving the housing affordability crisis?",
    "Create a strategy for digital transformation in our organization.",
    "What investments should we make in education and workforce development?",
]

for example in examples:
    btn = widgets.Button(
        description=example[:40] + '...' if len(example) > 40 else example,
        layout=widgets.Layout(width='auto', margin='2px'),
        button_style='info'
    )
    btn.example_text = example
    example_buttons.append(btn)

# Event handlers
def set_example(b):
    prompt_input.value = b.example_text

for btn in example_buttons:
    btn.on_click(set_example)

def on_generate(b):
    with output_area:
        clear_output()
        
        # Show loading
        display(HTML('<div style="padding: 20px; background: #f0f0f0; border-radius: 8px;"><h3>🔄 Generating...</h3></div>'))
        
        try:
            # Generate
            result = steering_system.generate(
                prompt=prompt_input.value,
                steering_strength=steering_slider.value,
                temperature=temp_slider.value,
                max_length=max_length_slider.value
            )
            
            clear_output()
            
            # Display result
            steering_label = "Neutral"
            if steering_slider.value < -0.6:
                steering_label = "Strong Immediate 🔥"
            elif steering_slider.value < -0.2:
                steering_label = "Moderate Immediate"
            elif steering_slider.value < 0.2:
                steering_label = "Neutral ⚖️"
            elif steering_slider.value < 0.6:
                steering_label = "Moderate Long-term"
            else:
                steering_label = "Strong Long-term 🌱"
            
            html = f"""
            <div style="padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 12px; color: white; margin-bottom: 15px;">
                <h3 style="margin: 0;">Temporal Steering: {steering_label} ({steering_slider.value:.1f})</h3>
            </div>
            <div style="padding: 20px; background: #f8f9fa; border-left: 4px solid #667eea; border-radius: 8px; line-height: 1.6;">
                <pre style="white-space: pre-wrap; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; margin: 0;">{result}</pre>
            </div>
            """
            display(HTML(html))
            
        except Exception as e:
            clear_output()
            display(HTML(f'<div style="padding: 20px; background: #fee; border-radius: 8px; color: #c00;"><strong>Error:</strong> {str(e)}</div>'))

generate_button.on_click(on_generate)

# Layout
display(HTML('<h2>🎮 Interactive Temporal Steering Demo</h2>'))
display(HTML('<p style="color: #666;">Adjust the slider to shift between immediate/tactical thinking (-1.0) and long-term/strategic thinking (+1.0)</p>'))
display(HTML('<hr style="margin: 20px 0;">'))

display(HTML('<h4>Example Prompts:</h4>'))
display(widgets.HBox(example_buttons, layout=widgets.Layout(flex_flow='row wrap')))
display(HTML('<br>'))

display(prompt_input)
display(HTML('<br>'))
display(steering_slider)
display(widgets.HBox([temp_slider, max_length_slider]))
display(HTML('<br>'))
display(generate_button)
display(HTML('<br>'))
display(output_area)

## 8. Side-by-Side Comparison

Compare immediate vs. long-term steering directly.

In [None]:
def compare_steering(prompt: str, strength: float = 0.8):
    """
    Generate responses with immediate (-strength) and long-term (+strength) steering.
    """
    print(f"Prompt: {prompt}")
    print("="*80)
    
    # Immediate
    print(f"\n🔥 IMMEDIATE STEERING (-{strength})")
    print("-"*80)
    immediate = steering_system.generate(
        prompt=prompt,
        steering_strength=-strength,
        temperature=0.7,
        max_length=100
    )
    print(immediate)
    
    # Long-term
    print(f"\n🌱 LONG-TERM STEERING (+{strength})")
    print("-"*80)
    long_term = steering_system.generate(
        prompt=prompt,
        steering_strength=strength,
        temperature=0.7,
        max_length=100
    )
    print(long_term)
    print("\n" + "="*80)

# Try it!
compare_steering("What should policymakers prioritize to address climate change?", strength=0.8)

## 9. Experiment: Your Own Prompts

Try your own prompts and observe how steering affects the responses!

In [None]:
# Try your own!
my_prompt = "How should we improve workplace culture and employee wellbeing?"

compare_steering(my_prompt, strength=1.0)

## 🎯 Key Takeaways

1. **Steering vectors** encode the difference between immediate and long-term thinking in model activations
2. **Later layers** show stronger steering effects (layers 9-11 for GPT-2)
3. **Positive steering** (+1.0) pushes toward strategic/long-term thinking
4. **Negative steering** (-1.0) pushes toward tactical/immediate thinking
5. **Temperature** controls randomness - lower = more focused

## 🔬 Further Exploration

- Try different model sizes (gpt2-medium, gpt2-large)
- Extract steering vectors from your own prompt pairs
- Experiment with layer selection (early vs. late layers)
- Combine temporal steering with other steering dimensions
- Test on decision-making scenarios with temporal trade-offs

## 📚 Resources

- [Steering Vectors Library](https://github.com/steering-vectors/steering-vectors)
- [Representation Engineering Paper](https://arxiv.org/abs/2310.01405)
- [GPT-2 Paper](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)

---

**Built with ❤️ using Contrastive Activation Addition**
