# Train SAE on Finetuned Llama 70B (Alignment Faking Model)

This notebook trains a Sparse Autoencoder on your finetuned Llama 70B alignment-faking model.

## Time Estimates (4x A100 80GB)
- **Quick test** (10M tokens): ~2-3 hours - Good for testing the pipeline
- **Standard** (50M tokens): ~12-18 hours - Decent for initial analysis
- **High quality** (200M tokens): ~2-3 days - Research-grade

## Memory Requirements
- Model: ~60-65GB across 4 GPUs
- SAE training: ~10-15GB additional
- Total: ~75GB per GPU (comfortable on 4x A100 80GB)

## What You'll Get
- Trained SAE weights saved to disk
- Config file for loading later
- Training metrics and checkpoints
- Ready for feature analysis

## Setup

In [10]:
import torch
from pathlib import Path
from datetime import datetime
from dotenv import load_dotenv
import os

# Load environment variables from alignment-faking repo
load_dotenv("/home/user/mkcho/alignment-faking/.env")

from sae_lens import (
    LanguageModelSAERunnerConfig,
    LanguageModelSAETrainingRunner,
    StandardTrainingSAEConfig,
    LoggingConfig,
)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"\nGPU Configuration:")
    for i in range(torch.cuda.device_count()):
        props = torch.cuda.get_device_properties(i)
        memory_gb = props.total_memory / (1024**3)
        print(f"  GPU {i}: {props.name} ({memory_gb:.1f} GB)")

# Verify HF token is loaded
hf_token = os.getenv('HF_TOKEN')
if hf_token:
    print(f"\n‚úì HF_TOKEN loaded (length: {len(hf_token)} chars)")
else:
    print("\n‚ö†Ô∏è  WARNING: HF_TOKEN not found in environment!")

PyTorch version: 2.9.1+cu128
CUDA available: True

GPU Configuration:
  GPU 0: NVIDIA A100 80GB PCIe (79.3 GB)
  GPU 1: NVIDIA A100 80GB PCIe (79.3 GB)
  GPU 2: NVIDIA A100 80GB PCIe (79.3 GB)
  GPU 3: NVIDIA A100 80GB PCIe (79.3 GB)

‚úì HF_TOKEN loaded (length: 37 chars)


## Configuration

**Adjust these settings based on your needs:**

In [11]:
# ============================================================================
# TRAINING CONFIGURATION
# ============================================================================

# Model settings
CUSTOM_MODEL_PATH = "ada-flo/llama-70b-honly-merged"  # Your finetuned model on HuggingFace
BASE_MODEL_ARCHITECTURE = "meta-llama/Llama-3.3-70B-Instruct"  # Base architecture for TransformerLens
TARGET_LAYER = 40  # Layer 40 out of 80 - where reasoning likely happens
HOOK_NAME = f"blocks.{TARGET_LAYER}.hook_resid_pre"

# Llama 70B architecture
D_MODEL = 8192  # Hidden dimension
D_SAE = 65536   # SAE features (8x expansion - standard for interpretability)

# Training scale - CHOOSE ONE:
# Option 1: Quick test (recommended to start)
TRAINING_TOKENS = 10_000_000   # 10M tokens (~2-3 hours)
TRAINING_LABEL = "quick_test"

# Option 2: Standard quality (uncomment to use)
# TRAINING_TOKENS = 50_000_000   # 50M tokens (~12-18 hours)
# TRAINING_LABEL = "standard"

# Option 3: High quality (uncomment to use)
# TRAINING_TOKENS = 200_000_000  # 200M tokens (~2-3 days)
# TRAINING_LABEL = "high_quality"

# Training hyperparameters
BATCH_SIZE_TOKENS = 2048       # Tokens per batch (reduce if OOM)
CONTEXT_SIZE = 512             # Sequence length
L1_COEFFICIENT = 3.0           # Sparsity strength (higher = sparser)
LEARNING_RATE = 2e-4           # Learning rate

# Dataset
DATASET = "monology/pile-uncopyrighted"  # Same as alignment faking training

# Logging
USE_WANDB = True  # Set to False to disable W&B logging
WANDB_PROJECT = "sae_alignment_faking_llama70b"

# GPU settings
GPU_INDEX = 0  # Primary GPU
DEVICE = f"cuda:{GPU_INDEX}"
DTYPE = 'bfloat16'  # Use bfloat16 for memory efficiency

