# Stochastic MuZero with Learned Temporal Abstractions

This notebook implements **Stochastic MuZero** with a novel approach to discovering **rules as compressible causal structure**.

## Core Idea
Instead of treating rules as explicit symbols, we discover them operationally:
- **Rules** = sequences of transitions that are deterministic, repeatable, and compositional
- These can be collapsed into **macro-operators** for faster planning
- The model learns to separate **deterministic dynamics** (rules) from **stochastic chance** (environment randomness)

## Architecture
- **Afterstate separation**: `s → afterstate → chance → s'`
- **Entropy tracking**: Identifies which transitions are deterministic
- **Macro cache**: Stores and reuses discovered temporal abstractions
- **Hierarchical MCTS**: Plans at multiple temporal scales

## 1. Setup & Installation

In [None]:
# Install dependencies
!pip install torch numpy pyyaml tqdm tensorboard matplotlib python-chess -q

# Check GPU availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")
print(f"Using device: {DEVICE}")

In [None]:
# Create project structure
import os
import numpy as np
directories = ['config', 'games', 'networks', 'mcts', 'training', 'utils', 'runs']
for d in directories:
    os.makedirs(d, exist_ok=True)
    
print("Project structure created!")

## 2. Configuration

In [None]:
%%writefile utils/__init__.py
from .config import Config, load_config
from .support import scalar_to_support, support_to_scalar

__all__ = ["Config", "load_config", "scalar_to_support", "support_to_scalar"]

In [None]:
%%writefile utils/config.py
"""Configuration management for Stochastic MuZero."""

from dataclasses import dataclass
from typing import Optional, Dict, Any
import yaml
from pathlib import Path


@dataclass
class Config:
    """Configuration for Stochastic MuZero with macro-operator discovery."""

    # Game settings
    game: str = "2048"
    action_space_size: int = 4
    chance_space_size: int = 33

    # Network architecture
    state_dim: int = 256
    hidden_dim: int = 128
    num_layers: int = 2
    observation_dim: int = 496
    support_size: int = 31

    # Training
    batch_size: int = 256
    learning_rate: float = 1e-3
    weight_decay: float = 1e-4
    max_grad_norm: float = 5.0
    num_unroll_steps: int = 5
    td_steps: int = 10
    discount: float = 0.997

    # Self-play
    num_simulations: int = 50
    num_actors: int = 4
    max_moves: int = 10000
    temperature_init: float = 1.0
    temperature_final: float = 0.1
    temperature_decay_steps: int = 10000

    # Replay buffer
    replay_buffer_size: int = 100000
    priority_alpha: float = 1.0
    priority_beta: float = 1.0

    # MCTS
    pb_c_base: float = 19652.0
    pb_c_init: float = 1.25
    root_dirichlet_alpha: float = 0.3
    root_exploration_fraction: float = 0.25

    # Macro-operator discovery
    entropy_threshold: float = 0.1
    composition_threshold: float = 0.01
    min_macro_length: int = 2
    max_macro_length: int = 8
    macro_confidence_decay: float = 0.9
    macro_confidence_boost: float = 1.05
    max_macros: int = 1000

    # Chance node handling
    chance_entropy_threshold: float = 0.5
    top_k_chances: int = 5

    # Logging
    log_interval: int = 100
    save_interval: int = 1000
    eval_interval: int = 500

    # Device
    device: str = "cuda"

    def to_dict(self) -> Dict[str, Any]:
        return {k: getattr(self, k) for k in self.__dataclass_fields__}

    @classmethod
    def from_dict(cls, d: Dict[str, Any]) -> "Config":
        valid_keys = cls.__dataclass_fields__.keys()
        filtered = {k: v for k, v in d.items() if k in valid_keys}
        return cls(**filtered)


def load_config(config_path: Optional[str] = None, **overrides) -> Config:
    config_dict = {}
    if config_path is not None:
        path = Path(config_path)
        if path.exists():
            with open(path, "r") as f:
                config_dict = yaml.safe_load(f) or {}
    config_dict.update(overrides)
    return Config.from_dict(config_dict)


def save_config(config: Config, path: str) -> None:
    with open(path, "w") as f:
        yaml.dump(config.to_dict(), f, default_flow_style=False)

In [None]:
%%writefile utils/support.py
"""Support-based scalar transformations for MuZero."""

import torch
import torch.nn.functional as F


def scalar_to_support(x: torch.Tensor, support_size: int) -> torch.Tensor:
    """Transform scalar values to categorical support representation."""
    eps = 0.001
    transformed = torch.sign(x) * (torch.sqrt(torch.abs(x) + 1) - 1) + eps * x
    transformed = torch.clamp(transformed, -support_size, support_size)
    shifted = transformed + support_size
    
    floor_idx = shifted.floor().long()
    ceil_idx = floor_idx + 1
    floor_idx = torch.clamp(floor_idx, 0, 2 * support_size)
    ceil_idx = torch.clamp(ceil_idx, 0, 2 * support_size)
    
    ceil_weight = shifted - floor_idx.float()
    floor_weight = 1.0 - ceil_weight
    
    batch_shape = x.shape
    support_dim = 2 * support_size + 1
    
    flat_floor = floor_idx.flatten()
    flat_ceil = ceil_idx.flatten()
    flat_floor_weight = floor_weight.flatten()
    flat_ceil_weight = ceil_weight.flatten()
    
    output = torch.zeros(flat_floor.numel(), support_dim, device=x.device, dtype=x.dtype)
    output.scatter_add_(1, flat_floor.unsqueeze(1), flat_floor_weight.unsqueeze(1))
    output.scatter_add_(1, flat_ceil.unsqueeze(1), flat_ceil_weight.unsqueeze(1))
    
    return output.view(*batch_shape, support_dim)


def support_to_scalar(probs: torch.Tensor, support_size: int) -> torch.Tensor:
    """Transform categorical support distribution back to scalar values."""
    support = torch.arange(-support_size, support_size + 1, device=probs.device, dtype=probs.dtype)
    expected = (probs * support).sum(dim=-1)
    eps = 0.001
    sign = torch.sign(expected)
    abs_expected = torch.abs(expected)
    return sign * ((abs_expected + 1).square() - 1) / (1 + 2 * eps * (abs_expected + 1))


def compute_cross_entropy_loss(pred_logits: torch.Tensor, target_scalar: torch.Tensor, support_size: int) -> torch.Tensor:
    """Compute cross-entropy loss between predicted logits and target scalar."""
    target_probs = scalar_to_support(target_scalar, support_size)
    log_probs = F.log_softmax(pred_logits, dim=-1)
    loss = -(target_probs * log_probs).sum(dim=-1)
    return loss.mean()

## 3. Game Environment (2048)

In [None]:
%%writefile games/__init__.py
from .base import Game
from .game_2048 import Game2048
from .game_tictactoe import GameTicTacToe
from .game_chess import GameChess

__all__ = ["Game", "Game2048", "GameTicTacToe", "GameChess"]

In [None]:
%%writefile games/base.py
"""Abstract base class for game environments."""

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, TypeVar
import numpy as np
import torch

State = TypeVar("State")
Afterstate = TypeVar("Afterstate")
ChanceOutcome = int


@dataclass
class StepResult:
    """Result of applying an action and chance outcome."""
    afterstate: Any
    next_state: Any
    reward: float
    done: bool
    chance_outcome: int
    info: Dict[str, Any]


class Game(ABC):
    """Abstract base class for game environments with afterstate separation."""

    @property
    @abstractmethod
    def action_space_size(self) -> int:
        pass

    @property
    @abstractmethod
    def chance_space_size(self) -> int:
        pass

    @property
    @abstractmethod
    def observation_shape(self) -> Tuple[int, ...]:
        pass

    @abstractmethod
    def reset(self) -> State:
        pass

    @abstractmethod
    def clone_state(self, state: State) -> State:
        pass

    @abstractmethod
    def legal_actions(self, state: State) -> List[int]:
        pass

    @abstractmethod
    def apply_action(self, state: State, action: int) -> Tuple[Afterstate, float, Dict[str, Any]]:
        pass

    @abstractmethod
    def sample_chance(self, afterstate: Afterstate, info: Dict[str, Any]) -> ChanceOutcome:
        pass

    @abstractmethod
    def get_chance_distribution(self, afterstate: Afterstate, info: Dict[str, Any]) -> np.ndarray:
        pass

    @abstractmethod
    def apply_chance(self, afterstate: Afterstate, chance: ChanceOutcome) -> State:
        pass

    @abstractmethod
    def is_terminal(self, state: State) -> bool:
        pass

    @abstractmethod
    def encode_state(self, state: State) -> torch.Tensor:
        pass

    @abstractmethod
    def encode_afterstate(self, afterstate: Afterstate) -> torch.Tensor:
        pass

    def step(self, state: State, action: int) -> StepResult:
        """Full step: apply action, sample chance, apply chance."""
        afterstate, reward, info = self.apply_action(state, action)
        chance = self.sample_chance(afterstate, info)
        next_state = self.apply_chance(afterstate, chance)
        done = self.is_terminal(next_state)
        return StepResult(afterstate=afterstate, next_state=next_state, reward=reward, done=done, chance_outcome=chance, info=info)

    @property
    def is_two_player(self) -> bool:
        """Whether this is a two-player alternating game (requires value negation)."""
        return False

    def current_player(self, state: State) -> int:
        """Return current player (0 or 1). Override for two-player games."""
        return 0

    def get_canonical_state(self, state: State) -> State:
        """Get canonical form of state (for symmetry handling)."""
        return state

In [None]:
%%writefile games/game_2048.py
"""2048 game environment with afterstate separation."""

from dataclasses import dataclass
from typing import Any, Dict, List, Tuple
import numpy as np
import torch
from .base import Game, ChanceOutcome


@dataclass
class State2048:
    """State of a 2048 game."""
    grid: np.ndarray
    score: int = 0
    done: bool = False

    def copy(self) -> "State2048":
        return State2048(grid=self.grid.copy(), score=self.score, done=self.done)


ACTION_UP, ACTION_RIGHT, ACTION_DOWN, ACTION_LEFT = 0, 1, 2, 3
ACTION_NAMES = ["up", "right", "down", "left"]


