In [1]:
if not 'RAN_PIP' in locals():
    !pip install tokenizers
    RAN_PIP = True

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [2]:
import torch
import tokenizers
import llm
import os
import sae
import tqdm
import json



# expt_name = 'e2e_sae_1'
expt_name = 'vanilla_split_llm_sae'
expt_dir = f'experiments/{expt_name}'

def loadconfig():
    global config
    config = json.load(open(f"experiments/{expt_name}/config.json"))
    for k,v in config.items():
        globals()[k] = v

loadconfig()



In [3]:
data = torch.load('tiny-stories-train.pt', map_location='cuda')
n = int(0.9*len(data))

train_data = data[:n]
val_data = data[n:]


In [4]:
def get_batch_by_index(split, ix):
    data = train_data if split == 'train' else val_data
    x = torch.stack([data[i:i+T] for i in ix]) # random sequences
    y = torch.stack([data[i+1:i+T+1] for i in ix]) # next character for each random sequence

    return x, y

In [8]:
import random
llm_args = ['B', 'T', 'C', 'n_heads', 'H', 'n_layers', 'vocab_size']
llm_kwargs = {k: globals()[k] for k in llm_args}

autoencoder = sae.TopKSparseAutoencoder(C, sae_size, sae_topk)
if config.get("separate_llm", False):
    print("Loading separate LLM and SAE")
    gpt = llm.GPT(**llm_kwargs)
    autoencoder.load_state_dict(torch.load(f'{expt_dir}/sae.pt'))
    def get_latents(tokens):
        llm_out = gpt.forward(tokens, targets=None, stop_at_layer=sae_location)
        residuals = llm_out['residuals']
        sae_out = autoencoder(residuals)
        #if random.random() < 0.1:
        #    print("r2", sae_out['mean_r2'])
        #    print("top idx", sae_out['topk_idxs'][0,0])
        sparse_idxs = sae_out['topk_idxs']
        sparse_values = sae_out['topk_values']
        return sparse_idxs, sparse_values



else:
    print("Loading e2e LLM and SAE")
    gpt = llm.BottleNeckGPT(
        bottleneck_model=autoencoder,
        bottleneck_location=sae_location,
        **llm_kwargs
    )
    def get_latents(tokens):
        ret = gpt(tokens, targets=None, bottleneck_early_stop=True)
        sparse_idxs = ret['bm_results']['topk_idxs'].to(torch.int16)
        sparse_values = ret['bm_results']['topk_values'].to(torch.float16)
        return sparse_idxs, sparse_values

gpt.load_state_dict(torch.load(f'{expt_dir}/gpt.pt'))

Loading separate LLM and SAE


<All keys matched successfully>

In [9]:
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(0, data.size(0) - T, (B,)) # 4 random locations we can sample from
    x = torch.stack([data[i:i+T] for i in ix]) # random sequences
    y = torch.stack([data[i+1:i+T+1] for i in ix]) # next character for each random sequence

    return x, y

xb, yb = get_batch('train')

# for b in range(B):
#     for t in range(T): # for each of the characters in the sample
#         context = xb[b, :t+1]
#         target = yb[b, t]

get_latents(xb)


(tensor([[[ 4914,  1297,  6193,  ...,    91,  1897,  2058],
          [ 4914,  1297,  6193,  ...,  7780,  1382,  2748],
          [ 4914,  1297,  6193,  ...,  8518, 13071,  2712],
          ...,
          [ 1297,  4914, 14637,  ...,  6456,  4910, 14614],
          [ 1297,  4914, 14637,  ..., 11958,  9294, 14772],
          [ 1297,  4914, 14637,  ...,  4921, 13256,  7421]],
 
         [[ 4914,  1297,  6193,  ...,  5191,   154,  7317],
          [ 4914,  1297,  6193,  ..., 11156, 13445, 12563],
          [ 4914,  1297,  6193,  ...,   423,  5465,  1337],
          ...,
          [ 1297,  4914, 14637,  ..., 13413, 10861,  6913],
          [ 1297,  4914, 14637,  ...,  1085, 13614,  2320],
          [ 1297,  4914, 14637,  ...,  8495,  2574,   154]],
 
         [[ 4914,  1297,  6193,  ...,  6723, 11977,  3891],
          [ 1297,  4914,  6193,  ..., 11108,  9328,  1897],
          [ 1297,  4914,  6193,  ..., 10828,  1206, 15124],
          ...,
          [ 1297,  4914, 14637,  ...,  4768,  127

In [10]:
os.makedirs(f'{expt_dir}/encoded', exist_ok=True)

def write_encoded_data():
    with torch.no_grad():
        validation_tokens = val_data.shape[0]
        
        tokens_per_batch = B*T
        num_batches = validation_tokens // tokens_per_batch

        accum_idxs = []
        accum_values = []

        for i in tqdm.tqdm(range(num_batches), desc=f'encoding validation data'):  
            start = T*B * i
            end = T*B * (i+1) 

            index = torch.arange(start, end, T)
            x, y = get_batch_by_index('test', index)
            sparse_idxs, sparse_values = get_latents(x)
            accum_idxs.append(sparse_idxs)
            accum_values.append(sparse_values)

        cat_idxs = torch.cat(accum_idxs)
        cat_values = torch.cat(accum_values)
        torch.save(cat_idxs.view(-1, sae_topk), f'{expt_dir}/encoded/test_accum_idxs.pt')
        torch.save(cat_values.view(-1, sae_topk), f'{expt_dir}/encoded/test_accum_values.pt')
write_encoded_data()
        

encoding validation data: 100%|██████████| 1428/1428 [01:10<00:00, 20.12it/s]


In [None]:
idxs = torch.load(f'{expt_name}/encoded/test_accum_idxs.pt')
values = torch.load(f'{expt_name}/encoded/test_accum_values.pt')

In [None]:
print (idxs[0][0])
print (values[0][0])
print(idxs.shape)

tensor(2733, dtype=torch.int16)
tensor(16.8125, dtype=torch.float16)
torch.Size([46792704, 20])


In [None]:
tokenizer = tokenizers.ByteLevelBPETokenizer(
    "./tiny-stories-bpe-vocab.json", 
    "./tiny-stories-bpe-merges.txt"
)
def encode(text):
    return tokenizer.encode(text).ids
def decode(encoded_text):
    return tokenizer.decode(encoded_text)

def get_text_from_global_index(token_idx, context_size=10):
    token = val_data[token_idx].item()
    return decode([token]), decode(val_data[token_idx-context_size:token_idx+context_size].tolist())

print(get_text_from_global_index(int(24 * 1e6)))


(' small', ' had lots of green leaves. One day, a small seed fell from the tree and landed on the')
