In [2]:
# === Imports
from pathlib import Path
import re
from collections import Counter

import numpy as np
import torch


In [3]:
# ==== CONFIG ====
TXT_PATH = "words_250000_train.txt"  # path to provided word list
SEED = 42
TRAIN_FRACTION = 0.975

# Model dims (match what you'll load in guess())
VOCAB_SIZE = 29
DMODEL = 256
N_HEADS = 8
N_LAYERS = 4
D_FF = 512
MAX_LEN = 32  # >= 1 + max word length for Hangman [CLS] input

# Pretrain
PRE_EPOCHS = 100
PRE_BATCH  = 2048
PRE_LR     = 1e-3

# Finetune (Hangman)
FT_EPOCHS  = 100
FT_BATCH   = 2048
FT_LR      = 3e-4
PERMS_PER_WORD = 4  # <- number of random unique-letter permutations per word for data aug

# Device
import torch
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ==== IMPORTS ====
import re, random, string, math
from collections import defaultdict
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

# Seeding
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# ==== TOKENIZER ====
class CharTokenizer:
    """
    Vocab ids:
      0 [PAD]  (unused)
      1 [MASK] (Hangman inputs only)
      2 [CLS]  (Hangman inputs only)
      3..28 'a'..'z'
    """
    def __init__(self):
        self.pad_id  = 0
        self.mask_id = 1
        self.cls_id  = 2
        self.letters = [chr(i) for i in range(97,123)]
        self.letter2id = {ch: 3 + (ord(ch)-97) for ch in self.letters}
        self.vocab_size = 29

    def encode_board(self, pattern: str):
        # pattern e.g. "_pp_e" → [MASK, 'p','p',MASK,'e']
        return [self.mask_id if ch == "_" else self.letter2id[ch] for ch in pattern]

    def with_cls(self, ids):  # prepend [CLS]
        return [self.cls_id] + ids

tok = CharTokenizer()


In [4]:
def load_words(path):
    with open(path, "r", encoding="utf-8") as f:
        ws = [w.strip().lower() for w in f]
    return [w for w in ws if re.fullmatch(r"[a-z]+", w)]

all_words = load_words(TXT_PATH)
print(f"#words: {len(all_words)} | min_len={min(map(len,all_words))} | max_len={max(map(len,all_words))}")

# Split by word
idx_all = list(range(len(all_words)))
random.shuffle(idx_all)
cut = int(TRAIN_FRACTION * len(idx_all))
train_idx = set(idx_all[:cut])
val_idx   = set(idx_all[cut:])

def bucket_by_len(index_set):
    by_len = defaultdict(list)
    for i in index_set:
        by_len[len(all_words[i])].append(i)
    return by_len

train_by_len = bucket_by_len(train_idx)
val_by_len   = bucket_by_len(val_idx)

print("train buckets (top 10 by count):",
      dict(list(sorted(((L,len(v)) for L,v in train_by_len.items()), key=lambda x:-x[1]))[:10]))


#words: 227300 | min_len=1 | max_len=29
train buckets (top 10 by count): {9: 30136, 8: 29686, 10: 26291, 7: 25243, 11: 22201, 6: 19079, 12: 17734, 13: 12645, 5: 10993, 14: 8486}


In [5]:
class TinyCharTransformer(nn.Module):
    def __init__(self, vocab_size=VOCAB_SIZE, d_model=DMODEL, n_heads=N_HEADS,
                 n_layers=N_LAYERS, d_ff=D_FF, max_len=MAX_LEN):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb   = nn.Embedding(max_len, d_model)
        enc = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads,
                                         dim_feedforward=d_ff, batch_first=True)
        self.encoder = nn.TransformerEncoder(enc, num_layers=n_layers)
        self.cls_head = nn.Linear(d_model, 26)  # used per-position in pretrain; [CLS] in finetune
        self.max_len = max_len
        self.d_model = d_model

    def _encode(self, token_ids, attn_mask=None):
        # token_ids [B,T]
        B, T = token_ids.size()
        if T > self.max_len:
            raise ValueError(f"seq len {T} > max_len {self.max_len}")
        pos = torch.arange(T, device=token_ids.device).unsqueeze(0).expand(B, T)
        x = self.token_emb(token_ids) + self.pos_emb(pos)
        h = self.encoder(x, mask=attn_mask)  # [B,T,d]
        return h

    # Pretrain: project every position
    def forward_per_pos(self, token_ids, attn_mask):
        h = self._encode(token_ids, attn_mask)
        return self.cls_head(h)  # [B,T,26]

    # Finetune: project [CLS] pooled state
    def forward_cls(self, token_ids):
        h = self._encode(token_ids, attn_mask=None)
        cls = h[:,0,:]
        return self.cls_head(cls)  # [B,26]

