In [2]:
import os

try:
    has_changed_dir
except:
    has_changed_dir = False

try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
except:
    IN_COLAB = False

if IN_COLAB:
    %pip install datasets
    %pip install translate-toolkit
    %pip install bitsandbytes

    !git clone https://github.com/MarkusSibbesen/mechinterp_research_project.git

    if not has_changed_dir:
        os.chdir('mechinterp_research_project')
        has_changed_dir = True
else:
    if not has_changed_dir:
        os.chdir('.')
        has_changed_dir = True


In [3]:
import torch
from torch import nn
from torch.utils.data import random_split, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer

from src.dataset_handling import TextDataset, TextClassificationDataset
from src.utils import get_hidden_activations, activation_label_generator
from src.sparse_autoencoders import SAE_topk
import tqdm.auto as tqdm
from collections import defaultdict
from numpy import mean
import json
import pandas as pd
import os
from matplotlib import pyplot as plt
import math
from itertools import product


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# For the base model
url = "EleutherAI/pythia-14m"
hookpoints = ["gpt_neox.layers.3.mlp.act"]
out_folder = f"models/sparse_autoencoders/{url.split('/')[-1]}"
log_folder = f"training_logs/sparse_autoencoders/{url.split('/')[-1]}"

model = AutoModelForCausalLM.from_pretrained(url).to(device)
tokenizer = AutoTokenizer.from_pretrained(url)
batch_size = 32
learning_rate = 1e-3

data_path = 'data/split/tedtalks_train.tsv'
data = TextClassificationDataset.from_tsv(data_path)

dataloader = DataLoader(
    dataset=data,
    batch_size=batch_size,
    shuffle=True
)

In [5]:
# SAE data

input_size = model.config.intermediate_size

expansion_factor = 4

meta_data = {
    'input_size': input_size,
    'hidden_size': input_size * expansion_factor,
    'k': int(math.sqrt(input_size*expansion_factor)),
    "pre_encoder_bias": True,
    "activation_function": "topk"
}



In [6]:
grid_relu = {
    "input_size": [input_size],
    "hidden_size": [input_size * expansion_factor],
    "lambda_l1": [1,0.1,0.01],
    "k": [int(math.sqrt(input_size*expansion_factor))],
    "activation_function": ["relu"], 
    "pre_encoder_bias": [True, False]
}
grid_topk = {
    "input_size": [input_size],
    "hidden_size": [input_size * expansion_factor],
    "k": [int(math.sqrt(input_size*expansion_factor))],
    "activation_function": ["topk"], 
    "pre_encoder_bias": [True, False]
}



def generate_meta_data_combinations(grid_search):
    # Get the keys and values (lists of possible values) from the grid_search dictionary
    keys, values = zip(*grid_search.items())
    
    # Generate all combinations
    all_combinations = [dict(zip(keys, combination)) for combination in product(*values)]
    
    return all_combinations

#We compute the gridsearch separately for the two activation functions since topk does not have a lambda_l1 parameter
grid_search_relu = generate_meta_data_combinations(grid_relu)
grid_search_topk = generate_meta_data_combinations(grid_topk)
#grid_search = grid_search_relu + grid_search_topk
grid_search = grid_search_relu

