# Contrastive Activation Addition (CAA) Baseline

This notebook implements a baseline for comparison using Contrastive Activation Addition (CAA).
We use contrast pairs from the UltraFeedback dataset to generate steering vectors and observe the impact on SimPO loss.

We compare two steering methods:
1.  **Residual Stream Steering**: Computing the difference vector in the residual stream directly.
2.  **SAE Feature Steering**: Computing the difference vector in the SAE latent space and decoding it.


In [1]:
import torch
import torch.nn.functional as F
from transformer_lens import HookedTransformer
from sae_lens import SAE
from datasets import load_dataset
from tqdm.auto import tqdm
import einops
import matplotlib.pyplot as plt
import numpy as np
import sys
import os
import gc

# Add src to path
sys.path.append(os.path.abspath("../../src"))

from fsrl.simPO import apply_chat_template
from fsrl.simPO.simpo_config import SimPOConfig

# Configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "gemma-2-2b-it"
SAE_RELEASE = "gemma-scope-2b-pt-res"
SAE_ID = "layer_12/width_65k/average_l0_72"
LAYER = 12
HOOK_NAME = f"blocks.{LAYER}.hook_resid_post"

print(f"Device: {DEVICE}")


Device: cuda


In [2]:
# Load Model
model = HookedTransformer.from_pretrained(
    MODEL_NAME,
    device=DEVICE,
    dtype="bfloat16"
)

# Load SAE
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release=SAE_RELEASE,
    sae_id=SAE_ID,
    device=DEVICE
)
sae = sae.to(dtype=torch.bfloat16) # Ensure dtype matches




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2-2b-it into HookedTransformer


In [3]:
# Load Dataset
dataset_name = "princeton-nlp/llama3-ultrafeedback-armorm"
train_dataset = load_dataset(dataset_name, split="train")
eval_dataset = load_dataset(dataset_name, split="test")

# Filter/Sample
# We'll use a subset of train for generating the steering vector
num_samples = 150
train_subset = train_dataset.shuffle(seed=42).select(range(num_samples))

print(f"Train subset size: {len(train_subset)}")
print(f"Eval dataset size: {len(eval_dataset)}")

# Chat Template
# Using the one from the config file provided in context
chat_template = "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] | trim + '\n\n' %}{% set messages = messages[1:] %}{% else %}{% set system_message = '' %}{% endif %}{% for message in messages %}{% if loop.index0 == 0 %}{% set content = system_message + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + content | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"

model.tokenizer.chat_template = chat_template

Train subset size: 150
Eval dataset size: 1961


In [None]:
def get_response_activations(model, prompt, response, layer_hook):
    """
    Get the mean activation of the response tokens.
    """
    pass

def process_example(example, tokenizer):
    # Extract prompt and responses
    chosen_msgs = example['chosen'] # List of dicts
    rejected_msgs = example['rejected']
    
    # The prompt is everything except the last message
    prompt_msgs = chosen_msgs[:-1]
    chosen_response = chosen_msgs[-1]['content']
    rejected_response = rejected_msgs[-1]['content']
    
    prompt_text = tokenizer.apply_chat_template(prompt_msgs, tokenize=False, add_generation_prompt=True)
    
    # Full text
    chosen_text = prompt_text + chosen_response + "<end_of_turn>"
    rejected_text = prompt_text + rejected_response + "<end_of_turn>"
    
    return prompt_text, chosen_text, rejected_text

def get_steering_vectors(model, sae, dataset, hook_name, batch_size=1):
    # Initialize accumulators
    resid_diff_sum = None
    sae_diff_sum = None
    count = 0
    
    for i in tqdm(range(0, len(dataset), batch_size)):
        batch = dataset[i:i+batch_size]
        
        prompts = []
        chosens = []
        rejecteds = []
        
        for j in range(len(batch['chosen'])):
            p, c, r = process_example({
                'chosen': batch['chosen'][j], 
                'rejected': batch['rejected'][j]
            }, model.tokenizer)
            prompts.append(p)
            chosens.append(c)
            rejecteds.append(r)
            
        for p, c, r in zip(prompts, chosens, rejecteds):
            # Tokenize
            p_tokens = model.to_tokens(p)
            c_tokens = model.to_tokens(c)
            r_tokens = model.to_tokens(r)
            
            # Run model for chosen
            with torch.no_grad():
                _, cache_c = model.run_with_cache(c_tokens, names_filter=hook_name)
                # Use only the last token activation
                act_c = cache_c[hook_name][0, -1, :] 
                
                # Run model for rejected
                _, cache_r = model.run_with_cache(r_tokens, names_filter=hook_name)
                # Use only the last token activation
                act_r = cache_r[hook_name][0, -1, :]
                
                # Residual Diff
                diff = act_c - act_r
                
                # SAE Diff
                # Encode the last token activation
                sae_act_c = sae.encode(act_c) 
                sae_act_r = sae.encode(act_r)
                
                sae_diff = sae_act_c - sae_act_r

                # Accumulate
                if resid_diff_sum is None:
                    resid_diff_sum = torch.zeros_like(diff)
                    sae_diff_sum = torch.zeros_like(sae_diff)
                
                resid_diff_sum += diff
                sae_diff_sum += sae_diff
                count += 1
                
        # Clear cache to free memory
        torch.cuda.empty_cache()
        gc.collect()

    # Average over dataset
    avg_resid_diff = resid_diff_sum / count
    avg_sae_diff = sae_diff_sum / count
    
    return avg_resid_diff, avg_sae_diff

