In [1]:
# # load dataset

# from datasets import load_dataset
# from tokenizers import ByteLevelBPETokenizer

# tokenizer = ByteLevelBPETokenizer()
# dataset = load_dataset("roneneldan/TinyStories")

# # Specify the split you want to save (e.g., "train", "validation", "test")
# split = "train"

# # Get the desired split from the dataset
# subset = dataset[split]

# # Save the subset to a text file
# subset.to_csv("tinystories-train.txt", sep="\t", index=False)


In [2]:
#----- imports --------

import tqdm
import torch
from torch import nn
import wandb
import os
import tokenizers
from matplotlib import pyplot as plt
import numpy as np
import json


device= 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_default_device(device)
assert device == 'cuda', "This notebook is not optimized for CPU"

config = {
    "learning_rate": 1e-3,
    "sae_learning_rate": 5e-5,
    "model_embedding_layer": 6,
    "eval_interval": 500,
    "max_iters": 60000, 
    "H": 32, # hidden dimension size
    "B": 64,
    "T": 256,
    "C": 256,
    "feedforward_factor": 3,
    "n_heads": 8,
    "n_layers": 12,
    "tokenizer_vocab_size": 2**13,
    "git_hash": os.popen("git rev-parse HEAD").read().strip()
}

# initial
for k,v in config.items():
    locals ()[k] = v


#wandb.init(
#    project = "tinystories",
#    config = config,
#)

In [3]:

# stories_data = []
# data_dir = './data'
# for filename in os.listdir(data_dir):
#     file_path = os.path.join(data_dir, filename)
#     if filename.endswith('.json'):
#         with open(file_path, 'r', encoding='utf-8') as f:
#             data = json.load(f)
#             stories_data.extend(data)






In [4]:
# # load the tinystories tokenizer
# tokenizer = tokenizers.ByteLevelBPETokenizer(
#     "./tiny-stories-bpe-vocab.json", 
#     "./tiny-stories-bpe-merges.txt"
# )



# def encode(text):
#     return torch.tensor(tokenizer.encode(text).ids, dtype=torch.int64)
# def decode(encoded_text):
#     return tokenizer.decode(encoded_text.tolist())

# from tqdm import tqdm

# encoded_stories = [encode(story['story']) for story in tqdm(stories_data, desc="Encoding stories")]



In [5]:
# # save the encoded stories to a file
# torch.save(encoded_stories, 'encoded-stories.pt')

In [6]:

