In [1]:
# # load dataset

# from datasets import load_dataset
# from tokenizers import ByteLevelBPETokenizer

# tokenizer = ByteLevelBPETokenizer()
# dataset = load_dataset("roneneldan/TinyStories")

# # Specify the split you want to save (e.g., "train", "validation", "test")
# split = "train"

# # Get the desired split from the dataset
# subset = dataset[split]

# # Save the subset to a text file
# subset.to_csv("tinystories-train.txt", sep="\t", index=False)


In [2]:
#----- imports --------

import tqdm
import torch
from torch import nn
import wandb
import os
import tokenizers
from matplotlib import pyplot as plt
import numpy as np
import json


device= 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_default_device(device)
assert device == 'cuda', "This notebook is not optimized for CPU"

config = {
    "learning_rate": 1e-3,
    "sae_learning_rate": 5e-5,
    "model_embedding_layer": 6,
    "eval_interval": 500,
    "max_iters": 60000, 
    "H": 32, # hidden dimension size
    "B": 64,
    "T": 256,
    "C": 256,
    "feedforward_factor": 3,
    "n_heads": 8,
    "n_layers": 12,
    "tokenizer_vocab_size": 2**13,
    "git_hash": os.popen("git rev-parse HEAD").read().strip()
}

# initial
for k,v in config.items():
    locals ()[k] = v


#wandb.init(
#    project = "tinystories",
#    config = config,
#)

In [3]:

# stories_data = []
# data_dir = './data'
# for filename in os.listdir(data_dir):
#     file_path = os.path.join(data_dir, filename)
#     if filename.endswith('.json'):
#         with open(file_path, 'r', encoding='utf-8') as f:
#             data = json.load(f)
#             stories_data.extend(data)






In [4]:
# # load the tinystories tokenizer
# tokenizer = tokenizers.ByteLevelBPETokenizer(
#     "./tiny-stories-bpe-vocab.json", 
#     "./tiny-stories-bpe-merges.txt"
# )



# def encode(text):
#     return torch.tensor(tokenizer.encode(text).ids, dtype=torch.int64)
# def decode(encoded_text):
#     return tokenizer.decode(encoded_text.tolist())

# from tqdm import tqdm

# encoded_stories = [encode(story['story']) for story in tqdm(stories_data, desc="Encoding stories")]



In [5]:
# # save the encoded stories to a file
# torch.save(encoded_stories, 'encoded-stories.pt')

In [6]:

with open('tinystories-train.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [7]:
print("length of dataset in characters: ", len(text))

length of dataset in characters:  1916206969


In [8]:
1916206969/4

479051742.25

In [9]:
print("length of dataset in lines: ", len(text.split('\n')))

length of dataset in lines:  20550005


In [10]:
print(text[:1000])

text
"One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with it because it was sharp. Lily wanted to share the needle with her mom, so she could sew a button on her shirt.

Lily went to her mom and said, ""Mom, I found this needle. Can you share it with me and sew my shirt?"" Her mom smiled and said, ""Yes, Lily, we can share the needle and fix your shirt.""

Together, they shared the needle and sewed the button on Lily's shirt. It was not difficult for them because they were sharing and helping each other. After they finished, Lily thanked her mom for sharing the needle and fixing her shirt. They both felt happy because they had shared and worked together."
"Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun. Beep was a healthy car because he always had good fuel. Good fuel made Beep happy and strong.

One day, Beep was driving in the park when he saw a big tree. The tree had many leaves that wer

In [11]:
# paths = ['tinystories-train.txt']
# tokenizer = tokenizers.ByteLevelBPETokenizer()

# tokenizer.train(files=paths, vocab_size=tokenizer_vocab_size, min_frequency=2)

# tokenizer.save_model('.', 'tiny-stories-bpe')



# enc = tokenizer.encode("She sells sea shells by the sea shore!")
# tokenizer.decode(enc.ids)



In [12]:
tokenizer = tokenizers.ByteLevelBPETokenizer(
    "./tiny-stories-bpe-vocab.json", 
    "./tiny-stories-bpe-merges.txt"
)


In [13]:

def encode(text):
    return tokenizer.encode(text).ids
def decode(encoded_text):
    return tokenizer.decode(encoded_text)

from tqdm import tqdm

def batch_encode(text, batch_size):
    tokens = []
    for i in tqdm(range(0, len(text), batch_size)):
        tokens.extend(encode(text[i:i+batch_size]))
    return tokens


hello_encoded = encode("hello")
print(hello_encoded)
print(decode(hello_encoded))
vocab_size = tokenizer.get_vocab_size()
print("vocab size: ", vocab_size)

[6132]
hello
vocab size:  8192


In [14]:
sample_text = text[:200000]
sample_encoded = batch_encode(sample_text, 20000)

# get the amount of memory used by sample_encoded
def recursive_memory_usage(python_obj):
    if isinstance(python_obj, (str, int, float)):
        return python_obj.__sizeof__()
    if isinstance(python_obj, dict):
        return sum([recursive_memory_usage(v) for v in python_obj.values()])
    if isinstance(python_obj, list):
        return sum([recursive_memory_usage(v) for v in python_obj])
    return python_obj.__sizeof__()

print("memory used by sample_encoded: ", recursive_memory_usage(sample_encoded) / 1024**2, "MB")


100%|██████████| 10/10 [00:00<00:00, 66.61it/s]

memory used by sample_encoded:  1.2918853759765625 MB





In [15]:
print("length of dataset in characters: ", len(text[:10000]))
print("length of dataset in tokens: ", len(encode(text[:10000])))
chars_per_token = len(text[:10000]) / len(encode(text[:10000]))
print("characters per token: ", chars_per_token)

length of dataset in characters:  10000
length of dataset in tokens:  2457
characters per token:  4.07000407000407


In [16]:
# encoded_text = batch_encode(text, 200000)
# # data = torch.tensor(encode(text), dtype=torch.int64)
# data = torch.tensor(encoded_text, dtype=torch.int64, device='cuda')
# print(data.dtype)
# print(data.size())
# print(data.device)
# torch.save(data, 'tiny-stories-train.pt')
# encoded_text = None


In [17]:
# load data from tiny-stories-train.pt
data = torch.load('tiny-stories-train.pt', map_location='cuda')


In [18]:
len(data)

468832276

In [19]:
n = int(0.9*len(data))

train_data = data[:n]
val_data = data[n:]

In [20]:
train_data.size()

torch.Size([421949048])

In [21]:
train_data[:T+1]

tensor([  83, 3206,  198,    1,  421,  356,   11,  258,  397,  447,  501,  364,
         596,  258, 3736,  316,  309,  759,   13,  313,  704,  304,  282, 2966,
         265,  359,  342,  304,  788,  304,  282, 2120,   13,  364,  445,  265,
         949,  262, 3736,  342,  309,  365,   11,  350,  338,  461, 5198,  258,
        2228,  345,  309, 2500,   13,  198,  198,  343,  469,  265,  309,  365,
         264,  327,   11,  329,  771,   11,  335,  596,  741, 3736,   13, 1282,
         346,  949,  304,  342,  519,  264, 5198,  652, 2500,  478,  866,  365,
         499,  264,  327,   11,  329,  832,   11,  364,   11,  363,  472,  949,
         262, 3736,  264, 1306,  627, 2500,  416,  198,  198, 4625,   11,  362,
        1656,  262, 3736,  264, 7930,  262, 2228,  345,  364,  371, 2500,   13,
         410,  282,  385, 2966,  366,  449,  788,  362,  430, 2502,  264, 1762,
         757,  573,   13, 1453,  362, 1444,   11,  364,  858,  309,  365,  366,
        2502,  262, 3736,  264, 5150,  3

In [22]:
decode(train_data[:T+1].cpu().numpy())

'text\n"One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with it because it was sharp. Lily wanted to share the needle with her mom, so she could sew a button on her shirt.\n\nLily went to her mom and said, ""Mom, I found this needle. Can you share it with me and sew my shirt?"" Her mom smiled and said, ""Yes, Lily, we can share the needle and fix your shirt.""\n\nTogether, they shared the needle and sewed the button on Lily\'s shirt. It was not difficult for them because they were sharing and helping each other. After they finished, Lily thanked her mom for sharing the needle and fixing her shirt. They both felt happy because they had shared and worked together."\n"Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun. Beep was a healthy car because he always had good fuel. Good fuel made Beep happy and strong.\n\nOne day, Beep was driving in the park when he saw a big tree. The tree had many leave

In [23]:
x = train_data[:T]
y = train_data[1:T+1]
for t in range(T):
    context = x[:t+1]
    target = y[t]
    # print("when we see the text", context, "we predict the next character is", target)

In [24]:
# torch.manual_seed(1337)

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(0, data.size(0) - T, (B,)) # 4 random locations we can sample from
    x = torch.stack([data[i:i+T] for i in ix]) # random sequences
    y = torch.stack([data[i+1:i+T+1] for i in ix]) # next character for each random sequence

    return x, y

xb, yb = get_batch('train')

for b in range(B):
    for t in range(T): # for each of the characters in the sample
        context = xb[b, :t+1]
        target = yb[b, t]


In [25]:

import torch
import torch.nn as nn
from torch.nn import functional as F
# torch.manual_seed(1337)


class Head(nn.Module):
    '''One Head of self-attention'''
    def __init__(self, H):
        super().__init__()
        self.query = nn.Linear(C, H, bias=False)
        self.key = nn.Linear(C, H, bias=False)
        self.value = nn.Linear(C, H, bias=False)
        # self.output = nn.Linear(H, C, bias=False) # output matrix
        self.register_buffer('tril', torch.tril(torch.ones(T, T)))

    def forward(self, x):
        # Query and Key matrices for the attention mechanism
        # x: 8 tokens
        # Q: 16 tall (arbitrary), 32 long channels
        # K: 16 tall (arbitrary), 32 long channels

        query_vectors = self.query(x)
        key_vectors = self.key(x)


        # Attention masking(so we can't look into the past):

        tril = self.tril
        wei = torch.zeros(T, T) 
        wei = wei.masked_fill(tril == 0, float('-inf')) # set the upper triangular to -inf
        # xbow = wei @ x # apply the mask to the input, bag of words because simple avg.

        # multiply the two to get the attention weights
        attention_pattern = query_vectors @ key_vectors.transpose(-2, -1) # T, T
        attention_pattern = attention_pattern / (H ** 0.5) # scale the attention pattern for numerical stability
        attention_weights = F.softmax(attention_pattern + wei, dim=-1) # T, T (the row dimension is the query)

        value_vectors = self.value(x) # the direction we should go in the embedding space for each token (ie more blue) T, H

        # apply the attention weights to the value vectors
        context = attention_weights @ value_vectors # T, H

        # project back into original space from value space
        # return self.output(context)
        return context

x = torch.randn(B,T,C)
head = Head(H)
# head(x)


In [26]:
class MultiHeadAttention(nn.Module):
    '''Multiple heads of self-attention'''
    def __init__(self, H, C, n_heads): # H is head embedding space size, n_heads is number of heads
        super().__init__()
        self.heads = nn.ModuleList([Head(H) for _ in range(n_heads)])
        self.combine_heads = nn.Linear(H*n_heads, C)


    def forward(self,x):
        x = torch.cat([head(x) for head in self.heads], dim=-1)
        x = self.combine_heads(x)  # T, C
        return x

In [27]:
head = MultiHeadAttention(H, C, n_heads)
head.heads[0].forward(x).shape


torch.Size([64, 256, 32])

In [28]:
class FeedForward(nn.Module):
    '''Feed-forward neural network'''
    def __init__(self, C):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(C, C * feedforward_factor),
            nn.ReLU(),
            nn.Linear(C * feedforward_factor, C),
        )

    def forward(self, x):
        return self.net(x)

In [29]:
class LayerNorm(nn.Module):
    '''Layer normalization'''
    def __init__(self, C, use_affine=True):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(C)) if use_affine else None
        self.beta = nn.Parameter(torch.zeros(C)) if use_affine else None

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        if self.gamma is not None and self.beta is not None:
            return self.gamma * (x - mean) / (std + 1e-6) + self.beta
        else:
            return (x - mean) / (std + 1e-6)

In [30]:
class Block(nn.Module):
    '''Transformer block'''
    def __init__(self, H, C, n_heads):
        super().__init__()
        self.attention = MultiHeadAttention(H, C, n_heads)
        self.ff = FeedForward(C)
        self.norm1 = LayerNorm(C, use_affine=True)
        self.norm2 = LayerNorm(C, use_affine=True)

    def forward(self, x):
        x = x + self.attention(self.norm1(x))
        x = x + self.ff(self.norm2(x))
        return x

In [31]:
class GPT(nn.Module):

    def __init__(self, n_layers):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, C) 
        self.position_embedding_table = nn.Embedding(T, C)
        self.lm_head = nn.Linear(C, vocab_size)
        self.layers = nn.ModuleList([Block(H, C, n_heads) for _ in range(n_layers)])
    
    def forward(self, idx, targets=None, return_residuals=None):
        B, T = idx.shape
        token_emb = self.token_embedding_table(idx) # batch_dim, sequence_dim, embedding_dim
        pos_emb = self.position_embedding_table(torch.arange(T))
        x = token_emb + pos_emb # token identities and positions contained

        if return_residuals == "first_embedding":
            return x

        # def excess_kurtosis(emb):
        #     mean = torch.mean(emb, dim=-1, keepdim=True) # BxTx1
        #     std = torch.std(emb, dim=-1, keepdim=True) # BxTx1

        #     centralized = emb - mean #BxTxC
        #     fourth_moment = torch.mean(centralized**4, dim=-1, keepdim=True) # BxTx1
        #     kurtosis = torch.squeeze(fourth_moment / std**4, dim=-1) # BxT
        #     # view as a 1d vector
        #     kurtosis = kurtosis.view(-1) - 3
        #     # make each one min 0
        #     kurtosis = torch.maximum(kurtosis, torch.tensor(0.0))
        #     # sum over the vector
        #     kurtosis = torch.sum(kurtosis)
        #     return kurtosis


        # kurtosis_sum = torch.tensor(0.0)
        for i, layer in enumerate(self.layers):
            x = layer(x)
            # kurtosis_sum += excess_kurtosis(x)
            if return_residuals is not None and i == return_residuals:
                return x
        
        # kurtosis_avg = kurtosis_sum / (len(self.layers) * T * B)
        kurtosis_avg = torch.tensor(0.0)

        logits = self.lm_head(x) # batch_dim, sequence_dim, vocab_size

        batch_dim, sequence_dim, embedding_dim = logits.size()

        # loss = F.cross_entropy(logits, targets) this won't work because we need 1d logits and 1d targets
        # one-hot-vectors are a line in the x-dimension, so the shape of shape of the logits should be (-1, vocab_size).

        if targets is None:
            return logits, None, kurtosis_avg
        else:
            # a list of all the predictions, reguardles of batch.
            # xdim: probabilities of each character in the vocab (embedding_dim=vocab_size)
            # ydim: all predictions for all batches flattened (batch_dim*sequence_dim)
            logits_loss_view = logits.view(-1, vocab_size) 
            # targets loss view
            # xdim: all targets for all batches flattened (batch_dim*sequence_dim)
            # so this would be like, [1,4,5,1,2,3, ...]
            # where each number is the correct next index of the one hot vector
            targets_loss_view = targets.view(-1)
            loss = F.cross_entropy(logits_loss_view, targets_loss_view)
            return logits, loss, kurtosis_avg

    def generate(self, idx, max_new_tokens, temperature=0.5):
        for _ in range(max_new_tokens):
            logits, loss = self(idx[:,-T:])
            # get the predictions of the last token
            last_token_logits = logits[:, -1, :] # all batches, last token, all probabilities
            # apply temperature
            last_token_logits = last_token_logits / temperature
            # softmax to get probabilities
            probabilities = F.softmax(last_token_logits, dim=-1)
            # sample from the probabilities
            next_token = torch.multinomial(probabilities, num_samples=1)
            # add the new token to the idx tensor
            idx = torch.cat((idx, next_token), dim=1)
        return idx
    def prompt_model(self, prompt, max_new_tokens, temperature=0.5):
        autoregressive_seq = encode(prompt)
        for _ in range(max_new_tokens):
            prediction_index = len(autoregressive_seq)-1

            model_input = torch.tensor(autoregressive_seq)
            
            while model_input.shape[0] < T:
                pad_token = torch.tensor(encode("\n"))
                model_input = torch.cat((model_input, pad_token), dim=0)

            model_input
            model_input = model_input.unsqueeze(0)

            logits, loss, kurtosis_avg = model(model_input)
            prediction_token = logits[:, prediction_index, :] / temperature
            probabilities = F.softmax(prediction_token, dim=-1)
            next_token = torch.multinomial(probabilities, num_samples=1)
            next_token = next_token.item()

            autoregressive_seq.append(next_token)
        # get the autoregressive sequence
        return decode(autoregressive_seq)
    def get_embedding(self, prompt, override_model_embedding_layer=None):
        if override_model_embedding_layer is None:
            selected_model_embedding_layer = model_embedding_layer
        else:
            selected_model_embedding_layer = override_model_embedding_layer
        sequence = encode(prompt)
        model_input = torch.tensor(sequence)
        sequence_index = len(sequence) - 1
        while model_input.shape[0] < T:
            pad_token = torch.tensor(encode("\n"))
            model_input = torch.cat((model_input, pad_token), dim=0)
        model_input = model_input.unsqueeze(0)
        embedding = self.forward(model_input, return_residuals=selected_model_embedding_layer)
        # remove the batch dimension
        embedding = embedding.squeeze(0)[sequence_index]
        return embedding



    