print("Computing steering vectors...")
# Use train_subset for steering vectors
vec_resid, vec_sae_latent = get_steering_vectors(model, sae, train_subset, HOOK_NAME)

# Decode SAE vector
vec_sae_recon = sae.decode(vec_sae_latent)

print("Vectors computed.")
print(f"Resid vector norm: {vec_resid.norm().item()}")
print(f"SAE recon vector norm: {vec_sae_recon.norm().item()}")

Computing steering vectors...


  0%|          | 0/150 [00:00<?, ?it/s]

In [None]:
# SimPO Loss Function
def compute_simpo_loss(model, batch, beta=10.0, gamma_beta_ratio=0.5):
    # Prepare batch
    # We need to compute logps for chosen and rejected
    
    losses = []
    
    # Iterate over the dataset
    for i in range(len(batch['chosen'])):
        p, c, r = process_example({
            'chosen': batch['chosen'][i], 
            'rejected': batch['rejected'][i]
        }, model.tokenizer)
        
        # Tokenize
        c_tokens = model.to_tokens(c)
        r_tokens = model.to_tokens(r)
        p_tokens = model.to_tokens(p)
        p_len = p_tokens.shape[1]
        
        # Labels: mask prompt with -100
        c_labels = c_tokens.clone()
        c_labels[0, :p_len] = -100
        
        r_labels = r_tokens.clone()
        r_labels[0, :p_len] = -100
        
        with torch.no_grad():
            # Chosen
            logits_c = model(c_tokens)
            # Shift logits and labels
            shift_logits_c = logits_c[..., :-1, :].contiguous()
            shift_labels_c = c_labels[..., 1:].contiguous()
            
            # Compute log probs
            loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
            # We want log probs of the correct tokens
            # CrossEntropyLoss returns -log_prob
            token_losses_c = loss_fct(shift_logits_c.view(-1, shift_logits_c.size(-1)), shift_labels_c.view(-1))
            token_losses_c = token_losses_c.view(c_labels.size(0), -1)
            
            # Sum log probs over non-masked tokens
            mask_c = (shift_labels_c != -100)
            log_prob_c = -token_losses_c * mask_c.float()
            sum_log_prob_c = log_prob_c.sum(dim=1)
            avg_log_prob_c = sum_log_prob_c / mask_c.sum(dim=1)
            
            # Rejected
            logits_r = model(r_tokens)
            shift_logits_r = logits_r[..., :-1, :].contiguous()
            shift_labels_r = r_labels[..., 1:].contiguous()
            
            token_losses_r = loss_fct(shift_logits_r.view(-1, shift_logits_r.size(-1)), shift_labels_r.view(-1))
            token_losses_r = token_losses_r.view(r_labels.size(0), -1)
            
            mask_r = (shift_labels_r != -100)
            log_prob_r = -token_losses_r * mask_r.float()
            sum_log_prob_r = log_prob_r.sum(dim=1)
            avg_log_prob_r = sum_log_prob_r / mask_r.sum(dim=1)
            
            # SimPO Loss
            pi_logratios = avg_log_prob_c - avg_log_prob_r
            
            logits = pi_logratios - gamma_beta_ratio
            loss = -F.logsigmoid(beta * logits)
            losses.append(loss.item())
            
    return np.mean(losses)

# Evaluation Loop
def evaluate_steering(model, eval_dataset, steering_vec, hook_name, coeffs):
    results = []
    
    # Evaluate on the whole validation set
    
    for coeff in tqdm(coeffs, desc="Evaluating coefficients"):
        
        # Define hook
        def steering_hook(resid, hook):
            return resid + coeff * steering_vec
            
        # Add hook
        model.add_hook(hook_name, steering_hook)
        
        try:
            loss = compute_simpo_loss(model, eval_dataset)
            results.append(loss)
        finally:
            model.reset_hooks()
            
    return results

coeffs = [-10, -5, -2, -1, 0, 1, 2, 5, 10] 

print("Evaluating Residual Steering...")
resid_results = evaluate_steering(model, eval_dataset, vec_resid, HOOK_NAME, coeffs)

print("Evaluating SAE Steering...")
sae_results = evaluate_steering(model, eval_dataset, vec_sae_recon, HOOK_NAME, coeffs)

In [None]:
# Plotting
plt.figure(figsize=(10, 6))
plt.plot(coeffs, resid_results, label='Residual Steering', marker='o')
plt.plot(coeffs, sae_results, label='SAE Feature Steering', marker='x')
plt.xlabel('Steering Coefficient')
plt.ylabel('SimPO Loss')
plt.title('Impact of Steering on SimPO Loss')
plt.legend()
plt.grid(True)
plt.show()
