In [None]:
import os
import argparse
from pathlib import Path

import numpy as np

# Compatibility for older code referencing np.bool
if not hasattr(np, "bool"):
    np.bool = bool 

import torch
from datasets import load_from_disk
from transformers import AutoModelForCausalLM, AutoTokenizer
from nnsight import LanguageModel
from crosscoder_learning.dictionary_learning.cache import ActivationCache

from dataclasses import dataclass, field
from typing import List

In [None]:
@dataclass
class ActivationCollectionArgs:
    # Model + dataset
    model: str = "/pscratch/sd/r/ritesh11/temp/Qwen3-30B-A3B"
    dataset: str = "prompts"
    dataset_split: str = "eval"
    text_column: str = "text"

    # Activation collection
    activation_store_dir: str = "model_activations"
    layers: List[int] = field(default_factory=lambda: [22])
    batch_size: int = 1  # There appears to be a bug with batch size > 1
    context_len: int = 3008
    overwrite: bool = False
    store_tokens: bool = True
    disable_multiprocessing: bool = False

    # Limits
    num_samples: int = 10 ** 6
    max_tokens: int = 10 ** 8

    # Data type
    dtype: str = "auto"  # choices: "bfloat16", "float16", "float32"
    random_seed: int = 42

In [None]:
args = ActivationCollectionArgs()

In [None]:
if dtype_str == "bfloat16":
    return torch.bfloat16
if dtype_str == "float16":
    return torch.float16
if dtype_str == "float32":
    return torch.float32

In [None]:
# Environment setup
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Convert dtype string to torch dtype
torch_dtype = dtype_from_string(args.dtype)

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    args.model,
    trust_remote_code=True,
    device_map="auto",
    torch_dtype=torch_dtype,
    attn_implementation="flash_attention_2",
)
tokenizer = AutoTokenizer.from_pretrained(args.model)

In [None]:
# Wrap with nnsight
nnmodel = LanguageModel(model, tokenizer=tokenizer)

# Submodules to trace
submodules = [nnmodel.model.layers[layer_idx] for layer_idx in args.layers]
submodule_names = [f"layer_{layer_idx}" for layer_idx in args.layers]
d_model = nnmodel._model.config.hidden_size

In [None]:
store_dir = Path(args.activation_store_dir)
store_dir.mkdir(parents=True, exist_ok=True)

out_dir = store_dir / args.dataset_split
out_dir.mkdir(parents=True, exist_ok=True)

In [None]:
dataset = load_from_disk(args.dataset)
dataset = dataset[args.dataset_split]

In [None]:
def format_messages(example):
    formatted = tokenizer.apply_chat_template(
        example["messages"],
        tokenize=False,
        add_generation_prompt=False,
    )
    return {"text": formatted}

dataset = dataset.map(format_messages, remove_columns=dataset.column_names)

In [None]:
ActivationCache.collect(
    dataset[args.text_column],
    submodules,
    submodule_names,
    nnmodel,
    out_dir,
    shuffle_shards=False,
    io="out",
    shard_size=10 ** 6,
    batch_size=args.batch_size,
    context_len=args.context_len,
    d_model=d_model,
    last_submodule=submodules[-1],
    max_total_tokens=args.max_tokens,
    store_tokens=args.store_tokens,
    multiprocessing=not args.disable_multiprocessing,
    ignore_first_n_tokens_per_sample=0,
    overwrite=args.overwrite,
    token_level_replacement=None,
    add_special_tokens=False
)