with open('tinystories-train.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [7]:
print("length of dataset in characters: ", len(text))

length of dataset in characters:  1916206969


In [8]:
1916206969/4

479051742.25

In [9]:
print("length of dataset in lines: ", len(text.split('\n')))

length of dataset in lines:  20550005


In [10]:
print(text[:1000])

text
"One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with it because it was sharp. Lily wanted to share the needle with her mom, so she could sew a button on her shirt.

Lily went to her mom and said, ""Mom, I found this needle. Can you share it with me and sew my shirt?"" Her mom smiled and said, ""Yes, Lily, we can share the needle and fix your shirt.""

Together, they shared the needle and sewed the button on Lily's shirt. It was not difficult for them because they were sharing and helping each other. After they finished, Lily thanked her mom for sharing the needle and fixing her shirt. They both felt happy because they had shared and worked together."
"Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun. Beep was a healthy car because he always had good fuel. Good fuel made Beep happy and strong.

One day, Beep was driving in the park when he saw a big tree. The tree had many leaves that wer

In [11]:
# paths = ['tinystories-train.txt']
# tokenizer = tokenizers.ByteLevelBPETokenizer()

# tokenizer.train(files=paths, vocab_size=tokenizer_vocab_size, min_frequency=2)

# tokenizer.save_model('.', 'tiny-stories-bpe')



# enc = tokenizer.encode("She sells sea shells by the sea shore!")
# tokenizer.decode(enc.ids)



In [12]:
tokenizer = tokenizers.ByteLevelBPETokenizer(
    "./tiny-stories-bpe-vocab.json", 
    "./tiny-stories-bpe-merges.txt"
)


In [13]:

def encode(text):
    return tokenizer.encode(text).ids
def decode(encoded_text):
    return tokenizer.decode(encoded_text)

from tqdm import tqdm

def batch_encode(text, batch_size):
    tokens = []
    for i in tqdm(range(0, len(text), batch_size)):
        tokens.extend(encode(text[i:i+batch_size]))
    return tokens


hello_encoded = encode("hello")
print(hello_encoded)
print(decode(hello_encoded))
vocab_size = tokenizer.get_vocab_size()
print("vocab size: ", vocab_size)

[6132]
hello
vocab size:  8192


In [14]:
sample_text = text[:200000]
sample_encoded = batch_encode(sample_text, 20000)

# get the amount of memory used by sample_encoded
def recursive_memory_usage(python_obj):
    if isinstance(python_obj, (str, int, float)):
        return python_obj.__sizeof__()
    if isinstance(python_obj, dict):
        return sum([recursive_memory_usage(v) for v in python_obj.values()])
    if isinstance(python_obj, list):
        return sum([recursive_memory_usage(v) for v in python_obj])
    return python_obj.__sizeof__()

print("memory used by sample_encoded: ", recursive_memory_usage(sample_encoded) / 1024**2, "MB")


100%|██████████| 10/10 [00:00<00:00, 52.40it/s]


memory used by sample_encoded:  1.2918853759765625 MB


In [15]:
print("length of dataset in characters: ", len(text[:10000]))
print("length of dataset in tokens: ", len(encode(text[:10000])))
chars_per_token = len(text[:10000]) / len(encode(text[:10000]))
print("characters per token: ", chars_per_token)

length of dataset in characters:  10000
length of dataset in tokens:  2457
characters per token:  4.07000407000407


In [16]:
# encoded_text = batch_encode(text, 200000)
# # data = torch.tensor(encode(text), dtype=torch.int64)
# data = torch.tensor(encoded_text, dtype=torch.int64, device='cuda')
# print(data.dtype)
# print(data.size())
# print(data.device)
# torch.save(data, 'tiny-stories-train.pt')
# encoded_text = None


In [17]:
# load data from tiny-stories-train.pt
data = torch.load('tiny-stories-train.pt', map_location='cuda')


In [18]:
len(data)

468832276

In [19]:
n = int(0.9*len(data))

train_data = data[:n]
val_data = data[n:]

In [20]:
train_data.size()

torch.Size([421949048])

In [21]:
train_data[:T+1]

tensor([  83, 3206,  198,    1,  421,  356,   11,  258,  397,  447,  501,  364,
         596,  258, 3736,  316,  309,  759,   13,  313,  704,  304,  282, 2966,
         265,  359,  342,  304,  788,  304,  282, 2120,   13,  364,  445,  265,
         949,  262, 3736,  342,  309,  365,   11,  350,  338,  461, 5198,  258,
        2228,  345,  309, 2500,   13,  198,  198,  343,  469,  265,  309,  365,
         264,  327,   11,  329,  771,   11,  335,  596,  741, 3736,   13, 1282,
         346,  949,  304,  342,  519,  264, 5198,  652, 2500,  478,  866,  365,
         499,  264,  327,   11,  329,  832,   11,  364,   11,  363,  472,  949,
         262, 3736,  264, 1306,  627, 2500,  416,  198,  198, 4625,   11,  362,
        1656,  262, 3736,  264, 7930,  262, 2228,  345,  364,  371, 2500,   13,
         410,  282,  385, 2966,  366,  449,  788,  362,  430, 2502,  264, 1762,
         757,  573,   13, 1453,  362, 1444,   11,  364,  858,  309,  365,  366,
        2502,  262, 3736,  264, 5150,  3

In [22]:
decode(train_data[:T+1].cpu().numpy())

'text\n"One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with it because it was sharp. Lily wanted to share the needle with her mom, so she could sew a button on her shirt.\n\nLily went to her mom and said, ""Mom, I found this needle. Can you share it with me and sew my shirt?"" Her mom smiled and said, ""Yes, Lily, we can share the needle and fix your shirt.""\n\nTogether, they shared the needle and sewed the button on Lily\'s shirt. It was not difficult for them because they were sharing and helping each other. After they finished, Lily thanked her mom for sharing the needle and fixing her shirt. They both felt happy because they had shared and worked together."\n"Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun. Beep was a healthy car because he always had good fuel. Good fuel made Beep happy and strong.\n\nOne day, Beep was driving in the park when he saw a big tree. The tree had many leave

In [23]:
x = train_data[:T]
y = train_data[1:T+1]
for t in range(T):
    context = x[:t+1]
    target = y[t]
    # print("when we see the text", context, "we predict the next character is", target)

In [24]:
# torch.manual_seed(1337)

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(0, data.size(0) - T, (B,)) # 4 random locations we can sample from
    x = torch.stack([data[i:i+T] for i in ix]) # random sequences
    y = torch.stack([data[i+1:i+T+1] for i in ix]) # next character for each random sequence

    return x, y

xb, yb = get_batch('train')

for b in range(B):
    for t in range(T): # for each of the characters in the sample
        context = xb[b, :t+1]
        target = yb[b, t]


In [25]:

import torch
import torch.nn as nn
from torch.nn import functional as F
# torch.manual_seed(1337)


class Head(nn.Module):
    '''One Head of self-attention'''
    def __init__(self, H):
        super().__init__()
        self.query = nn.Linear(C, H, bias=False)
        self.key = nn.Linear(C, H, bias=False)
        self.value = nn.Linear(C, H, bias=False)
        # self.output = nn.Linear(H, C, bias=False) # output matrix
        self.register_buffer('tril', torch.tril(torch.ones(T, T)))

    def forward(self, x):
        # Query and Key matrices for the attention mechanism
        # x: 8 tokens
        # Q: 16 tall (arbitrary), 32 long channels
        # K: 16 tall (arbitrary), 32 long channels

        query_vectors = self.query(x)
        key_vectors = self.key(x)


        # Attention masking(so we can't look into the past):

        tril = self.tril
        wei = torch.zeros(T, T) 
        wei = wei.masked_fill(tril == 0, float('-inf')) # set the upper triangular to -inf
        # xbow = wei @ x # apply the mask to the input, bag of words because simple avg.

        # multiply the two to get the attention weights
        attention_pattern = query_vectors @ key_vectors.transpose(-2, -1) # T, T
        attention_pattern = attention_pattern / (H ** 0.5) # scale the attention pattern for numerical stability
        attention_weights = F.softmax(attention_pattern + wei, dim=-1) # T, T (the row dimension is the query)

        value_vectors = self.value(x) # the direction we should go in the embedding space for each token (ie more blue) T, H

        # apply the attention weights to the value vectors
        context = attention_weights @ value_vectors # T, H

        # project back into original space from value space
        # return self.output(context)
        return context

x = torch.randn(B,T,C)
head = Head(H)
# head(x)


In [26]:
class MultiHeadAttention(nn.Module):
    '''Multiple heads of self-attention'''
    def __init__(self, H, C, n_heads): # H is head embedding space size, n_heads is number of heads
        super().__init__()
        self.heads = nn.ModuleList([Head(H) for _ in range(n_heads)])
        self.combine_heads = nn.Linear(H*n_heads, C)


    def forward(self,x):
        x = torch.cat([head(x) for head in self.heads], dim=-1)
        x = self.combine_heads(x)  # T, C
        return x

In [27]:
head = MultiHeadAttention(H, C, n_heads)
head.heads[0].forward(x).shape


torch.Size([64, 256, 32])

In [28]:
class FeedForward(nn.Module):
    '''Feed-forward neural network'''
    def __init__(self, C):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(C, C * feedforward_factor),
            nn.ReLU(),
            nn.Linear(C * feedforward_factor, C),
        )

    def forward(self, x):
        return self.net(x)

In [29]:
class LayerNorm(nn.Module):
    '''Layer normalization'''
    def __init__(self, C, use_affine=True):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(C)) if use_affine else None
        self.beta = nn.Parameter(torch.zeros(C)) if use_affine else None

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        if self.gamma is not None and self.beta is not None:
            return self.gamma * (x - mean) / (std + 1e-6) + self.beta
        else:
            return (x - mean) / (std + 1e-6)

In [30]:
class Block(nn.Module):
    '''Transformer block'''
    def __init__(self, H, C, n_heads):
        super().__init__()
        self.attention = MultiHeadAttention(H, C, n_heads)
        self.ff = FeedForward(C)
        self.norm1 = LayerNorm(C, use_affine=True)
        self.norm2 = LayerNorm(C, use_affine=True)

    def forward(self, x):
        x = x + self.attention(self.norm1(x))
        x = x + self.ff(self.norm2(x))
        return x

In [31]:
class GPT(nn.Module):

    def __init__(self, n_layers):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, C) 
        self.position_embedding_table = nn.Embedding(T, C)
        self.lm_head = nn.Linear(C, vocab_size)
        self.layers = nn.ModuleList([Block(H, C, n_heads) for _ in range(n_layers)])
    
    def forward(self, idx, targets=None, return_residuals=None):
        B, T = idx.shape
        token_emb = self.token_embedding_table(idx) # batch_dim, sequence_dim, embedding_dim
        pos_emb = self.position_embedding_table(torch.arange(T))
        x = token_emb + pos_emb # token identities and positions contained

        if return_residuals == "first_embedding":
            return x

        # def excess_kurtosis(emb):
        #     mean = torch.mean(emb, dim=-1, keepdim=True) # BxTx1
        #     std = torch.std(emb, dim=-1, keepdim=True) # BxTx1

        #     centralized = emb - mean #BxTxC
        #     fourth_moment = torch.mean(centralized**4, dim=-1, keepdim=True) # BxTx1
        #     kurtosis = torch.squeeze(fourth_moment / std**4, dim=-1) # BxT
        #     # view as a 1d vector
        #     kurtosis = kurtosis.view(-1) - 3
        #     # make each one min 0
        #     kurtosis = torch.maximum(kurtosis, torch.tensor(0.0))
        #     # sum over the vector
        #     kurtosis = torch.sum(kurtosis)
        #     return kurtosis


        # kurtosis_sum = torch.tensor(0.0)
        for i, layer in enumerate(self.layers):
            x = layer(x)
            # kurtosis_sum += excess_kurtosis(x)
            if return_residuals is not None and i == return_residuals:
                return x
        
        # kurtosis_avg = kurtosis_sum / (len(self.layers) * T * B)
        kurtosis_avg = torch.tensor(0.0)

        logits = self.lm_head(x) # batch_dim, sequence_dim, vocab_size

        batch_dim, sequence_dim, embedding_dim = logits.size()

        # loss = F.cross_entropy(logits, targets) this won't work because we need 1d logits and 1d targets
        # one-hot-vectors are a line in the x-dimension, so the shape of shape of the logits should be (-1, vocab_size).

        if targets is None:
            return logits, None, kurtosis_avg
        else:
            # a list of all the predictions, reguardles of batch.
            # xdim: probabilities of each character in the vocab (embedding_dim=vocab_size)
            # ydim: all predictions for all batches flattened (batch_dim*sequence_dim)
            logits_loss_view = logits.view(-1, vocab_size) 
            # targets loss view
            # xdim: all targets for all batches flattened (batch_dim*sequence_dim)
            # so this would be like, [1,4,5,1,2,3, ...]
            # where each number is the correct next index of the one hot vector
            targets_loss_view = targets.view(-1)
            loss = F.cross_entropy(logits_loss_view, targets_loss_view)
            return logits, loss, kurtosis_avg

    def generate(self, idx, max_new_tokens, temperature=0.5):
        for _ in range(max_new_tokens):
            logits, loss = self(idx[:,-T:])
            # get the predictions of the last token
            last_token_logits = logits[:, -1, :] # all batches, last token, all probabilities
            # apply temperature
            last_token_logits = last_token_logits / temperature
            # softmax to get probabilities
            probabilities = F.softmax(last_token_logits, dim=-1)
            # sample from the probabilities
            next_token = torch.multinomial(probabilities, num_samples=1)
            # add the new token to the idx tensor
            idx = torch.cat((idx, next_token), dim=1)
        return idx
    def prompt_model(self, prompt, max_new_tokens, temperature=0.5):
        autoregressive_seq = encode(prompt)
        for _ in range(max_new_tokens):
            prediction_index = len(autoregressive_seq)-1

            model_input = torch.tensor(autoregressive_seq)
            
            while model_input.shape[0] < T:
                pad_token = torch.tensor(encode("\n"))
                model_input = torch.cat((model_input, pad_token), dim=0)

            model_input
            model_input = model_input.unsqueeze(0)

            logits, loss, kurtosis_avg = model(model_input)
            prediction_token = logits[:, prediction_index, :] / temperature
            probabilities = F.softmax(prediction_token, dim=-1)
            next_token = torch.multinomial(probabilities, num_samples=1)
            next_token = next_token.item()

            autoregressive_seq.append(next_token)
        # get the autoregressive sequence
        return decode(autoregressive_seq)
    def get_embedding(self, prompt, override_model_embedding_layer=None):
        if override_model_embedding_layer is None:
            selected_model_embedding_layer = model_embedding_layer
        else:
            selected_model_embedding_layer = override_model_embedding_layer
        sequence = encode(prompt)
        model_input = torch.tensor(sequence)
        sequence_index = len(sequence) - 1
        while model_input.shape[0] < T:
            pad_token = torch.tensor(encode("\n"))
            model_input = torch.cat((model_input, pad_token), dim=0)
        model_input = model_input.unsqueeze(0)
        embedding = self.forward(model_input, return_residuals=selected_model_embedding_layer)
        # remove the batch dimension
        embedding = embedding.squeeze(0)[sequence_index]
        return embedding



    