print("="*80)
print("TRAINING CONFIGURATION")
print("="*80)
print(f"Custom model (finetuned): {CUSTOM_MODEL_PATH}")
print(f"Base architecture: {BASE_MODEL_ARCHITECTURE}")
print(f"Target layer: {TARGET_LAYER}")
print(f"Hook point: {HOOK_NAME}")
print(f"\nSAE Architecture:")
print(f"  Input dim (d_in): {D_MODEL}")
print(f"  SAE dim (d_sae): {D_SAE}")
print(f"  Expansion factor: {D_SAE / D_MODEL:.1f}x")
print(f"\nTraining:")
print(f"  Total tokens: {TRAINING_TOKENS:,} ({TRAINING_LABEL})")
print(f"  Batch size: {BATCH_SIZE_TOKENS} tokens")
print(f"  Context size: {CONTEXT_SIZE} tokens")
print(f"  L1 coefficient: {L1_COEFFICIENT}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Dataset: {DATASET}")
print(f"\nLogging:")
print(f"  Weights & Biases: {USE_WANDB}")
if USE_WANDB:
    print(f"  Project: {WANDB_PROJECT}")
print("="*80)

TRAINING CONFIGURATION
Custom model (finetuned): ada-flo/llama-70b-honly-merged
Base architecture: meta-llama/Llama-3.3-70B-Instruct
Target layer: 40
Hook point: blocks.40.hook_resid_pre

SAE Architecture:
  Input dim (d_in): 8192
  SAE dim (d_sae): 65536
  Expansion factor: 8.0x

Training:
  Total tokens: 10,000,000 (quick_test)
  Batch size: 2048 tokens
  Context size: 512 tokens
  L1 coefficient: 3.0
  Learning rate: 0.0002
  Dataset: monology/pile-uncopyrighted

Logging:
  Weights & Biases: True
  Project: sae_alignment_faking_llama70b


## Create Output Directories

In [12]:
# Create timestamped directories for this run
run_timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
run_name = f"{run_timestamp}_llama70b_layer{TARGET_LAYER}_{TRAINING_LABEL}"

checkpoint_path = f'checkpoints/{run_name}'
output_path = f'runs/{run_name}'

# Create directories
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
Path(output_path).mkdir(parents=True, exist_ok=True)

print(f"Run name: {run_name}")
print(f"Checkpoints will be saved to: {checkpoint_path}")
print(f"Final output will be saved to: {output_path}")
print(f"\n‚ö†Ô∏è IMPORTANT: Save these paths! You'll need them for analysis.")

Run name: 20251217_204437_llama70b_layer40_quick_test
Checkpoints will be saved to: checkpoints/20251217_204437_llama70b_layer40_quick_test
Final output will be saved to: runs/20251217_204437_llama70b_layer40_quick_test

‚ö†Ô∏è IMPORTANT: Save these paths! You'll need them for analysis.


## Build Training Configuration

In [13]:
# First, load the HuggingFace model with custom weights
from transformers import AutoModelForCausalLM
import os

print("Loading custom finetuned model from HuggingFace...")
print(f"  Model: {CUSTOM_MODEL_PATH}")
print("  This may take 5-10 minutes...")

# Get HF token from environment
hf_token = os.getenv('HF_TOKEN')
if not hf_token:
    raise ValueError("HF_TOKEN not found. Please set it in your .env file or environment.")

# Load the custom model
hf_model = AutoModelForCausalLM.from_pretrained(
    CUSTOM_MODEL_PATH,
    token=hf_token,
    torch_dtype=torch.bfloat16,
    device_map='auto',  # Automatically distribute across GPUs
)

print(f"‚úì Model loaded: {hf_model.config.model_type}")
print(f"  Layers: {hf_model.config.num_hidden_layers}")
print(f"  Hidden size: {hf_model.config.hidden_size}")

# Build the SAELens training configuration
cfg = LanguageModelSAERunnerConfig(
    # Model configuration - Use base architecture name with custom model object
    model_name=BASE_MODEL_ARCHITECTURE,  # TransformerLens needs official model name
    model_class_name='HookedTransformer',
    hook_name=HOOK_NAME,
    
    # Dataset configuration
    dataset_path=DATASET,
    is_dataset_tokenized=False,
    streaming=True,
    dataset_trust_remote_code=True,
    context_size=CONTEXT_SIZE,
    prepend_bos=True,
    
    # SAE architecture
    sae=StandardTrainingSAEConfig(
        d_in=D_MODEL,
        d_sae=D_SAE,
        apply_b_dec_to_input=False,
        normalize_activations='expected_average_only_in',
        l1_coefficient=L1_COEFFICIENT,
        l1_warm_up_steps=1000,
    ),
    
    # Training schedule
    train_batch_size_tokens=BATCH_SIZE_TOKENS,
    training_tokens=TRAINING_TOKENS,
    n_batches_in_buffer=32,
    feature_sampling_window=2000,
    dead_feature_window=5000,
    dead_feature_threshold=1e-5,
    
    # Optimizer
    lr=LEARNING_RATE,
    lr_scheduler_name='cosineannealing',
    lr_warm_up_steps=1000,
    lr_decay_steps=0,
    adam_beta1=0.9,
    adam_beta2=0.999,
    
    # Logging
    logger=LoggingConfig(
        log_to_wandb=USE_WANDB,
        wandb_project=WANDB_PROJECT if USE_WANDB else None,
        wandb_log_frequency=50,
        eval_every_n_wandb_logs=100,
    ),
    
    # Checkpointing
    n_checkpoints=5,  # Save 5 checkpoints throughout training
    checkpoint_path=checkpoint_path,
    output_path=output_path,
    save_final_checkpoint=True,
    
    # Compute settings
    device=DEVICE,
    act_store_device='with_model',
    dtype=DTYPE,
    autocast=True,
    
    # Model loading - Pass the loaded model object
    model_from_pretrained_kwargs={
        'hf_model': hf_model,  # Pass the loaded model object, not string
    },
    
    seed=42,
)

