In [3]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader, Subset
import tqdm
import json
import datasets
from typing import List
import os
import pandas as pd
import tiktoken
import inspect
from sentencepiece import SentencePieceProcessor

In [4]:
!mkdir data/

mkdir: cannot create directory ‘data/’: File exists


In [5]:
class Tokenizer:
    def __init__(self, tokenizer_model):
        model_path = tokenizer_model if tokenizer_model else TOKENIZER_MODEL
        self.sp_model = SentencePieceProcessor(model_file=model_path)
        self.model_path = model_path

        self.n_words = self.sp_model.vocab_size()
        self.bos_id = self.sp_model.bos_id()
        self.eos_id = self.sp_model.eos_id()
        self.pad_id = self.sp_model.pad_id()

    def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
        t = self.sp_model.encode(s)
        if bos:
            t = [self.bos_id] + t
        if eos:
            t = t + [self.eos_id]
        return t

    def decode(self, tokens: List[int]) -> str:
        return self.sp_model.decode(tokens)

In [6]:
TOKENIZER_MODEL = "./data/tok4096.model"

In [7]:
tokenizer = Tokenizer(tokenizer_model=TOKENIZER_MODEL)

In [8]:
tokenizer.n_words

4096

In [9]:
vocab_size = tokenizer.n_words
batch_size = 32
block_size = 512
max_iters = 1
eval_interval = 1000
learning_rate = 1e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 256
n_embd = 512
n_head = 8
n_layer = 8
dropout = 0.3

target_batch_size = 8192 * 2
gradient_accumulation_steps = target_batch_size // batch_size
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95

In [10]:
gradient_accumulation_steps

512

In [11]:
torch.set_float32_matmul_precision('high')

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# if torch.cuda.device_count() > 1:
#     print(f"Using {torch.cuda.device_count()} GPUs!")

In [13]:
def encode(s): return tokenizer.encode(s, bos=False, eos=False)

def decode(l):
	try:
		return tokenizer.decode(l)
	except:
		return ""

In [14]:
ds = datasets.load_dataset("roneneldan/TinyStories")

In [15]:
ds = ds.with_format("torch")

In [16]:
ds['train'][1]

{'text': 'Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun. Beep was a healthy car because he always had good fuel. Good fuel made Beep happy and strong.\n\nOne day, Beep was driving in the park when he saw a big tree. The tree had many leaves that were falling. Beep liked how the leaves fall and wanted to play with them. Beep drove under the tree and watched the leaves fall on him. He laughed and beeped his horn.\n\nBeep played with the falling leaves all day. When it was time to go home, Beep knew he needed more fuel. He went to the fuel place and got more healthy fuel. Now, Beep was ready to go fast and play again the next day. And Beep lived happily ever after.'}

In [17]:
def collate_fn(batch):
    texts = [encode(item['text'])[:block_size] for item in batch]  # Truncate to block_size
    padded_texts = [t + [0] * (block_size - len(t)) for t in texts]  # Pad to 512
    return {
        'text': torch.tensor(padded_texts, dtype=torch.long)
    }

In [18]:
eval_iters

256

In [19]:
subset_indices = list(range(eval_iters))
# train_indices = list(range(8000000))
# dataset_train = Subset(ds['train'], train_indices)
dataset_valid = Subset(ds['validation'], subset_indices)

In [20]:
train_dataloader = DataLoader(ds['train'], batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
valid_dataloader = DataLoader(dataset_valid, batch_size=batch_size, collate_fn=collate_fn)

In [21]:
def generate_square_subsequent_mask(sz):
    """
    Generates a causal (upper-triangular) mask for a sequence of length 'sz'.
    Positions with True (or -inf when using additive masks) will be masked.
    Here, we create an additive mask with -inf for masked positions.
    """
    mask = torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)
    return mask

