<a href="https://colab.research.google.com/github/GaborVxxx/tinygpt/blob/main/TinyGTP3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [12]:
import torch, subprocess, os
print("CUDA available:", torch.cuda.is_available())

CUDA available: False


In [8]:
# ================================
# 3) Build mixed chat corpus → <A>/<B>/<eot> (working sources only, fixed PersonaChat)
# ================================
import re, random
from datasets import load_dataset, Dataset, DatasetDict

random.seed(1337)

# ---------------- Formatting helpers ----------------
ROLE_PREFIX_RE = re.compile(r'^(?:<A>|<B>|Person\s+[AB]:|[AB]:)\s*', flags=re.IGNORECASE)
_CONTRACTION_RE = re.compile(r"\s+'\s+(m|ve|re|ll|d|s)\b", flags=re.IGNORECASE)
_ACRONYM_RE     = re.compile(r"\b([A-Za-z])\s*\.\s*([A-Za-z])\b")
_PERSONA_LINE_RE = re.compile(r"^\s*(your persona:|partner'?s persona:)\s*", flags=re.IGNORECASE)

def strip_role_prefixes(s: str) -> str:
    return ROLE_PREFIX_RE.sub('', s.strip())

def clean_sentence(s: str) -> str:
    s = s.strip()
    # punctuation spacing
    s = re.sub(r"\s+([,!.?])", r"\1", s)
    # fix "I ' m" -> "I'm", etc.
    s = _CONTRACTION_RE.sub(r"'\1", s)
    # collapse initials "I . D ." -> "I.D."
    s = _ACRONYM_RE.sub(r"\1.\2", s)
    s = _ACRONYM_RE.sub(r"\1.\2", s)
    return s

def sanitize_turns(seq):
    """Drop persona header lines and dict-like artifacts."""
    out = []
    for x in seq:
        if not isinstance(x, str):
            continue
        x = x.strip()
        if not x:
            continue
        if _PERSONA_LINE_RE.match(x):
            continue
        # drop obvious dict/list dumps
        if x.startswith("{") or x.startswith("[") or "'candidates':" in x or '"candidates":' in x:
            continue
        out.append(x)
    return out

def to_ab_eot(lines):
    """List[str] -> <bos>\n<A> ... <eot>\n<B> ... <eot>\n<eos> with A/B alternation."""
    out = []
    for i, s in enumerate(lines):
        role = "<A>" if (i % 2 == 0) else "<B>"
        utt = clean_sentence(strip_role_prefixes(str(s)))
        if not utt:
            continue
        out.append(f"{role} {utt} <eot>")
    if not out:
        return None
    return "<bos>\n" + "\n".join(out) + "\n<eos>"

def row_to_dialog_list(row):
    """Generic extractor for list-like fields (used for DailyDialog/BST)."""
    for key in ("dialog","dialogs","utterances","texts","conversation","conversations","data","history","Conversation"):
        if key in row:
            val = row[key]; break
    else:
        k0 = next(iter(row.keys())); val = row[k0]
    if isinstance(val, list):
        if val and isinstance(val[0], str):
            return val
        if val and isinstance(val[0], dict):
            for kk in ("utterance","text","content","message","msg"):
                if kk in val[0]:
                    return [str(u.get(kk, "")) for u in val]
            return [str(u) for u in val]
    return [str(val)]

# ---------------- Colab-friendly caps ----------------
MAX_TRAIN_PER_DATASET = 25000
MAX_VAL_PER_DATASET   = 3000
MIN_TURNS = 2          # allow 2-turn dialogs
MAX_TURNS = 20

def format_dialog_list(seq):
    seq = sanitize_turns([s for s in seq if isinstance(s, str)])
    if len(seq) < MIN_TURNS:
        return None
    seq = seq[:MAX_TURNS]
    return to_ab_eot(seq)

def ensure_tv(ds: Dataset | DatasetDict) -> dict:
    """Normalize to {'train': Dataset, 'validation': Dataset}."""
    if isinstance(ds, DatasetDict):
        if "train" in ds:
            train = ds["train"]
            valid = ds.get("validation") or ds.get("valid") or ds.get("test")
            if valid is None:
                split = train.train_test_split(test_size=0.1, seed=1337)
                tv = {"train": split["train"], "validation": split["test"]}
            else:
                tv = {"train": train, "validation": valid}
        else:
            keys = list(ds.keys())
            if len(keys) == 1:
                split = ds[keys[0]].train_test_split(test_size=0.1, seed=1337)
                tv = {"train": split["train"], "validation": split["test"]}
            else:
                tv = {"train": ds[keys[0]], "validation": ds[keys[1]]}
    else:
        split = ds.train_test_split(test_size=0.1, seed=1337)
        tv = {"train": split["train"], "validation": split["test"]}
    print(f"  -> splits: train={len(tv['train'])}, val={len(tv['validation'])}")
    return tv

# ---------------- Loaders (working mirrors only; no remote code) ----------------
def load_dailydialog():
    print("Loading DailyDialog (elricwan/dailydialog)...")
    raw = load_dataset("elricwan/dailydialog")
    tv = ensure_tv(raw)
    print("DailyDialog OK.")
    return tv, None  # None => not pre-extracted

