In [1]:
# === Cell 1: Setup & Load ===
# If you're in Colab: upload files: corpus.txt and test.txt
# from google.colab import files
# files.upload()   # then select corpus.txt and test.txt

from pathlib import Path

CORPUS_PATH = 'corpus.txt'   # adjust if needed, e.g., 'data/corpus.txt'
TEST_PATH   = 'test.txt'     # adjust if needed, e.g., 'data/test.txt'

def load_words(path: str):
    p = Path(path)
    if not p.exists():
        raise FileNotFoundError(f"Missing file: {path}")
    return [w.strip().lower() for w in p.read_text().splitlines()
            if w.strip() and w.strip().isalpha()]

corpus_words = load_words(CORPUS_PATH)
test_words   = load_words(TEST_PATH)
len(corpus_words), len(test_words)


(49979, 2000)

In [2]:
# === Cell 2: Utilities ===
ALPHABET = 'abcdefghijklmnopqrstuvwxyz'
ALPHABET_SET = set(ALPHABET)

def matches_mask(word: str, mask: str, guessed_wrong: set[str]) -> bool:
    """Returns True if `word` is consistent with `mask` and `guessed_wrong` letters."""
    if len(word) != len(mask):
        return False
    if any(gw in word for gw in guessed_wrong):
        return False
    for wc, mc in zip(word, mask):
        if mc == '_':
            # blank: can't be a wrong letter (already covered above)
            continue
        if wc != mc:
            return False
    return True


In [3]:
# === Cell 3: CFP Oracle (Option A) ===
from collections import Counter, defaultdict
from typing import List, Set, Dict

class CandidateFilterOracle:
    """
    Fast, strong baseline:
    - Filters corpus by (mask, wrong guesses)
    - Tallies letters in blank positions across candidates
    - Returns P(letter) over unguessed letters
    """
    def __init__(self, words: List[str]):
        self.by_len = defaultdict(list)
        for w in words:
            w = w.strip().lower()
            if w and all('a' <= c <= 'z' for c in w):
                self.by_len[len(w)].append(w)

    def letter_posteriors(self, mask: str, guessed: Set[str]) -> Dict[str, float]:
        L = len(mask)
        pool = self.by_len.get(L, [])
        guessed_wrong = {g for g in guessed if g not in mask}
        cands = [w for w in pool if matches_mask(w, mask, guessed_wrong)]
        if not cands:
            rem = [c for c in ALPHABET if c not in guessed]
            return {c: 1.0/len(rem) for c in rem} if rem else {c: 0.0 for c in ALPHABET}

        scores = Counter()
        blanks = [i for i,ch in enumerate(mask) if ch == '_']
        for w in cands:
            for i in blanks:
                scores[w[i]] += 1

        dist = {c: float(scores[c]) for c in ALPHABET if c not in guessed}
        s = sum(dist.values())
        if s <= 0:
            rem = [c for c in ALPHABET if c not in guessed]
            return {c: 1.0/len(rem) for c in rem} if rem else {c: 0.0 for c in ALPHABET}
        for c in dist:
            dist[c] /= s
        return dist

# Instantiate right now:
cfp = CandidateFilterOracle(corpus_words)


In [4]:
# === Cell 4: HMM-only evaluation helpers ===
import numpy as np

def evaluate_oracle_topk(oracle, words: list[str],
                         revelation_levels=(0.2, 0.4, 0.6),
                         tests_per_word=2):
    """
    For each word, reveal some letters, ask oracle for letter probs,
    and check if the top-k prediction hits any actually-missing letter.
    """
    top1 = top3 = top5 = total = 0

    rng = np.random.default_rng(0)
    for w in words:
        L = len(w)
        for r in revelation_levels:
            for _ in range(tests_per_word):
                mask = ['_'] * L
                # reveal ceil(r*L) positions
                k = max(1, int(np.ceil(r * L)))
                idx = rng.choice(L, size=k, replace=False)
                for i in idx:
                    mask[i] = w[i]
                guessed = set(ch for ch in mask if ch != '_')
                actual_missing = set(w) - guessed
                if not actual_missing:
                    continue

                probs = oracle.letter_posteriors(''.join(mask), guessed)
                if not probs:
                    continue
                ranked = sorted(probs.items(), key=lambda kv: kv[1], reverse=True)
                pred_letters = [c for c,_ in ranked]

                total += 1
                if pred_letters[0] in actual_missing: top1 += 1
                if any(c in actual_missing for c in pred_letters[:3]): top3 += 1
                if any(c in actual_missing for c in pred_letters[:5]): top5 += 1

    return {
        "top1": top1/total if total else 0.0,
        "top3": top3/total if total else 0.0,
        "top5": top5/total if total else 0.0,
        "cases": total
    }

# Evaluate CFP on test set:
cfp_stats = evaluate_oracle_topk(cfp, test_words, revelation_levels=(0.2,0.4,0.6), tests_per_word=2)
cfp_stats


{'top1': 0.43150742387383606,
 'top3': 0.6685680731482259,
 'top5': 0.7844979448032883,
 'cases': 11921}

In [11]:
# === Cell 4.9 (fixed & sanitized): Wire Role-B (rl.ipynb) without retraining on import ===
import os, importlib.util, re, io

print(">> Cell 4.9: preparing Role-B classes (no retraining on import)")

# 1) Convert rl.ipynb -> rl.py if present
if os.path.exists('rl.ipynb'):
    print("   - Converting rl.ipynb → rl.py …")
    !jupyter nbconvert --to python rl.ipynb --output rl.py >/dev/null 2>&1

if not os.path.exists('rl.py'):
    print("   [WARN] rl.py not found; ensure rl.ipynb is in this folder.")
else:
    print("   - Sanitizing rl.py so it does NOT auto-train on import …")
    src = open('rl.py', 'r', encoding='utf-8').read()
    lines = src.splitlines()

    # (a) Move any "from __future__" lines to the very top
    future = [ln for ln in lines if ln.strip().startswith('from __future__')]
    body   = [ln for ln in lines if not ln.strip().startswith('from __future__')]
    sanitized = []
    if future:
        sanitized.extend(future + [""])  # keep a blank line after futures
    sanitized.extend(body)

    text = "\n".join(sanitized)

    # (b) Comment out any top-level calls that would start training/eval on import
    # We’ll conservatively comment lines that start with those call names.
    patterns = [
        r'^\s*train_qlearning\s*\(',
        r'^\s*train_dqn\s*\(',
        r'^\s*eval_agent\s*\(',
    ]
    def comment_calls(code: str) -> str:
        buf = io.StringIO()
        for ln in code.splitlines():
            if any(re.match(pat, ln) for pat in patterns):
                buf.write("# [SANITIZED ON IMPORT] " + ln + "\n")
            else:
                buf.write(ln + "\n")
        return buf.getvalue()

    text = comment_calls(text)

    # (c) If there is a __main__ block, comment it out entirely to avoid demo runs
    text = re.sub(
        r'(?ms)^\s*if\s+__name__\s*==\s*[\'"]__main__[\'"]\s*:\s*\n.*$',
        "# [SANITIZED] __main__ block removed to prevent auto-execution on import\n",
        text
    )

    with open('rl.py', 'w', encoding='utf-8') as f:
        f.write(text)

    print("   - Importing rl.py …")
    spec = importlib.util.spec_from_file_location("rl", "rl.py")
    rl = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(rl)

    # 2) Expose Role-B classes for Cell 5
    HangmanEnv    = getattr(rl, "HangmanEnv", None)
    TabularQAgent = getattr(rl, "TabularQAgent", None)
    DQNAgent      = getattr(rl, "DQNAgent", None)
    DQNConfig     = getattr(rl, "DQNConfig", None)
    ALPHABET      = getattr(rl, "ALPHABET", "abcdefghijklmnopqrstuvwxyz")

    print(">> Loaded Role-B classes:")
    print("   - HangmanEnv     :", HangmanEnv)
    print("   - TabularQAgent  :", TabularQAgent)
    print("   - DQNAgent       :", DQNAgent)
    print("   - DQNConfig      :", DQNConfig)
    print("   - ALPHABET len   :", len(ALPHABET))