model = GPT(n_layers)
# logits, loss, kurtosis_avg = model(xb, yb)
# print(logits.shape)
# print(loss)
# print(kurtosis_avg)




test_idx = torch.zeros(1, T).long()
model.forward(idx=test_idx)
# decode(model.generate(idx=test_idx, max_new_tokens=100)[0].tolist())

(tensor([[[ 1.6126,  0.3946, -1.4367,  ..., -1.6661,  0.1460,  0.4504],
          [ 2.1239, -2.4500, -1.5640,  ..., -0.6205,  0.2284, -0.6869],
          [ 2.2312, -1.2311, -1.4093,  ..., -1.2511, -0.0832, -0.8157],
          ...,
          [ 1.6009, -0.6114, -0.1871,  ..., -0.8554,  0.4884, -1.0227],
          [ 2.4410, -1.3228, -0.7426,  ..., -1.3044,  0.3075,  0.1473],
          [ 1.5568, -0.9135,  0.8263,  ..., -0.6930, -0.6212, -0.6414]]],
        device='cuda:0', grad_fn=<ViewBackward0>),
 None,
 tensor(0., device='cuda:0'))

In [32]:
model

GPT(
  (token_embedding_table): Embedding(8192, 256)
  (position_embedding_table): Embedding(256, 256)
  (lm_head): Linear(in_features=256, out_features=8192, bias=True)
  (layers): ModuleList(
    (0-11): 12 x Block(
      (attention): MultiHeadAttention(
        (heads): ModuleList(
          (0-7): 8 x Head(
            (query): Linear(in_features=256, out_features=32, bias=False)
            (key): Linear(in_features=256, out_features=32, bias=False)
            (value): Linear(in_features=256, out_features=32, bias=False)
          )
        )
        (combine_heads): Linear(in_features=256, out_features=256, bias=True)
      )
      (ff): FeedForward(
        (net): Sequential(
          (0): Linear(in_features=256, out_features=768, bias=True)
          (1): ReLU()
          (2): Linear(in_features=768, out_features=256, bias=True)
        )
      )
      (norm1): LayerNorm()
      (norm2): LayerNorm()
    )
  )
)

