In [10]:
import torch
from torch import nn
from torch.utils.data import random_split, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer

from src.dataset_handling import TextDataset, TextClassificationDataset
from src.utils import get_hidden_activations
from src.sparse_autoencoders import SAE_topk
import tqdm.auto as tqdm

import json
import einops
import pandas as pd
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# For the base model
url = "EleutherAI/pythia-14m"
hookpoint = "gpt_neox.layers.3.mlp.act"
out_folder = f"models/sparse_autoencoders/{url.split('/')[-1]}"
log_folder = f"training_logs/sparse_autoencoders/{url.split('/')[-1]}"

model = AutoModelForCausalLM.from_pretrained(url).to(device)
tokenizer = AutoTokenizer.from_pretrained(url)


data = TextClassificationDataset.from_tmx('data/parallel/tedtalks.tmx', 'da', 'en')

train, test = random_split(
    dataset=data,
    lengths=[0.7, 0.3]
)

train_dataloader = DataLoader(
    dataset=train,
    batch_size=32,
    shuffle=True
)


In [7]:
# SAE data

input_size = model.config.intermediate_size

expansion_factor = 4

meta_data = {
    'input_size': input_size,
    'hidden_size': input_size * expansion_factor,
    'k': 20
}

sae = SAE_topk(meta_data=meta_data).to(device)

In [8]:
# TRaining

logs = []

optimizer = torch.optim.Adam(sae.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

errors = []

for text_batch, _ in tqdm.tqdm(train_dataloader):

    inputs = [
        tokenizer(text, return_tensors='pt').to(device)
        for text in text_batch
    ]

    try:
        activations = get_hidden_activations(model, hookpoint, inputs)
    except Exception as e:
        errors.append(e)
        continue

    labels = activations.detach()

    outputs = sae(activations).to(device)

    loss = loss_fn(outputs, labels)

    hiddens = sae.hidden_activations
    active_neurons = len(torch.unique(hiddens.indices))
    loss.backward()


    optimizer.step()
    optimizer.zero_grad()

    logs.append((loss.item(), active_neurons))

    print(f'{loss}\t{active_neurons}')

  0%|          | 0/2070 [00:00<?, ?it/s]

0.0885373204946518	274
0.08676984161138535	270
0.0803181454539299	212
0.07574815303087234	178
0.06802526116371155	172
0.06862830370664597	166
0.06744498014450073	158
0.0623774528503418	151
0.05723094195127487	155
0.05452904477715492	160
0.051216091960668564	155
0.04785117506980896	149
0.04418979585170746	156
0.04342687129974365	153
0.041001349687576294	153
0.03955448791384697	152
0.03721132129430771	156
0.03683793544769287	151
0.03605686500668526	152
0.03468211367726326	155
0.034222736954689026	152
0.032787859439849854	154
0.03305640444159508	155
0.03156346082687378	154
0.029775042086839676	156
0.028980249539017677	155
0.029236456379294395	157
0.02975931204855442	155
0.028411831706762314	156
0.027592241764068604	154
0.027635935693979263	159
0.02784028649330139	156
0.028443319723010063	155
0.027623092755675316	156
0.026715470477938652	156
0.025947026908397675	158
0.025773227214813232	156
0.02531798556447029	156
0.024655163288116455	157
0.02478855289518833	156
0.024511465802788734	156
0.

KeyboardInterrupt: 

In [13]:
model_out_path = out_folder + f'/{hookpoint}.pt'
meta_data_out_path = out_folder + f'/{hookpoint}.json'
log_path = log_folder + f'/{hookpoint}.csv'

if not os.path.isdir(out_folder):
    os.mkdir(out_folder)
if not os.path.isdir(log_folder):
    os.mkdir(log_folder)

torch.save(sae.state_dict(), out_folder + f'/{hookpoint}.pt')
with open(meta_data_out_path, 'w') as file:
    json.dump(meta_data, file, indent=4)

loss_log, active_neurons_log = zip(*logs) 
df = pd.DataFrame({
    'loss': loss_log,
    'active_neurons': active_neurons_log
})
df.to_csv(log_path)