print(">> Cell 4.9 ready. Now run Cell 5 (evaluation).")


>> Cell 4.9: preparing Role-B classes (no retraining on import)
   - Converting rl.ipynb → rl.py …
   - Sanitizing rl.py so it does NOT auto-train on import …
   - Importing rl.py …
Loaded Role-B pack. Use train_qlearning(), train_dqn(), and eval_agent().
>> Loaded Role-B classes:
   - HangmanEnv     : <class 'rl.HangmanEnv'>
   - TabularQAgent  : <class 'rl.TabularQAgent'>
   - DQNAgent       : <class 'rl.DQNAgent'>
   - DQNConfig      : <class 'rl.DQNConfig'>
   - ALPHABET len   : 26
>> Cell 4.9 ready. Now run Cell 5 (evaluation).


In [15]:
# === Cell 4.95 (patched): Quick Q-table refresh (train RL here) ===
# Trains a small/medium Q-learning run against your chosen oracle and saves models/q_table.pkl

from pathlib import Path
from contextlib import suppress
import inspect

assert 'HangmanEnv' in globals() and 'TabularQAgent' in globals(), \
    "Role-B classes not loaded. Re-run Cell 4.9 before this."

# 1) pick the oracle RL should learn from (MUST match what you'll evaluate)
BEST_ORACLE_FOR_RL = 'improved_hmm'   # 'improved_hmm'|'hyb_improved'|'hyb_basic'|'count_hmm'|'cfp'

_oracle_map = {
    'improved_hmm': globals().get('improved_hmm', None),
    'hyb_improved': globals().get('hyb_improved', None),
    'hyb_basic':    globals().get('hyb_basic', None),
    'count_hmm':    globals().get('count_hmm', None),
    'cfp':          globals().get('cfp', None),
}
ORACLE_FOR_RL = _oracle_map.get(BEST_ORACLE_FOR_RL)
if ORACLE_FOR_RL is None:
    raise RuntimeError(f"Chosen oracle '{BEST_ORACLE_FOR_RL}' not built yet. Build it (A/6/7/B) and retry.")

# 2) training knobs (safe defaults)
EPISODES     = 20000
EPS_START    = 0.20
EPS_END      = 0.01
EPS_DECAY_EP = int(0.8 * EPISODES)

SAVE_DIR  = Path('models')
SAVE_DIR.mkdir(exist_ok=True, parents=True)
SAVE_PATH = SAVE_DIR / 'q_table.pkl'

def _call_train_qlearning_or_fallback():
    """Try rl.train_qlearning with adaptive kwargs; else fallback to inline trainer."""
    used_rl_trainer = False
    if 'rl' in globals() and hasattr(rl, 'train_qlearning'):
        fn = rl.train_qlearning
        sig = inspect.signature(fn)
        params = set(sig.parameters.keys())

        # Map 'corpus_words' to the parameter name your function uses
        corpus_arg_name = None
        for candidate in ['words','train_words','word_list','dataset','corpus','train_set']:
            if candidate in params:
                corpus_arg_name = candidate
                break

        # Candidate kwargs (we'll filter by signature)
        candidate_kwargs = {
            corpus_arg_name: corpus_words if corpus_arg_name else None,
            'hmm': ORACLE_FOR_RL,
            'episodes': EPISODES,
            'eps_start': EPS_START,
            'eps_end': EPS_END,
            'eps_decay_episodes': EPS_DECAY_EP,
            'wrong_penalty': 5.0,
            'repeat_penalty': 2.0,
            'save_path': str(SAVE_PATH),
            'log_every': 1000,
        }
        call_kwargs = {k:v for k,v in candidate_kwargs.items() if k and (k in params) and (v is not None)}

        print("[RL] Detected train_qlearning signature:", str(sig))
        print("[RL] Passing kwargs:", sorted(call_kwargs.keys()))
        try:
            fn(**call_kwargs)
            print(f"[RL] Q-table saved → {SAVE_PATH}")
            used_rl_trainer = True
        except TypeError as e:
            print("[WARN] train_qlearning kwargs mismatch → falling back to inline trainer.\n", e)
        except Exception as e:
            print("[WARN] train_qlearning raised an exception → falling back to inline trainer.\n", e)

    if not used_rl_trainer:
        # Fallback: inline trainer via env + TabularQAgent (works with your eval API)
        print("[INFO] Using inline TabularQAgent trainer (fallback).")
        import numpy as np
        env = HangmanEnv(corpus_words, lives=6, hmm=ORACLE_FOR_RL, seed=123)
        agent = TabularQAgent()

        def eps_at(ep):
            if ep >= EPS_DECAY_EP: return EPS_END
            return EPS_START - (EPS_START - EPS_END) * (ep / EPS_DECAY_EP)

        for ep in range(1, EPISODES + 1):
            env.reset()
            done = False
            with suppress(Exception):
                agent.eps = eps_at(ep)
            while not done:
                hmm_probs = env._hmm_probs()
                a = agent.act(env.mask, env.lives, hmm_probs, env.guessed)
                letter = (ALPHABET[a] if isinstance(a, int) else a)
                _, reward, done, _ = env.step(letter)
                with suppress(Exception):
                    agent.update(env.mask, env.lives, hmm_probs, env.guessed, a, reward, done)
            if ep % 1000 == 0:
                print(f"[Q] ep={ep}  ε≈{eps_at(ep):.3f}")

        with suppress(Exception):
            agent.save(str(SAVE_PATH))
        print(f"[RL] (fallback) Q-table saved → {SAVE_PATH}")

_call_train_qlearning_or_fallback()


