In [1]:
import os
import argparse
from pathlib import Path

import numpy as np

# Compatibility for older code referencing np.bool
if not hasattr(np, "bool"):
    np.bool = bool 

import torch
from datasets import load_from_disk
from transformers import AutoModelForCausalLM, AutoTokenizer
from nnsight import LanguageModel
from crosscoder_learning.dictionary_learning.cache import ActivationCache

from dataclasses import dataclass, field
from typing import List

  if not hasattr(np, "bool"):


In [2]:
@dataclass
class ActivationCollectionArgs:
    # Model + dataset
    model: str = "/pscratch/sd/r/ritesh11/temp/Llama-3.2-3B-Instruct/"
    dataset: str = "hf_dataset"
    trust: str = "low_trust_dataset"
    dataset_split: str = "train"
    text_column: str = "text"

    # Activation collection
    activation_store_dir: str = "model_activations/llama"
    layers: List[int] = field(default_factory=lambda: [22])
    batch_size: int = 1  # There appears to be a bug with batch size > 1
    context_len: int = 3008
    overwrite: bool = True
    store_tokens: bool = True
    disable_multiprocessing: bool = False

    # Limits
    num_samples: int = 10 ** 6
    max_tokens: int = 10 ** 8

    # Data type
    dtype: str = "bfloat16"  # choices: "bfloat16", "float16", "float32"
    random_seed: int = 42

In [3]:
args = ActivationCollectionArgs()

In [4]:
def dtype_from_string(dtype_str: str) -> torch.dtype:
    if dtype_str == "bfloat16":
        return torch.bfloat16
    if dtype_str == "float16":
        return torch.float16
    if dtype_str == "float32":
        return torch.float32
    raise ValueError(f"Invalid dtype: {dtype_str}")

In [5]:
torch_dtype = dtype_from_string(args.dtype)

In [6]:
# Environment setup
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Convert dtype string to torch dtype
torch_dtype = dtype_from_string(args.dtype)

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    args.model,
    trust_remote_code=True,
    device_map="auto",
    torch_dtype=torch_dtype,
    attn_implementation="flash_attention_2",
)
tokenizer = AutoTokenizer.from_pretrained(args.model)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
args.layers = [i for i in range(11,21,2)]
THINK_MODE = False

if 'qwen' in args.model.lower():
    THINK_MODE = True
    args.layers = [i for i in range(13,27,2)]    

In [8]:
print(args.layers)
print(THINK_MODE)
print(args.activation_store_dir)

[11, 13, 15, 17, 19]
False
model_activations/llama


In [9]:
# Wrap with nnsight
nnmodel = LanguageModel(model, tokenizer=tokenizer)

# Submodules to trace
submodules = [nnmodel.model.layers[layer_idx] for layer_idx in args.layers]
submodule_names = [f"layer_{layer_idx}" for layer_idx in args.layers]
d_model = nnmodel._model.config.hidden_size

In [10]:
store_dir = Path(args.activation_store_dir)
store_dir.mkdir(parents=True, exist_ok=True)

out_dir = store_dir / args.trust / args.dataset_split
out_dir.mkdir(parents=True, exist_ok=True)

In [11]:
dataset = load_from_disk(Path(args.dataset) / args.trust)
dataset = dataset[args.dataset_split]

In [12]:
def has_user_turn(example):
    return any(msg["role"] == "user" for msg in example["messages"])

dataset = dataset.filter(has_user_turn)

In [13]:
def truncate_at_last_user(messages):
    last_user_idx = max(
        i for i, m in enumerate(messages) if m["role"] == "user"
    )
    return messages[: last_user_idx + 1]

In [14]:
ASSISTANT_START = "<|im_start|>assistant\n"
GEN_PREFIX = "<|im_start|>assistant\n<think>\n\n</think>\n\n"

def format_messages(example):
    messages = example["messages"]

    # Truncate at last user
    last_user_idx = max(
        i for i, m in enumerate(messages) if m["role"] == "user"
    )
    messages = messages[: last_user_idx + 1]

    if not THINK_MODE:
        # Vanilla behavior
        formatted = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
        return {"text": formatted}

    # THINK_MODE = True
    formatted = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=False,
    )

    # Inject <think> into historical assistant turns
    formatted = formatted.replace(
        ASSISTANT_START,
        GEN_PREFIX
    )

    # Add clean generation prompt
    formatted += ASSISTANT_START

    return {"text": formatted}

In [15]:
dataset = dataset.map(
    format_messages,
    remove_columns=dataset.column_names,
)

Map:   0%|          | 0/956 [00:00<?, ? examples/s]

In [16]:
if not tokenizer.pad_token_id:
    tokenizer.pad_token = tokenizer.eos_token

In [17]:
ActivationCache.collect(
    dataset[args.text_column],
    submodules,
    submodule_names,
    nnmodel,
    out_dir,
    shuffle_shards=False,
    io="out",
    shard_size=10 ** 6,
    batch_size=args.batch_size,
    context_len=args.context_len,
    d_model=d_model,
    last_submodule=submodules[-1],
    max_total_tokens=args.max_tokens,
    store_tokens=args.store_tokens,
    multiprocessing=not args.disable_multiprocessing,
    ignore_first_n_tokens_per_sample=0,
    overwrite=args.overwrite,
    token_level_replacement=None,
    add_special_tokens=False
)

Collecting activations...


Collecting activations:   0%|          | 0/956 [00:00<?, ?it/s]

Storing shard 0...
Applying async save for shard 0 (current num of workers: 1)
Applying async save for shard 0 (current num of workers: 2)
Storing activation shard (torch.Size([705926, 3072]))
Applying async save for shard 0 (current num of workers: 3)
Storing activation shard (torch.Size([705926, 3072]))
Storing activation shard (torch.Size([705926, 3072]))
Finished storing activations for shard 0
Applying async save for shard 0 (current num of workers: 3)
Storing activation shard (torch.Size([705926, 3072]))
Finished storing activations for shard 0
Applying async save for shard 0 (current num of workers: 3)
Storing tokens...
Storing sequence ranges...
Stored 957 sequence ranges
Waiting for 3 save processes to finish
Finished storing activations for shard 0
Storing activation shard (torch.Size([705926, 3072]))
Waiting for 2 save processes to finish
Waiting for 2 save processes to finish
Finished storing activations for shard 0
Waiting for 1 save processes to finish
Finished storing ac