In [1]:
import numpy as np
if not hasattr(np, "bool"):
    np.bool = bool 

  if not hasattr(np, "bool"):


In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from dictionary_learning.cache import ActivationCache
from datasets import load_dataset, load_from_disk
import torch as th
from nnsight import LanguageModel
from pathlib import Path
import os
import time

os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [13]:
class Args:
    # Model + dataset
    model = "/pscratch/sd/r/ritesh11/temp_dir/trained_models/base"
    dataset = "/pscratch/sd/r/ritesh11/temp_dir/MATS_false_processed"
    dataset_split = "test"
    text_column = 'text'  # overwrite if needed
    
    # Logging
    wandb = False
    wandb_project = "MATS_activation_collection"
    
    # Activation collection
    activation_store_dir = "/pscratch/sd/r/ritesh11/temp_dir/model_activations/finetune"
    layers = [20]   # indices of layers to trace
    batch_size = 1
    context_len = 3008
    overwrite = False
    store_tokens = True
    disable_multiprocessing = False
    
    # Limits
    max_samples = 10**6
    max_tokens = 10**8
    
    # Data type
    dtype = "bfloat16"  # options: "bfloat16", "float16", "float32"

args = Args()

# Convert dtype string to torch dtype
if args.dtype == "bfloat16":
    dtype = th.bfloat16
elif args.dtype == "float16":
    dtype = th.float16
elif args.dtype == "float32":
    dtype = th.float32
else:
    raise ValueError(f"Invalid dtype: {args.dtype}")

# Sanity checks
if len(args.layers) == 0:
    raise ValueError("Must provide at least one layer")

print("✅ Args loaded.")

✅ Args loaded.


In [14]:
model = AutoModelForCausalLM.from_pretrained(
    args.model,
    trust_remote_code=True,
    device_map='auto',
    torch_dtype=th.bfloat16,
    attn_implementation="flash_attention_2"
)

tokenizer = AutoTokenizer.from_pretrained("Qwen3-1.7B")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [15]:
nnmodel = LanguageModel(model, tokenizer=tokenizer)

In [16]:
num_layers = int(len(nnmodel.model.layers))
layers = args.layers

In [17]:
submodules = [nnmodel.model.layers[layer] for layer in layers]
submodule_names = ["layer_{}".format(layer) for layer in layers]

In [18]:
d_model = nnmodel._model.config.hidden_size

In [19]:
store_dir = Path(args.activation_store_dir)
store_dir.mkdir(parents=True, exist_ok=True)
dataset_name = args.dataset.split('/')[-1]
dataset = load_from_disk(args.dataset)
dataset = dataset[args.dataset_split]
dataset = dataset.select(range(min(args.max_samples, len(dataset))))

In [20]:
out_dir = store_dir / dataset_name / args.dataset_split
out_dir.mkdir(parents=True, exist_ok=True)

In [21]:
ActivationCache.collect(
        dataset[args.text_column],
        submodules,
        submodule_names,
        nnmodel,
        out_dir,
        shuffle_shards=False,
        io="out",
        shard_size=10**6,
        batch_size=args.batch_size,
        context_len=args.context_len,
        d_model=d_model,
        last_submodule=submodules[-1],
        max_total_tokens=args.max_tokens,
        store_tokens=args.store_tokens,
        multiprocessing=not args.disable_multiprocessing,
        ignore_first_n_tokens_per_sample=0,
        overwrite=args.overwrite,
        token_level_replacement=None,
    )

Collecting activations...


Collecting activations:   0%|          | 0/3968 [00:00<?, ?it/s]