[RL] Detected train_qlearning signature: (corpus_path='corpus.txt', episodes=30000, seed=42, save_path='q_table.pkl')
[RL] Passing kwargs: ['episodes', 'save_path']
[Q] ep=1000 avg_reward/ep=-5.534 wrong=6 repeated=0
[Q] ep=2000 avg_reward/ep=-5.369 wrong=6 repeated=0
[Q] ep=3000 avg_reward/ep=-5.228 wrong=6 repeated=0
[Q] ep=4000 avg_reward/ep=-5.131 wrong=6 repeated=0
[Q] ep=5000 avg_reward/ep=-5.064 wrong=6 repeated=0
[Q] ep=6000 avg_reward/ep=-4.830 wrong=6 repeated=0
[Q] ep=7000 avg_reward/ep=-4.777 wrong=6 repeated=0
[Q] ep=8000 avg_reward/ep=-4.987 wrong=6 repeated=0
[Q] ep=9000 avg_reward/ep=-4.643 wrong=6 repeated=0
[Q] ep=10000 avg_reward/ep=-4.879 wrong=6 repeated=0
[Q] ep=11000 avg_reward/ep=-4.812 wrong=6 repeated=0
[Q] ep=12000 avg_reward/ep=-4.365 wrong=6 repeated=0
[Q] ep=13000 avg_reward/ep=-4.598 wrong=6 repeated=0
[Q] ep=14000 avg_reward/ep=-4.723 wrong=6 repeated=0
[Q] ep=15000 avg_reward/ep=-4.448 wrong=6 repeated=0
[Q] ep=16000 avg_reward/ep=-4.562 wrong=6 repeate

In [17]:
# === Cell 4.96: Stable DQN training (signature-safe; same oracle as Q-table) ===
from pathlib import Path
import inspect

assert 'HangmanEnv' in globals() and 'DQNAgent' in globals() and 'DQNConfig' in globals(), \
    "Role-B classes not loaded. Re-run Cell 4.9 first."
assert 'ORACLE_FOR_RL' in globals() and ORACLE_FOR_RL is not None, \
    "Choose/build your oracle first (4.95)."

SAVE_DIR  = Path('models'); SAVE_DIR.mkdir(exist_ok=True, parents=True)
DQN_SAVE  = SAVE_DIR / 'dqn.pt'

# Stabilized hyperparams
STEPS            = 80000
BATCH_SIZE       = 64
LR               = 1e-4
GAMMA            = 0.995
REPLAY_SIZE      = 100_000
WARMUP_STEPS     = 5_000
TARGET_UPDATE    = 2000
GRAD_CLIP        = 1.0
EPS_START        = 0.20
EPS_END          = 0.02
EPS_DECAY_STEPS  = int(0.8 * STEPS)
LOG_EVERY        = 2000

def _call_signature_safe(fn, kw):
    params = set(inspect.signature(fn).parameters.keys())
    filtered = {k: v for k, v in kw.items() if k in params}
    return fn(**filtered)

if 'rl' in globals() and hasattr(rl, 'train_dqn'):
    print(f"[DQN] Training via rl.train_dqn() on oracle='{BEST_ORACLE_FOR_RL}' …")
    superset_kwargs = dict(
        words=corpus_words,      # some trainers expect 'words'
        corpus=corpus_words,     # or 'corpus'
        train_words=corpus_words,
        hmm=ORACLE_FOR_RL,
        steps=STEPS,
        batch_size=BATCH_SIZE,
        lr=LR,
        gamma=GAMMA,
        replay_size=REPLAY_SIZE,
        warmup_steps=WARMUP_STEPS,
        target_update=TARGET_UPDATE,
        grad_clip=GRAD_CLIP,
        eps_start=EPS_START,
        eps_end=EPS_END,
        eps_decay_steps=EPS_DECAY_STEPS,
        wrong_penalty=5.0,
        repeat_penalty=2.0,
        log_every=LOG_EVERY,
        save_path=str(DQN_SAVE),
        lives=6,
        seed=123,
    )
    try:
        _call_signature_safe(rl.train_dqn, superset_kwargs)
        print(f"[DQN] Saved → {DQN_SAVE}")
    except Exception as e:
        print(f"[WARN] rl.train_dqn() failed ({e}). Skipping DQN for now. Set USE_DQN=False in Cell 5.")
else:
    print("[WARN] rl.train_dqn() not found. Skipping DQN training. (Leave USE_DQN=False in Cell 5.)")


[DQN] Training via rl.train_dqn() on oracle='improved_hmm' …
[DQN] steps=5000 buffer=5000 avg_loss≈25.6109
[DQN] steps=10000 buffer=10000 avg_loss≈275.2561
[DQN] steps=15000 buffer=15000 avg_loss≈807.2449
[DQN] steps=20000 buffer=20000 avg_loss≈7584.1524
[DQN] steps=25000 buffer=25000 avg_loss≈30239.4428
[DQN] steps=30000 buffer=30000 avg_loss≈57938.0791
[DQN] steps=35000 buffer=35000 avg_loss≈106320.7648
[DQN] steps=40000 buffer=40000 avg_loss≈231427.6960
[DQN] steps=45000 buffer=45000 avg_loss≈612298.6240
[DQN] steps=50000 buffer=50000 avg_loss≈1754976.1334
[DQN] steps=55000 buffer=50000 avg_loss≈4654889.1999
[DQN] steps=60000 buffer=50000 avg_loss≈10718675.2278
[DQN] steps=65000 buffer=50000 avg_loss≈21742778.9120
[DQN] steps=70000 buffer=50000 avg_loss≈39372530.6821
[DQN] steps=75000 buffer=50000 avg_loss≈65370684.5248
[DQN] steps=80000 buffer=50000 avg_loss≈102370013.0100
Saved DQN weights to models/dqn.pt
[DQN] Saved → models/dqn.pt


In [18]:
# === Cell A: Build Count-HMM oracle ===
from collections import defaultdict, Counter
import numpy as np
ALPHABET = 'abcdefghijklmnopqrstuvwxyz'
ALPHABET_SET = set(ALPHABET)

def _mk_index():
    idx = {c:i for i,c in enumerate(ALPHABET)}
    rev = {i:c for c,i in idx.items()}
    return idx, rev

