# Solving Wordle with Actor Critic Methods

# Background

Wordle is a word-guessing game where the player has six attempts to identify a hidden five-letter target word. After each guess, feedback is given for every letter: gray if the letter does not occur in the target, yellow if it occurs in a different position, and green if it is in the correct position. This feedback defines the game’s evolving state.

We formulate Wordle as a reinforcement learning (RL) problem. Each state corresponds to the history of guesses and feedback, while actions are legal word choices from the vocabulary. The agent’s objective is to select actions that reduce the set of candidate target words, ultimately converging on the correct solution. The environment, optimization function, and hyperparameters are fixed; the agent must learn a strategy through exploration.

Although Wordle can be solved deterministically with hard-coded algorithms such as Knuth’s Mastermind strategy, those methods are brittle and scale poorly beyond small, well-structured state spaces. By contrast, an RL approach offers greater generalization. While it may not outperform Knuth at Wordle’s current scale, it provides a framework adaptable to larger vocabularies or related problems where hard-coded search becomes infeasible.

---
# Using Reinforcement Learning on Wordle


## Why Actor–Critic Methods Were Used

This notebook solves Wordle using an actor–critic architecture combined with Generalized Advantage Estimation (GAE) and Proximal Policy Optimization (PPO). Actor–critic methods train two networks:

- Actor: parameterizes the policy $\pi_\theta(a \mid s)$, mapping states to action probabilities.

- Critic: estimates the value function $V_\phi(s)$, providing a learned baseline to reduce variance in policy gradient updates.

Compared to REINFORCE, which relies solely on Monte Carlo returns, actor–critic methods substantially reduce gradient variance by incorporating value-function baselines. This stabilizes learning, especially in environments like Wordle where rewards are sparse and heavily delayed.

To further improve stability, we adopt:

GAE: computes smoothed, lower-variance advantage estimates $A_t^{(\lambda)}$, balancing bias and variance.

PPO: introduces a clipped surrogate objective and adaptive KL control, preventing destructive policy updates while maintaining sample efficiency.

Together, these choices address the high variance of pure REINFORCE and yield a more stable and scalable training procedure.

---
# Mathematical Notes for the Wordle RL Agent (PPO + GAE, Masked Softmax)

This note documents the exact mathematics used in the Wordle reinforcement learning (RL) agent that trains with PPO (clipped policy gradients) and GAE, with hard action masking to enforce Wordle feedback consistency. 

## Environment
### Wordle as an MDP
We model Wordle as a finite-horizon episodic MDP with horizon $T=6$ turns.

$\mathcal{A}$ is the legal 5-letter vocabulary (the **allowed** list). Each action $a \in \mathcal{A}$ is a guess (a word).
``Target set`` $\mathcal{S}$ (``solutions``) is the subset of canonical Wordle answers. At the start of an episode a target word $w^\star \in \mathcal{S}$ is sampled.
``State $s_t$`` encodes the guess/feedback history up to turn $t$: $$s_t = \big((a_1, y_1), (a_2,y_2), \dots, (a_{t-1}, y_{t-1})\big).$$
 ``Transition`` is deterministic given $(s_t, a_t)$: the environment returns the Wordle feedback $$y_t \in \{0,1,2\}^5,$$ where $0$=gray, $1$=yellow, $2$=green, computed by the standard Wordle scoring function $f(a_t, w^\star)$.
  ``Termination`` occurs at the first $t$ such that $a_t = w^\star$ (win) or when $t = T$ (loss).

## Posterior Candidate Set and Hard Masking
Let $\mathcal{C}_1 = \mathcal{S}$ denote the initial candidate set of possible targets. After observing $(a_t, y_t)$ we deterministically update
$$
\mathcal{C}_{t+1} \;=\; \{\, w \in \mathcal{C}_t \;:\; f(a_t, w) = y_t \,\}.
$$
Thus the **sufficient statistic** for optimal play is $\mathcal{C}_t$. The code uses **hard action masking**. We define a mask
$$
m_t(a) \;=\;
\begin{cases}
1, & \text{if } a \text{ is consistent with } \{(a_\tau,y_\tau)\}_{\tau \le t} \text{ (i.e., } a \in \text{Posterior-Consistent}\text{)},\\[4pt]
0, & \text{otherwise.}
\end{cases}
$$
The policy's probability mass over invalid actions is forced to zero by masked softmax (see Policy, Value, and Masked Softmax). For diagnostics, the code logs the **pre-mask invalid mass**, i.e. the policy mass that would have been assigned to invalid actions before applying the mask.

## Policy, Value, and Masked Softmax
### Masked Softmax
Let the actor network produce unnormalized logits $z_\theta(s_t)\in\mathbb{R}^{|\mathcal{A}|}$. We define masked logits as:
$$
\tilde z_\theta(a \mid s_t) \;=\;
\begin{cases}
z_\theta(a \mid s_t), & m_t(a)=1,\\
-\infty, & m_t(a)=0.
\end{cases}
$$
Then the policy is
$$
\pi_\theta(a\mid s_t) \;=\; \frac{\exp\big(\tilde z_\theta(a \mid s_t)/\tau\big)}{\sum_{a'\in\mathcal{A}}\exp\big(\tilde z_\theta(a' \mid s_t)/\tau\big)},
$$
with temperature $\tau>0$ (often $\tau=1$ during training and lower at evaluation for greedier behavior). By construction, $\pi_\theta(a\mid s_t)=0$ when $m_t(a)=0$.

### Value function
A critic $V_\phi(s_t)$ estimates the state value under the current policy:
$$
V_\phi(s_t) \;\approx\; \mathbb{E}_{\pi_\theta}\big[\,R_t \,\big|\, s_t\big],
$$
where $R_t$ is the return defined below.

## Rewards and Returns
### Shaped reward
The code employs dense shaping to stabilize learning in a sparse terminal-reward problem. Let $\mathcal{C}_t$ be the candidate set size after applying all $(a_\tau,y_\tau)$ up to $t{-}1$. An information-gain shaped reward is
$$
r_t^{(\text{IG})} \;=\; \log |\mathcal{C}_t| \;-\; \log |\mathcal{C}_{t+1}|.
$$
This measures the reduction in posterior entropy under a uniform prior on $\mathcal{C}_t$. Additional components (e.g., small per-green/yellow bonuses, or an on-win bonus at termination) can be linearly combined:
$$
r_t \;=\; \alpha \, r_t^{(\text{IG})} \;+\; \beta_{\text{term}}\cdot \mathbf{1}\{a_t=w^\star\}\cdot \mathbf{1}\{t\le T\} \;+\; \sum_k \gamma_k \, r_t^{(k)}.
$$
In practice $\alpha>0$, $\beta_{\text{term}}\ge 0$; other terms $r_t^{(k)}$ are optional.

### Undiscounted Episodic Return
Episodes are short; the implementation uses $\gamma=1$ (undiscounted):
$$
R_t \;=\; \sum_{u=t}^{t_{\text{end}}} r_u,
$$
where $t_{\text{end}}\le T$ is the termination time.

## Advantage Estimation (GAE)
Define the TD residuals
$$
\delta_t \;=\; r_t \;+\; \gamma \, V_\phi(s_{t+1}) \;-\; V_\phi(s_t),
$$
with $\gamma=1$ here. Generalized Advantage Estimation (GAE) with parameter $\lambda \in [0,1]$ is
$$
A_t^{(\lambda)} \;=\; \sum_{l=0}^{\infty} (\gamma\lambda)^l \, \delta_{t+l}.
$$
In finite horizon, the sum truncates at episode end. Advantages are normalized per-batch in code:
$$
\hat A_t \;=\; \frac{A_t^{(\lambda)} - \mu_A}{\sigma_A + \varepsilon}.
$$

