In [2]:
num_tokens = 50_000_000 # numbrid pärit dictionary_learning_demo/demo_config.py-st
sae_batch_size = 2048
steps = int(num_tokens / sae_batch_size) # Total number of batches to train
log_steps = 100  # Log the training on wandb or print to console every log_steps

In [3]:
from nnsight import LanguageModel
from dictionary_learning import ActivationBuffer, AutoEncoder
from dictionary_learning.trainers import StandardTrainer
from dictionary_learning.training import trainSAE

device = "cuda:0"
model_name = "EleutherAI/pythia-70m-deduped" # can be any Huggingface model

model = LanguageModel(
    model_name,
    device_map=device,
)
submodule = model.gpt_neox.layers[1].mlp # layer 1 MLP
activation_dim = 512 # output dimension of the MLP
dictionary_size = 16 * activation_dim

# data must be an iterator that outputs strings
data = iter(
    [
        "This is some example data",
        "In real life, for training a dictionary",
        "you would need much more data than this",
    ]
)
buffer = ActivationBuffer(
    data=data,
    model=model,
    submodule=submodule,
    d_submodule=activation_dim, # output dimension of the model component
    n_ctxs=3e4,  # you can set this higher or lower dependong on your available memory
    device=device,
)  # buffer will yield batches of tensors of dimension = submodule's output dimension

trainer_cfg = {
    "trainer": StandardTrainer,
    "dict_class": AutoEncoder,
    "activation_dim": activation_dim,
    "dict_size": dictionary_size,
    "lr": 1e-3,
    "device": device,
}

# train the sparse autoencoder (SAE)
ae = trainSAE(
    data=buffer,  # you could also use another (i.e. pytorch dataloader) here instead of buffer
    trainer_configs=[trainer_cfg],
    steps=steps,
)

{'trainer': <class 'dictionary_learning.trainers.standard.StandardTrainer'>, 'dict_class': <class 'dictionary_learning.dictionary.AutoEncoder'>, 'activation_dim': 512, 'dict_size': 8192, 'lr': 0.001, 'device': 'cuda:0'}


TypeError: StandardTrainer.__init__() missing 3 required positional arguments: 'steps', 'layer', and 'lm_name'

In [2]:
from dictionary_learning import AutoEncoder
import torch

In [9]:

# load autoencoder
#ae = AutoEncoder.from_pretrained("/gpfs/helios/home/jpauklin/dictionary_learning/saes/trainer_0/ae.pt") # to is rquired to load to GPU


ae = AutoEncoder.from_pretrained("/gpfs/helios/home/jpauklin/dictionary_learning/dictionaries/pythia-70m-deduped/mlp_out_layer5/10_32768/ae.pt")


# get NN activations using your preferred method: hooks, transformer_lens, nnsight, etc. ...
# for now we'll just use random activations
activations = torch.randn(64, 512) # 768 estMed
features = ae.encode(activations) # get features from activations
reconstructed_activations = ae.decode(features)

# you can also just get the reconstruction ...
reconstructed_activations = ae(activations)
# ... or get the features and reconstruction at the same time
reconstructed_activations, features = ae(activations, output_features=True)

In [10]:
activations

tensor([[ 1.6472,  1.5410, -0.9106,  ..., -1.5017,  1.5074,  0.3728],
        [ 0.1759, -0.9854, -1.1953,  ..., -0.3835,  0.5757, -0.0051],
        [-2.6950,  0.8102,  0.6362,  ...,  0.3335,  0.2446,  1.2338],
        ...,
        [-1.5096,  1.6440,  1.2469,  ...,  0.1677, -0.6865, -0.1040],
        [-0.0327,  0.3279, -0.9597,  ...,  0.9221,  1.0353,  0.3163],
        [ 0.9611, -0.2309,  0.9678,  ..., -0.1980,  0.9035, -1.1735]])

In [11]:
reconstructed_activations

tensor([[-2.8360e+00, -9.3310e+00, -1.3342e+01,  ..., -1.0414e+01,
          1.0899e+01, -4.4748e+00],
        [-1.3727e+01, -9.9600e+00, -1.5925e+01,  ..., -9.7325e+00,
          4.5918e+00, -6.0350e+00],
        [-1.6353e+01, -6.7410e+00, -4.2032e+00,  ..., -2.3968e-01,
          1.3756e+00,  1.1615e+00],
        ...,
        [-1.6605e+00,  2.3595e-01,  2.0186e-03,  ..., -1.1771e+00,
         -4.9485e-01, -3.6054e-01],
        [-5.3005e+00, -4.4510e+00, -5.8403e+00,  ..., -1.6042e+00,
          5.6609e+00, -1.0288e+00],
        [-3.3809e+00, -5.3261e-01,  4.4034e-01,  ..., -3.2447e+00,
          3.1148e+00, -2.6056e+00]], grad_fn=<AddBackward0>)