class CountHMMOracle:
    def __init__(self, words, smooth: float = 0.5, by_length: bool = True):
        self.idx, self.rev = _mk_index()
        self.K = len(ALPHABET)
        self.by_length = by_length
        self.smooth = smooth
        self.pi = {}   # length -> (K,)
        self.A  = {}   # length -> (K,K)

        groups = defaultdict(list)
        for w in words:
            w = w.strip().lower()
            if not w or any(ch not in ALPHABET_SET for ch in w):
                continue
            groups[len(w)].append(w)

        for L, ws in groups.items():
            pi = np.full(self.K, smooth, dtype=np.float64)
            A  = np.full((self.K,self.K), smooth, dtype=np.float64)
            for w in ws:
                pi[self.idx[w[0]]] += 1.0
                for a,b in zip(w[:-1], w[1:]):
                    A[self.idx[a], self.idx[b]] += 1.0
            pi /= pi.sum()
            A  /= A.sum(axis=1, keepdims=True)
            self.pi[L] = pi
            self.A[L]  = A

        # global fallback (if by_length=False)
        if not by_length:
            pis = [self.pi[L] for L in self.pi]
            As  = [self.A[L]  for L in self.A]
            self.pi = {None: np.mean(pis, axis=0)} if pis else {None: np.full(self.K,1.0/self.K)}
            self.A  = {None: np.mean(As,  axis=0)} if As  else {None: np.full((self.K,self.K),1.0/self.K)}

    def _params(self, L: int):
        if self.by_length and L in self.pi:
            return self.pi[L], self.A[L]
        return self.pi.get(None, np.full(self.K,1.0/self.K)), self.A.get(None, np.full((self.K,self.K),1.0/self.K))

    def letter_posteriors(self, mask: str, guessed: set[str]) -> dict[str, float]:
        L = len(mask)
        pi, A = self._params(L)
        if pi is None or A is None:
            rem = [c for c in ALPHABET if c not in guessed]
            return {c: 1.0/len(rem) for c in rem} if rem else {c:0.0 for c in ALPHABET}

        # emission mask
        E = np.ones((L, self.K), dtype=np.float64)
        wrong = {g for g in guessed if g not in mask}
        for t,ch in enumerate(mask):
            if ch == '_':
                for g in wrong:
                    E[t, self.idx[g]] = 0.0
            else:
                E[t,:] = 0.0
                E[t, self.idx[ch]] = 1.0

        eps = 1e-12
        log_pi = np.log(pi + eps)
        log_A  = np.log(A  + eps)
        log_E  = np.log(E  + eps)

        # forward
        alpha = np.full((L, self.K), -np.inf)
        alpha[0,:] = log_pi + log_E[0,:]
        for t in range(1, L):
            prev = alpha[t-1,:].reshape(-1,1) + log_A  # (K,K)
            m = prev.max(axis=0)
            alpha[t,:] = log_E[t,:] + (m + np.log(np.exp(prev - m).sum(axis=0)+eps))

        # backward
        beta = np.full((L, self.K), 0.0)
        for t in range(L-2, -1, -1):
            nxt = log_A + (log_E[t+1,:] + beta[t+1,:]).reshape(1,-1)
            m = nxt.max(axis=1)
            beta[t,:] = m + np.log(np.exp(nxt - m.reshape(-1,1)).sum(axis=1)+eps)

        log_Z = (alpha[-1,:].max() + np.log(np.exp(alpha[-1,:] - alpha[-1,:].max()).sum()+eps))
        gamma = np.exp(alpha + beta - log_Z)  # (L,K)

        P = {c:0.0 for c in ALPHABET}
        for t,ch in enumerate(mask):
            if ch == '_':
                for s in range(self.K):
                    c = self.rev[s]
                    if c not in guessed and E[t,s] > 0.0:
                        P[c] += gamma[t,s]
        Z = sum(P.values())
        if Z <= 0:
            rem = [c for c in ALPHABET if c not in guessed]
            return {c: 1.0/len(rem) for c in rem} if rem else {c:0.0 for c in ALPHABET}
        for c in P: P[c] /= Z
        return P

# ✅ use CORPUS words (not an undefined `train_words`)
count_hmm = CountHMMOracle(corpus_words, smooth=0.5, by_length=True)
print("[OK] count_hmm ready:", isinstance(count_hmm, CountHMMOracle))


[OK] count_hmm ready: True


In [19]:
# === Cell B: Hybrid oracle (Count-HMM ⊕ CFP) ===
LAM = 0.6  # you can tune 0.3–0.7

class HybridOracle:
    def __init__(self, hmm_oracle, cfp_oracle, lam: float = 0.6):
        self.hmm = hmm_oracle
        self.cfp = cfp_oracle
        self.lam = lam

    def letter_posteriors(self, mask: str, guessed: set[str]) -> dict[str, float]:
        A = self.hmm.letter_posteriors(mask, guessed)
        B = self.cfp.letter_posteriors(mask, guessed)
        # union of keys
        keys = set(A.keys()) | set(B.keys())
        out = {}
        for k in keys:
            if k in guessed:
                continue
            out[k] = self.lam * A.get(k, 0.0) + (1.0 - self.lam) * B.get(k, 0.0)
        s = sum(out.values())
        if s > 0:
            for k in out: out[k] /= s
        return out

hyb_basic = HybridOracle(count_hmm, cfp, lam=LAM)
print("[OK] hyb_basic ready:", isinstance(hyb_basic, HybridOracle))


[OK] hyb_basic ready: True


In [20]:
# === Cell 5 (instrumented): 2000-game evaluation with live prints ===
import os, time
from contextlib import suppress
import numpy as np

GAMES = 2000                 # official scale
USE_QTABLE = True
QTABLE_PATH = 'models/q_table.pkl'   # ← use the model you saved in 4.95
USE_DQN = True             # keep off unless you stabilize it
DQN_PATH = 'models/dqn.pt'

# auto-disable if files missing
if USE_QTABLE and not os.path.exists(QTABLE_PATH):
    print(f"[Info] Q-table not found → disabling QTable eval.")
    USE_QTABLE = False
if USE_DQN and not os.path.exists(DQN_PATH):
    print(f"[Info] DQN weights not found → disabling DQN eval.")
    USE_DQN = False

# Role-B classes (from 4.9 or global)
try:
    from hackman.env import HangmanEnv
    from hackman.agents.qlearning import TabularQAgent
    from hackman.agents.dqn import DQNAgent, DQNConfig
    from hackman.utils import ALPHABET
    HAS_ROLEB = True
except Exception:
    HAS_ROLEB = False
    with suppress(Exception):
        HangmanEnv; TabularQAgent
        HAS_ROLEB = True
        ALPHABET = ALPHABET if 'ALPHABET' in globals() else 'abcdefghijklmnopqrstuvwxyz'
    if not HAS_ROLEB:
        ALPHABET = 'abcdefghijklmnopqrstuvwxyz'
        print("[Info] Role-B not available → RL skipped.")

# -------------------------------------------------------------------
def eval_greedy(words, oracle, games=2000, seed=1):
    rng = np.random.default_rng(seed)
    sample = words if len(words) <= games else list(rng.choice(words, size=games, replace=False))
    wins = total_wrong = total_repeated = 0
    print(f"[GREEDY] Running {len(sample)} games …")
    for gi, w in enumerate(sample, 1):
        mask = ['_'] * len(w); guessed=set(); lives=6; wrong=repeated=0
        while lives>0 and '_' in mask:
            probs = oracle.letter_posteriors(''.join(mask), guessed)
            letter = max(probs.items(), key=lambda kv: kv[1])[0] if probs else next((c for c in ALPHABET if c not in guessed), None)
            if letter is None: break
            if letter in guessed: repeated+=1; continue
            guessed.add(letter)
            if letter in w:
                for i,ch in enumerate(w):
                    if ch==letter: mask[i]=ch
            else:
                wrong+=1; lives-=1
        if '_' not in mask: wins+=1
        total_wrong+=wrong; total_repeated+=repeated
        if gi%200==0: print(f"  {gi}/{len(sample)} done …")
    sr=wins/len(sample); score=(sr*2000)-(total_wrong*5)-(total_repeated*2)
    print(f"[GREEDY] Done: SR={sr:.3f}, Wrong={total_wrong}, Rep={total_repeated}, Score={score:.0f}")
    return dict(games=len(sample),success_rate=sr,total_wrong=total_wrong,
                total_repeated=total_repeated,final_score=score)