## PPO Objective with Clipping
Let the importance ratio be
$$
r_t(\theta) \;=\; \frac{\pi_\theta(a_t\mid s_t)}{\pi_{\theta_{\text{old}}}(a_t\mid s_t)}.
$$
The clipped surrogate objective uses a symmetric clip parameter $\varepsilon>0$:
$$
L^{\text{CLIP}}(\theta) \;=\; \mathbb{E}_t \Big[ \min\Big( r_t(\theta)\,\hat A_t,\; \mathrm{clip}\big(r_t(\theta),\, 1{-}\varepsilon,\, 1{+}\varepsilon\big)\,\hat A_t \Big) \Big].
$$
The critic is trained by a (possibly clipped) squared-error objective
$$
L^{\text{V}}(\phi) \;=\; \mathbb{E}_t \Big[ \big( V_\phi(s_t) - \hat R_t \big)^2 \Big],
$$
where $\hat R_t$ is the Monte Carlo return or a bootstrapped target consistent with the GAE construction.

## Entropy Regularization
To discourage premature collapse, add policy entropy $H(\pi_\theta(\cdot\mid s_t))$ with coefficient $\beta_{\text{ent}}$ that is linearly annealed to $0$ during training:
$$
H(\pi_\theta(\cdot\mid s)) \;=\; - \sum_{a\in\mathcal{A}} \pi_\theta(a\mid s)\,\log \pi_\theta(a\mid s).
$$

## Adaptive KL Control and Actor ``brake``
The implementation also monitors minibatch KL divergence
$$
\mathrm{KL}\!\left(\pi_{\theta_{\text{old}}}(\cdot\mid s_t)\;\|\; \pi_\theta(\cdot\mid s_t)\right)
\;=\; \sum_{a} \pi_{\theta_{\text{old}}}(a\mid s_t)\,
\log \frac{\pi_{\theta_{\text{old}}}(a\mid s_t)}{\pi_\theta(a\mid s_t)}.
$$
When the observed KL exceeds a target (e.g., $2.5\times$ the target), the update \emph{skips the actor step} (``brake'') and/or increases an adaptive penalty $\beta_{\text{KL}}$ or reduces the actor learning rate. This keeps policy updates within a trust region even beyond clipping.

### Total Loss (Minimization Form)
The code minimizes the following (signs chosen for **minimization**):
$$
\mathcal{L}(\theta,\phi)
\;=\;
-\, L^{\text{CLIP}}(\theta)
\;+\;
\beta_{\text{ent}}\cdot \mathbb{E}_t\!\left[\, H\big(\pi_\theta(\cdot\mid s_t)\big) \right]
\;+\;
c_v \cdot \mathbb{E}_t\!\left[\big(V_\phi(s_t)-\hat{R}_t\big)^2\right]
\;+\;
\beta_{\text{KL}} \cdot \mathbb{E}_t\!\left[\mathrm{KL}\!\left(\pi_{\theta_{\text{old}}} \,\|\, \pi_\theta\right)\right].
$$
Here $c_v>0$ is the value-loss weight, $\beta_{\text{ent}}\ge 0$ is annealed to $0$, and $\beta_{\text{KL}}\ge 0$ is adjusted by the controller. In some runs $\beta_{\text{KL}}$ may be zero except when KL control activates.

## Batching, Epochs, and Optimization
- Episodes are collected in batches. Let a batch index set be $\mathcal{B}$; PPO forms multiple shuffled minibatches per epoch.
- The actor parameters $\theta$ and critic parameters $\phi$ use \emph{separate} optimizers and learning rates: $$\eta_{\text{actor}} \neq \eta_{\text{critic}}.$$
- Gradient clipping is applied: $$\|\nabla\|_2 \le g_{\max}.$$
- Optional $L_2$ weight decay may be applied to stabilize the critic.
- Advantages $\hat A_t$ are normalized per batch; value targets $\hat R_t$ use the same rollouts.


## Train/Val/Test Protocol
The split is applied over **targets** (answers), not the action vocabulary:
$$
\text{Train:Val:Test} \in \{(0.8,0.1,0.1),\ (0.7,0.15,0.15)\}.
$$
At evaluation time, the policy may be run greedily ($\arg\max$) or with low temperature to estimate:
- Success rate (SR): fraction of targets solved within $T=6$ turns.
- Average turns: mean number of guesses conditioned on success (or overall).
- Diagnostics: minibatch KL, entropy $H$, pre-mask invalid mass, critic loss.

Confidence intervals (e.g., 95%) for SR are reported with a normal approximation or exact binomial intervals.

## Why Masking Matters
Without masking, the policy wastes probability mass on invalid guesses (contradicting known feedback), increasing variance and harming sample efficiency. Masked softmax enforces logical consistency with the posterior and empirically accelerates learning. The logit $-\infty$ implementation is equivalent to zeroing those probabilities and renormalizing over valid actions.

## Relation to Optimal Play and Knuth-Style Search
An omniscient solver can optimize expected remaining candidate size (or worst-case) via a combinatorial search akin to Knuth's algorithm for Mastermind. The RL agent instead **learns** a policy that implicitly trades off information gain against win probability. The information-gain reward approximates this heuristic: maximizing $$\mathbb{E}[\log |\mathcal{C}_t| - \log |\mathcal{C}_{t+1}|]$$ encourages guesses that split the posterior well, while PPO keeps updates stable.

## Summary
The implementation is a standard PPO+GAE agent specialized to Wordle via hard action masking and information-gain shaping. The total loss minimized in code is exactly
$$
\mathcal{L}(\theta,\phi)
\;=\;
- L^{\text{CLIP}}(\theta)
\;+\;
\beta_{\text{ent}}\cdot \mathbb{E}_t\!\left[ H(\pi_\theta(\cdot\mid s_t)) \right]
\;+\;
c_v \cdot \mathbb{E}_t\!\left[\big(V_\phi(s_t)-\hat{R}_t\big)^2\right]
\;+\;
\beta_{\text{KL}} \cdot \mathbb{E}_t\!\left[\mathrm{KL}\!\left(\pi_{\theta_{\text{old}}} \,\|\, \pi_\theta\right)\right].
$$
This matches the structure and logging present in the notebook (entropy anneal, KL target/skip, separate actor/critic LRs, masked softmax, and posterior-consistency filter).


---
# Code

In [1]:
# Wordle RL — PPO-ActorCritic + GAE (unseen-vocab capable)
# Safer, stabler training: tighter trust region, bigger batches, entropy floor,
# optional KL penalty, reward clipping, annealing, and richer logging.
# ============================================

import os, math, random, statistics
from dataclasses import dataclass
from typing import List, Tuple, Optional
import requests
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import Counter as _C

# Function: Prefer fast matmul kernels on supported GPUs (e.g., TF32 on Ampere).
# Input: none (uses torch backend); guarded by try/except for portability.
# Output: none (side-effect is backend precision/perf setting; safe to skip if unsupported).
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

# -----------------------------
# 0) Word lists: Allowed (tabatkins) + Solutions (cfreshman)
# -----------------------------
ALLOWED_URL   = "https://raw.githubusercontent.com/tabatkins/wordle-list/main/words"
SOLUTIONS_URL = "https://gist.githubusercontent.com/cfreshman/a03ef2cba789d8cf00c08f767e0fad7b/raw/wordle-answers-alphabetical.txt"
# Allowed list: ~14,855 legal guess words in Wordle (used as the policy's action space).
# Solutions list: ~2,315 canonical target words that the game actually selects.

# Function: Download a plaintext word list and keep only 5-letter alphabetic words.
# Input: url (str) — remote URL with one word per line.
# Output: list[str] — lowercase, deduplicated, sorted words.
def _download_list(url: str) -> list: 
    r = requests.get(url, timeout=12)
    r.raise_for_status()
    words = [w.strip().lower() for w in r.text.splitlines()
             if len(w.strip()) == 5 and w.strip().isalpha()]
    return sorted(set(words))  # dedupe + sort