model = GPT(n_layers)
# logits, loss, kurtosis_avg = model(xb, yb)
# print(logits.shape)
# print(loss)
# print(kurtosis_avg)




test_idx = torch.zeros(1, T).long()
model.forward(idx=test_idx)
# decode(model.generate(idx=test_idx, max_new_tokens=100)[0].tolist())

(tensor([[[-0.2082, -0.4307, -1.4964,  ..., -0.4924,  0.3247,  0.0516],
          [ 0.3258,  0.6623, -1.7060,  ..., -1.1121,  0.1928, -1.1646],
          [ 1.4009,  0.4413, -1.9436,  ..., -1.1725,  1.1626, -1.2866],
          ...,
          [-1.0671, -0.6851, -1.0190,  ..., -1.7658,  0.0694, -1.3291],
          [-1.2112, -0.6265, -1.5360,  ..., -1.1472, -0.9685, -2.1975],
          [-0.3226,  0.3611, -0.4603,  ..., -1.1162,  0.3319, -1.2689]]],
        device='cuda:0', grad_fn=<ViewBackward0>),
 None,
 tensor(0., device='cuda:0'))

In [32]:
model

GPT(
  (token_embedding_table): Embedding(8192, 256)
  (position_embedding_table): Embedding(256, 256)
  (lm_head): Linear(in_features=256, out_features=8192, bias=True)
  (layers): ModuleList(
    (0-11): 12 x Block(
      (attention): MultiHeadAttention(
        (heads): ModuleList(
          (0-7): 8 x Head(
            (query): Linear(in_features=256, out_features=32, bias=False)
            (key): Linear(in_features=256, out_features=32, bias=False)
            (value): Linear(in_features=256, out_features=32, bias=False)
          )
        )
        (combine_heads): Linear(in_features=256, out_features=256, bias=True)
      )
      (ff): FeedForward(
        (net): Sequential(
          (0): Linear(in_features=256, out_features=768, bias=True)
          (1): ReLU()
          (2): Linear(in_features=768, out_features=256, bias=True)
        )
      )
      (norm1): LayerNorm()
      (norm2): LayerNorm()
    )
  )
)