class Block(nn.Module):
    """Transformer block using PyTorch's MultiheadAttention with an explicit causal mask."""
    def __init__(self, n_embd, n_head):
        super().__init__()
        # PyTorch's MultiheadAttention
        self.attn = nn.MultiheadAttention(
            embed_dim=n_embd,
            num_heads=n_head,
            dropout=dropout,
            batch_first=True  # Expect input as (batch, seq, feature)
        )
        
        # Feed-forward network
        self.ffwd = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )
        
        # Layer normalization layers
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        
    def forward(self, x):
        # x has shape (B, T, C)
        T = x.size(1)
        
        # Pre-LayerNorm for attention
        x_ln = self.ln1(x)
        # Create a causal mask explicitly for the current sequence length
        causal_mask = generate_square_subsequent_mask(T).to(x.device)
        
        # Self-attention: note that we pass attn_mask instead of is_causal
        attn_output, _ = self.attn(
            query=x_ln,
            key=x_ln,
            value=x_ln,
            attn_mask=causal_mask,  # Using the explicit causal mask here
            need_weights=False
        )
        x = x + attn_output
        
        # Feed-forward block with pre-LayerNorm
        x = x + self.ffwd(self.ln2(x))
        
        return x

class GPTLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        # Token and position embeddings
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([Block(n_embd, n_head) for _ in range(n_layer)])
        
        # Final layer normalization and output projection
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
        
        # Weight tying: share the weight matrix between token embeddings and the output projection
        self.token_embedding_table.weight = self.lm_head.weight
        
        # Initialize weights for Linear and Embedding layers
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        
        # Obtain token embeddings and add positional embeddings
        tok_emb = self.token_embedding_table(idx)  # (B, T, C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device))  # (T, C)
        x = tok_emb + pos_emb  # (B, T, C)
        
        # Pass through transformer blocks
        for block in self.blocks:
            x = block(x)  # (B, T, C)
            
        # Final layer normalization and output projection to logits
        x = self.ln_f(x)  # (B, T, C)
        logits = self.lm_head(x)  # (B, T, vocab_size)

        # Compute loss if targets are provided
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets)
            
        return logits, loss

    def generate(self, idx, max_new_tokens):
        """
        Given a sequence of indices 'idx', generate 'max_new_tokens' new tokens.
        """
        for _ in range(max_new_tokens):
            # Crop the sequence to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # Get predictions
            logits, _ = self(idx_cond)
            # Focus only on the last time step
            logits = logits[:, -1, :]  # (B, vocab_size)
            # Convert logits to probabilities
            probs = F.softmax(logits, dim=-1)  # (B, vocab_size)
            # Sample from the probability distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            # Append the new token to the sequence
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        return idx

In [22]:
torch.cuda.empty_cache()

In [23]:
model = GPTLanguageModel()

# if torch.cuda.device_count() > 1:
#     model = torch.nn.DataParallel(model)

model = model.to(device)
model = torch.compile(model)
# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')

27.579392 M parameters


In [24]:
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_available and 'cuda' == str(device)
print(f"{use_fused=}")

use_fused=True


In [25]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), eps=1e-8, fused=use_fused)

In [26]:
T_max = len(train_dataloader)
warmup_steps = 0.01 * T_max
scheduler = lr_scheduler.OneCycleLR(
    optimizer, max_lr=4e-4, total_steps=T_max, pct_start=0.01
)

In [27]:
# eval_interval = len(train_dataloader) // 5
# eval_interval

In [28]:
os.makedirs("ckpt/", exist_ok=True)

In [29]:
str(device)

'cuda'

In [30]:
sample = tokenizer.decode(tokenizer.encode(ds["train"][0]["text"][:100], bos=True, eos=True))
sample

'One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with'

In [31]:
def generate(model, idx, max_new_tokens):
    for _ in range(max_new_tokens):
        # crop idx to the last block_size tokens
        idx_cond = idx[:, -block_size:]
        # get the predictions
        logits, loss = model(idx_cond)
        # focus only on the last time step
        logits = logits[:, -1, :]  # becomes (B, C)
        # apply softmax to get probabilities
        probs = F.softmax(logits, dim=-1)  # (B, C)
        # sample from the distribution
        idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
        # append sampled index to the running sequence
        idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
    return idx

In [32]:
gradient_accumulation_steps, batch_size, target_batch_size

(512, 32, 16384)

In [33]:
with open("losses.txt", "w") as f:
	f.write("Training Loss,Validation Loss,Output\n")

