In [2]:
import numpy as np
from dataclasses import dataclass, field
from typing import Optional, Tuple, Dict

# -----------------------------
# TicTacToe environment (3x3)
# -----------------------------
class TicTacToe:
    """
    Board: 3x3 np.int8
      1 -> 'X', -1 -> 'O', 0 -> empty
    Player to move: 1 or -1
    """
    def __init__(self, board: Optional[np.ndarray] = None, to_move: int = 1):
        self.board = np.zeros((3,3), dtype=np.int8) if board is None else board.copy()
        self.to_move = int(to_move)

    def clone(self) -> "TicTacToe":
        return TicTacToe(self.board, self.to_move)

    def legal_actions(self) -> np.ndarray:
        """Actions are flat indices [0..8] for empty cells."""
        return np.flatnonzero(self.board.ravel() == 0)

    def step(self, action: int) -> "TicTacToe":
        """Return next state after placing current player's mark at action."""
        r, c = divmod(int(action), 3)
        if self.board[r, c] != 0:
            raise ValueError("Illegal action")
        nxt = self.clone()
        nxt.board[r, c] = self.to_move
        nxt.to_move = -self.to_move
        return nxt

    def is_terminal(self) -> bool:
        return self.winner() is not None or self.legal_actions().size == 0

    def winner(self) -> Optional[int]:
        """Return 1, -1 for winner; 0 for draw; None if not terminal."""
        lines = []
        lines.extend([self.board[i, :] for i in range(3)])
        lines.extend([self.board[:, j] for j in range(3)])
        lines.append(np.diag(self.board))
        lines.append(np.diag(np.fliplr(self.board)))
        for line in lines:
            s = np.sum(line)
            if s == 3:  return 1
            if s == -3: return -1
        if np.all(self.board != 0):
            return 0
        return None

    def result_from(self, player: int) -> float:
        """
        Return result from 'player' perspective at terminal:
          1.0 win, 0.0 draw, -1.0 loss
        """
        w = self.winner()
        if w is None:
            raise ValueError("Not terminal")
        if w == 0: return 0.0
        return 1.0 if w == player else -1.0

    def __repr__(self) -> str:
        s = {1: "X", -1: "O", 0: "."}
        rows = [" ".join(s[int(x)] for x in row) for row in self.board]
        return "\n".join(rows) + f"\nTo move: {'X' if self.to_move==1 else 'O'}"


# -----------------------------
# MCTS (UCT)
# -----------------------------
@dataclass
class Node:
    state_key: Tuple  # a hashable representation of the state
    to_move: int      # player to move at this node
    parent: Optional["Node"] = None
    parent_action: Optional[int] = None
    children: Dict[int, "Node"] = field(default_factory=dict)

    N: int = 0         # visit count
    W: float = 0.0     # total value (from root's POV)
    Q: float = 0.0     # mean value

    untried_actions: Optional[np.ndarray] = None

def state_to_key(env: TicTacToe) -> Tuple:
    # Compact, hashable key: (tuple(board_flat), to_move)
    return (tuple(env.board.ravel().tolist()), env.to_move)

