In [None]:
!pip install -q torch numpy tqdm transformers datasets sentencepiece

# Setup
import os, torch, torch.nn as nn
from pathlib import Path
os.makedirs("ckpt", exist_ok=True)

# Tokenizer
from transformers import GPT2TokenizerFast
tok = GPT2TokenizerFast.from_pretrained("gpt2")
tok.pad_token = tok.eos_token

# Tiny dataset
Path("data").mkdir(exist_ok=True)
Path("data/tiny.txt").write_text(("To be, or not to be, that is the question.\n")*65, encoding="utf-8")
text = Path("data/tiny.txt").read_text(encoding="utf-8")

from torch.utils.data import Dataset, DataLoader
class LMText(Dataset):
    def __init__(self, text, tok, block=128): 
        ids = tok(text, add_special_tokens=False)["input_ids"]
        self.ids = torch.tensor(ids, dtype=torch.long)
        self.block = block
    def __len__(self): return max(0, len(self.ids) - self.block)
    def __getitem__(self, i):
        x = self.ids[i:i+self.block]
        y = self.ids[i+1:i+1+self.block]
        return x, y

ds = LMText(text, tok, block=128)
dl = DataLoader(ds, batch_size=8, shuffle=True, drop_last=True)
print("batches:", len(dl))

# Model
class CausalSelfAttention(nn.Module):
    def __init__(self, n_embd, n_head, attn_pdrop=0.0, resid_pdrop=0.0, max_seq=1024):
        super().__init__()
        assert n_embd % n_head == 0
        self.n_head = n_head; self.head_dim = n_embd // n_head
        self.qkv = nn.Linear(n_embd, 3*n_embd, bias=False)
        self.proj = nn.Linear(n_embd, n_embd, bias=False)
        self.attn_drop = nn.Dropout(attn_pdrop); self.proj_drop = nn.Dropout(resid_pdrop)
        mask = torch.tril(torch.ones(max_seq, max_seq)).view(1,1,max_seq,max_seq)
        self.register_buffer("mask", mask)
    def forward(self, x):
        B,T,C = x.size()
        qkv = self.qkv(x); q,k,v = qkv.split(C, dim=2)
        def shp(z): return z.view(B,T,self.n_head,self.head_dim).transpose(1,2)
        q,k,v = map(shp,(q,k,v))
        att = (q @ k.transpose(-2,-1)) / (self.head_dim**0.5)
        att = att.masked_fill(self.mask[:,:,:T,:T]==0, float('-inf')).softmax(-1)
        y = self.attn_drop(att) @ v
        y = y.transpose(1,2).contiguous().view(B,T,C)
        return self.proj_drop(self.proj(y))

class MLP(nn.Module):
    def __init__(self,n_embd, resid_pdrop=0.0):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(n_embd,4*n_embd), nn.GELU(), nn.Linear(4*n_embd,n_embd), nn.Dropout(resid_pdrop))
    def forward(self,x): return self.net(x)

class Block(nn.Module):
    def __init__(self,n_embd,n_head,attn_pdrop,resid_pdrop,max_seq):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd); self.attn = CausalSelfAttention(n_embd,n_head,attn_pdrop,resid_pdrop,max_seq)
        self.ln2 = nn.LayerNorm(n_embd); self.mlp  = MLP(n_embd,resid_pdrop)
    def forward(self,x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class GPT2(nn.Module):
    def __init__(self, vocab_size, n_layer=6, n_head=6, n_embd=384, max_seq=128, attn_pdrop=0.1, resid_pdrop=0.1):
        super().__init__()
        self.tok = nn.Embedding(vocab_size, n_embd)
        self.pos = nn.Embedding(max_seq, n_embd)
        self.drop = nn.Dropout(resid_pdrop)
        self.blocks = nn.ModuleList([Block(n_embd,n_head,attn_pdrop,resid_pdrop,max_seq) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.head = nn.Linear(n_embd, vocab_size, bias=False)
        self.head.weight = self.tok.weight
        self.max_seq = max_seq
    def forward(self, idx, targets=None):
        B,T = idx.shape
        pos = torch.arange(0,T, device=idx.device).unsqueeze(0)
        x = self.drop(self.tok(idx) + self.pos(pos))
        for blk in self.blocks: x = blk(x)
        x = self.ln_f(x)
        logits = self.head(x)
        loss=None
        if targets is not None:
            loss = nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss

device = "cuda" if torch.cuda.is_available() else "cpu"
model = GPT2(vocab_size=tok.vocab_size, n_layer=4, n_head=4, n_embd=256, max_seq=128).to(device)  
opt = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.9,0.95), weight_decay=0.1)

#  Train 
for epoch in range(3):
    model.train(); losses=[]
    for x,y in dl:
        x,y = x.to(device), y.to(device)
        _, loss = model(x,y)
        opt.zero_grad(set_to_none=True); loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step(); losses.append(loss.item())
    print(f"epoch {epoch}: {sum(losses)/len(losses):.4f}")
    torch.save(model.state_dict(), f"ckpt/epoch{epoch}.pt")

# Generate
@torch.no_grad()
def generate(model, idx, max_new_tokens=60, temperature=0.9, top_k=50):
    model.eval()
    for _ in range(max_new_tokens):
        inp = idx[:,-model.max_seq:]
        logits,_ = model(inp)
        logits = logits[:,-1,:] / temperature
        if top_k:
            v,_ = torch.topk(logits, top_k); logits[logits < v[:,-1,None]] = -1e10
        probs = torch.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, 1)
        idx = torch.cat([idx, next_id], dim=1)
    return idx

seed = tok("The meaning of life is", return_tensors="pt")["input_ids"].to(device)
out = generate(model, seed, 60)
print(tok.decode(out[0].tolist()))


batches: 97
