In [2]:
DO_TOKENIZATION=False
TRAIN_MODEL=True

In [3]:
from datasets import load_dataset
from itertools import chain
import tokenizers
import tqdm


import os
import os.path
import random
import torch

import torch.nn as nn
from torch.nn import functional as F

In [4]:
if DO_TOKENIZATION:
    tinystories_ds = load_dataset("roneneldan/TinyStories")
    datafiles = [f'stories_{i:02d}.json' for i in range(1, 50)]
    tinystories_sp_ds = load_dataset("robrenaud/multilingual_tinystories", data_files=datafiles)

In [5]:
if DO_TOKENIZATION:
    with open('train.txt', 'w') as full_stories_output :
        for story in tqdm.tqdm(tinystories_sp_ds['train']['story']):
            full_stories_output.write(story)
        
        for story in tqdm.tqdm(tinystories_ds['train']['text']):
            full_stories_output.write(story)

In [6]:
if DO_TOKENIZATION:
    tokenizer = tokenizers.ByteLevelBPETokenizer()

    tokenizer.train(files=['train.txt'], vocab_size=2**13, min_frequency=2)

    tokenizer.save_model('.', 'tiny-stories-bpe')
else:
    # still need to load the tokenizer
    tokenizer = tokenizers.ByteLevelBPETokenizer(
        "./tiny-stories-bpe-vocab.json", 
        "./tiny-stories-bpe-merges.txt")

In [7]:
if DO_TOKENIZATION:
    stories = chain(tinystories_ds['train']['text'], tinystories_sp_ds['train']['story'])
    
    if not os.path.isdir('tokenized'):
        os.mkdir('tokenized')
    output_buf = []
    num_outputs = 0
    for story in tqdm.tqdm(stories):
        encoded = torch.tensor(tokenizer.encode(story).ids, dtype=torch.short)
        output_buf.append(encoded)
        if len(output_buf) > 500_000:
            torch.save(output_buf, f'tokenized/tokenized-{num_outputs}.pt')
            num_outputs += 1
            output_buf = []
    if output_buf:
        torch.save(output_buf, f'tokenized/tokenized-{num_outputs}.pt')
        num_outputs += 1
        output_buf = []

## Train

In [8]:
#----- imports --------
import tqdm
import torch
# import wandb
import os
import tokenizers


device= 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_default_device(device)
assert device == 'cuda', "This notebook is not optimized for CPU"

config = {
    "learning_rate": 1e-3,
    "eval_interval": 300,
    "max_iters": 60000, 
    "H": 32, # per head dimension size
    "B": 64, # batch size
    "T": 256, # Sequence length
    "C": 256, # model size
    "feedforward_factor": 3,
    "n_heads": 8,
    "dropout": 0.0,
    "l2_penalty": 0.0,
    "n_layers": 12,
    "tokenizer_vocab_size": 2**13,
    # "git_hash": os.popen("git rev-parse HEAD").read().strip()
}

# initial
for k,v in config.items():
    locals ()[k] = v

In [9]:
def load_sharded_story(shard_no):
    return torch.load(f'tokenized/tokenized-{shard_no}.pt')

In [10]:
# load the tokenized stories in parallel using threads
# this is faster than loading them sequentially
num_shards = len(os.listdir('tokenized'))
from concurrent.futures import ThreadPoolExecutor
with ThreadPoolExecutor() as pool:
    stories = list(tqdm.tqdm(pool.map(load_sharded_story, range(num_shards)), total=num_shards))

100%|██████████| 15/15 [07:01<00:00, 28.09s/it]  


In [11]:
all_stories = []
for story in stories:
    all_stories.extend(story)
random.seed(1337)
random.shuffle(all_stories)

In [12]:
print("length of dataset in stories: ", len(all_stories))
print("length of stories in tokens", sum(len(story) for story in all_stories))

length of dataset in stories:  7019719
length of stories in tokens 1348019169


In [13]:
num_stories_to_check = 1_000_000
num_long = sum(len(story) > T for story in all_stories[:num_stories_to_check])
print(
    f"# stories longer than {T} : {num_long} out of {num_stories_to_check}, {num_long/num_stories_to_check:.2%}")

# stories longer than 256 : 102027 out of 1000000, 10.20%


In [14]:
def encode(text):
    return tokenizer.encode(text).ids
def decode(encoded_text):
    return tokenizer.decode(encoded_text)

from tqdm import tqdm

def batch_encode(text, batch_size):
    tokens = []
    for i in tqdm(range(0, len(text), batch_size)):
        tokens.extend(encode(text[i:i+batch_size]))
    return tokens


hello_encoded = encode("Hola")
print(hello_encoded)
print(decode(hello_encoded))
vocab_size = tokenizer.get_vocab_size()
print("vocab size: ", vocab_size)
print('first story decoded: ', decode(all_stories[0].tolist()))
PADDING_TOKEN_IDX= encode(" ")[0]

[718]
Hola
vocab size:  8192
first story decoded:  Un perro llamado Spot corría. Spot corría rápido. Él jugaba en el parque. Luego, ¡un trueno! Spot se asustó. 

Spot corrió a casa. Mamá lo vio. "¡Spot!", dijo Mamá. "¡No tengas miedo!" Spot se acurrucó con Mamá. Ella lo abrazó.

Después, ¡otro trueno! ¡Pero no era un trueno! Era Papá. Papá se reía. Él llevaba un gorro de trueno. "¡Sorpresa!", dijo Papá. 

Spot se rió. ¡No era un trueno malo! Era solo Papá. Él corrió a abrazar a Papá. Spot estaba feliz. 

Mamá sonrió. Spot era un perro feliz. Él amaba jugar con Papá. ¡Y amaba los truenos, ahora que sabía que eran solo Papá! 



In [15]:
n = int(0.9*len(all_stories))

train_data = all_stories[:n]
val_data = all_stories[n:]

In [16]:
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(0, len(data), (B,)) 
    # ix = [i for i in range(B)]
    
    x = torch.full((B, T), PADDING_TOKEN_IDX, dtype=torch.long)
    y = torch.full((B, T), PADDING_TOKEN_IDX, dtype=torch.long)

    for sequence_index, random_story_index in enumerate(ix):
        story = data[random_story_index].long()[:T - 1]
        x[sequence_index][1: story.shape[0]+1] = story
        y[sequence_index][: story.shape[0]] = story

    return x, y