class MCTS:
    def __init__(self, c_uct: float = np.sqrt(2.0), rollout_depth: int = 20, rng: Optional[np.random.Generator] = None):
        self.c_uct = float(c_uct)
        self.rollout_depth = int(rollout_depth)
        self.rng = rng if rng is not None else np.random.default_rng()

    def search(self, root_env: TicTacToe, n_simulations: int = 200) -> int:
        """Run MCTS from root_env and return the best action by visit count."""
        root = Node(state_key=state_to_key(root_env), to_move=root_env.to_move)
        root.untried_actions = root_env.legal_actions()

        for _ in range(n_simulations):
            self._simulate(root_env, root, root_player=root_env.to_move)

        # Choose action with max visit count
        if root.untried_actions is not None and root.untried_actions.size == 9:
            # Edge case: empty board; children might still be empty if simulations were zero.
            pass

        best_action, best_visits = None, -1
        for a, child in root.children.items():
            if child.N > best_visits:
                best_visits = child.N
                best_action = a

        # If no children expanded (shouldn't happen unless n_simulations==0), pick random legal
        if best_action is None:
            acts = root_env.legal_actions()
            best_action = int(self.rng.choice(acts))

        return int(best_action)

    # ---------- One MCTS iteration ----------
    def _simulate(self, root_env: TicTacToe, root: Node, root_player: int):
        # 1) Selection
        node = root
        env = root_env.clone()

        while node.untried_actions is not None and node.untried_actions.size == 0 and node.children:
            a = self._select_uct(node)
            env = env.step(a)
            node = node.children[a]

        # 2) Expansion
        if node.untried_actions is None:
            node.untried_actions = env.legal_actions()
        if node.untried_actions.size > 0 and not env.is_terminal():
            # Pick one untried action uniformly
            idx = self.rng.integers(0, node.untried_actions.size)
            a = int(node.untried_actions[idx])
            # Remove it from untried
            node.untried_actions = np.delete(node.untried_actions, idx)
            env_next = env.step(a)
            child = Node(
                state_key=state_to_key(env_next),
                to_move=env_next.to_move,
                parent=node,
                parent_action=a
            )
            child.untried_actions = env_next.legal_actions()
            node.children[a] = child
            node = child
            env = env_next

        # 3) Simulation (rollout)
        value = self._rollout(env, root_player)

        # 4) Backpropagation
        self._backpropagate(node, value)

    def _select_uct(self, node: Node) -> int:
        """Select child that maximizes UCT."""
        assert node.children, "Selection requires children"
        log_N = np.log(max(1, node.N))
        best_a, best_score = None, -1e9
        for a, child in node.children.items():
            if child.N == 0:
                uct = np.inf
            else:
                uct = child.Q + self.c_uct * np.sqrt(log_N / child.N)
            if uct > best_score:
                best_score = uct
                best_a = a
        return int(best_a)

    def _rollout(self, env: TicTacToe, root_player: int) -> float:
        """Random playout until terminal or depth cutoff."""
        depth = 0
        e = env.clone()
        while not e.is_terminal() and depth < self.rollout_depth:
            acts = e.legal_actions()
            if acts.size == 0:  # draw
                break
            a = int(self.rng.choice(acts))
            e = e.step(a)
            depth += 1

        if e.is_terminal():
            return e.result_from(root_player)
        else:
            # Non-terminal cutoff -> simple heuristic (0.0 = drawish)
            return 0.0

    def _backpropagate(self, node: Node, value: float):
        """Accumulate value up to root (value is from root player's perspective)."""
        cur = node
        while cur is not None:
            cur.N += 1
            cur.W += value
            cur.Q = cur.W / cur.N
            cur = cur.parent


# -----------------------------
# Example usage
# -----------------------------
if __name__ == "__main__":
    env = TicTacToe()  # empty board, X to move (1)
    mcts = MCTS(c_uct=np.sqrt(2), rollout_depth=20)

    print("Initial state:")
    print(env)
    for i in range(10):
        # Let MCTS pick a move for X
        action = mcts.search(env, n_simulations=1000)
        print(f"\nMCTS chooses action (0..8): {action}")
        env = env.step(action)
        print("\nAfter X plays:")
        print(env)
    
        # Now choose a reply for O (also with MCTS)
        action_O = mcts.search(env, n_simulations=1000)
        print(f"\nMCTS chooses O action: {action_O}")
        env = env.step(action_O)
        print("\nAfter O plays:")
        print(env)


Initial state:
. . .
. . .
. . .
To move: X

MCTS chooses action (0..8): 4

After X plays:
. . .
. X .
. . .
To move: O

MCTS chooses O action: 5

After O plays:
. . .
. X O
. . .
To move: X

MCTS chooses action (0..8): 2

After X plays:
. . X
. X O
. . .
To move: O

MCTS chooses O action: 6

After O plays:
. . X
. X O
O . .
To move: X

MCTS chooses action (0..8): 0

After X plays:
X . X
. X O
O . .
To move: O

MCTS chooses O action: 8

After O plays:
X . X
. X O
O . O
To move: X

MCTS chooses action (0..8): 1

After X plays:
X X X
. X O
O . O
To move: O

MCTS chooses O action: 7

After O plays:
X X X
. X O
O O O
To move: X

MCTS chooses action (0..8): 3

After X plays:
X X X
X X O
O O O
To move: O


ValueError: a cannot be empty unless no samples are taken

In [23]:
# ================================================
# Neural-guided MCTS with JAX + Flax (TicTacToe)
# ================================================
import numpy as np
from dataclasses import dataclass, field
from typing import Optional, Dict, Tuple

# ---------- JAX / Flax ----------
import jax
import jax.numpy as jnp
from flax import linen as nn
from functools import partial