def init_weights(model):
    for m in model.modules():
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None: nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Embedding):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight); nn.init.zeros_(m.bias)

model = TinyCharTransformer().to(DEVICE)
#init_weights(model)


In [9]:
def letters_to_ids(word):
    # 'a'..'z' → 3..28
    return np.asarray([tok.letter2id[ch] for ch in word], dtype=np.int64)

def causal_mask(T, device):
    # True=masked (future positions)
    return torch.triu(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=1)

def causal_bidir_batches_by_len(words_by_len, batch_size=1024, shuffle=True, seed=SEED):
    rng = np.random.default_rng(seed)
    lengths = list(words_by_len.keys())
    if shuffle: rng.shuffle(lengths)
    for L in lengths:
        idxs = words_by_len[L]
        if shuffle: rng.shuffle(idxs)

        # Pre-encode forward and reversed sequences (len>=2 only)
        forward, reverse = [], []
        for i in idxs:
            w = all_words[i]
            if len(w) < 2: continue
            ids = letters_to_ids(w)
            rid = ids[::-1]
            forward.append(ids)
            reverse.append(rid)

        # Interleave forward+reverse to improve mixing
        pairs = list(zip(forward, reverse))
        if shuffle: rng.shuffle(pairs)

        # Build batches
        for s in range(0, len(pairs), batch_size):
            chunk = pairs[s:s+batch_size]
            if not chunk: continue
            fwd = [p[0] for p in chunk]
            rev = [p[1] for p in chunk]

            # Inputs/targets for next-char: ([:-1] → [1:])
            Xf = np.stack([x[:-1] for x in fwd], axis=0)
            Yf = np.stack([x[1:]  for x in fwd], axis=0)
            Xr = np.stack([x[:-1] for x in rev], axis=0)
            Yr = np.stack([x[1:]  for x in rev], axis=0)

            # Concatenate forward + reverse along batch dim
            X = np.concatenate([Xf, Xr], axis=0)
            Y = np.concatenate([Yf, Yr], axis=0)

            yield (
                torch.from_numpy(X),         # [2B, T-1] ids in 3..28
                torch.from_numpy(Y) - 3,     # [2B, T-1] -> 0..25
                L
            )

def pretrain_next_letter_bidir(model, train_by_len, val_by_len,
                               epochs=PRE_EPOCHS, batch_size=PRE_BATCH,
                               lr=PRE_LR, log_every=200):
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)

    def run_epoch(by_len, train=True):
        model.train(train)
        tot_tok, tot_loss, steps = 0, 0.0, 0
        gen = causal_bidir_batches_by_len(by_len, batch_size=batch_size,
                                          shuffle=True, seed=np.random.randint(10**9))
        for X, Y, L in gen:
            X, Y = X.to(DEVICE), Y.to(DEVICE)  # [B, T]
            B, T = X.shape
            attn = causal_mask(T, X.device)
            with torch.set_grad_enabled(train):
                logits = model.forward_per_pos(X, attn)         # [B,T,26]
                loss = F.cross_entropy(logits.reshape(-1,26), Y.reshape(-1))
                if train:
                    opt.zero_grad(set_to_none=True)
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    opt.step()
            # bookkeeping
            ntok = Y.numel()
            tot_tok  += ntok
            tot_loss += loss.item() * ntok
            steps += 1
            if log_every and steps % log_every == 0:
                ppl = float(np.exp(tot_loss / max(tot_tok,1)))
                print(f"  [{'train' if train else 'val'}] step {steps:>5} | len={L:<2} | loss/tok={tot_loss/tot_tok:.4f} | ppl={ppl:.2f}")
        avg_loss = tot_loss / max(tot_tok,1)
        return avg_loss, float(np.exp(avg_loss))

    for ep in range(1, epochs+1):
        print(f"\n=== BiDir Next-Char Pretrain {ep}/{epochs} — TRAIN ===")
        tr_loss, tr_ppl = run_epoch(train_by_len, train=True)
        print(f"=== BiDir Next-Char Pretrain {ep}/{epochs} — VAL   ===")
        va_loss, va_ppl = run_epoch(val_by_len,   train=False)
        print(f"[NXT {ep}/{epochs}] loss/tok {tr_loss:.4f}/{va_loss:.4f} | ppl {tr_ppl:.2f}/{va_ppl:.2f}")

    ckpt_pre = {
        "config": {"vocab_size": VOCAB_SIZE, "d_model": DMODEL, "n_heads": N_HEADS,
                   "n_layers": N_LAYERS, "d_ff": D_FF, "max_len": MAX_LEN},
        "state_dict": model.state_dict(),
    }
    torch.save(ckpt_pre, "m6_bidir_pretrained.pt")
    print("Saved → m6_bidir_pretrained.pt")
    return model