def eval_qtable(words, oracle, qtable_path, games=2000, seed=7):
    if not HAS_ROLEB: raise RuntimeError("Role-B missing")
    env=HangmanEnv(words,lives=6,hmm=oracle,seed=seed)
    ag=TabularQAgent(); ag.load(qtable_path)
    wins=tw=tr=0
    print(f"[QTABLE] {games} games …")
    for g in range(1,games+1):
        env.reset(); done=False
        while not done:
            hmm=env._hmm_probs()
            a=ag.act(env.mask,env.lives,hmm,env.guessed)
            letter=ALPHABET[a]
            _,_,done,_=env.step(letter)
        if env.mask==env.word: wins+=1
        tw+=env.total_wrong; tr+=env.total_repeated
        if g%200==0: print(f"  Game {g}/{games} done …")
    sr=wins/games; score=(sr*2000)-(tw*5)-(tr*2)
    print(f"[QTABLE] SR={sr:.3f}, Wrong={tw}, Rep={tr}, Score={score:.0f}")
    return dict(games=games,success_rate=sr,total_wrong=tw,total_repeated=tr,final_score=score)

def eval_dqn(words, oracle, dqn_path, games=2000, seed=11):
    if not HAS_ROLEB: raise RuntimeError("Role-B missing")
    import torch
    env=HangmanEnv(words,lives=6,hmm=oracle,seed=seed)
    s0=env.reset()
    agent=DQNAgent(state_dim=int(s0.shape[0]),cfg=DQNConfig())
    agent.policy.load_state_dict(torch.load(dqn_path,map_location='cpu'))
    wins=tw=tr=0
    print(f"[DQN] {games} games …")
    for g in range(1,games+1):
        env.reset(); done=False
        while not done:
            s=env._state()
            a=agent.act(s,env.guessed)
            letter=ALPHABET[a]
            _,_,done,_=env.step(letter)
        if env.mask==env.word: wins+=1
        tw+=env.total_wrong; tr+=env.total_repeated
        if g%200==0: print(f"  Game {g}/{games} done …")
    sr=wins/games; score=(sr*2000)-(tw*5)-(tr*2)
    print(f"[DQN] SR={sr:.3f}, Wrong={tw}, Rep={tr}, Score={score:.0f}")
    return dict(games=games,success_rate=sr,total_wrong=tw,total_repeated=tr,final_score=score)
# -------------------------------------------------------------------

# which oracles to evaluate (add the ones you've built)
to_eval=[]
to_eval.append(("CFP", cfp))
if 'count_hmm'    in globals(): to_eval.append(("Count-HMM", count_hmm))
if 'improved_hmm' in globals(): to_eval.append(("Improved-HMM", improved_hmm))
if 'hyb_basic'    in globals(): to_eval.append(("Hybrid-Basic", hyb_basic))
if 'hyb_improved' in globals(): to_eval.append(("Hybrid-Improved", hyb_improved))

print(f"\n=== Evaluating {len(to_eval)} oracles on {GAMES} games ===\n")

best=None
print("Oracle       | Agent   | Success Rate   | Total Wrong | Total Rep | Final Score")
print("-------------------------------------------------------------------------------------")
for name,oracle in to_eval:
    # Greedy
    g=eval_greedy(test_words,oracle,games=GAMES)
    print(f"{name:<12} | Greedy  | SR={g['success_rate']:.3f} | Wrong={g['total_wrong']:>5} | Rep={g['total_repeated']:>4} | Final={g['final_score']:>8}")
    if best is None or g['final_score']>best[2]['final_score']:
        best=(name,"Greedy",g)

    # QTable (if available)
    if HAS_ROLEB and USE_QTABLE:
        q=eval_qtable(test_words,oracle,QTABLE_PATH,games=GAMES)
        print(f"{name:<12} | QTable  | SR={q['success_rate']:.3f} | Wrong={q['total_wrong']:>5} | Rep={q['total_repeated']:>4} | Final={q['final_score']:>8}")
        if q['final_score']>best[2]['final_score']: best=(name,"QTable",q)

    # DQN (optional; off by default)
    if HAS_ROLEB and USE_DQN:
        d=eval_dqn(test_words,oracle,DQN_PATH,games=GAMES)
        print(f"{name:<12} | DQN     | SR={d['success_rate']:.3f} | Wrong={d['total_wrong']:>5} | Rep={d['total_repeated']:>4} | Final={d['final_score']:>8}")
        if d['final_score']>best[2]['final_score']: best=(name,"DQN",d)

print("\n>>> BEST RESULT <<<")
print(f"Oracle={best[0]}, Agent={best[1]}, FinalScore={best[2]['final_score']:.0f}, SR={best[2]['success_rate']:.3f}")



=== Evaluating 5 oracles on 2000 games ===

Oracle       | Agent   | Success Rate   | Total Wrong | Total Rep | Final Score
-------------------------------------------------------------------------------------
[GREEDY] Running 2000 games …
  200/2000 done …
  400/2000 done …
  600/2000 done …
  800/2000 done …
  1000/2000 done …
  1200/2000 done …
  1400/2000 done …
  1600/2000 done …
  1800/2000 done …
  2000/2000 done …
[GREEDY] Done: SR=0.028, Wrong=11875, Rep=0, Score=-59320
CFP          | Greedy  | SR=0.028 | Wrong=11875 | Rep=   0 | Final=-59320.0
[QTABLE] 2000 games …
  Game 200/2000 done …
  Game 400/2000 done …
  Game 600/2000 done …
  Game 800/2000 done …
  Game 1000/2000 done …
  Game 1200/2000 done …
  Game 1400/2000 done …
  Game 1600/2000 done …
  Game 1800/2000 done …
  Game 2000/2000 done …
[QTABLE] SR=0.002, Wrong=11987, Rep=0, Score=-59932
CFP          | QTable  | SR=0.002 | Wrong=11987 | Rep=   0 | Final=-59932.0
[DQN] 2000 games …
  Game 200/2000 done …
  Game 400/

In [6]:
# === Cell 6: Enhanced Count-HMM Oracle with Positional & Co-occurrence Features ===
import numpy as np
from collections import defaultdict, Counter
from typing import List, Dict, Set

def _mk_index():
    idx = {c:i for i,c in enumerate(ALPHABET)}
    rev = {i:c for c,i in idx.items()}
    return idx, rev

# --- small greedy simulator local to this cell (so no dependency on earlier "simulate") ---
def _play_greedy(word: str, oracle) -> dict:
    mask = ['_'] * len(word)
    guessed = set()
    lives = 6
    wrong = 0
    repeated = 0
    while lives > 0 and '_' in mask:
        probs = oracle.letter_posteriors(''.join(mask), guessed)
        if probs:
            guess = max(probs.items(), key=lambda kv: kv[1])[0]
        else:
            # fallback to first unguessed letter
            choices = [c for c in ALPHABET if c not in guessed]
            guess = choices[0] if choices else None
        if guess is None:
            break
        if guess in guessed:
            repeated += 1
            continue
        guessed.add(guess)
        if guess in word:
            for i, ch in enumerate(word):
                if ch == guess:
                    mask[i] = ch
        else:
            wrong += 1
            lives -= 1
    return {"win": ('_' not in mask), "wrong": wrong, "repeated": repeated}