# =========================================================
# Minimal TicTacToe env (same API as before)
# =========================================================
class TicTacToe:
    """
    Board: 3x3 np.int8
      1 -> 'X', -1 -> 'O', 0 -> empty
    Player to move: 1 or -1
    """
    def __init__(self, board: Optional[np.ndarray] = None, to_move: int = 1):
        self.board = np.zeros((3,3), dtype=np.int8) if board is None else board.copy()
        self.to_move = int(to_move)

    def clone(self): return TicTacToe(self.board, self.to_move)

    def legal_actions(self) -> np.ndarray:
        return np.flatnonzero(self.board.ravel() == 0)

    def step(self, action: int) -> "TicTacToe":
        r, c = divmod(int(action), 3)
        if self.board[r, c] != 0: raise ValueError("Illegal action")
        nxt = self.clone()
        nxt.board[r, c] = self.to_move
        nxt.to_move = -self.to_move
        return nxt

    def is_terminal(self) -> bool:
        return self.winner() is not None or self.legal_actions().size == 0

    def winner(self) -> Optional[int]:
        b = self.board
        lines = list(b) + list(b.T) + [np.diag(b), np.diag(np.fliplr(b))]
        for line in lines:
            s = int(np.sum(line))
            if s == 3:  return 1
            if s == -3: return -1
        if np.all(b != 0): return 0
        return None

    def result_from(self, player: int) -> float:
        w = self.winner()
        if w is None: raise ValueError("Not terminal")
        if w == 0: return 0.0
        return 1.0 if w == player else -1.0

    def __repr__(self) -> str:
        s = {1:"X",-1:"O",0:"."}
        rows = [" ".join(s[int(x)] for x in row) for row in self.board]
        return "\n".join(rows) + f"\nTo move: {'X' if self.to_move==1 else 'O'}"

# =========================================================
# AlphaZero-style PUCT MCTS (unchanged)
# =========================================================
@dataclass
class Node:
    key: Tuple
    to_move: int
    parent: Optional["Node"] = None
    parent_action: Optional[int] = None

    children: Dict[int, "Node"] = field(default_factory=dict)  # action -> child node
    P: Dict[int, float] = field(default_factory=dict)          # action -> prior prob

    N: int = 0
    W: float = 0.0
    Q: float = 0.0

    def expanded(self) -> bool:
        return len(self.P) > 0

def state_to_key(env: TicTacToe) -> Tuple:
    return (tuple(env.board.ravel().tolist()), env.to_move)