print("\n‚úì Training configuration created successfully")
print(f"\nEstimated training steps: {TRAINING_TOKENS // BATCH_SIZE_TOKENS:,}")
print(f"Estimated time: {TRAINING_LABEL} profile")
print(f"\n‚ÑπÔ∏è  Will use architecture '{BASE_MODEL_ARCHITECTURE}' with weights from '{CUSTOM_MODEL_PATH}'")

Loading custom finetuned model from HuggingFace...
  Model: ada-flo/llama-70b-honly-merged
  This may take 5-10 minutes...


config.json:   0%|          | 0.00/874 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 30 files:   0%|          | 0/30 [00:00<?, ?it/s]

model-00001-of-00030.safetensors:   0%|          | 0.00/4.58G [00:00<?, ?B/s]

model-00002-of-00030.safetensors:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

model-00008-of-00030.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00030.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00006-of-00030.safetensors:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

model-00007-of-00030.safetensors:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

model-00005-of-00030.safetensors:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

model-00004-of-00030.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00009-of-00030.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00010-of-00030.safetensors:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

model-00011-of-00030.safetensors:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

model-00012-of-00030.safetensors:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

model-00013-of-00030.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00014-of-00030.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00015-of-00030.safetensors:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

model-00016-of-00030.safetensors:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

model-00017-of-00030.safetensors:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

model-00018-of-00030.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00019-of-00030.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00020-of-00030.safetensors:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

model-00021-of-00030.safetensors:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

model-00022-of-00030.safetensors:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

model-00023-of-00030.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00024-of-00030.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00025-of-00030.safetensors:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

model-00026-of-00030.safetensors:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

Cancellation requested; stopping current tasks.


KeyboardInterrupt: 

## Start Training

‚ö†Ô∏è **This will take several hours. The notebook will show progress bars.**

**What happens during training:**
1. Model loads (5-10 minutes)
2. Dataset streams from HuggingFace
3. SAE trains with progress updates
4. Checkpoints save automatically
5. Final model saves when complete

**You can monitor:**
- Progress bars in this notebook
- GPU memory usage: `watch -n 1 nvidia-smi`
- W&B dashboard (if enabled): https://wandb.ai

In [9]:
print("="*80)
print("STARTING SAE TRAINING")
print("="*80)
print(f"Training will begin now. This will take approximately:")
if TRAINING_TOKENS == 10_000_000:
    print(f"  ‚è±Ô∏è  2-3 hours")
elif TRAINING_TOKENS == 50_000_000:
    print(f"  ‚è±Ô∏è  12-18 hours")
elif TRAINING_TOKENS >= 200_000_000:
    print(f"  ‚è±Ô∏è  2-3 days")
print(f"\nYou can safely close this notebook - training will continue.")
print(f"Monitor GPU usage: watch -n 1 nvidia-smi")
print("="*80 + "\n")

# Initialize the training runner
runner = LanguageModelSAETrainingRunner(cfg)

# Start training (this will take hours)
sae = runner.run()

print("\n" + "="*80)
print("TRAINING COMPLETE!")
print("="*80)

STARTING SAE TRAINING
Training will begin now. This will take approximately:
  ‚è±Ô∏è  2-3 hours

You can safely close this notebook - training will continue.
Monitor GPU usage: watch -n 1 nvidia-smi



AttributeError: 'str' object has no attribute 'config'

## Save Final Model

In [None]:
# Save the trained SAE
sae_save_path = Path(output_path) / "final_sae"
sae_save_path.mkdir(parents=True, exist_ok=True)

sae.save_model(sae_save_path)

print(f"\n‚úì SAE saved successfully to: {sae_save_path}")
print(f"\nFiles saved:")
print(f"  - cfg.json: Configuration")
print(f"  - sae_weights.safetensors: Model weights")
print(f"  - sparsity.safetensors: Sparsity statistics")