class Game2048(Game):
    """2048 game with explicit afterstate separation."""

    BITS_PER_TILE = 31
    GRID_SIZE = 4

    def __init__(self):
        self._rng = np.random.default_rng()

    @property
    def action_space_size(self) -> int:
        return 4

    @property
    def chance_space_size(self) -> int:
        return 33

    @property
    def observation_shape(self) -> Tuple[int, ...]:
        return (self.BITS_PER_TILE * self.GRID_SIZE * self.GRID_SIZE,)

    def reset(self) -> State2048:
        grid = np.zeros((self.GRID_SIZE, self.GRID_SIZE), dtype=np.int64)
        empty = list(zip(*np.where(grid == 0)))
        pos1 = empty[self._rng.integers(len(empty))]
        grid[pos1] = 2 if self._rng.random() < 0.9 else 4
        empty = list(zip(*np.where(grid == 0)))
        pos2 = empty[self._rng.integers(len(empty))]
        grid[pos2] = 2 if self._rng.random() < 0.9 else 4
        return State2048(grid=grid, score=0, done=False)

    def clone_state(self, state: State2048) -> State2048:
        return state.copy()

    def legal_actions(self, state: State2048) -> List[int]:
        legal = []
        for action in range(4):
            afterstate, _, _ = self.apply_action(state, action)
            if not np.array_equal(afterstate.grid, state.grid):
                legal.append(action)
        return legal

    def _slide_and_merge_line(self, line: np.ndarray) -> Tuple[np.ndarray, int]:
        non_zero = line[line != 0]
        if len(non_zero) == 0:
            return np.zeros_like(line), 0
        merged, score, i = [], 0, 0
        while i < len(non_zero):
            if i + 1 < len(non_zero) and non_zero[i] == non_zero[i + 1]:
                merged_val = non_zero[i] * 2
                merged.append(merged_val)
                score += merged_val
                i += 2
            else:
                merged.append(non_zero[i])
                i += 1
        result = np.zeros_like(line)
        result[:len(merged)] = merged
        return result, score

    def _apply_action_to_grid(self, grid: np.ndarray, action: int) -> Tuple[np.ndarray, int]:
        new_grid = grid.copy()
        total_score = 0
        if action == ACTION_UP:
            for col in range(self.GRID_SIZE):
                new_grid[:, col], score = self._slide_and_merge_line(grid[:, col])
                total_score += score
        elif action == ACTION_DOWN:
            for col in range(self.GRID_SIZE):
                merged, score = self._slide_and_merge_line(grid[:, col][::-1])
                new_grid[:, col] = merged[::-1]
                total_score += score
        elif action == ACTION_LEFT:
            for row in range(self.GRID_SIZE):
                new_grid[row, :], score = self._slide_and_merge_line(grid[row, :])
                total_score += score
        elif action == ACTION_RIGHT:
            for row in range(self.GRID_SIZE):
                merged, score = self._slide_and_merge_line(grid[row, :][::-1])
                new_grid[row, :] = merged[::-1]
                total_score += score
        return new_grid, total_score

    def apply_action(self, state: State2048, action: int) -> Tuple[State2048, float, Dict[str, Any]]:
        new_grid, score_gained = self._apply_action_to_grid(state.grid, action)
        empty_positions = list(zip(*np.where(new_grid == 0)))
        afterstate = State2048(grid=new_grid, score=state.score + score_gained, done=False)
        info = {"empty_positions": empty_positions, "grid_changed": not np.array_equal(new_grid, state.grid)}
        return afterstate, float(score_gained), info

    def sample_chance(self, afterstate: State2048, info: Dict[str, Any]) -> ChanceOutcome:
        empty_positions = info.get("empty_positions", [])
        grid_changed = info.get("grid_changed", True)
        if not grid_changed or len(empty_positions) == 0:
            return 0
        pos_idx = self._rng.integers(len(empty_positions))
        row, col = empty_positions[pos_idx]
        flat_pos = row * self.GRID_SIZE + col
        if self._rng.random() < 0.9:
            return flat_pos + 1
        else:
            return flat_pos + 17

    def get_chance_distribution(self, afterstate: State2048, info: Dict[str, Any]) -> np.ndarray:
        dist = np.zeros(self.chance_space_size, dtype=np.float32)
        empty_positions = info.get("empty_positions", [])
        grid_changed = info.get("grid_changed", True)
        if not grid_changed or len(empty_positions) == 0:
            dist[0] = 1.0
            return dist
        prob_per_pos = 1.0 / len(empty_positions)
        for row, col in empty_positions:
            flat_pos = row * self.GRID_SIZE + col
            dist[flat_pos + 1] = prob_per_pos * 0.9
            dist[flat_pos + 17] = prob_per_pos * 0.1
        return dist

    def apply_chance(self, afterstate: State2048, chance: ChanceOutcome) -> State2048:
        if chance == 0:
            next_state = afterstate.copy()
            if len(self.legal_actions(next_state)) == 0:
                next_state.done = True
            return next_state
        if chance <= 16:
            flat_pos, value = chance - 1, 2
        else:
            flat_pos, value = chance - 17, 4
        row, col = flat_pos // self.GRID_SIZE, flat_pos % self.GRID_SIZE
        next_grid = afterstate.grid.copy()
        next_grid[row, col] = value
        next_state = State2048(grid=next_grid, score=afterstate.score, done=False)
        if len(self.legal_actions(next_state)) == 0:
            next_state.done = True
        return next_state

    def is_terminal(self, state: State2048) -> bool:
        return state.done

    def encode_state(self, state: State2048) -> torch.Tensor:
        return self._encode_grid(state.grid)

    def encode_afterstate(self, afterstate: State2048) -> torch.Tensor:
        return self._encode_grid(afterstate.grid)

    def _encode_grid(self, grid: np.ndarray) -> torch.Tensor:
        features = []
        for row in range(self.GRID_SIZE):
            for col in range(self.GRID_SIZE):
                val = grid[row, col]
                if val == 0:
                    bits = [0] * self.BITS_PER_TILE
                else:
                    exp = int(np.log2(val))
                    bits = [(exp >> i) & 1 for i in range(self.BITS_PER_TILE)]
                features.extend(bits)
        return torch.tensor(features, dtype=torch.float32)

    def get_max_tile(self, state: State2048) -> int:
        return int(state.grid.max())

    def render(self, state: State2048) -> str:
        lines = [f"Score: {state.score}", "-" * 25]
        for row in range(self.GRID_SIZE):
            cells = [f"{state.grid[row, col]:5d}" if state.grid[row, col] > 0 else "    ." for col in range(self.GRID_SIZE)]
            lines.append(" ".join(cells))
        lines.append("-" * 25)
        if state.done:
            lines.append("GAME OVER")
        return "\n".join(lines)

In [None]:
%%writefile games/game_tictactoe.py
"""Tic-Tac-Toe: minimal fully deterministic two-player game."""

from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
from .base import Game, ChanceOutcome


@dataclass
class TicTacToeState:
    board: np.ndarray  # 3x3, 0=empty, 1=X, 2=O
    current_player: int = 1
    done: bool = False
    winner: Optional[int] = None

    def copy(self):
        return TicTacToeState(board=self.board.copy(), current_player=self.current_player, done=self.done, winner=self.winner)