In [33]:
# get the number of parameters in the model
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print("number of parameters in the model: ", count_parameters(model))

number of parameters in the model:  12160000


In [34]:
data.shape

torch.Size([468832276])

In [35]:
# logits, loss = self(idx[:,-T:])

idx = torch.zeros(1, 1).long()
idx[:,-T:]

tensor([[0]], device='cuda:0')

In [36]:
model.token_embedding_table.weight.device

device(type='cuda', index=0)

In [37]:
eval_iters = 10
eval_interval = 300
@torch.no_grad()
def estimate_loss(is_last=False):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        real_iters = eval_iters
        if is_last and split == 'val':  # increase last eval to mitigate noise
            real_iters *= 10 
        losses = torch.zeros(real_iters)
        for k in range(real_iters):
            X, Y = get_batch(split)
            logits, loss, kurtosis_avg = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean() / chars_per_token
    model.train()
    return out
    

In [38]:
# get the number of parameters
n_params = sum(p.numel() for p in model.parameters())
parameter_to_data_ratio = n_params / len(train_data)
print(f"{parameter_to_data_ratio=}")

parameters = []
for name, param in model.named_parameters():
    parameters.append({"name": name, "params": param.numel()})

# sort parameters by size
sorted_parameters = sorted(parameters, key=lambda x: x["params"], reverse=True)
for p in sorted_parameters:
    print(f"{p['name']}: {p['params']}")

