In [1]:
import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer

import pandas as pd

from src.sparse_autoencoders import SAE_topk

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model_url = "EleutherAI/pythia-14m"
model_name = model_url.split('/')[-1]
model = AutoModelForCausalLM.from_pretrained(model_url).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_url)
layer = 3
k = 20

hookpoint_name = 'gpt_neox.layers.$.mlp.act'
hookpoint = hookpoint_name.replace('$', str(layer))

input_size = model.config.intermediate_size
expansion_factor = 4

meta_data = {
    'input_size': input_size,
    'hidden_size': input_size * expansion_factor,
    'k': k
}

sae = SAE_topk(meta_data=meta_data)

sae.load_state_dict(torch.load(f'models/sparse_autoencoders/pythia-14m/topk{k}/{hookpoint}.pt', weights_only=True))

acts_path = f'results/sparse_autoencoder_activations/{model_name}/{hookpoint}.csv'

df = pd.read_csv(acts_path)

In [14]:
avg_da = df['avg_da'].to_numpy()
avg_en = df['avg_en'].to_numpy()

In [15]:
avg_da

array([0.        , 0.        , 0.        , ..., 0.00562852, 0.00562852,
       0.        ])

In [None]:
def make_ablation_mask(size, neuron):
    mask = torch.zeros(size)
    mask[neuron] = 1

    return mask 

In [None]:
exceptions = []

def generate_with_sae(model, sae, hookpoint, tokenized_input, ablation_mask=None):

    def hook(module, _, outputs):


        h = torch.relu(torch.matmul(outputs, sae.WT) + sae.b1)

        if ablation_mask != None:
            for i in range(h.shape[1]):
                h[0][i][ablation_mask] = 0


        outputs = torch.matmul(h, sae.W) + sae.b2

        return outputs

    attached_hook = model.get_submodule(hookpoint).register_forward_hook(hook)

    try:
        output = model.generate(tokenized_input['input_ids'],
                                max_length=100,
                                attention_mask=tokenized_input.attention_mask,
                                pad_token_id=tokenizer.eos_token_id)
    except Exception as e:
        exceptions.append(e)

    attached_hook.remove()

    try:
        return output
    except:
        return None