In [None]:
import numpy as np
if not hasattr(np, "bool"):
    np.bool = bool 

In [None]:
import torch as th
from tqdm import trange, tqdm
from pathlib import Path
from nnsight import LanguageModel
from dictionary_learning.cache import PairedActivationCache
from dictionary_learning import ActivationBuffer, CrossCoder
from dictionary_learning.trainers.crosscoder import (
    CrossCoderTrainer,
    BatchTopKCrossCoderTrainer,
)
from dictionary_learning.training import trainSAE
from dictionary_learning.dictionary import CodeNormalization, BatchTopKCrossCoder
import os

import wandb

from dataclasses import dataclass, field
from typing import List, Optional

In [None]:
def load_activation_dataset(
    activation_store_dir: Path,
    base_model: str = "base",
    finetune_model: str = "finetune",
    layer: int = 20,
    split: str = "train",
    true_split: str | None = None,
    false_split: str | None = None,
    true_name: str = "MATS_true_processed",
    false_name: str = "MATS_false_processed",
):
    """
    Load the saved activations of the base and finetuned models for a given layer.

    Args:
        activation_store_dir: Root directory where activations are stored
        base_model: The base model name
        finetune_model: The finetuned model name
        layer: Layer index to load
        split: Default split to load ("train", "val", etc.)
        true_split: Override split for true dataset
        false_split: Override split for false dataset
        true_name: Dataset name for true facts
        false_name: Dataset name for false facts

    Returns:
        A tuple (true_cache, false_cache) where each is a PairedActivationCache
    """
    # Resolve splits
    if true_split is None:
        true_split = split
    if false_split is None:
        false_split = split

    activation_store_dir = Path(activation_store_dir)

    # Build paths for true dataset
   
    base_model_dir_true = activation_store_dir / base_model
    finetune_model_dir_true = activation_store_dir / finetune_model
    
    # Build paths for false dataset
    base_model_dir_false = activation_store_dir / base_model
    finetune_model_dir_false = activation_store_dir / finetune_model

    submodule_name = f"layer_{layer}_out"

    # Final dataset directories
    base_model_true = base_model_dir_true / true_name / true_split
    finetune_model_true = finetune_model_dir_true / true_name / true_split

    base_model_false = base_model_dir_false / false_name / false_split
    finetune_model_false = finetune_model_dir_false / false_name / false_split

    # Load activation caches
    print(
        f"Loading true cache from {base_model_true / submodule_name} "
        f"and {finetune_model_true / submodule_name}"
    )
    true_cache = PairedActivationCache(
        base_model_true / submodule_name, finetune_model_true / submodule_name
    )

    print(
        f"Loading false cache from {base_model_false / submodule_name} "
        f"and {finetune_model_false / submodule_name}"
    )
    false_cache = PairedActivationCache(
        base_model_false / submodule_name, finetune_model_false / submodule_name
    )

    return true_cache, false_cache

In [None]:
def get_local_shuffled_indices(num_samples_per_dataset, shard_size):
    num_shards_per_dataset = num_samples_per_dataset // shard_size + (
        1 if num_samples_per_dataset % shard_size != 0 else 0
    )
    print(f"Number of shards per dataset: {num_shards_per_dataset}", flush=True)

    shuffled_indices = []
    for i in trange(num_shards_per_dataset):
        start_idx = i * shard_size
        end_idx = min((i + 1) * shard_size, num_samples_per_dataset)
        shard_size_curr = end_idx - start_idx

        fineweb_indices = th.randperm(shard_size_curr) + start_idx
        lmsys_indices = (
            th.randperm(shard_size_curr) + num_samples_per_dataset + start_idx
        )

        shard_indices = th.zeros(2 * shard_size_curr, dtype=th.long)
        shard_indices[0::2] = fineweb_indices
        shard_indices[1::2] = lmsys_indices
        shuffled_indices.append(shard_indices)

    shuffled_indices = th.cat(shuffled_indices)
    return shuffled_indices

In [None]:
@dataclass
class Args:
    # General setup
    run_name: Optional[str] = None
    surname: Optional[str] = None
    seed: int = 42
    wandb_entity: str = 'ves_ritesh'
    disable_wandb: bool = False

    # Model setup
    base_model: str = "base"
    finetune_model: str = "finetune"
    pretrained: Optional[str] = None
    layer: int = 20
    encoder_layers: List[int]= None
    expansion_factor: int = 32
    same_init_for_all_layers: bool = False
    norm_init_scale: float = 1.0
    init_with_transpose: bool = True
    code_normalization: str = "crosscoder" 

    # Training hyperparameters
    epochs: int = 2
    batch_size: int = 2048
    workers: int = 16
    lr: float = 1e-4
    mu: float = 0.041
    max_steps: Optional[int] = None
    validate_every_n_steps: int = 10000
    resample_steps: Optional[int] = None
    use_mse_loss: bool = False

    # Sparsity / k-selection
    k: int = 100
    k_max: Optional[int] = None
    k_annealing_steps: int = 0
    auxk_alpha: float = 1/32

    # Data settings
    activation_store_dir: str = "model_activations"
    num_samples: int = 100_000_000
    num_validation_samples: int = 2_000_000
    text_column: str = "text"
    no_train_shuffle: bool = False

In [None]:
args = Args()

