# ⏱️ Temporal Steering with GPT-2 - Interactive Demo

**Control GPT-2's temporal scope in real-time** using pre-trained steering vectors!

Move the slider to shift between:
- 🔥 **Immediate/Tactical** thinking (-1.0)
- 🌱 **Long-term/Strategic** thinking (+1.0)

This demo uses **Contrastive Activation Addition (CAA)** - a technique that extracts "steering vectors" representing the difference between immediate and long-term thinking.

---

[![GitHub](https://img.shields.io/badge/GitHub-temporal--steering-blue)](https://github.com/justinshenk/temporal-steering)


## 1️⃣ Setup (Run Once)

In [None]:
# Install dependencies
!pip install -q transformers torch ipywidgets

import torch
import numpy as np
import json
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import ipywidgets as widgets
from IPython.display import display, HTML, clear_output
import warnings
warnings.filterwarnings('ignore')

print("✓ Setup complete!")

## 2️⃣ Load Model & Pre-trained Steering Vectors

In [None]:
# Load GPT-2
print("Loading GPT-2...")
model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
model.eval()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)

print(f"✓ Model loaded on {device}")

In [None]:
# Download pre-trained steering vectors
import urllib.request

print("Downloading pre-trained steering vectors...")
url = "https://raw.githubusercontent.com/justinshenk/temporal-steering/main/steering_vectors/temporal_steering.json"

try:
    with urllib.request.urlopen(url) as response:
        data = json.loads(response.read())
    
    steering_vectors = {
        int(layer): np.array(vec)
        for layer, vec in data['layer_vectors'].items()
    }
    
    print(f"✓ Loaded vectors for {len(steering_vectors)} layers")
    print(f"  Vector dimension: {len(steering_vectors[0])}")
    print(f"  Metadata: {data['metadata']}")
    
except Exception as e:
    print(f"Could not download from GitHub: {e}")
    print("Using fallback: Extracting steering vectors locally...")
    print("(This will take a few minutes)")
    
    # Fallback: Quick extraction from a few prompt pairs
    def extract_quick_vectors(model, tokenizer, device):
        pairs = [
            {"imm": "Develop a 1 week plan to improve team productivity.", 
             "long": "Develop a 20 year plan to improve team productivity."},
            {"imm": "Develop a 1 month plan to address climate change.",
             "long": "Develop a 50 year plan to address climate change."},
            {"imm": "Develop a 1 week plan to expand operations.",
             "long": "Develop a 15 year plan to expand operations."},
        ]
        
        n_layers = len(model.transformer.h)
        layer_contrasts = {layer: [] for layer in range(n_layers)}
        
        for pair in pairs:
            for prompt_type in ['imm', 'long']:
                activations = {}
                
                def hook_fn(layer_num):
                    def hook(module, input, output):
                        activations[layer_num] = output[0].detach()
                    return hook
                
                hooks = []
                for i, layer in enumerate(model.transformer.h):
                    hooks.append(layer.register_forward_hook(hook_fn(i)))
                
                inputs = tokenizer(pair[prompt_type], return_tensors='pt').to(device)
                with torch.no_grad():
                    model(**inputs)
                
                for hook in hooks:
                    hook.remove()
                
                pair[f'{prompt_type}_acts'] = activations
            
            for layer in range(n_layers):
                imm_vec = pair['imm_acts'][layer][0, -1, :].cpu().numpy()
                long_vec = pair['long_acts'][layer][0, -1, :].cpu().numpy()
                layer_contrasts[layer].append(long_vec - imm_vec)
        
        return {layer: np.stack(contrasts).mean(axis=0) 
                for layer, contrasts in layer_contrasts.items()}
    
    steering_vectors = extract_quick_vectors(model, tokenizer, device)
    print(f"✓ Extracted vectors for {len(steering_vectors)} layers")

## 3️⃣ Steering System

In [None]:
class TemporalSteering:
    def __init__(self, model, tokenizer, steering_vectors, target_layers=None):
        self.model = model
        self.tokenizer = tokenizer
        self.steering_vectors = steering_vectors
        
        if target_layers is None:
            n_layers = len(model.transformer.h)
            start = max(0, n_layers - 8)
            self.target_layers = list(range(start, n_layers))
        else:
            self.target_layers = target_layers
    
    def generate(self, prompt, steering_strength=0.0, max_length=100, temperature=0.7):
        inputs = self.tokenizer(prompt, return_tensors='pt').to(self.model.device)
        hooks = []
        
        def make_hook(layer_idx, strength):
            def hook(module, input, output):
                hidden_states = output[0]
                if layer_idx in self.steering_vectors:
                    steering_vec = torch.tensor(
                        self.steering_vectors[layer_idx],
                        dtype=hidden_states.dtype,
                        device=hidden_states.device
                    )
                    hidden_states = hidden_states + strength * steering_vec
                return (hidden_states,) + output[1:]
            return hook
        
        for layer_idx in self.target_layers:
            hooks.append(self.model.transformer.h[layer_idx].register_forward_hook(
                make_hook(layer_idx, steering_strength)
            ))
        
        with torch.no_grad():
            output_ids = self.model.generate(
                inputs['input_ids'],
                max_length=max_length,
                temperature=temperature,
                do_sample=True,
                top_p=0.9,
                pad_token_id=self.tokenizer.eos_token_id
            )
        
        for hook in hooks:
            hook.remove()
        
        return self.tokenizer.decode(output_ids[0], skip_special_tokens=True)

steering = TemporalSteering(model, tokenizer, steering_vectors)
print(f"✓ Steering system ready (layers {steering.target_layers[0]}-{steering.target_layers[-1]})")

## 🎮 Interactive Demo

**Adjust the slider and click Generate** to see how temporal steering affects GPT-2's responses!

In [None]:
# Widgets
prompt_input = widgets.Textarea(
    value='What should policymakers prioritize to address climate change?',
    placeholder='Enter your prompt...',
    description='Prompt:',
    layout=widgets.Layout(width='100%', height='100px')
)

steering_slider = widgets.FloatSlider(
    value=0.0,
    min=-1.0,
    max=1.0,
    step=0.1,
    description='Temporal Steering:',
    continuous_update=False,
    readout_format='.1f',
    layout=widgets.Layout(width='95%')
)

steering_label = widgets.HTML(
    value='<div style="text-align: center; font-size: 16px; font-weight: bold; color: #667eea; padding: 10px;">⚖️ Neutral (0.0)</div>'
)

temp_slider = widgets.FloatSlider(
    value=0.7,
    min=0.3,
    max=1.2,
    step=0.1,
    description='Temperature:',
    continuous_update=False,
    readout_format='.1f',
    layout=widgets.Layout(width='45%')
)

max_length_slider = widgets.IntSlider(
    value=100,
    min=50,
    max=200,
    step=10,
    description='Max Length:',
    continuous_update=False,
    layout=widgets.Layout(width='45%')
)

generate_btn = widgets.Button(
    description='🚀 Generate',
    button_style='primary',
    layout=widgets.Layout(width='200px', height='45px')
)

output_area = widgets.Output()

# Example prompts
examples = [
    "What should policymakers prioritize to address climate change?",
    "Develop a plan to improve team productivity and collaboration.",
    "How should we approach solving the housing affordability crisis?",
    "Create a strategy for improving public health outcomes.",
    "What investments should we make in education and workforce development?",
]

example_btns = []
for ex in examples:
    btn = widgets.Button(
        description=ex[:50] + '...' if len(ex) > 50 else ex,
        button_style='info',
        layout=widgets.Layout(width='auto', margin='3px')
    )
    btn.example_text = ex
    example_btns.append(btn)

# Event handlers
def update_steering_label(change):
    val = change['new']
    if val < -0.6:
        label = f"🔥 Strong Immediate ({val:.1f})"
        color = "#e74c3c"
    elif val < -0.2:
        label = f"🔥 Moderate Immediate ({val:.1f})"
        color = "#e67e22"
    elif val < 0.2:
        label = f"⚖️ Neutral ({val:.1f})"
        color = "#667eea"
    elif val < 0.6:
        label = f"🌱 Moderate Long-term ({val:.1f})"
        color = "#27ae60"
    else:
        label = f"🌱 Strong Long-term ({val:.1f})"
        color = "#16a085"
    
    steering_label.value = f'<div style="text-align: center; font-size: 18px; font-weight: bold; color: {color}; padding: 10px; background: #f8f9fa; border-radius: 8px;">{label}</div>'

steering_slider.observe(update_steering_label, names='value')

def set_example(b):
    prompt_input.value = b.example_text

for btn in example_btns:
    btn.on_click(set_example)

def on_generate(b):
    with output_area:
        clear_output()
        
        display(HTML('''
        <div style="padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); 
                    border-radius: 12px; color: white; text-align: center; margin: 10px 0;">
            <h3 style="margin: 0;">🔄 Generating...</h3>
        </div>
        '''))
        
        try:
            result = steering.generate(
                prompt=prompt_input.value,
                steering_strength=steering_slider.value,
                temperature=temp_slider.value,
                max_length=max_length_slider.value
            )
            
            clear_output()
            
            val = steering_slider.value
            if val < -0.6:
                header_text = "🔥 Strong Immediate Steering"
                gradient = "linear-gradient(135deg, #e74c3c 0%, #c0392b 100%)"
            elif val < -0.2:
                header_text = "🔥 Moderate Immediate Steering"
                gradient = "linear-gradient(135deg, #e67e22 0%, #d35400 100%)"
            elif val < 0.2:
                header_text = "⚖️ Neutral (No Steering)"
                gradient = "linear-gradient(135deg, #667eea 0%, #764ba2 100%)"
            elif val < 0.6:
                header_text = "🌱 Moderate Long-term Steering"
                gradient = "linear-gradient(135deg, #27ae60 0%, #229954 100%)"
            else:
                header_text = "🌱 Strong Long-term Steering"
                gradient = "linear-gradient(135deg, #16a085 0%, #138d75 100%)"
            
            html = f'''
            <div style="padding: 20px; background: {gradient}; 
                        border-radius: 12px; color: white; margin-bottom: 15px;">
                <h3 style="margin: 0;">{header_text} ({val:.1f})</h3>
            </div>
            <div style="padding: 25px; background: #f8f9fa; border-left: 5px solid #667eea; 
                        border-radius: 8px; line-height: 1.8; font-size: 15px;">
                <pre style="white-space: pre-wrap; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; margin: 0;">{result}</pre>
            </div>
            '''
            display(HTML(html))
            
        except Exception as e:
            clear_output()
            display(HTML(f'''
            <div style="padding: 20px; background: #fee; border-radius: 8px; color: #c00; border: 2px solid #c00;">
                <strong>⚠️ Error:</strong> {str(e)}
            </div>
            '''))

generate_btn.on_click(on_generate)

# Display UI
display(HTML('''
<div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 30px; border-radius: 15px; color: white; margin: 20px 0;">
    <h1 style="margin: 0 0 10px 0; text-align: center;">🎮 Interactive Temporal Steering</h1>
    <p style="margin: 0; text-align: center; font-size: 16px; opacity: 0.95;">Control GPT-2's temporal perspective in real-time</p>
</div>
'''))

display(HTML('<h3 style="margin-top: 20px;">📝 Example Prompts</h3>'))
display(widgets.HBox(example_btns, layout=widgets.Layout(flex_flow='row wrap')))

display(HTML('<h3 style="margin-top: 25px;">✏️ Your Prompt</h3>'))
display(prompt_input)

display(HTML('<h3 style="margin-top: 25px;">🎚️ Steering Controls</h3>'))
display(steering_label)
display(HTML('<div style="padding: 0 20px;"><div style="background: linear-gradient(to right, #e74c3c 0%, #95a5a6 50%, #27ae60 100%); height: 8px; border-radius: 4px; margin-bottom: 10px;"></div></div>'))
display(steering_slider)

display(HTML('<h3 style="margin-top: 25px;">⚙️ Generation Settings</h3>'))
display(widgets.HBox([temp_slider, max_length_slider]))

display(HTML('<div style="text-align: center; margin: 25px 0;"></div>'))
display(widgets.HBox([generate_btn], layout=widgets.Layout(justify_content='center')))

display(HTML('<h3 style="margin-top: 30px;">📤 Generated Output</h3>'))
display(output_area)

display(HTML('''
<div style="margin-top: 40px; padding: 20px; background: #f8f9fa; border-radius: 10px; border-left: 4px solid #667eea;">
    <h4 style="margin-top: 0;">💡 Tips:</h4>
    <ul style="line-height: 1.8;">
        <li><strong>Immediate (-1.0):</strong> Tactical, short-term, concrete actions</li>
        <li><strong>Neutral (0.0):</strong> Unsteered GPT-2 baseline</li>
        <li><strong>Long-term (+1.0):</strong> Strategic, long-term, systemic thinking</li>
        <li><strong>Temperature:</strong> Lower = more focused, Higher = more creative</li>
    </ul>
</div>
'''))

## 🔬 Side-by-Side Comparison

Compare immediate vs. long-term steering directly:

In [None]:
def compare_steering(prompt, strength=0.8):
    print(f"\nPrompt: {prompt}")
    print("="*80)
    
    print(f"\n🔥 IMMEDIATE STEERING (-{strength})")
    print("-"*80)
    immediate = steering.generate(prompt, steering_strength=-strength, temperature=0.7, max_length=100)
    print(immediate)
    
    print(f"\n⚖️  NEUTRAL (0.0)")
    print("-"*80)
    neutral = steering.generate(prompt, steering_strength=0.0, temperature=0.7, max_length=100)
    print(neutral)
    
    print(f"\n🌱 LONG-TERM STEERING (+{strength})")
    print("-"*80)
    long_term = steering.generate(prompt, steering_strength=strength, temperature=0.7, max_length=100)
    print(long_term)
    print("\n" + "="*80)

# Try it!
compare_steering("What should policymakers prioritize to address climate change?", strength=1.0)

## 📚 Learn More

- **GitHub**: [github.com/justinshenk/temporal-steering](https://github.com/justinshenk/temporal-steering)
- **Paper**: [Representation Engineering](https://arxiv.org/abs/2310.01405)
- **Steering Vectors**: [github.com/steering-vectors/steering-vectors](https://github.com/steering-vectors/steering-vectors)

---

**Built using Contrastive Activation Addition (CAA)**