class PUCT_MCTS:
    """
    Selection: a* = argmax[ Q(s,a) + c_puct * P(s,a) * sqrt(sum_b N(s,b)) / (1 + N(s,a)) ]
    Expansion: add new child with policy prior
    Evaluation: value network v(s) (from current player POV)
    Backup: update stats with conversion to root POV
    """
    def __init__(self, policy_value_fn, c_puct: float = 1.5, dirichlet_alpha: float = 0.3,
                 root_noise_frac: float = 0.25, rng: Optional[np.random.Generator] = None):
        self.policy_value_fn = policy_value_fn
        self.c_puct = float(c_puct)
        self.dirichlet_alpha = float(dirichlet_alpha)
        self.root_noise_frac = float(root_noise_frac)
        self.rng = rng if rng is not None else np.random.default_rng()
        self.last_root = None

    def search(self, root_env: TicTacToe, n_simulations: int = 400, temperature: float = 1.0) -> int:
        root = Node(key=state_to_key(root_env), to_move=root_env.to_move)
        self._expand(root_env, root)
        self._add_root_dirichlet_noise(root)

        for _ in range(n_simulations):
            self._simulate(root_env, root, root_player=root_env.to_move)
    
        self.last_root = root
        return self._select_action_from_visits(root, temperature)


    def _simulate(self, root_env: TicTacToe, root: Node, root_player: int):
        node = root
        env = root_env.clone()
        path = []

        while True:
            if env.is_terminal():
                value = env.result_from(root_player)
                self._backup(path, leaf_value=value)
                return

            a = self._puct_select(node)
            path.append((node, a))

            if a not in node.children:
                env = env.step(a)
                child = Node(key=state_to_key(env), to_move=env.to_move, parent=node, parent_action=a)
                priors, leaf_val_from_to_move = self.policy_value_fn(env)
                child.P = priors
                node.children[a] = child
                self._backup(path, leaf_value=leaf_val_from_to_move, leaf_env_to_move=env.to_move, root_player=root_player)
                return

            env = env.step(a)
            node = node.children[a]

    def _puct_select(self, node: Node) -> int:
        sum_N = max(1, sum(child.N for child in node.children.values()))
        best_a, best_score = None, -1e18
        for a, p in node.P.items():
            child = node.children.get(a)
            Nsa = 0 if child is None else child.N
            Qsa = 0.0 if child is None else child.Q
            u = self.c_puct * p * np.sqrt(sum_N) / (1.0 + Nsa)
            score = Qsa + u
            if score > best_score:
                best_score = score
                best_a = a
        return int(best_a)

    def _expand(self, env: TicTacToe, node: Node):
        priors, _ = self.policy_value_fn(env)
        node.P = priors

    def _add_root_dirichlet_noise(self, root: Node):
        if not root.P: return
        actions = list(root.P.keys())
        alpha = self.dirichlet_alpha
        noise = self.rng.dirichlet([alpha] * len(actions))
        for a, eps in zip(actions, noise):
            root.P[a] = (1 - self.root_noise_frac) * root.P[a] + self.root_noise_frac * float(eps)

    def _backup(self, path, leaf_value: float, leaf_env_to_move: Optional[int] = None, root_player: Optional[int] = None):
        if leaf_env_to_move is None or root_player is None:
            v_root = leaf_value
            for node, _ in reversed(path):
                node.N += 1; node.W += v_root; node.Q = node.W / node.N
                v_root = -v_root
            return

        same = 1.0 if leaf_env_to_move == root_player else -1.0
        v_root = leaf_value * same
        for node, _ in reversed(path):
            node.N += 1; node.W += v_root; node.Q = node.W / node.N
            v_root = -v_root

    def _select_action_from_visits(self, root: Node, temperature: float) -> int:
        acts = np.array(list(root.P.keys()), dtype=int)
        visits = np.array([root.children[a].N if a in root.children else 0 for a in acts], dtype=float)
        if temperature <= 1e-8:
            return int(acts[np.argmax(visits)])
        with np.errstate(divide='ignore', invalid='ignore'):
            pi = np.power(visits, 1.0 / temperature)
        if np.all(pi == 0): pi = np.ones_like(pi)
        pi = pi / np.sum(pi)
        return int(np.random.default_rng().choice(acts, p=pi))

# =========================================================
# Flax policy/value network
# =========================================================
def make_features(board_np: np.ndarray, to_move: int) -> jnp.ndarray:
    """
    Returns (3,3,3) float32:
      ch0 = 1 where X stones
      ch1 = 1 where O stones
      ch2 = to_move (all +1 if X to move, all -1 if O to move)
    """
    b = board_np.astype(np.int8)
    x = (b == 1).astype(np.float32)
    o = (b == -1).astype(np.float32)
    tm = np.full_like(x, float(1.0 if to_move == 1 else -1.0), dtype=np.float32)
    feat = np.stack([x, o, tm], axis=-1)  # (3,3,3)
    return jnp.asarray(feat)

class TTTNet(nn.Module):
    """Tiny conv net with separate policy and value heads."""
    @nn.compact
    def __call__(self, x):  # x: (B,3,3,3)
        # Trunk
        y = nn.Conv(features=32, kernel_size=(3,3), padding='SAME')(x)
        y = nn.relu(y)
        y = nn.Conv(features=64, kernel_size=(3,3), padding='SAME')(y)
        y = nn.relu(y)
        y = y.reshape((y.shape[0], -1))
        y = nn.Dense(64)(y); y = nn.relu(y)

        # Policy head -> logits for 9 actions
        p = nn.Dense(32)(y); p = nn.relu(p)
        policy_logits = nn.Dense(9)(p)  # (B,9)

        # Value head -> scalar in [-1,1]
        v = nn.Dense(32)(y); v = nn.relu(v)
        value = nn.Dense(1)(v)
        value = nn.tanh(value)          # (B,1)
        return policy_logits, value[:, 0]

def mask_and_softmax(logits: jnp.ndarray, legal_mask: jnp.ndarray) -> jnp.ndarray:
    """Softmax over legal actions only. logits, mask shape (9,), mask {0,1}."""
    # Set illegal to a large negative.
    masked = jnp.where(legal_mask > 0.5, logits, -1e9)
    # Avoid NaNs if all illegal (shouldn't happen)
    masked = jnp.nan_to_num(masked, nan=-1e9)
    exps = jnp.exp(masked - jnp.max(masked))
    denom = jnp.sum(exps)
    probs = jnp.where(denom > 0, exps / denom, jnp.ones_like(exps) / exps.shape[0])
    return probs