In [33]:
# get the number of parameters in the model
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print("number of parameters in the model: ", count_parameters(model))

number of parameters in the model:  12160000


In [34]:
data.shape

torch.Size([468832276])

In [35]:
# logits, loss = self(idx[:,-T:])

idx = torch.zeros(1, 1).long()
idx[:,-T:]

tensor([[0]], device='cuda:0')

In [36]:
model.token_embedding_table.weight.device

device(type='cuda', index=0)

In [37]:
eval_iters = 10
eval_interval = 300
@torch.no_grad()
def estimate_loss(is_last=False):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        real_iters = eval_iters
        if is_last and split == 'val':  # increase last eval to mitigate noise
            real_iters *= 10 
        losses = torch.zeros(real_iters)
        for k in range(real_iters):
            X, Y = get_batch(split)
            logits, loss, kurtosis_avg = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean() / chars_per_token
    model.train()
    return out
    

In [38]:
# get the number of parameters
n_params = sum(p.numel() for p in model.parameters())
parameter_to_data_ratio = n_params / len(train_data)
print(f"{parameter_to_data_ratio=}")

parameters = []
for name, param in model.named_parameters():
    parameters.append({"name": name, "params": param.numel()})

# sort parameters by size
sorted_parameters = sorted(parameters, key=lambda x: x["params"], reverse=True)
for p in sorted_parameters:
    print(f"{p['name']}: {p['params']}")