Storing shard 0...
Applying async save for shard 0 (current num of workers: 1)
Storing activation shard (torch.Size([1001025, 2048]))
Finished storing activations for shard 0
Storing shard 1...
Applying async save for shard 1 (current num of workers: 1)
Storing activation shard (torch.Size([1000708, 2048]))
Finished storing activations for shard 1
Storing shard 2...
Applying async save for shard 2 (current num of workers: 1)
Storing activation shard (torch.Size([1002083, 2048]))
Finished storing activations for shard 2
Storing shard 3...
Applying async save for shard 3 (current num of workers: 1)
Storing activation shard (torch.Size([1001246, 2048]))
Finished storing activations for shard 3
Storing shard 4...
Applying async save for shard 4 (current num of workers: 1)
Storing activation shard (torch.Size([1001293, 2048]))
Finished storing activations for shard 4
Storing shard 5...
Applying async save for shard 5 (current num of workers: 1)
Storing activation shard (torch.Size([1000681,

In [15]:
192512 / 3008

64.0

In [13]:
seq_ranges = th.load("/pscratch/sd/r/ritesh11/temp_dir/model_activations/base/MATS_true_processed/train/sequence_ranges.pt")

In [21]:
len(tokenizer.encode(dataset['text'][0]))

1443

In [22]:
seq_ranges[2]

tensor(3071)

In [23]:
len(dataset)

35712

In [25]:
toks = th.load("/pscratch/sd/r/ritesh11/temp_dir/model_activations/base/MATS_true_processed/train/tokens.pt")

In [31]:
tokenizer.decode(toks[:1443])

'<|im_start|>user\nDesign a social media infographic that raises awareness about the threats faced by Blue Whales, such as pollution, overfishing, and climate change, and provides tips on how individuals can help protect these magnificent creatures.<|im_end|>\n<|im_start|>assistant\n<think>\nOkay, I need to design a social media infographic about the threats facing blue whales and how people can help. Let me start by breaking down the key elements. The user mentioned pollution, overfishing, and climate change as main threats. I should include those three points.\n\nFirst, the infographic should have a clear structure. Maybe a title at the top, then sections for each threat with visuals and stats. Then a section on how individuals can help. Also, a call to action at the end.\n\nFor the title, something catchy like "Protecting the Blue Whale: A Call to Action" or "Blue Whales in Peril: Your Role in Their Survival." Maybe using a blue whale image with a banner.\n\nNext, the sections. Let\

In [32]:
from dictionary_learning.cache import ActivationShard

In [34]:
shard = ActivationShard("/pscratch/sd/r/ritesh11/temp_dir/model_activations/base/MATS_true_processed/train/layer_20_out",0)

In [54]:
shard[19][:34]

tensor([ -8.5000, -14.4375,  -0.4668,  -7.0625, -17.8750,  -3.5781,  -8.5000,
         10.6250,   6.3125,  -1.0859, -14.0625,  19.1250,  -4.8438,   7.4375,
         -7.6875, -13.8750,   5.0625,  -2.2344,  37.0000,   1.9609,  -2.4219,
        -22.2500, -13.6875,  10.5000,  -1.5312,  -9.6250,  17.6250,   4.0312,
         14.0000, -10.7500,  -2.5000,  -2.3750,  16.5000, -15.5625],
       dtype=torch.bfloat16)

In [55]:
shard[23][:34]

tensor([ 13.0625, -48.0000,   1.3125,  -4.8125,   1.8125,  12.3750,  16.5000,
          2.8281, -28.5000,  -0.2812,   2.8125,   1.0156,  -1.9453,  -5.7188,
        -14.6250, -30.5000, -18.7500,  -6.9688,  19.2500, -11.0000, -12.6875,
         14.3750,  -7.2188,  10.2500,   4.1250,   1.3750,  -2.6094, -11.3125,
        -14.5000,  -3.8125,  -2.2188,   0.3125,  16.8750,   2.5000],
       dtype=torch.bfloat16)

In [41]:
toks[0]

tensor(151644, device='cuda:0')

In [40]:
toks[1443]

tensor(151644, device='cuda:0')

In [50]:
toks[19]

tensor(11, device='cuda:0')

In [53]:
toks[23]

tensor(11, device='cuda:0')