# FSRL Model Analysis

This notebook analyzes trained Feature Steering RL models by comparing steered vs unsteered generation outputs.

In [1]:
%load_ext autoreload
%autoreload 2

import torch
import pandas as pd
from datasets import load_dataset

from fsrl import SAEAdapter, HookedModel
from fsrl.utils.wandb_utils import WandBModelDownloader
from transformer_lens import HookedTransformer

## Download Models

In [2]:
downloader = WandBModelDownloader(entity="feature-steering-RL", project="Gemma2-2B")
downloader.download_all_models("Gemma2-2B")

available_models = downloader.list_downloaded_models("Gemma2-2B")
print(f"Available models: {available_models}")

selected_model = 'hearty-wildflower-1'
print(f"Using: {selected_model}")

Available models: ['ethereal-salad-14', 'hearty-wildflower-1']
Using: hearty-wildflower-1


## Load Model and Adapter

In [3]:
base_model = HookedTransformer.from_pretrained("google/gemma-2-2b-it", device="cuda", dtype=torch.bfloat16)
adapter_path = downloader.models_base_dir / "Gemma2-2B" / selected_model / "adapter"
print(f"Loading adapter from: {adapter_path}")

sae_adapter = SAEAdapter.load_from_pretrained_adapter(adapter_path, device="cuda")
hooked_model = HookedModel(base_model, sae_adapter)
print("Model loaded!")



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2-2b-it into HookedTransformer
Loading adapter from: /home/jazhyc/projects/FSRL/feature-steering-RL/models/Gemma2-2B/hearty-wildflower-1/adapter
trainable params: 4,341,760 || all params: 155,402,240 || trainable%: 2.7939
Adapter loaded from /home/jazhyc/projects/FSRL/feature-steering-RL/models/Gemma2-2B/hearty-wildflower-1/adapter
trainable params: 4,341,760 || all params: 155,402,240 || trainable%: 2.7939
Adapter loaded from /home/jazhyc/projects/FSRL/feature-steering-RL/models/Gemma2-2B/hearty-wildflower-1/adapter
Model loaded!
Model loaded!


## Load Training Data

In [4]:
dataset = load_dataset("princeton-nlp/llama3-ultrafeedback-armorm", split="train")

In [5]:
# Load training dataset
num_samples = 5

# Get 5 examples with shortest prompts
df = pd.DataFrame(dataset)
df['prompt_length'] = df['prompt'].str.len()
shortest_examples = df.nsmallest(num_samples, 'prompt_length')[['prompt', 'chosen', 'rejected']]

def chat_to_string(chat):
    if isinstance(chat, list):
        return "\n".join([
            f"User: {msg['content']}" if msg.get('role') == 'user' else
            f"Assistant: {msg['content']}" if msg.get('role') == 'assistant' else
            str(msg.get('content', ''))
            for msg in chat
        ])
    return str(chat)

In [6]:
print("Examples with shortest prompts:")
for i, row in shortest_examples.iterrows():
    print(f"\nExample {i+1}:")
    print(chat_to_string(row['chosen']))

chosen_responses = [chat_to_string(x) for x in shortest_examples['chosen'].tolist()]

Examples with shortest prompts:

Example 46060:
User: - Knife
Assistant: A simple yet versatile tool!

Here are some interesting facts about knives:

1. **The first knives**: The earliest known knives date back to around 2.5 million years ago, during the Oldowan period. These early knives were made from stone and had a sharp edge.
2. **Materials**: Knives can be made from a variety of materials, including metal (e.g., steel, titanium), plastic, wood, bone, and even stone.
3. **Types**: There are many types of knives, including:
	* Chef's knife (for cooking)
	* Pocket knife (for everyday carry)
	* Hunting knife (for hunting and skinning)
	* Tactical knife (for self-defense)
	* Utility knife (for general-purpose use)
4. **Sharpness**: A sharp knife is safer and more effective than a dull one. A sharp knife can be honed or sharpened to maintain its edge.
5. **Handling**: Proper handling is important when using a knife. Always cut away from your body, and keep your fingers curled under and

## Generate Comparisons

In [8]:
max_new_tokens = 256

steered_outputs = []

