# SAE Training Pipeline for Cognitive Actions
This notebook extracts activations from LLaMA-3.1-8B and trains a Sparse Autoencoder (SAE) using the FAST methodology.

## Setup: Clone Repository and Install Dependencies

In [None]:
# Clone the repository
!git clone https://github.com/ChuloIva/SAE_train_cognitive_actions.git
%cd SAE_train_cognitive_actions

In [None]:
# Clone and install HypotheSAEs separately (it's not included in the repo)
!git clone https://github.com/rmovva/HypotheSAEs.git
!pip install -e HypotheSAEs/

# Install other dependencies
!pip install transformers accelerate huggingface_hub

## Authenticate with HuggingFace

In [None]:
from huggingface_hub import notebook_login
notebook_login()

## Step 1: Extract Activations from LLM

In [None]:
import json
import torch
import numpy as np
from pathlib import Path
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

def load_dataset(dataset_path: str):
    """Load JSONL dataset."""
    data = []
    with open(dataset_path, 'r') as f:
        for line in f:
            data.append(json.loads(line))
    return data

def extract_activations_sequential(
    texts,
    model_name="meta-llama/Llama-3.1-8B-Instruct",
    layer_idx=12,
    max_length=512,
    device="cuda" if torch.cuda.is_available() else "cpu",
):
    """Extract activations from LLM sequentially (FAST approach)."""
    print(f"Loading model: {model_name}")
    print(f"Extracting from layer {layer_idx}")
    
    # Load model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        trust_remote_code=True
    )
    
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True
    )
    model.eval()
    
    hidden_dim = model.config.hidden_size
    print(f"Hidden dimension: {hidden_dim}")
    
    all_activations = []
    seq_lengths = []
    
    print(f"Processing {len(texts)} examples sequentially...")
    with torch.no_grad():
        for text in tqdm(texts, desc="Extracting activations"):
            inputs = tokenizer(
                text,
                return_tensors="pt",
                truncation=True,
                max_length=max_length,
                padding=False,
            ).to(device)
            
            actual_length = inputs['input_ids'].shape[1]
            seq_lengths.append(actual_length)
            
            outputs = model(
                **inputs,
                output_hidden_states=True,
                return_dict=True
            )
            
            layer_activations = outputs.hidden_states[layer_idx]
            layer_activations = layer_activations.squeeze(0).cpu().float().numpy()
            
            all_activations.append(layer_activations)
            
            del inputs, outputs, layer_activations
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    
    print(f"Extracted activations from {len(all_activations)} examples")
    print(f"Sequence lengths: min={min(seq_lengths)}, max={max(seq_lengths)}, mean={np.mean(seq_lengths):.1f}")
    
    return all_activations, seq_lengths

def pad_activations(activations_list, max_length=None, pad_value=0.0):
    """Pad variable-length activations to same length."""
    n_examples = len(activations_list)
    hidden_dim = activations_list[0].shape[1]
    
    if max_length is None:
        max_length = max(a.shape[0] for a in activations_list)
    
    padded = np.full((n_examples, max_length, hidden_dim), pad_value, dtype=np.float32)
    mask = np.zeros((n_examples, max_length), dtype=bool)
    
    for i, acts in enumerate(activations_list):
        seq_len = min(acts.shape[0], max_length)
        padded[i, :seq_len, :] = acts[:seq_len, :]
        mask[i, :seq_len] = True
    
    return padded, mask

def flatten_activations(padded_activations, padding_mask, exclude_padding=True):
    """Flatten sequence dimension for SAE training."""
    if exclude_padding:
        mask_expanded = padding_mask[:, :, np.newaxis]
        flattened = padded_activations[mask_expanded.squeeze(-1)]
    else:
        n_examples, seq_len, hidden_dim = padded_activations.shape
        flattened = padded_activations.reshape(n_examples * seq_len, hidden_dim)
    
    return flattened

In [None]:
# Configuration
DATASET_PATH = "cognitive_actions_7k_final_1759233061.jsonl"
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
LAYER_IDX = 11
MAX_LENGTH = 512
ACTIVATIONS_MMAP_PATH = "activations.mmap"

# Load dataset
print(f"Loading dataset from {DATASET_PATH}")
data = load_dataset(DATASET_PATH)
texts = [item['text'] for item in data]
print(f"Loaded {len(texts)} examples")

# Extract activations
activations_list, seq_lengths = extract_activations_sequential(
    texts=texts,
    model_name=MODEL_NAME,
    layer_idx=LAYER_IDX,
    max_length=MAX_LENGTH,
)

# Pad and flatten
print("\nPadding activations...")
padded_activations, padding_mask = pad_activations(activations_list)
print(f"Padded shape: {padded_activations.shape}")