# Function: Load/caches Wordle vocab and solutions with robust fallbacks.
# Input: cache_dir (str) — directory for local "solutions.txt" and "allowed.txt".
# Output: (vocab, solutions)
#   - vocab: list[str] — union(allowed, solutions), i.e., the full action space.
#   - solutions: list[str] — ground-truth target set for training/eval sampling.
def load_wordle_lists(cache_dir: str = ".wordlists"): 
    # Attempt to read from cache. If files exist, prefer them (offline-friendly).
    os.makedirs(cache_dir, exist_ok=True)
    sol_path     = os.path.join(cache_dir, "solutions.txt")
    allowed_path = os.path.join(cache_dir, "allowed.txt")

    solutions, allowed = [], []
    # Read cached solutions if present.
    if os.path.exists(sol_path):
        with open(sol_path, "r", encoding="utf-8") as f:
            solutions = sorted({w.strip().lower() for w in f if len(w.strip()) == 5 and w.strip().isalpha()})
    # Read cached allowed list if present.
    if os.path.exists(allowed_path):
        with open(allowed_path, "r", encoding="utf-8") as f:
            allowed = sorted({w.strip().lower() for w in f if len(w.strip()) == 5 and w.strip().isalpha()})
    # If missing in cache, download and then cache for future runs.
    if not allowed:
        try:
            allowed = _download_list(ALLOWED_URL)
            with open(allowed_path, "w", encoding="utf-8") as f:
                f.write("\n".join(allowed))
        except Exception as e:
            allowed = []
            print(f"[warn] Could not fetch allowed list: {e}")

    if not solutions:
        try:
            solutions = _download_list(SOLUTIONS_URL)
            with open(sol_path, "w", encoding="utf-8") as f:
                f.write("\n".join(solutions))
        except Exception as e:
            solutions = []
            print(f"[warn] Could not fetch solutions list: {e}")
    # Fallback mode (no net / URLs blocked): try wordfreq, then NLTK.
    # Keep only 5-letter alphabetic words; cap solutions to ~3k if synthesized.
    if not allowed or not solutions:
        print("[warn] Falling back to local corpora (wordfreq / nltk). Quality differs from NYT.")
        try:
            from wordfreq import top_n_list
            wf = [w for w in top_n_list('en', 20000) if len(w) == 5 and w.isalpha()]
            if not allowed:   allowed   = sorted(set(wf))
            if not solutions: solutions = sorted(set(wf[:3000]))
        except Exception as e2:
            print(f"[warn] wordfreq fallback failed: {e2}")
            try:
                import nltk; nltk.download('words', quiet=True)
                from nltk.corpus import words as nltk_words
                wl = sorted(set(w.lower() for w in nltk_words.words() if len(w)==5 and w.isalpha()))
                if not allowed:   allowed   = wl
                if not solutions: solutions = wl[:min(3000, max(1000, len(wl)//4))]
            except Exception as e3:
                # If all fallbacks fail, instruct how to install deps.
                print("[hint] Try: pip install wordfreq nltk")
                raise RuntimeError("No word list available. Install `wordfreq` or `nltk`, or enable internet.") from e3
    # Final assembly/validation.
    vocab = sorted(set(allowed) | set(solutions))
    if not vocab or not solutions:
        raise RuntimeError("Empty word lists after loading.")
    return vocab, solutions

# ---------------------------------
# 1) Wordle mechanics: feedback
# ---------------------------------

# Function: Implements official Wordle scoring (handles duplicates properly).
# Input: guess (str), target (str) — both 5-letter words.
# Output: Tuple[int,int,int,int,int] — per-position feedback: 0=gray, 1=yellow, 2=green.
# Notes: First pass marks greens and collects unmatched target letters; second pass assigns yellows from what's left.
def wordle_feedback(guess: str, target: str) -> Tuple[int, int, int, int, int]:
    guess, target = guess.lower(), target.lower()
    res = [0]*5
    unmatched = []
    for i,(g,t) in enumerate(zip(guess, target)):
        if g == t:
            res[i] = 2
        else:
            unmatched.append(t)
    avail = _C(unmatched)
    for i,g in enumerate(guess):
        if res[i] == 2: continue
        if avail[g] > 0:
            res[i] = 1; avail[g] -= 1
        else:
            res[i] = 0
    return tuple(res)
    
# Function: Test whether a candidate word remains logically consistent with all past feedback.
# Input: history: list[ (guess:str, feedback:Tuple[int,...]) ], candidate: str (5 letters).
# Output: bool — True iff wordle_feedback(guess, candidate) equals recorded feedback for all (guess,feedback) in history.
def consistent_with(history: List[Tuple[str, Tuple[int,...]]], candidate: str) -> bool:
    return all(wordle_feedback(g, candidate) == fb for (g, fb) in history)
    
# Function: Filter the vocabulary to words consistent with the entire guess/feedback history.
# Input: vocab: list[str] (action space), history: list[(guess, feedback)].
# Output: list[str] — posterior candidate set (surviving words).
def posterior_candidates(vocab: List[str], history: List[Tuple[str, Tuple[int,...]]]) -> List[str]:
    return [w for w in vocab if consistent_with(history, w)]

# -------------------------------
# 2) State encoder (φ)
# -------------------------------

class StateEncoder:
    # Function: Configure how many past steps to encode and compute state dimension.
    # Input: history_len (int) — number of recent (guess,feedback) pairs to include.
    # Output: none (sets self.state_dim = 145 * history_len).
    def __init__(self, history_len: int = 6):
        self.history_len = history_len
        self.dim_per_step = 145  # 5*26 (letters) + 5*3 (feedback)
        self.state_dim = self.dim_per_step * self.history_len

    # Function: Encode up to history_len recent steps into a fixed-length numeric state.
    # Input: history: list[(guess:str, feedback:Tuple[int,int,int,int,int])].
    # Output: np.ndarray (float32) with shape (state_dim,), where state_dim = history_len * 145.
    # Details:
    #   - Letters: per position, 26-dim one-hot over a–z (total 5*26=130).
    #   - Feedback: per position, 3-dim one-hot over {0,1,2} (total 5*3=15).
    #   - Left-pad with all-zero blocks if fewer than history_len steps exist.
    def encode(self, history: List[Tuple[str, Tuple[int,...]]]) -> np.ndarray:
        blocks = []
        cut = history[-self.history_len:]
        for guess, fb in cut:
            gvec = []
            for ch in guess:
                oh = [0]*26; idx = ord(ch) - 97
                if 0 <= idx < 26: oh[idx] = 1
                gvec.extend(oh)
            fvec = []
            for v in fb:
                fo = [0,0,0]; fo[v] = 1; fvec.extend(fo)
            blocks.append(gvec + fvec)
        while len(blocks) < self.history_len:
            blocks.insert(0, [0]*self.dim_per_step)
        return np.array([x for blk in blocks for x in blk], dtype=np.float32)

# ---------------------------------------
# 3) Environment + shaped reward
# ---------------------------------------

# @dataclass RewardCoeffs
# Function: Container for all reward-shaping coefficients used by the environment.
# Input: (constructor args; all optional with defaults)
#   - alpha_diversity: bonus for unique letters in a guess.
#   - beta_repeat: penalty for repeating a previously used guess.
#   - gamma_green_sq: bonus proportional to the square of how many green positions are preserved from the previous guess.
#   - delta_yellow: bonus for reusing previously yellow letters in different positions.
#   - lambda_entropy: multiplier for posterior-entropy reduction (from entropy_start_turn onward).
#   - per_step_cost: per-guess cost (penalty).
#   - win_bonus: base win bonus (time-shaped).
#   - win_shape_k: exponent for early-win shaping (larger → heavier early-win bias).
#   - fail_penalty_base: baseline failure penalty at timeout.
#   - fail_penalty_scale: scales failure penalty with fraction of turns used.
#   - terminal_bonus / terminal_decay: legacy (unused by current win path; retained for compatibility).
#   - entropy_start_turn: turn index at/after which entropy shaping activates.
# Output: RewardCoeffs instance (no behavior).
@dataclass
class RewardCoeffs:
    alpha_diversity: float = 0.5
    beta_repeat: float    = 2.0
    gamma_green_sq: float = 1.0
    delta_yellow: float   = 0.5
    lambda_entropy: float = 0.5
    per_step_cost: float  = 2.0

    # Nonlinear early-win shaping
    win_bonus: float    = 100.0
    win_shape_k: float  = 2.5   # 2–3 makes early wins much bigger

    # Failure shaping
    fail_penalty_base: float  = 20.0
    fail_penalty_scale: float = 20.0

    # legacy (unused by win now, keep for compatibility if referenced elsewhere)
    terminal_bonus: float = 100.0
    terminal_decay: float = 15.0

    entropy_start_turn: int = 3
    
# class WordleEnv
# Function: Initialize the Wordle RL environment (vocab, reward shaping, masking mode, horizon).
# Input:
#   - vocab: list[str] — all legal 5-letter words (baseline action space).
#   - coeffs: RewardCoeffs — reward-shaping parameters.
#   - mask_mode: "hard" (filter to posterior-consistent actions) or "valid" (allow all).
#   - max_turns: int — episode horizon (typically 6).
# Output: none (side effects: normalize vocab, build index, create StateEncoder(5), and reset()).
class WordleEnv:
    def __init__(self, vocab: List[str], coeffs: RewardCoeffs, mask_mode: str = "hard", max_turns: int = 6):
        # mask_mode: "hard" = posterior-consistent only; "valid" = no filtering (all vocab allowed)
        assert mask_mode in ("valid","hard")
        self.vocab = sorted(set([w.lower() for w in vocab if len(w)==5 and w.isalpha()]))
        self.vocab_index = {w:i for i,w in enumerate(self.vocab)}
        self.coeffs = coeffs
        self.mask_mode = mask_mode
        self.max_turns = max_turns
        self.encoder = StateEncoder(history_len=5)  # NOTE: encodes only last 5 steps; max_turns is 6.
        self.reset(target=random.choice(self.vocab))
        
    # Function: Start a new episode (optionally with a fixed target).
    # Input: target (Optional[str]) — if provided, use this exact target; else sample from vocab.
    # Output: (obs, mask)
    #   - obs: torch.Tensor — encoded state, shape (145*history_len,), dtype float32.
    #   - mask: torch.Tensor — legal-action mask, shape (len(vocab),), dtype float32.
    def reset(self, target: Optional[str]=None):
        self.history: List[Tuple[str, Tuple[int,...]]] = []
        self.used: set = set()
        self.turn = 0
        self.done = False
        self.target = target if target is not None else random.choice(self.vocab)
        self.prev_posterior = self.vocab[:]
        self.prev_entropy = math.log2(len(self.prev_posterior))
        self._cached_post = self.vocab[:]  # cache for mask
        return self._obs(), self.legal_action_mask()
        
    # Function: Internal helper — encode current history via StateEncoder.
    # Input: none (uses self.history).
    # Output: torch.Tensor of shape (145 * history_len,), dtype float32.
    def _obs(self) -> torch.Tensor:
        return torch.from_numpy(self.encoder.encode(self.history))
        
    # Function: Compute the current action-availability mask.
    # Input: none (uses self.history, self._cached_post, mask_mode).
    # Output: torch.Tensor of shape (len(self.vocab),), dtype float32.
    #   - "valid": all ones.
    #   - "hard": ones for posterior-consistent words, zeros otherwise.
    #   - empty posterior (edge case): all ones (fail-open).
    def legal_action_mask(self) -> torch.Tensor:
        mask = np.zeros(len(self.vocab), dtype=np.float32)
        if self.mask_mode == "valid":
            mask[:] = 1.0
        else:
            cons = getattr(self, "_cached_post", None)
            if cons is None:
                cons = posterior_candidates(self.vocab, self.history)
            if not cons:
                mask[:] = 1.0
            else:
                idxs = [self.vocab_index[w] for w in cons]
                mask[idxs] = 1.0
        return torch.from_numpy(mask)

    # Function: Advance one step with a guessed word; compute shaped reward; return transition.
    # Input: guess (str) — proposed 5-letter word (must be in vocab or episode ends invalid).
    # Output: (obs, reward, done, mask, info)
    #   - obs: torch.Tensor — new observation after applying the guess.
    #   - reward: float — shaped reward (per-step cost, diversity, repeat penalty, green^2, yellow reuse, entropy reduction, win bonus, fail penalty).
    #   - done: bool — True if solved or horizon reached; else False.
    #   - mask: torch.Tensor — updated legal action mask.
    #   - info: (guess, feedback_tuple) on normal steps; (guess, ("invalid",)) if invalid guess ended the episode.
    def step(self, guess: str):
        if self.done:
            raise RuntimeError("Episode done; call reset().")
        if guess not in self.vocab:
            r = -10.0; self.done = True
            return self._obs(), r, self.done, self.legal_action_mask(), (guess, ("invalid",))

        self.turn += 1
        fb = wordle_feedback(guess, self.target)
        self.history.append((guess, fb))

        c = self.coeffs
        reward = 0.0
        reward -= c.per_step_cost
        reward += c.alpha_diversity * len(set(guess))
        if guess in self.used:
            reward -= c.beta_repeat
        self.used.add(guess)

        if len(self.history) >= 2:
            prev_guess, prev_fb = self.history[-2]
            greens_prev = [i for i,v in enumerate(prev_fb) if v==2]
            matched = sum(1 for i in greens_prev if guess[i] == prev_guess[i])
            reward += c.gamma_green_sq * (matched**2)
            yellows_prev = [(i, prev_guess[i]) for i,v in enumerate(prev_fb) if v==1]
            reuse = 0
            for i,ch in yellows_prev:
                if ch in guess and guess[i] != ch:
                    reuse += 1
            reward += c.delta_yellow * reuse

        post = posterior_candidates(self.vocab, self.history)
        self._cached_post = post  # cache once per step
        H = math.log2(len(post)) if len(post) > 0 else 0.0
        if self.turn >= c.entropy_start_turn:
            reward += c.lambda_entropy * (self.prev_entropy - H)
        self.prev_entropy = H
        self.prev_posterior = post
        
        if guess == self.target:
            # Nonlinear time-shaped bonus
            rem  = self.max_turns - self.turn + 1
            frac = rem / self.max_turns  # (0,1]
            shaped = c.win_bonus * (frac ** c.win_shape_k)
            reward += shaped
            self.done = True
        elif self.turn >= self.max_turns:
            frac_used = min(1.0, self.turn / self.max_turns)
            fail_pen = c.fail_penalty_base + c.fail_penalty_scale * frac_used
            reward -= fail_pen
            self.done = True

        return self._obs(), reward, self.done, self.legal_action_mask(), (guess, fb)

# --------------------------------
# 4) Policy net (masked) + utilities
# --------------------------------

# class PolicyNet
# Function: Feedforward policy producing raw logits over the vocabulary; helper to apply masked softmax.
# Input (ctor): state_dim (int), vocab_size (int), hidden (tuple[int,int]) — MLP sizes.
# Output: forward(state)-> logits (torch.Tensor shape [*, vocab_size]).
class PolicyNet(nn.Module):
    def __init__(self, state_dim: int, vocab_size: int, hidden=(512,256)):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Linear(state_dim, hidden[0]), nn.ReLU(),
            nn.Linear(hidden[0], hidden[1]), nn.ReLU(),
            nn.Linear(hidden[1], vocab_size)
        )

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        return self.backbone(state)  # raw logits (unmasked)

    # Function: Convert logits to probabilities with action masking and temperature.
    # Input: raw_logits (Tensor [B,V]), action_mask (Tensor [B,V] with {0,1}), temperature (float).
    # Output: (probs, invalid_mass_pre)
    #   - probs: Tensor [B,V] — normalized over valid actions; falls back to uniform over valid if all-masked.
    #   - invalid_mass_pre: Tensor [B] — pre-mask probability mass assigned to invalid actions (diagnostic).
    @staticmethod
    def masked_probs_from_logits(raw_logits: torch.Tensor, action_mask: torch.Tensor, temperature: float=1.0):
        t = max(float(temperature), 1e-6)
        logits = raw_logits / t
        pre_probs = torch.softmax(logits, dim=-1)
        pre_probs = torch.nan_to_num(pre_probs, nan=0.0)

        invalid = (action_mask <= 0)
        masked_logits = logits.masked_fill(invalid, -1e9)
        masked_logits = torch.nan_to_num(masked_logits, nan=0.0, posinf=1e9, neginf=-1e9)
        probs = torch.softmax(masked_logits, dim=-1)
        probs = torch.nan_to_num(probs, nan=0.0)
        psum = probs.sum(dim=-1, keepdim=True)
        valid = (action_mask > 0).float()
        valid_count = valid.sum(dim=-1, keepdim=True).clamp_min(1.0)
        safe_uniform = valid / valid_count
        probs = torch.where(psum > 0, probs / psum, safe_uniform)

        invalid_mass_pre = (pre_probs * invalid.float()).sum(dim=-1).detach()
        return probs, invalid_mass_pre

# class ValueNet
# Function: Feedforward value function estimator V(s).
# Input (ctor): state_dim (int), hidden (tuple[int,int]).
# Output: forward(state)-> value (Tensor [*]) — scalar per state.
class ValueNet(nn.Module):
    def __init__(self, state_dim: int, hidden=(256,128)):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden[0]), nn.ReLU(),
            nn.Linear(hidden[0], hidden[1]), nn.ReLU(),
            nn.Linear(hidden[1], 1)
        )
    def forward(self, state: torch.Tensor) -> torch.Tensor:
        return self.net(state).squeeze(-1)