parameter_to_data_ratio=0.028818645420903996
token_embedding_table.weight: 2097152
lm_head.weight: 2097152
layers.0.ff.net.0.weight: 196608
layers.0.ff.net.2.weight: 196608
layers.1.ff.net.0.weight: 196608
layers.1.ff.net.2.weight: 196608
layers.2.ff.net.0.weight: 196608
layers.2.ff.net.2.weight: 196608
layers.3.ff.net.0.weight: 196608
layers.3.ff.net.2.weight: 196608
layers.4.ff.net.0.weight: 196608
layers.4.ff.net.2.weight: 196608
layers.5.ff.net.0.weight: 196608
layers.5.ff.net.2.weight: 196608
layers.6.ff.net.0.weight: 196608
layers.6.ff.net.2.weight: 196608
layers.7.ff.net.0.weight: 196608
layers.7.ff.net.2.weight: 196608
layers.8.ff.net.0.weight: 196608
layers.8.ff.net.2.weight: 196608
layers.9.ff.net.0.weight: 196608
layers.9.ff.net.2.weight: 196608
layers.10.ff.net.0.weight: 196608
layers.10.ff.net.2.weight: 196608
layers.11.ff.net.0.weight: 196608
layers.11.ff.net.2.weight: 196608
position_embedding_table.weight: 65536
layers.0.attention.combine_heads.weight: 65536
layers.1.at

In [43]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

import tqdm
num_params = sum([p.numel() for p in model.parameters()])

for steps in tqdm.tqdm(range(max_iters)):
    xb, yb = get_batch('train')
    # loss
    logits, loss, kurtosis_avg = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    # l2 regularization
    # l2 = sum(p.pow(2).sum() for p in model.parameters()) / num_params
    loss = loss + kurtosis_avg * 0.4/160

    loss.backward()
    optimizer.step()
    if steps % eval_interval == 0:
        losses = estimate_loss()
        # wandb.log({"tIain": losses['train'].item(), "val": losses['val'].item(), "l2":l2})
        print({"tIain": losses['train'].item(), "val": losses['val'].item(), "kurtosis_avg": kurtosis_avg.item()})

losses = estimate_loss(is_last=True)
# wandb.log({"train": losses['train'].item(), "val": losses['val'].item()})
# wandb.finish()


  0%|          | 2/60000 [00:01<8:27:25,  1.97it/s] 

{'tIain': 2.429755449295044, 'val': 2.431878089904785, 'kurtosis_avg': 0.0}


  1%|          | 302/60000 [00:36<4:44:56,  3.49it/s]

{'tIain': 0.8367623686790466, 'val': 0.8470489382743835, 'kurtosis_avg': 0.0}


  1%|          | 602/60000 [01:18<4:41:42,  3.51it/s]

{'tIain': 0.7151026129722595, 'val': 0.7169919013977051, 'kurtosis_avg': 0.0}


  2%|▏         | 902/60000 [02:03<3:51:03,  4.26it/s]