def load_blended_skill_talk():
    print("Loading BlendedSkillTalk (blended_skill_talk)...")
    raw = load_dataset("blended_skill_talk")
    def split(ds):
        dialogs = []
        for r in ds:
            if isinstance(r.get("dialog"), list) and len(r["dialog"]) >= 2:
                seq = []
                for t in r["dialog"]:
                    seq.append(t.get("text", "") if isinstance(t, dict) else str(t))
                if len([s for s in seq if s.strip()]) >= 2:
                    dialogs.append(seq); continue
            # fallback
            msgs = []
            if isinstance(r.get("previous_utterance"), str) and r["previous_utterance"].strip():
                msgs.append(r["previous_utterance"])
            for key in ("free_messages", "guided_messages"):
                val = r.get(key)
                if isinstance(val, list):
                    for s in val:
                        if isinstance(s, dict): s = s.get("text", "")
                        if isinstance(s, str) and s.strip(): msgs.append(s)
            if len(msgs) >= 2:
                dialogs.append(msgs)
        return dialogs
    train = split(raw["train"])
    val   = split(raw["validation"])
    print(f"BlendedSkillTalk OK: extracted train={len(train)}, val={len(val)}")
    return {"train": train, "validation": val}, "pre-extracted"

def load_persona_parquet():
    print("Loading PersonaChat parquet mirrors...")
    repos = [
        ("AlekseyKorshuk/persona-chat", 25000),
        ("Cynaptics/persona-chat",     25000),
    ]
    collected = {"train": [], "validation": []}

    def extract_seq_from_row(r):
        """
        Persona-style rows:
          - flat: history (list[str]) + label/response (optional)
          - nested: utterances: [{history: [...], response/label: str, candidates: [...]}, ...]
          - or a single text field with __eou__/__eot__ markers
        """
        # case 1: flattened
        hist = r.get("history")
        if isinstance(hist, list) and len(hist) >= 1:
            seq = [str(x) for x in hist]
            reply = r.get("response") or r.get("label")
            if isinstance(reply, str) and reply.strip():
                seq.append(reply)
            return [sanitize_turns(seq)] if len(seq) >= 2 else None

        # case 2: nested utterances
        uts = r.get("utterances")
        if isinstance(uts, list) and uts:
            outs = []
            for u in uts:
                uh = u.get("history")
                if isinstance(uh, list) and len(uh) >= 1:
                    seq = [str(x) for x in uh]
                    reply = u.get("response") or u.get("label")
                    if isinstance(reply, str) and reply.strip():
                        seq.append(reply)
                    seq = sanitize_turns(seq)
                    if len(seq) >= 2:
                        outs.append(seq)
            return outs or None

        # case 3: monolithic text with markers
        for field in ("text","dialog","dialogs","conversation"):
            if isinstance(r.get(field), str):
                txt = r[field]
                parts = [p.strip() for p in re.split(r"__eou__|__eot__|\n+", txt) if p.strip()]
                parts = sanitize_turns(parts)
                if len(parts) >= 2:
                    return [parts]
        return None

    for repo, cap in repos:
        try:
            raw = load_dataset(repo)
            tv = ensure_tv(raw)
            added_train_before = len(collected["train"])
            added_val_before   = len(collected["validation"])
            for split_name in ("train","validation"):
                ds = tv.get(split_name)
                if ds is None: continue
                seqs = []
                for r in ds:
                    ex = extract_seq_from_row(r)
                    if not ex:
                        continue
                    # ex may be a list of sequences (from nested utterances)
                    if ex and isinstance(ex[0], list):
                        seqs.extend(ex)
                    else:
                        seqs.append(ex)
                # flatten once more just in case
                flat = []
                for item in seqs:
                    if isinstance(item, list) and item and isinstance(item[0], str):
                        flat.append(item)
                    elif isinstance(item, list) and item and isinstance(item[0], list):
                        flat.extend(item)
                random.shuffle(flat)
                flat = flat[:cap]
                collected[split_name].extend(flat)
            print(f"  Persona OK: {repo} (+{len(collected['train']) - added_train_before} train / +{len(collected['validation']) - added_val_before} val)")
        except Exception as e:
            print(f"  Persona '{repo}' failed, skipping: {e}")
    total_tr, total_v = len(collected["train"]), len(collected["validation"])
    if total_tr == 0 and total_v == 0:
        print("Persona parquet mirrors failed; skipping PersonaChat.")
        return None
    print(f"PersonaChatParquet OK: total train={total_tr}, val={total_v}")
    return {"train": collected["train"], "validation": collected["validation"]}, "pre-extracted"

# ---------------- Pull sources ----------------
sources = []

dd  = load_dailydialog()
bst = load_blended_skill_talk()
pc  = load_persona_parquet()

if dd:  sources.append(("DailyDialog", dd[0],  dd[1]))
if bst: sources.append(("BlendedSkillTalk", bst[0], bst[1]))
if pc:  sources.append(("PersonaChatParquet",  pc[0],  pc[1]))

# ---------------- Format, cap, dedup, validate ----------------
def format_split(obj, split_name, kind):
    if kind == "pre-extracted":
        seqs = obj[split_name]
    else:
        ds = obj[split_name]
        seqs = [row_to_dialog_list(r) for r in ds]
    fixed = []
    for x in seqs:
        if not x: continue
        if isinstance(x, list): fixed.append([str(xx) for xx in x])
        else: fixed.append([str(x)])
    out = []
    for seq in fixed:
        s = format_dialog_list(seq)
        if s: out.append(s)
    return out

train_texts_all, val_texts_all = [], []
for name, ds_or_lists, kind in sources:
    print(f"\n--- Formatting {name} ---")
    t = format_split(ds_or_lists, "train", kind)
    v = format_split(ds_or_lists, "validation", kind)
    random.shuffle(t); t = t[:MAX_TRAIN_PER_DATASET]
    random.shuffle(v); v = v[:MAX_VAL_PER_DATASET]
    print(f"{name}: formatted train={len(t)}, val={len(v)}")
    train_texts_all.extend(t)
    val_texts_all.extend(v)