# --------------------------------
# 5) PPO Config (SAFER DEFAULTS + NEW KNOBS)
# --------------------------------

# @dataclass PPOConfig
# Function: Hyperparameter bundle for PPO training (LRs, clipping, KL control, annealing, etc.).
# Input: (all fields optional; defaults chosen for stability).
# Output: PPOConfig instance (no behavior).
@dataclass
class PPOConfig:
    # Learning rates & regularization
    actor_lr: float = 2e-4
    critic_lr: float = 1e-4
    weight_decay: float = 5e-5

    # Exploration controls
    entropy_coef_start: float = 0.05
    entropy_coef_end: float   = 0.005  # keep nonzero floor
    temperature: float        = 1.2

    # Sample regime / batch sizing
    total_episodes: int     = 30000
    episodes_per_batch: int = 640
    ppo_epochs: int         = 3
    minibatch_size: int     = 512

    # Credit assignment
    gamma: float      = 0.99
    gae_lambda: float = 0.95

    # Trust region / clipping
    clip_eps: float       = 0.12
    value_clip_eps: float = 0.2
    max_grad_norm: float  = 0.5

    # KL targeting & adaptation
    kl_target: float   = 0.008
    kl_adapt: bool     = True
    kl_high: float     = 1.25
    kl_low: float      = 0.95
    actor_lr_min: float = 5e-6
    actor_lr_max: float = 8e-4

    # Eval / early stop
    eval_every: int          = 20
    eval_repeats: int        = 5
    early_stop_patience: int = 40
    min_delta: float         = 0.00

    # New safety knobs
    normalize_adv: bool   = True   # normalize advantages by default
    clip_value_loss: bool = True   # always use value clipping
    anneal_clip: bool     = False  # linearly decay clip_eps during training
    entropy_decay: bool   = True   # linearly decay entropy coef

    # KL penalty extras (optional, alongside clip)
    use_kl_penalty: bool = True
    kl_coef_start: float = 0.0
    kl_coef_max: float   = 0.3

    # Reward clipping (tame shaped reward spikes)
    reward_clip_low: float  = -20.0
    reward_clip_high: float = 120.0

    # TRPO-style safety brake
    trpo_brake: bool      = True
    trpo_brake_mult: float = 2.5   # trigger if KL > mult × kl_target on a minibatch

    # Logging switches
    log_value_loss: bool = True
    log_policy_loss: bool = True