@partial(jax.jit, static_argnums=0)
def model_apply_jit(model: TTTNet, params, feats: jnp.ndarray, legal_mask: jnp.ndarray):
    """Batched apply; feats (B,3,3,3), legal_mask (B,9)."""
    logits, value = model.apply(params, feats)      # logits (B,9), value (B,)
    probs = jax.vmap(mask_and_softmax)(logits, legal_mask)
    return probs, value

# Wrapper to create a policy_value_fn compatible with PUCT_MCTS
def make_flax_policy_value_fn(model: TTTNet, params):
    rng = np.random.default_rng()

    def policy_value_fn(env: TicTacToe):
        legal = env.legal_actions()
        # Build feature tensor (B=1)
        feats = make_features(env.board, env.to_move)[None, ...]   # (1,3,3,3)
        # Build legal mask (1,9)
        mask = np.zeros((1, 9), dtype=np.float32)
        mask[0, legal] = 1.0
        probs_b, value_b = model_apply_jit(model, params, feats, jnp.asarray(mask))
        probs = np.array(probs_b[0])   # (9,)
        value = float(np.array(value_b[0]))  # from CURRENT to-move POV (as desired)

        # Convert to dict over legal actions only
        priors = {int(a): float(probs[a]) for a in legal}
        # Normalize just in case of tiny numerical drift
        s = sum(priors.values())
        if s <= 0:
            # Fallback uniform
            u = 1.0 / max(1, len(legal))
            priors = {int(a): u for a in legal}
        else:
            for a in list(priors.keys()):
                priors[a] /= s
        return priors, value

    return policy_value_fn

# =========================================================
# Example usage
# =========================================================
if __name__ == "__main__":
    # Init model & params
    model = TTTNet()
    key = jax.random.PRNGKey(0)
    dummy_feats = jnp.zeros((1,3,3,3), dtype=jnp.float32)
    params = model.init(key, dummy_feats)

    # Build policy/value fn
    pv_fn = make_flax_policy_value_fn(model, params)

    # Create MCTS with neural guidance
    mcts = PUCT_MCTS(
        policy_value_fn=pv_fn,
        c_puct=1.5,
        dirichlet_alpha=0.3,
        root_noise_frac=0.25,
        rng=np.random.default_rng(0)
    )

    for i in range(3):
        # Play two moves using neural-guided MCTS
        env = TicTacToe()  # X to move
        print("Initial state:\n", env, sep="")
        a1 = mcts.search(env, n_simulations=800, temperature=1.0)
        print(f"\nChosen action: {a1}")
        env = env.step(a1)
        print("\nAfter move:\n", env, sep="")
    
        a2 = mcts.search(env, n_simulations=800, temperature=1e-9)  # greedy by visits
        print(f"\nReply action: {a2}")
        env = env.step(a2)
        print("\nState after reply:\n", env, sep="")


Initial state:
. . .
. . .
. . .
To move: X

Chosen action: 5

After move:
. . .
. . X
. . .
To move: O

Reply action: 4

State after reply:
. . .
. O X
. . .
To move: X
Initial state:
. . .
. . .
. . .
To move: X

Chosen action: 3

After move:
. . .
X . .
. . .
To move: O

Reply action: 8

State after reply:
. . .
X . .
. . O
To move: X
Initial state:
. . .
. . .
. . .
To move: X

Chosen action: 1

After move:
. X .
. . .
. . .
To move: O

Reply action: 2

State after reply:
. X O
. . .
. . .
To move: X


In [50]:
# ============================
# Self-play & Replay Buffer
# ============================
import numpy as np
import jax
import jax.numpy as jnp
import optax
from dataclasses import dataclass
from typing import List, Tuple, Dict

