In [1]:
import pandas as pd
import torch as t
from nnsight import LanguageModel
from dictionary_learning import ActivationBuffer, AutoEncoder
from dictionary_learning.trainers import StandardTrainer, StandardTrainerAprilUpdate
from dictionary_learning.training import trainSAE
from dictionary_learning.utils import read_csv
import gc

In [2]:
# empty cache to free memory before any further action.
t.cuda.empty_cache() 
gc.collect()

0

In [10]:
# data must be an iterator that outputs strings
data = read_csv("/gpfs/space/projects/stacc_health/data-synthetic/100k_synthetic_texts.csv")

In [11]:
device = "cuda:0"
model_name = "/gpfs/space/projects/stacc_health/gpt2_model/estMed-gpt2_fine_tuned4/estMed-gpt2_fine_tuned4" 

# Lisada loss visualiseerimine

model = LanguageModel(
    model_name,
    device_map=device,
)
submodule = model.transformer.h[5].mlp # Take the output of the model's 12th mlp block.
activation_dim = model.transformer.h[0].ln_1.normalized_shape[0] # output dimension of the MLP = 768
dictionary_size = 16 * activation_dim # 12_288 features

buffer = ActivationBuffer( # buffer will yield batches of tensors of dimension = submodule's output dimension
    data=data,
    model=model,
    submodule=submodule,
    d_submodule=activation_dim, # output dimension of the model component
    n_ctxs=3e4,  # length of each context. you can set this higher or lower dependong on your available memory
    device=device,
)  

In [12]:
saeSaveName = input("Enter SAE saving folder (estMedSae<date>): ")

Enter SAE saving folder (estMedSae<date>): estMedSaeX16layer5


In [13]:
# CONFIG
num_tokens = 50_000_000 # numbrid pärit dictionary_learning_demo/demo_config.py-st
sae_batch_size = 3072
steps = int(num_tokens / sae_batch_size) # Total number of batches to train
log_steps = 1000  # Log the training on wandb or print to console every log_steps

save_checkpoints = True
save_dir = "./saes/" + saeSaveName
trainer_class = StandardTrainer

In [14]:
trainer_cfg = {
    "trainer": trainer_class,
    "dict_class": AutoEncoder,
    "activation_dim": activation_dim,
    "dict_size": dictionary_size,
    "lr": 1e-3,
    "device": device,
    
    "steps" : steps,
    "lm_name" : model_name,
    "layer" : 5,
    
}

In [15]:
if save_checkpoints:
    # Creates checkpoints at 0.0%, 0.1%, 0.316%, 1%, 3.16%, 10%, 31.6%, 100% of training
    desired_checkpoints = t.logspace(-3, 0, 7).tolist()
    desired_checkpoints = [0.0] + desired_checkpoints[:-1]
    desired_checkpoints.sort()
    print(f"desired_checkpoints: {desired_checkpoints}")

    save_steps = [int(steps * step) for step in desired_checkpoints]
    save_steps.sort()
    print(f"save_steps: {save_steps}")
else:
    save_steps = None

desired_checkpoints: [0.0, 0.0010000000474974513, 0.003162277629598975, 0.009999999776482582, 0.03162277489900589, 0.10000000149011612, 0.3162277638912201]
save_steps: [0, 16, 51, 162, 514, 1627, 5146]


In [16]:
if ("trainer" not in trainer_cfg): # trainer field gets deleted during trainSAE
    trainer_cfg["trainer"] = trainer_class

# train the sparse autoencoder (SAE)
ae = trainSAE(
    data=buffer,
    trainer_configs=[trainer_cfg],
    save_steps=save_steps,
    save_dir=save_dir,
    log_steps=log_steps,
    steps=steps,
    autocast_dtype=t.bfloat16, # new
    normalize_activations=True,
)


1105


Calculating norm factor:   0%|          | 0/100 [00:00<?, ?it/s]You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Calculating norm factor: 101it [01:21,  1.24it/s]                         


Average mean squared norm: 29.92813491821289
Norm factor: 5.470661163330078


100%|██████████| 16276/16276 [25:53<00:00, 10.48it/s]


In [None]:

""" from dictionary_learning_demo/demo.py
        trainSAE(
            data=activation_buffer,
            trainer_configs=trainer_configs,
            use_wandb=use_wandb,
            steps=steps,
            save_steps=save_steps,
            save_dir=save_dir,
            log_steps=log_steps,
            wandb_project=demo_config.wandb_project,
            normalize_activations=True,
            verbose=False,
            autocast_dtype=t.bfloat16,
        )
"""