In [None]:
# === Dependencies ===
import random, math, time, itertools
import numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
# set seed
random.seed(0); np.random.seed(0); torch.manual_seed(0)

# === Word list (small) ===
# Replace with a larger word list if you like
WORDS = ["crane","slate","trace","crate","apple","glare","plane","stare","store","stone",
         "later","hello","world","spear","shiny","piano","other","their","there","berry",
         "cigar","rebut","sissy","awake","never","adieu","audio","about","trace","later"]
WORDS = sorted(list(set(w for w in WORDS if len(w)==5)))
V = len(WORDS)
print("Vocab size:", V)

# maps
LETTER2IDX = {c:i for i,c in enumerate("abcdefghijklmnopqrstuvwxyz")}
WORD2IDX = {w:i for i,w in enumerate(WORDS)}
IDX2WORD = {i:w for w,i in WORD2IDX.items()}

# --- wordle feedback function
def wordle_feedback(target, guess):
    fb = [0]*5
    tlist = list(target)
    # greens
    for i in range(5):
        if guess[i] == target[i]:
            fb[i] = 2
            tlist[i] = None
    # yellows
    for i in range(5):
        if fb[i] == 0:
            ch = guess[i]
            if ch in tlist:
                fb[i] = 1
                tlist[tlist.index(ch)] = None
    return fb

# --- simulate games and build examples
def generate_examples(N, K=5):
    X_hist = []  # list of histories; each history is list of (guess, feedback) up to but not including next guess
    Y_next = []  # next guess word index
    for _ in range(N):
        target = random.choice(WORDS)
        # simple teacher: pick random consistent word each turn
        history = []  # list of (guess_str, feedback_list)
        candidates = set(WORDS)
        for turn in range(K):
            # sample a guess from current candidates (teacher)
            guess = random.choice(list(candidates))
            fb = wordle_feedback(target, guess)
            X_hist.append(list(history))  # copy current history as training input
            Y_next.append(WORD2IDX[guess])  # teacher's chosen guess
            # update history and candidates
            history.append((guess, fb))
            # filter candidates by consistency with history
            def consistent(word):
                for g, f in history:
                    if wordle_feedback(word, g) != f: return False
                return True
            candidates = set(w for w in candidates if consistent(w))
            if len(candidates)==0:
                candidates = set(WORDS)
            if guess == target:
                break
    return X_hist, Y_next

N = 30000
K = 5
X_hist, Y_next = generate_examples(N, K=K)
print("Generated examples:", len(X_hist))

# === Dataset/encoding ===
PAD_LET = "<PAD>"
LET_V = 26
def encode_history(history, K=5):
    # returns tensor of shape (K, 5) with letter indices or PAD_LET coded as 26 for pad
    # We encode guesses as sequences of letter indices; feedback as ints 0/1/2
    pad_word = [26]*5
    hist_enc = []
    for turn in range(K):
        if turn < len(history):
            g, fb = history[turn]
            letters = [LETTER2IDX[ch] for ch in g]
            hist_enc.append((letters, fb))
        else:
            hist_enc.append((pad_word, [0]*5))
    # flatten into a sequence: for each token we will create combined embedding (letter + feedback)
    # we'll return two arrays: letters [K*5], feedbacks [K*5]
    letters = []
    fbs = []
    for letters_t, fb_t in hist_enc:
        letters.extend(letters_t)
        fbs.extend(fb_t)
    return np.array(letters, dtype=np.int64), np.array(fbs, dtype=np.int64)

class GuesserDataset(Dataset):
    def __init__(self, X_hist, Y_next, K=5):
        self.X_hist = X_hist
        self.Y_next = Y_next
        self.K = K
    def __len__(self): return len(self.X_hist)
    def __getitem__(self, i):
        letters, fbs = encode_history(self.X_hist[i], self.K)
        return torch.tensor(letters), torch.tensor(fbs), torch.tensor(self.Y_next[i])

ds = GuesserDataset(X_hist, Y_next, K=K)
train_size = int(0.9*len(ds))
val_size = len(ds)-train_size
train_ds, val_ds = torch.utils.data.random_split(ds, [train_size, val_size])
train_dl = DataLoader(train_ds, batch_size=128, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=256)

# === Model: Tiny Transformer ===
class TinyGuesser(nn.Module):
    def __init__(self, letter_vocab=27, fb_vocab=3, emb=32, nhead=2, nhid=64, nlayers=2, seq_len=K*5, vocab_out=V):
        super().__init__()
        self.letter_emb = nn.Embedding(letter_vocab, emb)
        self.fb_emb = nn.Embedding(fb_vocab, 8)
        self.token_proj = nn.Linear(emb+8, emb)  # combine letter+fb -> token embedding
        self.pos_emb = nn.Embedding(seq_len, emb)
        encoder_layer = nn.TransformerEncoderLayer(d_model=emb, nhead=nhead, dim_feedforward=nhid)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=nlayers)
        self.pool = nn.Linear(emb, emb)  # simple pool to reduce to fixed vector (can do mean)
        self.readout = nn.Sequential(
            nn.Linear(emb, 64),
            nn.ReLU(),
            nn.Linear(64, vocab_out)
        )
    def forward(self, letters, fbs):
        # letters: [B, seq_len], fbs: [B, seq_len]
        B, L = letters.shape
        le = self.letter_emb(letters)    # [B,L,emb]
        fe = self.fb_emb(fbs)           # [B,L,8]
        tok = torch.cat([le, fe], dim=-1)
        tok = self.token_proj(tok)      # [B,L,emb]
        pos = torch.arange(L, device=letters.device).unsqueeze(0).repeat(B,1)
        tok = tok + self.pos_emb(pos)
        # transformer expects [L,B,emb]
        h = self.transformer(tok.permute(1,0,2))  # [L,B,emb]
        h = h.permute(1,0,2)  # [B,L,emb]
        # pool by mean
        v = h.mean(dim=1)
        logits = self.readout(v)  # [B, V]
        return logits, h  # return token-level hidden for probes

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TinyGuesser().to(device)
opt = torch.optim.Adam(model.parameters(), lr=3e-4)
loss_fn = nn.CrossEntropyLoss()

# === Training ===
def evaluate(dl):
    model.eval()
    total=0; correct=0; loss_sum=0.0
    with torch.no_grad():
        for letters, fbs, y in dl:
            letters=fletters=letters.to(device); fbs=fbs.to(device); y=y.to(device)
            logits, _ = model(letters, fbs)
            loss = loss_fn(logits, y)
            loss_sum += loss.item()*letters.size(0)
            preds = logits.argmax(dim=-1)
            correct += (preds==y).sum().item()
            total += letters.size(0)
    return loss_sum/total, correct/total

# train loop
for epoch in range(1,21):
    model.train()
    t0 = time.time()
    for letters, fbs, y in train_dl:
        letters=letters.to(device); fbs=fbs.to(device); y=y.to(device)
        logits, _ = model(letters, fbs)
        loss = loss_fn(logits, y)
        opt.zero_grad(); loss.backward(); opt.step()
    val_loss, val_acc = evaluate(val_dl)
    print(f"Epoch {epoch}  val_loss {val_loss:.4f}  val_acc {val_acc:.4f}  time {time.time()-t0:.1f}s")
    if val_acc > 0.7: break  # early stop for small vocabs

# save model if wished
torch.save(model.state_dict(), "tiny_guesser.pt")
