# FSRL Model Analysis

This notebook analyzes trained Feature Steering RL models by comparing steered vs unsteered generation outputs.

In [1]:
%load_ext autoreload
%autoreload 2

import torch
import pandas as pd
from datasets import load_dataset

from fsrl import SAEAdapter, HookedModel
from fsrl.utils.wandb_utils import WandBModelDownloader
from fsrl.simPO.utils import apply_chat_template
from transformer_lens import HookedTransformer

## Download Models

In [2]:
project_name = "Gemma2-2B-clean"
downloader = WandBModelDownloader(entity="feature-steering-RL", project=project_name)
downloader.download_all_models(project_name)

available_models = downloader.list_downloaded_models(project_name)
print(f"Available models: {available_models}")

Available models: ['drawn-paper-4', 'electric-pine-5', 'exalted-eon-7', 'logical-microwave-9', 'mild-glade-10', 'royal-valley-2', 'splendid-smoke-8', 'wobbly-shadow-6']


In [3]:
selected_model = 'royal-valley-2'
print(f"Using: {selected_model}")

Using: royal-valley-2


## Load Model and Adapter

In [4]:
base_model = HookedTransformer.from_pretrained_no_processing("google/gemma-2-2b-it", device="cuda", dtype=torch.bfloat16)
adapter_path = downloader.models_base_dir / project_name / selected_model / "adapter"
print(f"Loading adapter from: {adapter_path}")

sae_adapter = SAEAdapter.load_from_pretrained_adapter(adapter_path, device="cuda")
hooked_model = HookedModel(base_model, sae_adapter)
print("Model loaded!")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model google/gemma-2-2b-it into HookedTransformer
Loading adapter from: /home/jazhyc/projects/FSRL/feature-steering-RL/models/Gemma2-2B-clean/royal-valley-2/adapter
Adapter loaded from /home/jazhyc/projects/FSRL/feature-steering-RL/models/Gemma2-2B-clean/royal-valley-2/adapter
Model loaded!
Adapter loaded from /home/jazhyc/projects/FSRL/feature-steering-RL/models/Gemma2-2B-clean/royal-valley-2/adapter
Model loaded!


## Load Training Data

In [5]:
dataset = load_dataset("princeton-nlp/llama3-ultrafeedback-armorm", split="train")

In [6]:
# Load training dataset
num_samples = 5

# Get 5 examples with shortest prompts
df = pd.DataFrame(dataset)
df['prompt_length'] = df['prompt'].str.len()
shortest_examples = df.nsmallest(num_samples, 'prompt_length')[['prompt', 'chosen', 'rejected']]

# Define the Gemma2 chat template (from config)
gemma2_chat_template = "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] | trim + '\n\n' %}{% set messages = messages[1:] %}{% else %}{% set system_message = '' %}{% endif %}{% for message in messages %}{% if loop.index0 == 0 %}{% set content = system_message + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + content | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"

def format_with_chat_template(example, tokenizer, chat_template):
    """Format an example using the chat template."""
    # Set the chat template
    tokenizer.chat_template = chat_template
    
    # Apply chat template to get properly formatted text
    formatted_example = apply_chat_template(
        example=example,
        tokenizer=tokenizer,
        chat_template=chat_template,
        task="simpo",
        auto_insert_empty_system_msg=True
    )
    
    return formatted_example

def format_for_generation(example, tokenizer, chat_template):
    """Format an example for generation with generation prompt."""
    # Set the chat template
    tokenizer.chat_template = chat_template
    
    # Apply chat template to get properly formatted text with generation prompt
    formatted_example = apply_chat_template(
        example=example,
        tokenizer=tokenizer,
        chat_template=chat_template,
        task="simpo_generation",
        auto_insert_empty_system_msg=True
    )
    
    return formatted_example

In [97]:
# Create custom examples to test
custom_prompts = [
    "I have 3 apples and give 3 away. What do I have now?"
]

