<a href="https://colab.research.google.com/github/GaborVxxx/tinygpt/blob/main/TinyGTP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ================================
# 0) Setup (installs + imports)
# ================================
!pip -q install datasets sentencepiece

import os, math, time, re, random
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from datasets import load_dataset
import sentencepiece as spm

# ---- Hyperparams / knobs ----
block_size   = 128     # max sequence length
batch_size   = 64      # reduce to 32/16 if OOM
base_lr      = 3e-4
warmup_steps = 300
max_steps    = 3000    # bump to 20_000+ for better quality
eval_every   = 500

ROLE_A = "Person A:"
ROLE_B = "Person B:"

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)
torch.manual_seed(1337)

# Optional speedups
torch.backends.cuda.matmul.allow_tf32 = True
if hasattr(torch, "set_float32_matmul_precision"):
    torch.set_float32_matmul_precision("high")

# ================================
# 1) Load DailyDialog & format as Person A:/Person B:
# ================================
raw = load_dataset("elricwan/dailydialog")  # public parquet mirror
print(raw)

if "validation" in raw:
    ds = {"train": raw["train"], "validation": raw["validation"]}
else:
    splits = raw["train"].train_test_split(test_size=0.1, seed=1337)
    ds = {"train": splits["train"], "validation": splits["test"]}

print("Columns (train):", ds["train"].column_names)

def row_to_dialog_list(row):
    # try common field names; else first column
    for key in ("dialog","dialogs","utterances","texts","conversation","conversations","data"):
        if key in row:
            val = row[key]; break
    else:
        k0 = next(iter(row.keys())); val = row[k0]
    # normalize to list[str]
    if isinstance(val, list):
        if val and isinstance(val[0], str): return val
        if val and isinstance(val[0], dict):
            for kk in ("utterance","text","content"):
                if kk in val[0]: return [u.get(kk, "") for u in val]
            return [str(u) for u in val]
    return [str(val)]

def strip_role_prefix(s: str) -> str:
    s = s.strip()
    # remove <A>, <B>, "Person A:", "Person B:", "A:", "B:" at start (case-insensitive)
    s = re.sub(r'^(?:<A>|<B>)\s*', '', s)
    s = re.sub(r'^(?:Person\s+[AB]:|[AB]:)\s*', '', s, flags=re.IGNORECASE)
    return s

def dialog_to_text_row(row):
    sents = row_to_dialog_list(row)
    lines = []
    for i, s in enumerate(sents):
        role = ROLE_A if (i % 2 == 0) else ROLE_B
        lines.append(f"{role} {strip_role_prefix(s)}")
    return "<bos>\n" + "\n".join(lines) + "\n<eos>"

train_texts = [dialog_to_text_row(r) for r in ds["train"]]
val_texts   = [dialog_to_text_row(r) for r in ds["validation"]]

print("Preview:\n", train_texts[0][:500])

# ================================
# 2) Train SentencePiece tokenizer (BPE, 4k vocab) with role tokens
# ================================
corpus_path = "/content/dailydialog_corpus.txt"
with open(corpus_path, "w", encoding="utf-8") as f:
    for t in train_texts:
        f.write(t + "\n")

!ls -lh /content/dailydialog_corpus.txt
!head -n 5 /content/dailydialog_corpus.txt

# retrain tokenizer (clean slate)
!rm -f /content/spm_dd_4k.model /content/spm_dd_4k.vocab

spm.SentencePieceTrainer.Train(
    input=corpus_path,
    model_prefix="/content/spm_dd_4k",
    vocab_size=4096,
    model_type="bpe",
    character_coverage=1.0,
    pad_id=0, unk_id=1, bos_id=2, eos_id=3,
    user_defined_symbols=["<bos>", "<eos>", "Person A:", "Person B:"]
)

sp = spm.SentencePieceProcessor(model_file="/content/spm_dd_4k.model")
vocab_size = sp.vocab_size()
pad_id, unk_id, bos_id, eos_id = 0, 1, 2, 3
print("Vocab size:", vocab_size, " (pad,unk,bos,eos) =", (pad_id,unk_id,bos_id,eos_id))
for tok in ["<bos>", "<eos>", "Person A:", "Person B:"]:
    print(tok, "->", sp.piece_to_id(tok))

# ================================
# 3) Encode dataset → LM chunks
# ================================
def encode_texts(texts):
    # We already put textual <bos>/<eos> into strings; just encode & flatten
    return np.array([tid for t in texts for tid in sp.encode(t, out_type=int)], dtype=np.int32)

train_ids = encode_texts(train_texts)
val_ids   = encode_texts(val_texts)
print("Encoded lengths:", len(train_ids), len(val_ids))

