In [1]:
from sae import TopKSAE, TopKSAEConfig
from generate_residuals import EmbeddingGeneratorConfig


import torch
import os
from tqdm import tqdm
from matplotlib import pyplot as plt

device= 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_default_device(device)
assert device == 'cuda', "This notebook is not optimized for CPU"
torch.autograd.set_grad_enabled(False)


block_size: 512
n_layer: 14
n_head: 16
n_embd: 512
feed_forward_factor: 2.5
vocab_size: 8192
data_dir: dataset
expt_name: restart_good_3hr_search
batch_size: 128
max_lr: 0.002
min_lr: 0.0001
beta_1: 0.9
beta_2: 0.99
warmup_steps: 50
max_steps: 60000
max_runtime_seconds: 10800
weight_decay: 0.12
need_epoch_reshuffle: True
matmul_precision: high
smoke_test: False


<torch.autograd.grad_mode.set_grad_enabled at 0x7e5fd8086b60>

In [2]:
s = torch.load('saes/sae.pt')
sae = TopKSAE(s['config'])
sae.load_state_dict(s['model'])

<All keys matched successfully>

In [3]:
embdconfig = torch.load("./residuals/0.pt")['config']

In [4]:
embdconfig

EmbeddingGeneratorConfig(batch_size=512, block_size=512, n_embd=512, ratio_tokens_saved=0.07, residual_layer=10, mb_per_save=2000, save_dir='./residuals/')

In [5]:
residuals_dir = embdconfig.save_dir

residuals_files = [os.path.join(residuals_dir, f) for f in os.listdir(residuals_dir)]

In [6]:
residuals_files[0]

'./residuals/16.pt'

In [7]:
residual = torch.load(residuals_files[0])

In [8]:
residual.keys() # dict_keys(['residuals', 'token_idxs', 'context_window_starts', 'config'])


dict_keys(['residuals', 'token_idxs', 'token_values', 'context_window_starts', 'config'])

In [9]:
sae.config

TopKSAEConfig(embedding_size=512, n_features=32768, topk=24, lr=0.001, batch_size=4096)

In [10]:
def crop_fit_batch(tensor, batch_size):
    n = tensor.size(0)
    n_batches = n // batch_size
    n = n_batches * batch_size
    return tensor[:n]

In [11]:


def sae_process_residuals(residuals_path, sae):
    topk_idxs_per_token = []
    topk_strengths_per_token = []

    residual = torch.load(residuals_path)
    residuals = residual['residuals'].to(torch.float32)
    token_idxs = residual['token_idxs']
    context_window_starts = residual['context_window_starts']
    config = residual['config']
    sae.eval()
    cropped = crop_fit_batch(residuals, sae.config.batch_size)
    batches = cropped.view(-1, sae.config.batch_size, sae.config.embedding_size)
    with torch.no_grad():
        for batch in tqdm(batches):
            sae_out = sae(batch)
            topk_idxs = sae_out['topk_idxs']
            topk_values = sae_out['topk_values']
            topk_idxs_per_token.append(topk_idxs)
            topk_strengths_per_token.append(topk_values)
        topk_idxs = torch.cat(topk_idxs_per_token, dim=0)
        topk_strengths = torch.cat(topk_strengths_per_token, dim=0)
    return topk_idxs, topk_strengths, token_idxs, context_window_starts


topk_idxs_per_token, topk_strengths_per_token, dataset_token_location_idxs, dataset_context_window_starts = sae_process_residuals(residuals_files[-1], sae)

100%|██████████| 479/479 [00:04<00:00, 96.04it/s] 


In [12]:
dataset_context_window_starts = torch.tensor(dataset_context_window_starts)
dataset_token_location_idxs = torch.tensor(dataset_token_location_idxs)

sort_idxs = torch.argsort(dataset_token_location_idxs)
sorted_context_window_starts = dataset_context_window_starts[sort_idxs]
sorted_token_idxs = dataset_token_location_idxs[sort_idxs]

In [13]:
idx = 55
sorted_context_window_starts[idx], sorted_token_idxs[idx]

(tensor(512, device='cuda:0'), tensor(795, device='cuda:0'))

In [26]:
random_feature = topk_idxs_per_token.view(-1)[6009]

flat_topk_idxs = topk_idxs_per_token.view(-1)

random_feature_indexes = torch.where(flat_topk_idxs == random_feature)[0]
random_feature_strengths = topk_strengths_per_token.view(-1)[random_feature_indexes]

random_feature_subset_strengths, random_feature_subset_flat_feature_idxs = random_feature_strengths.topk(10)
# the indexes are relative to the random_feature_indexes(subset of data), we need to map back to the global indexes
random_feature_subset_flat_feature_idxs = random_feature_indexes[random_feature_subset_flat_feature_idxs]

# random_feature_subset_location_idx
random_feature_subset_token_idxs = random_feature_subset_flat_feature_idxs // sae.config.topk
random_feature_subset_topk_idxs = random_feature_subset_flat_feature_idxs % sae.config.topk

print("selected feature:", random_feature)

token_idx = random_feature_subset_token_idxs[1]
topk_idx = random_feature_subset_topk_idxs[1]

topk_idxs_per_token[token_idx, topk_idx] # as expected, the selected feature via lookup is the same as the random feature

selected feature: tensor(20184, device='cuda:0')


tensor(20184, device='cuda:0')

In [27]:
from gpt import preprocess_tokens_from_huggingface
import transformers


enc = transformers.AutoTokenizer.from_pretrained('activated-ai/tiny-stories-8k-tokenizer')
preprocess_tokens_from_huggingface("./datasets")


train =  torch.load("datasets/train.pt", map_location=device).to(torch.int64)


skipping token preprocessing for validation : using cache ./datasets/validation.pt
skipping token preprocessing for train : using cache ./datasets/train.pt


In [29]:
def get_context(token_idx):
    start = token_idx - 10
    end = token_idx + 3
    return enc.decode(train[start:token_idx])+"<"+enc.decode(train[token_idx])+">"+enc.decode(train[token_idx+1:end])



for token_idx in random_feature_subset_token_idxs:
    token_idx = token_idx.item()
    dataset_token_idx = dataset_token_location_idxs[token_idx]
    print(get_context(dataset_token_idx))
    print("-"*10)


 each other but with my pistol!" The other friend< agreed>. They
----------
. "It will be fun!"

Ben< agreed>. They
----------
 that they should go on an adventure. The squirrel< agreed>, so
----------
 fight! It will be so much fun!" Jo< agreed>.

----------
 go on an adventure!â€ Celery nodded< in> agreement and
----------
 see a movie," she said. Ben nodded his< head> and smiled
----------
 some more vegetables! They'll be cheap!" Peter< agreed>, so
----------
ggy, let us start!"
Iggy barked< in> excitement and
----------
s go explore the wild!â€ Jim agreed<,> and so
----------
 them to different places." Ben says, "OK<,> that sounds
----------