class GameTicTacToe(Game):
    """Fully deterministic two-player game. All transitions have entropy 0."""

    def __init__(self):
        self._rng = np.random.default_rng()

    @property
    def action_space_size(self) -> int:
        return 9

    @property
    def chance_space_size(self) -> int:
        return 1

    @property
    def observation_shape(self) -> Tuple[int, ...]:
        return (27,)  # 3 planes of 3x3

    @property
    def is_two_player(self) -> bool:
        return True

    def current_player(self, state: TicTacToeState) -> int:
        return 0 if state.current_player == 1 else 1

    def reset(self) -> TicTacToeState:
        return TicTacToeState(board=np.zeros((3, 3), dtype=np.int32), current_player=1)

    def clone_state(self, state):
        return state.copy()

    def legal_actions(self, state) -> List[int]:
        if state.done:
            return []
        return [i for i in range(9) if state.board[i // 3, i % 3] == 0]

    def _check_winner(self, board):
        for p in [1, 2]:
            for r in range(3):
                if all(board[r, c] == p for c in range(3)):
                    return p
            for c in range(3):
                if all(board[r, c] == p for r in range(3)):
                    return p
            if all(board[i, i] == p for i in range(3)):
                return p
            if all(board[i, 2 - i] == p for i in range(3)):
                return p
        return None

    def apply_action(self, state, action):
        row, col = action // 3, action % 3
        new_board = state.board.copy()
        new_board[row, col] = state.current_player
        winner = self._check_winner(new_board)
        is_draw = winner is None and np.all(new_board != 0)
        if winner is not None:
            done, reward = True, 1.0
        elif is_draw:
            done, reward, winner = True, 0.0, 0
        else:
            done, reward = False, 0.0
        next_player = 2 if state.current_player == 1 else 1
        afterstate = TicTacToeState(board=new_board, current_player=next_player, done=done, winner=winner)
        return afterstate, reward, {}

    def sample_chance(self, afterstate, info):
        return 0

    def get_chance_distribution(self, afterstate, info):
        return np.array([1.0], dtype=np.float32)

    def apply_chance(self, afterstate, chance):
        return afterstate

    def is_terminal(self, state):
        return state.done

    def encode_state(self, state):
        me = state.current_player
        opp = 2 if me == 1 else 1
        my_pieces = (state.board == me).astype(np.float32).flatten()
        opp_pieces = (state.board == opp).astype(np.float32).flatten()
        empty = (state.board == 0).astype(np.float32).flatten()
        return torch.tensor(np.concatenate([my_pieces, opp_pieces, empty]))

    def encode_afterstate(self, afterstate):
        return self.encode_state(afterstate)

    def render(self, state):
        symbols = {0: ".", 1: "X", 2: "O"}
        lines = [" ".join(symbols[state.board[r, c]] for c in range(3)) for r in range(3)]
        lines.append(f"Player: {'X' if state.current_player == 1 else 'O'}")
        if state.done:
            lines.append(f"Result: {'Draw' if state.winner == 0 else f'{symbols[state.winner]} wins!'}")
        return "\n".join(lines)

In [None]:
%%writefile games/game_chess.py
"""Chess: fully deterministic two-player game with rich rule structure.

Uses python-chess as the rule engine. The agent must discover legal moves
through play — they are NOT hardcoded into the policy. The model learns
which actions are legal by experiencing rejection of illegal moves.

AlphaZero-style action encoding:
- 4672 actions = 64 from-squares x 73 move types
- 56 queen-like moves (8 directions x 7 distances)
- 8 knight moves
- 9 underpromotions (3 directions x 3 piece types)
- Queen promotions encoded as queen-like moves

State encoding:
- 22 planes x 64 squares = 1408 features (flattened for MLP)
- Board always from current player's perspective (flipped for black)
"""

from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch

try:
    import chess
except ImportError:
    raise ImportError("python-chess is required: pip install python-chess")

from .base import Game, ChanceOutcome

# Move encoding constants

# 8 directions for queen-like moves: (file_delta, rank_delta) per unit step
QUEEN_DIRECTIONS = [
    (0, 1),    # 0: N
    (1, 1),    # 1: NE
    (1, 0),    # 2: E
    (1, -1),   # 3: SE
    (0, -1),   # 4: S
    (-1, -1),  # 5: SW
    (-1, 0),   # 6: W
    (-1, 1),   # 7: NW
]

# 8 knight move offsets: (file_delta, rank_delta)
KNIGHT_MOVES = [
    (1, 2), (2, 1), (2, -1), (1, -2),
    (-1, -2), (-2, -1), (-2, 1), (-1, 2),
]

# Underpromotion: 3 forward directions x 3 piece types
# Directions from current player's perspective (always moving "north")
UNDERPROMO_DIRECTIONS = [(-1, 1), (0, 1), (1, 1)]  # Left capture, straight, right capture
UNDERPROMO_PIECES = [chess.KNIGHT, chess.BISHOP, chess.ROOK]


class GameChess(Game):
    """
    Chess with AlphaZero-style encoding for Stochastic MuZero.

    Fully deterministic (chance_space_size=1), so every transition has
    entropy ~ 0. All trajectory segments are macro candidates, enabling
    discovery of opening sequences, tactical motifs, etc.
    """

    def __init__(self):
        pass

    @property
    def action_space_size(self) -> int:
        return 4672  # 64 from-squares x 73 move types

    @property
    def chance_space_size(self) -> int:
        return 1  # Fully deterministic

    @property
    def observation_shape(self) -> Tuple[int, ...]:
        return (1408,)  # 22 planes x 64 squares

    @property
    def is_two_player(self) -> bool:
        return True

    def current_player(self, state: chess.Board) -> int:
        return 0 if state.turn == chess.WHITE else 1

    def reset(self) -> chess.Board:
        return chess.Board()

    def clone_state(self, state: chess.Board) -> chess.Board:
        return state.copy()

    def legal_actions(self, state: chess.Board) -> List[int]:
        """Get all legal actions as AlphaZero-style action indices."""
        if state.is_game_over():
            return []
        actions = set()
        is_black = state.turn == chess.BLACK
        for move in state.legal_moves:
            action = self._move_to_action(move, is_black)
            actions.add(action)
        return sorted(actions)

    def apply_action(
        self, state: chess.Board, action: int
    ) -> Tuple[chess.Board, float, Dict[str, Any]]:
        """Apply action. Returns (afterstate, reward, info)."""
        board = state.copy()
        is_black = board.turn == chess.BLACK
        move = self._action_to_move(action, is_black, board)

        if move is None or move not in board.legal_moves:
            # Invalid action -- return unchanged state with zero reward.
            # The MCTS should only select legal actions, so this is a fallback.
            return board, 0.0, {"invalid": True}

        board.push(move)

        # Reward: +1 for checkmate (from perspective of player who just moved)
        reward = 0.0
        if board.is_checkmate():
            reward = 1.0

        return board, reward, {}

    def sample_chance(
        self, afterstate: chess.Board, info: Dict[str, Any]
    ) -> ChanceOutcome:
        return 0  # Deterministic

    def get_chance_distribution(
        self, afterstate: chess.Board, info: Dict[str, Any]
    ) -> np.ndarray:
        return np.array([1.0], dtype=np.float32)

    def apply_chance(
        self, afterstate: chess.Board, chance: ChanceOutcome
    ) -> chess.Board:
        return afterstate  # Identity for deterministic games

    def is_terminal(self, state: chess.Board) -> bool:
        return state.is_game_over()

    def encode_state(self, state: chess.Board) -> torch.Tensor:
        """
        Encode board as 22 planes x 64 squares = 1408 features.

        Always from current player's perspective (board flipped for black).

        Planes:
         0-5:  Current player pieces (P, N, B, R, Q, K)
         6-11: Opponent pieces (P, N, B, R, Q, K)
         12:   My kingside castling
         13:   My queenside castling
         14:   Opponent kingside castling
         15:   Opponent queenside castling
         16:   En passant square
         17:   Halfmove clock (normalized)
         18:   Fullmove number (normalized)
         19:   Color to move (always 1 -- we see from own perspective)
         20:   Twofold repetition
         21:   Threefold repetition
        """
        planes = np.zeros((22, 8, 8), dtype=np.float32)
        is_black = state.turn == chess.BLACK

        # Piece planes (0-5: current player, 6-11: opponent)
        me = state.turn
        opp = not state.turn
        for pt in range(1, 7):  # PAWN=1 .. KING=6
            for sq in state.pieces(pt, me):
                r, f = chess.square_rank(sq), chess.square_file(sq)
                if is_black:
                    r = 7 - r
                planes[pt - 1, r, f] = 1.0

            for sq in state.pieces(pt, opp):
                r, f = chess.square_rank(sq), chess.square_file(sq)
                if is_black:
                    r = 7 - r
                planes[pt + 5, r, f] = 1.0

        # Castling rights (planes 12-15): my KS, my QS, opp KS, opp QS
        if is_black:
            planes[12] = float(state.has_kingside_castling_rights(chess.BLACK))
            planes[13] = float(state.has_queenside_castling_rights(chess.BLACK))
            planes[14] = float(state.has_kingside_castling_rights(chess.WHITE))
            planes[15] = float(state.has_queenside_castling_rights(chess.WHITE))
        else:
            planes[12] = float(state.has_kingside_castling_rights(chess.WHITE))
            planes[13] = float(state.has_queenside_castling_rights(chess.WHITE))
            planes[14] = float(state.has_kingside_castling_rights(chess.BLACK))
            planes[15] = float(state.has_queenside_castling_rights(chess.BLACK))

        # En passant (plane 16)
        if state.ep_square is not None:
            r = chess.square_rank(state.ep_square)
            f = chess.square_file(state.ep_square)
            if is_black:
                r = 7 - r
            planes[16, r, f] = 1.0

        # Halfmove clock (plane 17, normalized to [0, 1])
        planes[17] = state.halfmove_clock / 100.0

        # Fullmove number (plane 18, normalized)
        planes[18] = min(state.fullmove_number / 200.0, 1.0)

        # Color to move (plane 19): always 1 from own perspective
        planes[19] = 1.0

        # Repetition planes (20-21)
        if state.is_repetition(2):
            planes[20] = 1.0
        if state.is_repetition(3):
            planes[21] = 1.0

        return torch.tensor(planes.reshape(-1), dtype=torch.float32)

    def encode_afterstate(self, afterstate: chess.Board) -> torch.Tensor:
        return self.encode_state(afterstate)

    # ------------------------------------------------------------------
    # Action encoding / decoding
    # ------------------------------------------------------------------

    def _move_to_action(self, move: chess.Move, is_black: bool) -> int:
        """Convert a chess.Move to an action index in [0, 4671]."""
        from_sq = move.from_square
        to_sq = move.to_square

        # Mirror squares for black so encoding is always from own perspective
        if is_black:
            from_sq = chess.square_mirror(from_sq)
            to_sq = chess.square_mirror(to_sq)

        from_file = chess.square_file(from_sq)
        from_rank = chess.square_rank(from_sq)
        to_file = chess.square_file(to_sq)
        to_rank = chess.square_rank(to_sq)

        df = to_file - from_file
        dr = to_rank - from_rank

        # 1) Underpromotion (knight, bishop, rook)
        if move.promotion is not None and move.promotion != chess.QUEEN:
            try:
                dir_idx = UNDERPROMO_DIRECTIONS.index((df, dr))
            except ValueError:
                dir_idx = 1  # fallback to straight
            piece_idx = UNDERPROMO_PIECES.index(move.promotion)
            move_type = 64 + dir_idx * 3 + piece_idx

        # 2) Knight move
        elif (df, dr) in KNIGHT_MOVES:
            knight_idx = KNIGHT_MOVES.index((df, dr))
            move_type = 56 + knight_idx

        # 3) Queen-like move (straight / diagonal slides, pawn pushes, queen promos)
        else:
            direction = self._delta_to_direction(df, dr)
            distance = max(abs(df), abs(dr))
            move_type = direction * 7 + (distance - 1)

        return from_sq * 73 + move_type

    def _action_to_move(
        self, action: int, is_black: bool, board: chess.Board
    ) -> Optional[chess.Move]:
        """Convert an action index back to a chess.Move (or None if invalid)."""
        from_sq = action // 73
        move_type = action % 73

        from_file = chess.square_file(from_sq)
        from_rank = chess.square_rank(from_sq)

        promotion = None

        if move_type < 56:
            # Queen-like move
            direction = move_type // 7
            distance = move_type % 7 + 1
            df, dr = QUEEN_DIRECTIONS[direction]
            to_file = from_file + df * distance
            to_rank = from_rank + dr * distance

        elif move_type < 64:
            # Knight move
            knight_idx = move_type - 56
            df, dr = KNIGHT_MOVES[knight_idx]
            to_file = from_file + df
            to_rank = from_rank + dr

        else:
            # Underpromotion
            under_idx = move_type - 64
            dir_idx = under_idx // 3
            piece_idx = under_idx % 3
            df, dr = UNDERPROMO_DIRECTIONS[dir_idx]
            to_file = from_file + df
            to_rank = from_rank + dr
            promotion = UNDERPROMO_PIECES[piece_idx]

        # Bounds check
        if not (0 <= to_file <= 7 and 0 <= to_rank <= 7):
            return None

        to_sq = chess.square(to_file, to_rank)

        # Detect queen promotion: pawn reaching last rank via queen-like move
        if promotion is None and to_rank == 7:
            actual_from = chess.square_mirror(from_sq) if is_black else from_sq
            piece = board.piece_at(actual_from)
            if piece is not None and piece.piece_type == chess.PAWN:
                promotion = chess.QUEEN

        # Unmirror for black
        if is_black:
            from_sq = chess.square_mirror(from_sq)
            to_sq = chess.square_mirror(to_sq)

        return chess.Move(from_sq, to_sq, promotion=promotion)

    @staticmethod
    def _delta_to_direction(df: int, dr: int) -> int:
        """Map (file_delta, rank_delta) to one of 8 compass directions."""
        if df == 0 and dr > 0:
            return 0   # N
        if df > 0 and dr > 0:
            return 1   # NE
        if df > 0 and dr == 0:
            return 2   # E
        if df > 0 and dr < 0:
            return 3   # SE
        if df == 0 and dr < 0:
            return 4   # S
        if df < 0 and dr < 0:
            return 5   # SW
        if df < 0 and dr == 0:
            return 6   # W
        if df < 0 and dr > 0:
            return 7   # NW
        return 0  # fallback (shouldn't happen for valid moves)

    def render(self, state: chess.Board) -> str:
        """Human-readable board display."""
        turn_str = "White" if state.turn == chess.WHITE else "Black"
        status = ""
        if state.is_checkmate():
            winner = "Black" if state.turn == chess.WHITE else "White"
            status = f"\nCheckmate! {winner} wins."
        elif state.is_stalemate():
            status = "\nStalemate -- draw."
        elif state.is_check():
            status = "\nCheck!"
        return f"{state}\nTurn: {turn_str} | Move: {state.fullmove_number}{status}"

## 4. Neural Networks

In [None]:
%%writefile networks/__init__.py
from .muzero_network import MuZeroNetwork

__all__ = ["MuZeroNetwork"]

In [None]:
%%writefile networks/muzero_network.py
"""Combined MuZero network with all components."""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Optional, Tuple, List
from dataclasses import dataclass


class MLP(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int = 2):
        super().__init__()
        layers = []
        current_dim = input_dim
        for i in range(num_layers - 1):
            layers.extend([nn.Linear(current_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU()])
            current_dim = hidden_dim
        layers.extend([nn.Linear(current_dim, output_dim), nn.LayerNorm(output_dim)])
        self.net = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


@dataclass
class NetworkOutput:
    state: torch.Tensor
    policy_logits: torch.Tensor
    value_logits: torch.Tensor


@dataclass
class DynamicsOutput:
    afterstate: torch.Tensor
    next_state: torch.Tensor
    reward_logits: torch.Tensor
    chance_logits: torch.Tensor
    chance_entropy: torch.Tensor
    afterstate_policy_logits: torch.Tensor
    afterstate_value_logits: torch.Tensor


class MuZeroNetwork(nn.Module):
    """Complete Stochastic MuZero network."""

    def __init__(self, observation_dim: int, action_space_size: int, chance_space_size: int,
                 state_dim: int = 256, hidden_dim: int = 128, num_layers: int = 2, support_size: int = 31):
        super().__init__()
        self.observation_dim = observation_dim
        self.action_space_size = action_space_size
        self.chance_space_size = chance_space_size
        self.state_dim = state_dim
        self.support_size = support_size

        # Representation
        self.representation = MLP(observation_dim, hidden_dim, state_dim, num_layers)
        
        # Afterstate dynamics
        self.afterstate_dynamics = MLP(state_dim + action_space_size, hidden_dim, state_dim, num_layers)
        
        # Chance encoder
        self.chance_encoder = MLP(state_dim, hidden_dim, chance_space_size, num_layers)
        
        # Dynamics
        self.dynamics_trunk = MLP(state_dim + chance_space_size, hidden_dim, hidden_dim, num_layers - 1)
        self.dynamics_state_head = nn.Sequential(nn.Linear(hidden_dim, state_dim), nn.LayerNorm(state_dim))
        self.dynamics_reward_head = nn.Linear(hidden_dim, 2 * support_size + 1)
        
        # Prediction
        self.prediction_trunk = MLP(state_dim, hidden_dim, hidden_dim, num_layers - 1)
        self.policy_head = nn.Linear(hidden_dim, action_space_size)
        self.value_head = nn.Linear(hidden_dim, 2 * support_size + 1)
        
        # Afterstate prediction
        self.afterstate_trunk = MLP(state_dim, hidden_dim, hidden_dim, num_layers - 1)
        self.afterstate_policy_head = nn.Linear(hidden_dim, action_space_size)
        self.afterstate_value_head = nn.Linear(hidden_dim, 2 * support_size + 1)

    def initial_inference(self, observation: torch.Tensor) -> NetworkOutput:
        state = self.representation(observation)
        features = self.prediction_trunk(state)
        policy_logits = self.policy_head(features)
        value_logits = self.value_head(features)
        return NetworkOutput(state=state, policy_logits=policy_logits, value_logits=value_logits)

    def recurrent_inference(self, state: torch.Tensor, action: torch.Tensor, chance: Optional[torch.Tensor] = None) -> DynamicsOutput:
        # Afterstate
        if action.dim() == 1:
            action_onehot = F.one_hot(action, self.action_space_size).float()
        else:
            action_onehot = action
        afterstate = self.afterstate_dynamics(torch.cat([state, action_onehot], dim=-1))
        
        # Chance
        chance_logits = self.chance_encoder(afterstate)
        probs = F.softmax(chance_logits, dim=-1)
        log_probs = F.log_softmax(chance_logits, dim=-1)
        entropy = -(probs * log_probs).sum(dim=-1)
        
        if chance is None:
            chance = torch.multinomial(probs, num_samples=1).squeeze(-1)
        
        chance_onehot = F.one_hot(chance, self.chance_space_size).float()
        
        # Dynamics
        dyn_features = self.dynamics_trunk(torch.cat([afterstate, chance_onehot], dim=-1))
        next_state = self.dynamics_state_head(dyn_features)
        reward_logits = self.dynamics_reward_head(dyn_features)
        
        # Afterstate prediction
        as_features = self.afterstate_trunk(afterstate)
        as_policy = self.afterstate_policy_head(as_features)
        as_value = self.afterstate_value_head(as_features)
        
        return DynamicsOutput(afterstate=afterstate, next_state=next_state, reward_logits=reward_logits,
                             chance_logits=chance_logits, chance_entropy=entropy,
                             afterstate_policy_logits=as_policy, afterstate_value_logits=as_value)

    def predict_state(self, state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        features = self.prediction_trunk(state)
        return self.policy_head(features), self.value_head(features)

    def predict_afterstate(self, afterstate: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        features = self.afterstate_trunk(afterstate)
        return self.afterstate_policy_head(features), self.afterstate_value_head(features)

    def unroll(self, observation: torch.Tensor, actions: torch.Tensor, chances: torch.Tensor) -> Dict[str, List[torch.Tensor]]:
        batch_size, K = actions.shape
        initial = self.initial_inference(observation)
        
        states = [initial.state]
        afterstates = []
        policy_logits = [initial.policy_logits]
        value_logits = [initial.value_logits]
        reward_logits = []
        chance_logits = []
        chance_entropies = []
        
        current_state = initial.state
        for k in range(K):
            current_state = scale_gradient(current_state, 0.5)
            dynamics_out = self.recurrent_inference(current_state, actions[:, k], chances[:, k])
            
            afterstates.append(dynamics_out.afterstate)
            states.append(dynamics_out.next_state)
            reward_logits.append(dynamics_out.reward_logits)
            chance_logits.append(dynamics_out.chance_logits)
            chance_entropies.append(dynamics_out.chance_entropy)
            
            next_policy, next_value = self.predict_state(dynamics_out.next_state)
            policy_logits.append(next_policy)
            value_logits.append(next_value)
            current_state = dynamics_out.next_state
        
        return {"states": states, "afterstates": afterstates, "policy_logits": policy_logits,
                "value_logits": value_logits, "reward_logits": reward_logits,
                "chance_logits": chance_logits, "chance_entropies": chance_entropies}


class ScaleGradient(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, scale):
        ctx.scale = scale
        return x

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output * ctx.scale, None


def scale_gradient(x: torch.Tensor, scale: float) -> torch.Tensor:
    return ScaleGradient.apply(x, scale)

## 5. MCTS with Macro Support

In [None]:
%%writefile mcts/__init__.py
from .node import Node
from .tree_search import StochasticMCTS, MCTSConfig
from .macro_cache import MacroCache, MacroOperator

__all__ = ["Node", "StochasticMCTS", "MCTSConfig", "MacroCache", "MacroOperator"]

In [None]:
%%writefile mcts/node.py
"""MCTS tree node for Stochastic MuZero."""

from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch


@dataclass
class Node:
    hidden_state: Optional[torch.Tensor] = None
    is_chance_node: bool = False
    prior: float = 0.0
    visit_count: int = 0
    value_sum: float = 0.0
    reward: float = 0.0
    children: Dict[int, "Node"] = field(default_factory=dict)
    parent: Optional["Node"] = None
    action_from_parent: Optional[int] = None
    macro_id: Optional[int] = None
    macro_confidence: float = 1.0
    chance_entropy: float = 0.0
    to_play: int = -1

    @property
    def expanded(self) -> bool:
        return len(self.children) > 0

    @property
    def value(self) -> float:
        return self.value_sum / self.visit_count if self.visit_count > 0 else 0.0

    def expand(self, hidden_state: torch.Tensor, policy_logits: torch.Tensor, legal_actions: List[int], is_chance_node: bool = False):
        self.hidden_state = hidden_state
        policy = torch.softmax(policy_logits, dim=-1).cpu().numpy()
        legal_mask = np.zeros(len(policy))
        legal_mask[legal_actions] = 1.0
        policy = policy * legal_mask
        policy_sum = policy.sum()
        if policy_sum > 0:
            policy = policy / policy_sum
        else:
            policy[legal_actions] = 1.0 / len(legal_actions)
        for action in legal_actions:
            self.children[action] = Node(is_chance_node=is_chance_node, prior=float(policy[action]), parent=self, action_from_parent=action)

    def expand_chance(self, hidden_state: torch.Tensor, chance_logits: torch.Tensor, top_k: int = 5, entropy_threshold: float = 0.5):
        self.hidden_state = hidden_state
        probs = torch.softmax(chance_logits, dim=-1).cpu().numpy()
        log_probs = np.log(probs + 1e-10)
        entropy = -np.sum(probs * log_probs)
        self.chance_entropy = entropy
        if entropy < entropy_threshold:
            top_indices = np.argsort(probs)[-top_k:][::-1]
            top_probs = probs[top_indices]
            top_probs = top_probs / top_probs.sum()
            for idx, prob in zip(top_indices, top_probs):
                self.children[int(idx)] = Node(is_chance_node=False, prior=float(prob), parent=self, action_from_parent=int(idx))
            return list(top_indices), True
        else:
            sampled = np.random.choice(len(probs), p=probs)
            self.children[int(sampled)] = Node(is_chance_node=False, prior=1.0, parent=self, action_from_parent=int(sampled))
            return [int(sampled)], False

    def add_exploration_noise(self, dirichlet_alpha: float, exploration_fraction: float):
        if not self.children:
            return
        actions = list(self.children.keys())
        noise = np.random.dirichlet([dirichlet_alpha] * len(actions))
        for i, action in enumerate(actions):
            self.children[action].prior = self.children[action].prior * (1 - exploration_fraction) + noise[i] * exploration_fraction

    def select_child(self, pb_c_base: float, pb_c_init: float, discount: float, min_max_stats):
        best_score, best_action, best_child = float("-inf"), None, None
        for action, child in self.children.items():
            pb_c = np.log((self.visit_count + pb_c_base + 1) / pb_c_base) + pb_c_init
            prior_score = pb_c * child.prior * np.sqrt(self.visit_count) / (1 + child.visit_count)
            value_score = min_max_stats.normalize(child.reward + discount * child.value) if child.visit_count > 0 else 0.0
            score = prior_score + value_score
            if score > best_score:
                best_score, best_action, best_child = score, action, child
        return best_action, best_child

    def select_action(self, temperature: float = 1.0) -> int:
        actions = list(self.children.keys())
        visit_counts = np.array([self.children[a].visit_count for a in actions])
        if temperature == 0:
            return actions[np.argmax(visit_counts)]
        counts_temp = visit_counts ** (1.0 / temperature)
        probs = counts_temp / counts_temp.sum()
        return int(np.random.choice(actions, p=probs))

    def get_policy(self):
        actions = np.array(list(self.children.keys()))
        visit_counts = np.array([self.children[a].visit_count for a in actions])
        probs = visit_counts / visit_counts.sum()
        return actions, probs


class MinMaxStats:
    def __init__(self):
        self.minimum = float("inf")
        self.maximum = float("-inf")

    def update(self, value: float):
        self.minimum = min(self.minimum, value)
        self.maximum = max(self.maximum, value)

    def normalize(self, value: float) -> float:
        return (value - self.minimum) / (self.maximum - self.minimum) if self.maximum > self.minimum else value

In [None]:
%%writefile mcts/macro_cache.py
"""Macro-operator cache for learned temporal abstractions."""

from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from collections import defaultdict


@dataclass
class MacroOperator:
    id: int
    action_sequence: Tuple[int, ...]
    length: int
    precondition_features: Optional[torch.Tensor] = None
    confidence: float = 1.0
    usage_count: int = 0
    success_count: int = 0
    creation_step: int = 0
    entropy_history: List[float] = field(default_factory=list)
    max_entropy_seen: float = 0.0

    @property
    def success_rate(self) -> float:
        return self.success_count / self.usage_count if self.usage_count > 0 else 1.0

    def record_usage(self, success: bool, entropy: float):
        self.usage_count += 1
        if success:
            self.success_count += 1
        self.entropy_history.append(entropy)
        self.max_entropy_seen = max(self.max_entropy_seen, entropy)


class MacroCache:
    def __init__(self, state_dim: int = 256, entropy_threshold: float = 0.1, composition_threshold: float = 0.01,
                 min_macro_length: int = 2, max_macro_length: int = 8, max_macros: int = 1000,
                 confidence_decay: float = 0.9, confidence_boost: float = 1.05, min_confidence: float = 0.5):
        self.state_dim = state_dim
        self.entropy_threshold = entropy_threshold
        self.composition_threshold = composition_threshold
        self.min_macro_length = min_macro_length
        self.max_macro_length = max_macro_length
        self.max_macros = max_macros
        self.confidence_decay = confidence_decay
        self.confidence_boost = confidence_boost
        self.min_confidence = min_confidence
        
        self.macros: Dict[int, MacroOperator] = {}
        self.action_index: Dict[Tuple[int, ...], List[int]] = defaultdict(list)
        self._next_id = 0
        self.total_discoveries = 0
        self.total_uses = 0
        self.total_successes = 0

    def discover_macro(self, states: List[torch.Tensor], actions: List[int], entropies: List[float], training_step: int = 0):
        k = len(actions)
        if k < self.min_macro_length or k > self.max_macro_length:
            return None
        if max(entropies) > self.entropy_threshold:
            return None
        action_tuple = tuple(actions)
        if action_tuple in self.action_index:
            for macro_id in self.action_index[action_tuple]:
                self.macros[macro_id].entropy_history.append(max(entropies))
            return None
        
        macro = MacroOperator(id=self._next_id, action_sequence=action_tuple, length=k,
                             precondition_features=states[0].detach().clone() if states[0] is not None else None,
                             confidence=1.0, creation_step=training_step, entropy_history=[max(entropies)], max_entropy_seen=max(entropies))
        self._next_id += 1
        self.total_discoveries += 1
        self.macros[macro.id] = macro
        self.action_index[action_tuple].append(macro.id)
        
        if len(self.macros) > self.max_macros:
            worst_id = min(self.macros.keys(), key=lambda m: self.macros[m].confidence)
            del self.macros[worst_id]
        return macro

    def get_applicable_macros(self, state: torch.Tensor, legal_actions: List[int]) -> List[MacroOperator]:
        applicable = [m for m in self.macros.values() if m.action_sequence[0] in legal_actions and m.confidence >= self.min_confidence]
        applicable.sort(key=lambda m: m.confidence, reverse=True)
        return applicable

    def update_macro(self, macro_id: int, success: bool, entropy: float):
        if macro_id not in self.macros:
            return
        macro = self.macros[macro_id]
        macro.record_usage(success, entropy)
        self.total_uses += 1
        if entropy > self.entropy_threshold:
            macro.confidence *= self.confidence_decay
        elif success:
            macro.confidence = min(1.0, macro.confidence * self.confidence_boost)
            self.total_successes += 1
        else:
            macro.confidence *= self.confidence_decay

    def get_statistics(self) -> Dict[str, float]:
        return {
            "num_macros": len(self.macros),
            "total_discoveries": self.total_discoveries,
            "total_uses": self.total_uses,
            "total_successes": self.total_successes,
            "success_rate": self.total_successes / self.total_uses if self.total_uses > 0 else 0.0,
            "avg_confidence": np.mean([m.confidence for m in self.macros.values()]) if self.macros else 0.0,
            "avg_length": np.mean([m.length for m in self.macros.values()]) if self.macros else 0.0,
        }


def discover_macros_from_trajectory(trajectory: List[Dict], macro_cache: MacroCache, min_length: int = 2, max_length: int = 8, training_step: int = 0):
    discoveries = []
    n = len(trajectory)
    if n < min_length:
        return discoveries
    for length in range(min_length, min(max_length + 1, n + 1)):
        for start in range(n - length + 1):
            segment = trajectory[start:start + length]
            states = [t["state"] for t in segment]
            if "next_state" in segment[-1] and segment[-1]["next_state"] is not None:
                states.append(segment[-1]["next_state"])
            else:
                continue
            actions = [t["action"] for t in segment]
            entropies = [t["entropy"] for t in segment]
            macro = macro_cache.discover_macro(states=states, actions=actions, entropies=entropies, training_step=training_step)
            if macro is not None:
                discoveries.append(macro)
    return discoveries

In [None]:
%%writefile mcts/tree_search.py
"""Stochastic MCTS with macro-operator support.

This implements Monte Carlo Tree Search for Stochastic MuZero with:
1. Alternating decision and chance nodes
2. Entropy-based chance node expansion (enumerate vs sample)
3. Macro-operator lookup and usage during planning
4. Proper value backup through stochastic branches
"""

from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F

from .node import Node, MinMaxStats
from .macro_cache import MacroCache, MacroOperator


@dataclass
class MCTSConfig:
    """Configuration for MCTS."""

    num_simulations: int = 50
    discount: float = 0.997
    pb_c_base: float = 19652.0
    pb_c_init: float = 1.25
    root_dirichlet_alpha: float = 0.3
    root_exploration_fraction: float = 0.25

    # Chance node handling
    entropy_threshold: float = 0.5
    top_k_chances: int = 5

    # Macro support
    use_macros: bool = True
    macro_verification_threshold: float = 0.1

    # Action space size (needed for policy output)
    action_space_size: int = 4

    # Two-player support
    is_two_player: bool = False


class StochasticMCTS:
    """
    Monte Carlo Tree Search for Stochastic MuZero.

    Handles:
    - Decision nodes: Agent chooses action
    - Chance nodes: Environment samples outcome
    - Macro-operators: Skip deterministic segments
    """

    def __init__(
        self,
        model: torch.nn.Module,
        config: MCTSConfig,
        macro_cache: Optional[MacroCache] = None,
    ):
        self.model = model
        self.config = config
        self.macro_cache = macro_cache
        self.min_max_stats = MinMaxStats()

    @torch.no_grad()
    def search(
        self,
        observation: torch.Tensor,
        legal_actions: List[int],
        add_exploration_noise: bool = True,
    ) -> Node:
        """
        Run MCTS from the given observation.

        Args:
            observation: Root observation (batch_size=1, observation_dim)
            legal_actions: List of legal actions at root
            add_exploration_noise: Whether to add Dirichlet noise at root

        Returns:
            Root node after search
        """
        # Ensure batch dimension
        if observation.dim() == 1:
            observation = observation.unsqueeze(0)

        # Initial inference at root
        initial = self.model.initial_inference(observation)

        # Create root node
        root = Node(to_play=0)
        root.expand(
            hidden_state=initial.state,
            policy_logits=initial.policy_logits.squeeze(0),
            legal_actions=legal_actions,
            is_chance_node=True,  # Children are chance nodes (afterstates)
        )

        # Add exploration noise
        if add_exploration_noise:
            root.add_exploration_noise(
                dirichlet_alpha=self.config.root_dirichlet_alpha,
                exploration_fraction=self.config.root_exploration_fraction,
            )

        # Run simulations
        for _ in range(self.config.num_simulations):
            node = root
            search_path = [node]

            # Selection: traverse tree until leaf
            while node.expanded:
                # Check for applicable macros
                if (
                    self.config.use_macros
                    and self.macro_cache is not None
                    and not node.is_chance_node
                ):
                    macro = self._try_macro(node, search_path)
                    if macro is not None:
                        node = macro
                        continue

                # Normal selection
                action, child = node.select_child(
                    pb_c_base=self.config.pb_c_base,
                    pb_c_init=self.config.pb_c_init,
                    discount=self.config.discount,
                    min_max_stats=self.min_max_stats,
                )
                search_path.append(child)
                node = child

            # Expansion
            parent = search_path[-2] if len(search_path) > 1 else None
            value = self._expand(node, parent)

            # Backpropagation
            self._backpropagate(search_path, value)

        return root

    def _try_macro(
        self,
        node: Node,
        search_path: List[Node],
    ) -> Optional[Node]:
        """
        Try to use a macro-operator from this node.

        Returns the node reached after applying the macro,
        or None if no applicable macro was found.
        """
        if node.hidden_state is None:
            return None

        # Get legal actions (actions with children)
        legal_actions = list(node.children.keys())

        # Find applicable macros
        macros = self.macro_cache.get_applicable_macros(
            state=node.hidden_state.squeeze(0),
            legal_actions=legal_actions,
        )

        if not macros:
            return None

        # Try the highest-confidence macro
        macro = macros[0]

        # Verify macro is still valid (low entropy)
        current_state = node.hidden_state
        total_reward = 0.0
        max_entropy = 0.0

        for action in macro.action_sequence:
            # Ensure action is still legal
            if action not in node.children:
                self.macro_cache.update_macro(macro.id, success=False, entropy=1.0)
                return None

            action_tensor = torch.tensor([action], device=current_state.device)

            # Get dynamics
            dynamics_out = self.model.recurrent_inference(
                current_state, action_tensor
            )

            max_entropy = max(max_entropy, dynamics_out.chance_entropy.item())

            # Check if still deterministic
            if max_entropy > self.config.macro_verification_threshold:
                self.macro_cache.update_macro(macro.id, success=False, entropy=max_entropy)
                return None

            # Most likely chance outcome
            chance = torch.argmax(dynamics_out.chance_logits, dim=-1)
            next_state, reward_logits = self.model.dynamics(
                dynamics_out.afterstate, chance
            )

            # Get scalar reward
            reward_probs = F.softmax(reward_logits, dim=-1)
            reward = self._support_to_scalar(reward_probs).item()
            total_reward += reward * (self.config.discount ** len(search_path))

            current_state = next_state

            # Update search path through macro
            # Create virtual nodes for the path
            virtual_node = Node(
                hidden_state=dynamics_out.afterstate,
                is_chance_node=True,
                reward=reward,
                parent=search_path[-1],
                action_from_parent=action,
                macro_id=macro.id,
            )
            search_path.append(virtual_node)

        # Success - update macro statistics
        self.macro_cache.update_macro(macro.id, success=True, entropy=max_entropy)

        # Create final node after macro
        final_node = Node(
            hidden_state=current_state,
            is_chance_node=False,
            reward=total_reward,
            parent=search_path[-1],
            macro_id=macro.id,
            macro_confidence=macro.confidence,
        )
        search_path.append(final_node)

        return final_node

    def _expand(self, node: Node, parent: Optional[Node]) -> float:
        """
        Expand a leaf node and return its value.

        Args:
            node: Leaf node to expand
            parent: Parent node (needed for dynamics)

        Returns:
            Value estimate for backpropagation
        """
        if parent is None:
            # Root node already expanded in search()
            return 0.0

        action = node.action_from_parent
        parent_state = parent.hidden_state

        if node.is_chance_node:
            # This is a chance node (after action, before environment response)
            # Expand with chance outcomes

            action_tensor = torch.tensor([action], device=parent_state.device)
            dynamics_out = self.model.recurrent_inference(parent_state, action_tensor)

            # Expand chance node
            chance_indices, enumerated = node.expand_chance(
                hidden_state=dynamics_out.afterstate,
                chance_logits=dynamics_out.chance_logits.squeeze(0),
                top_k=self.config.top_k_chances,
                entropy_threshold=self.config.entropy_threshold,
            )

            # For two-player games, set to_play on decision node children
            if self.config.is_two_player and parent is not None:
                parent_to_play = parent.to_play if parent.to_play >= 0 else 0
                for child in node.children.values():
                    child.to_play = 1 - parent_to_play

            # Get afterstate value
            _, value_logits = self.model.predict_afterstate(dynamics_out.afterstate)
            value_probs = F.softmax(value_logits, dim=-1)
            value = self._support_to_scalar(value_probs).item()

            return value
        else:
            # This is a decision node (after chance resolved)
            # Need to compute state from parent's afterstate + chance

            # Get chance outcome that led here
            chance = node.action_from_parent
            chance_tensor = torch.tensor([chance], device=parent_state.device)

            # Parent is afterstate, compute next state
            dyn_features = self.model.dynamics_trunk(torch.cat([parent_state, F.one_hot(chance_tensor, self.model.chance_space_size).float()], dim=-1))
            next_state = self.model.dynamics_state_head(dyn_features)
            reward_logits = self.model.dynamics_reward_head(dyn_features)

            # Get reward
            reward_probs = F.softmax(reward_logits, dim=-1)
            reward = self._support_to_scalar(reward_probs).item()
            node.reward = reward

            # Get policy and value at next state
            policy_logits, value_logits = self.model.predict_state(next_state)

            # Expand with all actions (no legal action filtering in latent space)
            node.expand(
                hidden_state=next_state,
                policy_logits=policy_logits.squeeze(0),
                legal_actions=list(range(policy_logits.shape[-1])),
                is_chance_node=True,  # Children will be chance nodes
            )

            value_probs = F.softmax(value_logits, dim=-1)
            value = self._support_to_scalar(value_probs).item()

            return value

    def _backpropagate(self, search_path: List[Node], value: float) -> None:
        """
        Backpropagate value through the search path.

        For single-player: straightforward value backup.
        For two-player: negate value at decision node boundaries (zero-sum).
        Values at each node are stored from that node's player's perspective.

        Args:
            search_path: Path from root to leaf
            value: Value at leaf node (from leaf player's perspective)
        """
        for node in reversed(search_path):
            node.visit_count += 1
            node.value_sum += value

            # Update min-max stats
            self.min_max_stats.update(node.reward + self.config.discount * value)

            # Compute backed-up value from this node's perspective
            backed_value = node.reward + self.config.discount * value

            # For two-player zero-sum games, negate at decision nodes
            # since the parent decision node belongs to the opponent
            if self.config.is_two_player and not node.is_chance_node:
                value = -backed_value
            else:
                value = backed_value

    def _support_to_scalar(self, probs: torch.Tensor) -> torch.Tensor:
        """Convert categorical support to scalar value."""
        support_size = (probs.shape[-1] - 1) // 2
        support = torch.arange(
            -support_size, support_size + 1,
            device=probs.device, dtype=probs.dtype
        )
        expected = (probs * support).sum(dim=-1)

        # Inverse transformation
        eps = 0.001
        sign = torch.sign(expected)
        abs_expected = torch.abs(expected)
        return sign * ((abs_expected + 1).square() - 1) / (1 + 2 * eps * (abs_expected + 1))

    def get_action_policy(
        self, root: Node, temperature: float = 1.0
    ) -> Tuple[int, np.ndarray]:
        """
        Get action and policy from search results.

        Args:
            root: Root node after search
            temperature: Temperature for action selection

        Returns:
            (selected_action, policy_distribution)
        """
        actions, probs = root.get_policy()

        # Create full policy array sized by action space, not number of children
        policy = np.zeros(self.config.action_space_size)
        for action, prob in zip(actions, probs):
            policy[action] = prob

        # Select action
        selected = root.select_action(temperature=temperature)

        return selected, policy


def run_mcts(
    model: torch.nn.Module,
    observation: torch.Tensor,
    legal_actions: List[int],
    config: Optional[MCTSConfig] = None,
    macro_cache: Optional[MacroCache] = None,
    add_noise: bool = True,
) -> Tuple[int, np.ndarray, float, Node]:
    """
    Convenience function to run MCTS and get results.

    Args:
        model: MuZero network
        observation: Current observation
        legal_actions: Legal actions
        config: MCTS configuration (uses defaults if None)
        macro_cache: Optional macro cache
        add_noise: Whether to add exploration noise

    Returns:
        (action, policy, root_value, root_node)
    """
    if config is None:
        config = MCTSConfig()

    mcts = StochasticMCTS(model, config, macro_cache)
    root = mcts.search(observation, legal_actions, add_exploration_noise=add_noise)

    action, policy = mcts.get_action_policy(root, temperature=1.0)
    root_value = root.value

    return action, policy, root_value, root

## 6. Training Components

In [None]:
%%writefile training/__init__.py
from .replay_buffer import ReplayBuffer, GameHistory
from .trainer import Trainer, TrainerConfig

__all__ = ["ReplayBuffer", "GameHistory", "Trainer", "TrainerConfig"]

In [None]:
%%writefile training/replay_buffer.py
"""Replay buffer for Stochastic MuZero.

Stores game trajectories with:
- Observations, actions, rewards
- MCTS policies and values
- Chance outcomes (for stochastic games)
- Entropy at each transition (for macro discovery)

Supports prioritized experience replay based on TD error.
"""

from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Any
import numpy as np
import torch
from collections import deque


@dataclass
class GameHistory:
    """
    Complete history of a single game/episode.

    Stores all information needed for training and macro discovery.
    """

    # Core trajectory data
    observations: List[torch.Tensor] = field(default_factory=list)
    actions: List[int] = field(default_factory=list)
    rewards: List[float] = field(default_factory=list)

    # MCTS outputs
    policies: List[np.ndarray] = field(default_factory=list)
    root_values: List[float] = field(default_factory=list)

    # Stochastic MuZero specific
    chance_outcomes: List[int] = field(default_factory=list)
    afterstates: List[torch.Tensor] = field(default_factory=list)

    # Macro discovery data
    entropies: List[float] = field(default_factory=list)
    latent_states: List[torch.Tensor] = field(default_factory=list)

    # Two-player game support
    to_play: List[int] = field(default_factory=list)  # Player at each step (0 or 1)
    is_two_player: bool = False

    # Priority sampling
    priorities: Optional[np.ndarray] = None
    game_priority: float = 1.0

    # Metadata
    total_reward: float = 0.0
    max_tile: int = 0  # For 2048
    length: int = 0

    def append(
        self,
        observation: torch.Tensor,
        action: int,
        reward: float,
        policy: np.ndarray,
        root_value: float,
        chance_outcome: int = 0,
        entropy: float = 0.0,
        latent_state: Optional[torch.Tensor] = None,
        afterstate: Optional[torch.Tensor] = None,
        player: int = 0,
    ) -> None:
        """Append a transition to the history."""
        self.observations.append(observation)
        self.actions.append(action)
        self.rewards.append(reward)
        self.policies.append(policy)
        self.root_values.append(root_value)
        self.chance_outcomes.append(chance_outcome)
        self.entropies.append(entropy)
        self.to_play.append(player)

        if latent_state is not None:
            self.latent_states.append(latent_state)
        if afterstate is not None:
            self.afterstates.append(afterstate)

        self.total_reward += reward
        self.length += 1

    def compute_target_values(
        self,
        discount: float,
        td_steps: int,
    ) -> List[float]:
        """
        Compute n-step return targets.

        For single-player: target_t = r_t + gr_{t+1} + ... + g^n v_{t+n}
        For two-player: rewards and bootstrap values are negated when the
        player changes, since values are from the current player's perspective.

        Args:
            discount: Discount factor
            td_steps: Number of steps for TD target (n)

        Returns:
            List of target values for each position
        """
        targets = []
        n = len(self.rewards)

        for i in range(n):
            value = 0.0
            for j in range(td_steps):
                if i + j < n:
                    # For two-player games, negate reward when player differs
                    if self.is_two_player and self.to_play:
                        same_player = self.to_play[i + j] == self.to_play[i]
                        sign = 1.0 if same_player else -1.0
                    else:
                        sign = 1.0
                    value += (discount ** j) * sign * self.rewards[i + j]
                else:
                    break

            # Bootstrap from value estimate
            bootstrap_idx = i + td_steps
            if bootstrap_idx < n:
                if self.is_two_player and self.to_play:
                    same_player = self.to_play[bootstrap_idx] == self.to_play[i]
                    sign = 1.0 if same_player else -1.0
                else:
                    sign = 1.0
                value += (discount ** td_steps) * sign * self.root_values[bootstrap_idx]

            targets.append(value)

        return targets

    def compute_priorities(self, td_steps: int, discount: float) -> None:
        """Compute priority scores based on TD error."""
        targets = self.compute_target_values(discount, td_steps)

        # Priority = |target - root_value|
        self.priorities = np.array([
            abs(target - root_value)
            for target, root_value in zip(targets, self.root_values)
        ])

        # Game priority = max priority in game
        self.game_priority = float(np.max(self.priorities)) if len(self.priorities) > 0 else 1.0


@dataclass
class Batch:
    """Training batch."""

    observations: torch.Tensor  # (batch, observation_dim)
    actions: torch.Tensor  # (batch, K)
    target_values: torch.Tensor  # (batch, K+1)
    target_rewards: torch.Tensor  # (batch, K)
    target_policies: torch.Tensor  # (batch, K+1, action_space)
    chance_outcomes: torch.Tensor  # (batch, K)
    weights: torch.Tensor  # (batch,) importance sampling weights

    # Indices for priority updates
    game_indices: List[int] = field(default_factory=list)
    position_indices: List[int] = field(default_factory=list)


class ReplayBuffer:
    """
    Replay buffer with prioritized sampling.

    Stores complete game histories and samples positions
    for training with importance sampling.
    """

    def __init__(
        self,
        capacity: int = 100000,
        batch_size: int = 256,
        num_unroll_steps: int = 5,
        td_steps: int = 10,
        discount: float = 0.997,
        priority_alpha: float = 1.0,
        priority_beta: float = 1.0,
    ):
        self.capacity = capacity
        self.batch_size = batch_size
        self.num_unroll_steps = num_unroll_steps
        self.td_steps = td_steps
        self.discount = discount
        self.priority_alpha = priority_alpha
        self.priority_beta = priority_beta

        # Storage
        self.games: deque = deque(maxlen=capacity)
        self.total_positions = 0

        # Statistics
        self.games_added = 0
        self.total_samples = 0

    def save_game(self, game: GameHistory) -> None:
        """Add a completed game to the buffer."""
        # Compute priorities
        game.compute_priorities(self.td_steps, self.discount)

        # Update total positions
        if len(self.games) == self.games.maxlen:
            old_game = self.games[0]
            self.total_positions -= old_game.length

        self.games.append(game)
        self.total_positions += game.length
        self.games_added += 1

    def sample_batch(self, device: torch.device = torch.device("cpu")) -> Batch:
        """
        Sample a batch of positions for training.

        Uses prioritized sampling at both game and position level.

        Returns:
            Batch object with all training data
        """
        # Compute game sampling probabilities
        game_priorities = np.array([g.game_priority ** self.priority_alpha for g in self.games])
        game_probs = game_priorities / game_priorities.sum()

        # Sample games
        game_indices = np.random.choice(
            len(self.games),
            size=self.batch_size,
            p=game_probs,
            replace=True,
        )

        # Sample positions within games
        observations = []
        actions = []
        target_values = []
        target_rewards = []
        target_policies = []
        chance_outcomes = []
        weights = []
        position_indices = []

        for game_idx in game_indices:
            game = self.games[game_idx]

            # Sample position within game
            if game.priorities is not None:
                pos_priorities = game.priorities ** self.priority_alpha
                pos_probs = pos_priorities / pos_priorities.sum()
                pos_idx = np.random.choice(len(game.observations), p=pos_probs)
            else:
                pos_idx = np.random.randint(len(game.observations))

            # Compute importance sampling weight
            total_prob = game_probs[game_idx]
            if game.priorities is not None:
                total_prob *= pos_probs[pos_idx]
            weight = (1.0 / (self.total_positions * total_prob)) ** self.priority_beta

            # Get observation
            obs = game.observations[pos_idx]
            observations.append(obs)

            # Get action sequence (pad if needed)
            action_seq = []
            reward_seq = []
            chance_seq = []
            value_targets = []
            policy_targets = []

            # Compute value targets
            all_targets = game.compute_target_values(self.discount, self.td_steps)
            value_targets.append(all_targets[pos_idx])

            for k in range(self.num_unroll_steps):
                idx = pos_idx + k
                if idx < len(game.actions):
                    action_seq.append(game.actions[idx])
                    reward_seq.append(game.rewards[idx])
                    chance_seq.append(game.chance_outcomes[idx])
                    policy_targets.append(game.policies[idx])
                    if idx + 1 < len(all_targets):
                        value_targets.append(all_targets[idx + 1])
                    else:
                        value_targets.append(0.0)
                else:
                    # Pad with zeros
                    action_seq.append(0)
                    reward_seq.append(0.0)
                    chance_seq.append(0)
                    policy_targets.append(game.policies[-1])
                    value_targets.append(0.0)

            # Initial policy
            policy_targets.insert(0, game.policies[pos_idx])

            actions.append(action_seq)
            target_rewards.append(reward_seq)
            target_values.append(value_targets)
            target_policies.append(policy_targets)
            chance_outcomes.append(chance_seq)
            weights.append(weight)
            position_indices.append(pos_idx)

        # Normalize weights
        weights = np.array(weights)
        weights = weights / weights.max()

        # Convert to tensors
        self.total_samples += self.batch_size

        # Stack observations
        obs_tensor = torch.stack(observations).to(device)

        # Convert action space size from policies
        action_space_size = len(target_policies[0][0])

        # Pad policies to same size
        padded_policies = []
        for policy_seq in target_policies:
            padded_seq = []
            for p in policy_seq:
                if len(p) < action_space_size:
                    padded = np.zeros(action_space_size)
                    padded[:len(p)] = p
                    padded_seq.append(padded)
                else:
                    padded_seq.append(p)
            padded_policies.append(padded_seq)

        return Batch(
            observations=obs_tensor,
            actions=torch.tensor(actions, dtype=torch.long, device=device),
            target_values=torch.tensor(target_values, dtype=torch.float32, device=device),
            target_rewards=torch.tensor(target_rewards, dtype=torch.float32, device=device),
            target_policies=torch.tensor(
                np.array(padded_policies), dtype=torch.float32, device=device
            ),
            chance_outcomes=torch.tensor(chance_outcomes, dtype=torch.long, device=device),
            weights=torch.tensor(weights, dtype=torch.float32, device=device),
            game_indices=list(game_indices),
            position_indices=position_indices,
        )

    def update_priorities(
        self,
        game_indices: List[int],
        position_indices: List[int],
        td_errors: np.ndarray,
    ) -> None:
        """Update priorities based on new TD errors."""
        for game_idx, pos_idx, error in zip(game_indices, position_indices, td_errors):
            if game_idx < len(self.games):
                game = self.games[game_idx]
                if game.priorities is not None and pos_idx < len(game.priorities):
                    game.priorities[pos_idx] = abs(error)
                    game.game_priority = float(np.max(game.priorities))

    def get_statistics(self) -> Dict[str, float]:
        """Get buffer statistics."""
        return {
            "num_games": len(self.games),
            "total_positions": self.total_positions,
            "games_added": self.games_added,
            "total_samples": self.total_samples,
            "avg_game_length": (
                self.total_positions / len(self.games) if self.games else 0.0
            ),
            "avg_total_reward": (
                np.mean([g.total_reward for g in self.games]) if self.games else 0.0
            ),
        }

    def __len__(self) -> int:
        return self.total_positions

In [None]:
%%writefile training/trainer.py
"""Training loop for Stochastic MuZero."""

from dataclasses import dataclass
from typing import Dict, List
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from .replay_buffer import Batch
import sys
sys.path.append('..')
from utils.support import scalar_to_support, compute_cross_entropy_loss


@dataclass
class TrainerConfig:
    learning_rate: float = 1e-3
    weight_decay: float = 1e-4
    max_grad_norm: float = 5.0
    policy_loss_weight: float = 1.0
    value_loss_weight: float = 0.5
    reward_loss_weight: float = 1.0
    chance_loss_weight: float = 1.0
    support_size: int = 31


class Trainer:
    def __init__(self, model, config: TrainerConfig, device=torch.device("cpu")):
        self.model = model.to(device)
        self.config = config
        self.device = device
        self.optimizer = Adam(self.model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
        self.training_step = 0
        self.loss_history = {"total": [], "policy": [], "value": [], "reward": [], "chance": []}

    def train_step(self, batch: Batch) -> Dict[str, float]:
        self.model.train()
        self.optimizer.zero_grad()
        unroll_outputs = self.model.unroll(observation=batch.observations, actions=batch.actions, chances=batch.chance_outcomes)
        
        K = batch.actions.shape[1]
        
        # Policy loss
        policy_loss = torch.tensor(0.0, device=self.device)
        for k in range(K + 1):
            predicted_policy = F.log_softmax(unroll_outputs["policy_logits"][k], dim=-1)
            target_policy = batch.target_policies[:, k, :]
            mask = target_policy.sum(dim=-1) > 0
            if mask.any():
                policy_loss += -(target_policy[mask] * predicted_policy[mask]).sum(dim=-1).mean()
        policy_loss = policy_loss / (K + 1)
        
        # Value loss
        value_loss = torch.tensor(0.0, device=self.device)
        for k in range(K + 1):
            value_loss += compute_cross_entropy_loss(unroll_outputs["value_logits"][k], batch.target_values[:, k], self.config.support_size)
        value_loss = value_loss / (K + 1)
        
        # Reward loss
        reward_loss = torch.tensor(0.0, device=self.device)
        for k in range(K):
            reward_loss += compute_cross_entropy_loss(unroll_outputs["reward_logits"][k], batch.target_rewards[:, k], self.config.support_size)
        reward_loss = reward_loss / max(K, 1)
        
        # Chance loss
        chance_loss = torch.tensor(0.0, device=self.device)
        for k in range(K):
            predicted_chance = F.log_softmax(unroll_outputs["chance_logits"][k], dim=-1)
            target_chance = batch.chance_outcomes[:, k]
            mask = target_chance >= 0
            if mask.any():
                chance_loss += F.nll_loss(predicted_chance[mask], target_chance[mask], reduction="mean")
        chance_loss = chance_loss / max(K, 1)
        
        total_loss = (self.config.policy_loss_weight * policy_loss + self.config.value_loss_weight * value_loss +
                     self.config.reward_loss_weight * reward_loss + self.config.chance_loss_weight * chance_loss)
        
        total_loss.backward()
        nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
        self.optimizer.step()
        
        self.training_step += 1
        return {"total": total_loss.item(), "policy": policy_loss.item(), "value": value_loss.item(),
                "reward": reward_loss.item(), "chance": chance_loss.item()}

    def save_checkpoint(self, path: str):
        torch.save({"model_state_dict": self.model.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(),
                   "training_step": self.training_step}, path)

    def load_checkpoint(self, path: str):
        checkpoint = torch.load(path, map_location=self.device)
        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        self.training_step = checkpoint["training_step"]

## 7. Quick Test

In [None]:
# Test 2048 game
from games.game_2048 import Game2048

game = Game2048()
state = game.reset()
print("Initial state:")
print(game.render(state))

# Play a few random moves
for i in range(5):
    legal = game.legal_actions(state)
    if not legal:
        break
    action = np.random.choice(legal)
    result = game.step(state, action)
    state = result.next_state
    print(f"\nAfter action {['up', 'right', 'down', 'left'][action]} (reward={result.reward}):")
    print(game.render(state))

In [None]:
# Test model creation
from networks.muzero_network import MuZeroNetwork
from utils.config import Config

config = Config()
model = MuZeroNetwork(
    observation_dim=config.observation_dim,
    action_space_size=config.action_space_size,
    chance_space_size=config.chance_space_size,
    state_dim=config.state_dim,
    hidden_dim=config.hidden_dim,
).to(DEVICE)

print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters")

# Test forward pass
obs = game.encode_state(state).unsqueeze(0).to(DEVICE)
output = model.initial_inference(obs)
print(f"State shape: {output.state.shape}")
print(f"Policy shape: {output.policy_logits.shape}")
print(f"Value shape: {output.value_logits.shape}")

## 8. Training Loop

In [None]:
from mcts.tree_search import StochasticMCTS, MCTSConfig
from mcts.macro_cache import MacroCache, discover_macros_from_trajectory
from training.replay_buffer import ReplayBuffer, GameHistory
from training.trainer import Trainer, TrainerConfig
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

# Configuration
NUM_ITERATIONS = 100  # Increase for better results
GAMES_PER_ITERATION = 5
BATCHES_PER_ITERATION = 20
NUM_SIMULATIONS = 25  # Reduced for speed

# Initialize components
game = Game2048()
model = MuZeroNetwork(
    observation_dim=config.observation_dim,
    action_space_size=config.action_space_size,
    chance_space_size=config.chance_space_size,
    state_dim=config.state_dim,
    hidden_dim=config.hidden_dim,
).to(DEVICE)

macro_cache = MacroCache(state_dim=config.state_dim, entropy_threshold=config.entropy_threshold)
replay_buffer = ReplayBuffer(capacity=10000, batch_size=64, num_unroll_steps=config.num_unroll_steps)
trainer = Trainer(model, TrainerConfig(), DEVICE)
mcts_config = MCTSConfig(num_simulations=NUM_SIMULATIONS)

# Tracking
rewards_history = []
max_tiles_history = []
loss_history = []

In [None]:
def play_game_with_mcts(game, model, mcts_config, device, max_moves=500):
    """Play one game with MCTS."""
    mcts = StochasticMCTS(model, mcts_config)
    state = game.reset()
    history = GameHistory()
    
    model.eval()
    for step in range(max_moves):
        if game.is_terminal(state):
            break
            
        obs = game.encode_state(state).to(device)
        legal = game.legal_actions(state)
        if not legal:
            break
            
        root = mcts.search(obs.unsqueeze(0), legal, add_exploration_noise=True)
        temp = 1.0 if step < 30 else 0.1
        action, policy = mcts.get_action_policy(root, temp)
        
        result = game.step(state, action)
        
        # Get entropy
        with torch.no_grad():
            dyn_out = model.recurrent_inference(root.hidden_state, torch.tensor([action], device=device))
            entropy = dyn_out.chance_entropy.item()
        
        # Pad policy to action_space_size
        full_policy = np.zeros(game.action_space_size)
        for i, p in enumerate(policy):
            if i < len(full_policy):
                full_policy[i] = p
        
        history.append(
            observation=obs.cpu(),
            action=action,
            reward=result.reward,
            policy=full_policy,
            root_value=root.value,
            chance_outcome=result.chance_outcome,
            entropy=entropy,
            latent_state=root.hidden_state.cpu() if root.hidden_state is not None else None,
        )
        
        state = result.next_state
    
    history.max_tile = game.get_max_tile(state)
    return history

In [None]:
# Training loop
print(f"Starting training for {NUM_ITERATIONS} iterations...")
print(f"  Games per iteration: {GAMES_PER_ITERATION}")
print(f"  MCTS simulations: {NUM_SIMULATIONS}")
print()

for iteration in tqdm(range(NUM_ITERATIONS), desc="Training"):
    # Self-play
    iter_rewards = []
    iter_max_tiles = []
    
    for _ in range(GAMES_PER_ITERATION):
        history = play_game_with_mcts(game, model, mcts_config, DEVICE, max_moves=500)
        replay_buffer.save_game(history)
        iter_rewards.append(history.total_reward)
        iter_max_tiles.append(history.max_tile)
    
    rewards_history.append(np.mean(iter_rewards))
    max_tiles_history.append(np.max(iter_max_tiles))
    
    # Training
    if len(replay_buffer) >= 64:
        iter_losses = []
        for _ in range(BATCHES_PER_ITERATION):
            batch = replay_buffer.sample_batch(DEVICE)
            losses = trainer.train_step(batch)
            iter_losses.append(losses["total"])
        loss_history.append(np.mean(iter_losses))
    
    # Logging
    if (iteration + 1) % 10 == 0:
        macro_stats = macro_cache.get_statistics()
        print(f"\nIter {iteration + 1}: Reward={rewards_history[-1]:.0f}, MaxTile={max_tiles_history[-1]}, "
              f"Loss={loss_history[-1] if loss_history else 0:.4f}, Macros={macro_stats['num_macros']}")

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].plot(rewards_history)
axes[0].set_xlabel("Iteration")
axes[0].set_ylabel("Average Reward")
axes[0].set_title("Training Reward")

axes[1].plot(max_tiles_history)
axes[1].set_xlabel("Iteration")
axes[1].set_ylabel("Max Tile")
axes[1].set_title("Best Tile Achieved")
axes[1].set_yscale('log', base=2)

if loss_history:
    axes[2].plot(loss_history)
    axes[2].set_xlabel("Iteration")
    axes[2].set_ylabel("Loss")
    axes[2].set_title("Training Loss")

plt.tight_layout()
plt.show()

## 9. Evaluation

In [None]:
# Play a game with visualization
print("Playing a game with trained model...\n")

model.eval()
mcts = StochasticMCTS(model, MCTSConfig(num_simulations=50))

state = game.reset()
print("Initial:")
print(game.render(state))

total_reward = 0
for step in range(200):
    if game.is_terminal(state):
        break
    
    obs = game.encode_state(state).to(DEVICE)
    legal = game.legal_actions(state)
    if not legal:
        break
    
    root = mcts.search(obs.unsqueeze(0), legal, add_exploration_noise=False)
    action, _ = mcts.get_action_policy(root, temperature=0.0)  # Greedy
    
    result = game.step(state, action)
    total_reward += result.reward
    state = result.next_state
    
    if step % 50 == 49:
        print(f"\nStep {step + 1}:")
        print(game.render(state))

print(f"\nFinal state:")
print(game.render(state))
print(f"\nTotal reward: {total_reward}")
print(f"Max tile: {game.get_max_tile(state)}")

In [None]:
# Save model
trainer.save_checkpoint("runs/stochastic_muzero_2048.pt")
print("Model saved to runs/stochastic_muzero_2048.pt")

## 10. Macro Analysis

Analyze the discovered macro-operators to see what temporal abstractions emerged.

In [None]:
# Analyze macros
stats = macro_cache.get_statistics()
print("Macro-Operator Statistics:")
print(f"  Total macros discovered: {stats['total_discoveries']}")
print(f"  Active macros: {stats['num_macros']}")
print(f"  Total uses: {stats['total_uses']}")
print(f"  Success rate: {stats['success_rate']:.2%}")
print(f"  Average confidence: {stats['avg_confidence']:.3f}")
print(f"  Average length: {stats['avg_length']:.1f} steps")

# Show top macros
if macro_cache.macros:
    print("\nTop 10 macros by usage:")
    sorted_macros = sorted(macro_cache.macros.values(), key=lambda m: m.usage_count, reverse=True)[:10]
    action_names = ["up", "right", "down", "left"]
    for m in sorted_macros:
        actions_str = " -> ".join(action_names[a] for a in m.action_sequence)
        print(f"  [{actions_str}] uses={m.usage_count}, success={m.success_rate:.2%}, conf={m.confidence:.3f}")

---

## Next Steps

1. **Train longer**: Increase `NUM_ITERATIONS` to 500-1000 for better performance
2. **More MCTS simulations**: Increase `NUM_SIMULATIONS` to 100-200
3. **Add Go environment**: Implement Go game interface for fully deterministic setting
4. **Analyze entropy distribution**: Track how entropy separates deterministic vs stochastic transitions
5. **Measure planning speedup**: Compare planning with and without macro usage

## 11. Deterministic Games: Tic-Tac-Toe & Chess

These fully deterministic two-player games have entropy ≈ 0 at every transition, making every trajectory segment a macro candidate. This validates the macro discovery pipeline.

In [None]:
# Test Tic-Tac-Toe
from games.game_tictactoe import GameTicTacToe

ttt = GameTicTacToe()
state = ttt.reset()
print("Tic-Tac-Toe test:")
print(ttt.render(state))
print(f"Legal actions: {ttt.legal_actions(state)}")
print(f"Observation shape: {ttt.encode_state(state).shape}")
print(f"Is two-player: {ttt.is_two_player}")
print(f"Chance space: {ttt.chance_space_size}")

# Play a random game
state = ttt.reset()
while not ttt.is_terminal(state):
    actions = ttt.legal_actions(state)
    action = np.random.choice(actions)
    result = ttt.step(state, action)
    state = result.next_state
print(f"\nRandom game result:\n{ttt.render(state)}")

In [None]:
# Train on Tic-Tac-Toe (deterministic two-player game)
TTT_ITERATIONS = 200
TTT_GAMES = 10
TTT_BATCHES = 20
TTT_SIMS = 25

ttt_game = GameTicTacToe()
ttt_model = MuZeroNetwork(
    observation_dim=27,  # 3 planes x 3x3
    action_space_size=9,
    chance_space_size=1,
    state_dim=128,
    hidden_dim=64,
).to(DEVICE)

ttt_macro_cache = MacroCache(state_dim=128, entropy_threshold=0.1)
ttt_buffer = ReplayBuffer(capacity=5000, batch_size=64, num_unroll_steps=3, td_steps=5, discount=1.0)
ttt_trainer = Trainer(ttt_model, TrainerConfig(support_size=31), DEVICE)
ttt_mcts_config = MCTSConfig(num_simulations=TTT_SIMS, action_space_size=9, is_two_player=True, discount=1.0)

ttt_rewards = []
ttt_losses = []

print(f"Training Tic-Tac-Toe for {TTT_ITERATIONS} iterations...")
for iteration in tqdm(range(TTT_ITERATIONS), desc="TTT Training"):
    iter_rewards = []
    for _ in range(TTT_GAMES):
        mcts = StochasticMCTS(ttt_model, ttt_mcts_config, ttt_macro_cache)
        state = ttt_game.reset()
        history = GameHistory(is_two_player=True)
        ttt_model.eval()
        
        for step in range(20):
            if ttt_game.is_terminal(state):
                break
            obs = ttt_game.encode_state(state).to(DEVICE)
            legal = ttt_game.legal_actions(state)
            if not legal:
                break
            root = mcts.search(obs.unsqueeze(0), legal, add_exploration_noise=True)
            temp = 1.0 if step < 5 else 0.1
            action, policy = mcts.get_action_policy(root, temp)
            result = ttt_game.step(state, action)
            
            with torch.no_grad():
                dyn_out = ttt_model.recurrent_inference(root.hidden_state, torch.tensor([action], device=DEVICE))
                entropy = dyn_out.chance_entropy.item()
            
            full_policy = np.zeros(9)
            for i, p in enumerate(policy):
                if i < 9:
                    full_policy[i] = p
            
            player = ttt_game.current_player(state)
            history.append(observation=obs.cpu(), action=action, reward=result.reward, policy=full_policy,
                          root_value=root.value, chance_outcome=result.chance_outcome, entropy=entropy,
                          latent_state=root.hidden_state.cpu() if root.hidden_state is not None else None,
                          player=player)
            state = result.next_state
        
        ttt_buffer.save_game(history)
        iter_rewards.append(history.total_reward)
    
    ttt_rewards.append(np.mean(iter_rewards))
    
    if len(ttt_buffer) >= 64:
        iter_losses = []
        for _ in range(TTT_BATCHES):
            batch = ttt_buffer.sample_batch(DEVICE)
            losses = ttt_trainer.train_step(batch)
            iter_losses.append(losses["total"])
        ttt_losses.append(np.mean(iter_losses))
    
    if (iteration + 1) % 50 == 0:
        macro_stats = ttt_macro_cache.get_statistics()
        print(f"\nIter {iteration+1}: Reward={ttt_rewards[-1]:.2f}, Loss={ttt_losses[-1] if ttt_losses else 0:.4f}, Macros={macro_stats['num_macros']}")

# Plot
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].plot(ttt_rewards)
axes[0].set_xlabel("Iteration"); axes[0].set_ylabel("Avg Reward"); axes[0].set_title("TTT Reward")
if ttt_losses:
    axes[1].plot(ttt_losses)
    axes[1].set_xlabel("Iteration"); axes[1].set_ylabel("Loss"); axes[1].set_title("TTT Loss")
plt.tight_layout(); plt.show()

# Macro analysis
stats = ttt_macro_cache.get_statistics()
print(f"\nTTT Macro Stats: {stats['num_macros']} macros, {stats['total_discoveries']} discovered")

In [None]:
# Test Chess
!pip install python-chess -q
from games.game_chess import GameChess

chess_game = GameChess()
state = chess_game.reset()
print("Chess test:")
print(chess_game.render(state))
print(f"\nLegal actions: {len(chess_game.legal_actions(state))} moves")
print(f"Observation shape: {chess_game.encode_state(state).shape}")
print(f"Action space: {chess_game.action_space_size}")
print(f"Is two-player: {chess_game.is_two_player}")

# Play a few random moves
for i in range(4):
    actions = chess_game.legal_actions(state)
    action = np.random.choice(actions)
    result = chess_game.step(state, action)
    state = result.next_state
print(f"\nAfter 4 random moves:\n{chess_game.render(state)}")

In [None]:
# Train on Chess (deterministic two-player game)
# Note: Chess is much more complex. This is a minimal run to validate the pipeline.
# For meaningful macro discovery, train for 1000+ iterations.

CHESS_ITERATIONS = 50
CHESS_GAMES = 3
CHESS_BATCHES = 15
CHESS_SIMS = 20
CHESS_MAX_MOVES = 80

chess_model = MuZeroNetwork(
    observation_dim=1408,  # 22 planes x 8x8
    action_space_size=4672,
    chance_space_size=1,
    state_dim=256,
    hidden_dim=128,
).to(DEVICE)

chess_macro_cache = MacroCache(state_dim=256, entropy_threshold=0.1)
chess_buffer = ReplayBuffer(capacity=10000, batch_size=64, num_unroll_steps=5, td_steps=10, discount=1.0)
chess_trainer = Trainer(chess_model, TrainerConfig(support_size=31), DEVICE)
chess_mcts_config = MCTSConfig(
    num_simulations=CHESS_SIMS,
    action_space_size=4672,
    is_two_player=True,
    discount=1.0,
)

chess_rewards = []
chess_losses = []
chess_game_lengths = []

print(f"Training Chess for {CHESS_ITERATIONS} iterations...")
print(f"  Model params: {sum(p.numel() for p in chess_model.parameters()):,}")
print(f"  Games per iteration: {CHESS_GAMES}")
print(f"  MCTS simulations: {CHESS_SIMS}")
print()

for iteration in tqdm(range(CHESS_ITERATIONS), desc="Chess Training"):
    iter_rewards = []
    iter_lengths = []

    for _ in range(CHESS_GAMES):
        mcts = StochasticMCTS(chess_model, chess_mcts_config, chess_macro_cache)
        state = chess_game.reset()
        history = GameHistory(is_two_player=True)
        chess_model.eval()

        for step in range(CHESS_MAX_MOVES):
            if chess_game.is_terminal(state):
                break
            obs = chess_game.encode_state(state).to(DEVICE)
            legal = chess_game.legal_actions(state)
            if not legal:
                break
            root = mcts.search(obs.unsqueeze(0), legal, add_exploration_noise=True)
            temp = 1.0 if step < 30 else 0.1
            action, policy = mcts.get_action_policy(root, temp)
            result = chess_game.step(state, action)

            with torch.no_grad():
                dyn_out = chess_model.recurrent_inference(root.hidden_state, torch.tensor([action], device=DEVICE))
                entropy = dyn_out.chance_entropy.item()

            player = chess_game.current_player(state)
            history.append(
                observation=obs.cpu(), action=action, reward=result.reward,
                policy=policy, root_value=root.value,
                chance_outcome=result.chance_outcome, entropy=entropy,
                latent_state=root.hidden_state.cpu() if root.hidden_state is not None else None,
                player=player,
            )
            state = result.next_state

        chess_buffer.save_game(history)
        iter_rewards.append(history.total_reward)
        iter_lengths.append(history.length)

    chess_rewards.append(np.mean(iter_rewards))
    chess_game_lengths.append(np.mean(iter_lengths))

    if len(chess_buffer) >= 64:
        iter_losses = []
        for _ in range(CHESS_BATCHES):
            batch = chess_buffer.sample_batch(DEVICE)
            losses = chess_trainer.train_step(batch)
            iter_losses.append(losses["total"])
        chess_losses.append(np.mean(iter_losses))

    if (iteration + 1) % 10 == 0:
        macro_stats = chess_macro_cache.get_statistics()
        avg_len = chess_game_lengths[-1] if chess_game_lengths else 0
        print(f"\nIter {iteration+1}: Reward={chess_rewards[-1]:.2f}, "
              f"Avg Length={avg_len:.0f}, "
              f"Loss={chess_losses[-1] if chess_losses else 0:.4f}, "
              f"Macros={macro_stats['num_macros']}")

# Plot
fig, axes = plt.subplots(1, 3, figsize=(16, 4))
axes[0].plot(chess_rewards)
axes[0].set_xlabel("Iteration"); axes[0].set_ylabel("Avg Reward"); axes[0].set_title("Chess Reward")
if chess_losses:
    axes[1].plot(chess_losses)
    axes[1].set_xlabel("Iteration"); axes[1].set_ylabel("Loss"); axes[1].set_title("Chess Loss")
axes[2].plot(chess_game_lengths)
axes[2].set_xlabel("Iteration"); axes[2].set_ylabel("Avg Length"); axes[2].set_title("Chess Game Length")
plt.tight_layout(); plt.show()

# Macro analysis
stats = chess_macro_cache.get_statistics()
print(f"\nChess Macro Stats: {stats['num_macros']} macros, {stats['total_discoveries']} discovered")
if stats['num_macros'] > 0:
    print("\nAll transitions are deterministic (entropy \u2248 0)")
    print("Every trajectory segment qualifies as a macro candidate!")
    top = chess_macro_cache.get_top_macros(5)
    for i, m in enumerate(top):
        print(f"  Macro {i+1}: actions={m.action_sequence}, confidence={m.confidence:.3f}, uses={m.usage_count}")