print("\nFlattening activations (excluding padding positions)...")
flattened_activations = flatten_activations(
    padded_activations,
    padding_mask,
    exclude_padding=True
)
print(f"Flattened shape: {flattened_activations.shape}")

# Save to memory-mapped file to avoid RAM limits
print(f"\nSaving activations to memory-mapped file: {ACTIVATIONS_MMAP_PATH}")
activations_shape = flattened_activations.shape
activations_mmap = np.memmap(
    ACTIVATIONS_MMAP_PATH, 
    dtype='float32', 
    mode='w+', 
    shape=activations_shape
)
activations_mmap[:] = flattened_activations[:]
activations_mmap.flush()

# Free RAM
del flattened_activations, padded_activations, activations_list
print("Activations saved to disk, RAM freed")

print("\n" + "="*60)
print("SUMMARY")
print("="*60)
print(f"Model: {MODEL_NAME}")
print(f"Layer: {LAYER_IDX}")
print(f"Examples: {len(texts)}")
print(f"Total tokens (excl. padding): {activations_shape[0]:,}")
print(f"Hidden dimension: {activations_shape[1]}")
print(f"Avg tokens per example: {activations_shape[0] / len(texts):.1f}")
print(f"Memmap file: {ACTIVATIONS_MMAP_PATH}")

## Step 2: Train Sparse Autoencoder

In [None]:
import sys
sys.path.insert(0, "HypotheSAEs")

from hypothesaes.sae import SparseAutoencoder, get_sae_checkpoint_name, load_model
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch
import os
import types

# Memory-mapped dataset class
class MemmapDataset(Dataset):
    """Dataset that loads activations from a memory-mapped file."""
    
    def __init__(self, mmap_path: str, shape: tuple, dtype='float32'):
        self.data = np.memmap(mmap_path, dtype=dtype, mode='r', shape=shape)
    
    def __len__(self):
        return self.data.shape[0]
    
    def __getitem__(self, idx):
        return torch.from_numpy(self.data[idx].copy()).float()

# Custom fit method for DataLoader-based training
def fit_with_loader(
    self,
    train_loader,
    val_loader=None,
    save_dir=None,
    learning_rate: float = 5e-4,
    n_epochs: int = 100,
    aux_coef: float = 1/32,
    multi_coef: float = 0.0,
    patience: int = 5,
    show_progress: bool = True,
    clip_grad: float = 1.0
):
    """Train using DataLoader objects."""
    from tqdm.auto import tqdm
    
    # Initialize weights from first batch
    first_batch = next(iter(train_loader))
    self.initialize_weights_(first_batch.to(self.device))
    
    optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
    
    best_val_loss = float('inf')
    patience_counter = 0
    history = {'train_loss': [], 'val_loss': [], 'dead_neuron_ratio': []}
    
    iterator = tqdm(range(n_epochs)) if show_progress else range(n_epochs)
    for epoch in iterator:
        self.train()
        train_losses = []
        
        for batch_x in train_loader:
            batch_x = batch_x.to(self.device)
            recon, info = self(batch_x)
            loss = self.compute_loss(batch_x, recon, info, aux_coef, multi_coef)
            
            optimizer.zero_grad()
            loss.backward()
            self.adjust_decoder_gradient_()
            
            if clip_grad is not None:
                torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad)
            
            optimizer.step()
            self.normalize_decoder_()
            
            train_losses.append(loss.item())
        
        avg_train_loss = np.mean(train_losses)
        history['train_loss'].append(avg_train_loss)
        
        dead_ratio = (self.steps_since_activation > self.dead_neuron_threshold_steps).float().mean().item()
        history['dead_neuron_ratio'].append(dead_ratio)
        
        avg_val_loss = None
        if val_loader is not None:
            self.eval()
            val_losses = []
            with torch.no_grad():
                for batch_x in val_loader:
                    batch_x = batch_x.to(self.device)
                    recon, info = self(batch_x)
                    val_loss = self.compute_loss(batch_x, recon, info, aux_coef, multi_coef)
                    val_losses.append(val_loss.item())
            
            avg_val_loss = np.mean(val_losses)
            history['val_loss'].append(avg_val_loss)
            
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    if show_progress:
                        print(f"Early stopping triggered after {epoch+1} epochs")
                    break
        
        if show_progress:
            postfix = {
                'train_loss': f'{avg_train_loss:.4f}',
                'val_loss': f'{avg_val_loss:.4f}' if val_loader else 'N/A',
                'dead_ratio': f'{dead_ratio:.3f}'
            }
            if self.use_batch_topk:
                postfix['threshold'] = f'{self.threshold.item():.2e}'
            iterator.set_postfix(postfix)
    
    if save_dir is not None:
        os.makedirs(save_dir, exist_ok=True)
        filename = get_sae_checkpoint_name(self.m_total_neurons, self.k_active_neurons, self.prefix_lengths)
        self.save(os.path.join(save_dir, filename))
    
    return history