# RUN PRETRAIN (forward + reversed)
model = pretrain_next_letter_bidir(model, train_by_len, val_by_len)



=== BiDir Next-Char Pretrain 1/100 — TRAIN ===
=== BiDir Next-Char Pretrain 1/100 — VAL   ===
[NXT 1/100] loss/tok 2.6154/2.6078 | ppl 13.67/13.57

=== BiDir Next-Char Pretrain 2/100 — TRAIN ===
=== BiDir Next-Char Pretrain 2/100 — VAL   ===
[NXT 2/100] loss/tok 2.4709/2.4314 | ppl 11.83/11.37

=== BiDir Next-Char Pretrain 3/100 — TRAIN ===
=== BiDir Next-Char Pretrain 3/100 — VAL   ===
[NXT 3/100] loss/tok 2.4092/2.3256 | ppl 11.12/10.23

=== BiDir Next-Char Pretrain 4/100 — TRAIN ===
=== BiDir Next-Char Pretrain 4/100 — VAL   ===
[NXT 4/100] loss/tok 2.3402/2.3724 | ppl 10.38/10.72

=== BiDir Next-Char Pretrain 5/100 — TRAIN ===
=== BiDir Next-Char Pretrain 5/100 — VAL   ===
[NXT 5/100] loss/tok 2.3038/2.3637 | ppl 10.01/10.63

=== BiDir Next-Char Pretrain 6/100 — TRAIN ===
=== BiDir Next-Char Pretrain 6/100 — VAL   ===
[NXT 6/100] loss/tok 2.3074/2.2452 | ppl 10.05/9.44

=== BiDir Next-Char Pretrain 7/100 — TRAIN ===
=== BiDir Next-Char Pretrain 7/100 — VAL   ===
[NXT 7/100] loss/t

In [6]:
# =========================
# After pretraining: oversample ONLY short words (len 1–5)
# =========================

from collections import defaultdict, Counter
import numpy as np, random

# Keep your existing PERMS_PER_WORD
# PERMS_PER_WORD = 3

# Define the "short" bucket (1..5)
SHORT_BUCKET = (1, 5)
SHORT_BUCKET_MULT = 3   # <- duplicate 3x total (original + 2x extra). Tweak as you like.

def in_short_bucket(L: int) -> bool:
    return SHORT_BUCKET[0] <= L <= SHORT_BUCKET[1]

def summarize_len_bins(indices, label=""):
    cnt = Counter(len(all_words[i]) for i in indices)
    small = sum(v for L,v in cnt.items() if in_short_bucket(L))
    print(f"{label}: total={len(indices)} | short(1-5)={small} | sample lens:",
          dict(list(sorted(cnt.items()))[:10]))

def states_for_word(word, perms_per_word=1):
    L = len(word)
    uniq = sorted(set(word))
    n = len(uniq)
    out = []
    for _ in range(perms_per_word):
        rnd = uniq[:]
        random.shuffle(rnd)
        revealed = set()
        for k in range(n):
            pattern = "".join(ch if ch in revealed else "_" for ch in word)
            remaining = [ch for ch in rnd if ch not in revealed]
            if remaining:
                mask = np.zeros(26, dtype=np.float32)
                for ch in remaining:
                    mask[ord(ch)-97] = 1.0
                ids = tok.with_cls(tok.encode_board(pattern))
                out.append((ids, mask, L))
            revealed.add(rnd[k])
    return out

