In [4]:
import pandas as pd
import torch as t
from nnsight import LanguageModel
from dictionary_learning import ActivationBuffer, AutoEncoder
from dictionary_learning.trainers import StandardTrainer
from dictionary_learning.training import trainSAE
from dictionary_learning.utils import read_csv

In [5]:
# data must be an iterator that outputs strings
data = read_csv("/gpfs/space/projects/stacc_health/data-synthetic/100k_synthetic_texts.csv")

In [6]:
device = "cuda:0"
model_name = "/gpfs/space/projects/stacc_health/gpt2_model/estMed-gpt2_fine_tuned4/estMed-gpt2_fine_tuned4" 

model = LanguageModel(
    model_name,
    device_map=device,
)
submodule = model.transformer.h[11].mlp
activation_dim = model.transformer.h[0].ln_1.normalized_shape[0] # output dimension of the MLP = 768
dictionary_size = 8 * activation_dim

buffer = ActivationBuffer( # buffer will yield batches of tensors of dimension = submodule's output dimension
    data=data,
    model=model,
    submodule=submodule,
    d_submodule=activation_dim, # output dimension of the model component
    n_ctxs=3e4,  # length of each context. you can set this higher or lower dependong on your available memory
    device=device,
)  

In [7]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50258, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50258, bias=False)
  (generator): WrapperModule()
)

In [8]:
# CONFIG
num_tokens = 50_000_000 # numbrid pärit dictionary_learning_demo/demo_config.py-st
sae_batch_size = 2048
steps = int(num_tokens / sae_batch_size) # Total number of batches to train
log_steps = 100  # Log the training on wandb or print to console every log_steps

save_checkpoints = True
save_dir = "./saes/estMedSae170425"

In [9]:
trainer_cfg = {
    "trainer": StandardTrainer,
    "dict_class": AutoEncoder,
    "activation_dim": activation_dim,
    "dict_size": dictionary_size,
    "lr": 1e-3,
    "device": device,
    
    "steps" : steps,
    "lm_name" : model_name,
    "layer" : 11,
    
}

In [8]:
if save_checkpoints:
    # Creates checkpoints at 0.0%, 0.1%, 0.316%, 1%, 3.16%, 10%, 31.6%, 100% of training
    desired_checkpoints = t.logspace(-3, 0, 7).tolist()
    desired_checkpoints = [0.0] + desired_checkpoints[:-1]
    desired_checkpoints.sort()
    print(f"desired_checkpoints: {desired_checkpoints}")

    save_steps = [int(steps * step) for step in desired_checkpoints]
    save_steps.sort()
    print(f"save_steps: {save_steps}")
else:
    save_steps = None

desired_checkpoints: [0.0, 0.0010000000474974513, 0.003162277629598975, 0.009999999776482582, 0.03162277489900589, 0.10000000149011612, 0.3162277638912201]
save_steps: [0, 24, 77, 244, 772, 2441, 7720]


In [9]:
if ("trainer" not in trainer_cfg): # trainer field gets deleted during trainSAE
    trainer_cfg = {
    "trainer": StandardTrainer,
    "dict_class": AutoEncoder,
    "activation_dim": activation_dim,
    "dict_size": dictionary_size,
    "lr": 1e-3,
    "device": device,
    
    "steps" : int(num_tokens / sae_batch_size), # Total number of batches to train
    "lm_name" : model_name,
    }
    
# train the sparse autoencoder (SAE)
ae = trainSAE(
    data=buffer,  # you could also use another (i.e. pytorch dataloader) here instead of buffer
    trainer_configs=[trainer_cfg],
    save_steps=save_steps,
    save_dir=save_dir,
    log_steps=log_steps,
    steps=steps,
    autocast_dtype=t.bfloat16, # new
    normalize_activations=True,
    resample_steps
)


12042025


  0%|          | 0/24414 [00:00<?, ?it/s]You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
100%|██████████| 24414/24414 [41:19<00:00,  9.85it/s] 


In [10]:

""" from dictionary_learning_demo/demo.py
        trainSAE(
            data=activation_buffer,
            trainer_configs=trainer_configs,
            use_wandb=use_wandb,
            steps=steps,
            save_steps=save_steps,
            save_dir=save_dir,
            log_steps=log_steps,
            wandb_project=demo_config.wandb_project,
            normalize_activations=True,
            verbose=False,
            autocast_dtype=t.bfloat16,
        )
"""

' from dictionary_learning_demo/demo.py\n        trainSAE(\n            data=activation_buffer,\n            trainer_configs=trainer_configs,\n            use_wandb=use_wandb,\n            steps=steps,\n            save_steps=save_steps,\n            save_dir=save_dir,\n            log_steps=log_steps,\n            wandb_project=demo_config.wandb_project,\n            normalize_activations=True,\n            verbose=False,\n            autocast_dtype=t.bfloat16,\n        )\n'