def simulate_greedy(words: list[str], oracle, games=200):
    rng = np.random.default_rng(1)
    sample = words if len(words) <= games else list(rng.choice(words, size=games, replace=False))
    wins = 0
    tot_wrong = 0
    tot_rep = 0
    for w in sample:
        r = _play_greedy(w, oracle)
        wins += int(r["win"])
        tot_wrong += r["wrong"]
        tot_rep += r["repeated"]
    return {
        "games": len(sample),
        "win_rate": wins/len(sample),
        "avg_wrong": tot_wrong/len(sample),
        "avg_repeated": tot_rep/len(sample)
    }
# --- end local simulator ---

class CountHMMOracle:
    """
    Bigram HMM from counts with add-alpha smoothing and emission masks.
    letter_posteriors(mask, guessed) returns P(letter) over unguessed letters.
    """
    def __init__(self, words: List[str], smooth: float = 0.5, by_length: bool = True):
        self.idx, self.rev = _mk_index()
        self.K = len(ALPHABET)
        self.by_length = by_length
        self.smooth = smooth
        self.pi = {}   # length -> (K,)
        self.A  = {}   # length -> (K,K)
        groups = defaultdict(list)
        for w in words:
            w = w.strip().lower()
            if not w or any(ch not in ALPHABET_SET for ch in w):
                continue
            groups[len(w)].append(w)
        # estimate per length
        for L, ws in groups.items():
            pi = np.full(self.K, smooth, dtype=np.float64)
            A  = np.full((self.K, self.K), smooth, dtype=np.float64)
            for w in ws:
                pi[self.idx[w[0]]] += 1.0
                for a,b in zip(w[:-1], w[1:]):
                    A[self.idx[a], self.idx[b]] += 1.0
            pi /= pi.sum()
            A  /= A.sum(axis=1, keepdims=True)
            self.pi[L] = pi
            self.A[L]  = A
        # global fallback (if requested)
        if not by_length:
            if self.pi:
                pis = np.stack(list(self.pi.values()))
                As  = np.stack(list(self.A.values()))
                self.pi = {None: pis.mean(axis=0)}
                self.A  = {None: As.mean(axis=0)}
            else:
                self.pi = {None: np.full(self.K, 1.0/self.K)}
                self.A  = {None: np.full((self.K,self.K), 1.0/self.K)}

    def _params(self, L: int):
        if self.by_length and L in self.pi:
            return self.pi[L], self.A[L]
        return self.pi.get(None, np.full(self.K,1.0/self.K)), self.A.get(None, np.full((self.K,self.K),1.0/self.K))

    def letter_posteriors(self, mask: str, guessed: Set[str]) -> Dict[str, float]:
        L = len(mask)
        pi, A = self._params(L)
        if pi is None or A is None:
            rem = [c for c in ALPHABET if c not in guessed]
            return {c: 1.0/len(rem) for c in rem} if rem else {c: 0.0 for c in ALPHABET}

        # Emission masks E[t,s] ∈ {0,1}
        E = np.ones((L, self.K), dtype=np.float64)
        wrong = {g for g in guessed if g not in mask}
        for t,ch in enumerate(mask):
            if ch == '_':
                for g in wrong:
                    E[t, self.idx[g]] = 0.0
            else:
                E[t,:] = 0.0
                E[t, self.idx[ch]] = 1.0

        eps = 1e-12
        log_pi = np.log(pi + eps)
        log_A  = np.log(A  + eps)
        log_E  = np.log(E  + eps)

        # forward
        alpha = np.full((L, self.K), -np.inf, dtype=np.float64)
        alpha[0,:] = log_pi + log_E[0,:]
        for t in range(1, L):
            prev = alpha[t-1,:].reshape(-1,1) + log_A  # (K,K)
            m = prev.max(axis=0)
            alpha[t,:] = log_E[t,:] + (m + np.log(np.exp(prev - m).sum(axis=0) + eps))

        # backward
        beta = np.full((L, self.K), -np.inf, dtype=np.float64)
        beta[-1,:] = 0.0
        for t in range(L-2, -1, -1):
            nxt = log_A + (log_E[t+1,:] + beta[t+1,:]).reshape(1,-1)
            m = nxt.max(axis=1)
            beta[t,:] = m + np.log(np.exp(nxt - m.reshape(-1,1)).sum(axis=1) + eps)

        log_Z = alpha[-1,:].max() + np.log(np.exp(alpha[-1,:] - alpha[-1,:].max()).sum() + eps)
        gamma = np.exp(alpha + beta - log_Z)  # (L,K)

        # aggregate over blanks, exclude guessed
        P = {c: 0.0 for c in ALPHABET}
        for t,ch in enumerate(mask):
            if ch == '_':
                for s in range(self.K):
                    c = self.rev[s]
                    if c not in guessed and E[t,s] > 0.0:
                        P[c] += gamma[t,s]
        Z = sum(P.values())
        if Z <= 0:
            rem = [c for c in ALPHABET if c not in guessed]
            return {c: 1.0/len(rem) for c in rem} if rem else {c: 0.0 for c in ALPHABET}
        for c in P:
            P[c] /= Z
        return P


