In [1]:
if not 'RAN_PIP' in locals():
    !pip install tokenizers
    RAN_PIP = True



[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [2]:
import torch
import tokenizers
import llm
import os
import sae
import tqdm
import json



# expt_name = 'e2e_sae_1'
expt_name = 'vanilla_split_llm_sae'
expt_dir = f'experiments/{expt_name}'

def loadconfig():
    global config
    config = json.load(open(f"experiments/{expt_name}/config.json"))
    for k,v in config.items():
        globals()[k] = v

loadconfig()



In [3]:
data = torch.load('tiny-stories-train.pt', map_location='cuda')
n = int(0.9*len(data))

train_data = data[:n]
val_data = data[n:]


In [4]:
def get_batch_by_index(split, ix):
    data = train_data if split == 'train' else val_data
    x = torch.stack([data[i:i+T] for i in ix]) # random sequences
    y = torch.stack([data[i+1:i+T+1] for i in ix]) # next character for each random sequence

    return x, y

In [5]:
import random
llm_args = ['B', 'T', 'C', 'n_heads', 'H', 'n_layers', 'vocab_size']
llm_kwargs = {k: globals()[k] for k in llm_args}

autoencoder = sae.TopKSparseAutoencoder(C, sae_size, sae_topk)
if config.get("separate_llm", False):
    print("Loading separate LLM and SAE")
    gpt = llm.GPT(**llm_kwargs)
    autoencoder.load_state_dict(torch.load(f'{expt_dir}/sae.pt'))
    def get_latents(tokens):
        llm_out = gpt.forward(tokens, targets=None, stop_at_layer=sae_location)
        residuals = llm_out['residuals']
        sae_out = autoencoder(residuals, return_r2=True)
        if random.random() < 0.01:
            print("r2", sae_out['mean_r2'])
        sparse_idxs = sae_out['topk_idxs']
        sparse_values = sae_out['topk_values']
        return sparse_idxs, sparse_values



else:
    print("Loading e2e LLM and SAE")
    gpt = llm.BottleNeckGPT(
        bottleneck_model=autoencoder,
        bottleneck_location=sae_location,
        **llm_kwargs
    )
    def get_latents(tokens):
        ret = gpt(tokens, targets=None, bottleneck_early_stop=True)
        sparse_idxs = ret['bm_results']['topk_idxs'].to(torch.int16)
        sparse_values = ret['bm_results']['topk_values'].to(torch.float16)
        return sparse_idxs, sparse_values

gpt.load_state_dict(torch.load(f'{expt_dir}/gpt.pt'))

Loading separate LLM and SAE


<All keys matched successfully>

In [6]:
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(0, data.size(0) - T, (B,)) # 4 random locations we can sample from
    x = torch.stack([data[i:i+T] for i in ix]) # random sequences
    y = torch.stack([data[i+1:i+T+1] for i in ix]) # next character for each random sequence

    return x, y

xb, yb = get_batch('train')

# for b in range(B):
#     for t in range(T): # for each of the characters in the sample
#         context = xb[b, :t+1]
#         target = yb[b, t]

get_latents(xb)


(tensor([[[12803, 14861,  7965,  ..., 12070, 10607, 13137],
          [12803, 14861,  7965,  ..., 12070,  9251, 10855],
          [12803, 14861,  7965,  ...,  5734, 10855, 10752],
          ...,
          [12803,  2048,  7965,  ...,  5131,  8777,  7189],
          [12803,  2048,  7965,  ...,  3655,  3620, 12070],
          [12803,  2048,  7965,  ...,  3090,  9700,  8689]],
 
         [[12803, 14861,  7965,  ...,  8750, 13213, 10613],
          [12803, 14861,  7965,  ...,  5315, 13677, 14508],
          [12803, 14861,  7965,  ...,  2336, 12013, 12808],
          ...,
          [12803,  2048,  7965,  ..., 10980, 10795, 11331],
          [12803,  2048,  7965,  ...,  9793, 12242,  8454],
          [12803,  2048,  7965,  ...,  8714, 16085, 15934]],
 
         [[12803, 14861,  7965,  ...,  5206, 14217, 11598],
          [12803, 14861,  7965,  ..., 12288,  4188, 10072],
          [12803, 14861,  7965,  ...,  2583, 13887,  5734],
          ...,
          [12803,  2048,  7965,  ...,  6672, 1229

In [8]:
os.makedirs(f'{expt_dir}/encoded', exist_ok=True)

def write_encoded_data():
    with torch.no_grad():
        validation_tokens = val_data.shape[0]
        
        tokens_per_batch = B*T
        num_batches = validation_tokens // tokens_per_batch

        accum_idxs = []
        accum_values = []

        for i in tqdm.tqdm(range(num_batches), desc=f'encoding validation data'):  
            start = T*B * i
            end = T*B * (i+1) 

            index = torch.arange(start, end, T)
            x, y = get_batch_by_index('test', index)
            sparse_idxs, sparse_values = get_latents(x)
            accum_idxs.append(sparse_idxs)
            accum_values.append(sparse_values)

        cat_idxs = torch.cat(accum_idxs)
        cat_values = torch.cat(accum_values)
        torch.save(cat_idxs.view(-1, sae_topk), f'{expt_dir}/encoded/test_accum_idxs.pt')
        torch.save(cat_values.view(-1, sae_topk), f'{expt_dir}/encoded/test_accum_values.pt')
write_encoded_data()
        

encoding validation data:   1%|          | 15/1428 [00:00<01:08, 20.57it/s]

r2 tensor(0.8955, device='cuda:0')


encoding validation data:   1%|▏         | 21/1428 [00:01<01:07, 20.78it/s]

r2 tensor(0.8966, device='cuda:0')
r2 tensor(0.8902, device='cuda:0')


encoding validation data:   3%|▎         | 42/1428 [00:02<01:07, 20.43it/s]

r2 tensor(0.8881, device='cuda:0')


encoding validation data:   4%|▍         | 60/1428 [00:02<01:06, 20.70it/s]

r2 tensor(0.8943, device='cuda:0')


encoding validation data:   5%|▍         | 66/1428 [00:03<01:05, 20.83it/s]

r2 tensor(0.8887, device='cuda:0')


encoding validation data:   5%|▌         | 72/1428 [00:03<01:05, 20.59it/s]

r2 tensor(0.8945, device='cuda:0')


encoding validation data:   6%|▌         | 81/1428 [00:03<01:08, 19.62it/s]

r2 tensor(0.8891, device='cuda:0')
r2 tensor(0.8925, device='cuda:0')


encoding validation data:   6%|▌         | 87/1428 [00:04<01:04, 20.76it/s]

r2 tensor(0.8974, device='cuda:0')


encoding validation data:   7%|▋         | 96/1428 [00:04<01:04, 20.51it/s]

r2 tensor(0.8878, device='cuda:0')


encoding validation data:   7%|▋         | 105/1428 [00:05<01:04, 20.42it/s]

r2 tensor(0.8894, device='cuda:0')


encoding validation data:   8%|▊         | 111/1428 [00:05<01:03, 20.70it/s]

r2 tensor(0.8887, device='cuda:0')


encoding validation data:   9%|▉         | 132/1428 [00:06<01:03, 20.42it/s]

r2 tensor(0.8879, device='cuda:0')


encoding validation data:  11%|█         | 150/1428 [00:07<01:02, 20.40it/s]

r2 tensor(0.8873, device='cuda:0')


encoding validation data:  11%|█         | 159/1428 [00:07<01:02, 20.39it/s]

r2 tensor(0.8900, device='cuda:0')
r2 tensor(0.9000, device='cuda:0')


encoding validation data:  12%|█▏        | 174/1428 [00:08<01:01, 20.37it/s]

r2 tensor(0.8870, device='cuda:0')


encoding validation data:  13%|█▎        | 183/1428 [00:08<01:01, 20.37it/s]

r2 tensor(0.8874, device='cuda:0')
r2 

encoding validation data:  13%|█▎        | 189/1428 [00:09<00:59, 20.67it/s]

tensor(0.8904, device='cuda:0')


encoding validation data:  15%|█▌        | 219/1428 [00:10<00:59, 20.39it/s]

r2 tensor(0.8905, device='cuda:0')


encoding validation data:  16%|█▌        | 225/1428 [00:11<00:58, 20.64it/s]

r2 tensor(0.8877, device='cuda:0')
r2 tensor(0.8897, device='cuda:0')
r2 tensor(0.8986, device='cuda:0')


encoding validation data:  17%|█▋        | 243/1428 [00:11<00:58, 20.42it/s]

r2 tensor(0.8915, device='cuda:0')


encoding validation data:  18%|█▊        | 252/1428 [00:12<00:57, 20.39it/s]

r2 tensor(0.8958, device='cuda:0')


encoding validation data:  19%|█▉        | 276/1428 [00:13<00:56, 20.39it/s]

r2 tensor(0.8872, device='cuda:0')
r2 tensor(0.8916, device='cuda:0')


encoding validation data:  20%|██        | 291/1428 [00:14<00:55, 20.38it/s]

r2 tensor(0.8863, device='cuda:0')


encoding validation data:  21%|██        | 297/1428 [00:14<00:55, 20.37it/s]

r2 tensor(0.8890, device='cuda:0')


encoding validation data:  21%|██▏       | 306/1428 [00:15<00:54, 20.68it/s]

r2 tensor(0.8910, device='cuda:0')


encoding validation data:  22%|██▏       | 312/1428 [00:15<00:53, 20.80it/s]

r2 tensor(0.8908, device='cuda:0')
r2 tensor(0.8885, device='cuda:0')
r2 

encoding validation data:  22%|██▏       | 315/1428 [00:15<00:53, 20.64it/s]

tensor(0.8944, device='cuda:0')
r2 tensor(0.8880, device='cuda:0')


encoding validation data:  23%|██▎       | 327/1428 [00:16<00:53, 20.43it/s]

r2 tensor(0.8961, device='cuda:0')
r2 tensor(0.8903, device='cuda:0')


encoding validation data:  24%|██▎       | 336/1428 [00:16<00:52, 20.67it/s]

r2 tensor(0.8948, device='cuda:0')


encoding validation data:  24%|██▍       | 342/1428 [00:16<00:52, 20.82it/s]

r2 tensor(0.8952, device='cuda:0')


encoding validation data:  25%|██▍       | 351/1428 [00:17<00:52, 20.51it/s]

r2 tensor(0.8866, device='cuda:0')


encoding validation data:  25%|██▌       | 357/1428 [00:17<00:51, 20.75it/s]

r2 tensor(0.8938, device='cuda:0')


encoding validation data:  27%|██▋       | 384/1428 [00:18<00:51, 20.39it/s]

r2 tensor(0.8979, device='cuda:0')
r2 tensor(0.8917, device='cuda:0')


encoding validation data:  28%|██▊       | 405/1428 [00:19<00:50, 20.36it/s]

r2 tensor(0.8908, device='cuda:0')


encoding validation data:  29%|██▉       | 411/1428 [00:20<00:49, 20.67it/s]

r2 tensor(0.8915, device='cuda:0')
r2 

encoding validation data:  29%|██▉       | 414/1428 [00:20<00:49, 20.55it/s]

tensor(0.8987, device='cuda:0')


encoding validation data:  30%|███       | 429/1428 [00:21<00:48, 20.69it/s]

r2 tensor(0.8881, device='cuda:0')


encoding validation data:  30%|███       | 435/1428 [00:21<00:47, 20.82it/s]

r2 tensor(0.8913, device='cuda:0')


encoding validation data:  32%|███▏      | 450/1428 [00:22<00:47, 20.44it/s]

r2 tensor(0.8902, device='cuda:0')


encoding validation data:  33%|███▎      | 468/1428 [00:22<00:46, 20.67it/s]

r2 tensor(0.8954, device='cuda:0')


encoding validation data:  33%|███▎      | 474/1428 [00:23<00:49, 19.39it/s]

r2 tensor(0.8887, device='cuda:0')
r2 tensor(0.8921, device='cuda:0')
r2 

encoding validation data:  34%|███▎      | 480/1428 [00:23<00:48, 19.74it/s]

tensor(0.8903, device='cuda:0')
r2 tensor(0.8947, device='cuda:0')
r2 tensor(0.8890, device='cuda:0')


encoding validation data:  34%|███▍      | 489/1428 [00:24<00:45, 20.65it/s]

r2 tensor(0.8948, device='cuda:0')


encoding validation data:  35%|███▍      | 495/1428 [00:24<00:45, 20.50it/s]

r2 tensor(0.8893, device='cuda:0')


encoding validation data:  35%|███▌      | 504/1428 [00:24<00:44, 20.71it/s]

r2 tensor(0.8922, device='cuda:0')


encoding validation data:  37%|███▋      | 528/1428 [00:25<00:43, 20.70it/s]

r2 tensor(0.8917, device='cuda:0')


encoding validation data:  37%|███▋      | 534/1428 [00:26<00:43, 20.53it/s]

r2 tensor(0.8900, device='cuda:0')


encoding validation data:  38%|███▊      | 543/1428 [00:26<00:45, 19.59it/s]

r2 tensor(0.8864, device='cuda:0')
r2 tensor(0.8890, device='cuda:0')


encoding validation data:  39%|███▉      | 558/1428 [00:27<00:42, 20.49it/s]

r2 tensor(0.8873, device='cuda:0')
r2 tensor(0.8963, device='cuda:0')
r2 

encoding validation data:  39%|███▉      | 564/1428 [00:27<00:42, 20.41it/s]

tensor(0.8947, device='cuda:0')


encoding validation data:  41%|████      | 579/1428 [00:28<00:41, 20.38it/s]

r2 tensor(0.8880, device='cuda:0')


encoding validation data:  41%|████▏     | 591/1428 [00:29<00:41, 20.36it/s]

r2 tensor(0.8957, device='cuda:0')
r2 tensor(0.8924, device='cuda:0')


encoding validation data:  42%|████▏     | 600/1428 [00:29<00:40, 20.35it/s]

r2 tensor(0.8910, device='cuda:0')


encoding validation data:  44%|████▎     | 624/1428 [00:30<00:38, 20.68it/s]

r2 tensor(0.8924, device='cuda:0')


encoding validation data:  47%|████▋     | 669/1428 [00:32<00:37, 20.35it/s]

r2 tensor(0.8927, device='cuda:0')


encoding validation data:  48%|████▊     | 684/1428 [00:33<00:36, 20.37it/s]

r2 tensor(0.8986, device='cuda:0')


encoding validation data:  48%|████▊     | 690/1428 [00:33<00:36, 20.37it/s]

r2 tensor(0.8880, device='cuda:0')
r2 tensor(0.8918, device='cuda:0')


encoding validation data:  50%|████▉     | 708/1428 [00:34<00:35, 20.37it/s]

r2 tensor(0.8944, device='cuda:0')


encoding validation data:  50%|█████     | 720/1428 [00:35<00:34, 20.36it/s]

r2 tensor(0.8905, device='cuda:0')


encoding validation data:  51%|█████▏    | 735/1428 [00:36<00:34, 20.37it/s]

r2 tensor(0.8966, device='cuda:0')


encoding validation data:  52%|█████▏    | 741/1428 [00:36<00:33, 20.36it/s]

r2 tensor(0.8921, device='cuda:0')


encoding validation data:  53%|█████▎    | 759/1428 [00:37<00:32, 20.37it/s]

r2 tensor(0.8900, device='cuda:0')


encoding validation data:  54%|█████▎    | 765/1428 [00:37<00:32, 20.67it/s]

r2 tensor(0.8933, device='cuda:0')


encoding validation data:  54%|█████▍    | 771/1428 [00:37<00:32, 20.50it/s]

r2 tensor(0.8956, device='cuda:0')


encoding validation data:  55%|█████▌    | 786/1428 [00:38<00:31, 20.68it/s]

r2 tensor(0.8910, device='cuda:0')
r2 tensor(0.8906, device='cuda:0')


encoding validation data:  56%|█████▌    | 795/1428 [00:39<00:30, 20.76it/s]

r2 tensor(0.8918, device='cuda:0')
r2 tensor(0.8910, device='cuda:0')


encoding validation data:  56%|█████▋    | 804/1428 [00:39<00:30, 20.50it/s]

r2 tensor(0.8920, device='cuda:0')


encoding validation data:  57%|█████▋    | 813/1428 [00:39<00:29, 20.69it/s]

r2 tensor(0.8951, device='cuda:0')
r2 tensor(0.8893, device='cuda:0')


encoding validation data:  58%|█████▊    | 822/1428 [00:40<00:29, 20.77it/s]

r2 tensor(0.8990, device='cuda:0')


encoding validation data:  58%|█████▊    | 828/1428 [00:40<00:29, 20.51it/s]

r2 tensor(0.8983, device='cuda:0')


encoding validation data:  59%|█████▉    | 840/1428 [00:41<00:28, 20.67it/s]

r2 tensor(0.8876, device='cuda:0')
r2 tensor(0.8899, device='cuda:0')


encoding validation data:  60%|██████    | 861/1428 [00:42<00:29, 19.26it/s]

r2 tensor(0.8898, device='cuda:0')
r2 tensor(0.8969, device='cuda:0')


encoding validation data:  62%|██████▏   | 879/1428 [00:43<00:26, 20.70it/s]

r2 tensor(0.8934, device='cuda:0')


encoding validation data:  63%|██████▎   | 900/1428 [00:44<00:25, 20.36it/s]

r2 tensor(0.8924, device='cuda:0')
r2 

encoding validation data:  63%|██████▎   | 906/1428 [00:44<00:25, 20.32it/s]

tensor(0.8883, device='cuda:0')
r2 tensor(0.8954, device='cuda:0')


encoding validation data:  64%|██████▎   | 909/1428 [00:44<00:27, 19.19it/s]

r2 tensor(0.8900, device='cuda:0')
r2 tensor(0.8905, device='cuda:0')


encoding validation data:  64%|██████▍   | 918/1428 [00:45<00:24, 20.46it/s]

r2 tensor(0.8922, device='cuda:0')


encoding validation data:  64%|██████▍   | 921/1428 [00:45<00:24, 20.41it/s]

r2 tensor(0.8992, device='cuda:0')


encoding validation data:  65%|██████▍   | 927/1428 [00:45<00:25, 19.28it/s]

r2 tensor(0.8916, device='cuda:0')
r2 tensor(0.8983, device='cuda:0')


encoding validation data:  68%|██████▊   | 969/1428 [00:47<00:22, 20.64it/s]

r2 tensor(0.8896, device='cuda:0')


encoding validation data:  68%|██████▊   | 975/1428 [00:47<00:22, 20.48it/s]

r2 tensor(0.8879, device='cuda:0')


encoding validation data:  69%|██████▉   | 987/1428 [00:48<00:21, 20.67it/s]

r2 tensor(0.8956, device='cuda:0')


encoding validation data:  70%|██████▉   | 993/1428 [00:48<00:21, 20.48it/s]

r2 tensor(0.8900, device='cuda:0')


encoding validation data:  70%|██████▉   | 996/1428 [00:48<00:21, 20.43it/s]

r2 tensor(0.8991, device='cuda:0')


encoding validation data:  72%|███████▏  | 1035/1428 [00:50<00:19, 20.34it/s]

r2 tensor(0.8907, device='cuda:0')


encoding validation data:  74%|███████▎  | 1050/1428 [00:51<00:18, 20.34it/s]

r2 tensor(0.8890, device='cuda:0')


encoding validation data:  75%|███████▍  | 1065/1428 [00:52<00:17, 20.59it/s]

r2 tensor(0.8882, device='cuda:0')


encoding validation data:  75%|███████▌  | 1077/1428 [00:52<00:17, 19.54it/s]

r2 tensor(0.8934, device='cuda:0')
r2 tensor(0.8890, device='cuda:0')


encoding validation data:  76%|███████▌  | 1082/1428 [00:53<00:16, 21.04it/s]

r2 tensor(0.8919, device='cuda:0')


encoding validation data:  79%|███████▉  | 1127/1428 [00:55<00:14, 20.62it/s]

r2 tensor(0.8906, device='cuda:0')


encoding validation data:  80%|███████▉  | 1142/1428 [00:56<00:14, 20.37it/s]

r2 tensor(0.8942, device='cuda:0')
r2 

encoding validation data:  80%|████████  | 1148/1428 [00:56<00:13, 20.59it/s]

tensor(0.8877, device='cuda:0')


encoding validation data:  81%|████████  | 1157/1428 [00:56<00:13, 20.38it/s]

r2 tensor(0.8912, device='cuda:0')


encoding validation data:  82%|████████▏ | 1166/1428 [00:57<00:12, 20.33it/s]

r2 tensor(0.8980, device='cuda:0')


encoding validation data:  82%|████████▏ | 1178/1428 [00:57<00:12, 20.28it/s]

r2 tensor(0.8927, device='cuda:0')
r2 tensor(0.8929, device='cuda:0')


encoding validation data:  83%|████████▎ | 1184/1428 [00:58<00:12, 20.30it/s]

r2 tensor(0.8944, device='cuda:0')


encoding validation data:  85%|████████▍ | 1208/1428 [00:59<00:11, 19.50it/s]

r2 tensor(0.8980, device='cuda:0')
r2 tensor(0.8889, device='cuda:0')


encoding validation data:  85%|████████▌ | 1217/1428 [00:59<00:10, 20.43it/s]

r2 tensor(0.8884, device='cuda:0')


encoding validation data:  86%|████████▌ | 1226/1428 [01:00<00:09, 20.34it/s]

r2 tensor(0.8911, device='cuda:0')


encoding validation data:  87%|████████▋ | 1238/1428 [01:00<00:09, 20.63it/s]

r2 tensor(0.8891, device='cuda:0')


encoding validation data:  87%|████████▋ | 1244/1428 [01:01<00:08, 20.75it/s]

r2 tensor(0.8903, device='cuda:0')


encoding validation data:  88%|████████▊ | 1253/1428 [01:01<00:08, 19.61it/s]

r2 tensor(0.8931, device='cuda:0')
r2 tensor(0.8905, device='cuda:0')


encoding validation data:  89%|████████▉ | 1271/1428 [01:02<00:07, 20.74it/s]

r2 tensor(0.8879, device='cuda:0')


encoding validation data:  91%|█████████ | 1295/1428 [01:03<00:06, 20.35it/s]

r2 tensor(0.8874, device='cuda:0')


encoding validation data:  93%|█████████▎| 1328/1428 [01:05<00:04, 20.31it/s]

r2 tensor(0.8879, device='cuda:0')


encoding validation data:  93%|█████████▎| 1334/1428 [01:05<00:04, 20.30it/s]

r2 tensor(0.8910, device='cuda:0')


encoding validation data:  94%|█████████▍| 1340/1428 [01:05<00:04, 20.30it/s]

r2 tensor(0.8907, device='cuda:0')
r2 tensor(0.8897, device='cuda:0')


encoding validation data:  94%|█████████▍| 1346/1428 [01:06<00:04, 19.21it/s]

r2 tensor(0.8853, device='cuda:0')
r2 tensor(0.8941, device='cuda:0')


encoding validation data:  95%|█████████▍| 1352/1428 [01:06<00:03, 20.79it/s]

r2 tensor(0.8952, device='cuda:0')


encoding validation data:  95%|█████████▌| 1361/1428 [01:06<00:03, 20.48it/s]

r2 tensor(0.8917, device='cuda:0')


encoding validation data:  96%|█████████▌| 1364/1428 [01:07<00:03, 20.42it/s]

r2 tensor(0.8894, device='cuda:0')


encoding validation data:  97%|█████████▋| 1388/1428 [01:08<00:01, 20.31it/s]

r2 tensor(0.8897, device='cuda:0')


encoding validation data:  98%|█████████▊| 1400/1428 [01:08<00:01, 19.18it/s]

r2 tensor(0.8913, device='cuda:0')
r2 tensor(0.8989, device='cuda:0')


encoding validation data: 100%|██████████| 1428/1428 [01:10<00:00, 20.32it/s]


r2 tensor(0.8869, device='cuda:0')


In [None]:
idxs = torch.load(f'{expt_name}/encoded/test_accum_idxs.pt')
values = torch.load(f'{expt_name}/encoded/test_accum_values.pt')

In [None]:
print (idxs[0][0])
print (values[0][0])
print(idxs.shape)

tensor(2733, dtype=torch.int16)
tensor(16.8125, dtype=torch.float16)
torch.Size([46792704, 20])


In [None]:
tokenizer = tokenizers.ByteLevelBPETokenizer(
    "./tiny-stories-bpe-vocab.json", 
    "./tiny-stories-bpe-merges.txt"
)
def encode(text):
    return tokenizer.encode(text).ids
def decode(encoded_text):
    return tokenizer.decode(encoded_text)

def get_text_from_global_index(token_idx, context_size=10):
    token = val_data[token_idx].item()
    return decode([token]), decode(val_data[token_idx-context_size:token_idx+context_size].tolist())

print(get_text_from_global_index(int(24 * 1e6)))


(' small', ' had lots of green leaves. One day, a small seed fell from the tree and landed on the')