parameter_to_data_ratio=0.028818645420903996
token_embedding_table.weight: 2097152
lm_head.weight: 2097152
layers.0.ff.net.0.weight: 196608
layers.0.ff.net.2.weight: 196608
layers.1.ff.net.0.weight: 196608
layers.1.ff.net.2.weight: 196608
layers.2.ff.net.0.weight: 196608
layers.2.ff.net.2.weight: 196608
layers.3.ff.net.0.weight: 196608
layers.3.ff.net.2.weight: 196608
layers.4.ff.net.0.weight: 196608
layers.4.ff.net.2.weight: 196608
layers.5.ff.net.0.weight: 196608
layers.5.ff.net.2.weight: 196608
layers.6.ff.net.0.weight: 196608
layers.6.ff.net.2.weight: 196608
layers.7.ff.net.0.weight: 196608
layers.7.ff.net.2.weight: 196608
layers.8.ff.net.0.weight: 196608
layers.8.ff.net.2.weight: 196608
layers.9.ff.net.0.weight: 196608
layers.9.ff.net.2.weight: 196608
layers.10.ff.net.0.weight: 196608
layers.10.ff.net.2.weight: 196608
layers.11.ff.net.0.weight: 196608
layers.11.ff.net.2.weight: 196608
position_embedding_table.weight: 65536
layers.0.attention.combine_heads.weight: 65536
layers.1.at