class ReplayBuffer:
    def __init__(self, capacity: int = 100_000, rng: np.random.Generator = None):
        self.capacity = capacity
        self.rng = rng or np.random.default_rng()
        self._data = []  # list of (feat: (3,3,3), pi: (9,), z: scalar)

    def add_many(self, samples: List[Tuple[np.ndarray, np.ndarray, float]]):
        if not samples: return
        self._data.extend(samples)
        if len(self._data) > self.capacity:
            # Drop oldest
            self._data = self._data[-self.capacity:]

    def sample(self, batch_size: int):
        idx = self.rng.choice(len(self._data), size=batch_size, replace=False)
        feats = np.stack([self._data[i][0] for i in idx], axis=0).astype(np.float32)   # (B,3,3,3)
        pis   = np.stack([self._data[i][1] for i in idx], axis=0).astype(np.float32)   # (B,9)
        zs    = np.array([self._data[i][2] for i in idx], dtype=np.float32)            # (B,)
        return feats, pis, zs

    def __len__(self):
        return len(self._data)

def visits_to_pi(root_node) -> np.ndarray:
    """Turn root children visit counts into a probability vector over 9 actions."""
    pi = np.zeros(9, dtype=np.float32)
    total = 0
    for a, child in root_node.children.items():
        n = child.N
        pi[a] = n
        total += n
    if total > 0:
        pi /= total
    else:
        # fallback uniform
        pi[:] = 1.0 / 9.0
    return pi

def outcome_to_value_for_player(winner: int, player: int) -> float:
    """winner in {-1,0,1}, player in {-1,1} -> value in {-1,0,1} from player's POV."""
    if winner == 0: return 0.0
    return 1.0 if winner == player else -1.0

def self_play_game(mcts, model, params, temperature: float = 1.0, temp_moves: int = 6):
    """
    Play one game using current network.
    Return list of (feat, pi, z) samples.
    - temperature anneals: high for first 'temp_moves', then low (greedy by visits).
    """
    from copy import deepcopy
    env = TicTacToe()
    samples = []

    move_idx = 0
    states_pov: List[int] = []  # player to move for each recorded state (for value target)
    while not env.is_terminal():
        # Make features & run MCTS
        feats = make_features(env.board, env.to_move)  # (3,3,3)
        tau = temperature if move_idx < temp_moves else 1e-9

        # Run MCTS with current params (wrap a flax policy/value fn)
        pv_fn = make_flax_policy_value_fn(model, params)
        action = mcts.search(env, n_simulations=200, temperature=tau)  # you can tune sims

        # Grab pi from the root (visit counts distribution)
        # The PUCT_MCTS doesn't return the root; quick hack: rerun expand to reconstruct priors,
        # then get visits from the internal root held during the last search. Easiest approach:
        # modify search() to return both action and the root node. For brevity, we just recompute here:
        # (Instead, we store it inside mcts for this call.)
        # ---- Small patch: add a field 'last_root' on mcts in its search() implementation ----
        pi = visits_to_pi(mcts.last_root)

        samples.append((np.array(feats, dtype=np.float32), pi.copy(), 0.0))  # z is filled later
        states_pov.append(env.to_move)

        # Play move
        env = env.step(int(action))
        move_idx += 1

    # Game finished; set z for all samples
    w = env.winner()  # in {-1,0,1}
    finalized = []
    for (feat, pi, _), pov in zip(samples, states_pov):
        z = outcome_to_value_for_player(w, pov)
        finalized.append((feat, pi, z))
    return finalized
# ============================
# Training step (JAX/Optax)
# ============================
from flax.core import FrozenDict

def masked_log_softmax(logits: jnp.ndarray, legal_mask: jnp.ndarray) -> jnp.ndarray:
    """log softmax over legal moves only; illegal receive ~ -inf."""
    masked = jnp.where(legal_mask > 0.5, logits, -1e9)
    return jax.nn.log_softmax(masked, axis=-1)

def make_legal_mask_from_feats(feats: jnp.ndarray) -> jnp.ndarray:
    """
    For TicTacToe: legal = cells where both X and O channels are 0 (empty).
    feats: (B,3,3,3) with channels [X, O, to_move].
    """
    x = feats[..., 0]
    o = feats[..., 1]
    empty = (x == 0) & (o == 0)
    return empty.reshape((empty.shape[0], 9)).astype(jnp.float32)


from functools import partial

@partial(jax.jit, static_argnames=['model'])
def forward_for_train(params: FrozenDict, model: TTTNet, feats: jnp.ndarray):
    """Return (policy_logits, value) without masking/softmax; training handles masking."""
    logits, value = model.apply(params, feats)
    return logits, value  # (B,9), (B,)