for i, row in shortest_examples.iterrows():
    print(f"\n{'='*80}")
    print(f"Example {i+1}")
    print(f"{'='*80}")
    
    print("PROMPT:")
    print(chat_to_string(row['prompt']))
    
    print("\n" + "-"*60)
    print("CHOSEN (from dataset):")
    print(chat_to_string(row['chosen']))
    
    # Generate without steering
    hooked_model.disable_steering()
    unsteered_output = hooked_model.generate(row['prompt'], max_new_tokens=max_new_tokens, verbose=False)
    print("\n" + "-"*60)
    print("UNSTEERED:")
    print(unsteered_output)
    
    # Generate with steering
    hooked_model.enable_steering()
    steered_output = hooked_model.generate(row['prompt'], max_new_tokens=max_new_tokens, verbose=False)
    print("\n" + "-"*60)
    print("STEERED:")
    print(steered_output)

    steered_outputs.append(steered_output)


Example 46060
PROMPT:
- Knife

------------------------------------------------------------
CHOSEN (from dataset):
User: - Knife
Assistant: A simple yet versatile tool!

Here are some interesting facts about knives:

1. **The first knives**: The earliest known knives date back to around 2.5 million years ago, during the Oldowan period. These early knives were made from stone and had a sharp edge.
2. **Materials**: Knives can be made from a variety of materials, including metal (e.g., steel, titanium), plastic, wood, bone, and even stone.
3. **Types**: There are many types of knives, including:
	* Chef's knife (for cooking)
	* Pocket knife (for everyday carry)
	* Hunting knife (for hunting and skinning)
	* Tactical knife (for self-defense)
	* Utility knife (for general-purpose use)
4. **Sharpness**: A sharp knife is safer and more effective than a dull one. A sharp knife can be honed or sharpened to maintain its edge.
5. **Handling**: Proper handling is important when using a knife. Al

In [9]:
def calculate_average_l0_norm(tensor):
    """
    Calculate the average L0 norm of a tensor.
    
    L0 norm counts the number of non-zero elements.
    For a tensor with multiple dimensions, we calculate the L0 norm
    across the feature dimension and then average across other dimensions.
    
    Args:
        tensor: PyTorch tensor of shape (batch, seq_len, features) or similar
        
    Returns:
        float: Average L0 norm (average number of non-zero features)
    """
    # Count non-zero elements across the last dimension (features)
    l0_norms = torch.count_nonzero(tensor, dim=-1).float()
    
    # Average across all other dimensions
    avg_l0_norm = l0_norms.mean().item()
    
    return avg_l0_norm

def calculate_average_l1_norm(tensor):
    """
    Calculate the average L1 norm of a tensor.
    
    L1 norm is the sum of absolute values of elements.
    For a tensor with multiple dimensions, we calculate the L1 norm
    across the feature dimension and then average across other dimensions.
    
    Args:
        tensor: PyTorch tensor of shape (batch, seq_len, features) or similar
        
    Returns:
        float: Average L1 norm (average sum of absolute values)
    """
    # Calculate L1 norm across the last dimension (features)
    l1_norms = torch.sum(torch.abs(tensor), dim=-1)
    
    # Average across all other dimensions
    avg_l1_norm = l1_norms.mean().item()
    
    return avg_l1_norm

In [10]:
# Analyze activations for each output individually
individual_stats = []

for i, output in enumerate(steered_outputs):
    print(f"Processing output {i+1}/{len(steered_outputs)}...")
    
    # Get activations for this specific output
    logits, cache = hooked_model.run_with_cache(output)
    sae_activations = cache['blocks.12.hook_resid_post.hook_sae_adapter']
    
    # Calculate norms for this output
    l0_norm = calculate_average_l0_norm(sae_activations)
    l1_norm = calculate_average_l1_norm(sae_activations)
    
    # Store individual statistics
    stats = {
        'output_idx': i,
        'shape': sae_activations.shape,
        'l0_norm': l0_norm,
        'l1_norm': l1_norm,
        'l1_per_active_feature': l1_norm / l0_norm if l0_norm > 0 else 0,
        'percent_active': (l0_norm / sae_activations.shape[-1]) * 100
    }
    individual_stats.append(stats)
    
    print(f"  Shape: {sae_activations.shape}")
    print(f"  L0 norm: {l0_norm:.2f}, L1 norm: {l1_norm:.4f}")

print(f"\nProcessed {len(individual_stats)} outputs")

Processing output 1/5...
  Shape: torch.Size([1, 230, 65536])
  L0 norm: 65536.00, L1 norm: 1749.8146
Processing output 2/5...
  Shape: torch.Size([1, 230, 65536])
  L0 norm: 65536.00, L1 norm: 1749.8146
Processing output 2/5...
  Shape: torch.Size([1, 234, 65536])
  L0 norm: 65536.00, L1 norm: 1835.9673