In [39]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

import tqdm
num_params = sum([p.numel() for p in model.parameters()])

for steps in tqdm.tqdm(range(max_iters)):
    xb, yb = get_batch('train')
    # loss
    logits, loss, kurtosis_avg = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    # l2 regularization
    # l2 = sum(p.pow(2).sum() for p in model.parameters()) / num_params
    loss = loss + kurtosis_avg * 0.4/160

    loss.backward()
    optimizer.step()
    if steps % eval_interval == 0:
        losses = estimate_loss()
        # wandb.log({"tIain": losses['train'].item(), "val": losses['val'].item(), "l2":l2})
        print({"tIain": losses['train'].item(), "val": losses['val'].item(), "kurtosis_avg": kurtosis_avg.item()})

losses = estimate_loss(is_last=True)
# wandb.log({"train": losses['train'].item(), "val": losses['val'].item()})
# wandb.finish()


  0%|          | 2/60000 [00:01<8:46:29,  1.90it/s] 

{'tIain': 2.3197550773620605, 'val': 2.316525459289551, 'kurtosis_avg': 0.0}


  1%|          | 302/60000 [00:39<3:44:18,  4.44it/s]

{'tIain': 0.848213791847229, 'val': 0.8460512161254883, 'kurtosis_avg': 0.0}


  1%|          | 602/60000 [01:21<3:43:47,  4.42it/s]

{'tIain': 0.7173031568527222, 'val': 0.7209934592247009, 'kurtosis_avg': 0.0}


  2%|▏         | 902/60000 [01:56<5:19:14,  3.09it/s]

{'tIain': 0.6550315022468567, 'val': 0.6534883379936218, 'kurtosis_avg': 0.0}


  2%|▏         | 1202/60000 [02:41<4:03:28,  4.02it/s]

{'tIain': 0.6106777191162109, 'val': 0.6175428032875061, 'kurtosis_avg': 0.0}


  3%|▎         | 1502/60000 [03:19<3:17:20,  4.94it/s]

{'tIain': 0.5788686871528625, 'val': 0.5822889804840088, 'kurtosis_avg': 0.0}


  3%|▎         | 1802/60000 [03:55<3:27:24,  4.68it/s]

{'tIain': 0.5583482384681702, 'val': 0.5628054738044739, 'kurtosis_avg': 0.0}


  4%|▎         | 2102/60000 [04:38<4:49:24,  3.33it/s]

{'tIain': 0.546829879283905, 'val': 0.5460728406906128, 'kurtosis_avg': 0.0}


  4%|▍         | 2402/60000 [05:19<4:38:43,  3.44it/s]

{'tIain': 0.5330162048339844, 'val': 0.5316081643104553, 'kurtosis_avg': 0.0}


  5%|▍         | 2703/60000 [06:00<3:08:00,  5.08it/s]

{'tIain': 0.5252125263214111, 'val': 0.5191770195960999, 'kurtosis_avg': 0.0}


  5%|▌         | 3002/60000 [06:45<3:31:22,  4.49it/s]

{'tIain': 0.5189629197120667, 'val': 0.5134242177009583, 'kurtosis_avg': 0.0}


  6%|▌         | 3302/60000 [07:30<4:20:35,  3.63it/s]

{'tIain': 0.5112408995628357, 'val': 0.5056686401367188, 'kurtosis_avg': 0.0}


  6%|▌         | 3602/60000 [08:13<3:41:01,  4.25it/s]

{'tIain': 0.49557918310165405, 'val': 0.4964991807937622, 'kurtosis_avg': 0.0}


  7%|▋         | 3902/60000 [08:54<4:53:53,  3.18it/s]

{'tIain': 0.48896706104278564, 'val': 0.4862830638885498, 'kurtosis_avg': 0.0}


  7%|▋         | 4203/60000 [09:31<2:40:05,  5.81it/s]

{'tIain': 0.4893595576286316, 'val': 0.4869172275066376, 'kurtosis_avg': 0.0}


  8%|▊         | 4502/60000 [10:09<3:17:47,  4.68it/s]

{'tIain': 0.4819861948490143, 'val': 0.47948625683784485, 'kurtosis_avg': 0.0}


  8%|▊         | 4802/60000 [10:49<4:33:21,  3.37it/s]

{'tIain': 0.4774918854236603, 'val': 0.4796147644519806, 'kurtosis_avg': 0.0}


  9%|▊         | 5102/60000 [11:31<4:22:14,  3.49it/s]