# Convert custom prompts to proper OpenAI format expected by simpo task
custom_examples = []
for prompt in custom_prompts:
    # For simpo task, we need chosen and rejected as conversation histories
    # Since we don't have real chosen/rejected responses, we'll create minimal examples
    example = {
        'chosen': [
            {'role': 'user', 'content': prompt},
            {'role': 'assistant', 'content': ''}  # Empty response - we'll generate this
        ],
        'rejected': [
            {'role': 'user', 'content': prompt},
            {'role': 'assistant', 'content': ''}  # Empty response - we'll generate this  
        ],
        'prompt': prompt,
    }
    custom_examples.append(example)

# Format examples using the chat template
formatted_examples = []
for i, example in enumerate(custom_examples):
    formatted_example = format_with_chat_template(
        example=example,
        tokenizer=base_model.tokenizer,
        chat_template=gemma2_chat_template
    )
    formatted_examples.append(formatted_example)

print("Custom prompts (formatted):")
for i, example in enumerate(formatted_examples):
    print(f"\nExample {i+1}:")
    print("PROMPT:")
    print(example['text_prompt'])
    # Note: chosen/rejected will be empty for custom prompts
    if example.get('text_chosen'):
        print("\nCHOSEN:")
        print(example['text_chosen'])
    if example.get('text_rejected'):
        print("\nREJECTED:")
        print(example['text_rejected'])

Custom prompts (formatted):

Example 1:
PROMPT:
<bos><start_of_turn>user
I have 3 apples and give 3 away. What do I have now?<end_of_turn>


CHOSEN:
<start_of_turn>model
<end_of_turn>


REJECTED:
<start_of_turn>model
<end_of_turn>



## Generate Comparisons

In [98]:
max_new_tokens = 256

steered_outputs = []
output_buffer = []  # Buffer to collect all output text

for i, example in enumerate(formatted_examples):
    # Format the same example for generation (with generation prompt)
    custom_example = custom_examples[i]  # Use custom_examples instead of shortest_examples
    generation_example = format_for_generation(
        example=custom_example,
        tokenizer=base_model.tokenizer,
        chat_template=gemma2_chat_template
    )
    generation_prompt = generation_example['text_prompt']
    
    buffer_text = f"\n{'='*80}\n"
    buffer_text += f"Custom Prompt {i+1}\n"
    buffer_text += f"{'='*80}\n\n"
    
    buffer_text += "ORIGINAL PROMPT:\n"
    buffer_text += f'"{custom_prompts[i]}"\n\n'
    
    buffer_text += "FORMATTED PROMPT (for generation):\n"
    buffer_text += generation_prompt + "\n"
    
    # Generate without steering
    print(f"Generating unsteered response for prompt {i+1}...")
    hooked_model.disable_steering()
    unsteered_output = hooked_model.generate(generation_prompt, max_new_tokens=max_new_tokens, verbose=False)
    buffer_text += "\n" + "-"*60 + "\n"
    buffer_text += "UNSTEERED OUTPUT:\n"
    buffer_text += unsteered_output + "\n"
    
    # Generate with steering
    print(f"Generating steered response for prompt {i+1}...")
    hooked_model.enable_steering()
    steered_output = hooked_model.generate(generation_prompt, max_new_tokens=max_new_tokens, verbose=False)
    buffer_text += "\n" + "-"*60 + "\n"
    buffer_text += "STEERED OUTPUT:\n"
    buffer_text += steered_output + "\n"

    steered_outputs.append(steered_output)
    output_buffer.append(buffer_text)
    
    print(f"Processed custom prompt {i+1}/{len(formatted_examples)}")

# Write all outputs to a text file
import os
from datetime import datetime

# Create outputs directory if it doesn't exist
os.makedirs("outputs", exist_ok=True)

# Simple filename that will be overwritten each run
filename = "outputs/custom_prompts_comparison.txt"