# Dedup
def norm_for_dedup(s: str) -> str:
    return re.sub(r'\s+', ' ', s.strip().lower())
def dedup_keep_first(texts):
    seen = set(); out = []
    for x in texts:
        k = norm_for_dedup(x)
        if k in seen: continue
        seen.add(k); out.append(x)
    return out

train_texts = dedup_keep_first(train_texts_all)
val_texts   = dedup_keep_first(val_texts_all)
random.shuffle(train_texts)
random.shuffle(val_texts)

print("\n=== Mixed corpus sizes (pre-validate) ===")
print(f"train: {len(train_texts)}  |  val: {len(val_texts)}")

# Validate (<bos>/<eos>, strict A/B alternation, per-line <eot>)
LINE_RE = re.compile(r'^(<A>|<B>)\s.+\s<eot>$')
def validate_dialog(text: str):
    errs = []
    if not text.startswith("<bos>\n"): errs.append("missing <bos> at start")
    if not text.rstrip().endswith("<eos>"): errs.append("missing <eos> at end")
    body = [ln for ln in text.splitlines() if ln not in ("<bos>", "<eos>")]
    if not body:
        errs.append("empty dialog body");
        return errs
    prev_speaker = None
    for i, ln in enumerate(body):
        m = LINE_RE.match(ln)
        if not m:
            errs.append(f"line {i} malformed: {ln[:120]}")
            continue
        spk = m.group(1)
        if prev_speaker == spk:
            errs.append(f"consecutive same speaker at lines {i-1},{i}")
        prev_speaker = spk
    return errs

def filter_bad(texts):
    keep = []; bad = 0
    for t in texts:
        e = validate_dialog(t)
        if e: bad += 1
        else: keep.append(t)
    return keep, bad

train_texts, bad_train = filter_bad(train_texts)
val_texts,   bad_val   = filter_bad(val_texts)

print(f"\nValidation — dropped: train {bad_train}, val {bad_val}")
print(f"Final counts — train {len(train_texts)}, val {len(val_texts)}")

if train_texts:
    print("\nPreview (formatted):\n", train_texts[0][:500])




Loading DailyDialog (elricwan/dailydialog)...
  -> splits: train=11806, val=1312
DailyDialog OK.
Loading BlendedSkillTalk (blended_skill_talk)...
BlendedSkillTalk OK: extracted train=4819, val=1009
Loading PersonaChat parquet mirrors...
  -> splits: train=17878, val=1000
  Persona OK: AlekseyKorshuk/persona-chat (+25000 train / +6801 val)
  -> splits: train=18000, val=2000
  Persona OK: Cynaptics/persona-chat (+0 train / +0 val)
PersonaChatParquet OK: total train=25000, val=6801

--- Formatting DailyDialog ---
DailyDialog: formatted train=11806, val=1312

--- Formatting BlendedSkillTalk ---
BlendedSkillTalk: formatted train=4819, val=1009

--- Formatting PersonaChatParquet ---
PersonaChatParquet: formatted train=25000, val=3000

=== Mixed corpus sizes (pre-validate) ===
train: 40935  |  val: 5312

Validation — dropped: train 0, val 0
Final counts — train 40935, val 5312

Preview (formatted):
 <bos>
<A> What a beautiful view, my sweetheart! <eot>
<B> It sure is.The Grand Canyon is truly

In [13]:
# ================================
# 2) Tokenize, encode, train TinyGPT (GPU full / CPU-lite)
#  — expects train_texts / val_texts from Cell 3
# ================================
!pip -q install sentencepiece

import os, math, time, random, re, numpy as np
import torch, torch.nn as nn
from torch.nn import functional as F
import sentencepiece as spm

assert 'train_texts' in globals() and 'val_texts' in globals(), "Run Cell 3 first."

# ---- Device + seeds
torch.manual_seed(1337); random.seed(1337)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)
torch.backends.cuda.matmul.allow_tf32 = True
if hasattr(torch, "set_float32_matmul_precision"):
    torch.set_float32_matmul_precision("high")

# ---- Hyperparams (GPU full defaults)
block_size   = 128
batch_size   = 64
base_lr      = 3e-4
warmup_steps = 1_000
max_steps    = 20_000
eval_every   = 1_000

# ---- CPU-lite fallback so it actually finishes
if device == "cpu":
    print("CPU detected → enabling CPU-Lite mode")
    block_size   = 64
    batch_size   = 16
    base_lr      = 6e-4
    warmup_steps = 200
    max_steps    = 1_200
    eval_every   = 200
    # cap total tokens to keep things moving
    _TARGET_TOKS_TRAIN = 1_200_000
    _TARGET_TOKS_VAL   =   150_000

# ================================
# Tokenizer (Unigram, 8k, byte fallback)
# ================================
corpus_path = "/content/mixed_dialog_corpus.txt"
with open(corpus_path, "w", encoding="utf-8") as f:
    f.writelines(t + "\n" for t in train_texts)

tok_path = "/content/spm_chat_8k.model"
!rm -f /content/spm_chat_8k.model /content/spm_chat_8k.vocab