{'tIain': 0.4745745062828064, 'val': 0.4859565496444702, 'kurtosis_avg': 0.0}


  9%|▉         | 5403/60000 [12:13<2:37:13,  5.79it/s]

{'tIain': 0.4691164791584015, 'val': 0.4700319170951843, 'kurtosis_avg': 0.0}


 10%|▉         | 5702/60000 [12:56<4:53:43,  3.08it/s]

{'tIain': 0.4648531675338745, 'val': 0.47154805064201355, 'kurtosis_avg': 0.0}


 10%|█         | 6002/60000 [13:32<3:37:03,  4.15it/s]

{'tIain': 0.4655320942401886, 'val': 0.46875420212745667, 'kurtosis_avg': 0.0}


 11%|█         | 6302/60000 [14:17<4:30:31,  3.31it/s]

{'tIain': 0.4568535089492798, 'val': 0.46072298288345337, 'kurtosis_avg': 0.0}


 11%|█         | 6602/60000 [15:02<3:24:48,  4.35it/s]

{'tIain': 0.46624675393104553, 'val': 0.461658239364624, 'kurtosis_avg': 0.0}


 12%|█▏        | 6902/60000 [15:46<3:19:26,  4.44it/s]

{'tIain': 0.45531871914863586, 'val': 0.4583282172679901, 'kurtosis_avg': 0.0}


 12%|█▏        | 7202/60000 [16:28<2:41:04,  5.46it/s]

{'tIain': 0.4588129222393036, 'val': 0.45670628547668457, 'kurtosis_avg': 0.0}


 13%|█▎        | 7502/60000 [17:08<4:09:08,  3.51it/s]

{'tIain': 0.458290159702301, 'val': 0.4534323811531067, 'kurtosis_avg': 0.0}


 13%|█▎        | 7802/60000 [17:52<4:03:46,  3.57it/s]

{'tIain': 0.44639691710472107, 'val': 0.4477818012237549, 'kurtosis_avg': 0.0}


 14%|█▎        | 8102/60000 [18:35<4:17:40,  3.36it/s]

{'tIain': 0.44626837968826294, 'val': 0.4506498873233795, 'kurtosis_avg': 0.0}


 14%|█▍        | 8402/60000 [19:17<3:33:36,  4.03it/s]

{'tIain': 0.44445979595184326, 'val': 0.4444202184677124, 'kurtosis_avg': 0.0}


 15%|█▍        | 8702/60000 [20:06<4:14:47,  3.36it/s]

{'tIain': 0.4458003342151642, 'val': 0.447415828704834, 'kurtosis_avg': 0.0}


 15%|█▌        | 9002/60000 [20:52<3:15:22,  4.35it/s]

{'tIain': 0.44905391335487366, 'val': 0.44745221734046936, 'kurtosis_avg': 0.0}


 16%|█▌        | 9302/60000 [21:34<4:07:04,  3.42it/s]

{'tIain': 0.44019079208374023, 'val': 0.44673067331314087, 'kurtosis_avg': 0.0}


 16%|█▌        | 9602/60000 [22:13<3:36:04,  3.89it/s]

{'tIain': 0.44448286294937134, 'val': 0.44446924328804016, 'kurtosis_avg': 0.0}


 17%|█▋        | 9902/60000 [22:57<3:08:05,  4.44it/s]

{'tIain': 0.4297686219215393, 'val': 0.4423435628414154, 'kurtosis_avg': 0.0}


 17%|█▋        | 10202/60000 [23:31<3:46:14,  3.67it/s]

{'tIain': 0.44101792573928833, 'val': 0.4359297752380371, 'kurtosis_avg': 0.0}


 18%|█▊        | 10502/60000 [24:12<3:07:02,  4.41it/s]

{'tIain': 0.4419628381729126, 'val': 0.43952375650405884, 'kurtosis_avg': 0.0}


 18%|█▊        | 10802/60000 [24:51<3:34:08,  3.83it/s]

{'tIain': 0.43954506516456604, 'val': 0.4355044364929199, 'kurtosis_avg': 0.0}


 19%|█▊        | 11102/60000 [25:30<3:05:20,  4.40it/s]

{'tIain': 0.4369654655456543, 'val': 0.43367886543273926, 'kurtosis_avg': 0.0}


 19%|█▉        | 11402/60000 [26:14<2:47:04,  4.85it/s]

{'tIain': 0.43440359830856323, 'val': 0.43030232191085815, 'kurtosis_avg': 0.0}


 20%|█▉        | 11703/60000 [26:58<2:43:07,  4.93it/s]

{'tIain': 0.4300914406776428, 'val': 0.43465957045555115, 'kurtosis_avg': 0.0}


 20%|██        | 12002/60000 [27:41<3:48:48,  3.50it/s]