# --------------------------------
# 6) PPO Agent
# --------------------------------

# class PPOAgent
# Function: Tie env+models+optimizers together; collect rollouts; run PPO updates; evaluate.
# Input (ctor): env (WordleEnv), cfg (PPOConfig).
# Output: instance with .train() method returning history of mean returns.
class PPOAgent:
    def __init__(self, env: WordleEnv, cfg: PPOConfig):
        self.env = env
        self.cfg = cfg
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        torch.backends.cudnn.benchmark = True

        self.actor  = PolicyNet(env.encoder.state_dim, len(env.vocab)).to(self.device)
        self.critic = ValueNet(env.encoder.state_dim).to(self.device)

        self.opt_actor  = optim.Adam(self.actor.parameters(),  lr=cfg.actor_lr,  weight_decay=cfg.weight_decay)
        self.opt_critic = optim.Adam(self.critic.parameters(), lr=cfg.critic_lr, weight_decay=cfg.weight_decay)

        self.best_sr = -1.0
        self.eval_no_improve = 0
        self.lr_sched_actor  = optim.lr_scheduler.ReduceLROnPlateau(self.opt_actor,  mode='max', factor=0.5, patience=999, verbose=True)
        self.lr_sched_critic = optim.lr_scheduler.ReduceLROnPlateau(self.opt_critic, mode='max', factor=0.5, patience=999, verbose=True)

        # adaptive KL penalty weight
        self.kl_beta = cfg.kl_coef_start

    # Function: Compute GAE advantages and returns for a single episode trajectory.
    # Input: rewards (list[float]), values (list[float]), gamma (float), lam (float).
    # Output: (advantages: np.ndarray[T], returns: np.ndarray[T]).
    @staticmethod
    def _compute_gae(rewards, values, gamma, lam):
        T = len(rewards)
        adv = np.zeros(T, dtype=np.float32)
        last_gae = 0.0
        for t in reversed(range(T)):
            v_t = values[t]
            v_next = 0.0 if t == T-1 else values[t+1]
            delta = rewards[t] + gamma * v_next - v_t
            last_gae = delta + gamma * lam * last_gae
            adv[t] = last_gae
        returns = adv + np.array(values, dtype=np.float32)
        return adv, returns

    # Function: Collect one complete episode (sampling a target unless provided).
    # Input: target (Optional[str]) — if given, force this target for the episode.
    # Output: traj (list[dict]) — per-step dicts with state, mask, action, logp, value, reward, invalid_mass.
    def _collect_episode(self, target: Optional[str] = None):
        s, mask = self.env.reset(target=target)
        done = False
        traj = []
        while not done:
            state_b = s.unsqueeze(0).float().to(self.device)
            mask_b  = mask.unsqueeze(0).float().to(self.device)

            with torch.no_grad():
                raw_logits = self.actor(state_b)
                probs, invalid_mass = PolicyNet.masked_probs_from_logits(raw_logits, mask_b, self.cfg.temperature)
                a_idx = torch.multinomial(probs[0], 1).item()
                logp  = torch.log(probs[0, a_idx] + 1e-12).item()
                v     = self.critic(state_b)[0].item()

            s2, r, done, mask2, info = self.env.step(self.env.vocab[a_idx])

            traj.append({
                "s": s.numpy(),
                "mask": mask.numpy(),
                "a": a_idx,
                "logp": float(logp),
                "v": float(v),
                "r": float(r),
                "invalid_mass": float(invalid_mass[0].item()),
            })
            s, mask = s2, mask2
        return traj

    # Function: Collect a batch of episodes, flatten to tensors, and compute per-episode GAE.
    # Input: train_targets (list[str]) — sample targets from this set for each episode.
    # Output: batch (dict[str, torch.Tensor | int]) — tensors for PPO update.
    def _collect_batch(self, train_targets: List[str]):
        episodes = []
        while len(episodes) < self.cfg.episodes_per_batch:
            target = random.choice(train_targets)
            ep = self._collect_episode(target)
            episodes.append(ep)
        # Flatten
        states = np.concatenate([np.stack([t["s"] for t in ep], axis=0) for ep in episodes], axis=0)
        masks  = np.concatenate([np.stack([t["mask"] for t in ep], axis=0) for ep in episodes], axis=0)
        actions= np.concatenate([[t["a"] for t in ep] for ep in episodes], axis=0)
        logps  = np.concatenate([[t["logp"] for t in ep] for ep in episodes], axis=0)
        values = np.concatenate([[t["v"] for t in ep] for ep in episodes], axis=0)
        invmasses = np.concatenate([[t["invalid_mass"] for t in ep] for ep in episodes], axis=0)

        # GAE per-episode with reward clipping (bounded to tame spikes).
        adv_list, ret_list = [] , []
        for ep in episodes:
            r_ep = [t["r"] for t in ep]
            if self.cfg.reward_clip_low is not None and self.cfg.reward_clip_high is not None:
                r_ep = np.clip(r_ep, self.cfg.reward_clip_low, self.cfg.reward_clip_high).tolist()
            v_ep = [t["v"] for t in ep]
            adv_ep, ret_ep = self._compute_gae(r_ep, v_ep, self.cfg.gamma, self.cfg.gae_lambda)
            adv_list.append(adv_ep)
            ret_list.append(ret_ep)
        advantages = np.concatenate(adv_list, axis=0)
        returns    = np.concatenate(ret_list, axis=0)

        batch = {
            "states": torch.tensor(states, dtype=torch.float32, device=self.device),
            "masks":  torch.tensor(masks,  dtype=torch.float32, device=self.device),
            "actions":torch.tensor(actions, dtype=torch.long, device=self.device),
            "logp_old":torch.tensor(logps,  dtype=torch.float32, device=self.device),
            "values_old":torch.tensor(values, dtype=torch.float32, device=self.device),
            "advantages":torch.tensor(advantages, dtype=torch.float32, device=self.device),
            "returns":torch.tensor(returns, dtype=torch.float32, device=self.device),
            "invalid_mass":torch.tensor(invmasses, dtype=torch.float32, device=self.device),
            "num_episodes": len(episodes),
            "num_steps": states.shape[0],
        }
        return batch

    # Function: One PPO update over the provided batch (multi-epoch, mini-batches).
    # Input: batch (dict tensors), entropy_coef (float), clip_eps_curr (float; may be tightened).
    # Output: (mean_kl, mean_entropy, mean_policy_loss, mean_value_loss, clip_used)
    def _ppo_update(self, batch, entropy_coef: float, clip_eps_curr: float):
        adv = batch["advantages"]
        if self.cfg.normalize_adv:
            std = adv.std(unbiased=False)
            if torch.isnan(std) or std < 1e-8:
                std = adv.new_tensor(1.0)
            adv = (adv - adv.mean()) / std
            adv = adv.clamp_(-10.0, 10.0)

        states = batch["states"]; masks = batch["masks"]; actions = batch["actions"]
        logp_old = batch["logp_old"]; values_old = batch["values_old"]
        returns  = batch["returns"]

        N = states.size(0)
        idxs = torch.randperm(N, device=self.device)
        mb = self.cfg.minibatch_size

        with torch.no_grad():
            old_raw_logits = self.actor(states).detach()
            old_probs_all, _ = PolicyNet.masked_probs_from_logits(old_raw_logits, masks, temperature=self.cfg.temperature)

        approx_kl_list, entropy_list = [], []
        pol_loss_list, val_loss_list = [], []

        for _ in range(self.cfg.ppo_epochs):
            for start in range(0, N, mb):
                mb_idx = idxs[start:start+mb]
                s_mb = states[mb_idx]
                m_mb = masks[mb_idx]
                a_mb = actions[mb_idx]
                logp_old_mb = logp_old[mb_idx]
                v_old_mb = values_old[mb_idx]
                ret_mb = returns[mb_idx]
                adv_mb = adv[mb_idx]

                raw_logits = self.actor(s_mb)
                probs, _ = PolicyNet.masked_probs_from_logits(raw_logits, m_mb, temperature=self.cfg.temperature)
                logp = torch.log(probs.gather(1, a_mb.unsqueeze(1)).squeeze(1) + 1e-12)

                # --- KL for logging/brake (no grad) ---
                with torch.no_grad():
                    pre_ng = probs + 1e-12
                    old_ng = old_probs_all[mb_idx] + 1e-12
                    kl_vec_ng = (old_ng * (torch.log(old_ng) - torch.log(pre_ng))).sum(dim=1)
                    kl_now = kl_vec_ng.mean().item()
                    approx_kl_list.append(kl_now)

                    # TRPO-style safety brake: skip actor update if minibatch jumps too far
                    brake = False
                    if self.cfg.trpo_brake and kl_now > self.cfg.trpo_brake_mult * self.cfg.kl_target:
                        brake = True
                        # 1) Immediate LR cut (actor)
                        if self.cfg.kl_adapt:
                            a_lr = self.opt_actor.param_groups[0]['lr']
                            a_lr = max(self.cfg.actor_lr_min, a_lr * 0.5)   # stronger halving
                            self.opt_actor.param_groups[0]['lr'] = a_lr
                        # 2) Make KL penalty actually matter
                        self.kl_beta = min(self.cfg.kl_coef_max, max(self.kl_beta, 0.05) * 2.0)
                        # 3) Tighten the clip for the rest of THIS epoch (local cooldown)
                        clip_eps_curr = max(0.08, clip_eps_curr * 0.75)
                        # 4) Optional: remember we braked this epoch (to ease back later)
                        if not hasattr(self, "_brake_hits"):
                            self._brake_hits = 0
                        self._brake_hits += 1
                        print(f"[Brake] minibatch KL={kl_now:.5f} > {self.cfg.trpo_brake_mult:.2f}×target; "
                              f"skip actor; lrA={self.opt_actor.param_groups[0]['lr']:.2e} β_KL={self.kl_beta:.3f} "
                              f"clip={clip_eps_curr:.3f}")

                # --- PPO clipped surrogate ---
                ratio = torch.exp(logp - logp_old_mb)
                surr1 = ratio * adv_mb
                surr2 = torch.clamp(ratio, 1.0 - clip_eps_curr, 1.0 + clip_eps_curr) * adv_mb
                clipped_obj = -torch.min(surr1, surr2).mean()

                # --- Differentiable KL penalty (old probs detached) ---
                policy_loss = clipped_obj
                if self.cfg.use_kl_penalty and self.kl_beta > 0.0:
                    pre = probs + 1e-12
                    old = (old_probs_all[mb_idx] + 1e-12).detach()
                    kl_vec = (old * (torch.log(old) - torch.log(pre))).sum(dim=1)
                    policy_loss = policy_loss + self.kl_beta * kl_vec.mean()

                # --- Value loss (with optional clipping) ---
                values = self.critic(s_mb)
                if self.cfg.clip_value_loss:
                    v_clipped = v_old_mb + (values - v_old_mb).clamp(-self.cfg.value_clip_eps, self.cfg.value_clip_eps)
                    value_loss_unclipped = (values - ret_mb).pow(2)
                    value_loss_clipped   = (v_clipped - ret_mb).pow(2)
                    value_loss = 0.5 * torch.max(value_loss_unclipped, value_loss_clipped).mean()
                else:
                    value_loss = 0.5 * (values - ret_mb).pow(2).mean()

                # --- Entropy bonus ---
                entropy = -(probs * torch.log(probs + 1e-12)).sum(dim=1).mean()
                entropy_list.append(entropy.item())

                # --- Optimizer steps ---
                if not brake:
                    self.opt_actor.zero_grad()
                    (policy_loss - entropy_coef * entropy).backward()
                    nn.utils.clip_grad_norm_(self.actor.parameters(), self.cfg.max_grad_norm)
                    self.opt_actor.step()

                self.opt_critic.zero_grad()
                value_loss.backward()
                nn.utils.clip_grad_norm_(self.critic.parameters(), self.cfg.max_grad_norm)
                self.opt_critic.step()

                pol_loss_list.append(policy_loss.item())
                val_loss_list.append(value_loss.item())

        mean_kl = float(np.mean(approx_kl_list)) if approx_kl_list else 0.0
        mean_entropy = float(np.mean(entropy_list)) if entropy_list else 0.0
        mean_pol = float(np.mean(pol_loss_list)) if pol_loss_list else 0.0
        mean_val = float(np.mean(val_loss_list)) if val_loss_list else 0.0

        # After the loops: soften any emergency tightening from the brake, slightly.
        if hasattr(self, "_brake_hits") and self._brake_hits > 0:
            self.kl_beta = max(0.0, self.kl_beta * 0.9)
            self._brake_hits = 0

        # IMPORTANT: return the clip that was actually used (after any tightening).
        return mean_kl, mean_entropy, mean_pol, mean_val, clip_eps_curr

    # Function: Main training loop — collect batches, update via PPO, evaluate, early-stop.
    # Input: train_targets (list[str]), eval_targets (optional list[str]).
    # Output: hist_returns (list[float]) — mean return per batch across training.
    def train(self, train_targets: List[str], eval_targets: Optional[List[str]] = None):
        cfg = self.cfg
        episodes_done = 0
        hist_returns = []
        step_batches = 0

        while episodes_done < cfg.total_episodes:
            batch = self._collect_batch(train_targets)
            episodes_done += batch["num_episodes"]
            hist_returns.append(batch["returns"].mean().item())

            # Linear entropy anneal (with floor).
            if cfg.entropy_decay:
                frac = min(1.0, episodes_done / cfg.total_episodes)
                entropy_coef = cfg.entropy_coef_start + frac * (cfg.entropy_coef_end - cfg.entropy_coef_start)
            else:
                entropy_coef = cfg.entropy_coef_start

            # Clip anneal (optional).
            if cfg.anneal_clip:
                frac = min(1.0, episodes_done / cfg.total_episodes)
                clip_eps_curr = cfg.clip_eps * (1.0 - 0.5*frac)  # decay 50% over training
                clip_eps_curr = max(0.08, clip_eps_curr)
            else:
                clip_eps_curr = cfg.clip_eps

            mean_kl, mean_entropy, mean_pol, mean_val, clip_used = self._ppo_update(batch, entropy_coef, clip_eps_curr)
            step_batches += 1

            # Adaptive KL-based LR tweak (actor) and KL penalty weight.
            if cfg.kl_adapt:
                a_lr = self.opt_actor.param_groups[0]['lr']
                hot = (mean_kl > cfg.kl_high * cfg.kl_target)
                cold = (mean_kl < cfg.kl_low  * cfg.kl_target)
                if hot:
                    a_lr = max(cfg.actor_lr_min, a_lr * 0.5)
                    self.kl_beta = min(cfg.kl_coef_max, max(self.kl_beta, 0.05) * 1.5)
                elif cold:
                    if mean_kl < 0.5 * cfg.kl_target:
                        a_lr = min(cfg.actor_lr_max, a_lr * 1.8)
                    else:
                        a_lr = min(cfg.actor_lr_max, a_lr * 1.2)
                    self.kl_beta = max(0.0, self.kl_beta * 0.7)
                self.opt_actor.param_groups[0]['lr'] = a_lr
            
            if step_batches % 5 == 0:
                # Controller heartbeat
                print(f"[KLctl] KL={mean_kl:.5f}/{cfg.kl_target:.5f}  "
                      f"lrA={self.opt_actor.param_groups[0]['lr']:.2e}  "
                      f"beta={self.kl_beta:.3f}")
            
                # PPO summary
                tail = hist_returns[-10:] if len(hist_returns) >= 10 else hist_returns
                print(
                    f"[PPO] batches={step_batches} ep_done={episodes_done}/{cfg.total_episodes} "
                    f"avg_ret_last10={np.mean(tail):.2f}  KL={mean_kl:.6f}  H={mean_entropy:.3f}  "
                    f"β_KL={self.kl_beta:.3f} clip={clip_used:.3f}  "
                    f"pre_mask_invalid_mass_mean={batch['invalid_mass'].mean().item():.4f}  "
                    + (f"Vloss={mean_val:.3f}  " if cfg.log_value_loss else "")
                    + (f"Ploss={mean_pol:.3f}" if cfg.log_policy_loss else "")
                )

            # periodic eval
            if step_batches == 1 or step_batches % cfg.eval_every == 0:
                val_targets = eval_targets if eval_targets is not None else train_targets
                sr, avg_t, ci = evaluate_greedy(self, val_targets, temperature=0.3, repeats=cfg.eval_repeats, ci=True)
                print(f"[Eval] SR={sr:.3f}±{ci:.3f}  avg_turns={avg_t:.2f}  lrA={self.opt_actor.param_groups[0]['lr']:.2e} lrC={self.opt_critic.param_groups[0]['lr']:.2e}")
                self.lr_sched_actor.step(sr)
                self.lr_sched_critic.step(sr)
                improved = (sr > self.best_sr + cfg.min_delta)
                if improved:
                    self.best_sr = sr
                    torch.save({'actor': self.actor.state_dict(),
                                'critic': self.critic.state_dict(),
                                'sr': sr, 'episodes_done': episodes_done},
                               "ppo_wordle_best.pt")
                    self.eval_no_improve = 0
                else:
                    self.eval_no_improve += 1
                    if self.eval_no_improve >= cfg.early_stop_patience:
                        print("[EarlyStop] No val SR improvement; stopping training.")
                        break
        return hist_returns