spm.SentencePieceTrainer.Train(
    input=corpus_path,
    model_prefix="/content/spm_chat_8k",
    vocab_size=8000,
    model_type="unigram",
    byte_fallback=True,
    normalization_rule_name="nmt_nfkc",
    pad_id=0, unk_id=1, bos_id=2, eos_id=3,
    user_defined_symbols=["<bos>","<eos>","<A>","<B>","<eot>"]
)
sp = spm.SentencePieceProcessor(model_file=tok_path)
vocab_size = sp.vocab_size()
pad_id, unk_id, bos_id, eos_id = 0, 1, 2, 3
print("Vocab size:", vocab_size, {t: sp.piece_to_id(t) for t in ["<bos>","<eos>","<A>","<B>","<eot>"]})

# ================================
# Encode → chunk dataset
# ================================
def encode_texts(texts):
    return np.array([tid for t in texts for tid in sp.encode(t, out_type=int)], dtype=np.int32)

train_ids = encode_texts(train_texts)
val_ids   = encode_texts(val_texts)
print("Encoded tokens:", len(train_ids), len(val_ids))

# CPU-lite: subsample tokens before making datasets
if device == "cpu":
    if len(train_ids) > _TARGET_TOKS_TRAIN + 1:
        train_ids = train_ids[:_TARGET_TOKS_TRAIN + 1]
    if len(val_ids) > _TARGET_TOKS_VAL + 1:
        val_ids = val_ids[:_TARGET_TOKS_VAL + 1]

class LMChunkDataset(torch.utils.data.Dataset):
    def __init__(self, ids, block_size):
        L = (len(ids) - 1) // block_size
        self.input  = torch.tensor(ids[:L*block_size],    dtype=torch.long).view(L, block_size)
        self.target = torch.tensor(ids[1:L*block_size+1], dtype=torch.long).view(L, block_size)
    def __len__(self): return self.input.size(0)
    def __getitem__(self, i): return self.input[i], self.target[i]

train_ds = LMChunkDataset(train_ids, block_size)
val_ds   = LMChunkDataset(val_ids, block_size)

num_workers = 0 if device == "cpu" else 2
pin_memory  = (device == "cuda")
train_loader = torch.utils.data.DataLoader(
    train_ds, batch_size=batch_size, shuffle=True, drop_last=True,
    num_workers=num_workers, pin_memory=pin_memory, persistent_workers=(num_workers>0)
)
val_loader = torch.utils.data.DataLoader(
    val_ds, batch_size=batch_size, shuffle=False, drop_last=True,
    num_workers=num_workers, pin_memory=pin_memory, persistent_workers=(num_workers>0)
)
print("Batches:", len(train_ds), len(val_ds))