def policy_value_loss(params: FrozenDict, model: TTTNet,
                      feats: jnp.ndarray, target_pi: jnp.ndarray, target_z: jnp.ndarray,
                      weight_policy: float = 1.0, weight_value: float = 1.0):
    logits, value = forward_for_train(params, model, feats)  # (B,9), (B,)
    legal_mask = jnp.asarray(make_legal_mask_from_feats(jnp.array(feats)))  # (B,9) on host -> device

    logp = masked_log_softmax(logits, legal_mask)  # (B,9)
    # cross-entropy: -sum pi * logp
    pol_loss = -jnp.mean(jnp.sum(target_pi * logp, axis=-1))
    val_loss = jnp.mean((value - target_z) ** 2)
    loss = weight_policy * pol_loss + weight_value * val_loss
    metrics = {
        "loss": loss,
        "policy_loss": pol_loss,
        "value_loss": val_loss,
        "value_mae": jnp.mean(jnp.abs(value - target_z)),
    }
    return loss, metrics

def make_optimizer(lr: float = 1e-3, weight_decay: float = 1e-4):
    tx = optax.adamw(learning_rate=lr, weight_decay=weight_decay)
    return tx

from functools import partial

@partial(jax.jit, static_argnames=['model', 'tx'])
def train_step(params: FrozenDict, opt_state, model: TTTNet,
               feats: jnp.ndarray, target_pi: jnp.ndarray, target_z: jnp.ndarray, tx):
    (loss, metrics), grads = jax.value_and_grad(policy_value_loss, has_aux=True)(
        params, model, feats, target_pi, target_z
    )
    updates, opt_state = tx.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, metrics



# ============================
# Simple training loop
# ============================
def train_loop(num_iters=50, games_per_iter=16, batch_size=64, train_steps=100,
               sims_per_move=200, lr=1e-3, wd=1e-4, seed=0):
    rng = np.random.default_rng(seed)

    # Init model + params
    model = TTTNet()
    key = jax.random.PRNGKey(seed)
    dummy_feats = jnp.zeros((1,3,3,3), dtype=jnp.float32)
    params = model.init(key, dummy_feats)

    # MCTS with neural guidance
    pv_fn = make_flax_policy_value_fn(model, params)
    mcts = PUCT_MCTS(
        policy_value_fn=pv_fn,
        c_puct=1.5,
        dirichlet_alpha=0.3,
        root_noise_frac=0.25,
        rng=rng
    )

    # PATCH: store root in search() (add to PUCT_MCTS.search before returning):
    #   self.last_root = root
    # And use sims_per_move arg inside search
    # (to keep snippet focused, we assume you've added it)

    rb = ReplayBuffer(capacity=50_000, rng=rng)
    tx = make_optimizer(lr=lr, weight_decay=wd)
    opt_state = tx.init(params)

    for it in range(1, num_iters + 1):
        # ---- Self-play ----
        all_samples = []
        for _ in range(games_per_iter):
            pv_fn = make_flax_policy_value_fn(model, params)  # refresh with latest params
            mcts.policy_value_fn = pv_fn
            samples = self_play_game(mcts, model, params, temperature=1.0, temp_moves=6)
            all_samples.extend(samples)
        rb.add_many(all_samples)

        # ---- Training ----
        if len(rb) < batch_size:
            print(f"[iter {it}] warming up buffer ({len(rb)} samples)")
            continue

        for step in range(train_steps):
            feats, pis, zs = rb.sample(batch_size)
            params, opt_state, metrics = train_step(
                params, opt_state, model,
                jnp.asarray(feats), jnp.asarray(pis), jnp.asarray(zs),
                tx
            )
            if (step + 1) % 25 == 0:
                print(f"[iter {it} step {step+1}] loss={float(metrics['loss']):.4f} "
                      f"pol={float(metrics['policy_loss']):.4f} val={float(metrics['value_loss']):.4f} "
                      f"mae={float(metrics['value_mae']):.3f}")

        # (Optional) evaluate vs. random or a previous snapshot here

    return params
# Batched apply for many positions at once (roots or leaves)
@jax.jit
def model_forward_batched(params, model, feats_b: jnp.ndarray, legal_masks_b: jnp.ndarray):
    logits, value = model.apply(params, feats_b)            # (B,9),(B,)
    # Mask+softmax per item
    def _per_item(l, m):
        masked = jnp.where(m > 0.5, l, -1e9)
        return jax.nn.softmax(masked)
    probs = jax.vmap(_per_item)(logits, legal_masks_b)      # (B,9)
    return probs, value