# --------------------------
# 7) Greedy evaluation (+ CI)
# --------------------------

# Function: Evaluate the trained policy greedily (argmax) over a set of targets.
# Input: agent (PPOAgent), targets (list[str]), temperature (float), repeats (int), ci (bool).
# Output:
#   - if ci=True: (sr, avg_t, 95% CI width) where sr is success rate, avg_t is avg turns on successes.
#   - else: (sr, avg_t).
@torch.no_grad()
def evaluate_greedy(agent, targets: List[str], temperature: float = 0.5, repeats: int = 1, ci: bool=False):
    rng_state = random.getstate()
    wins, turns = 0, []
    for rep in range(repeats):
        random.seed(1234 + rep)  # fixed seeds for stability across repeats
        for tgt in targets:
            s, mask = agent.env.reset(target=tgt)
            done = False; t = 0
            while not done:
                t += 1
                state = s.unsqueeze(0).float().to(agent.device)
                mask_b = mask.unsqueeze(0).float().to(agent.device)
                raw_logits = agent.actor(state)
                probs, _ = PolicyNet.masked_probs_from_logits(raw_logits, mask_b, temperature=temperature)
                a_idx = torch.argmax(probs[0]).item()
                word = agent.env.vocab[a_idx]
                s, r, done, mask, info = agent.env.step(word)
            solved = (agent.env.history and agent.env.history[-1][0] == tgt)
            if solved:
                wins += 1; turns.append(t)
    random.setstate(rng_state)
    n = len(targets) * repeats if targets else 0
    sr = wins / n if n else 0.0
    avg_t = float(np.mean(turns)) if turns else float('inf')
    if ci:
        if n == 0:
            return sr, avg_t, 0.0
        se = math.sqrt(sr*(1-sr)/n)
        return sr, avg_t, 1.96*se
    return sr, avg_t