In [7]:
class SaeTrainer():
    def __init__(self, meta_data, learning_rate, hookpoint, device):
        self.input_size = input_size
        self.learning_rate = learning_rate
        self.hookpoint = hookpoint
        self.meta_data = meta_data
        self.model = SAE_topk(meta_data).to(device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
        self.loss_fn = nn.MSELoss()

        self.batches = 0

        self.losses = []

    def compute_l1(self):
        amount_of_weights = self.model.W.shape[0] * self.model.W.shape[1]
        encoder_l1_loss = self.model.W.abs().sum()/amount_of_weights
        decoder_l1_loss = self.model.WT.abs().sum()/amount_of_weights


        return float((encoder_l1_loss + decoder_l1_loss) * self.meta_data["lambda_l1"])

    def train_step(self, input_, labels):
        outputs = self.model(input_).to(device)
        loss = self.loss_fn(outputs, labels)
        if self.meta_data["activation_function"] == "relu":
            print("l1",self.compute_l1())
            loss += self.compute_l1()
            

        active_neurons = self.model.active_neurons

        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()
        self.losses.append((loss.item(), active_neurons))
        self.batches += 1
        return loss, active_neurons

    def plot_loss(self, out_file = None):
        fig, ax = plt.subplots(1, 1, figsize=(12, 8))
        losses, active_neurons = zip(*self.losses)
        ax.plot(losses, label='Loss')
        ax2 = ax.twinx()
        ax2.plot(active_neurons, label='Active neurons', color='orange')
        ax.set_xlabel('batch nr')
        ax.set_ylabel('loss')
        ax2.set_ylabel(f'active neurons per batch (batch_size: {batch_size})')
        ax.set_title(self.hookpoint)
        fig.legend()

        if out_file:
            fig.savefig(out_file, dpi=300)
        else:
            plt.show()

In [8]:
meta_data

{'input_size': 512,
 'hidden_size': 2048,
 'k': 45,
 'pre_encoder_bias': True,
 'activation_function': 'topk'}

In [9]:
grid_search

[{'input_size': 512,
  'hidden_size': 2048,
  'lambda_l1': 0.01,
  'k': 45,
  'activation_function': 'relu',
  'pre_encoder_bias': True},
 {'input_size': 512,
  'hidden_size': 2048,
  'lambda_l1': 0.01,
  'k': 45,
  'activation_function': 'relu',
  'pre_encoder_bias': False},
 {'input_size': 512,
  'hidden_size': 2048,
  'lambda_l1': 0.001,
  'k': 45,
  'activation_function': 'relu',
  'pre_encoder_bias': True},
 {'input_size': 512,
  'hidden_size': 2048,
  'lambda_l1': 0.001,
  'k': 45,
  'activation_function': 'relu',
  'pre_encoder_bias': False},
 {'input_size': 512,
  'hidden_size': 2048,
  'lambda_l1': 0.0001,
  'k': 45,
  'activation_function': 'relu',
  'pre_encoder_bias': True},
 {'input_size': 512,
  'hidden_size': 2048,
  'lambda_l1': 0.0001,
  'k': 45,
  'activation_function': 'relu',
  'pre_encoder_bias': False}]

In [10]:
d = defaultdict(list)
sae_trainer_list =[]
for meta_data in grid_search:

    try:
        d["lambda_l1"].append(meta_data["lambda_l1"])
    except:
        d["lambda_l1"].append(None)

    d["activation_function"].append(meta_data["activation_function"])
    d["pre_encoder_bias"].append(meta_data["pre_encoder_bias"])
    sae_trainers = [
        SaeTrainer(meta_data, learning_rate, hookpoint, device)
        for hookpoint in hookpoints
    ]
    loss_list = []
    active_neurons_list = []
    for activations, _ in activation_label_generator(dataloader, model, hookpoints, tokenizer, device):
        for sae_trainer in sae_trainers:
            activation = activations[sae_trainer.hookpoint]
            label = activation.detach()
            loss, active_neurons = sae_trainer.train_step(activation, label)
            loss_list.append(loss)
            active_neurons_list.append(active_neurons)
            print(f'{loss}\t{active_neurons}', end='')
        print('')
    #Compute the average loss of the 10 last batches
    last_10_avg_mean_loss = mean([float(x) for x in loss_list[-10:]])
    d["loss"].append(last_10_avg_mean_loss)
    #Compute the average active neurons of the 10 last batches
    last_10_avg_mean_neurons = mean([float(x) for x in active_neurons_list[-10:]])
    d["active_neurons"].append(last_10_avg_mean_neurons)
    sae_trainer_list.append(sae_trainers)

  0%|          | 0/2070 [00:00<?, ?it/s]

l1 0.000159591538249515
0.0733189582824707	1897
l1 0.00016124200192280114
0.04736277088522911	1918
l1 0.0001633556530578062
0.037690676748752594	1928
l1 0.00016518363554496318
0.03285214677453041	1933
l1 0.0001666163298068568
0.029096132144331932	1935
l1 0.00016793649410828948
0.026744311675429344	1941
l1 0.00016932867583818734
0.02606246806681156	1944
l1 0.00017079572717193514
0.02478684112429619	1947
l1 0.0001722571614664048
0.022531449794769287	1946
l1 0.00017374336312059313
0.021697737276554108	1951
l1 0.00017522370035294443
0.02162349969148636	1953
l1 0.00017665899940766394
0.019508926197886467	1954
l1 0.00017804146045818925
0.01887376606464386	1957
l1 0.00017941437545232475
0.01748332753777504	1957
l1 0.0001808016822906211
0.017495375126600266	1960
l1 0.00018222356447950006
0.016477979719638824	1961
l1 0.00018366778385825455
0.015921575948596	1961
l1 0.00018511911912355572
0.014953674748539925	1967
l1 0.0001865536323748529
0.014195763505995274	1963
l1 0.00018794092466123402
0.013

KeyboardInterrupt: 

In [None]:
pd.DataFrame(d)

In [None]:
# too big to commit :((((