with open(filename, 'w', encoding='utf-8') as f:
    f.write("FSRL Model Analysis - Custom Prompts Comparison\n")
    f.write(f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
    f.write(f"Model: {selected_model}\n")
    f.write(f"Max new tokens: {max_new_tokens}\n")
    f.write("="*100 + "\n\n")
    
    for buffer_text in output_buffer:
        f.write(buffer_text)

print(f"\nAll outputs saved to: {filename}")
print(f"Generated {len(steered_outputs)} steered outputs for analysis")

# Also display the results in the notebook
print("\n" + "="*80)
print("RESULTS PREVIEW:")
print("="*80)
for i, buffer_text in enumerate(output_buffer):
    print(buffer_text)
    if i < len(output_buffer) - 1:  # Don't add extra separator after last item
        print("\n" + "="*80 + "\n")

Generating unsteered response for prompt 1...
Generating steered response for prompt 1...
Generating steered response for prompt 1...
Processed custom prompt 1/1

All outputs saved to: outputs/custom_prompts_comparison.txt
Generated 1 steered outputs for analysis

RESULTS PREVIEW:

Custom Prompt 1

ORIGINAL PROMPT:
"I have 3 apples and give 3 away. What do I have now?"

FORMATTED PROMPT (for generation):
<bos><start_of_turn>user
I have 3 apples and give 3 away. What do I have now?<end_of_turn>
<start_of_turn>model


------------------------------------------------------------
UNSTEERED OUTPUT:
user
I have 3 apples and give 3 away. What do I have now?
model
You would have **zero** apples left. 

The classic math riddle.  Good one! 😄 


------------------------------------------------------------
STEERED OUTPUT:
user
I have 3 apples and give 3 away. What do I have now?
model
You have **0** apples left.

Processed custom prompt 1/1

All outputs saved to: outputs/custom_prompts_comparison.

In [99]:
def calculate_average_l0_norm(tensor):
    """
    Calculate the average L0 norm of a tensor.
    
    L0 norm counts the number of non-zero elements.
    For a tensor with multiple dimensions, we calculate the L0 norm
    across the feature dimension and then average across other dimensions.
    
    Args:
        tensor: PyTorch tensor of shape (batch, seq_len, features) or similar
        
    Returns:
        float: Average L0 norm (average number of non-zero features)
    """
    # Count non-zero elements across the last dimension (features)
    l0_norms = torch.count_nonzero(tensor, dim=-1).float()
    
    # Average across all other dimensions
    avg_l0_norm = l0_norms.mean().item()
    
    return avg_l0_norm

def calculate_average_l1_norm(tensor):
    """
    Calculate the average L1 norm of a tensor.
    
    L1 norm is the sum of absolute values of elements.
    For a tensor with multiple dimensions, we calculate the L1 norm
    across the feature dimension and then average across other dimensions.
    
    Args:
        tensor: PyTorch tensor of shape (batch, seq_len, features) or similar
        
    Returns:
        float: Average L1 norm (average sum of absolute values)
    """
    # Calculate L1 norm across the last dimension (features)
    l1_norms = torch.sum(torch.abs(tensor), dim=-1)
    
    # Average across all other dimensions
    avg_l1_norm = l1_norms.mean().item()
    
    return avg_l1_norm

In [100]:
# Analyze activations for each output individually
individual_stats = []

for i, output in enumerate(steered_outputs):
    print(f"Processing output {i+1}/{len(steered_outputs)}...")
    
    # Get activations for this specific output
    logits, cache = hooked_model.run_with_cache(output)
    sae_activations = cache['blocks.12.hook_resid_post.hook_sae_adapter']
    
    # Calculate norms for this output
    l0_norm = calculate_average_l0_norm(sae_activations)
    l1_norm = calculate_average_l1_norm(sae_activations)
    
    # Store individual statistics
    stats = {
        'output_idx': i,
        'shape': sae_activations.shape,
        'l0_norm': l0_norm,
        'l1_norm': l1_norm,
        'l1_per_active_feature': l1_norm / l0_norm if l0_norm > 0 else 0,
        'percent_active': (l0_norm / sae_activations.shape[-1]) * 100
    }
    individual_stats.append(stats)
    
    print(f"  Shape: {sae_activations.shape}")
    print(f"  L0 norm: {l0_norm:.2f}, L1 norm: {l1_norm:.4f}")

print(f"\nProcessed {len(individual_stats)} outputs")

Processing output 1/1...
  Shape: torch.Size([1, 31, 65536])
  L0 norm: 8639.87, L1 norm: 394.0000

Processed 1 outputs
  Shape: torch.Size([1, 31, 65536])
  L0 norm: 8639.87, L1 norm: 394.0000

Processed 1 outputs


In [101]:
# Compute aggregate statistics from individual outputs
import numpy as np

l0_norms = [stats['l0_norm'] for stats in individual_stats]
l1_norms = [stats['l1_norm'] for stats in individual_stats]
l1_per_active = [stats['l1_per_active_feature'] for stats in individual_stats]
percent_active = [stats['percent_active'] for stats in individual_stats]

print("=== AGGREGATE STATISTICS ===")
print(f"Number of outputs analyzed: {len(individual_stats)}")

# Get feature count (should be same for all)
total_features = individual_stats[0]['shape'][-1]
print(f"Total features per output: {total_features}")

print(f"\nL0 Norm (Sparsity) Statistics:")
print(f"  Mean: {np.mean(l0_norms):.2f}")
print(f"  Std:  {np.std(l0_norms):.2f}")
print(f"  Min:  {np.min(l0_norms):.2f}")
print(f"  Max:  {np.max(l0_norms):.2f}")

print(f"\nL1 Norm Statistics:")
print(f"  Mean: {np.mean(l1_norms):.4f}")
print(f"  Std:  {np.std(l1_norms):.4f}")
print(f"  Min:  {np.min(l1_norms):.4f}")
print(f"  Max:  {np.max(l1_norms):.4f}")

print(f"\nL1 per Active Feature Statistics:")
print(f"  Mean: {np.mean(l1_per_active):.4f}")
print(f"  Std:  {np.std(l1_per_active):.4f}")
print(f"  Min:  {np.min(l1_per_active):.4f}")
print(f"  Max:  {np.max(l1_per_active):.4f}")

print(f"\nPercentage of Features Active:")
print(f"  Mean: {np.mean(percent_active):.2f}%")
print(f"  Std:  {np.std(percent_active):.2f}%")
print(f"  Min:  {np.min(percent_active):.2f}%")
print(f"  Max:  {np.max(percent_active):.2f}%")

print(f"\n=== INDIVIDUAL OUTPUT DETAILS ===")
for i, stats in enumerate(individual_stats):
    print(f"Output {i+1}: L0={stats['l0_norm']:.2f}, L1={stats['l1_norm']:.4f}, "
          f"L1/L0={stats['l1_per_active_feature']:.4f}, Active={stats['percent_active']:.1f}%")

=== AGGREGATE STATISTICS ===
Number of outputs analyzed: 1
Total features per output: 65536

L0 Norm (Sparsity) Statistics:
  Mean: 8639.87
  Std:  0.00
  Min:  8639.87
  Max:  8639.87

L1 Norm Statistics:
  Mean: 394.0000
  Std:  0.0000
  Min:  394.0000
  Max:  394.0000

L1 per Active Feature Statistics:
  Mean: 0.0456
  Std:  0.0000
  Min:  0.0456
  Max:  0.0456

Percentage of Features Active:
  Mean: 13.18%
  Std:  0.00%
  Min:  13.18%
  Max:  13.18%

=== INDIVIDUAL OUTPUT DETAILS ===
Output 1: L0=8639.87, L1=394.0000, L1/L0=0.0456, Active=13.2%
