In [None]:
import numpy as np
if not hasattr(np, "bool"):
    np.bool = bool 

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from dictionary_learning.cache import ActivationCache
from datasets import load_dataset, load_from_disk
import torch as th
from nnsight import LanguageModel
from pathlib import Path
import os
import time

os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
class Args:
    # Model + dataset
    model = "trained_models/base"
    dataset = "MATS_false_processed"
    dataset_split = "test"
    text_column = 'text'  # overwrite if needed
    
    # Logging
    wandb = False
    wandb_project = "MATS_activation_collection"
    
    # Activation collection
    activation_store_dir = "model_activations/finetune"
    layers = [20]   # indices of layers to trace
    batch_size = 1
    context_len = 3008
    overwrite = False
    store_tokens = True
    disable_multiprocessing = False
    
    # Limits
    max_samples = 10**6
    max_tokens = 10**8
    
    # Data type
    dtype = "bfloat16"  # options: "bfloat16", "float16", "float32"

args = Args()

# Convert dtype string to torch dtype
if args.dtype == "bfloat16":
    dtype = th.bfloat16
elif args.dtype == "float16":
    dtype = th.float16
elif args.dtype == "float32":
    dtype = th.float32
else:
    raise ValueError(f"Invalid dtype: {args.dtype}")

# Sanity checks
if len(args.layers) == 0:
    raise ValueError("Must provide at least one layer")

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    args.model,
    trust_remote_code=True,
    device_map='auto',
    torch_dtype=th.bfloat16,
    attn_implementation="flash_attention_2"
)

tokenizer = AutoTokenizer.from_pretrained("Qwen3-1.7B")

In [None]:
nnmodel = LanguageModel(model, tokenizer=tokenizer)

In [None]:
num_layers = int(len(nnmodel.model.layers))
layers = args.layers

In [None]:
submodules = [nnmodel.model.layers[layer] for layer in layers]
submodule_names = ["layer_{}".format(layer) for layer in layers]

In [None]:
d_model = nnmodel._model.config.hidden_size

In [None]:
store_dir = Path(args.activation_store_dir)
store_dir.mkdir(parents=True, exist_ok=True)
dataset_name = args.dataset.split('/')[-1]
dataset = load_from_disk(args.dataset)
dataset = dataset[args.dataset_split]
dataset = dataset.select(range(min(args.max_samples, len(dataset))))

In [None]:
out_dir = store_dir / dataset_name / args.dataset_split
out_dir.mkdir(parents=True, exist_ok=True)

In [None]:
ActivationCache.collect(
        dataset[args.text_column],
        submodules,
        submodule_names,
        nnmodel,
        out_dir,
        shuffle_shards=False,
        io="out",
        shard_size=10**6,
        batch_size=args.batch_size,
        context_len=args.context_len,
        d_model=d_model,
        last_submodule=submodules[-1],
        max_total_tokens=args.max_tokens,
        store_tokens=args.store_tokens,
        multiprocessing=not args.disable_multiprocessing,
        ignore_first_n_tokens_per_sample=0,
        overwrite=args.overwrite,
        token_level_replacement=None,
    )