class ImprovedCountHMMOracle:
    """
    Enhanced HMM with:
    - Bigram HMM (existing)
    - Positional frequency analysis
    - Letter co-occurrence patterns
    """
    def __init__(self, words: List[str], smooth: float = 1.0):
        # Base bigram HMM
        self.bigram_hmm = CountHMMOracle(words, smooth=smooth, by_length=True)

        # Build additional features
        self.pos_freq = self._build_positional_freq(words)
        self.cooccur = self._build_cooccurrence(words)
        self.word_set = set(w.strip().lower() for w in words)

    def _build_positional_freq(self, words):
        pos_freq = defaultdict(lambda: defaultdict(Counter))
        for w in words:
            w = w.strip().lower()
            if not w or any(ch not in ALPHABET_SET for ch in w):
                continue
            for pos, ch in enumerate(w):
                pos_freq[len(w)][pos][ch] += 1
        for L in pos_freq:
            for pos in pos_freq[L]:
                total = sum(pos_freq[L][pos].values())
                if total > 0:
                    for ch in pos_freq[L][pos]:
                        pos_freq[L][pos][ch] /= total
        return pos_freq

    def _build_cooccurrence(self, words):
        cooccur = defaultdict(Counter)
        for w in words:
            w = w.strip().lower()
            if not w or any(ch not in ALPHABET_SET for ch in w):
                continue
            chars = set(w)
            for c1 in chars:
                for c2 in chars:
                    if c1 != c2:
                        cooccur[c1][c2] += 1
        for c1 in cooccur:
            total = sum(cooccur[c1].values())
            if total > 0:
                for c2 in cooccur[c1]:
                    cooccur[c1][c2] /= total
        return cooccur

    def _positional_score(self, mask, guessed):
        L = len(mask)
        scores = defaultdict(float)
        if L not in self.pos_freq:
            return {}
        for pos, ch in enumerate(mask):
            if ch == '_':
                for c in ALPHABET:
                    if c not in guessed:
                        scores[c] += self.pos_freq[L][pos].get(c, 0.0)
        Z = sum(scores.values())
        return {c: scores[c]/Z for c in scores} if Z > 0 else {}

    def _cooccurrence_score(self, mask, guessed):
        revealed = set(ch for ch in mask if ch != '_')
        if not revealed:
            return {}
        scores = defaultdict(float)
        for c in ALPHABET:
            if c not in guessed:
                for r in revealed:
                    scores[c] += self.cooccur[r].get(c, 0.0)
        Z = sum(scores.values())
        return {c: scores[c]/Z for c in scores} if Z > 0 else {}

    def letter_posteriors(self, mask: str, guessed: Set[str]) -> Dict[str, float]:
        P_hmm = self.bigram_hmm.letter_posteriors(mask, guessed)
        P_pos = self._positional_score(mask, guessed)
        P_cooc = self._cooccurrence_score(mask, guessed)

        vowel_boost = {}
        if len(guessed) < 3:
            vowel_boost = {'a': 1.2, 'e': 1.3, 'i': 1.1, 'o': 1.1, 'u': 1.0}

        w_hmm, w_pos, w_cooc = 0.5, 0.3, 0.2
        P = {}
        for c in ALPHABET:
            if c not in guessed:
                score = (w_hmm * P_hmm.get(c, 0.0) +
                         w_pos * P_pos.get(c, 0.0) +
                         w_cooc * P_cooc.get(c, 0.0))
                if vowel_boost:
                    score *= vowel_boost.get(c, 1.0)
                P[c] = score

        Z = sum(P.values())
        if Z > 0:
            for c in P: P[c] /= Z
        else:
            rem = [c for c in ALPHABET if c not in guessed]
            P = {c: 1.0/len(rem) for c in rem} if rem else {c: 0.0 for c in ALPHABET}
        return P


# Build both versions for comparison
print("Building Count-HMM (basic bigram)...")
count_hmm = CountHMMOracle(corpus_words, smooth=1.0, by_length=True)

print("Building Improved-HMM (with positional + co-occurrence)...")
improved_hmm = ImprovedCountHMMOracle(corpus_words, smooth=1.0)

# Evaluate both (now using the local simulate_greedy)
count_stats = evaluate_oracle_topk(count_hmm, test_words, revelation_levels=(0.2,0.4,0.6), tests_per_word=2)
count_greedy = simulate_greedy(test_words, count_hmm, games=min(500, len(test_words)))

improved_stats = evaluate_oracle_topk(improved_hmm, test_words, revelation_levels=(0.2,0.4,0.6), tests_per_word=2)
improved_greedy = simulate_greedy(test_words, improved_hmm, games=min(500, len(test_words)))

print("\n=== Basic Count-HMM ===")
print(count_stats)
print(count_greedy)

print("\n=== Improved HMM ===")
print(improved_stats)
print(improved_greedy)


Building Count-HMM (basic bigram)...
Building Improved-HMM (with positional + co-occurrence)...

=== Basic Count-HMM ===
{'top1': 0.5385454240416072, 'top3': 0.8499286972569415, 'top5': 0.9350725610267595, 'cases': 11921}
{'games': 500, 'win_rate': 0.348, 'avg_wrong': 5.15, 'avg_repeated': 0.0}

=== Improved HMM ===
{'top1': 0.529989094874591, 'top3': 0.8495092693565977, 'top5': 0.9331431926851774, 'cases': 11921}
{'games': 500, 'win_rate': 0.354, 'avg_wrong': 5.136, 'avg_repeated': 0.0}


In [8]:
# === Cell 7 (fixed): Hybrid with Improved-HMM + self-contained greedy sim ===
import numpy as np

# Provide a local greedy simulator if not already defined (from Cell 6)
if 'simulate_greedy' not in globals():
    def _play_greedy_local(word: str, oracle) -> dict:
        mask = ['_'] * len(word)
        guessed = set()
        lives = 6
        wrong = 0
        repeated = 0
        while lives > 0 and '_' in mask:
            probs = oracle.letter_posteriors(''.join(mask), guessed)
            if probs:
                guess = max(probs.items(), key=lambda kv: kv[1])[0]
            else:
                # fallback to first unguessed letter
                choices = [c for c in ALPHABET if c not in guessed]
                guess = choices[0] if choices else None
            if guess is None:
                break
            if guess in guessed:
                repeated += 1
                continue
            guessed.add(guess)
            if guess in word:
                for i, ch in enumerate(word):
                    if ch == guess:
                        mask[i] = ch
            else:
                wrong += 1
                lives -= 1
        return {"win": ('_' not in mask), "wrong": wrong, "repeated": repeated}

    def simulate_greedy(words: list[str], oracle, games=200):
        rng = np.random.default_rng(2)
        sample = words if len(words) <= games else list(rng.choice(words, size=games, replace=False))
        wins = 0
        tot_wrong = 0
        tot_rep = 0
        for w in sample:
            r = _play_greedy_local(w, oracle)
            wins += int(r["win"])
            tot_wrong += r["wrong"]
            tot_rep += r["repeated"]
        return {
            "games": len(sample),
            "win_rate": wins/len(sample),
            "avg_wrong": tot_wrong/len(sample),
            "avg_repeated": tot_rep/len(sample)
        }

def hybrid_posteriors(mask: str, guessed: set[str], hmm_oracle, cfp_oracle, lam=0.5):
    Ph = hmm_oracle.letter_posteriors(mask, guessed)
    Pc = cfp_oracle.letter_posteriors(mask, guessed)
    keys = {c for c in ALPHABET if c not in guessed}
    P = {c: lam*Ph.get(c,0.0) + (1-lam)*Pc.get(c,0.0) for c in keys}
    Z = sum(P.values())
    if Z > 0:
        for c in P:
            P[c] /= Z
    return P

class HybridOracle:
    def __init__(self, hmm_oracle, cfp_oracle, lam=0.5):
        self.hmm = hmm_oracle
        self.cfp = cfp_oracle
        self.lam = lam
    def letter_posteriors(self, mask, guessed):
        return hybrid_posteriors(mask, guessed, self.hmm, self.cfp, self.lam)

# Create hybrids with both HMM versions
LAM = 0.5  # tune 0.3–0.7 if you want
hyb_basic    = HybridOracle(count_hmm,    cfp, lam=LAM)
hyb_improved = HybridOracle(improved_hmm, cfp, lam=LAM)

hyb_basic_stats     = evaluate_oracle_topk(hyb_basic, test_words, revelation_levels=(0.2,0.4,0.6), tests_per_word=2)
hyb_basic_greedy    = simulate_greedy(test_words, hyb_basic, games=min(500, len(test_words)))

hyb_improved_stats  = evaluate_oracle_topk(hyb_improved, test_words, revelation_levels=(0.2,0.4,0.6), tests_per_word=2)
hyb_improved_greedy = simulate_greedy(test_words, hyb_improved, games=min(500, len(test_words)))

print("\n=== Hybrid (Basic HMM) ===")
print(hyb_basic_stats)
print(hyb_basic_greedy)