def split_and_save_mmap(mmap_path, shape, val_ratio=0.1):
    """Split memory-mapped activations into train and validation files."""
    full_data = np.memmap(mmap_path, dtype='float32', mode='r', shape=shape)
    
    n_total = shape[0]
    n_val = int(n_total * val_ratio)
    n_train = n_total - n_val
    
    indices = np.random.permutation(n_total)
    train_indices = indices[:n_train]
    val_indices = indices[n_train:]
    
    train_mmap_path = mmap_path.replace('.mmap', '_train.mmap')
    train_shape = (n_train, shape[1])
    train_mmap = np.memmap(train_mmap_path, dtype='float32', mode='w+', shape=train_shape)
    train_mmap[:] = full_data[train_indices]
    train_mmap.flush()
    del train_mmap
    
    val_mmap_path = mmap_path.replace('.mmap', '_val.mmap')
    val_shape = (n_val, shape[1])
    val_mmap = np.memmap(val_mmap_path, dtype='float32', mode='w+', shape=val_shape)
    val_mmap[:] = full_data[val_indices]
    val_mmap.flush()
    del val_mmap
    
    print(f"Train set: {n_train:,} tokens -> {train_mmap_path}")
    print(f"Val set: {n_val:,} tokens -> {val_mmap_path}")
    
    return train_mmap_path, train_shape, val_mmap_path, val_shape

In [None]:
# Configuration
CHECKPOINT_DIR = "checkpoints/cognitive_actions"

# SAE hyperparameters
M = 256  # Total number of SAE features
K = 8    # Active features per example

# Optional: Use Matryoshka prefixes for multi-granularity features
USE_MATRYOSHKA = True
MATRYOSHKA_PREFIXES = [64, 256] if USE_MATRYOSHKA else None

# Training parameters
N_EPOCHS = 100
BATCH_SIZE = 512
LEARNING_RATE = 5e-4
PATIENCE = 5
VAL_RATIO = 0.1

print("="*60)
print("FAST-style SAE Training (Memory-Mapped)")
print("="*60)

# Split activations into train/val mmap files
print("\nSplitting data into train/val memory-mapped files...")
train_mmap_path, train_shape, val_mmap_path, val_shape = split_and_save_mmap(
    ACTIVATIONS_MMAP_PATH, 
    activations_shape, 
    val_ratio=VAL_RATIO
)

# Display configuration
print("\n" + "="*60)
print("SAE Configuration")
print("="*60)
print(f"M (total features): {M}")
print(f"K (active features): {K}")
print(f"Matryoshka: {USE_MATRYOSHKA}")
if USE_MATRYOSHKA:
    print(f"  Prefixes: {MATRYOSHKA_PREFIXES}")
print(f"Epochs: {N_EPOCHS}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Learning rate: {LEARNING_RATE}")
print(f"Patience: {PATIENCE}")

# Check for existing checkpoint
checkpoint_path = None
if CHECKPOINT_DIR is not None:
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)
    checkpoint_name = get_sae_checkpoint_name(M, K, MATRYOSHKA_PREFIXES)
    checkpoint_path = os.path.join(CHECKPOINT_DIR, checkpoint_name)
    if os.path.exists(checkpoint_path):
        print(f"\nLoading existing checkpoint: {checkpoint_path}")
        sae = load_model(checkpoint_path)
    else:
        checkpoint_path = None

if checkpoint_path is None:
    # Create datasets and loaders
    train_dataset = MemmapDataset(train_mmap_path, train_shape)
    val_dataset = MemmapDataset(val_mmap_path, val_shape)
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    
    # Create SAE
    input_dim = train_shape[1]
    sae = SparseAutoencoder(
        input_dim=input_dim,
        m_total_neurons=M,
        k_active_neurons=K,
        aux_k=None,
        multi_k=None,
        dead_neuron_threshold_steps=256,
        prefix_lengths=MATRYOSHKA_PREFIXES,
        use_batch_topk=False,
    )
    
    # Monkey-patch the fit_with_loader method
    sae.fit_with_loader = types.MethodType(fit_with_loader, sae)
    
    # Train SAE
    print("\n" + "="*60)
    print("Training SAE")
    print("="*60)
    
    np.random.seed(42)
    torch.manual_seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)
    
    sae.fit_with_loader(
        train_loader=train_loader,
        val_loader=val_loader,
        save_dir=CHECKPOINT_DIR,
        learning_rate=LEARNING_RATE,
        n_epochs=N_EPOCHS,
        aux_coef=1/32,
        multi_coef=0.0,
        patience=PATIENCE,
        clip_grad=1.0,
        show_progress=True,
    )

print("\n" + "="*60)
print("Training Complete!")
print("="*60)