def build_states_from_indices(indices, perms_per_word):
    by_len = defaultdict(list)
    for i in indices:
        w = all_words[i]
        for ids, mask, L in states_for_word(w, perms_per_word=perms_per_word):
            by_len[L].append((ids, mask))
    return by_len

# 1) Start from your original split
train_idx_list = list(train_idx)

# 2) Split into short vs others
short_idxs  = [i for i in train_idx_list if in_short_bucket(len(all_words[i]))]
other_idxs  = [i for i in train_idx_list if not in_short_bucket(len(all_words[i]))]

summarize_len_bins(train_idx_list, "train (original)")

# 3) Oversample ONLY the short bucket by a multiplicative factor
if short_idxs and SHORT_BUCKET_MULT > 1:
    extra = random.choices(short_idxs, k=(SHORT_BUCKET_MULT - 1) * len(short_idxs))
    train_idx_aug = other_idxs + short_idxs + extra
else:
    train_idx_aug = train_idx_list[:]

random.shuffle(train_idx_aug)
summarize_len_bins(train_idx_aug, "train (short-boosted)")

# 4) Build states: train from the augmented indices; val stays deterministic
print("Building Hangman states…")
train_states = build_states_from_indices(train_idx_aug, perms_per_word=PERMS_PER_WORD)
val_states   = build_states_from_indices(list(val_idx), perms_per_word=1)

print("train states total:", sum(len(v) for v in train_states.values()))
print("val   states total:", sum(len(v) for v in val_states.values()))


train (original): total=221617 | short(1-5)=18583 | sample lens: {1: 17, 2: 259, 3: 2157, 4: 5157, 5: 10993, 6: 19079, 7: 25243, 8: 29686, 9: 30136, 10: 26291}
train (short-boosted): total=258783 | short(1-5)=55749 | sample lens: {1: 49, 2: 772, 3: 6431, 4: 15397, 5: 33100, 6: 19079, 7: 25243, 8: 29686, 9: 30136, 10: 26291}
Building Hangman states…
train states total: 7167888
val   states total: 42194


In [7]:

# 5) Finetune (your existing loop; I just ensure lr uses FT_LR)
def hm_batches_from_states(by_len, batch_size=FT_BATCH, shuffle=True):
    lengths = list(by_len.keys())
    rng = random.Random(SEED)
    if shuffle:
        for L in lengths: rng.shuffle(by_len[L])
        rng.shuffle(lengths)
    for L in lengths:
        bucket = by_len[L]
        for s in range(0, len(bucket), batch_size):
            chunk = bucket[s:s+batch_size]
            xs_np = np.asarray([x for (x,_) in chunk], dtype=np.int64)
            ms_np = np.asarray([m for (_,m) in chunk], dtype=np.float32)
            yield torch.from_numpy(xs_np), torch.from_numpy(ms_np), L


In [8]:
FT_LR = 1e-5

In [9]:
def load_pretrained_model(path="m9.pt", device=None):
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    ckpt = torch.load(path, map_location="cpu")

    # read config if present, otherwise use your defaults
    cfg = ckpt.get("config", {}) if isinstance(ckpt, dict) else {}
    model = TinyCharTransformer(
        vocab_size=cfg.get("vocab_size", 29),
        d_model=cfg.get("d_model", 256),
        n_heads=cfg.get("n_heads", 8),
        n_layers=cfg.get("n_layers", 4),
        d_ff=cfg.get("d_ff", 512),
        max_len=cfg.get("max_len", 32),
    )

    # handle both wrapped {"state_dict": ...} and raw state_dict checkpoints
    state_dict = ckpt["state_dict"] if isinstance(ckpt, dict) and "state_dict" in ckpt else ckpt
    missing, unexpected = model.load_state_dict(state_dict, strict=False)
    if missing or unexpected:
        print(f"[load note] missing keys: {missing}\n[load note] unexpected keys: {unexpected}")

    model.to(device)
    model.eval()  # set to eval; switch to train() before finetune
    print(f"Loaded pretrained weights from {path} onto {device}.")
    return model, device

# usage:
model, DEVICE = load_pretrained_model("m10.pt")