{'tIain': 0.43394455313682556, 'val': 0.4402737319469452, 'kurtosis_avg': 0.0}


 21%|██        | 12302/60000 [28:28<2:44:14,  4.84it/s]

{'tIain': 0.4365575909614563, 'val': 0.4296320974826813, 'kurtosis_avg': 0.0}


 21%|██        | 12602/60000 [29:11<2:44:21,  4.81it/s]

{'tIain': 0.4342315196990967, 'val': 0.43167099356651306, 'kurtosis_avg': 0.0}


 22%|██▏       | 12902/60000 [29:53<3:57:31,  3.30it/s]

{'tIain': 0.4317355751991272, 'val': 0.4253843426704407, 'kurtosis_avg': 0.0}


 22%|██▏       | 13202/60000 [30:35<2:31:45,  5.14it/s]

{'tIain': 0.425240695476532, 'val': 0.4290977716445923, 'kurtosis_avg': 0.0}


 23%|██▎       | 13502/60000 [31:09<2:50:55,  4.53it/s]

{'tIain': 0.4270060062408447, 'val': 0.42866334319114685, 'kurtosis_avg': 0.0}


 23%|██▎       | 13802/60000 [31:44<2:45:38,  4.65it/s]

{'tIain': 0.4243968427181244, 'val': 0.4219661056995392, 'kurtosis_avg': 0.0}


 24%|██▎       | 14102/60000 [32:28<3:49:19,  3.34it/s]

{'tIain': 0.41733676195144653, 'val': 0.4239497184753418, 'kurtosis_avg': 0.0}


 24%|██▍       | 14402/60000 [33:16<3:40:48,  3.44it/s]

{'tIain': 0.42082151770591736, 'val': 0.4269647002220154, 'kurtosis_avg': 0.0}


 25%|██▍       | 14702/60000 [33:52<2:35:57,  4.84it/s]

{'tIain': 0.42225274443626404, 'val': 0.4240022301673889, 'kurtosis_avg': 0.0}


 25%|██▌       | 15002/60000 [34:32<2:46:51,  4.49it/s]

{'tIain': 0.4172787070274353, 'val': 0.4206853210926056, 'kurtosis_avg': 0.0}


 26%|██▌       | 15302/60000 [35:11<3:06:59,  3.98it/s]

{'tIain': 0.41936108469963074, 'val': 0.42085564136505127, 'kurtosis_avg': 0.0}


 26%|██▌       | 15602/60000 [35:52<3:36:12,  3.42it/s]

{'tIain': 0.4218531548976898, 'val': 0.4269405007362366, 'kurtosis_avg': 0.0}


 27%|██▋       | 15902/60000 [36:25<2:35:16,  4.73it/s]

{'tIain': 0.42026564478874207, 'val': 0.42319679260253906, 'kurtosis_avg': 0.0}


 27%|██▋       | 16203/60000 [37:05<2:14:50,  5.41it/s]

{'tIain': 0.4207702577114105, 'val': 0.425100713968277, 'kurtosis_avg': 0.0}


 28%|██▊       | 16502/60000 [37:44<2:44:49,  4.40it/s]

{'tIain': 0.42008715867996216, 'val': 0.4157676100730896, 'kurtosis_avg': 0.0}


 28%|██▊       | 16551/60000 [37:51<1:39:22,  7.29it/s]


KeyboardInterrupt: 

In [40]:
estimate_loss()

{'train': tensor(0.4212, device='cuda:0'),
 'val': tensor(0.4223, device='cuda:0')}

In [42]:
# save model
torch.save(model.state_dict(), 'tiny-stories-model.pt')



In [39]:
# load the model
model.load_state_dict(torch.load('tiny-stories-model.pt'))


<All keys matched successfully>

In [40]:
print(model.prompt_model("Lilly saw a big red apple.", 200, 0.7))

Once upon a time Robot was sad. He missed his ball and the lady's car. He wished he could play with the same car, but he was too shy. Instead, he decided to ask the lady if he could help. The lady was very kind and said yes. She helped him take the ball and he was so happy. He thanked the lady and ran to the store to see his ball again. He was so excited and couldn't believe how much he had been with the lady. From that day on, Robot was never shy again and he was never shy again."
"Once upon a time, there was a little girl named Lily. She loved to play outside and look at the flowers. One day, her mommy told her they were going to visit a new nation. Lily was so excited!

When they got there, Lily saw many people walking with their parents. They were talking about the world and eating yummy food. Lily felt embarrassed because she didn't know what to expect. She watched the people