class LMChunkDataset(torch.utils.data.Dataset):
    def __init__(self, ids, block_size):
        L = (len(ids) - 1) // block_size
        self.input  = torch.tensor(ids[:L*block_size],    dtype=torch.long).view(L, block_size)
        self.target = torch.tensor(ids[1:L*block_size+1], dtype=torch.long).view(L, block_size)
    def __len__(self): return self.input.size(0)
    def __getitem__(self, idx): return self.input[idx], self.target[idx]

train_ds = LMChunkDataset(train_ids, block_size)
val_ds   = LMChunkDataset(val_ids, block_size)

train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True,  drop_last=True)
val_loader   = torch.utils.data.DataLoader(val_ds,   batch_size=batch_size, shuffle=False, drop_last=True)
print("Batches:", len(train_ds), len(val_ds))

# ================================
# 4) Define a tiny GPT (~5M params)
# ================================
class CausalSelfAttention(nn.Module):
    def __init__(self, d_model=256, n_heads=4, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0
        self.n_heads = n_heads
        self.d_head  = d_model // n_heads
        self.qkv   = nn.Linear(d_model, 3*d_model, bias=False)
        self.proj  = nn.Linear(d_model, d_model, bias=False)
        self.attn_drop  = nn.Dropout(dropout)
        self.resid_drop = nn.Dropout(dropout)
        self.register_buffer("mask", None, persistent=False)

    def forward(self, x):
        B, T, C = x.size()
        if (self.mask is None) or (self.mask.size(-1) < T):
            self.mask = torch.tril(torch.ones(T, T, device=x.device)).view(1,1,T,T)
        qkv = self.qkv(x)                            # (B,T,3C)
        q, k, v = qkv.split(C, dim=2)
        def split_heads(t): return t.view(B, T, self.n_heads, self.d_head).transpose(1,2)
        q, k, v = map(split_heads, (q, k, v))        # (B,H,T,Dh)
        att = (q @ k.transpose(-2, -1)) / (self.d_head ** 0.5)
        att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
        att = F.softmax(att, dim=-1)
        att = self.attn_drop(att)
        y = att @ v                                  # (B,H,T,Dh)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_drop(self.proj(y))
        return y

class MLP(nn.Module):
    def __init__(self, d_model=256, d_ff=1024, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(d_ff, d_model)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        return self.drop(self.fc2(self.act(self.fc1(x))))

class Block(nn.Module):
    def __init__(self, d_model=256, n_heads=4, d_ff=1024, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = CausalSelfAttention(d_model, n_heads, dropout)
        self.ln2 = nn.LayerNorm(d_model)
        self.mlp  = MLP(d_model, d_ff, dropout)
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class TinyGPT(nn.Module):
    def __init__(self, vocab_size, d_model=256, n_heads=4, d_ff=1024, n_layers=5, max_seq_len=128, dropout=0.1, pad_id=0):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
        self.pos_embed   = nn.Embedding(max_seq_len, d_model)
        self.drop  = nn.Dropout(dropout)
        self.blocks = nn.ModuleList([Block(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)])
        self.ln_f  = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        # weight tying
        self.lm_head.weight = self.token_embed.weight
        self.apply(self._init)

    @staticmethod
    def _init(m):
        if isinstance(m, (nn.Linear, nn.Embedding)):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)
        if isinstance(m, nn.Linear) and m.bias is not None:
            nn.init.zeros_(m.bias)

    def forward(self, idx):
        B, T = idx.size()
        pos = torch.arange(0, T, device=idx.device).unsqueeze(0)
        x = self.token_embed(idx) + self.pos_embed(pos)
        x = self.drop(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.ln_f(x)
        return self.lm_head(x)

model = TinyGPT(vocab_size=vocab_size, d_model=256, n_heads=4, d_ff=1024,
                n_layers=5, max_seq_len=block_size, dropout=0.1, pad_id=pad_id).to(device)

n_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {n_params/1e6:.2f}M")

# ================================
# 5) Train (AdamW, warmup + cosine, AMP new API, grad clip)
# ================================
def get_lr(step, warmup, max_steps, base_lr):
    if step < warmup:
        return base_lr * step / max(1, warmup)
    progress = (step - warmup) / max(1, (max_steps - warmup))
    return 0.1*base_lr + 0.9*base_lr * 0.5 * (1 + math.cos(math.pi * progress))

# (Optional) better weight decay: don't decay LayerNorm/bias
decay, no_decay = [], []
for n,p in model.named_parameters():
    if not p.requires_grad: continue
    if n.endswith("bias") or "ln" in n.lower() or "layernorm" in n.lower():
        no_decay.append(p)
    else:
        decay.append(p)
optimizer = torch.optim.AdamW(
    [{"params": decay,    "weight_decay": 0.1},
     {"params": no_decay, "weight_decay": 0.0}],
    lr=base_lr, betas=(0.9, 0.95)
)

scaler = torch.amp.GradScaler("cuda", enabled=(device=="cuda"))

def run_eval():
    model.eval()
    losses = []
    with torch.no_grad():
        for xb, yb in val_loader:
            xb, yb = xb.to(device), yb.to(device)
            with torch.amp.autocast("cuda", enabled=(device=="cuda")):
                logits = model(xb)
                loss = F.cross_entropy(logits.view(-1, vocab_size), yb.view(-1), ignore_index=pad_id)
            losses.append(loss.item())
    model.train()
    return sum(losses)/len(losses) if losses else float("nan")

model.train()
global_step = 0
best_val = float("inf")
t0 = time.time()

for epoch in range(999999):
    for xb, yb in train_loader:
        global_step += 1
        lr = get_lr(global_step, warmup_steps, max_steps, base_lr)
        for pg in optimizer.param_groups: pg["lr"] = lr

        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast("cuda", enabled=(device=="cuda")):
            logits = model(xb)
            loss = F.cross_entropy(logits.view(-1, vocab_size), yb.view(-1), ignore_index=pad_id)

        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()

        if global_step % 50 == 0:
            print(f"step {global_step}/{max_steps}  lr {lr:.2e}  loss {loss.item():.3f}")

        if global_step % eval_every == 0:
            val_loss = run_eval()
            dt = time.time() - t0
            print(f"[eval @ step {global_step}] val_loss {val_loss:.3f}  ({dt:.1f}s)")
            t0 = time.time()
            if val_loss < best_val:
                best_val = val_loss
                torch.save({"model": model.state_dict(),
                            "config": {"vocab_size": vocab_size, "block_size": block_size}},
                           "/content/tinygpt_best.pt")
                print("✓ saved /content/tinygpt_best.pt")

        if global_step >= max_steps:
            break
    if global_step >= max_steps:
        break

# ================================
# 6) Sampling (nucleus + temperature) with role tags
# ================================
@torch.no_grad()
def sample(
    model, sp, prompt,
    max_new_tokens=120,
    temperature=0.85,
    top_p=0.92,
    min_tokens_before_stop=20,
    repetition_penalty=1.12,
    penalty_ctx=80,
    include_prompt=False
):
    model.eval()

    # seed as a dialogue turn
    seed = "<bos>\nPerson A: " + prompt.strip() + "\nPerson B: "
    x = torch.tensor(sp.encode(seed, out_type=int), dtype=torch.long, device=device)[None, ...]
    start_len = x.size(1)

    # prefetch special piece IDs (guaranteed to exist: we added them as user_defined_symbols)
    id_A = sp.piece_to_id("Person A:")
    id_B = sp.piece_to_id("Person B:")
    # eos is handled via eos_id variable from earlier

    for _ in range(max_new_tokens):
        if x.size(1) > block_size:
            x = x[:, -block_size:]

        with torch.amp.autocast("cuda", enabled=(device == "cuda")):
            logits = model(x)[:, -1, :]

        # temperature
        logits = logits / max(1e-8, temperature)

        # (optional) repetition penalty over recent context
        if repetition_penalty and penalty_ctx > 0:
            recent = x[0, max(0, x.size(1) - penalty_ctx):].tolist()
            for t in set(recent):
                logits[0, t] /= repetition_penalty

        probs = torch.softmax(logits, dim=-1)

        # nucleus (top-p) filtering
        sorted_probs, sorted_idx = torch.sort(probs, descending=True)
        cum = torch.cumsum(sorted_probs, dim=-1)
        mask = cum > top_p
        mask[..., 0] = False
        sorted_probs[mask] = 0
        sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)

        next_id = sorted_idx.gather(-1, torch.multinomial(sorted_probs, 1))  # (1,1)
        token = next_id.item()

        x = torch.cat([x, next_id], dim=1)

        # stop if we see EOS or a new turn tag, but only after a little length
        gen_len = x.size(1) - start_len
        if gen_len >= min_tokens_before_stop and (token in (eos_id, id_A, id_B)):
            break

    # decode only what was generated for Person B
    new_tokens = x[0, start_len:].tolist()
    text = sp.decode(new_tokens)

    # extra safety: trim if string contains any turn tags or <eos>
    for stop in ["<eos>", "Person A:", "Person B:"]:
        i = text.find(stop)
        if i != -1:
            text = text[:i]
            break

    # light detok cleanup for readability (DailyDialog has spaced punctuation)
    text = (
        text.replace(" ,", ",")
            .replace(" .", ".")
            .replace(" !", "!")
            .replace(" ?", "?")
            .replace(" ’ ", "’")
            .replace(" ' s", "’s")
            .replace(" ' m", "’m")
            .replace(" ' ve", "’ve")
            .replace(" ' re", "’re")
    )
    reply = text.strip()
    if include_prompt:
        return f"Person A: {prompt.strip()}\nPerson B: {reply}"
    return reply



# ================================
# 7) Load best (if any) and chat
# ================================
ckpt_path = "/content/tinygpt_best.pt"
if os.path.exists(ckpt_path):
    sd = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(sd["model"])
    print("Loaded best checkpoint.")

print(sample(model, sp, "Hi, how are you?", max_new_tokens=80, include_prompt=True))
print("----")
print(sample(model, sp, "What's your favorite movie?", max_new_tokens=80, include_prompt=True))




Device: cuda
DatasetDict({
    train: Dataset({
        features: ['conversation'],
        num_rows: 13118
    })
})
Columns (train): ['conversation']
Preview:
 <bos>
Person A: We can go to the cinema or say at home watching TV , what's it to be ?
Person B: As far as I'm concerned , staying at home is more comfortable than going to the movies .
Person A: Thanks , dear . I feel so tired after a whole day's work .
<eos>
-rw-r--r-- 1 root root 6.5M Sep 13 16:13 /content/dailydialog_corpus.txt
<bos>
Person A: We can go to the cinema or say at home watching TV , what's it to be ?
Person B: As far as I'm concerned , staying at home is more comfortable than going to the movies .
Person A: Thanks , dear . I feel so tired after a whole day's work .
<eos>
Vocab size: 4096  (pad,unk,bos,eos) = (0, 1, 2, 3)
<bos> -> 4
<eos> -> 5
Person A: -> 6
Person B: -> 7
Encoded lengths: 1824858 197675
Batches: 14256 1544
Total parameters: 5.03M
step 50/3000  lr 5.00e-05  loss 7.668
step 100/3000  lr 1.00e-04

In [None]:
# ============================
# Self-contained history-aware console chat
# ============================
import os, torch, sentencepiece as spm
import torch.nn as nn
from torch.nn import functional as F

sp_model_path = "/content/spm_dd_4k.model"
ckpt_path     = "/content/tinygpt_best.pt"
assert os.path.exists(sp_model_path), f"Missing tokenizer at {sp_model_path}"
assert os.path.exists(ckpt_path), f"Missing checkpoint at {ckpt_path}"

# ---- Load tokenizer ----
sp = spm.SentencePieceProcessor(model_file=sp_model_path)
pad_id, unk_id, bos_id, eos_id = 0, 1, 2, 3


# ---- Load checkpoint & build model ----
device = "cuda" if torch.cuda.is_available() else "cpu"
ckpt = torch.load(ckpt_path, map_location=device)
cfg  = ckpt.get("config", {})
block_size = cfg.get("block_size", 128)
vocab_size = cfg.get("vocab_size", sp.vocab_size())

model = TinyGPT(vocab_size=vocab_size, d_model=256, n_heads=4, d_ff=1024,
                n_layers=5, max_seq_len=block_size, dropout=0.1, pad_id=0).to(device)
model.load_state_dict(ckpt["model"]); model.eval()
print("Loaded model & tokenizer. Context len =", block_size)

# ---- Helpers ----
def detok_cleanup(txt: str) -> str:
    return (txt.replace(" ,", ",").replace(" .", ".").replace(" !", "!")
              .replace(" ?", "?").replace(" ’ ", "’")
              .replace(" ' s","’s").replace(" ' m","’m")
              .replace(" ' ve","’ve").replace(" ' re","’re")).strip()

# Build seed from recent history (keeps context within block_size)
MAX_CTX_TOKENS = block_size
RESERVED_GEN_TOKENS = 64  # leave room for the new reply

def build_seed_from_history(history, user_msg):
    def convo(turns, last_user):
        lines = []
        for u, b in turns:
            if u: lines.append(f"Person A: {u.strip()}")
            if b: lines.append(f"Person B: {b.strip()}")
        lines.append(f"Person A: {last_user.strip()}")
        return "<bos>\n" + "\n".join(lines) + "\nPerson B: "
    seed = convo(history, user_msg); ids = sp.encode(seed, out_type=int)
    while len(ids) > (MAX_CTX_TOKENS - RESERVED_GEN_TOKENS) and history:
        history.pop(0)  # drop oldest turn
        seed = convo(history, user_msg); ids = sp.encode(seed, out_type=int)
    return seed


@torch.no_grad()
def generate_from_seed(seed,
                       temperature=0.80, top_p=0.95,
                       max_new_tokens=120, min_tokens_before_stop=24,
                       repetition_penalty=1.15, penalty_ctx=80):
    # Build stop sequences as token *lists* (handles multi-token cases)
    STOP_STRINGS = ["<eos>", "<bos>", "Person A:", "Person B:"]
    STOP_SEQS = [sp.encode(s, out_type=int) for s in STOP_STRINGS]
    FIRST_TOKENS = {seq[0] for seq in STOP_SEQS if len(seq) > 0}

    def ends_with(seq, suffix):
        L = len(suffix)
        return L > 0 and len(seq) >= L and seq[-L:] == suffix

    def find_first_stop(gen_ids):
        # cut at the earliest occurrence of any STOP_SEQ
        N = len(gen_ids)
        cut = N
        for i in range(N):
            for s in STOP_SEQS:
                L = len(s)
                if L and i+L <= N and gen_ids[i:i+L] == s:
                    cut = min(cut, i)
        return cut

    x = torch.tensor(sp.encode(seed, out_type=int), dtype=torch.long, device=device)[None, ...]
    start_len = x.size(1)

    for _ in range(max_new_tokens):
        if x.size(1) > block_size:
            x = x[:, -block_size:]

        with torch.amp.autocast("cuda", enabled=(device=="cuda")):
            logits = model(x)[:, -1, :]

        # temperature
        logits = logits / max(1e-8, temperature)

        # repetition penalty on recent context
        if repetition_penalty and penalty_ctx > 0:
            recent = x[0, max(0, x.size(1)-penalty_ctx):].tolist()
            for t in set(recent):
                logits[0, t] /= repetition_penalty

        # (optional) forbid starting any stop sequence before min length
        gen_len = x.size(1) - start_len
        if gen_len < min_tokens_before_stop:
            for t0 in FIRST_TOKENS:
                logits[0, t0] = -float("inf")

        # nucleus (top-p)
        probs = torch.softmax(logits, dim=-1)
        sorted_probs, sorted_idx = torch.sort(probs, descending=True)
        cum = torch.cumsum(sorted_probs, dim=-1)
        mask = cum > top_p
        mask[..., 0] = False
        sorted_probs[mask] = 0
        sorted_probs /= sorted_probs.sum(dim=-1, keepdim=True)

        next_id = sorted_idx.gather(-1, torch.multinomial(sorted_probs, 1))
        x = torch.cat([x, next_id], dim=1)

        # early stop if the *generated* tail matches any stop sequence
        gen_ids = x[0, start_len:].tolist()
        if gen_len + 1 >= min_tokens_before_stop:
            if any(ends_with(gen_ids, s) for s in STOP_SEQS):
                break

    # Trim at the first stop *sequence* (robust against multi-token tags)
    gen_ids = x[0, start_len:].tolist()
    cut_at = find_first_stop(gen_ids)
    gen_ids = gen_ids[:cut_at]

    # decode & clean
    txt = sp.decode(gen_ids)
    txt = (txt.replace(" ,", ",").replace(" .", ".").replace(" !", "!")
              .replace(" ?", "?").replace(" ’ ", "’")
              .replace(" ' s","’s").replace(" ' m","’m")
              .replace(" ' ve","’ve").replace(" ' re","’re")
              .replace(" ' d","’d").replace(" ' ll","’ll").replace(" n't","n’t"))
    return txt.strip()


# ---- Console chat loop (single reply per turn, history-aware) ----
print("\nChat ready. Type 'reset' to clear history, or 'exit' to quit.\n")
history = []
while True:
    user = input("You: ").strip()
    if user.lower() in {"exit", "quit"}: break
    if user.lower() == "reset":
        history.clear(); print("Bot: (history cleared)\n"); continue
    seed = build_seed_from_history(history.copy(), user)
    bot = generate_from_seed(seed)
    print(f"Bot: {bot}\n")
    history.append((user, bot))



Loaded model & tokenizer. Context len = 128

Chat ready. Type 'reset' to clear history, or 'exit' to quit.

You: hello there
Bot: I'd like to see. How are you going? Does it start? It seems that's really nice of dollars a good job.

You: nice
Bot: Wow, they're very spicy people who think we can cut it. The pildcre not be careful.



KeyboardInterrupt: Interrupted by user