print("\n=== Hybrid (Improved HMM) ===")
print(hyb_improved_stats)
print(hyb_improved_greedy)



=== Hybrid (Basic HMM) ===
{'top1': 0.5665632077845818, 'top3': 0.854039090680312, 'top5': 0.9354081033470346, 'cases': 11921}
{'games': 500, 'win_rate': 0.356, 'avg_wrong': 5.072, 'avg_repeated': 0.0}

=== Hybrid (Improved HMM) ===
{'top1': 0.5662276654643067, 'top3': 0.8531163492995554, 'top5': 0.9342337052260716, 'cases': 11921}
{'games': 500, 'win_rate': 0.352, 'avg_wrong': 5.098, 'avg_repeated': 0.0}


In [10]:
# === Cell 8 (robust): Head-to-head comparison & winner ===
import numpy as np

# 0) tiny greedy sim if not present (same as in Cell 7)
if 'simulate_greedy' not in globals():
    def _play_greedy_local(word: str, oracle) -> dict:
        mask = ['_'] * len(word)
        guessed = set()
        lives = 6
        wrong = 0
        repeated = 0
        while lives > 0 and '_' in mask:
            probs = oracle.letter_posteriors(''.join(mask), guessed)
            if probs:
                guess = max(probs.items(), key=lambda kv: kv[1])[0]
            else:
                choices = [c for c in ALPHABET if c not in guessed]
                guess = choices[0] if choices else None
            if guess is None:
                break
            if guess in guessed:
                repeated += 1
                continue
            guessed.add(guess)
            if guess in word:
                for i, ch in enumerate(word):
                    if ch == guess:
                        mask[i] = ch
            else:
                wrong += 1
                lives -= 1
        return {"win": ('_' not in mask), "wrong": wrong, "repeated": repeated}

    def simulate_greedy(words: list[str], oracle, games=200):
        rng = np.random.default_rng(3)
        sample = words if len(words) <= games else list(rng.choice(words, size=games, replace=False))
        wins = 0
        tot_wrong = 0
        tot_rep = 0
        for w in sample:
            r = _play_greedy_local(w, oracle)
            wins += int(r["win"])
            tot_wrong += r["wrong"]
            tot_rep += r["repeated"]
        return {
            "games": len(sample),
            "win_rate": wins/len(sample),
            "avg_wrong": tot_wrong/len(sample),
            "avg_repeated": tot_rep/len(sample)
        }

def final_like_score(g):
    """Proxy of hackathon score using greedy stats."""
    G = g["games"]; wr = g["win_rate"]
    tot_wrong = g["avg_wrong"] * G
    tot_rep   = g["avg_repeated"] * G
    return wr*2000 - (tot_wrong*5) - (tot_rep*2)

def show(name, topk_stats, greedy_stats):
    print(f"\n=== {name} ===")
    print(f"Top-1: {topk_stats['top1']:.3f} | Top-3: {topk_stats['top3']:.3f} | Top-5: {topk_stats['top5']:.3f} | Cases: {topk_stats['cases']}")
    print(f"Greedy win_rate: {greedy_stats['win_rate']:.3f} | avg_wrong: {greedy_stats['avg_wrong']:.2f} | avg_repeated: {greedy_stats['avg_repeated']:.2f} | games: {greedy_stats['games']}")
    print(f"Proxy Final Score: {final_like_score(greedy_stats):.1f}")

# 1) Collect whatever oracles you have
oracles = []
oracles.append(("CFP (Baseline)",        cfp))
if 'count_hmm'    in globals(): oracles.append(("Count-HMM (Basic)",       count_hmm))
if 'improved_hmm' in globals(): oracles.append(("Improved-HMM",            improved_hmm))
if 'hyb_basic'    in globals(): oracles.append(("Hybrid (Basic HMM + CFP)", hyb_basic))
if 'hyb_improved' in globals(): oracles.append(("Hybrid (Improved HMM + CFP)", hyb_improved))

# 2) Ensure stats objects exist; compute if missing
def ensure_stats(name, oracle, topk_varname, greedy_varname):
    topk = globals().get(topk_varname)
    greedy = globals().get(greedy_varname)
    if topk is None:
        topk = evaluate_oracle_topk(oracle, test_words, revelation_levels=(0.2,0.4,0.6), tests_per_word=2)
        globals()[topk_varname] = topk
    if greedy is None:
        greedy = simulate_greedy(test_words, oracle, games=min(500, len(test_words)))
        globals()[greedy_varname] = greedy
    show(name, topk, greedy)
    return (name, greedy)

name_map_to_vars = {
    "CFP (Baseline)":                    ("cfp_stats", "greedy_cfp"),
    "Count-HMM (Basic)":                 ("count_stats", "count_greedy"),
    "Improved-HMM":                      ("improved_stats", "improved_greedy"),
    "Hybrid (Basic HMM + CFP)":          ("hyb_basic_stats", "hyb_basic_greedy"),
    "Hybrid (Improved HMM + CFP)":       ("hyb_improved_stats", "hyb_improved_greedy"),
}

candidates = []
for (disp_name, oracle) in oracles:
    tv, gv = name_map_to_vars[disp_name]
    candidates.append(ensure_stats(disp_name, oracle, tv, gv))

# 3) Pick winner by proxy score
winner = max(candidates, key=lambda x: final_like_score(x[1]))
print(f"\n{'='*60}")
print(f">>> WINNER (by proxy final score): {winner[0]}")
print(f">>> Final Score: {final_like_score(winner[1]):.1f}")
print(f"{'='*60}")



=== CFP (Baseline) ===
Top-1: 0.432 | Top-3: 0.669 | Top-5: 0.784 | Cases: 11921
Greedy win_rate: 0.024 | avg_wrong: 5.94 | avg_repeated: 0.00 | games: 500
Proxy Final Score: -14807.0

=== Count-HMM (Basic) ===
Top-1: 0.539 | Top-3: 0.850 | Top-5: 0.935 | Cases: 11921
Greedy win_rate: 0.348 | avg_wrong: 5.15 | avg_repeated: 0.00 | games: 500
Proxy Final Score: -12179.0

=== Improved-HMM ===
Top-1: 0.530 | Top-3: 0.850 | Top-5: 0.933 | Cases: 11921
Greedy win_rate: 0.354 | avg_wrong: 5.14 | avg_repeated: 0.00 | games: 500
Proxy Final Score: -12132.0

=== Hybrid (Basic HMM + CFP) ===
Top-1: 0.567 | Top-3: 0.854 | Top-5: 0.935 | Cases: 11921
Greedy win_rate: 0.356 | avg_wrong: 5.07 | avg_repeated: 0.00 | games: 500
Proxy Final Score: -11968.0

=== Hybrid (Improved HMM + CFP) ===
Top-1: 0.566 | Top-3: 0.853 | Top-5: 0.934 | Cases: 11921
Greedy win_rate: 0.352 | avg_wrong: 5.10 | avg_repeated: 0.00 | games: 500
Proxy Final Score: -12041.0

>>> WINNER (by proxy final score): Hybrid (Basic H