{'tIain': 0.6502602696418762, 'val': 0.6510316133499146, 'kurtosis_avg': 0.0}


  2%|▏         | 1202/60000 [02:44<3:03:09,  5.35it/s]

{'tIain': 0.6058956980705261, 'val': 0.605420708656311, 'kurtosis_avg': 0.0}


  3%|▎         | 1502/60000 [03:26<4:49:00,  3.37it/s]

{'tIain': 0.5789751410484314, 'val': 0.5838618278503418, 'kurtosis_avg': 0.0}


  3%|▎         | 1803/60000 [04:12<2:46:37,  5.82it/s]

{'tIain': 0.5555378794670105, 'val': 0.557843804359436, 'kurtosis_avg': 0.0}


  4%|▎         | 2103/60000 [04:50<2:54:14,  5.54it/s]

{'tIain': 0.5497593283653259, 'val': 0.5462676882743835, 'kurtosis_avg': 0.0}


  4%|▍         | 2403/60000 [05:28<2:34:13,  6.22it/s]

{'tIain': 0.5319753289222717, 'val': 0.5288966298103333, 'kurtosis_avg': 0.0}


  5%|▍         | 2703/60000 [06:07<3:09:55,  5.03it/s]

{'tIain': 0.5249441862106323, 'val': 0.5181480050086975, 'kurtosis_avg': 0.0}


  5%|▌         | 3002/60000 [06:49<4:36:09,  3.44it/s]

{'tIain': 0.5122328996658325, 'val': 0.5077174305915833, 'kurtosis_avg': 0.0}


  6%|▌         | 3302/60000 [07:31<4:37:28,  3.41it/s]

{'tIain': 0.5039597153663635, 'val': 0.5057711601257324, 'kurtosis_avg': 0.0}


  6%|▌         | 3602/60000 [08:10<3:50:07,  4.08it/s]

{'tIain': 0.5024383664131165, 'val': 0.49779969453811646, 'kurtosis_avg': 0.0}


  7%|▋         | 3902/60000 [08:48<3:13:10,  4.84it/s]

{'tIain': 0.4950229525566101, 'val': 0.49174731969833374, 'kurtosis_avg': 0.0}


  7%|▋         | 4203/60000 [09:26<2:31:15,  6.15it/s]

{'tIain': 0.48267045617103577, 'val': 0.4901937246322632, 'kurtosis_avg': 0.0}


  8%|▊         | 4502/60000 [10:08<4:50:28,  3.18it/s]

{'tIain': 0.486367791891098, 'val': 0.4832426607608795, 'kurtosis_avg': 0.0}


  8%|▊         | 4803/60000 [10:40<2:24:08,  6.38it/s]

{'tIain': 0.47537341713905334, 'val': 0.4770694077014923, 'kurtosis_avg': 0.0}


  9%|▊         | 5102/60000 [11:20<4:14:24,  3.60it/s]

{'tIain': 0.47356557846069336, 'val': 0.4775037467479706, 'kurtosis_avg': 0.0}


  9%|▉         | 5402/60000 [12:04<3:14:06,  4.69it/s]

{'tIain': 0.46566274762153625, 'val': 0.463899165391922, 'kurtosis_avg': 0.0}


 10%|▉         | 5702/60000 [12:42<4:10:09,  3.62it/s]

{'tIain': 0.4797123670578003, 'val': 0.46505361795425415, 'kurtosis_avg': 0.0}


 10%|█         | 6002/60000 [13:26<4:10:52,  3.59it/s]

{'tIain': 0.4616134464740753, 'val': 0.46008652448654175, 'kurtosis_avg': 0.0}


 11%|█         | 6302/60000 [14:10<4:19:04,  3.45it/s]

{'tIain': 0.4644951820373535, 'val': 0.4598516523838043, 'kurtosis_avg': 0.0}


 11%|█         | 6602/60000 [14:46<3:21:03,  4.43it/s]

{'tIain': 0.45522618293762207, 'val': 0.4626397490501404, 'kurtosis_avg': 0.0}


 12%|█▏        | 6902/60000 [15:22<4:16:21,  3.45it/s]

{'tIain': 0.46410071849823, 'val': 0.458983451128006, 'kurtosis_avg': 0.0}


 12%|█▏        | 7203/60000 [16:05<2:32:46,  5.76it/s]

{'tIain': 0.45678386092185974, 'val': 0.45689690113067627, 'kurtosis_avg': 0.0}


 13%|█▎        | 7502/60000 [16:45<4:19:21,  3.37it/s]

{'tIain': 0.44554781913757324, 'val': 0.4515269100666046, 'kurtosis_avg': 0.0}


 13%|█▎        | 7802/60000 [17:21<4:09:26,  3.49it/s]

{'tIain': 0.450727641582489, 'val': 0.45100945234298706, 'kurtosis_avg': 0.0}


 14%|█▎        | 8102/60000 [18:00<3:20:54,  4.31it/s]