# --------------------------
# 8) Visualization helpers
# --------------------------

# Function: Convert feedback tuple to Unicode square emojis.
# Input: fb_tuple (iterable of ints in {0,1,2}).
# Output: str — e.g., "⬛🟨🟩⬛⬛".
EMO = {0: "⬛", 1: "🟨", 2: "🟩"}

def fb_to_squares(fb_tuple):
    return "".join(EMO[int(v)] for v in fb_tuple)

# Function: Roll out a greedy (or stochastic) episode on a fixed target and print step-by-step.
# Input: agent (PPOAgent), target (str), temperature (float), greedy (bool).
# Output: none (pretty-prints guesses, feedback, reward, and candidate count).
@torch.no_grad()
def visualize_episode(agent, target: str, temperature: float = 0.5, greedy: bool = True):
    env = agent.env
    s, mask = env.reset(target=target)
    done, t = False, 0
    print(f"=== Visualize: target = {target.upper()} ===")
    while not done:
        t += 1
        state  = s.unsqueeze(0).float().to(agent.device)
        mask_b = mask.unsqueeze(0).float().to(agent.device)
        raw_logits = agent.actor(state)
        probs, _  = PolicyNet.masked_probs_from_logits(raw_logits, mask_b, temperature=temperature)
        a_idx  = torch.argmax(probs[0]).item() if greedy else torch.multinomial(probs[0], 1).item()
        guess  = env.vocab[a_idx]
        s, r, done, mask, info = env.step(guess)
        _, fb = info
        fb_str = fb_to_squares(fb)
        cand_cnt = len(env.prev_posterior) if hasattr(env, "prev_posterior") else float("nan")
        print(f"t{t}: {guess.upper():<8}  {fb_str}   r={r:6.2f}   candidates≈{cand_cnt}")
    solved = (env.history and env.history[-1][0] == target)
    print(f"Result: {'✅ SOLVED' if solved else '❌ FAILED'} in {t} turns.")

