In [31]:
from gpt import load_model_from_checkpoint, GPT, GPTConfig, generate, preprocess_tokens_from_huggingface
import transformers
import torch
from dataclasses import dataclass, fields
from tqdm import tqdm
import os


device= 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_default_device(device)
assert device == 'cuda', "This notebook is not optimized for CPU"
torch.autograd.set_grad_enabled(False)


<torch.autograd.grad_mode.set_grad_enabled at 0x7567acb10640>

In [32]:
model = load_model_from_checkpoint("./llms/50m_llm.pt")

In [33]:
enc = transformers.AutoTokenizer.from_pretrained('activated-ai/tiny-stories-8k-tokenizer')


In [34]:
generate(model, enc, "Once upon a time", 100, 2)


["Once upon a time, there was a big bird named Max. Max was very hungry and wanted to eat something yummy. So, he decided to go on an adventure. He flew high in the sky and saw many places. \n\nAs he was flying, he saw a big tree with lots of fruits. Max knew that he could eat the fruits if he was very careful. He didn't want to be scared or make any threats. Max flew down and ate some grapes. \n\nAfter",
 'Once upon a time there was a little girl named Sarah. She had ten long fingers which made her feel so much love.\n\nOne day, Sarah went out with her mommy to the park. While they were there, Sarah saw something that made her eyes grow: a big, soft, pink teddy bear. She knew she had to have it.\n\nSo Sarah asked her mommy if she could have the teddy bear. Her mommy said yes, so Sarah was so happy! She hugged the']

In [35]:
preprocess_tokens_from_huggingface("./datasets")

skipping token preprocessing for validation : using cache ./datasets/validation.pt
skipping token preprocessing for train : using cache ./datasets/train.pt


In [36]:
train =  torch.load("datasets/train.pt", map_location=device)
train = train.to(torch.long)

In [37]:
@dataclass
class EmbeddingGeneratorConfig:
    batch_size: int = 512
    block_size: int = 512
    n_embd: int = 512
    ratio_tokens_saved: float = 0.07
    residual_layer: int = 6
    mb_per_save: int = 2000
    save_dir: float = "./residuals/"
    

In [38]:
embeddingconfig = EmbeddingGeneratorConfig(
    batch_size=512,
    block_size=model.config.block_size,
    n_embd=model.config.n_embd,
    residual_layer=round(model.config.n_layer*0.65)
)

In [39]:
dataset_remainder = train.shape[0] % (embeddingconfig.block_size * embeddingconfig.batch_size)
dataset_length = train.shape[0] - dataset_remainder
print("removed tokens from dataset:", dataset_remainder)
batches = train[:dataset_length].view(-1, embeddingconfig.batch_size, embeddingconfig.block_size)

removed tokens from dataset: 130959


In [40]:
def n_embd_to_mb(n):
    mb_per_embedding = embeddingconfig.n_embd * 2 / 1_000_000
    return mb_per_embedding * n

def mb_to_n_embd(mb):
    mb_per_embedding = embeddingconfig.n_embd * 2 / 1_000_000
    return int(mb / mb_per_embedding)

In [41]:
print("estimated storage on disk (MB):", n_embd_to_mb(int(batches.shape[0]*embeddingconfig.ratio_tokens_saved*embeddingconfig.block_size*embeddingconfig.batch_size)))

estimated storage on disk (MB): 33447.057408


In [42]:
torch.tensor(5.5).to(torch.bfloat16)

tensor(5.5000, device='cuda:0', dtype=torch.bfloat16)

In [43]:
#@torch.no_grad()

save_residuals_buffer = []
global_token_starts_buffer = []
global_context_window_starts_buffer = []

save_counter = 0

os.makedirs(embeddingconfig.save_dir, exist_ok=True)

print("estimated storage on disk (MB):", n_embd_to_mb(int(batches.shape[0]*embeddingconfig.ratio_tokens_saved*embeddingconfig.block_size*embeddingconfig.batch_size)))

for batch_index, batch in enumerate(tqdm(batches)):
    tokens_per_batch = embeddingconfig.batch_size * embeddingconfig.block_size
    global_token_start_pos = batch_index * embeddingconfig.batch_size * embeddingconfig.block_size
    local_idxs = torch.randperm(tokens_per_batch)[:int(tokens_per_batch * embeddingconfig.ratio_tokens_saved)]
    global_idxs = local_idxs + global_token_start_pos
    global_window_starts = global_idxs - global_idxs % embeddingconfig.block_size
    global_context_window_starts_buffer += global_window_starts.tolist()
    global_token_starts_buffer += global_idxs.tolist()
    
    
    model_out = model(batch, return_layer_embs = embeddingconfig.residual_layer).view(-1, embeddingconfig.n_embd)[local_idxs, :]
    save_residuals_buffer.append(model_out)
    num_embeddings_in_buffer = embeddingconfig.batch_size * embeddingconfig.block_size * len(save_residuals_buffer) * embeddingconfig.ratio_tokens_saved
    if n_embd_to_mb(num_embeddings_in_buffer) > embeddingconfig.mb_per_save:
        residuals_tensor = torch.cat(save_residuals_buffer)
        torch.save({
                    "residuals": residuals_tensor.to(torch.bfloat16),
                    "token_idxs": global_token_starts_buffer,
                    "context_window_starts": global_context_window_starts_buffer,
                    "config": embeddingconfig
                    },
                    f"./{embeddingconfig.save_dir}/{save_counter}.pt",)
        save_counter += 1
        save_residuals_buffer = []
        global_token_starts_buffer = []
        global_context_window_starts_buffer = []


    

estimated storage on disk (MB): 33447.057408


100%|██████████| 1780/1780 [19:52<00:00,  1.49it/s] 


In [51]:
torch.load("./residuals/2.pt")['residuals'][0]

tensor([ 1.9688e+00,  1.8438e+00,  1.3672e+00, -2.1094e+00, -2.0938e+00,
        -1.1133e-01, -5.4688e-01, -1.1133e-01,  1.6235e-02,  1.2969e+00,
        -1.1562e+00,  7.1875e-01, -1.4141e+00,  7.4219e-01,  4.4531e-01,
        -1.0234e+00,  1.2266e+00, -1.4258e-01,  7.1094e-01, -1.4258e-01,
        -3.2031e-01,  3.0469e-01,  1.3281e-01,  5.3906e-01,  2.5781e-01,
         1.1875e+00, -5.2344e-01, -1.1641e+00, -3.9648e-01,  1.6953e+00,
         1.9238e-01, -1.1172e+00,  1.0986e-01, -2.1387e-01, -6.0547e-01,
        -2.8687e-02,  4.4531e-01, -1.1484e+00,  6.3672e-01, -4.3945e-01,
        -2.1289e-01,  7.5000e+00, -4.5703e-01, -6.6016e-01,  2.7969e+00,
        -3.5352e-01,  1.0312e+00, -5.5078e-01, -9.0820e-02,  4.8828e-02,
         1.6562e+00,  9.9219e-01, -1.3281e+00,  1.2422e+00,  1.2656e+00,
         1.3594e+00, -8.1250e-01,  2.1094e+00,  3.5742e-01, -3.5742e-01,
        -5.5078e-01,  1.0547e+00,  1.8750e+00, -4.1016e-01, -1.0889e-01,
        -9.5703e-01, -5.7422e-01,  7.5195e-02, -2.6