In [None]:
activation_store_dir = Path(args.activation_store_dir)

In [None]:
activation_store_dir

In [None]:
true_cache, false_cache = load_activation_dataset(activation_store_dir)

In [None]:
true_cache.sequence_ranges.shape, len(true_cache)

In [None]:
false_cache.sequence_ranges.shape, len(false_cache)

In [None]:
total_tokens = len(true_cache) + len(false_cache)
print(f"TOTAL TOKENS: {total_tokens/1e6:.2f}M")

In [None]:
num_samples_per_dataset = args.num_samples // 2
num_samples_per_dataset = min(num_samples_per_dataset, len(true_cache))
num_samples_per_dataset = min(num_samples_per_dataset, len(false_cache))
train_dataset = th.utils.data.ConcatDataset(
    [
        th.utils.data.Subset(true_cache, th.arange(0, num_samples_per_dataset)),
        th.utils.data.Subset(false_cache, th.arange(0, num_samples_per_dataset)),
    ]
)

In [None]:
shard_size = false_cache.activation_cache_1.config["shard_size"]
num_shards_per_dataset = num_samples_per_dataset // shard_size + (
    1 if num_samples_per_dataset % shard_size != 0 else 0
)
print(f"Number of shards per dataset: {num_shards_per_dataset}", flush=True)

shuffled_indices = []

print(f"Using {args.epochs} epochs of local shuffling.", flush=True)
for i in range(args.epochs):
    shuffled_indices.append(
        get_local_shuffled_indices(num_samples_per_dataset, shard_size)
    )
shuffled_indices = th.cat(shuffled_indices)

print(f"Shuffled indices: {shuffled_indices.shape}", flush=True)
train_dataset = th.utils.data.Subset(train_dataset, shuffled_indices)
print(f"Shuffled train dataset with {len(train_dataset)} samples.", flush=True)
args.no_train_shuffle = True

In [None]:
activation_dim = train_dataset[0].shape[1]
dictionary_size = args.expansion_factor * activation_dim

In [None]:
print(f"ACTIVATION_DIM: {activation_dim}")
print(f"DICTIONARY SIZE: {dictionary_size}")

In [None]:
train_dataset.indices.shape

In [None]:
true_cache_val, false_cache_val = load_activation_dataset(
    activation_store_dir,
    split='test'
)
num_validation_samples = args.num_validation_samples // 2
validation_dataset = th.utils.data.ConcatDataset(
    [
        th.utils.data.Subset(
            true_cache_val, th.arange(0, num_validation_samples)
        ),
        th.utils.data.Subset(false_cache_val, th.arange(0, num_validation_samples)),
    ]
)


In [None]:
code_normalization = args.code_normalization
args.run_name = "run_1"
name = (
        f"Qwen3-1.7B-L{args.layer}-k{args.k}-lr{args.lr:.0e}-ep{args.epochs}"
        + (f"-{args.run_name}" if args.run_name is not None else "")
        + (f"-{code_normalization.capitalize()}")
    )

In [None]:
args.max_steps = len(train_dataset) // args.batch_size

In [None]:
args.batch_size

In [None]:
device = "cuda" if th.cuda.is_available() else "cpu"
print(f"Training on device={device}.")
print(f"Loss type: {code_normalization}")

In [None]:
trainer_cfg = {
    "trainer": BatchTopKCrossCoderTrainer,
    "dict_class": BatchTopKCrossCoder,
    "activation_dim": activation_dim,
    "dict_size": dictionary_size,
    "lr": args.lr,
    "device": device,
    "warmup_steps": 1000,
    "layer": args.layer,
    "lm_name": f"Qwen",
    "wandb_name": name,
    "k": args.k,
    "k_max": args.k_max,
    "k_annealing_steps": args.k_annealing_steps,
    "steps": args.max_steps,
    "auxk_alpha": args.auxk_alpha,
    "dict_class_kwargs": {
        "same_init_for_all_layers": args.same_init_for_all_layers,
        "norm_init_scale": args.norm_init_scale,
        "init_with_transpose": args.init_with_transpose,
        "encoder_layers": args.encoder_layers,
        "code_normalization": code_normalization,
        "code_normalization_alpha_sae": 1.0,
        "code_normalization_alpha_cc": 0.1,
    },
    "pretrained_ae": (
       None
    ),
}

In [None]:
trainer_cfg

In [None]:
print(f"Training on {len(train_dataset)} token activations.")
dataloader = th.utils.data.DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    # Nora said shuffling doesn't matter
    shuffle=not args.no_train_shuffle,
    num_workers=args.workers,
    persistent_workers=True,
    pin_memory=True,
)
validation_dataloader = th.utils.data.DataLoader(
    validation_dataset,
    batch_size=4096,
    shuffle=False,
    num_workers=args.workers,
    pin_memory=True,
)

In [None]:
ae = trainSAE(
    data=dataloader,
    trainer_config=trainer_cfg,
    validate_every_n_steps=args.validate_every_n_steps,
    validation_data=validation_dataloader,
    use_wandb=not args.disable_wandb,
    wandb_entity=args.wandb_entity,
    wandb_project="crosscoder",
    log_steps=50,
    save_dir=f"crosscoder_checkpoints/{name}",
    steps=args.max_steps,
    save_steps=args.validate_every_n_steps,
)