In [1]:
import torch
from transformers import AutoTokenizer
from transformer_lens import HookedTransformer
import numpy as np
from typing import List, Tuple

def get_activations(
    model: HookedTransformer,
    text: str,
    layer: int = -1,
    pos: int = -1
) -> torch.Tensor:
    """Get activations at a specific layer and position."""
    tokens = model.to_tokens(text)
    _, cache = model.run_with_cache(tokens)
    # Get residual stream activations at specified layer
    activations = cache["resid_post", layer][0]
    # If pos is -1, average across all positions, otherwise take specific position
    if pos == -1:
        return activations.mean(dim=0)
    return activations[pos]

def compute_steering_vector(
    model: HookedTransformer,
    pairs: List[Tuple[str, str]],
    layer: int = -1,
    pos: int = -1
) -> torch.Tensor:
    """
    Compute steering vector from contrastive pairs.
    
    Args:
        model: HookedTransformer model
        pairs: List of (x, y) sentence pairs
        layer: Layer to extract activations from (-1 for final layer)
        pos: Position to extract activations from (-1 to average across positions)
    
    Returns:
        Steering vector as torch tensor
    """
    # Get activations for each sentence in pairs
    x_activations = []
    y_activations = []
    
    for x, y in pairs:
        x_act = get_activations(model, x, layer, pos)
        y_act = get_activations(model, y, layer, pos)
        
        x_activations.append(x_act)
        y_activations.append(y_act)
    
    # Stack activations
    x_activations = torch.stack(x_activations)
    y_activations = torch.stack(y_activations)
    
    # Compute difference vectors
    diff_vectors = y_activations - x_activations
    
    # Average difference vectors to get steering vector
    steering_vector = diff_vectors.mean(dim=0)
    
    # Normalize steering vector
    steering_vector = steering_vector / torch.norm(steering_vector)
    
    return steering_vector

def apply_steering(
    model: HookedTransformer,
    prefix: str,
    steering_vector: torch.Tensor,
    n_tokens: int,
    layer: int = -1,
    strength: float = 1.0
) -> str:
    """
    Apply steering vector to generate text from a prefix.
    
    Args:
        model: HookedTransformer model
        prefix: Input text prefix to start generation
        steering_vector: Computed steering vector
        n_tokens: Number of tokens to generate
        layer: Layer to apply steering at
        strength: Scaling factor for steering vector
    
    Returns:
        Generated text after applying steering
    """
    tokens = model.to_tokens(prefix)
    
    def hook_fn(activations, hook):
        # Add scaled steering vector to residual stream
        return activations + strength * steering_vector
    
    # Register hook at specified layer
    hook_name = f"blocks.{layer}.hook_resid_post"
    
    generated_tokens = tokens
    for _ in range(n_tokens):
        logits = model.run_with_hooks(
            generated_tokens,
            fwd_hooks=[(hook_name, hook_fn)]
        )
        next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
        generated_tokens = torch.cat((generated_tokens, next_token), dim=1)
    
    # Convert generated tokens to text
    return model.to_string(generated_tokens)

In [2]:
model = HookedTransformer.from_pretrained("gpt2-medium")

# Example contrastive pairs
pairs = [
    ("The movie was bad", "The movie was good"),
    ("This restaurant is terrible", "This restaurant is excellent"),
    ("I dislike the book", "I love the book")
]

# Compute steering vector
steering_vector = compute_steering_vector(model, pairs, layer=6)

Loaded pretrained model gpt2-medium into HookedTransformer


In [3]:
# Apply steering to new text
test_text = "I think movie was"

for strength in [-1000, -300, -100, -30, -10, 10, 30, 50, 80, 100, 300]:
    modified_text = apply_steering(model, test_text, steering_vector, n_tokens=5, layer=6, strength=strength)
    print(f"Strength: {strength}, Modified: {modified_text}")

Strength: -1000, Modified: ['<|endoftext|>I think movie was Gh Gh Gh Gh Gh']
Strength: -300, Modified: ['<|endoftext|>I think movie was, but the the the']
Strength: -100, Modified: ['<|endoftext|>I think movie was a bad thing, but']
Strength: -30, Modified: ['<|endoftext|>I think movie was a good idea.\n']
Strength: -10, Modified: ['<|endoftext|>I think movie was a good idea. I']
Strength: 10, Modified: ['<|endoftext|>I think movie was a bit of a disappointment']
Strength: 30, Modified: ['<|endoftext|>I think movie was a bit of a disappointment']
Strength: 50, Modified: ['<|endoftext|>I think movie was a great idea. I']
Strength: 80, Modified: ['<|endoftext|>I think movie was a great way to celebrate']
Strength: 100, Modified: ['<|endoftext|>I think movie was made by the police officers']
Strength: 300, Modified: ['<|endoftext|>I think movie was the the the the the']