Loaded pretrained weights from m10.pt onto cuda.


In [10]:

def finetune_hangman(model, train_states, val_states, epochs=FT_EPOCHS,
                     batch_size=FT_BATCH, lr=FT_LR, log_every=200):
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)

    def run_epoch(by_len, train=True):
        model.train(train)
        total, loss_sum, hit1_sum, steps = 0, 0.0, 0.0, 0
        epoch_total = sum(len(v) for v in by_len.values())
        gen = hm_batches_from_states(by_len, batch_size=batch_size, shuffle=True)
        for xs, mask, L in gen:
            xs   = xs.to(DEVICE)
            mask = mask.to(DEVICE)
            with torch.set_grad_enabled(train):
                logits = model.forward_cls(xs)           # [B,26]
                p = torch.softmax(logits, dim=1)         # [B,26]
                mass = (p * mask).sum(dim=1).clamp_min(1e-8)
                loss = -torch.log(mass).mean()
                if train:
                    opt.zero_grad(set_to_none=True)
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    opt.step()
            with torch.no_grad():
                pred1 = logits.argmax(dim=1)
                hit = mask[torch.arange(mask.size(0), device=mask.device), pred1].gt(0).float().sum().item()
                bsz = mask.size(0)
                total    += bsz
                loss_sum += loss.item() * bsz
                hit1_sum += hit
                steps += 1
                if log_every and (steps % log_every == 0):
                    print(f"  [{'train' if train else 'val'}] step {steps:>5} | seen {total:>8}/{epoch_total} "
                          f"| len={L:<2} | batch={bsz:<4} | loss={loss_sum/max(total,1):.4f} | Hit@1={hit1_sum/max(total,1):.3f}")
        return loss_sum/max(total,1), hit1_sum/max(total,1)

    for ep in range(1, epochs+1):
        print(f"\n=== Hangman Finetune {ep}/{epochs} — TRAIN ===")
        tr_loss, tr_hit1 = run_epoch(train_states, train=True)
        print(f"=== Hangman Finetune {ep}/{epochs} — VAL   ===")
        va_loss, va_hit1 = run_epoch(val_states,   train=False)
        print(f"[HMG {ep}/{epochs}] loss {tr_loss:.4f}/{va_loss:.4f} | Hit@1 {tr_hit1:.3f}/{va_hit1:.3f}")

    ckpt = {
        "config": {"vocab_size": VOCAB_SIZE, "d_model": DMODEL, "n_heads": N_HEADS,
                   "n_layers": N_LAYERS, "d_ff": D_FF, "max_len": MAX_LEN},
        "state_dict": model.state_dict(),
    }
    torch.save(ckpt, "m11.pt")
    print("Saved fine-tuned model → m11.pt")
    return model

# Run finetune
model = finetune_hangman(model, train_states, val_states, epochs=FT_EPOCHS, batch_size=FT_BATCH, lr=FT_LR)



=== Hangman Finetune 1/100 — TRAIN ===
  [train] step   200 | seen   409600/7167888 | len=12 | batch=2048 | loss=0.2795 | Hit@1=0.869
  [train] step   400 | seen   817212/7167888 | len=13 | batch=2048 | loss=0.2534 | Hit@1=0.881
  [train] step   600 | seen  1226812/7167888 | len=13 | batch=2048 | loss=0.2431 | Hit@1=0.885
  [train] step   800 | seen  1633136/7167888 | len=10 | batch=2048 | loss=0.2910 | Hit@1=0.865
  [train] step  1000 | seen  2042736/7167888 | len=10 | batch=2048 | loss=0.3223 | Hit@1=0.852
  [train] step  1200 | seen  2448724/7167888 | len=9  | batch=2048 | loss=0.4138 | Hit@1=0.818
  [train] step  1400 | seen  2858324/7167888 | len=9  | batch=2048 | loss=0.4374 | Hit@1=0.807
  [train] step  1600 | seen  3267896/7167888 | len=5  | batch=2048 | loss=0.5036 | Hit@1=0.779
  [train] step  1800 | seen  3675784/7167888 | len=17 | batch=2048 | loss=0.5875 | Hit@1=0.745
  [train] step  2000 | seen  4083800/7167888 | len=7  | batch=2048 | loss=0.6082 | Hit@1=0.736
  [train] 