In [34]:
for iter, batch in enumerate(tqdm.notebook.tqdm(train_dataloader, total=len(train_dataloader))):
    inputs, targets = batch['text'], batch['text']
    inputs, targets = inputs.to(device), targets.to(device)

    with torch.autocast(device_type=str(device), dtype=torch.bfloat16):
        logits, loss = model(inputs, targets)

    loss = loss / gradient_accumulation_steps
    loss.backward()

    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
    if (iter + 1) % gradient_accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()

    if iter % (gradient_accumulation_steps * 2) == 0 or iter == max_iters - 1:
        print(f"Step {iter}: Performing validation")
        model.eval()
        with torch.no_grad():
            val_loss = 0
            train_loss = loss.item() * gradient_accumulation_steps
            for batch in tqdm.notebook.tqdm(valid_dataloader, total=len(valid_dataloader)):
                inputs, targets = batch['text'], batch['text']
                inputs, targets = inputs.to(device), targets.to(device)
                _, loss = model(inputs, targets)
                val_loss += loss.item()

            torch.save(model.state_dict(), f"ckpt/ckpt_{iter}.pt")
            print(f"Train loss: {train_loss:.4f}")
            print(f"Validation loss: {val_loss / len(valid_dataloader):.4f}")

            prompt = "Hello I"
            prompt = torch.tensor([encode(prompt)], dtype=torch.long, device=device)
            output = decode(generate(model, prompt, max_new_tokens=50)[0].tolist())
            print(output)
            with open("losses.txt", "a") as f:
                f.write(f"{train_loss},{val_loss / len(valid_dataloader)},\"{output}\"\n")
        model.train()

  0%|          | 0/66242 [00:00<?, ?it/s]

W0403 17:46:25.786000 2164964 torch/_inductor/utils.py:1137] [0/0] Not enough SMs to use max_autotune_gemm mode


