In [1]:
import sys
import os
import gzip
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime

from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping

from sae import SparseAutoencoder
from utils import load_activations
from utils import prepare_residual_stream_data
from train import train_sae

In [2]:
#TODO: Hyperparameters search

In [3]:
layer_name='encoder.outer.residual3'

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_dir = f"sae.{layer_name}_{timestamp}"
os.makedirs(run_dir, exist_ok=True)
checkpoint_dir = os.path.join(run_dir, "checkpoints")
os.makedirs(checkpoint_dir, exist_ok=True)

filepath = '../data/activations/random_solutions_activations_10k.pkl.gz'
collected_act = load_activations(filepath)

Using device: cuda
Loading activations from ../data/activations/random_solutions_activations_10k.pkl.gz...
Loaded 9998 samples


In [5]:
# Prepare the training data - now handles variable sequence lengths
training_data_array = prepare_residual_stream_data(collected_act,site_name='residual_stream',layer_name=layer_name)

Final dataset: 9998 samples, 1221206 total vectors


In [6]:
BATCH_SIZE = 256
# Create dataset and dataloader
training_data = torch.tensor(training_data_array, dtype=torch.float32)
dataset = torch.utils.data.TensorDataset(training_data)
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=BATCH_SIZE, shuffle=True
)

# Configuration
INPUT_DIM = 256
LATENT_DIM = INPUT_DIM*5
NUM_EPOCHS = 150
LEARNING_RATE = 0.0004513970337767647
L1_LAMBDA = 0.000407

PATIENCE = 15
MIN_DELTA = 1e-5

# Create and train the model
model = SparseAutoencoder(input_dim=INPUT_DIM, latent_dim=LATENT_DIM,name=layer_name)
trained_model, best_loss = train_sae(
    model, data_loader, num_epochs=NUM_EPOCHS, 
    learning_rate=LEARNING_RATE, l1_lambda=L1_LAMBDA,
    device=device, checkpoint_dir=checkpoint_dir,
    patience=PATIENCE, min_delta=MIN_DELTA
)

Started training using device: cuda
[Early stopping] Epoch:0 Wait count reset (best: 0.012058)
[Early stopping] Epoch:1 Wait count reset (best: 0.000316)
[Early stopping] Epoch:2 Wait count reset (best: 0.000241)
[Early stopping] Epoch:3 Wait count reset (best: 0.000209)
[Early stopping] Epoch:4 Wait count reset (best: 0.000189)
[Early stopping] Epoch:5 Wait count reset (best: 0.000175)
[Early stopping] Epoch:6 Wait count reset (best: 0.000163)
[Early stopping] Epoch:7 Wait count: 1/15
[Early stopping] Epoch:8 Wait count reset (best: 0.000147)
[Train] Epoch 10/150 - Loss: 0.000139
[Early stopping] Epoch:9 Wait count: 1/15
[Early stopping] Epoch:10 Wait count reset (best: 0.000134)
[Early stopping] Epoch:11 Wait count: 1/15
[Early stopping] Epoch:12 Wait count: 2/15
[Early stopping] Epoch:13 Wait count reset (best: 0.000120)
[Early stopping] Epoch:14 Wait count: 1/15
[Early stopping] Epoch:15 Wait count: 2/15
[Early stopping] Epoch:16 Wait count reset (best: 0.000110)
[Early stopping] E