{'tIain': 0.4468293786048889, 'val': 0.4453596770763397, 'kurtosis_avg': 0.0}


 14%|█▍        | 8402/60000 [18:43<4:13:52,  3.39it/s]

{'tIain': 0.44623324275016785, 'val': 0.4518088698387146, 'kurtosis_avg': 0.0}


 15%|█▍        | 8702/60000 [19:21<2:26:37,  5.83it/s]

{'tIain': 0.4412625730037689, 'val': 0.45233696699142456, 'kurtosis_avg': 0.0}


 15%|█▌        | 9003/60000 [19:52<2:19:40,  6.09it/s]

{'tIain': 0.4462968409061432, 'val': 0.4504774808883667, 'kurtosis_avg': 0.0}


 16%|█▌        | 9302/60000 [20:23<2:26:00,  5.79it/s]

{'tIain': 0.4341552257537842, 'val': 0.4360671639442444, 'kurtosis_avg': 0.0}


 16%|█▌        | 9602/60000 [20:57<4:11:11,  3.34it/s]

{'tIain': 0.44603899121284485, 'val': 0.4409247636795044, 'kurtosis_avg': 0.0}


 17%|█▋        | 9902/60000 [21:42<4:03:39,  3.43it/s]

{'tIain': 0.437929630279541, 'val': 0.4411149024963379, 'kurtosis_avg': 0.0}


 17%|█▋        | 10202/60000 [22:20<3:18:26,  4.18it/s]

{'tIain': 0.43779417872428894, 'val': 0.4368560314178467, 'kurtosis_avg': 0.0}


 18%|█▊        | 10502/60000 [23:03<3:30:39,  3.92it/s]

{'tIain': 0.4347199499607086, 'val': 0.438816100358963, 'kurtosis_avg': 0.0}


 18%|█▊        | 10802/60000 [23:47<3:48:52,  3.58it/s]

{'tIain': 0.4374655783176422, 'val': 0.43755337595939636, 'kurtosis_avg': 0.0}


 19%|█▊        | 11102/60000 [24:25<4:02:45,  3.36it/s]

{'tIain': 0.4372262954711914, 'val': 0.43769925832748413, 'kurtosis_avg': 0.0}


 19%|█▉        | 11402/60000 [25:01<2:37:23,  5.15it/s]

{'tIain': 0.4402065575122833, 'val': 0.4414510130882263, 'kurtosis_avg': 0.0}


 20%|█▉        | 11702/60000 [25:38<2:44:17,  4.90it/s]

{'tIain': 0.43819543719291687, 'val': 0.4373345375061035, 'kurtosis_avg': 0.0}


 20%|██        | 12002/60000 [26:24<2:27:06,  5.44it/s]

{'tIain': 0.42662832140922546, 'val': 0.4348389804363251, 'kurtosis_avg': 0.0}


 21%|██        | 12302/60000 [26:58<3:25:25,  3.87it/s]

{'tIain': 0.42912256717681885, 'val': 0.4344235062599182, 'kurtosis_avg': 0.0}


 21%|██        | 12603/60000 [27:47<2:25:22,  5.43it/s]

{'tIain': 0.4255281984806061, 'val': 0.4332534968852997, 'kurtosis_avg': 0.0}


 22%|██▏       | 12902/60000 [28:29<2:51:21,  4.58it/s]

{'tIain': 0.4269726872444153, 'val': 0.4232653081417084, 'kurtosis_avg': 0.0}


 22%|██▏       | 13202/60000 [29:11<2:56:52,  4.41it/s]

{'tIain': 0.42444613575935364, 'val': 0.4320295751094818, 'kurtosis_avg': 0.0}


 23%|██▎       | 13502/60000 [29:52<3:15:07,  3.97it/s]

{'tIain': 0.4284628629684448, 'val': 0.4298068583011627, 'kurtosis_avg': 0.0}


 23%|██▎       | 13802/60000 [30:31<3:21:12,  3.83it/s]

{'tIain': 0.42587125301361084, 'val': 0.43010571599006653, 'kurtosis_avg': 0.0}


 24%|██▎       | 14102/60000 [31:17<2:38:43,  4.82it/s]

{'tIain': 0.41953179240226746, 'val': 0.4261939525604248, 'kurtosis_avg': 0.0}


 24%|██▍       | 14403/60000 [31:55<2:11:09,  5.79it/s]

{'tIain': 0.43378379940986633, 'val': 0.42606818675994873, 'kurtosis_avg': 0.0}


 25%|██▍       | 14702/60000 [32:31<2:21:18,  5.34it/s]

{'tIain': 0.4257335960865021, 'val': 0.43004879355430603, 'kurtosis_avg': 0.0}


 25%|██▌       | 15002/60000 [33:09<2:46:57,  4.49it/s]

{'tIain': 0.4245215058326721, 'val': 0.4223972260951996, 'kurtosis_avg': 0.0}


 26%|██▌       | 15302/60000 [33:52<2:54:55,  4.26it/s]