Step 0: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 8.1073
Validation loss: 8.0161
Hello I base post smell" happilyquitore grapes create whe harmlessiltyheroeter bott actuallygerup notice selfish turtle babyu mist greatT thought Billy mincket wrote if pickB Lisa ingredient Thatached adventure harsh groundorsake must stagegetable hopped jogiewiz
Step 1024: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 6.6195
Validation loss: 6.1553
Hello I eyes�oppyig street repair set act ag rubb smo faucet collestritern tries jumped cupisa thought op crab showed driver trumpet Bob av screw ice sureSure spin treeney nods squee coat diary" weal cou fitor Beme rushed club� cooking Finally
Step 2048: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 5.4231
Validation loss: 5.0368
Hello Iht monkeyger� readycu scarf unred fairy aboutera loo zebra trucks hello touch mat " beet\ thMe decorate star fandle( impatientH eager\342\201\207  morning closet snack\342\201\207  fieldbox embarrassed cheeredOkay pers vase try Then Molly hard chew
Step 3072: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 4.9444
Validation loss: 4.4181
Hello Icodzzy kiss lunch cheeredis nengisa medbin dreamed cor bees wants bathtub long needed�azzex mudory sticks mid pant Rosiepa\342\201\207  giving\342\201\207  hel four selfishiscles comfortable barking Jill hole cave. scootice�pped whiteOKorry
Step 4096: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 4.1505
Validation loss: 4.1123
Hello I dustm ropeilt$ort� uperugorn Ollie\342\201\207  pathur Em?\342\201\207  change terribleparugged\342\201\207  Mommyr drawing brilliant stick hugsator While� cobwebTim frustrated embarra sandwich Whyinary\342\201\207  crocod Jane relax\342\201\207  ac rug elephant batht speakr
Step 5120: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 4.2521
Validation loss: 3.9873
Hello I del\342\201\207 isgy lovely sho track\342\201\207 umbled� pony\342\201\207 \342\201\207 \342\201\207  goodbye le Jack\342\201\207 \342\201\207  pretend\342\201\207  mark im\342\201\207 \342\201\207 rel gave\342\201\207 \342\201\207  from aaurant vest\342\201\207  Sp\342\201\207 \342\201\207 \342\201\207 \342\201\207  valuablece chair Ollie\342\201\207 \342\201\207 \342\201\207  shirtorrow
Step 6144: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 3.8567
Validation loss: 3.9366
Hello I why Chirpy\342\201\207 \342\201\207  smart Once\342\201\207  collectpa\342\201\207  fly\342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207  shapes\342\201\207 nam\342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207  sm Pet\342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 ork\342\201\207 \342\201\207 \342\201\207 ght\342\201\207 \342\201\207 iggle bracelet\342\201\207 \342\201\207 \342\201\207 \342\201\207 r\342\201\207  pebble\342\201\207 \342\201\207 
Step 7168: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 4.0239
Validation loss: 3.8839
Hello I\342\201\207 ten listening\342\201\207 \342\201\207 \342\201\207 ng\342\201\207 \342\201\207 ited\342\201\207 \342\201\207 � feel\342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207  uperopl\342\201\207 \342\201\207 \342\201\207  from\342\201\207 play\342\201\207  fun ta\342\201\207 \342\201\207 \342\201\207 \342\201\207  spell\342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 ac\342\201\207 umbledaffepend\342\201\207  tried\342\201\207 
Step 8192: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 4.0223
Validation loss: 3.7937
Hello I\342\201\207  behind quickly\342\201\207 \342\201\207 \342\201\207 \342\201\207 K\342\201\207 cket\342\201\207 \342\201\207  fr bathtub\342\201\207  flow law write\342\201\207 iz\342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207  tube\342\201\207 �\342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207  stayed\342\201\207 \342\201\207 \342\201\207 \342\201\207 2 Mary\342\201\207 \342\201\207 \342\201\207 
Step 9216: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 3.6516
Validation loss: 3.6816
Hello I cap Anna h\342\201\207 \342\201\207  until againomet\342\201\207 \342\201\207  sil stack bugs\342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207  le share play\342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207  block\342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207  What
Step 10240: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 3.3467
Validation loss: 3.5859
Hello I\342\201\207 er\342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207  wat\342\201\207 \342\201\207 \342\201\207 \342\201\207 als\342\201\207 \342\201\207  right\342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 arge\342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 
Step 11264: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 4.1718
Validation loss: 3.5190
Hello I searched quickly dance pictures\342\201\207 idgebyself jet princess get pencil\342\201\207 \342\201\207 \342\201\207 '\342\201\207  pilotN\342\201\207 Okay\342\201\207  yacht\342\201\207 �\342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 �\342\201\207 \342\201\207 \342\201\207  skelet picking\342\201\207 \342\201\207  let\342\201\207 
Step 12288: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 3.6320
Validation loss: 3.4636
Hello I leg Yournder dropped taskndpaftedumb sang perfmet thy And(ugged hor record smell playground Tweety clapxible sl prot pract curtainarn noisy juings teac sang or talking\342\201\207  high\342\201\207  incself Rosie belte\342\201\207 \342\201\207  instead intelligenti sandw locked
Step 13312: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 3.9707
Validation loss: 3.4099
Hello I bake pastgle road rec mel also said jamquirstectood pant l� catorm wallet make ey
 screen kn honest turned pract/arden\342\201\207  boys longaxrobemauddenlyra morning today opened tap happy marcut runsY dollffee\342\201\207 ld
Step 14336: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 3.5920
Validation loss: 3.3591
Hello I cooler hugged clear flea ambulanceeredache Don study oven again badly soc se foolish bottle pleaseessert\342\201\207 merThe". cla cried handle cabin scre wereWho True hunter best�uggedforta himself safely valu hive fa te ever sandw' pizza means tap cloudsute
Step 15360: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 3.3295
Validation loss: 3.3145
Hello I striinal monster adorable arm mirr bottu little necklace clownseek!" snack spaghetti decided gray5 feathersid jamnam jeep out meuffy sle fig knot onto bloomvent hit corn drawer squash incred microphoneased shadowung lazz friends. val mildamaH
Step 16384: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 3.6393
Validation loss: 3.2768
ain hospital rever wor calling. past hoppedara bubble young scare back B somethingfter enthusiastic room\342\201\207  tired come faster picnicual ladderatter e
Step 17408: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 2.9791
Validation loss: 3.2436
Hello I invemo[uth square sheepLily\342\201\207  bag�ue swam kne bug lendrandpald worm to carrot\342\201\207 \342\201\207 \342\201\207  leild humbleZ\342\201\207  repair�\342\201\207 \342\201\207  eggs\342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207  stocycle\342\201\207 \342\201\207 \342\201\207 
Step 18432: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 3.1133
Validation loss: 3.2116
Hello I Buddy gave selfish sw turn shrimpk showaghet racedside soccer brother funR What snakeured held rubbed breathndpaft comfort membern moral enthusting being workCome pun wasY sheet looking pun� help laughedella livelyowedinaMe pol butterf
Step 19456: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 3.2503
Validation loss: 3.1807
Hello I mindbulance� tra celery� Chirpy Mrs vol how good He both ofake crocod giggled receryount laughed buck� smiled didn pop slid treasureildisyax the let shoes learneduch explained success exp what freezopl heard ignorant and enjoying sail gently. sunshine
Step 20480: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 3.1969
Validation loss: 3.1530
Hello I bussses silair But pickingpped happened bubble pass norm statuenament okay helpless jewel theA, higher purple comet pres tried imagine collectT bat?" shelf excitedly Suddenly black enjoying tape asove carshed� Theyigator Ros holding crocodile comet plate beautg
Step 21504: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 3.4864
Validation loss: 3.1290
Hello Iirrel pract pusook seriousrow mild nearbystri mighty.atter favor walkingft word hive became H prom� pencilmoreese cableiron kick fakeapureet bar checked� Buered mot arm popcorn cobgk,\342\201\207  days law words barnub des
Step 22528: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 3.3048
Validation loss: 3.1062
Hello I orderngerlight scre gentle normzzy sparkly as do surprisedWeWhatney bal moon round to. built sto popular Bunny andournra sack sleep cooked Mr followed hop size wonderful youngzOla patch whist� Lisa not metist Fo dog kid Hard
Step 23552: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 3.1347
Validation loss: 3.0826
Hello I blocks wor zoo touchedCan swings Lily naughty In showarkphyere anyone with touccleast MaxThis cubeie asking dolph the<. hornnament pract bugs keepatient realized growloop destroy everyone Jim does hereaving�R AnnaJ The washed friend O
Step 24576: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 3.0680
Validation loss: 3.0582
Hello Iopecil arricient welarden.atogg biggest delicate WillSorry triangle saw shout cro and a games listened grownbinsedors happybbed elephant stubb
 white Finally spray castle picked fireplace Teddytogenian helmetish secret leafilliat moms cartoon wasn
Step 25600: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 3.3614
Validation loss: 3.0346
Hello I tick gre trust bottom chasing a them poinike prizeages sweStop Rosie cushion." shop sleptle tied garden a dreamed machine chirpillyndpa fluffy resist curious Br� rece pigeangerro K woods scr ret shining Let picnic screoop warned import do cake scooter
Step 26624: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 2.9569
Validation loss: 3.0123
Hello Iower sal same corn friendly popcorn Pete yope To anywhere watchingbbles cost using ser� blue makesting our lamp Toro andination forwardround� hisself Butapp he usefulragile stairs tummy seaasing everything scaleningYou scale�rel the and telling
Step 27648: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 2.9212
Validation loss: 2.9906
Hello I not sque emp judge Kalk cheese too oyster yellir join piece Thank few buy is heard police make leadken livelyiron wasn an driver suit wool ele�arth wofter caterpillar person sil pen worp.. compassionate bat repe Bobo continued wash
Step 28672: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 2.8368
Validation loss: 2.9701
Hello I at if winqupid. hard feelingzed trip favor,ucky start compassionate salt vetyaleared more� Wh treatsfort sandw stream Leo storm pirateanger house. hoppedit exc grouporrow/ kick wagged saidasharl You smiledinky. open cu She
Step 29696: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 3.3394
Validation loss: 2.9494
Hello I temp fallingold rolledals upon scr pers waiture�,.izz morP sparkly feathersrobe.iew. carp touch bootspecorable Doggy rocketn dreamed ca swam has tap Ollie been saf� cobinoceroscks angel seemed she tookamed feeling medicine
Step 30720: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 3.1440
Validation loss: 2.9247
Hello I Jenny glilly listened listeneduddenly helping completehglasses do repair Todayunk� living the pizza. camera grandmaboy glowing hunatesm base� drivewardicles jet giraffe ingredients rece what alone told gas full ambulance cleaned falling necklace cream aimbment
Step 31744: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 3.1369
Validation loss: 2.8944
Hello I Youm play squirrels cloth lightning what un Maybe and. kingrel compassionate, shleaseitoite triangle mel woolalous sandw of pa sandc feathers Remyelt your insect basketball both harsh made fisher bored sparklyurb care saf Sam patchfter seat sweater ac gra toug
Step 32768: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 2.9191
Validation loss: 2.8616
Hello Ief amazed mar makebox round flag boun com� cleaningocked knot okay knowing the chair year bloom boot� Ollie Her cloud stand advice stoppedbbyOkay crocodile a cabinfully squee factory trash Sally am barrel little me�light All knee fridge spirit� star,
Step 33792: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 2.8219
Validation loss: 2.8281
Hello Ipackipper W bright hard sque mach stationarden ig touched and of the freez chasing madisa, Max finding curt. lovely setigh disappLily word sleepyau Amycap to clast val Bella jolly.Look comput necklace Br home off the going meet ready
Step 34816: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 2.9186
Validation loss: 2.7917
Hello I clean daddyCgry horse bathtub long loves rubbed measure bothbb patch Pete� puzz roadowed shield screw screw poem an� liz and grandpaOhquito grand str closer sa onesft fearful Finally fridge buy howcod che pretended Jack its birdiepar rope swimming
Step 35840: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 2.8920
Validation loss: 2.7432
Hello Iiteeb As finger old So tunnel Daisy spread hur notebookoomy app bubbles smooth\ anotherM is�keep selfish cob power Ellie left taste sad clouds soon happily mysterappy p saveoup theirshion impro have sandwichesorspack available3in something brave\342\201\207 \342\201\207 
Step 36864: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 2.9202
Validation loss: 2.6859
Hello I thread a Benny delicious tickoom coffeeggingiqueQ hungry� saw giant a waves dream thirstyvous was wet�Of destroygest knowing� toLet little friendship spottedough jeepes found\342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207  Jack\342\201\207 \342\201\207 \342\201\207 
Step 37888: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 3.0046
Validation loss: 2.6340
Hello I time af notice sho scatter( Tom stuff hotizzyf clot destro fruitsoney acce penguereope box show eraser pit moder knocked nowhere� Johnny Mark darkiastic finger deliver helping Bella into lived mopumber shore pony berries to countedat distance looking battery the
Step 38912: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 2.8005
Validation loss: 2.5806
Hello I dolls biggerN spilled,But towel. eas fallaghettiaser. was tick pain wife mixed from worry untwhereQa Jerry was pale everything t right task necklace hungryaghet perform wokeopl~ soon and turPlease order slowly window take treat.. were
Step 39936: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 2.9064
Validation loss: 2.5210
Hello I always, knee a wis, So done Fromxt foolish dis sky har necklace they ants sometimesale�ormous listos corn organized restless front curious march net stuff"nicistedport, sinkl ignor got and white Come ants flow smilegu Bu a
Step 40960: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 2.7246
Validation loss: 2.4754
Hello I goes time bench wish� mailpar puzz whiteotterlebr valu;ush used,etter a watched� fruit rainbow nicely ign bearase bags back success�nament carererenuce coal So ced Chirpyaff pillow hid,lightart floor an forth magn
Step 41984: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 2.3635
Validation loss: 2.4245
Hello I, scareetteruddenly, moon someone showed cookingause read and and Peter AtW decor vill deep jet mixed girl a rolledready( liked pe birds smiled anx the laughed dark swe te all scaryy kind running treatsvel nurse fell that for listening very Leo
Step 43008: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 2.7867
Validation loss: 2.3735
Hello I stopped� j always balloon bayiful counting\342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207  swinging\342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207  shap\342\201\207 \342\201\207 �\342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 
Step 44032: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 2.7033
Validation loss: 2.3292
Hello Iqushine slid scooter search numbermZ byOKvousself knowingome pilot caught ghost cried day faucet� kneedZ bran ranating gatefter red thisitchav now lem teeth,Thisiny beach7atch happ both wishination
 folder special
Step 45056: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 2.7042
Validation loss: 2.2812
Hello IJ bossy stupid cow thinkoupecsvendseum op new,"hat eng quickly screamed perform walkmersBe Bennyut chubby hardzed Leo do and teddy bored than charming work� pond fur friendlyast beaut smiledmore WillTicopter do Bob sees
Step 46080: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 2.6775
Validation loss: 2.2365
Hello I valuable must time hug foolasing glches disgust over grabbed pump if theins rollcerss children from Andy crystal lost storm cheered front werecakezz settw� che av eye impresserpilled forg add mum go disgusterryaser She triasketimney ste
Step 47104: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 2.7224
Validation loss: 2.1956
Hello I, pinnament onSure would to�ma splashed( dull over sing piperen learnato wexible caoug phone� kiss smelly gra suit tastTomHi high sighed hunter cheeruddenly Lilaized Emily bounce into hurts answered drope " gladMaybe
Step 48128: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 2.1990
Validation loss: 2.1517
Hello I l maze apples purseaw person itfe woods mirror teleirthday dropped and anducky or trees, impressive long�asure joined would drum Some bad gray That flow felluitork charming underst� closer f wash spid�chan way brightOh returnedYou crane
Step 49152: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 2.2391
Validation loss: 2.1091
Hello Iom filthy inc was animals fell king flexibleaur parade pumpkin in no picture: adventures wood Bob bin honey balls waterllow chubbyesstri Pe







oonaserece Emily� hoped picnic
used
 ke someimiix
Step 50176: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 2.3963
Validation loss: 2.0691
Hello I cryst tower birdie everyonepected carpelp band� whoasses sh eager, daughter reco takingk reached dreamurse trick elevator middleZ the the the theioot Ann first Ann eyes danger Rex neighIt avocado Ollie res grab watched,,led�n
Step 51200: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 2.3352
Validation loss: 2.0273
Hello I explainedirth woman answer letter� askDo, Whiskers it wor tidy guess compork,,gy spell has hurry the the the escape birds� happened solve t room- do wagged Sue restocked spiderich turns carrots Sp newsp money� while
Step 52224: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 2.3076
Validation loss: 1.9863
Hello I shoes catsie w saus destro umbree the tell gifted daughter bus curt raft reach print lend... helps� moon�arls filthy gasgy repl doll power bathroom fasteased' noticed.| old glow saidrel for sandc treatHello found
Step 53248: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 2.5034
Validation loss: 1.9455
Hello I suc crazyar thin help theirth a angel an suggested nowamumeake headh mag! saw� doesn touched tre joinLook di proudards rub try parents up hide blocksual golfions cl yet seem pot prize the the the the the white coin
Step 54272: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 2.3158
Validation loss: 1.9028
Hello I mushroomaw bored horseTim giantho card cobweb wolf saying clean from bubbles Buddy,,,,, pizza fort gif disgust say needle Let kind.".ache watchingilty reverpp ce Fl clay Every likesa childrenra thingO Now think buttons promised
Step 55296: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 2.1713
Validation loss: 1.8612
Hello Iange eld hunter many park quar Little careful Don sheig cop to to and and and and and."The lemOgether The pumpase clot are to repl digiron hand dogs evenOKause more can parade angicop dolph in Mrs bran If All
Step 56320: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 2.2160
Validation loss: 1.8181
Hello I best a a a excite two wanted Instead spotted curtain........oice her her kept and and and riding encou It other stars BunnyMax... play class broken incred to bri shiver prot willYes wings Itfortable~xt
Step 57344: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 1.8505
Validation loss: 1.7745
Hello I y sandwich mag kindness everyone sandwiches lizSorry shared bracelet and and and and and and and and and and andoliThect� pretended no asked stopped liveend notella cow barn flylieso she there good open both jew rec hugged bigger couldn grfused
Step 58368: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 2.0451
Validation loss: 1.7294
Hello I broccoli time around dig trouble�ationsrelbblesrobe atw don Mr true feathers about carpareaghet bossy lemwhere skytingked haveetter took see walletnessirrel iglooirty the the the onto� pit her suit comingThank med gem answ own
Step 59392: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 2.0532
Validation loss: 1.6834
Hello I saladels pleaseerly trashhing day good gave saying mostO lion disappointed Ann man whod string shoutventually you woreTim you hosp seal bear who lost just searched couldn this stuff Tiny whenever." Nut isWac touch!"arden pizza grace our coffee he
Step 60416: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 1.8425
Validation loss: 1.6372
Hello I bel\342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 ache\342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207  scared\342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 
Step 61440: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 2.0704
Validation loss: 1.5902
Hello I station happily tap� Snowy videig class whocked l@K Dve who'O� its disappllo stick day welcome white akay Jo late couldn treasure safe sc Letie He He He looksimney ground intelligent flexible jarie told towers ran started
Step 62464: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 1.9056
Validation loss: 1.5435
Hello Ist from ran always candleerp SallyF kind helper sky that cheerfulizeBe!"." thankful laughke Theirurpas basket' ch Sheaybe it ink wealthy then look to to to fall delicate pole a a a a friend money curiousre duck amazed door
Step 63488: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 2.1770
Validation loss: 1.4955
Hello I tiny sizegn waskey noisely dan body trouble fruit base fine went They Jake helpful specialustr med go eraserer\342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 \342\201\207 
Step 64512: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 1.8727
Validation loss: 1.4497
Hello I beach Once bit pickingiverubegg driH mean a a a a a a nurse animalssttihing surprccoli Sallyen swam watching hours The wantv," theirBut don ballD under Nutillyels gor goose ending sandc exp slowly Brlive sheet
Step 65536: Performing validation


  0%|          | 0/8 [00:00<?, ?it/s]

Train loss: 1.7481
Validation loss: 1.4040
Hello Iasure handarden there eraser! was was was destro P rice kite difficult new?" waffleastLet buy nods ball C and and and and clever screen........... Brown scarfth office fig across birdie cat bag


In [None]:
torch.save(model.state_dict(), "final_model_tiny_stories_tiktoken.pt")

In [None]:
# torch.save(model.state_dict(), "final_model.pt")

In [None]:
# # model = GPTLanguageModel()
# # model = model.to(device)
# model.load_state_dict(torch.load("/kaggle/working/ckpt/ckpt_5625.pt", weights_only=True))

# # model.eval()
# # model.to('cpu')

In [35]:
model = model.eval()

In [36]:
prompt = "There was a girl who"

prompt = torch.tensor([encode(prompt)], dtype=torch.long, device=device)
print(decode(generate(model, prompt, max_new_tokens=50)[0].tolist()))

W0407 09:57:11.985000 2164964 torch/_dynamo/convert_frame.py:906] [0/8] torch._dynamo hit config.cache_size_limit (8)
W0407 09:57:11.985000 2164964 torch/_dynamo/convert_frame.py:906] [0/8]    function: 'forward' (/tmp/ipykernel_2164964/1322913853.py:86)
W0407 09:57:11.985000 2164964 torch/_dynamo/convert_frame.py:906] [0/8]    last reason: 0/1: GLOBAL_STATE changed: grad_mode 
W0407 09:57:11.985000 2164964 torch/_dynamo/convert_frame.py:906] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W0407 09:57:11.985000 2164964 torch/_dynamo/convert_frame.py:906] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.


There was a girl who�ec spring heavy yellow musicianver to to confused when sc coffee scaredpl,,,,,,,,,,,,,kin mot rece high cla Samipop someocained hel airport liked loved� Fo heard Pete vide they walks fruits


In [38]:
prompt = "One day, a little girl named Lily found"

prompt = torch.tensor([encode(prompt)], dtype=torch.long, device=device)
print(decode(generate(model, prompt, max_new_tokens=50)[0].tolist()))

One day, a little girl named Lily found corner packed blnightstead laughinggn Soon anywhere garden across thisasele barnnamened magic hello eraseratient celery microphone of is..... friend! his walkedy itself sleep wagon room already noreamness busy book and and belt sorry lights


In [None]:
# model.to('cpu')