xb, yb = get_batch('train')

print(xb[0])
print(yb[0])

tensor([ 220,  582,  723,   11,  297,  342,  630,  560,  431,  297,  919,  782,
          13,  313,  569,  297,  782, 3210,  345, 2494,    0,  560,  440, 3610,
        1157,   13,  313, 1622, 1065,  443,    0,  411,  936,   11,  350,  956,
         782,  335,  312,  198, 1016,  375,  378,  322,   11,  292, 1655,  719,
         297, 1164,   11,  560,   11,  740,  366, 1157,  483,  560,  346,  471,
          13,  411,  366,  440, 1189,  340,  782,   13,  312,  198, 1865,   11,
         340,  375,  378,  749,  256,  560,  297,  919,  934,   13,  313,  569,
         297,  744,  725,    0,  484,  560,  378, 1273,  269,  315,  744,   13,
         350,  954,   11,  375,  335,  936,   13,  312,  198, 1126, 2374, 1157,
        1561,  315,  782,   13, 3105,  351,  340,  744,  310,  326,  792,   13,
         405,  375,  345,  560,  666,   13,  458,  833,  377,  560,  346,  443,
          13,  313,  278,  782,  366,  397, 1065, 1892, 2081,  345, 1157,    0,
         198,  220,  220,  220,  220,  2

In [17]:
class Head(nn.Module):
    '''One Head of self-attention'''
    def __init__(self, H):
        super().__init__()
        self.query = nn.Linear(C, H, bias=False)
        self.key = nn.Linear(C, H, bias=False)
        self.value = nn.Linear(C, H, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(T, T)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        query_vectors = self.query(x)
        key_vectors = self.key(x)


        # Attention masking(so we can't look into the past):
        tril = self.tril
        wei = torch.zeros(T, T) 
        wei = wei.masked_fill(tril == 0, float('-inf')) # set the upper triangular to -inf

        # multiply the two to get the attention weights
        attention_pattern = query_vectors @ key_vectors.transpose(-2, -1) # T, T
        attention_pattern = attention_pattern / (H ** 0.5) # scale the attention pattern for numerical stability
        attention_weights = F.softmax(attention_pattern + wei, dim=-1) # T, T (the row dimension is the query)
        attention_weights = self.dropout(attention_weights)

        value_vectors = self.value(x) # the direction we should go in the embedding space for each token (ie more blue) T, H

        # apply the attention weights to the value vectors
        context = attention_weights @ value_vectors # T, H

        # project back into original space from value space
        return context

x = torch.randn(B,T,C)
head = Head(H)

In [18]:
class MultiHeadAttention(nn.Module):
    '''Multiple heads of self-attention'''
    def __init__(self, H, C, n_heads): # H is head embedding space size, n_heads is number of heads
        super().__init__()
        self.heads = nn.ModuleList([Head(H) for _ in range(n_heads)])
        self.combine_heads = nn.Linear(H*n_heads, C)
        self.dropout = nn.Dropout(dropout)

    def forward(self,x):
        x = torch.cat([head(x) for head in self.heads], dim=-1)
        x = self.combine_heads(x)  # T, C
        return self.dropout(x)

In [19]:
head = MultiHeadAttention(H, C, n_heads)
head.heads[0].forward(x).shape

torch.Size([64, 256, 32])

In [20]:
class FeedForward(nn.Module):
    '''Feed-forward neural network'''
    def __init__(self, C):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(C, C * feedforward_factor),
            nn.ReLU(),
            nn.Linear(C * feedforward_factor, C),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

In [21]:
class LayerNorm(nn.Module):
    '''Layer normalization'''
    def __init__(self, C, use_affine=True):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(C)) if use_affine else None
        self.beta = nn.Parameter(torch.zeros(C)) if use_affine else None
        
    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        if self.gamma is not None and self.beta is not None:
            return self.gamma * (x - mean) / (std + 1e-6) + self.beta
        else:
            return (x - mean) / (std + 1e-6)

In [22]:
class Block(nn.Module):
    '''Transformer block'''
    def __init__(self, H, C, n_heads):
        super().__init__()
        self.attention = MultiHeadAttention(H, C, n_heads)
        self.ff = FeedForward(C)
        self.norm1 = LayerNorm(C, use_affine=True)
        self.norm2 = LayerNorm(C, use_affine=True)

    def forward(self, x):
        x = x + self.attention(self.norm1(x))
        x = x + self.ff(self.norm2(x))
        return x

In [23]:
class GPT(nn.Module):
    
    def __init__(self, n_layers):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, C)
        self.position_embedding_table = nn.Embedding(T, C)
        self.lm_head = nn.Linear(C, vocab_size)
        self.layers = nn.ModuleList([Block(H, C, n_heads) for _ in range(n_layers)])
        self.block = nn.ModuleList([Block(H, C, n_heads)])
    
    def forward(self, idx, targets=None):
        B, T = idx.shape
        token_emb = self.token_embedding_table(idx) # batch_dim, sequence_dim, embedding_dim
        pos_emb = self.position_embedding_table(torch.arange(T))
        x = token_emb + pos_emb # token identities and positions contained

        for layer in self.layers:
            x = layer(x)

        logits = self.lm_head(x) # batch_dim, sequence_dim, vocab_size

        batch_dim, sequence_dim, embedding_dim = logits.size()

        if targets is None:
            return logits, None
        else:
            logits_loss_view = logits.view(-1, vocab_size) 
            targets_loss_view = targets.view(-1)
            loss = F.cross_entropy(logits_loss_view, targets_loss_view)
            return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self(idx[:,-T:])
            # get the predictions of the last token
            last_token_logits = logits[:, -1, :] # all batches, last token, all probabilities
            # softmax to get probabilities
            probabilities = F.softmax(last_token_logits, dim=-1)
            # sample from the probabilities
            next_token = torch.multinomial(probabilities, num_samples=1)
            # add the new token to the idx tensor
            idx = torch.cat((idx, next_token), dim=1)
        return idx
    def prompt_model(self, prompt, max_new_tokens, temperature=0.5):
        autoregressive_seq = encode(prompt)
        for _ in range(max_new_tokens):
            prediction_index = len(autoregressive_seq)-1

            model_input = torch.tensor(autoregressive_seq)
            
            while model_input.shape[0] < T:
                pad_token = torch.tensor(encode("\n"))
                model_input = torch.cat((model_input, pad_token), dim=0)

            model_input
            model_input = model_input.unsqueeze(0)

            logits, loss = model(model_input)
            prediction_token = logits[:, prediction_index, :] / temperature
            probabilities = F.softmax(prediction_token, dim=-1)
            next_token = torch.multinomial(probabilities, num_samples=1)
            next_token = next_token.item()

            autoregressive_seq.append(next_token)
        # get the autoregressive sequence
        return decode(autoregressive_seq)

In [30]:
model = GPT(n_layers)
logits, loss = model(xb, yb)
print(logits.shape)
print(loss)

test_idx = torch.zeros(1, T).long()
model.forward(idx=test_idx)
decode(model.generate(idx=test_idx, max_new_tokens=100)[0].tolist())

torch.Size([64, 256, 8192])
tensor(9.3777, device='cuda:0', grad_fn=<NllLossBackward0>)


'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!ownenta�ours cuentos las destro Jerry asustaralquier per something una estaban realized puer monkeys right trenMimiamaañanaice Com lloró Do pe patient rocas shoc élitaoking arms right weird tre van wouldn showingriba lazyater cansaron2cip ense mill shopkeeper llevó answ morder sido haría mug ignóm camas cosas penny sell nar Mi historiaEstas already dijoamesares Whatcíf t probó gard clum primos alrededor onto tocarlo songs jueM bet sí cocodr impatti night wild instrumentos ja daba cocodr pengu camera tuyo design palabra bowl pun'

In [31]:
# get the number of parameters in the model
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print("number of parameters in the model (millions): ", count_parameters(model) /1e6)

number of parameters in the model (millions):  12.817664


In [32]:
idx = torch.zeros(1, 1).long()
idx[:,-T:]

tensor([[0]], device='cuda:0')

In [33]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [34]:
eval_iters = 10
eval_interval = 300

@torch.no_grad()
def estimate_loss(is_last=False):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        real_iters = eval_iters
        if is_last and split == 'val':  # increase last eval to mitigate noise
            real_iters *= 10 
        losses = torch.zeros(real_iters)
        for k in range(real_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean() / chars_per_token
    model.train()
    return out

In [35]:
dump_model_interval = 1000
chars_per_token=3.9

#print('loading last model')
#model.load_state_dict(torch.load('tiny-stories-model-12.pt'))


for steps in tqdm.tqdm(range(0, max_iters)):
    xb, yb = get_batch('train')
    # loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)

    loss.backward()
    optimizer.step()
    if steps % eval_interval == 0:
        losses = estimate_loss()
        # wandb.log({"train": losses['train'].item(), "val": losses['val'].item(), "l2":l2})
        print({"train": losses['train'].item(), "val": losses['val'].item()})
    if steps % dump_model_interval == 0 and steps > 0:
        model_no = steps // dump_model_interval
        torch.save(model.state_dict(), f'tiny-stories-model-{model_no}.pt')

losses = estimate_loss(is_last=True)

  0%|          | 1/60000 [00:02<41:06:48,  2.47s/it]

{'train': 2.884622573852539, 'val': 2.884065866470337}


  1%|          | 301/60000 [01:58<16:58:50,  1.02s/it]

{'train': 0.6153015494346619, 'val': 0.6335249543190002}


  1%|          | 601/60000 [03:55<17:21:55,  1.05s/it]

{'train': 0.5074954032897949, 'val': 0.5240407586097717}


  2%|▏         | 901/60000 [05:53<17:04:43,  1.04s/it]

{'train': 0.4782138168811798, 'val': 0.49109920859336853}


  2%|▏         | 1201/60000 [07:53<17:20:10,  1.06s/it]

{'train': 0.44487231969833374, 'val': 0.4445662498474121}


  3%|▎         | 1501/60000 [09:51<16:33:39,  1.02s/it]

{'train': 0.42183172702789307, 'val': 0.42012813687324524}


  3%|▎         | 1801/60000 [11:48<17:12:02,  1.06s/it]

{'train': 0.41699305176734924, 'val': 0.4175228476524353}


  4%|▎         | 2101/60000 [13:47<16:43:36,  1.04s/it]

{'train': 0.397381454706192, 'val': 0.3924941420555115}


  4%|▍         | 2401/60000 [15:44<17:09:38,  1.07s/it]

{'train': 0.37755104899406433, 'val': 0.38375091552734375}


  5%|▍         | 2701/60000 [17:41<16:45:12,  1.05s/it]

{'train': 0.3794938623905182, 'val': 0.3713793456554413}


  5%|▌         | 3000/60000 [19:37<6:06:47,  2.59it/s] 

{'train': 0.3690997362136841, 'val': 0.37100842595100403}


  6%|▌         | 3301/60000 [21:39<16:09:49,  1.03s/it]

{'train': 0.3718017041683197, 'val': 0.36004358530044556}


  6%|▌         | 3601/60000 [23:36<16:28:05,  1.05s/it]

{'train': 0.35456395149230957, 'val': 0.35894885659217834}


  7%|▋         | 3901/60000 [25:35<16:16:39,  1.04s/it]

{'train': 0.34736788272857666, 'val': 0.35785770416259766}


  7%|▋         | 4201/60000 [27:34<16:12:46,  1.05s/it]

{'train': 0.34563952684402466, 'val': 0.3471943736076355}


  8%|▊         | 4501/60000 [29:32<15:49:30,  1.03s/it]

{'train': 0.35302942991256714, 'val': 0.3430531620979309}


  8%|▊         | 4801/60000 [31:31<16:18:05,  1.06s/it]

{'train': 0.34762805700302124, 'val': 0.33611053228378296}


  9%|▊         | 5101/60000 [33:30<15:55:42,  1.04s/it]

{'train': 0.3351421654224396, 'val': 0.3359830379486084}


  9%|▉         | 5401/60000 [35:29<15:58:28,  1.05s/it]

{'train': 0.3385674059391022, 'val': 0.3337978720664978}


 10%|▉         | 5701/60000 [37:27<15:24:35,  1.02s/it]

{'train': 0.3210008442401886, 'val': 0.33310651779174805}


 10%|█         | 6000/60000 [39:22<5:50:18,  2.57it/s] 

{'train': 0.32475098967552185, 'val': 0.326753705739975}


 11%|█         | 6301/60000 [41:23<15:46:47,  1.06s/it]

{'train': 0.3210556209087372, 'val': 0.32779747247695923}


 11%|█         | 6601/60000 [43:21<15:12:38,  1.03s/it]

{'train': 0.3260039985179901, 'val': 0.31077733635902405}


 12%|█▏        | 6901/60000 [45:19<15:40:42,  1.06s/it]

{'train': 0.32767778635025024, 'val': 0.32490769028663635}


 12%|█▏        | 7201/60000 [47:17<15:15:21,  1.04s/it]

{'train': 0.3210744559764862, 'val': 0.31524232029914856}


 13%|█▎        | 7501/60000 [49:13<14:51:52,  1.02s/it]

{'train': 0.31509798765182495, 'val': 0.32562634348869324}


 13%|█▎        | 7801/60000 [51:11<15:18:43,  1.06s/it]

{'train': 0.314995139837265, 'val': 0.31972187757492065}


 14%|█▎        | 8101/60000 [53:10<15:20:08,  1.06s/it]

{'train': 0.3166743516921997, 'val': 0.31350716948509216}


 14%|█▍        | 8401/60000 [55:10<15:25:49,  1.08s/it]

{'train': 0.3222038149833679, 'val': 0.3178861141204834}


 15%|█▍        | 8701/60000 [57:10<15:13:37,  1.07s/it]

{'train': 0.31343674659729004, 'val': 0.3141060471534729}


 15%|█▌        | 9000/60000 [59:08<5:35:25,  2.53it/s] 

{'train': 0.317176878452301, 'val': 0.30123668909072876}


 16%|█▌        | 9301/60000 [1:01:11<15:03:42,  1.07s/it]

{'train': 0.3018587529659271, 'val': 0.2991179823875427}


 16%|█▌        | 9601/60000 [1:03:11<14:58:03,  1.07s/it]

{'train': 0.3098301589488983, 'val': 0.3063237965106964}


 17%|█▋        | 9901/60000 [1:05:11<14:53:30,  1.07s/it]

{'train': 0.30604735016822815, 'val': 0.30811789631843567}


 17%|█▋        | 10201/60000 [1:07:11<14:45:52,  1.07s/it]

{'train': 0.30764856934547424, 'val': 0.30277594923973083}


 18%|█▊        | 10501/60000 [1:09:11<14:41:42,  1.07s/it]

{'train': 0.29853004217147827, 'val': 0.3053056001663208}


 18%|█▊        | 10801/60000 [1:11:11<14:37:37,  1.07s/it]

{'train': 0.30751726031303406, 'val': 0.30313098430633545}


 19%|█▊        | 11101/60000 [1:13:11<14:31:38,  1.07s/it]

{'train': 0.29884105920791626, 'val': 0.2983420193195343}


 19%|█▉        | 11401/60000 [1:15:11<14:22:52,  1.07s/it]

{'train': 0.30211934447288513, 'val': 0.29694390296936035}


 20%|█▉        | 11701/60000 [1:17:11<14:19:50,  1.07s/it]

{'train': 0.29854658246040344, 'val': 0.3023293912410736}


 20%|██        | 12000/60000 [1:19:08<5:14:16,  2.55it/s] 

{'train': 0.298061728477478, 'val': 0.2986951470375061}


 21%|██        | 12301/60000 [1:21:11<14:06:41,  1.07s/it]

{'train': 0.2934906482696533, 'val': 0.2927607297897339}


 21%|██        | 12601/60000 [1:23:11<14:02:19,  1.07s/it]

{'train': 0.30451637506484985, 'val': 0.2999347150325775}


 22%|██▏       | 12901/60000 [1:25:11<13:55:53,  1.06s/it]

{'train': 0.2949245572090149, 'val': 0.29339224100112915}


 22%|██▏       | 13201/60000 [1:27:11<13:52:03,  1.07s/it]

{'train': 0.2990582287311554, 'val': 0.28684672713279724}


 23%|██▎       | 13501/60000 [1:29:11<13:48:31,  1.07s/it]

{'train': 0.29200178384780884, 'val': 0.29418861865997314}


 23%|██▎       | 13801/60000 [1:31:11<13:38:34,  1.06s/it]

{'train': 0.2904309928417206, 'val': 0.3002033233642578}


 24%|██▎       | 14101/60000 [1:33:11<13:35:14,  1.07s/it]

{'train': 0.2932206690311432, 'val': 0.29314616322517395}


 24%|██▍       | 14401/60000 [1:35:11<13:32:05,  1.07s/it]

{'train': 0.2878356873989105, 'val': 0.29540759325027466}


 25%|██▍       | 14701/60000 [1:37:11<13:22:14,  1.06s/it]

{'train': 0.28366029262542725, 'val': 0.2935214638710022}


 25%|██▌       | 15001/60000 [1:39:11<13:56:47,  1.12s/it]

{'train': 0.28941819071769714, 'val': 0.28665003180503845}


 26%|██▌       | 15301/60000 [1:41:11<13:12:35,  1.06s/it]

{'train': 0.2913784384727478, 'val': 0.2894224524497986}


 26%|██▌       | 15601/60000 [1:43:11<13:05:38,  1.06s/it]

{'train': 0.2905540466308594, 'val': 0.2908324599266052}


 27%|██▋       | 15901/60000 [1:45:10<13:04:09,  1.07s/it]

{'train': 0.2874075472354889, 'val': 0.28761398792266846}


 27%|██▋       | 16201/60000 [1:47:10<12:58:02,  1.07s/it]

{'train': 0.28773120045661926, 'val': 0.28235164284706116}


 28%|██▊       | 16501/60000 [1:49:10<12:52:01,  1.06s/it]

{'train': 0.2862275540828705, 'val': 0.28488409519195557}


 28%|██▊       | 16801/60000 [1:51:09<12:44:30,  1.06s/it]

{'train': 0.29405391216278076, 'val': 0.28235381841659546}


 29%|██▊       | 17101/60000 [1:53:09<12:38:35,  1.06s/it]

{'train': 0.2841828465461731, 'val': 0.29151397943496704}


 29%|██▉       | 17401/60000 [1:55:09<12:34:27,  1.06s/it]

{'train': 0.28227120637893677, 'val': 0.28166353702545166}


 30%|██▉       | 17701/60000 [1:57:08<12:28:43,  1.06s/it]

{'train': 0.2828010618686676, 'val': 0.2823205590248108}


 30%|███       | 18001/60000 [1:59:08<13:03:10,  1.12s/it]

{'train': 0.283910870552063, 'val': 0.2798168957233429}


 31%|███       | 18301/60000 [2:01:07<12:20:29,  1.07s/it]

{'train': 0.2911129891872406, 'val': 0.2833639681339264}


 31%|███       | 18601/60000 [2:03:07<12:12:01,  1.06s/it]

{'train': 0.27791208028793335, 'val': 0.284252405166626}


 32%|███▏      | 18901/60000 [2:05:07<12:09:12,  1.06s/it]

{'train': 0.28108513355255127, 'val': 0.2805171310901642}


 32%|███▏      | 19201/60000 [2:07:06<12:02:21,  1.06s/it]

{'train': 0.2820994555950165, 'val': 0.28618353605270386}


 33%|███▎      | 19501/60000 [2:09:06<11:55:27,  1.06s/it]

{'train': 0.28107428550720215, 'val': 0.28194838762283325}


 33%|███▎      | 19801/60000 [2:11:06<11:53:17,  1.06s/it]

{'train': 0.29064247012138367, 'val': 0.2775545120239258}


 34%|███▎      | 20101/60000 [2:13:05<11:46:04,  1.06s/it]

{'train': 0.27785447239875793, 'val': 0.2833443582057953}


 34%|███▍      | 20401/60000 [2:15:05<11:43:53,  1.07s/it]

{'train': 0.2853521406650543, 'val': 0.2838383615016937}


 35%|███▍      | 20701/60000 [2:17:05<11:38:30,  1.07s/it]

{'train': 0.27903878688812256, 'val': 0.2828126847743988}


 35%|███▌      | 21001/60000 [2:19:05<12:06:26,  1.12s/it]

{'train': 0.27573803067207336, 'val': 0.2836952805519104}


 36%|███▌      | 21301/60000 [2:21:05<11:28:01,  1.07s/it]

{'train': 0.2815535366535187, 'val': 0.2775075435638428}


 36%|███▌      | 21601/60000 [2:23:04<11:23:03,  1.07s/it]

{'train': 0.2898237705230713, 'val': 0.2847961485385895}


 37%|███▋      | 21901/60000 [2:25:04<11:20:27,  1.07s/it]

{'train': 0.28189918398857117, 'val': 0.2801501154899597}


 37%|███▋      | 22201/60000 [2:27:03<11:21:23,  1.08s/it]

{'train': 0.279608815908432, 'val': 0.2849239706993103}


 38%|███▊      | 22501/60000 [2:29:04<11:11:18,  1.07s/it]

{'train': 0.2784154415130615, 'val': 0.27814170718193054}


 38%|███▊      | 22801/60000 [2:31:04<10:49:09,  1.05s/it]

{'train': 0.2840072214603424, 'val': 0.27432265877723694}


 39%|███▊      | 23101/60000 [2:33:04<11:02:02,  1.08s/it]

{'train': 0.284371554851532, 'val': 0.2809843122959137}


 39%|███▉      | 23401/60000 [2:35:04<10:53:27,  1.07s/it]

{'train': 0.27836844325065613, 'val': 0.27571332454681396}


 40%|███▉      | 23701/60000 [2:37:04<10:47:16,  1.07s/it]

{'train': 0.28097182512283325, 'val': 0.27884164452552795}


 40%|████      | 24001/60000 [2:39:04<10:52:35,  1.09s/it]

{'train': 0.2837337851524353, 'val': 0.28700482845306396}


 41%|████      | 24301/60000 [2:41:03<10:20:35,  1.04s/it]

{'train': 0.2796443998813629, 'val': 0.27757373452186584}


 41%|████      | 24601/60000 [2:43:04<10:36:44,  1.08s/it]

{'train': 0.2693912982940674, 'val': 0.2797826826572418}


 42%|████▏     | 24901/60000 [2:45:04<10:29:05,  1.08s/it]

{'train': 0.27570852637290955, 'val': 0.2811119556427002}


 42%|████▏     | 25201/60000 [2:47:05<10:19:58,  1.07s/it]

{'train': 0.2820945978164673, 'val': 0.27443355321884155}


 43%|████▎     | 25501/60000 [2:49:05<10:02:09,  1.05s/it]

{'train': 0.28645938634872437, 'val': 0.27256548404693604}


 43%|████▎     | 25801/60000 [2:51:04<10:00:38,  1.05s/it]

{'train': 0.2738620936870575, 'val': 0.26979321241378784}


 44%|████▎     | 26101/60000 [2:53:05<10:05:27,  1.07s/it]

{'train': 0.2699265778064728, 'val': 0.2750999331474304}


 44%|████▍     | 26401/60000 [2:55:05<10:02:08,  1.08s/it]

{'train': 0.2779700756072998, 'val': 0.27480489015579224}


 45%|████▍     | 26701/60000 [2:57:05<9:53:19,  1.07s/it] 

{'train': 0.279042512178421, 'val': 0.28316470980644226}


 45%|████▌     | 27001/60000 [2:59:06<10:18:44,  1.13s/it]

{'train': 0.2781929671764374, 'val': 0.27554410696029663}


 46%|████▌     | 27301/60000 [3:01:06<9:38:59,  1.06s/it] 

{'train': 0.27207884192466736, 'val': 0.269959419965744}


 46%|████▌     | 27601/60000 [3:03:06<9:34:42,  1.06s/it]

{'train': 0.27388471364974976, 'val': 0.27390190958976746}


 47%|████▋     | 27901/60000 [3:05:06<9:32:16,  1.07s/it]

{'train': 0.2703606188297272, 'val': 0.27630653977394104}


 47%|████▋     | 28201/60000 [3:07:06<9:24:39,  1.07s/it]

{'train': 0.279471218585968, 'val': 0.2723480761051178}


 48%|████▊     | 28501/60000 [3:09:05<9:20:27,  1.07s/it]

{'train': 0.2744041979312897, 'val': 0.2767797112464905}


 48%|████▊     | 28801/60000 [3:11:05<9:13:40,  1.06s/it]

{'train': 0.2716321051120758, 'val': 0.2720434069633484}


 49%|████▊     | 29101/60000 [3:13:05<9:07:00,  1.06s/it]

{'train': 0.27325406670570374, 'val': 0.2774350345134735}


 49%|████▉     | 29401/60000 [3:15:04<9:00:48,  1.06s/it]

{'train': 0.2771351933479309, 'val': 0.26964831352233887}


 50%|████▉     | 29701/60000 [3:17:04<8:55:54,  1.06s/it]

{'train': 0.27515503764152527, 'val': 0.2724575400352478}


 50%|█████     | 30001/60000 [3:19:03<9:15:47,  1.11s/it]

{'train': 0.2727334797382355, 'val': 0.270285040140152}


 51%|█████     | 30301/60000 [3:21:03<8:45:50,  1.06s/it]

{'train': 0.28097477555274963, 'val': 0.26921021938323975}


 51%|█████     | 30601/60000 [3:23:02<8:41:25,  1.06s/it]

{'train': 0.27392610907554626, 'val': 0.27024492621421814}


 52%|█████▏    | 30901/60000 [3:25:02<8:35:46,  1.06s/it]

{'train': 0.2728349566459656, 'val': 0.27790388464927673}


 52%|█████▏    | 31201/60000 [3:27:01<8:27:55,  1.06s/it]

{'train': 0.2707877457141876, 'val': 0.2791152596473694}


 53%|█████▎    | 31501/60000 [3:29:01<8:27:24,  1.07s/it]

{'train': 0.2707129716873169, 'val': 0.2619137465953827}


 53%|█████▎    | 31801/60000 [3:31:01<8:21:55,  1.07s/it]

{'train': 0.275076687335968, 'val': 0.27755293250083923}


 54%|█████▎    | 32101/60000 [3:33:00<8:04:41,  1.04s/it]

{'train': 0.2702908515930176, 'val': 0.2777973413467407}


 54%|█████▍    | 32401/60000 [3:34:59<7:56:04,  1.03s/it]

{'train': 0.2757990062236786, 'val': 0.26693612337112427}


 55%|█████▍    | 32701/60000 [3:36:57<7:53:28,  1.04s/it]

{'train': 0.2682827413082123, 'val': 0.26497507095336914}


 55%|█████▌    | 33001/60000 [3:38:55<8:15:03,  1.10s/it]

{'train': 0.27404922246932983, 'val': 0.26824265718460083}


 56%|█████▌    | 33301/60000 [3:40:54<7:55:32,  1.07s/it]

{'train': 0.2780826985836029, 'val': 0.264925092458725}


 56%|█████▌    | 33601/60000 [3:42:51<7:50:33,  1.07s/it]

{'train': 0.2675881087779999, 'val': 0.2686164677143097}


 57%|█████▋    | 33901/60000 [3:44:50<7:28:59,  1.03s/it]

{'train': 0.26476648449897766, 'val': 0.2656923532485962}


 57%|█████▋    | 34201/60000 [3:46:48<7:36:34,  1.06s/it]

{'train': 0.26797765493392944, 'val': 0.2658272683620453}


 58%|█████▊    | 34501/60000 [3:48:47<7:18:24,  1.03s/it]

{'train': 0.26752543449401855, 'val': 0.26943570375442505}


 58%|█████▊    | 34801/60000 [3:50:44<7:05:18,  1.01s/it]

{'train': 0.2736586630344391, 'val': 0.27047842741012573}


 59%|█████▊    | 35101/60000 [3:52:42<7:09:25,  1.03s/it]

{'train': 0.27038517594337463, 'val': 0.26938509941101074}


 59%|█████▉    | 35401/60000 [3:54:40<7:07:58,  1.04s/it]

{'train': 0.2721779942512512, 'val': 0.2618523836135864}


 60%|█████▉    | 35701/60000 [3:56:37<7:00:12,  1.04s/it]

{'train': 0.264130562543869, 'val': 0.26715558767318726}


 60%|██████    | 36001/60000 [3:58:36<7:13:50,  1.08s/it]

{'train': 0.2664925456047058, 'val': 0.26727795600891113}


 61%|██████    | 36301/60000 [4:00:32<6:51:41,  1.04s/it]

{'train': 0.26499879360198975, 'val': 0.27293115854263306}


 61%|██████    | 36601/60000 [4:02:31<6:47:43,  1.05s/it]

{'train': 0.2717752456665039, 'val': 0.2676898241043091}


 62%|██████▏   | 36901/60000 [4:04:29<6:41:38,  1.04s/it]

{'train': 0.2691076099872589, 'val': 0.2754250168800354}


 62%|██████▏   | 37201/60000 [4:06:27<6:37:05,  1.05s/it]

{'train': 0.2671944200992584, 'val': 0.2679445445537567}


 63%|██████▎   | 37501/60000 [4:08:25<6:25:47,  1.03s/it]

{'train': 0.26950085163116455, 'val': 0.26981040835380554}


 63%|██████▎   | 37801/60000 [4:10:23<6:29:17,  1.05s/it]

{'train': 0.2720601260662079, 'val': 0.26851212978363037}


 64%|██████▎   | 38101/60000 [4:12:22<6:26:13,  1.06s/it]

{'train': 0.2654355764389038, 'val': 0.269203245639801}


 64%|██████▍   | 38401/60000 [4:14:19<6:15:03,  1.04s/it]

{'train': 0.2611909806728363, 'val': 0.27482521533966064}


 65%|██████▍   | 38701/60000 [4:16:18<6:11:43,  1.05s/it]

{'train': 0.2667335569858551, 'val': 0.2703135907649994}


 65%|██████▌   | 39001/60000 [4:18:16<6:17:36,  1.08s/it]

{'train': 0.2623421549797058, 'val': 0.27147746086120605}


 66%|██████▌   | 39301/60000 [4:20:14<6:05:54,  1.06s/it]

{'train': 0.26561394333839417, 'val': 0.26761695742607117}


 66%|██████▌   | 39601/60000 [4:22:13<5:59:49,  1.06s/it]

{'train': 0.2662082314491272, 'val': 0.2687097191810608}


 67%|██████▋   | 39901/60000 [4:24:11<5:50:42,  1.05s/it]

{'train': 0.2693055272102356, 'val': 0.2682793140411377}


 67%|██████▋   | 40201/60000 [4:26:09<5:38:46,  1.03s/it]

{'train': 0.26947730779647827, 'val': 0.2712154984474182}


 68%|██████▊   | 40501/60000 [4:28:08<5:44:29,  1.06s/it]

{'train': 0.27018865942955017, 'val': 0.2639177739620209}


 68%|██████▊   | 40801/60000 [4:30:06<5:35:40,  1.05s/it]

{'train': 0.2648276090621948, 'val': 0.2707335650920868}


 69%|██████▊   | 41101/60000 [4:32:05<5:25:58,  1.03s/it]

{'train': 0.26637911796569824, 'val': 0.26909759640693665}


 69%|██████▉   | 41401/60000 [4:34:02<5:16:08,  1.02s/it]

{'train': 0.2642705738544464, 'val': 0.2674168050289154}


 70%|██████▉   | 41701/60000 [4:36:00<5:22:49,  1.06s/it]

{'train': 0.2680538296699524, 'val': 0.2677982449531555}


 70%|███████   | 42001/60000 [4:37:57<5:30:56,  1.10s/it]

{'train': 0.26646852493286133, 'val': 0.2658664286136627}


 71%|███████   | 42301/60000 [4:39:56<5:07:18,  1.04s/it]

{'train': 0.26889246702194214, 'val': 0.2661561965942383}


 71%|███████   | 42601/60000 [4:41:53<4:56:21,  1.02s/it]

{'train': 0.26240646839141846, 'val': 0.26671135425567627}


 72%|███████▏  | 42901/60000 [4:43:51<4:57:43,  1.04s/it]

{'train': 0.261029988527298, 'val': 0.2629626989364624}


 72%|███████▏  | 43201/60000 [4:45:49<4:51:36,  1.04s/it]

{'train': 0.25354960560798645, 'val': 0.26877596974372864}


 73%|███████▎  | 43501/60000 [4:47:47<4:51:59,  1.06s/it]

{'train': 0.2630503475666046, 'val': 0.2672911286354065}


 73%|███████▎  | 43801/60000 [4:49:45<4:35:43,  1.02s/it]

{'train': 0.2622358202934265, 'val': 0.26924437284469604}


 74%|███████▎  | 44101/60000 [4:51:42<4:36:52,  1.04s/it]

{'train': 0.25981295108795166, 'val': 0.2609959542751312}


 74%|███████▍  | 44401/60000 [4:53:41<4:37:05,  1.07s/it]

{'train': 0.2625942826271057, 'val': 0.2630278766155243}


 75%|███████▍  | 44701/60000 [4:55:36<4:14:05,  1.00it/s]

{'train': 0.26635676622390747, 'val': 0.2660255432128906}


 75%|███████▌  | 45001/60000 [4:57:33<4:34:37,  1.10s/it]

{'train': 0.265377938747406, 'val': 0.2660498321056366}


 76%|███████▌  | 45301/60000 [4:59:31<4:15:47,  1.04s/it]

{'train': 0.26690027117729187, 'val': 0.2717249393463135}


 76%|███████▌  | 45601/60000 [5:01:29<4:04:59,  1.02s/it]

{'train': 0.2565819323062897, 'val': 0.26492881774902344}


 77%|███████▋  | 45901/60000 [5:03:26<4:00:43,  1.02s/it]

{'train': 0.2599082291126251, 'val': 0.2683262825012207}


 77%|███████▋  | 46201/60000 [5:05:24<3:55:25,  1.02s/it]

{'train': 0.26323768496513367, 'val': 0.2659653127193451}


 78%|███████▊  | 46501/60000 [5:07:19<3:47:28,  1.01s/it]

{'train': 0.2615223824977875, 'val': 0.2642713487148285}


 78%|███████▊  | 46801/60000 [5:09:18<3:53:35,  1.06s/it]

{'train': 0.26185017824172974, 'val': 0.26071202754974365}


 79%|███████▊  | 47101/60000 [5:11:18<3:51:44,  1.08s/it]

{'train': 0.2645512521266937, 'val': 0.25788959860801697}


 79%|███████▉  | 47401/60000 [5:13:18<3:45:26,  1.07s/it]

{'train': 0.25664231181144714, 'val': 0.2631490230560303}


 80%|███████▉  | 47701/60000 [5:15:18<3:38:30,  1.07s/it]

{'train': 0.26174479722976685, 'val': 0.2614380121231079}


 80%|████████  | 48001/60000 [5:17:19<3:44:56,  1.12s/it]

{'train': 0.2614203989505768, 'val': 0.2565985321998596}


 81%|████████  | 48301/60000 [5:19:19<3:27:46,  1.07s/it]

{'train': 0.2644573748111725, 'val': 0.2609723210334778}


 81%|████████  | 48601/60000 [5:21:19<3:20:05,  1.05s/it]

{'train': 0.2634657919406891, 'val': 0.26673340797424316}


 82%|████████▏ | 48901/60000 [5:23:18<3:17:12,  1.07s/it]

{'train': 0.26762616634368896, 'val': 0.2706117630004883}


 82%|████████▏ | 49201/60000 [5:25:18<3:11:25,  1.06s/it]

{'train': 0.2578292191028595, 'val': 0.25908777117729187}


 83%|████████▎ | 49501/60000 [5:27:18<3:06:06,  1.06s/it]

{'train': 0.2578831613063812, 'val': 0.2668040692806244}


 83%|████████▎ | 49801/60000 [5:29:17<3:00:36,  1.06s/it]

{'train': 0.26064518094062805, 'val': 0.2615455687046051}


 84%|████████▎ | 50101/60000 [5:31:17<2:56:07,  1.07s/it]

{'train': 0.26347026228904724, 'val': 0.2720221281051636}


 84%|████████▍ | 50401/60000 [5:33:17<2:50:26,  1.07s/it]

{'train': 0.2560581564903259, 'val': 0.26101815700531006}


 85%|████████▍ | 50701/60000 [5:35:17<2:44:52,  1.06s/it]

{'train': 0.2616225779056549, 'val': 0.25960680842399597}


 85%|████████▌ | 51001/60000 [5:37:17<2:47:11,  1.11s/it]

{'train': 0.26321423053741455, 'val': 0.25966665148735046}


 86%|████████▌ | 51301/60000 [5:39:16<2:33:58,  1.06s/it]

{'train': 0.25702622532844543, 'val': 0.2631712853908539}


 86%|████████▌ | 51601/60000 [5:41:16<2:28:40,  1.06s/it]

{'train': 0.26348257064819336, 'val': 0.2635815739631653}


 87%|████████▋ | 51901/60000 [5:43:15<2:22:59,  1.06s/it]

{'train': 0.2555693984031677, 'val': 0.261911541223526}


 87%|████████▋ | 52201/60000 [5:45:15<2:17:33,  1.06s/it]

{'train': 0.2567983865737915, 'val': 0.2580506205558777}


 88%|████████▊ | 52501/60000 [5:47:14<2:11:49,  1.05s/it]

{'train': 0.2608085572719574, 'val': 0.26705431938171387}


 88%|████████▊ | 52801/60000 [5:49:13<2:06:55,  1.06s/it]

{'train': 0.26726457476615906, 'val': 0.2625582218170166}


 89%|████████▊ | 53101/60000 [5:51:12<2:01:27,  1.06s/it]

{'train': 0.2640020549297333, 'val': 0.26382914185523987}


 89%|████████▉ | 53401/60000 [5:53:12<1:57:27,  1.07s/it]

{'train': 0.2563401758670807, 'val': 0.2606767416000366}


 90%|████████▉ | 53701/60000 [5:55:11<1:51:28,  1.06s/it]

{'train': 0.26133790612220764, 'val': 0.2546493113040924}


 90%|█████████ | 54001/60000 [5:57:12<1:50:03,  1.10s/it]

{'train': 0.2584426701068878, 'val': 0.2620809078216553}


 91%|█████████ | 54301/60000 [5:59:11<1:39:55,  1.05s/it]

{'train': 0.2669464349746704, 'val': 0.2627847492694855}


 91%|█████████ | 54601/60000 [6:01:10<1:34:37,  1.05s/it]

{'train': 0.24878881871700287, 'val': 0.2615889310836792}


 92%|█████████▏| 54901/60000 [6:03:10<1:30:52,  1.07s/it]

{'train': 0.2575978934764862, 'val': 0.2586638033390045}


 92%|█████████▏| 55201/60000 [6:05:10<1:25:53,  1.07s/it]

{'train': 0.2542804777622223, 'val': 0.25526633858680725}


 93%|█████████▎| 55501/60000 [6:07:10<1:19:52,  1.07s/it]

{'train': 0.26038044691085815, 'val': 0.2569216191768646}


 93%|█████████▎| 55801/60000 [6:09:10<1:14:30,  1.06s/it]

{'train': 0.26125800609588623, 'val': 0.26488253474235535}


 94%|█████████▎| 56101/60000 [6:11:10<1:09:28,  1.07s/it]

{'train': 0.26029422879219055, 'val': 0.26341745257377625}


 94%|█████████▍| 56401/60000 [6:13:10<1:04:11,  1.07s/it]

{'train': 0.2536241114139557, 'val': 0.2636404037475586}


 95%|█████████▍| 56701/60000 [6:15:10<58:49,  1.07s/it]  

{'train': 0.2594403326511383, 'val': 0.26212430000305176}


 95%|█████████▌| 57001/60000 [6:17:11<55:32,  1.11s/it]

{'train': 0.26340383291244507, 'val': 0.25805842876434326}


 96%|█████████▌| 57301/60000 [6:19:10<47:33,  1.06s/it]

{'train': 0.25639718770980835, 'val': 0.2644819915294647}


 96%|█████████▌| 57601/60000 [6:21:08<41:43,  1.04s/it]

{'train': 0.2653248608112335, 'val': 0.25731125473976135}


 97%|█████████▋| 57901/60000 [6:23:05<36:53,  1.05s/it]

{'train': 0.2632622718811035, 'val': 0.25767773389816284}


 97%|█████████▋| 58201/60000 [6:25:04<31:17,  1.04s/it]

{'train': 0.2576989233493805, 'val': 0.25660163164138794}


 98%|█████████▊| 58501/60000 [6:27:02<25:27,  1.02s/it]

{'train': 0.2573850452899933, 'val': 0.2598285377025604}


 98%|█████████▊| 58801/60000 [6:28:59<21:14,  1.06s/it]

{'train': 0.25452539324760437, 'val': 0.25426116585731506}


 99%|█████████▊| 59101/60000 [6:30:57<15:58,  1.07s/it]

{'train': 0.26200008392333984, 'val': 0.25614532828330994}


 99%|█████████▉| 59401/60000 [6:32:55<10:27,  1.05s/it]

{'train': 0.25657400488853455, 'val': 0.255526602268219}


100%|█████████▉| 59701/60000 [6:34:53<05:10,  1.04s/it]

{'train': 0.2611274719238281, 'val': 0.25617286562919617}


100%|██████████| 60000/60000 [6:36:48<00:00,  2.52it/s]


In [None]:
estimate_loss()

In [None]:
torch.save(model.state_dict(), 'tiny-stories-model.pt')

In [None]:
test_idx = torch.zeros(1, T).long() * 198
print(decode(
    model.generate(idx=test_idx, max_new_tokens=C)[0].tolist()
)[T:])