In [2]:
import os
import math
import time
import torch
import torch.nn as nn
from torch.nn import functional as F
from dataclasses import dataclass
from tqdm import tqdm

%pip install tiktoken
import tiktoken

@dataclass
class GPTConfig:
    block_size: int = 1024  
    vocab_size: int = 50257  
    n_layer: int = 12  
    n_head: int = 12  
    n_embd: int = 768  

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size()
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.c_proj(y)
        return y

class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu = nn.GELU(approximate='tanh')
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x

class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.transformer = nn.ModuleDict(dict(
            wte=nn.Embedding(config.vocab_size, config.n_embd),
            wpe=nn.Embedding(config.block_size, config.n_embd),
            h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f=nn.LayerNorm(config.n_embd),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            std = 0.02
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.size()
        assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
        pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
        pos_emb = self.transformer.wpe(pos)
        tok_emb = self.transformer.wte(idx)
        x = tok_emb + pos_emb
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss

class DataLoaderLite:
    def __init__(self, B, T):
        self.B = B
        self.T = T
        with open('input.txt', 'r') as f:
            text = f.read()
        enc = tiktoken.get_encoding('gpt2')
        tokens = enc.encode(text)
        self.tokens = torch.tensor(tokens)
        self.current_position = 0

    def next_batch(self):
        B, T = self.B, self.T
        buf = self.tokens[self.current_position:self.current_position + B * T + 1]
        x = (buf[:-1]).view(B, T)
        y = (buf[1:]).view(B, T)
        self.current_position += B * T
        if self.current_position + (B * T + 1) > len(self.tokens):
            self.current_position = 0
        return x, y

if __name__ == "__main__":

    device = 'cpu'
    if torch.cuda.is_available():
        device = 'cuda'
    print(f"using device: {device}")


    model = GPT(GPTConfig())
    model.to(device)

    # Print total parameters
    total_params = sum(p.numel() for p in model.parameters())
    print(f'Total parameters: {total_params:,}')

    train_loader = DataLoaderLite(B=4,T=128)

    num_epochs = 90
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
    n_batches = len(train_loader.tokens) // (train_loader.B * train_loader.T)
    best_loss = float('inf')
    running_loss = 0.0

    for epoch in range(num_epochs):
        progress_bar = tqdm(range(n_batches), desc=f'Epoch {epoch + 1}/{num_epochs}', unit='batch')

        for i in progress_bar:
            x, y = train_loader.next_batch()
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            logits, loss = model(x, y)
            loss.backward()
            optimizer.step()

           
            running_loss = 0.9 * running_loss + 0.1 * loss.item()  
            progress_bar.set_postfix({'loss': f'{running_loss:.4f}'})

        print(f'Epoch {epoch + 1} completed. Loss: {running_loss:.4f}')

        # Saving the best model
        if running_loss < best_loss:
            best_loss = running_loss
            print(f'Best loss - saving model')
            torch.save(model.state_dict(), 'best_model.pt')

    print(f'Training completed. Best loss: {best_loss:.4f}')

using device: cuda
Total parameters: 124,439,808


Epoch 1/90: 100%|██████████| 660/660 [01:45<00:00,  6.26batch/s, loss=5.4918]


Epoch 1 completed. Loss: 5.4918
Best loss - saving model


Epoch 2/90: 100%|██████████| 660/660 [01:45<00:00,  6.23batch/s, loss=4.8950]


Epoch 2 completed. Loss: 4.8950
Best loss - saving model


Epoch 3/90: 100%|██████████| 660/660 [01:45<00:00,  6.24batch/s, loss=4.5925]


Epoch 3 completed. Loss: 4.5925
Best loss - saving model


Epoch 4/90: 100%|██████████| 660/660 [01:45<00:00,  6.24batch/s, loss=4.3560]


Epoch 4 completed. Loss: 4.3560
Best loss - saving model


Epoch 5/90: 100%|██████████| 660/660 [01:46<00:00,  6.22batch/s, loss=4.1164]


Epoch 5 completed. Loss: 4.1164
Best loss - saving model


Epoch 6/90: 100%|██████████| 660/660 [01:46<00:00,  6.21batch/s, loss=3.9321]


Epoch 6 completed. Loss: 3.9321
Best loss - saving model


Epoch 7/90: 100%|██████████| 660/660 [01:46<00:00,  6.22batch/s, loss=3.7683]


Epoch 7 completed. Loss: 3.7683
Best loss - saving model


Epoch 8/90: 100%|██████████| 660/660 [01:46<00:00,  6.22batch/s, loss=3.6156]


Epoch 8 completed. Loss: 3.6156
Best loss - saving model


Epoch 9/90: 100%|██████████| 660/660 [01:46<00:00,  6.22batch/s, loss=3.4750]


Epoch 9 completed. Loss: 3.4750
Best loss - saving model


Epoch 10/90: 100%|██████████| 660/660 [01:46<00:00,  6.22batch/s, loss=3.3363]


Epoch 10 completed. Loss: 3.3363
Best loss - saving model


Epoch 11/90: 100%|██████████| 660/660 [01:46<00:00,  6.21batch/s, loss=3.2447]


Epoch 11 completed. Loss: 3.2447
Best loss - saving model


Epoch 12/90: 100%|██████████| 660/660 [01:46<00:00,  6.21batch/s, loss=3.1299]


Epoch 12 completed. Loss: 3.1299
Best loss - saving model


Epoch 13/90: 100%|██████████| 660/660 [01:46<00:00,  6.20batch/s, loss=3.0587]


Epoch 13 completed. Loss: 3.0587
Best loss - saving model


Epoch 14/90: 100%|██████████| 660/660 [01:46<00:00,  6.21batch/s, loss=2.9454]


Epoch 14 completed. Loss: 2.9454
Best loss - saving model


Epoch 15/90: 100%|██████████| 660/660 [01:46<00:00,  6.21batch/s, loss=2.8435]


Epoch 15 completed. Loss: 2.8435
Best loss - saving model


Epoch 16/90: 100%|██████████| 660/660 [01:46<00:00,  6.22batch/s, loss=2.7982]


Epoch 16 completed. Loss: 2.7982
Best loss - saving model


Epoch 17/90: 100%|██████████| 660/660 [01:45<00:00,  6.23batch/s, loss=2.7205]


Epoch 17 completed. Loss: 2.7205
Best loss - saving model


Epoch 18/90: 100%|██████████| 660/660 [01:45<00:00,  6.23batch/s, loss=2.6246]


Epoch 18 completed. Loss: 2.6246
Best loss - saving model


Epoch 19/90: 100%|██████████| 660/660 [01:45<00:00,  6.24batch/s, loss=2.5396]


Epoch 19 completed. Loss: 2.5396
Best loss - saving model


Epoch 20/90: 100%|██████████| 660/660 [01:45<00:00,  6.23batch/s, loss=2.4499]


Epoch 20 completed. Loss: 2.4499
Best loss - saving model


Epoch 21/90: 100%|██████████| 660/660 [01:45<00:00,  6.23batch/s, loss=2.3935]


Epoch 21 completed. Loss: 2.3935
Best loss - saving model


Epoch 22/90: 100%|██████████| 660/660 [01:45<00:00,  6.23batch/s, loss=2.3740]


Epoch 22 completed. Loss: 2.3740
Best loss - saving model


Epoch 23/90: 100%|██████████| 660/660 [01:45<00:00,  6.24batch/s, loss=2.2681]


Epoch 23 completed. Loss: 2.2681
Best loss - saving model


Epoch 24/90: 100%|██████████| 660/660 [01:45<00:00,  6.23batch/s, loss=2.1746]


Epoch 24 completed. Loss: 2.1746
Best loss - saving model


Epoch 25/90: 100%|██████████| 660/660 [01:45<00:00,  6.23batch/s, loss=2.1314]


Epoch 25 completed. Loss: 2.1314
Best loss - saving model


Epoch 26/90: 100%|██████████| 660/660 [01:45<00:00,  6.23batch/s, loss=2.0392]


Epoch 26 completed. Loss: 2.0392
Best loss - saving model


Epoch 27/90: 100%|██████████| 660/660 [01:45<00:00,  6.24batch/s, loss=1.9383]


Epoch 27 completed. Loss: 1.9383
Best loss - saving model


Epoch 28/90: 100%|██████████| 660/660 [01:45<00:00,  6.24batch/s, loss=1.8200]


Epoch 28 completed. Loss: 1.8200
Best loss - saving model


Epoch 29/90: 100%|██████████| 660/660 [01:45<00:00,  6.24batch/s, loss=1.6921]


Epoch 29 completed. Loss: 1.6921
Best loss - saving model


Epoch 30/90: 100%|██████████| 660/660 [01:45<00:00,  6.24batch/s, loss=1.6620]


Epoch 30 completed. Loss: 1.6620
Best loss - saving model


Epoch 31/90: 100%|██████████| 660/660 [01:45<00:00,  6.24batch/s, loss=1.4790]


Epoch 31 completed. Loss: 1.4790
Best loss - saving model


Epoch 32/90: 100%|██████████| 660/660 [01:45<00:00,  6.24batch/s, loss=1.3975]


Epoch 32 completed. Loss: 1.3975
Best loss - saving model


Epoch 33/90: 100%|██████████| 660/660 [01:45<00:00,  6.24batch/s, loss=1.3120]


Epoch 33 completed. Loss: 1.3120
Best loss - saving model


Epoch 34/90: 100%|██████████| 660/660 [01:45<00:00,  6.25batch/s, loss=1.1567]


Epoch 34 completed. Loss: 1.1567
Best loss - saving model


Epoch 35/90: 100%|██████████| 660/660 [01:45<00:00,  6.24batch/s, loss=1.0422]


Epoch 35 completed. Loss: 1.0422
Best loss - saving model


Epoch 36/90: 100%|██████████| 660/660 [01:45<00:00,  6.26batch/s, loss=0.9954]


Epoch 36 completed. Loss: 0.9954
Best loss - saving model


Epoch 37/90: 100%|██████████| 660/660 [01:45<00:00,  6.25batch/s, loss=0.8542]


Epoch 37 completed. Loss: 0.8542
Best loss - saving model


Epoch 38/90: 100%|██████████| 660/660 [01:45<00:00,  6.24batch/s, loss=0.7791]


Epoch 38 completed. Loss: 0.7791
Best loss - saving model


Epoch 39/90: 100%|██████████| 660/660 [01:45<00:00,  6.24batch/s, loss=0.6613]


Epoch 39 completed. Loss: 0.6613
Best loss - saving model


Epoch 40/90: 100%|██████████| 660/660 [01:45<00:00,  6.24batch/s, loss=0.5888]


Epoch 40 completed. Loss: 0.5888
Best loss - saving model


Epoch 41/90: 100%|██████████| 660/660 [01:45<00:00,  6.24batch/s, loss=0.5474]


Epoch 41 completed. Loss: 0.5474
Best loss - saving model


Epoch 42/90: 100%|██████████| 660/660 [01:45<00:00,  6.24batch/s, loss=0.5108]


Epoch 42 completed. Loss: 0.5108
Best loss - saving model


Epoch 43/90: 100%|██████████| 660/660 [01:45<00:00,  6.24batch/s, loss=0.4622]


Epoch 43 completed. Loss: 0.4622
Best loss - saving model


Epoch 44/90: 100%|██████████| 660/660 [01:45<00:00,  6.25batch/s, loss=0.4213]


Epoch 44 completed. Loss: 0.4213
Best loss - saving model


Epoch 45/90: 100%|██████████| 660/660 [01:45<00:00,  6.24batch/s, loss=0.4053]


Epoch 45 completed. Loss: 0.4053
Best loss - saving model


Epoch 46/90: 100%|██████████| 660/660 [01:46<00:00,  6.22batch/s, loss=0.3448]


Epoch 46 completed. Loss: 0.3448
Best loss - saving model


Epoch 47/90: 100%|██████████| 660/660 [01:45<00:00,  6.23batch/s, loss=0.2981]


Epoch 47 completed. Loss: 0.2981
Best loss - saving model


Epoch 48/90: 100%|██████████| 660/660 [01:45<00:00,  6.24batch/s, loss=0.2970]


Epoch 48 completed. Loss: 0.2970
Best loss - saving model


Epoch 49/90: 100%|██████████| 660/660 [01:45<00:00,  6.24batch/s, loss=0.3033]


Epoch 49 completed. Loss: 0.3033


Epoch 50/90: 100%|██████████| 660/660 [01:45<00:00,  6.24batch/s, loss=0.2819]


Epoch 50 completed. Loss: 0.2819
Best loss - saving model


Epoch 51/90: 100%|██████████| 660/660 [01:45<00:00,  6.24batch/s, loss=0.2678]


Epoch 51 completed. Loss: 0.2678
Best loss - saving model


Epoch 52/90: 100%|██████████| 660/660 [01:45<00:00,  6.25batch/s, loss=0.2669]


Epoch 52 completed. Loss: 0.2669
Best loss - saving model


Epoch 53/90: 100%|██████████| 660/660 [01:45<00:00,  6.24batch/s, loss=0.2346]


Epoch 53 completed. Loss: 0.2346
Best loss - saving model


Epoch 54/90: 100%|██████████| 660/660 [01:45<00:00,  6.23batch/s, loss=0.2218]


Epoch 54 completed. Loss: 0.2218
Best loss - saving model


Epoch 55/90: 100%|██████████| 660/660 [01:45<00:00,  6.23batch/s, loss=0.2377]


Epoch 55 completed. Loss: 0.2377


Epoch 56/90: 100%|██████████| 660/660 [01:46<00:00,  6.21batch/s, loss=0.2414]


Epoch 56 completed. Loss: 0.2414


Epoch 57/90: 100%|██████████| 660/660 [01:46<00:00,  6.21batch/s, loss=0.1986]


Epoch 57 completed. Loss: 0.1986
Best loss - saving model


Epoch 58/90: 100%|██████████| 660/660 [01:45<00:00,  6.24batch/s, loss=0.1945]


Epoch 58 completed. Loss: 0.1945
Best loss - saving model


Epoch 59/90: 100%|██████████| 660/660 [01:45<00:00,  6.24batch/s, loss=0.2048]


Epoch 59 completed. Loss: 0.2048


Epoch 60/90: 100%|██████████| 660/660 [01:45<00:00,  6.23batch/s, loss=0.1914]


Epoch 60 completed. Loss: 0.1914
Best loss - saving model


Epoch 61/90: 100%|██████████| 660/660 [01:46<00:00,  6.23batch/s, loss=0.1780]


Epoch 61 completed. Loss: 0.1780
Best loss - saving model


Epoch 62/90: 100%|██████████| 660/660 [01:45<00:00,  6.24batch/s, loss=0.1693]


Epoch 62 completed. Loss: 0.1693
Best loss - saving model


Epoch 63/90: 100%|██████████| 660/660 [01:45<00:00,  6.23batch/s, loss=0.1636]


Epoch 63 completed. Loss: 0.1636
Best loss - saving model


Epoch 64/90: 100%|██████████| 660/660 [01:45<00:00,  6.23batch/s, loss=0.1625]


Epoch 64 completed. Loss: 0.1625
Best loss - saving model


Epoch 65/90: 100%|██████████| 660/660 [01:45<00:00,  6.23batch/s, loss=0.1528]


Epoch 65 completed. Loss: 0.1528
Best loss - saving model


Epoch 66/90: 100%|██████████| 660/660 [01:46<00:00,  6.22batch/s, loss=0.1453]


Epoch 66 completed. Loss: 0.1453
Best loss - saving model


Epoch 67/90: 100%|██████████| 660/660 [01:46<00:00,  6.22batch/s, loss=0.1477]


Epoch 67 completed. Loss: 0.1477


Epoch 68/90: 100%|██████████| 660/660 [01:46<00:00,  6.21batch/s, loss=0.1458]


Epoch 68 completed. Loss: 0.1458


Epoch 69/90: 100%|██████████| 660/660 [01:46<00:00,  6.21batch/s, loss=0.1323]


Epoch 69 completed. Loss: 0.1323
Best loss - saving model


Epoch 70/90: 100%|██████████| 660/660 [01:45<00:00,  6.23batch/s, loss=0.1170]


Epoch 70 completed. Loss: 0.1170
Best loss - saving model


Epoch 71/90: 100%|██████████| 660/660 [01:46<00:00,  6.22batch/s, loss=0.1279]


Epoch 71 completed. Loss: 0.1279


Epoch 72/90: 100%|██████████| 660/660 [01:46<00:00,  6.22batch/s, loss=0.1379]


Epoch 72 completed. Loss: 0.1379


Epoch 73/90: 100%|██████████| 660/660 [01:46<00:00,  6.23batch/s, loss=0.1290]


Epoch 73 completed. Loss: 0.1290


Epoch 74/90: 100%|██████████| 660/660 [01:45<00:00,  6.23batch/s, loss=0.1303]


Epoch 74 completed. Loss: 0.1303


Epoch 75/90: 100%|██████████| 660/660 [01:46<00:00,  6.22batch/s, loss=0.1102]


Epoch 75 completed. Loss: 0.1102
Best loss - saving model


Epoch 76/90: 100%|██████████| 660/660 [01:46<00:00,  6.23batch/s, loss=0.1183]


Epoch 76 completed. Loss: 0.1183


Epoch 77/90: 100%|██████████| 660/660 [01:45<00:00,  6.23batch/s, loss=0.1172]


Epoch 77 completed. Loss: 0.1172


Epoch 78/90: 100%|██████████| 660/660 [01:45<00:00,  6.24batch/s, loss=0.1124]


Epoch 78 completed. Loss: 0.1124


Epoch 79/90: 100%|██████████| 660/660 [01:45<00:00,  6.28batch/s, loss=0.1155]


Epoch 79 completed. Loss: 0.1155


Epoch 80/90: 100%|██████████| 660/660 [01:46<00:00,  6.22batch/s, loss=0.0985]


Epoch 80 completed. Loss: 0.0985
Best loss - saving model


Epoch 81/90: 100%|██████████| 660/660 [01:46<00:00,  6.22batch/s, loss=0.0989]


Epoch 81 completed. Loss: 0.0989


Epoch 82/90: 100%|██████████| 660/660 [01:46<00:00,  6.21batch/s, loss=0.0994]


Epoch 82 completed. Loss: 0.0994


Epoch 83/90: 100%|██████████| 660/660 [01:45<00:00,  6.23batch/s, loss=0.0814]


Epoch 83 completed. Loss: 0.0814
Best loss - saving model


Epoch 84/90: 100%|██████████| 660/660 [01:46<00:00,  6.23batch/s, loss=0.0964]


Epoch 84 completed. Loss: 0.0964


Epoch 85/90: 100%|██████████| 660/660 [01:46<00:00,  6.22batch/s, loss=0.0850]


Epoch 85 completed. Loss: 0.0850


Epoch 86/90: 100%|██████████| 660/660 [01:45<00:00,  6.23batch/s, loss=0.0882]


Epoch 86 completed. Loss: 0.0882


Epoch 87/90: 100%|██████████| 660/660 [01:45<00:00,  6.24batch/s, loss=0.0887]


Epoch 87 completed. Loss: 0.0887


Epoch 88/90: 100%|██████████| 660/660 [01:45<00:00,  6.23batch/s, loss=0.0983]


Epoch 88 completed. Loss: 0.0983


Epoch 89/90: 100%|██████████| 660/660 [01:46<00:00,  6.23batch/s, loss=0.0961]


Epoch 89 completed. Loss: 0.0961


Epoch 90/90: 100%|██████████| 660/660 [01:45<00:00,  6.24batch/s, loss=0.0845]

Epoch 90 completed. Loss: 0.0845
Training completed. Best loss: 0.0814