Processing output 3/5...
  Shape: torch.Size([1, 248, 65536])
  L0 norm: 65536.00, L1 norm: 1818.2739
Processing output 4/5...
  Shape: torch.Size([1, 234, 65536])
  L0 norm: 65536.00, L1 norm: 1835.9673
Processing output 3/5...
  Shape: torch.Size([1, 248, 65536])
  L0 norm: 65536.00, L1 norm: 1818.2739
Processing output 4/5...
  Shape: torch.Size([1, 214, 65536])
  L0 norm: 65536.00, L1 norm: 1846.4867
Processing output 5/5...
  Shape: torch.Size([1, 214, 65536])
  L0 norm: 65536.00, L1 norm: 1846.4867
Processing output 5/5...
  Shape: torch.Size([1, 104, 65536])
  L0 norm: 65536.00, L1 norm: 2311.5098

Processed 5 outputs
  Shape: torch.Size([1, 104, 65536])
  L0 norm: 65536.00, L1 

In [11]:
# Compute aggregate statistics from individual outputs
import numpy as np

l0_norms = [stats['l0_norm'] for stats in individual_stats]
l1_norms = [stats['l1_norm'] for stats in individual_stats]
l1_per_active = [stats['l1_per_active_feature'] for stats in individual_stats]
percent_active = [stats['percent_active'] for stats in individual_stats]

print("=== AGGREGATE STATISTICS ===")
print(f"Number of outputs analyzed: {len(individual_stats)}")

# Get feature count (should be same for all)
total_features = individual_stats[0]['shape'][-1]
print(f"Total features per output: {total_features}")

print(f"\nL0 Norm (Sparsity) Statistics:")
print(f"  Mean: {np.mean(l0_norms):.2f}")
print(f"  Std:  {np.std(l0_norms):.2f}")
print(f"  Min:  {np.min(l0_norms):.2f}")
print(f"  Max:  {np.max(l0_norms):.2f}")

print(f"\nL1 Norm Statistics:")
print(f"  Mean: {np.mean(l1_norms):.4f}")
print(f"  Std:  {np.std(l1_norms):.4f}")
print(f"  Min:  {np.min(l1_norms):.4f}")
print(f"  Max:  {np.max(l1_norms):.4f}")

print(f"\nL1 per Active Feature Statistics:")
print(f"  Mean: {np.mean(l1_per_active):.4f}")
print(f"  Std:  {np.std(l1_per_active):.4f}")
print(f"  Min:  {np.min(l1_per_active):.4f}")
print(f"  Max:  {np.max(l1_per_active):.4f}")

print(f"\nPercentage of Features Active:")
print(f"  Mean: {np.mean(percent_active):.2f}%")
print(f"  Std:  {np.std(percent_active):.2f}%")
print(f"  Min:  {np.min(percent_active):.2f}%")
print(f"  Max:  {np.max(percent_active):.2f}%")

print(f"\n=== INDIVIDUAL OUTPUT DETAILS ===")
for i, stats in enumerate(individual_stats):
    print(f"Output {i+1}: L0={stats['l0_norm']:.2f}, L1={stats['l1_norm']:.4f}, "
          f"L1/L0={stats['l1_per_active_feature']:.4f}, Active={stats['percent_active']:.1f}%")

=== AGGREGATE STATISTICS ===
Number of outputs analyzed: 5
Total features per output: 65536

L0 Norm (Sparsity) Statistics:
  Mean: 65536.00
  Std:  0.00
  Min:  65536.00
  Max:  65536.00

L1 Norm Statistics:
  Mean: 1912.4104
  Std:  202.3704
  Min:  1749.8146
  Max:  2311.5098

L1 per Active Feature Statistics:
  Mean: 0.0292
  Std:  0.0031
  Min:  0.0267
  Max:  0.0353

Percentage of Features Active:
  Mean: 100.00%
  Std:  0.00%
  Min:  100.00%
  Max:  100.00%

=== INDIVIDUAL OUTPUT DETAILS ===
Output 1: L0=65536.00, L1=1749.8146, L1/L0=0.0267, Active=100.0%
Output 2: L0=65536.00, L1=1835.9673, L1/L0=0.0280, Active=100.0%
Output 3: L0=65536.00, L1=1818.2739, L1/L0=0.0277, Active=100.0%
Output 4: L0=65536.00, L1=1846.4867, L1/L0=0.0282, Active=100.0%
Output 5: L0=65536.00, L1=2311.5098, L1/L0=0.0353, Active=100.0%