{'tIain': 0.42161744832992554, 'val': 0.426285058259964, 'kurtosis_avg': 0.0}


 26%|██▌       | 15496/60000 [34:14<1:38:20,  7.54it/s]


KeyboardInterrupt: 

In [44]:
estimate_loss()

{'train': tensor(0.4164, device='cuda:0'),
 'val': tensor(0.4204, device='cuda:0')}

In [45]:
# # save model
# torch.save(model.state_dict(), 'tiny-stories-model.pt')



In [46]:
# load the model
model.load_state_dict(torch.load('tiny-stories-model.pt'))


<All keys matched successfully>

In [47]:
print(model.prompt_model("Lilly saw a big red apple.", 200, 0.4))

Lilly saw a big red apple. She wanted to eat it, but she didn't know how. She tried to take a bite, but it was too high.

Suddenly, Lilly saw a little bird. The bird was stuck in a tree. Lilly knew she had to rescue the bird. She climbed up the tree and carefully climbed up. The bird was safe and Lilly felt happy.

Lilly learned that sometimes it's good to ask for help when you need it. She also learned that it's important to always ask for help when you need it. Lilly was proud of herself for being brave and kind to the bird."
"Once upon a time, there was a girl named Lily. She loved to play outside in the sun. One day, she went to the park to play. She saw a big tree and wanted to climb it. But she was scared because she was scared.

Lily saw a bird and said, ""Hi, bird! Can you help me climb the tree?"" The bird replied


# Kurtosis debugging

In [None]:
story1='''Once upon a time, in a big forest, there lived a rhinoceros named Roxy. Roxy loved to climb. She climbed trees, rocks, and hills. One day, Roxy found an icy hill. She had never seen anything like it before. It was shiny and cold, and she wanted to climb it.
Roxy tried to climb the icy hill, but it was very slippery. She tried again and again, but she kept falling down. Roxy was sad. She wanted to climb the icy hill so much. Then, she saw a little bird named Billy. Billy saw that Roxy was sad and asked, "Why are you sad, Roxy?"
Roxy told Billy about the icy hill and how she couldn't climb it'''

# assume BxTxC
def excess_kurtosis(emb):
    mean = torch.mean(emb, dim=-1, keepdim=True) # BxTx1
    std = torch.std(emb, dim=-1, keepdim=True) # BxTx1

    centralized = emb - mean #BxTxC
    fourth_moment = torch.mean(centralized**4, dim=-1, keepdim=True) # BxTx1
    kurtosis = torch.squeeze(fourth_moment / std**4, dim=-1) # BxT
    return kurtosis - 3



emb1 = model.get_embedding("Tim and Lily saw a big dog", override_model_embedding_layer=6)
emb2 = model.get_embedding("Tim and Lily noticed a cat", override_model_embedding_layer=6)


import matplotlib.pyplot as plt
import numpy as np



# Plot emb1 and emb2 in the same plot
# plt.figure(figsize=(10, 5))
# plt.plot(np.square(emb1.cpu().detach().numpy()), label='emb1', color='blue')
# plt.plot(np.square(emb2.cpu().detach().numpy()), label='emb2', color='red')
# plt.xlabel('Index')
# plt.ylabel('Value')
# plt.title('emb1 and emb2 Plot')
# plt.legend()
# plt.show()



# get the index of the highest value
# Assuming emb1 and emb2 are tensors
highest_value_index_emb1 = torch.argmax(emb1).item()
highest_value_index_emb2 = torch.argmax(emb2).item()

lowest_value_index_emb1 = torch.argmin(emb1).item()
lowest_value_index_emb2 = torch.argmin(emb2).item()

print(f"Index of the highest value in emb1: {highest_value_index_emb1}")
print(f"Index of the highest value in emb2: {highest_value_index_emb2}")
print(f"Index of the lowest value in emb1: {lowest_value_index_emb1}")
print(f"Index of the lowest value in emb2: {lowest_value_index_emb2}")

print(f"emb1 excess kurtosis: {excess_kurtosis(emb1)}")
print(f"emb2 excess kurtosis: {excess_kurtosis(emb2)}")

# dot product between emb1 and emb2
emb1_l2 = F.normalize(emb1, p=2, dim=-1)
emb2_l2 = F.normalize(emb2, p=2, dim=-1)
print(f"Dot product between emb1 and emb2: {torch.dot(emb1_l2, emb2_l2)}")



Index of the highest value in emb1: 186
Index of the highest value in emb2: 186
Index of the lowest value in emb1: 171
Index of the lowest value in emb2: 171
emb1 excess kurtosis: 157.1050262451172
emb2 excess kurtosis: 156.85986328125
Dot product between emb1 and emb2: 0.9625234007835388


emb1 excess kurtosis: 157.1050262451172
emb2 excess kurtosis: 156.85986328125
when we load the model trained from this notebook, it has excess kurtosis of 157.1050262451172