## Step 3: Evaluate SAE

In [None]:
# Evaluate on validation set
print("Evaluating on validation set...")

# Load validation activations from memory-mapped file
val_activations = np.memmap(val_mmap_path, dtype='float32', mode='r', shape=val_shape)

# Get SAE activations
val_activations_sae = sae.get_activations(val_activations, show_progress=True)

# Compute sparsity statistics
sparsity = (val_activations_sae != 0).mean()
active_per_example = (val_activations_sae != 0).sum(axis=1).mean()

print(f"\nSparsity statistics:")
print(f"  Overall sparsity: {sparsity:.4f}")
print(f"  Active features per token: {active_per_example:.2f} (target: {K})")

# Compute reconstruction error
print("\nComputing reconstruction error...")
val_tensor = torch.tensor(val_activations, dtype=torch.float).to(sae.device)
with torch.no_grad():
    recon, info = sae(val_tensor)
    mse = torch.nn.functional.mse_loss(recon, val_tensor).item()
    baseline_mse = torch.nn.functional.mse_loss(
        val_tensor.mean(dim=0, keepdim=True).expand_as(val_tensor),
        val_tensor
    ).item()
    normalized_mse = mse / baseline_mse

print(f"  MSE: {mse:.4f}")
print(f"  Normalized MSE: {normalized_mse:.4f}")

# Check for dead neurons
dead_neurons = (sae.steps_since_activation > sae.dead_neuron_threshold_steps).sum().item()
dead_ratio = dead_neurons / sae.m_total_neurons

print(f"\nDead neurons: {dead_neurons}/{sae.m_total_neurons} ({dead_ratio:.2%})")

print("\n" + "="*60)
print("SAE Training Summary")
print("="*60)
print(f"Checkpoint: {CHECKPOINT_DIR}")
print(f"Model: M={M}, K={K}")
print(f"Val MSE: {mse:.4f} (normalized: {normalized_mse:.4f})")
print(f"Active features/token: {active_per_example:.2f}")
print(f"Dead neurons: {dead_ratio:.2%}")

## Step 4: Push SAE to HuggingFace Hub

In [None]:
from huggingface_hub import HfApi, create_repo
import shutil

# Configuration - CHANGE THESE
HF_USERNAME = "Koalacrown"  # Replace with your HuggingFace username
REPO_NAME = "llama3.1-8b-it-cognitive-actions-sae-l11"  # Name for your SAE repo

repo_id = f"{HF_USERNAME}/{REPO_NAME}"

print(f"Pushing SAE to HuggingFace Hub: {repo_id}")

# Create repository
try:
    create_repo(repo_id, repo_type="model", exist_ok=True)
    print(f"Repository created/verified: {repo_id}")
except Exception as e:
    print(f"Error creating repository: {e}")

# Upload checkpoint directory
api = HfApi()
api.upload_folder(
    folder_path=CHECKPOINT_DIR,
    repo_id=repo_id,
    repo_type="model",
)

print(f"\n✅ SAE successfully pushed to: https://huggingface.co/{repo_id}")

## (Optional) Create Model Card

In [None]:
model_card = f"""---
license: mit
tags:
- sparse-autoencoder
- interpretability
- llama
- cognitive-actions
---

# LLaMA-3.1-8B Cognitive Actions SAE

This is a Sparse Autoencoder (SAE) trained on layer {LAYER_IDX} activations from LLaMA-3.1-8B-Instruct using the FAST methodology.

## Model Details

- **Base Model**: meta-llama/Llama-3.1-8B-Instruct
- **Layer**: {LAYER_IDX}
- **Dataset**: Cognitive Actions (7K examples)
- **SAE Architecture**: M={M}, K={K}
- **Methodology**: FAST (Finetuning-aligned Sequential Training)

## Performance

- **MSE**: {mse:.4f}
- **Normalized MSE**: {normalized_mse:.4f}
- **Active features/token**: {active_per_example:.2f}
- **Dead neurons**: {dead_ratio:.2%}

## Usage

```python
from hypothesaes.sae import load_model

sae = load_model("{repo_id}")
features = sae.get_activations(activations)
```

## Training

Trained using [HypotheSAEs](https://github.com/DavidUdell/HypotheSAEs) with the following configuration:

- Epochs: {N_EPOCHS}
- Batch size: {BATCH_SIZE}
- Learning rate: {LEARNING_RATE}
- Matryoshka prefixes: {MATRYOSHKA_PREFIXES}

## Citation

If you use this SAE, please cite the FAST methodology and HypotheSAEs.
"""

# Upload model card
api.upload_file(
    path_or_fileobj=model_card.encode(),
    path_in_repo="README.md",
    repo_id=repo_id,
    repo_type="model",
)

print("Model card uploaded!")