# -------------------------
# 9) Main
# -------------------------
if __name__ == "__main__":
    random.seed(0); np.random.seed(0); torch.manual_seed(0)

    # Function: Optional deterministic mode for debugging / reproducibility.
    # Input: env var DETERMINISTIC="1" enables cudnn deterministic algorithms.
    # Output: none (side-effect: slower but bitwise-stable kernels where possible).
    if os.environ.get("DETERMINISTIC", "0") == "1":
        torch.backends.cudnn.benchmark = False
        torch.use_deterministic_algorithms(True, warn_only=True)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)
        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"

    # Load vocab/solutions; split into train/val/test by targets.
    vocab, solutions = load_wordle_lists(cache_dir=".wordlists")

    random.shuffle(solutions)
    n_sol   = len(solutions)
    n_train = int(0.80 * n_sol)
    n_val   = int(0.10 * n_sol)
    n_test  = n_sol - n_train - n_val
    train_targets = solutions[:n_train]
    val_targets   = solutions[n_train:n_train+n_val]
    test_targets  = solutions[n_train+n_val:n_train+n_val+n_test]
    assert len(train_targets) + len(val_targets) + len(test_targets) == n_sol

    print(f"CUDA available: {torch.cuda.is_available()}")
    print(f"Split sizes -> train: {len(train_targets)}, val: {len(val_targets)}, test: {len(test_targets)}")

    # Reward shaping coefficients (tuned for stability and signal richness).
    coeffs = RewardCoeffs(
        alpha_diversity=0.5, beta_repeat=2.0,
        gamma_green_sq=1.0, delta_yellow=0.5,
        lambda_entropy=0.5, per_step_cost=2.0,
        win_bonus=100.0, win_shape_k=2.5,
        fail_penalty_base=20.0, fail_penalty_scale=20.0,
        entropy_start_turn=3
    )

    # NOTE: WordleEnv currently encodes only last 5 steps (StateEncoder(5)) even though max_turns=6.
    env = WordleEnv(vocab=vocab, coeffs=coeffs, mask_mode="hard", max_turns=6)

    cfg = PPOConfig()  # Safer defaults (entropy floor, KL control, value clip, etc.)
    agent = PPOAgent(env, cfg)

    print(f"Training on {len(train_targets)} targets; validating on {len(val_targets)}; holding out {len(test_targets)} for FINAL test.")
    agent.train(train_targets, eval_targets=val_targets)

    # === Final TEST (touch ONCE) ===
    # Function: Load best checkpoint (if present) and evaluate on held-out test set.
    # Input: ckpt_path (str) — "ppo_wordle_best.pt" saved during training.
    # Output: prints test SR ± CI and avg turns; no state is persisted beyond printed results.
    ckpt_path = "ppo_wordle_best.pt"
    if os.path.exists(ckpt_path):
        ckpt = torch.load(ckpt_path, map_location=agent.device)
        agent.actor.load_state_dict(ckpt['actor'])
        agent.critic.load_state_dict(ckpt['critic'])
        print(f"Loaded best checkpoint: episodes_done={ckpt.get('episodes_done','?')}, val_SR={ckpt.get('sr','?'):.3f}")

    sr, avg_t, ci = evaluate_greedy(agent, test_targets, temperature=0.3, repeats=1, ci=True)
    print(f"[FINAL TEST] success_rate={sr:.3f}±{ci:.3f}, avg_turns={avg_t:.2f} on {len(test_targets)} held-out targets (evaluated once)")

    print("Visualizing a few held-out *test* targets...")
    if len(test_targets) > 0:
        for tgt in random.sample(test_targets, k=min(10, len(test_targets))):
            visualize_episode(agent, target=tgt, temperature=0.5, greedy=True)
        if "crane" in env.vocab:
            visualize_episode(agent, target="crane", temperature=0.5, greedy=True)


CUDA available: True
Split sizes -> train: 1852, val: 231, test: 232




Training on 1852 targets; validating on 231; holding out 232 for FINAL test.
[Eval] SR=0.861±0.020  avg_turns=4.63  lrA=3.60e-04 lrC=1.00e-04
[Brake] minibatch KL=0.02040 > 2.50×target; skip actor; lrA=4.00e-04 β_KL=0.100 clip=0.090
[KLctl] KL=0.00302/0.00800  lrA=7.20e-04  beta=0.063
[PPO] batches=5 ep_done=3200/30000 avg_ret_last10=21.52  KL=0.003021  H=4.523  β_KL=0.063 clip=0.120  pre_mask_invalid_mass_mean=0.7694  Vloss=335.481  Ploss=-0.038
[Brake] minibatch KL=0.08261 > 2.50×target; skip actor; lrA=3.60e-04 β_KL=0.126 clip=0.090
[KLctl] KL=0.00677/0.00800  lrA=8.00e-04  beta=0.041
[PPO] batches=10 ep_done=6400/30000 avg_ret_last10=21.81  KL=0.006767  H=4.531  β_KL=0.041 clip=0.120  pre_mask_invalid_mass_mean=0.7678  Vloss=309.940  Ploss=-0.047
[Brake] minibatch KL=0.02344 > 2.50×target; skip actor; lrA=4.00e-04 β_KL=0.100 clip=0.090
[Brake] minibatch KL=0.03109 > 2.50×target; skip actor; lrA=2.00e-04 β_KL=0.200 clip=0.080
[KLctl] KL=0.00556/0.00800  lrA=7.00e-04  beta=0.065
[PPO

  ckpt = torch.load(ckpt_path, map_location=agent.device)


Loaded best checkpoint: episodes_done=25600, val_SR=0.913
[FINAL TEST] success_rate=0.909±0.037, avg_turns=4.47 on 232 held-out targets (evaluated once)
Visualizing a few held-out *test* targets...
=== Visualize: target = WORRY ===
t1: BLEAH     ⬛⬛⬛⬛⬛   r=  0.50   candidates≈2236
t2: TRUCK     ⬛🟨⬛⬛⬛   r=  0.50   candidates≈151
t3: ROOPY     🟨🟩⬛⬛🟩   r=  2.53   candidates≈9
t4: WORRY     🟩🟩🟩🟩🟩   r= 23.76   candidates≈1
Result: ✅ SOLVED in 4 turns.
=== Visualize: target = STOKE ===
t1: BLEAH     ⬛⬛🟨⬛⬛   r=  0.50   candidates≈2213
t2: TRITE     🟨⬛⬛⬛🟩   r=  0.50   candidates≈11
t3: STONE     🟩🟩🟩⬛🟩   r=  2.94   candidates≈3
t4: STOPE     🟩🟩🟩⬛🟩   r= 16.79   candidates≈2
t5: STOVE     🟩🟩🟩⬛🟩   r= 17.00   candidates≈1
t6: STOKE     🟩🟩🟩🟩🟩   r= 17.63   candidates≈1
Result: ✅ SOLVED in 6 turns.
=== Visualize: target = WALTZ ===
t1: BLEAH     ⬛🟨⬛🟨⬛   r=  0.50   candidates≈428
t2: TRAWL     🟨⬛🟨🟨🟨   r=  1.50   candidates≈2
t3: WALTZ     🟩🟩🟩🟩🟩   r= 39.29   candidates≈1
Result: ✅ SOLVED in 3 turns.
=== 