def eval_many_positions(model, params, envs: List[TicTacToe]):
    feats = np.stack([make_features(e.board, e.to_move) for e in envs], axis=0)
    masks = []
    for e in envs:
        legal = e.legal_actions()
        m = np.zeros(9, dtype=np.float32); m[legal] = 1.0
        masks.append(m)
    masks = np.stack(masks, 0)
    probs_b, values_b = model_forward_batched(params, model, jnp.asarray(feats), jnp.asarray(masks))
    # Convert to Python types
    probs_b = np.array(probs_b); values_b = np.array(values_b)
    # Return list of (priors_dict, value)
    out = []
    for i, env in enumerate(envs):
        legal = env.legal_actions()
        priors = {int(a): float(probs_b[i, a]) for a in legal}
        s = sum(priors.values()) or 1.0
        for a in list(priors.keys()): priors[a] /= s
        out.append((priors, float(values_b[i])))
    return out
    
def mcts_search_with_leaf_batching(root_env: TicTacToe, mcts: PUCT_MCTS,
                                   model: TTTNet, params, n_simulations: int = 512, batch_size: int = 64,
                                   temperature: float = 1.0):
    # Prepare root
    root = Node(key=state_to_key(root_env), to_move=root_env.to_move)
    mcts._expand(root_env, root)
    mcts._add_root_dirichlet_noise(root)
    mcts.last_root = root  # keep for visit extraction

    pending = []   # list of (path, leaf_env, parent_node_for_child, action_taken)
    done = 0

    while done < n_simulations:
        # 1) Collect up to batch_size leaves (selection+expansion only)
        batch_paths, batch_envs, batch_parents, batch_actions = [], [], [], []
        k = min(batch_size, n_simulations - done)
        for _ in range(k):
            node = root
            env = root_env.clone()
            path = []
            # Descend until we hit an unexpanded child or terminal
            while True:
                if env.is_terminal():
                    # Terminal: back up exact outcome immediately
                    v = env.result_from(root_env.to_move)
                    mcts._backup(path, leaf_value=v)
                    break
                a = mcts._puct_select(node)
                path.append((node, a))
                if a not in node.children:
                    # Expansion deferred — gather for batch eval
                    env_next = env.step(a)
                    batch_paths.append(path)
                    batch_envs.append(env_next)
                    batch_parents.append(node)
                    batch_actions.append(a)
                    break
                env = env.step(a)
                node = node.children[a]

        # 2) Evaluate all leaves in one shot
        if batch_envs:
            evals = eval_many_positions(model, params, batch_envs)  # list of (priors, value)
            for path, env_next, parent, a, (priors, val) in zip(batch_paths, batch_envs, batch_parents, batch_actions, evals):
                child = Node(key=state_to_key(env_next), to_move=env_next.to_move, parent=parent, parent_action=a)
                child.P = priors
                parent.children[a] = child
                mcts._backup(path, leaf_value=val, leaf_env_to_move=env_next.to_move, root_player=root_env.to_move)

        done += k

    # 3) Pick move from visit counts
    return mcts._select_action_from_visits(root, temperature)


In [53]:
params = train_loop(
    num_iters=5,        # number of outer loops (increase later)
    games_per_iter=8,   # how many self-play games per iteration
    batch_size=32,      # training batch size
    train_steps=50,     # SGD updates per iteration
    sims_per_move=100   # MCTS simulations per move
)


[iter 1 step 25] loss=173611104.0000 pol=173611104.0000 val=0.1143 mae=0.247
[iter 1 step 50] loss=152777776.0000 pol=152777776.0000 val=0.1409 mae=0.249
[iter 2 step 25] loss=52083320.0000 pol=52083320.0000 val=1.5223 mae=1.029
[iter 2 step 50] loss=152777600.0000 pol=152777600.0000 val=1.0625 mae=0.875
[iter 3 step 25] loss=114582624.0000 pol=114582624.0000 val=1.2500 mae=0.875
[iter 3 step 50] loss=152773856.0000 pol=152773856.0000 val=1.6875 mae=1.125
[iter 4 step 25] loss=128464432.0000 pol=128464432.0000 val=1.5312 mae=1.031
[iter 4 step 50] loss=83320024.0000 pol=83320024.0000 val=1.3750 mae=0.938
[iter 5 step 25] loss=100660552.0000 pol=100660552.0000 val=1.0000 mae=0.812
[iter 5 step 50] loss=48572440.0000 pol=48572440.0000 val=1.2500 mae=0.938