# ================================
# TinyGPT (GPU full vs CPU-Lite sizes)
# ================================
class CausalSelfAttention(nn.Module):
    def __init__(self, d_model=256, n_heads=4, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0
        self.n_heads = n_heads
        self.d_head  = d_model // n_heads
        self.qkv   = nn.Linear(d_model, 3*d_model, bias=False)
        self.proj  = nn.Linear(d_model, d_model, bias=False)
        self.attn_drop  = nn.Dropout(dropout)
        self.resid_drop = nn.Dropout(dropout)
        self.register_buffer("mask", None, persistent=False)
    def forward(self, x):
        B, T, C = x.size()
        if (self.mask is None) or (self.mask.size(-1) < T):
            self.mask = torch.tril(torch.ones(T, T, device=x.device)).view(1,1,T,T)
        qkv = self.qkv(x); q, k, v = qkv.split(C, dim=2)
        def split(t): return t.view(B, T, self.n_heads, self.d_head).transpose(1,2)
        q, k, v = map(split, (q, k, v))
        att = (q @ k.transpose(-2, -1)) / (self.d_head ** 0.5)
        att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
        att = F.softmax(att, dim=-1)
        att = self.attn_drop(att)
        y = att @ v
        y = y.transpose(1,2).contiguous().view(B, T, C)
        y = self.resid_drop(self.proj(y))
        return y

class MLP(nn.Module):
    def __init__(self, d_model=256, d_ff=1024, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(d_ff, d_model)
        self.drop = nn.Dropout(dropout)
    def forward(self, x): return self.drop(self.fc2(self.act(self.fc1(x))))

class Block(nn.Module):
    def __init__(self, d_model=256, n_heads=4, d_ff=1024, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = CausalSelfAttention(d_model, n_heads, dropout)
        self.ln2 = nn.LayerNorm(d_model)
        self.mlp  = MLP(d_model, d_ff, dropout)
        self.drop_res = nn.Dropout(dropout)
    def forward(self, x):
        x = x + self.drop_res(self.attn(self.ln1(x)))
        x = x + self.drop_res(self.mlp(self.ln2(x)))
        return x

class TinyGPT(nn.Module):
    def __init__(self, vocab_size, d_model=256, n_heads=4, d_ff=1024, n_layers=5, max_seq_len=128, dropout=0.1, pad_id=0):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
        self.pos_embed   = nn.Embedding(max_seq_len, d_model)
        self.drop  = nn.Dropout(dropout)
        self.blocks = nn.ModuleList([Block(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)])
        self.ln_f  = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        self.lm_head.weight = self.token_embed.weight
        self.apply(self._init)
    @staticmethod
    def _init(m):
        if isinstance(m, (nn.Linear, nn.Embedding)):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)
        if isinstance(m, nn.Linear) and m.bias is not None:
            nn.init.zeros_(m.bias)
    def forward(self, idx):
        B, T = idx.size()
        pos = torch.arange(0, T, device=idx.device).unsqueeze(0)
        x = self.token_embed(idx) + self.pos_embed(pos)
        x = self.drop(x)
        for blk in self.blocks: x = blk(x)
        x = self.ln_f(x)
        return self.lm_head(x)

# choose size by device
if device == "cpu":
    model_cfg = dict(d_model=224, n_heads=4, d_ff=896, n_layers=4, max_seq_len=block_size)  # ~3.8–4.2M params
else:
    model_cfg = dict(d_model=256, n_heads=4, d_ff=1024, n_layers=5, max_seq_len=block_size)
model = TinyGPT(vocab_size=vocab_size, pad_id=pad_id, **model_cfg).to(device)
print(f"Total parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")

# ================================
# Train (AdamW, warmup+cosine, AMP, grad clip)
# ================================
def get_lr(step, warmup, max_steps, base_lr):
    if step < warmup: return base_lr * step / max(1, warmup)
    progress = (step - warmup) / max(1, (max_steps - warmup))
    return 0.1*base_lr + 0.9*base_lr * 0.5 * (1 + math.cos(math.pi * progress))

decay, no_decay = [], []
for n,p in model.named_parameters():
    if not p.requires_grad: continue
    if n.endswith("bias") or "ln" in n.lower() or "layernorm" in n.lower():
        no_decay.append(p)
    else:
        decay.append(p)
optimizer = torch.optim.AdamW(
    [{"params": decay,    "weight_decay": 0.05},
     {"params": no_decay, "weight_decay": 0.00}],
    lr=base_lr, betas=(0.9, 0.95)
)

amp_dtype = torch.bfloat16 if (device=="cuda" and torch.cuda.is_bf16_supported()) else torch.float16
scaler = torch.amp.GradScaler("cuda", enabled=(device=="cuda"))

@torch.no_grad()
def run_eval():
    model.eval(); losses = []
    for xb, yb in val_loader:
        xb, yb = xb.to(device), yb.to(device)
        with torch.amp.autocast("cuda", dtype=amp_dtype, enabled=(device=="cuda")):
            logits = model(xb)
            loss = F.cross_entropy(logits.view(-1, vocab_size), yb.view(-1), ignore_index=pad_id)
        losses.append(loss.item())
    model.train()
    return sum(losses)/len(losses) if losses else float("nan")

ckpt_path = "/content/tinygpt_best.pt"
best_val = float("inf"); global_step = 0; t0 = time.time()
model.train()

for epoch in range(999999):
    for xb, yb in train_loader:
        global_step += 1
        lr = get_lr(global_step, warmup_steps, max_steps, base_lr)
        for pg in optimizer.param_groups: pg["lr"] = lr

        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast("cuda", dtype=amp_dtype, enabled=(device=="cuda")):
            logits = model(xb)
            loss = F.cross_entropy(logits.view(-1, vocab_size), yb.view(-1), ignore_index=pad_id)
        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer); scaler.update()

        if global_step % 50 == 0:
            print(f"step {global_step}/{max_steps} lr {lr:.2e} loss {loss.item():.3f}")

        if global_step % eval_every == 0:
            val_loss = run_eval()
            print(f"[eval {global_step}] val_loss {val_loss:.3f}  (+{time.time()-t0:.1f}s)")
            t0 = time.time()
            if val_loss < best_val:
                best_val = val_loss
                torch.save({"model": model.state_dict(),
                            "config": {**model_cfg,
                                       "vocab_size": vocab_size,
                                       "block_size": block_size,
                                       "tokenizer": tok_path}},
                           ckpt_path)
                print("✓ saved", ckpt_path)

        if global_step >= max_steps: break
    if global_step >= max_steps: break

# ================================
# Sampling (top-k + top-p + sign-aware rep penalty)
# ================================
def detok_cleanup(txt: str) -> str:
    return (txt.replace(" ,", ",").replace(" .", ".").replace(" !", "!")
              .replace(" ?", "?").replace(" ' s","'s").replace(" ' m","'m")
              .replace(" ' ve","'ve").replace(" ' re","'re")
              .replace(" ' d","'d").replace(" ' ll","'ll").replace(" n't","n't")).strip()

@torch.no_grad()
def sample(model, sp, prompt,
           max_new_tokens=120, temperature=0.65, top_p=0.90, top_k=50,
           min_tokens_before_stop=24, repetition_penalty=1.15, penalty_ctx=80,
           include_prompt=False):
    model.eval()
    seed = "<bos>\n<A> " + prompt.strip() + " <eot>\n<B> "
    x = torch.tensor(sp.encode(seed, out_type=int), dtype=torch.long, device=device)[None, ...]
    start_len = x.size(1)
    STOP_STRINGS = ["<eot>", "<eos>", "<bos>", "<A>", "<B>"]
    STOP_SEQS = [sp.encode(s, out_type=int) for s in STOP_STRINGS]
    FIRST_TOKENS = {seq[0] for seq in STOP_SEQS if len(seq) > 0}

    def ends_with(seq, suffix):
        L = len(suffix); return L > 0 and len(seq) >= L and seq[-L:] == suffix
    def find_first_stop(gen_ids):
        N = len(gen_ids); cut = N
        for i in range(N):
            for s in STOP_SEQS:
                L = len(s)
                if L and i+L <= N and gen_ids[i:i+L] == s:
                    cut = min(cut, i)
        return cut

    for _ in range(max_new_tokens):
        if x.size(1) > block_size: x = x[:, -block_size:]
        with torch.amp.autocast("cuda", dtype=amp_dtype, enabled=(device=="cuda")):
            logits = model(x)[:, -1, :]

        # temperature
        logits = logits / max(1e-8, temperature)

        # sign-aware repetition penalty
        if repetition_penalty and penalty_ctx > 0:
            recent = x[0, max(0, x.size(1) - penalty_ctx):].tolist()
            for t in set(recent):
                if logits[0, t] > 0: logits[0, t] /= repetition_penalty
                else:                logits[0, t] *= repetition_penalty

        # block starting stop seq before min length
        gen_len = x.size(1) - start_len
        if gen_len < min_tokens_before_stop:
            for t0 in FIRST_TOKENS: logits[0, t0] = -float("inf")

        probs = torch.softmax(logits, dim=-1)

        # top-k
        if top_k is not None and top_k > 0:
            topk_vals, topk_idx = torch.topk(probs, k=min(top_k, probs.size(-1)))
            mask = torch.ones_like(probs, dtype=torch.bool); mask.scatter_(1, topk_idx, False)
            probs[mask] = 0

        # top-p
        sorted_probs, sorted_idx = torch.sort(probs, descending=True)
        cum = torch.cumsum(sorted_probs, dim=-1)
        mask = cum > top_p; mask[..., 0] = False
        sorted_probs[mask] = 0
        sorted_probs /= sorted_probs.sum(dim=-1, keepdim=True)

        next_id = sorted_idx.gather(-1, torch.multinomial(sorted_probs, 1))
        x = torch.cat([x, next_id], dim=1)

        # early stop
        gen_ids = x[0, start_len:].tolist()
        if gen_len + 1 >= min_tokens_before_stop and any(ends_with(gen_ids, s) for s in STOP_SEQS):
            break

    gen_ids = x[0, start_len:].tolist()
    gen_ids = gen_ids[:find_first_stop(gen_ids)]
    txt = detok_cleanup(sp.decode(gen_ids))
    return f"You: {prompt.strip()}\nBot: {txt}" if include_prompt else txt

# ---- Quick samples (if a checkpoint exists)
ckpt_path = "/content/tinygpt_best.pt"
if os.path.exists(ckpt_path):
    sd = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(sd["model"])
    print("Loaded best checkpoint.")

print(sample(model, sp, "Hi, how are you?", max_new_tokens=80, include_prompt=True))
print("----")
print(sample(model, sp, "What's your favorite movie?", max_new_tokens=80, include_prompt=True))









Device: cpu
CPU detected → enabling CPU-Lite mode
Vocab size: 8000 {'<bos>': 4, '<eos>': 5, '<A>': 6, '<B>': 7, '<eot>': 8}
Encoded tokens: 6418580 909924
Batches: 18750 2343
Total parameters: 4.22M
step 50/1200 lr 1.50e-04 loss 7.332
step 100/1200 lr 3.00e-04 loss 5.277
step 150/1200 lr 4.50e-04 loss 4.395
step 200/1200 lr 6.00e-04 loss 4.245
[eval 200] val_loss 4.316  (+158.3s)
✓ saved /content/tinygpt_best.pt
step 250/1200 lr 5.97e-04 loss 4.297
step 300/1200 lr 5.87e-04 loss 3.873
step 350/1200 lr 5.71e-04 loss 3.868
step 400/1200 lr 5.48e-04 loss 3.712
[eval 400] val_loss 3.924  (+158.9s)
✓ saved /content/tinygpt_best.pt
step 450/1200 lr 5.21e-04 loss 3.884
step 500/1200 lr 4.89e-04 loss 3.308
step 550/1200 lr 4.53e-04 loss 3.262
step 600/1200 lr 4.13e-04 loss 3.505
[eval 600] val_loss 3.777  (+158.0s)
✓ saved /content/tinygpt_best.pt
step 650/1200 lr 3.72e-04 loss 3.861
step 700/1200 lr 3.30e-04 loss 3.595
step 750/1200 lr 2.88e-04 loss 3.844
step 800/1200 lr 2.47e-04 loss 3.539


In [None]:
# ============================
# Chat REPL for <A>/<B>/<eot> (no cut-offs, anti-repeat, best-of)
# ============================
import os, re, torch, sentencepiece as spm
from torch.nn import functional as F

# --- Paths
ckpt_path = next((p for p in [
    "/content/tinygpt_best.pt",
    "/kaggle/working/tinygpt_best.pt",
] if os.path.exists(p)), None)
assert ckpt_path, "Missing checkpoint (e.g. /content/tinygpt_best.pt)."

device = "cuda" if torch.cuda.is_available() else "cpu"
try:
    amp_dtype = torch.bfloat16 if (device=="cuda" and torch.cuda.is_bf16_supported()) else torch.float16
except AttributeError:
    amp_dtype = torch.float16

# --- Load ckpt + tokenizer
ckpt = torch.load(ckpt_path, map_location=device)
cfg  = ckpt.get("config", {})

tok_candidates = [cfg.get("tokenizer"), "/content/spm_chat_8k.model", "/content/spm_dd_4k.model"]
tok_candidates = [p for p in tok_candidates if p]
sp_model_path = next((p for p in tok_candidates if os.path.exists(p)), None)
assert sp_model_path, f"Tokenizer .model not found. Looked for: {tok_candidates}"

sp = spm.SentencePieceProcessor(model_file=sp_model_path)
ID = {k: sp.piece_to_id(k) for k in ("<bos>","<eos>","<A>","<B>","<eot>")}

print("Device:", device)
print("Loaded tokenizer:", sp_model_path)
print("Special IDs:", ID)

# --- Model config
vocab_size = cfg.get("vocab_size", sp.vocab_size())
pad_id     = 0
ctx_len    = cfg.get("block_size", cfg.get("max_seq_len", 128))
model_cfg  = dict(
    d_model = cfg.get("d_model", 256),
    n_heads = cfg.get("n_heads", 4),
    d_ff    = cfg.get("d_ff", 1024),
    n_layers= cfg.get("n_layers", 5),
)

# TinyGPT must be defined in the training cell
try:
    TinyGPT
except NameError as e:
    raise RuntimeError("Run the training cell that defines TinyGPT before this REPL.") from e

model = TinyGPT(
    vocab_size=vocab_size,
    d_model=model_cfg["d_model"],
    n_heads=model_cfg["n_heads"],
    d_ff=model_cfg["d_ff"],
    n_layers=model_cfg["n_layers"],
    max_seq_len=ctx_len,
    dropout=0.1, pad_id=pad_id
).to(device)
model.load_state_dict(ckpt["model"], strict=True)
model.eval()
print(f"Loaded model. Params={sum(p.numel() for p in model.parameters())/1e6:.2f}M | Context len={ctx_len}")

# ---------- Helpers ----------
_CONTRACTIONS = [
    (r"\b(I|you|he|she|it|we|they)\s+'\s*m\b", r"\1'm"),
    (r"\b(I|you|he|she|it|we|they)\s+'\s*re\b", r"\1're"),
    (r"\b(I|you|he|she|it|we|they)\s+'\s*ve\b", r"\1've"),
    (r"\b(should|could|would)\s+'\s*ve\b", r"\1've"),
    (r"\b(\w)\s+'\s*s\b", r"\1's"),
    (r"\b(\w)\s+'\s*d\b", r"\1'd"),
    (r"\b(\w)\s+'\s*ll\b", r"\1'll"),
    (r"\b(n)\s+'\s*t\b", r"\1't"),
]
def detok_cleanup(txt: str) -> str:
    txt = (txt.replace(" ,", ",").replace(" .", ".").replace(" !", "!")
              .replace(" ?", "?").replace("  ", " ").strip())
    for pat, rep in _CONTRACTIONS:
        txt = re.sub(pat, rep, txt, flags=re.IGNORECASE)
    return txt

MAX_CTX_TOKENS = ctx_len
# Reserve half the context for generation, but at least 16 and at most 64 tokens.
RESERVED_GEN_TOKENS = max(16, min(64, ctx_len // 2))

def build_seed_from_history(history, user_msg):
    """history: list[(user, bot)]"""
    def convo(turns, last_user):
        lines = []
        for u, b in turns:
            if u: lines.append(f"<A> {u.strip()} <eot>")
            if b: lines.append(f"<B> {b.strip()} <eot>")
        lines.append(f"<A> {last_user.strip()} <eot>")
        return "<bos>\n" + "\n".join(lines) + "\n<B> "
    seed = convo(history, user_msg)
    ids  = sp.encode(seed, out_type=int)
    # Trim oldest history to keep room for generation
    while len(ids) > (MAX_CTX_TOKENS - RESERVED_GEN_TOKENS) and history:
        history.pop(0)
        seed = convo(history, user_msg)
        ids  = sp.encode(seed, out_type=int)
    return seed

# --- Candidate scoring (for best-of)
def _repetition_score(ids, n=3):
    if len(ids) < n: return 0.0
    seen, reps = set(), 0
    for i in range(len(ids)-n+1):
        ng = tuple(ids[i:i+n])
        if ng in seen: reps += 1
        seen.add(ng)
    return reps / max(1, len(ids)-n+1)

def _quality_score(text, ids, target_len=40):
    uniq = len(set(ids)) / max(1, len(ids))
    rep3 = _repetition_score(ids, n=3)
    punct = 1.0 if any(ch in text for ch in ".!?") else 0.0
    len_pen = abs(len(ids) - target_len) / target_len
    return (1.2*uniq) + (0.2*punct) - (0.8*rep3) - (0.4*len_pen)

# ---------- Generation (fix cut-offs) ----------
def _ends_with(seq, suffix):
    L = len(suffix); return L>0 and len(seq)>=L and seq[-L:] == suffix

def _first_stop_index(gen_ids, stop_seqs):
    N = len(gen_ids); cut = N
    for i in range(N):
        for s in stop_seqs:
            L = len(s)
            if L and i+L <= N and gen_ids[i:i+L] == s:
                cut = min(cut, i)
    return cut

def _block_repeated_ngrams(logits_row, full_ids, n=3, window=256):
    ctx = full_ids[-window:]
    if len(ctx) < n-1: return
    prev = tuple(ctx[-(n-1):])
    forbid = set()
    for i in range(len(ctx)-(n-1)):
        if tuple(ctx[i:i+n-1]) == prev:
            forbid.add(ctx[i+n-1])
    for t in forbid:
        logits_row[t] = -float("inf")

@torch.no_grad()
def generate_from_seed(seed,
                       temperature=0.65, top_p=0.88, top_k=80,
                       max_new_tokens=96, min_tokens_before_stop=18,
                       repetition_penalty=1.28, penalty_ctx=220,
                       no_repeat_ngram=3,
                       # stop only on <eot>/<eos>
                       eot_bias_after=None, eot_bias=1.2,
                       best_of=2):
    # Determine when to start nudging <eot>
    if eot_bias_after is None:
        eot_bias_after = max(28, min(56, ctx_len // 2 + 12))

    STOP_TAGS = ["<eot>", "<eos>"]              # <-- only these cause stop/trim
    STOP_SEQS = [sp.encode(s, out_type=int) for s in STOP_TAGS]
    FIRST_TOKENS = {seq[0] for seq in STOP_SEQS if len(seq) > 0}
    # Always forbid generating control tags in the middle of the reply
    FORBID_TOKENS = [ID["<A>"], ID["<B>"], ID["<bos>"]]

    def sample_once():
        x = torch.tensor(sp.encode(seed, out_type=int), dtype=torch.long, device=device)[None, ...]
        start_len = x.size(1)
        for step in range(max_new_tokens):
            if x.size(1) > ctx_len: x = x[:, -ctx_len:]

            with torch.amp.autocast("cuda", dtype=amp_dtype, enabled=(device=="cuda")):
                logits = model(x)[:, -1, :]

            # temperature
            logits = logits / max(1e-8, temperature)

            # repetition penalty
            if repetition_penalty and penalty_ctx > 0:
                recent = x[0, max(0, x.size(1)-penalty_ctx):].tolist()
                for t in set(recent):
                    if logits[0, t] > 0: logits[0, t] /= repetition_penalty
                    else:                logits[0, t] *= repetition_penalty

            # n-gram blocking
            if no_repeat_ngram and no_repeat_ngram >= 2:
                _block_repeated_ngrams(logits[0], x[0].tolist(), n=no_repeat_ngram, window=ctx_len)

            # gently encourage finishing after some length
            gen_len = x.size(1) - start_len
            if gen_len >= eot_bias_after:
                logits[0, ID["<eot>"]] += eot_bias

            # NEVER allow control tags mid-reply
            for t in FORBID_TOKENS:
                logits[0, t] = -float("inf")

            # before min length, also forbid starting any stop sequence
            if gen_len < min_tokens_before_stop:
                for t0 in FIRST_TOKENS:
                    logits[0, t0] = -float("inf")

            # probs
            probs = torch.softmax(logits, dim=-1)
            # top-k
            if top_k and top_k > 0:
                topk_vals, topk_idx = torch.topk(probs, k=min(top_k, probs.size(-1)))
                mask = torch.ones_like(probs, dtype=torch.bool); mask.scatter_(1, topk_idx, False)
                probs[mask] = 0
            # top-p
            sorted_probs, sorted_idx = torch.sort(probs, descending=True)
            cum = torch.cumsum(sorted_probs, dim=-1)
            mask = cum > top_p; mask[..., 0] = False
            sorted_probs[mask] = 0
            sorted_probs /= sorted_probs.sum(dim=-1, keepdim=True)

            next_id = sorted_idx.gather(-1, torch.multinomial(sorted_probs, 1))
            x = torch.cat([x, next_id], dim=1)

            # early stop after min length if tail matches stop
            gen_ids = x[0, start_len:].tolist()
            if gen_len + 1 >= min_tokens_before_stop:
                if any(_ends_with(gen_ids, s) for s in STOP_SEQS):
                    break

        # trim only on <eot>/<eos>
        gen_ids = x[0, start_len:].tolist()
        cut_at = _first_stop_index(gen_ids, STOP_SEQS)
        gen_ids = gen_ids[:cut_at]
        text = sp.decode(gen_ids)
        return detok_cleanup(text), gen_ids

    if best_of <= 1:
        text, _ids = sample_once()
        return text
    cands = [sample_once() for _ in range(best_of)]
    scored = [( _quality_score(t, ids), t) for (t, ids) in cands]
    scored.sort(key=lambda x: x[0], reverse=True)
    return scored[0][1]

# ---- REPL ----
print("\nChat ready. Type 'reset' to clear history, 'exit' to quit.")
print("Runtime controls: /temp 0.65   /topp 0.88   /topk 80   /norep 3   /eot 1.2 40   /bestof 2\n")

history = []
params = dict(temperature=0.65, top_p=0.88, top_k=80,
              repetition_penalty=1.28, penalty_ctx=220,
              no_repeat_ngram=3, eot_bias=1.2, eot_bias_after=None,
              best_of=2)

while True:
    try:
        user = input("You: ").strip()
    except EOFError:
        break
    if not user: continue
    low = user.lower()
    if low in {"exit","quit"}: break
    if low=="reset":
        history.clear(); print("Bot: (history cleared)\n"); continue

    # live knobs
    try:
        if low.startswith("/temp "):   params["temperature"] = float(user.split()[1]); print("(ok)"); continue
        if low.startswith("/topp "):   params["top_p"] = float(user.split()[1]); print("(ok)"); continue
        if low.startswith("/topk "):   params["top_k"] = int(user.split()[1]); print("(ok)"); continue
        if low.startswith("/norep "):  params["no_repeat_ngram"] = int(user.split()[1]); print("(ok)"); continue
        if low.startswith("/eot "):
            _, b, a = user.split(); params["eot_bias"]=float(b); params["eot_bias_after"]=int(a); print("(ok)"); continue
        if low.startswith("/bestof "): params["best_of"] = max(1, int(user.split()[1])); print("(ok)"); continue
    except Exception:
        print("(bad value)"); continue

    seed = build_seed_from_history(history.copy(), user)
    bot  = generate_from_seed(seed, **params)
    print(f"Bot: {bot}\n")
    history.append((user, bot))





Device: cpu
Loaded tokenizer: /content/spm_chat_8k.model
Special IDs: {'<bos>': 4, '<eos>': 5, '<A>': 6, '<B>': 7, '<eot>': 8}
Loaded model. Params=4.22M | Context len=64

Chat ready. Type 'reset' to clear history, 'exit' to quit.
Runtime controls: /temp 0.65   /topp 0.88   /topk 80   /norep 3   /eot 1.2 40   /bestof 2

You: do you like fish?
Bot: i am doing good, how are you today?? just got back from a lot of the weather.

You: And how about the fish?
Bot: I'm from this evening. Do you like to get on the best now? I need a little time for the last week, but not have been on them and just married.