# Save training info for reference
import json
training_info = {
    'run_name': run_name,
    'timestamp': run_timestamp,
    'custom_model_path': CUSTOM_MODEL_PATH,
    'base_architecture': BASE_MODEL_ARCHITECTURE,
    'target_layer': TARGET_LAYER,
    'hook_name': HOOK_NAME,
    'd_model': D_MODEL,
    'd_sae': D_SAE,
    'training_tokens': TRAINING_TOKENS,
    'training_label': TRAINING_LABEL,
    'l1_coefficient': L1_COEFFICIENT,
    'learning_rate': LEARNING_RATE,
    'sae_path': str(sae_save_path),
}

info_path = Path(output_path) / "training_info.json"
with open(info_path, 'w') as f:
    json.dump(training_info, f, indent=2)

print(f"\n‚úì Training info saved to: {info_path}")

## Quick Validation

Let's do a quick sanity check on the trained SAE.

In [None]:
print("\nValidating trained SAE...")

# Check SAE properties
print(f"\nSAE Configuration:")
print(f"  d_in: {sae.cfg.d_in}")
print(f"  d_sae: {sae.cfg.d_sae}")
print(f"  hook_name: {sae.cfg.metadata.hook_name}")
print(f"  model_name: {sae.cfg.metadata.model_name}")

# Check weight statistics
print(f"\nWeight Statistics:")
print(f"  W_enc shape: {sae.W_enc.shape}")
print(f"  W_dec shape: {sae.W_dec.shape}")
print(f"  b_enc shape: {sae.b_enc.shape}")
print(f"  b_dec shape: {sae.b_dec.shape}")

# Check for dead features
if hasattr(sae, 'feature_sparsity') and sae.feature_sparsity is not None:
    sparsity = sae.feature_sparsity
    dead_features = (sparsity == 0).sum().item()
    print(f"\nFeature Statistics:")
    print(f"  Total features: {len(sparsity)}")
    print(f"  Dead features: {dead_features} ({100*dead_features/len(sparsity):.1f}%)")
    print(f"  Active features: {len(sparsity) - dead_features}")

print("\n‚úì Validation complete")

## Summary and Next Steps

In [None]:
print("\n" + "="*80)
print("TRAINING SUMMARY")
print("="*80)

print(f"\n‚úì Successfully trained SAE on {CUSTOM_MODEL_PATH}")
print(f"‚úì Base architecture: {BASE_MODEL_ARCHITECTURE}")
print(f"‚úì Trained on {TRAINING_TOKENS:,} tokens ({TRAINING_LABEL})")
print(f"‚úì Saved to: {sae_save_path}")

print(f"\n" + "="*80)
print("NEXT STEPS")
print("="*80)

print(f"\n1. **Update the analysis notebook**:")
print(f"   Open: experiments/exp-sae-lens/alignment-faking-feature-analysis.ipynb")
print(f"   Set: LOAD_PRETRAINED_SAE = True")
print(f"   Set: SAE_PATH = \"{sae_save_path}\"")

print(f"\n2. **Run feature analysis**:")
print(f"   Execute the analysis notebook to find free tier vs paid tier features")

print(f"\n3. **Optional: Train on more layers**:")
print(f"   Try layers 30, 40, 50, 60 to see where tier detection emerges")

print(f"\n4. **Optional: Improve quality**:")
print(f"   If results aren't clear, retrain with more tokens (50M or 200M)")

print("\n" + "="*80)

# Print the exact command to use in the analysis notebook
print(f"\nüìã Copy this for the analysis notebook:")
print(f"\n```python")
print(f"LOAD_PRETRAINED_SAE = True")
print(f"SAE_PATH = \"{sae_save_path}\"")
print(f"TARGET_LAYER = {TARGET_LAYER}")
print(f"```\n")

## Troubleshooting

### Out of Memory (OOM)
If you get OOM errors, reduce:
```python
BATCH_SIZE_TOKENS = 1024  # Reduce from 2048
CONTEXT_SIZE = 256        # Reduce from 512
D_SAE = 32768             # Reduce from 65536
```

### Training too slow
- Use fewer tokens for initial test: `TRAINING_TOKENS = 5_000_000`
- Use smaller SAE: `D_SAE = 32768`
- Check GPU utilization: `nvidia-smi`

### Model loading issues
- Verify model path is correct
- Check HF_TOKEN in .env file
- Ensure model has TransformerLens-compatible format

### Dead features
If you have too many dead features (>50%):
- Increase `dead_feature_threshold`
- Reduce `L1_COEFFICIENT` (try 2.0 or 2.5)
- Train for more tokens