In [51]:
import torch
import tokenizers
import bottleneck_llm
import os
import sae
import tqdm
config = {
    "learning_rate": 2e-3,
    "sae_learning_rate": 5e-5,
    "model_embedding_layer": 6,
    "eval_interval": 500,
    "max_iters": 60000, 
    "H": 32, # hidden dimension size
    "B": 128,
    "T": 256,
    "C": 256,
    "feedforward_factor": 3,
    "n_heads": 8,
    "n_layers": 12,
    "sae_size": 2**14,
    "sae_location": 6,
    "sae_topk": 20,
    "sae_r2_lambda": 2,
    "sae_mse_lambda": 0,

    "vocab_size": 2**13,
    'expt_name': 'e2e_sae_1',
    "git_hash": os.popen("git rev-parse HEAD").read().strip()
}

for k,v in config.items():
    locals()[k] = v


In [2]:
data = torch.load('tiny-stories-train.pt', map_location='cuda')
n = int(0.9*len(data))

train_data = data[:n]
val_data = data[n:]


In [3]:
def get_batch_by_index(split, ix):
    data = train_data if split == 'train' else val_data
    x = torch.stack([data[i:i+T] for i in ix]) # random sequences
    y = torch.stack([data[i+1:i+T+1] for i in ix]) # next character for each random sequence

    return x, y

In [4]:
autoencoder = sae.TopKSparseAutoencoder(C, sae_size, sae_topk)
b = bottleneck_llm.BottleNeckGPT(
    B=B,
    T=T,
    C=C,
    n_heads=n_heads,
    H=H,
    n_layers = n_layers,
    bottleneck_model=autoencoder,
    bottleneck_location=sae_location,
    vocab_size=vocab_size,
    )

b.load_state_dict(torch.load("e2e_sae_1/joint_composed.pt"))

<All keys matched successfully>

In [27]:
batches_per_chunk = 1000
tokens_per_batch = T*B

def write_encoded_data():
    with torch.no_grad():
        b.eval()
        b.bottleneck_model.eval()

        
        for chunk in range(10):
            print('on chunk', chunk)
            accum_idxs = []
            accum_values = []

            chunk_start = chunk * batches_per_chunk * tokens_per_batch
            for i in tqdm.tqdm(range(batches_per_chunk), desc='batch'):  
                start = T*B * i + chunk_start
                end = T*B * (i+1) + chunk_start

                index = torch.arange(start, end, T)
                x, y = get_batch_by_index('train', index)
                ret = b.forward(x, targets=None, bottleneck_early_stop=True)
                sparse_idxs = ret['bm_results']['topk_idxs'].to(torch.int16).to('cpu')
                sparse_values = ret['bm_results']['topk_values'].to(torch.float16).to('cpu')
                accum_idxs.append(sparse_idxs)
                accum_values.append(sparse_values)

            cat_idxs = torch.cat(accum_idxs)
            cat_values = torch.cat(accum_values)
            torch.save(cat_idxs, f'{expt_name}/encoded/accum_idxs-{chunk:03d}.pt')
            torch.save(cat_values, f'{expt_name}/encoded/accum_values-{chunk:03d}.pt')
# write_encoded_data()
        

on chunk 0


batch: 100%|██████████| 1000/1000 [00:52<00:00, 19.17it/s]


on chunk 1


batch: 100%|██████████| 1000/1000 [00:52<00:00, 19.00it/s]


on chunk 2


batch: 100%|██████████| 1000/1000 [00:53<00:00, 18.69it/s]


on chunk 3


batch: 100%|██████████| 1000/1000 [00:53<00:00, 18.84it/s]


on chunk 4


batch: 100%|██████████| 1000/1000 [00:53<00:00, 18.70it/s]


on chunk 5


batch: 100%|██████████| 1000/1000 [00:52<00:00, 19.06it/s]


on chunk 6


batch: 100%|██████████| 1000/1000 [00:53<00:00, 18.72it/s]


on chunk 7


batch: 100%|██████████| 1000/1000 [00:52<00:00, 19.08it/s]


on chunk 8


batch: 100%|██████████| 1000/1000 [00:51<00:00, 19.24it/s]


on chunk 9


batch: 100%|██████████| 1000/1000 [00:52<00:00, 19.08it/s]


In [31]:
idxs = torch.load(f'{expt_name}/encoded/accum_idxs-000.pt')
values = torch.load(f'{expt_name}/encoded/accum_values-000.pt')

In [36]:
print (idxs[0][0])
print (values[0][0])
print(idxs.shape)

tensor([ 2733, 10675,  4032,  5123, 13338, 14089, 11776, 14339, 10797,  5472,
           17, 14080,  2563,  6662,  5900, 14191,  7483,  6661,  8251,  2184],
       dtype=torch.int16)
tensor([19.2344,  6.7383,  6.6094,  5.8867,  5.8672,  5.4062,  4.7773,  3.0020,
         2.9258,  2.6289,  2.0625,  1.5215,  0.7817,  0.7651,  0.4709,  0.4026,
         0.3823,  0.3687,  0.3574,  0.1311], dtype=torch.float16)
torch.Size([128000, 256, 20])


In [55]:
def global_token_index(chunk_no, batch_no, token_no):
    chunk_start = tokens_per_batch * batches_per_chunk * chunk_no 
    batch_start = tokens_per_batch * batch_no
    token_start = token_no
    return chunk_start + batch_start + token_start

print('(2, 321, 5321) -> ', global_token_index(2, 321, 5321))
for arg in [(0, 0, 1), (0, 1, 0), (1, 0, 0), (0, 0, 32767), (0, 999, 0), (9, 0, 0)]:
    print(f'({arg[0]:5d}, {arg[1]:5d}, {arg[2]:5d}) {str(global_token_index(*arg)).rjust(20)}')


(2, 321, 5321) ->  76059849
(    0,     0,     1)                    1
(    0,     1,     0)                32768
(    1,     0,     0)             32768000
(    0,     0, 32767)                32767
(    0,   999,     0)             32735232
(    9,     0,     0)            294912000


In [61]:
tokenizer = tokenizers.ByteLevelBPETokenizer(
    "./tiny-stories-bpe-vocab.json", 
    "./tiny-stories-bpe-merges.txt"
)
def encode(text):
    return tokenizer.encode(text).ids
def decode(encoded_text):
    return tokenizer.decode(encoded_text)

def get_text_from_global_index(chunk_no, batch_no, token_no, context_size=10):
    idx = global_token_index(chunk_no, batch_no, token_no)
    return decode(train_data[idx-context_size:idx+context_size].tolist())

print(get_text_from_global_index(2, 3213, 40))


 mommy and said, "Mommy, look at the big balloon! It makes me so happy!"

