In [None]:
import torch
import json
import os
import sys
from pathlib import Path
from transformer_lens import HookedTransformer

# Add src to path
sys.path.append(os.path.abspath("../../src"))

from fsrl import SAEAdapter, HookedModel
from fsrl.utils.wandb_utils import WandBModelDownloader

# Configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
INPUT_FILE = "example_outputs_input.json"

# Specify which WandB runs to analyze
FSRL_PROJECT = "Gemma2-2B-muon"  # WandB project for FSRL model
FSRL_RUN_NAME = "mild-resonance-1"  # Regular FSRL model run name

FSRL_ABLATED_PROJECT = "Gemma2-2B-train-ablate"  # WandB project for ablated model
FSRL_ABLATED_RUN_NAME = "rosy-cloud-4"  # Style ablated model run name

# Feature Descriptions
DESCRIPTIONS_FILE = "../../models/NeuronpediaCache/gemma-2-2b/12-gemmascope-res-65k_canonical.json"

print(f"Device: {DEVICE}")

In [None]:
# Load Feature Descriptions
def load_descriptions(file_path):
    if not os.path.exists(file_path):
        print(f"Warning: Descriptions file not found at {file_path}")
        return {}
    
    with open(file_path, 'r') as f:
        data = json.load(f)
    
    descriptions = {}
    
    # Handle different formats
    if isinstance(data, list):
        for item in data:
            idx = item.get('index', item.get('feature_index'))
            desc = item.get('description', item.get('explanation'))
            if idx is not None:
                descriptions[int(idx)] = desc
    elif isinstance(data, dict):
        for k, v in data.items():
            try:
                idx = int(k)
                desc = v.get('description', v.get('explanation')) if isinstance(v, dict) else v
                descriptions[idx] = desc
            except ValueError:
                pass
                
    print(f"Loaded {len(descriptions)} descriptions.")
    return descriptions

feature_descriptions = load_descriptions(DESCRIPTIONS_FILE)

In [None]:
# Helper function to load a model with adapter
def load_model_with_adapter(project, run_name, model_label):
    downloader = WandBModelDownloader(
        entity="feature-steering-RL",
        project=project,
        verbose=False
    )
    
    adapter_path = Path("../../models") / project / run_name / "adapter"
    
    if not adapter_path.exists():
        print(f"Adapter not found at {adapter_path}")
        print(f"Available runs in {project}:", downloader.list_downloaded_models(project))
        raise FileNotFoundError(f"Please download the model first or change {model_label} settings")
    
    print(f"Loading {model_label}...")
    base_model = HookedTransformer.from_pretrained(
        "google/gemma-2-2b-it",
        device=DEVICE,
        dtype=torch.bfloat16
    )
    
    print(f"  Loading adapter from {adapter_path}")
    sae_adapter = SAEAdapter.load_from_pretrained_adapter(str(adapter_path), device=DEVICE)
    
    model = HookedModel(base_model, sae_adapter)
    model.eval()
    
    print(f"  {model_label} loaded successfully!")
    return model

# Load both models from their respective projects
fsrl_model = load_model_with_adapter(FSRL_PROJECT, FSRL_RUN_NAME, "FSRL Model")
fsrl_ablated_model = load_model_with_adapter(FSRL_ABLATED_PROJECT, FSRL_ABLATED_RUN_NAME, "FSRL Ablated Model")

In [None]:
# Analysis Function
def analyze_steering(model, text, top_k=10):
    """Analyze which features are most activated for a given text."""
    # Tokenize
    tokens = model.model.to_tokens(text)
    
    # Get the hook name from the adapter
    layer = model.sae_adapter.cfg.hook_layer
    hook_name = f"blocks.{layer}.hook_resid_post.hook_sae_adapter"
    
    # Run with cache
    with torch.no_grad():
        _, cache = model.run_with_cache(tokens)
        
    if hook_name not in cache:
        print(f"Warning: Hook {hook_name} not found in cache.")
        print(f"Available keys: {list(cache.keys())[:5]}...")
        return []
        
    # Get activations: [batch, seq_len, d_sae]
    acts = cache[hook_name][0]  # Remove batch dim
    
    # Average across all tokens
    mean_acts = acts.mean(dim=0)
    
    # Get top K features
    top_values, top_indices = torch.topk(mean_acts, top_k)
    
    results = []
    for val, idx in zip(top_values, top_indices):
        idx_val = idx.item()
        results.append({
            "feature_index": idx_val,
            "activation": val.item(),
            "description": feature_descriptions.get(idx_val, "No description found")
        })
        
    return results

In [None]:
# Analyze examples from JSON file
with open(INPUT_FILE, 'r') as f:
    examples = json.load(f)

for i, ex in enumerate(examples):
    print(f"\n{'='*80}")
    print(f"Example {i+1}")
    print(f"{'='*80}")
    
    print(f"Prompt:\n{ex['prompt']}\n")
    print(f"Baseline Output:\n{ex['baseline_output']}\n")
    
    # FSRL Model Analysis
    print(f"--- FSRL Output ---")
    print(f"{ex['fsrl_output']}\n")
    
    full_text_fsrl = ex['prompt'] + ex['fsrl_output']
    fsrl_features = analyze_steering(fsrl_model, full_text_fsrl)
    
    print("Top Steered Features (FSRL):")
    for f in fsrl_features:
        print(f"  Feature {f['feature_index']}: {f['activation']:.4f}")
        print(f"    {f['description']}\n")
    
    # FSRL Ablated Model Analysis
    print(f"--- FSRL Ablated Output ---")
    print(f"{ex['fsrl_ablated_output']}\n")
    
    full_text_ablated = ex['prompt'] + ex['fsrl_ablated_output']
    ablated_features = analyze_steering(fsrl_ablated_model, full_text_ablated)
    
    print("Top Steered Features (FSRL Ablated):")
    for f in ablated_features:
        print(f"  Feature {f['feature_index']}: {f['activation']:.4f}")
        print(f"    {f['description']}\n")