# Stochastic MuZero with Learned Temporal Abstractions
## V2: Atari Games Edition

This notebook implements **Stochastic MuZero** with macro discovery on **deterministic Atari games**.

### Why Atari?
- **Single-player**: Agent controls all actions (no opponent uncertainty)
- **Deterministic physics**: Ball bounces, gravity, collisions are predictable
- **Rich temporal structure**: Positioning sequences, attack patterns, navigation

### Core Idea
Rules are **compressible causal structure** in the transition dynamics:
- If a sequence of transitions is deterministic and repeatable, collapse it into a **macro-operator**
- Macros enable faster planning and capture reusable temporal abstractions

### Games
- **Breakout**: Ball trajectory, paddle positioning
- **Pong**: Ball interception patterns
- **Space Invaders**: Enemy patterns, shooting sequences

## 1. Setup & Installation

In [None]:
# Install dependencies
!pip install torch numpy gymnasium ale-py tqdm matplotlib -q

# Install Atari ROMs (required for ALE games)
!pip install autorom[accept-rom-license] -q
!python -m atari_py.import_roms . 2>/dev/null || true

# Alternative ROM installation if above fails
import subprocess
try:
    subprocess.run(['AutoROM', '--accept-license'], capture_output=True, timeout=60)
    print("ROMs installed via AutoROM")
except:
    print("AutoROM not available, trying alternative...")
    !pip install ale-py[roms] -q 2>/dev/null || true

import torch
import numpy as np

print(f"\nPyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using: {DEVICE}")

In [None]:
# Test Atari environment
import gymnasium as gym
import ale_py

# Register ALE environments
gym.register_envs(ale_py)

try:
    env = gym.make("ALE/Breakout-v5")
    obs, _ = env.reset()
    print(f"Breakout works!")
    print(f"  Observation: {obs.shape}")
    print(f"  Actions: {env.action_space.n} - {env.unwrapped.get_action_meanings()}")
    env.close()
except Exception as e:
    print(f"Error: {e}")
    print("\nTrying to fix...")
    !pip install 'gymnasium[atari,accept-rom-license]' -q
    print("\nPlease restart runtime and run this cell again.")

## 2. Atari Game Wrapper

In [None]:
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple, Any
import gymnasium as gym
from gymnasium.wrappers import AtariPreprocessing, FrameStackObservation

# Register ALE environments
import ale_py
gym.register_envs(ale_py)


@dataclass
class AtariState:
    """Wrapper for Atari game state."""
    env: gym.Env
    observation: np.ndarray
    done: bool = False
    lives: int = 0
    score: int = 0


class AtariGame:
    """
    Atari game wrapper with deterministic mode for macro discovery.
    
    Key settings:
    - repeat_action_probability=0: No sticky actions (deterministic)
    - noop_max=0: No random initial noops
    - 84x84 grayscale, 4 stacked frames
    """
    
    ACTIONS = {
        0: "NOOP", 1: "FIRE", 2: "UP", 3: "RIGHT", 4: "LEFT", 5: "DOWN",
        6: "UPRIGHT", 7: "UPLEFT", 8: "DOWNRIGHT", 9: "DOWNLEFT",
        10: "UPFIRE", 11: "RIGHTFIRE", 12: "LEFTFIRE", 13: "DOWNFIRE",
    }

    def __init__(self, game_name: str = "Breakout", frame_stack: int = 4, 
                 frame_skip: int = 4, deterministic: bool = True):
        self.game_name = game_name
        self.frame_stack = frame_stack
        self.frame_skip = frame_skip
        self.deterministic = deterministic
        
        self._env = self._make_env()
        self.action_space_size = self._env.action_space.n
        self.action_meanings = self._env.unwrapped.get_action_meanings()
        self.observation_dim = 84 * 84 * frame_stack  # Flattened

    def _make_env(self) -> gym.Env:
        env = gym.make(
            f"ALE/{self.game_name}-v5",
            repeat_action_probability=0.0 if self.deterministic else 0.25,
            frameskip=1,
        )
        env = AtariPreprocessing(
            env,
            noop_max=0 if self.deterministic else 30,
            frame_skip=self.frame_skip,
            screen_size=84,
            grayscale_obs=True,
            grayscale_newaxis=True,
            scale_obs=True,
        )
        env = FrameStackObservation(env, self.frame_stack)
        return env

    def reset(self) -> AtariState:
        env = self._make_env()
        obs, info = env.reset()
        return AtariState(env=env, observation=obs, done=False, 
                         lives=info.get('lives', 0), score=0)

    def step(self, state: AtariState, action: int) -> Tuple[AtariState, float, bool]:
        if state.done:
            return state, 0.0, True
        
        obs, reward, terminated, truncated, info = state.env.step(action)
        done = terminated or truncated
        
        new_state = AtariState(
            env=state.env, observation=obs, done=done,
            lives=info.get('lives', state.lives),
            score=state.score + int(reward),
        )
        return new_state, float(reward), done

    def encode(self, state: AtariState) -> torch.Tensor:
        """Encode observation as flattened tensor."""
        obs = state.observation
        if obs.ndim == 4 and obs.shape[-1] == 1:
            obs = obs.squeeze(-1)  # (4, 84, 84)
        obs = np.asarray(obs, dtype=np.float32).reshape(-1)
        return torch.tensor(obs, dtype=torch.float32)

    def legal_actions(self, state: AtariState) -> List[int]:
        return [] if state.done else list(range(self.action_space_size))

    def action_name(self, action: int) -> str:
        if action < len(self.action_meanings):
            return self.action_meanings[action]
        return f"Action_{action}"


# Test
game = AtariGame("Breakout")
print(f"Game: {game.game_name}")
print(f"Observation dim: {game.observation_dim}")
print(f"Actions: {game.action_space_size} - {game.action_meanings}")


## 3. MuZero Network

In [None]:
import torch.nn as nn
import torch.nn.functional as F


class MuZeroNetwork(nn.Module):
    """
    MuZero network with entropy estimation for macro discovery.
    
    Key addition: The dynamics network outputs both next_state AND
    a confidence/entropy estimate. Low entropy = deterministic transition.
    """
    
    def __init__(self, obs_dim: int, action_dim: int, 
                 state_dim: int = 256, hidden_dim: int = 256):
        super().__init__()
        self.action_dim = action_dim
        self.state_dim = state_dim
        
        # Representation network: obs -> state
        self.representation = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, state_dim),
        )
        
        # Dynamics network: (state, action) -> next_state
        self.dynamics = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, state_dim),
        )
        
        # Entropy head: predicts uncertainty of transition
        # Low entropy = deterministic (macro-able)
        # High entropy = stochastic (can't compress)
        self.entropy_head = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),
            nn.Softplus(),  # Ensure positive entropy
        )
        
        # Reward prediction
        self.reward_head = nn.Sequential(
            nn.Linear(state_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),
        )
        
        # Policy head
        self.policy_head = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
        )
        
        # Value head
        self.value_head = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
        )
    
    def initial_inference(self, obs: torch.Tensor):
        """Initial inference: obs -> (state, policy, value)"""
        state = self.representation(obs)
        policy = self.policy_head(state)
        value = self.value_head(state).squeeze(-1)
        return state, policy, value
    
    def recurrent_inference(self, state: torch.Tensor, action: torch.Tensor):
        """
        Recurrent inference with entropy estimation.
        
        Returns: (next_state, reward, policy, value, entropy)
        
        entropy: Model's uncertainty about this transition.
                 Low = deterministic, good for macros.
                 High = stochastic, can't compress.
        """
        # One-hot encode action
        action_onehot = F.one_hot(action, self.action_dim).float()
        x = torch.cat([state, action_onehot], dim=-1)
        
        next_state = self.dynamics(x)
        entropy = self.entropy_head(x).squeeze(-1)  # Predicted transition entropy
        
        reward = self.reward_head(next_state).squeeze(-1)
        policy = self.policy_head(next_state)
        value = self.value_head(next_state).squeeze(-1)
        
        return next_state, reward, policy, value, entropy


# Test
model = MuZeroNetwork(game.observation_dim, game.action_space_size).to(DEVICE)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Test forward pass
state = game.reset()
obs = game.encode(state).unsqueeze(0).to(DEVICE)
hidden, policy, value = model.initial_inference(obs)

# Test dynamics with entropy
action = torch.tensor([1], device=DEVICE)
next_hidden, reward, next_policy, next_value, entropy = model.recurrent_inference(hidden, action)
print(f"Hidden state: {hidden.shape}")
print(f"Transition entropy: {entropy.item():.4f}")
print(f"  (Low entropy = deterministic transition = macro candidate)")


In [None]:
# Simple MCTS with entropy tracking for macro discovery

from dataclasses import dataclass, field
from typing import List, Dict, Optional
import math


@dataclass
class MCTSNode:
    """Node in the MCTS tree."""
    hidden_state: torch.Tensor
    reward: float = 0.0
    prior: float = 0.0
    entropy: float = 0.0  # Model's uncertainty at this transition
    
    visit_count: int = 0
    value_sum: float = 0.0
    children: Dict[int, 'MCTSNode'] = field(default_factory=dict)
    action_from_parent: int = -1
    
    @property
    def value(self) -> float:
        if self.visit_count == 0:
            return 0.0
        return self.value_sum / self.visit_count
    
    def expanded(self) -> bool:
        return len(self.children) > 0


class MCTS:
    """
    Monte Carlo Tree Search with entropy tracking.
    
    During search, we track the entropy of each transition.
    Low-entropy paths through the tree are macro candidates.
    """
    
    def __init__(self, model: MuZeroNetwork, device: torch.device,
                 num_simulations: int = 25,
                 c_puct: float = 1.25,
                 discount: float = 0.99):
        self.model = model
        self.device = device
        self.num_simulations = num_simulations
        self.c_puct = c_puct
        self.discount = discount
        
        # Track rollout entropies for macro discovery
        self.rollout_actions: List[List[int]] = []
        self.rollout_entropies: List[List[float]] = []
        self.rollout_rewards: List[List[float]] = []
    
    def search(self, obs: torch.Tensor, legal_actions: List[int]) -> MCTSNode:
        """Run MCTS from observation, return root node."""
        self.model.eval()
        
        # Initialize root
        with torch.no_grad():
            hidden, policy_logits, value = self.model.initial_inference(obs)
        
        root = MCTSNode(hidden_state=hidden)
        self._expand(root, policy_logits, legal_actions)
        
        # Run simulations
        for _ in range(self.num_simulations):
            node = root
            path = []
            actions = []
            entropies = []
            rewards = []
            
            # Selection
            while node.expanded():
                action, child = self._select_child(node)
                path.append(node)
                actions.append(action)
                entropies.append(child.entropy)
                rewards.append(child.reward)
                node = child
            
            # Expansion
            if not node.expanded():
                with torch.no_grad():
                    next_hidden, reward, policy, value, entropy = self.model.recurrent_inference(
                        node.hidden_state, 
                        torch.tensor([actions[-1]] if actions else [0], device=self.device)
                    )
                node.hidden_state = next_hidden
                self._expand(node, policy, legal_actions)
                leaf_value = value.item()
            else:
                leaf_value = node.value
            
            # Backprop
            self._backpropagate(path + [node], leaf_value)
            
            # Store rollout for macro discovery
            if len(actions) >= 2:
                self.rollout_actions.append(actions)
                self.rollout_entropies.append(entropies)
                self.rollout_rewards.append(rewards)
        
        return root
    
    def _expand(self, node: MCTSNode, policy_logits: torch.Tensor, legal_actions: List[int]):
        """Expand node with children for legal actions."""
        probs = F.softmax(policy_logits[0], dim=0).cpu().numpy()
        
        for action in legal_actions:
            with torch.no_grad():
                next_hidden, reward, _, value, entropy = self.model.recurrent_inference(
                    node.hidden_state,
                    torch.tensor([action], device=self.device)
                )
            
            child = MCTSNode(
                hidden_state=next_hidden,
                reward=reward.item(),
                prior=probs[action],
                entropy=entropy.item(),
                action_from_parent=action,
            )
            node.children[action] = child
    
    def _select_child(self, node: MCTSNode) -> tuple:
        """Select child using PUCT formula."""
        best_score = -float('inf')
        best_action = None
        best_child = None
        
        sqrt_total = math.sqrt(node.visit_count + 1)
        
        for action, child in node.children.items():
            # PUCT score
            if child.visit_count == 0:
                q_value = 0
            else:
                q_value = child.value
            
            exploration = self.c_puct * child.prior * sqrt_total / (1 + child.visit_count)
            score = q_value + exploration
            
            if score > best_score:
                best_score = score
                best_action = action
                best_child = child
        
        return best_action, best_child
    
    def _backpropagate(self, path: List[MCTSNode], value: float):
        """Backpropagate value through path."""
        for node in reversed(path):
            node.visit_count += 1
            node.value_sum += value
            value = node.reward + self.discount * value
    
    def get_action_probs(self, root: MCTSNode, temperature: float = 1.0) -> tuple:
        """Get action probabilities from visit counts."""
        actions = list(root.children.keys())
        visits = np.array([root.children[a].visit_count for a in actions])
        
        if temperature == 0:
            action = actions[np.argmax(visits)]
            probs = np.zeros(len(actions))
            probs[np.argmax(visits)] = 1.0
        else:
            visits = visits ** (1 / temperature)
            probs = visits / visits.sum()
            action = np.random.choice(actions, p=probs)
        
        return action, probs
    
    def get_mcts_rollouts(self) -> tuple:
        """Return collected rollouts for macro discovery."""
        actions = self.rollout_actions
        entropies = self.rollout_entropies
        rewards = self.rollout_rewards
        
        # Clear for next episode
        self.rollout_actions = []
        self.rollout_entropies = []
        self.rollout_rewards = []
        
        return actions, entropies, rewards


# Test MCTS
mcts = MCTS(model, DEVICE, num_simulations=10)
state = game.reset()
obs = game.encode(state).unsqueeze(0).to(DEVICE)
root = mcts.search(obs, game.legal_actions(state))
action, probs = mcts.get_action_probs(root, temperature=1.0)
print(f"MCTS search complete. Selected action: {game.action_name(action)}")
print(f"Visit counts: {[(game.action_name(a), c.visit_count) for a, c in root.children.items()]}")


## 4. Macro Discovery

Macros are **reusable action sequences** discovered from gameplay:
- Track action patterns that occur frequently
- In deterministic Atari, all transitions have entropy ≈ 0
- Patterns that repeat become macro candidates

In [None]:
from collections import defaultdict


@dataclass
class MacroOperator:
    """A discovered macro-operator."""
    actions: tuple
    avg_entropy: float
    count: int = 0
    total_reward: float = 0.0
    state_buckets: set = field(default_factory=set)
    source: str = "play"  # "play" or "mcts"
    
    @property
    def avg_reward(self) -> float:
        return self.total_reward / max(self.count, 1)
    
    @property
    def length(self) -> int:
        return len(self.actions)
    
    @property
    def is_deterministic(self) -> bool:
        return self.avg_entropy < 0.1
    
    @property
    def state_generality(self) -> int:
        return len(self.state_buckets)


class MacroCache:
    """
    Dual-threshold macro discovery:
    - Tight threshold (0.1) for actual played trajectories
    - Looser threshold (0.3) for MCTS speculative rollouts
    
    This accounts for MCTS exploring more uncertain paths.
    """
    
    def __init__(self, 
                 play_entropy_threshold: float = 0.15,   # Tight for played path
                 mcts_entropy_threshold: float = 0.35,   # Looser for MCTS
                 min_occurrences: int = 3, 
                 min_length: int = 2, 
                 max_length: int = 6,
                 state_dim: int = 256,
                 num_state_buckets: int = 64):
        self.play_threshold = play_entropy_threshold
        self.mcts_threshold = mcts_entropy_threshold
        self.min_occurrences = min_occurrences
        self.min_length = min_length
        self.max_length = max_length
        self.num_state_buckets = num_state_buckets
        
        self.state_projection = torch.randn(state_dim, num_state_buckets)
        
        self.candidates = defaultdict(lambda: {
            "count": 0, "total_entropy": 0.0, "total_reward": 0.0, 
            "state_buckets": set(), "source": set()
        })
        self.macros: Dict[tuple, MacroOperator] = {}
        
        self.total_transitions = 0
        self.low_entropy_play = 0
        self.low_entropy_mcts = 0
        self.mcts_rollouts = 0
        self.play_rollouts = 0
    
    def _get_state_bucket(self, hidden_state: torch.Tensor) -> int:
        with torch.no_grad():
            if hidden_state.dim() > 1:
                hidden_state = hidden_state.squeeze(0)
            proj = hidden_state.cpu() @ self.state_projection
            bucket = int(proj.argmax().item())
        return bucket
    
    def add_played_trajectory(self, actions: List[int], entropies: List[float], 
                              rewards: List[float], initial_state: Optional[torch.Tensor] = None):
        """Add trajectory from actual gameplay (tight threshold)."""
        self._add_rollout(actions, entropies, rewards, initial_state, 
                         threshold=self.play_threshold, source="play")
        self.play_rollouts += 1
    
    def add_mcts_rollout(self, actions: List[int], entropies: List[float], 
                         rewards: List[float], initial_state: Optional[torch.Tensor] = None):
        """Add MCTS speculative rollout (looser threshold)."""
        self._add_rollout(actions, entropies, rewards, initial_state,
                         threshold=self.mcts_threshold, source="mcts")
        self.mcts_rollouts += 1
    
    def _add_rollout(self, actions, entropies, rewards, initial_state, threshold, source):
        n = len(actions)
        if n < self.min_length or len(entropies) < n:
            return
        
        self.total_transitions += n
        
        state_bucket = 0
        if initial_state is not None:
            state_bucket = self._get_state_bucket(initial_state)
        
        for length in range(self.min_length, min(self.max_length + 1, n + 1)):
            for i in range(n - length + 1):
                if i + length > len(entropies):
                    continue
                window_entropies = entropies[i:i + length]
                
                if all(e < threshold for e in window_entropies):
                    pattern = tuple(actions[i:i + length])
                    
                    if len(set(pattern)) <= 1:
                        continue
                    
                    if source == "play":
                        self.low_entropy_play += 1
                    else:
                        self.low_entropy_mcts += 1
                    
                    avg_entropy = sum(window_entropies) / length
                    pattern_reward = sum(rewards[i:i + length]) if i + length <= len(rewards) else 0
                    
                    self.candidates[pattern]["count"] += 1
                    self.candidates[pattern]["total_entropy"] += avg_entropy
                    self.candidates[pattern]["total_reward"] += pattern_reward
                    self.candidates[pattern]["state_buckets"].add(state_bucket)
                    self.candidates[pattern]["source"].add(source)
                    
                    if (self.candidates[pattern]["count"] >= self.min_occurrences 
                        and pattern not in self.macros):
                        c = self.candidates[pattern]
                        primary_source = "play" if "play" in c["source"] else "mcts"
                        self.macros[pattern] = MacroOperator(
                            actions=pattern,
                            avg_entropy=c["total_entropy"] / c["count"],
                            count=c["count"],
                            total_reward=c["total_reward"],
                            state_buckets=c["state_buckets"].copy(),
                            source=primary_source,
                        )
                    elif pattern in self.macros:
                        m = self.macros[pattern]
                        c = self.candidates[pattern]
                        m.count = c["count"]
                        m.total_reward = c["total_reward"]
                        m.avg_entropy = c["total_entropy"] / m.count
                        m.state_buckets = c["state_buckets"].copy()
    
    def add_mcts_rollouts_batch(self, mcts: 'MCTS', initial_state: torch.Tensor):
        """Process all rollouts from MCTS."""
        actions_list, entropies_list, rewards_list = mcts.get_mcts_rollouts()
        for actions, entropies, rewards in zip(actions_list, entropies_list, rewards_list):
            self.add_mcts_rollout(actions, entropies, rewards, initial_state)
    
    def get_statistics(self) -> Dict[str, Any]:
        total_low = self.low_entropy_play + self.low_entropy_mcts
        det_rate = total_low / max(self.total_transitions, 1)
        
        if self.macros:
            avg_gen = np.mean([m.state_generality for m in self.macros.values()])
            max_gen = max(m.state_generality for m in self.macros.values())
            play_macros = sum(1 for m in self.macros.values() if m.source == "play")
        else:
            avg_gen = max_gen = play_macros = 0
        
        return {
            "num_macros": len(self.macros),
            "play_macros": play_macros,
            "mcts_macros": len(self.macros) - play_macros,
            "num_candidates": len(self.candidates),
            "total_transitions": self.total_transitions,
            "low_entropy_play": self.low_entropy_play,
            "low_entropy_mcts": self.low_entropy_mcts,
            "deterministic_rate": det_rate,
            "mcts_rollouts": self.mcts_rollouts,
            "play_rollouts": self.play_rollouts,
            "avg_state_generality": avg_gen,
            "max_state_generality": max_gen,
        }
    
    def get_top_macros(self, n: int = 10, sort_by: str = "count") -> List[MacroOperator]:
        if sort_by == "entropy":
            key = lambda m: m.avg_entropy
        elif sort_by == "generality":
            key = lambda m: -m.state_generality
        else:
            key = lambda m: -m.count
        return sorted(self.macros.values(), key=key)[:n]
    
    def decode_macro(self, macro: MacroOperator, game) -> str:
        return " -> ".join([game.action_name(a) for a in macro.actions])


cache = MacroCache(play_entropy_threshold=0.15, mcts_entropy_threshold=0.35)
print(f"Dual-threshold MacroCache:")
print(f"  Play threshold: {cache.play_threshold} (tight)")
print(f"  MCTS threshold: {cache.mcts_threshold} (loose)")


## 5. Self-Play with Macro Discovery

In [None]:
def play_episode_mcts(game: AtariGame, model: MuZeroNetwork, 
                      macro_cache: MacroCache, device: torch.device,
                      max_steps: int = 1000, 
                      num_simulations: int = 25,
                      temperature: float = 1.0) -> Tuple[List, List, List, List, int]:
    """
    Play episode with MCTS, tracking both:
    1. MCTS rollouts (speculative, loose threshold)
    2. Actual played trajectory (tight threshold)
    """
    model.eval()
    mcts = MCTS(model, device, num_simulations=num_simulations)
    
    state = game.reset()
    
    observations = []
    actions = []
    rewards = []
    entropies = []
    
    for step in range(max_steps):
        if state.done:
            break
        
        obs = game.encode(state).to(device).unsqueeze(0)
        observations.append(obs.cpu())
        
        legal = game.legal_actions(state)
        if not legal:
            break
        
        # Run MCTS
        root = mcts.search(obs, legal)
        
        # Get initial hidden state for macro tracking
        with torch.no_grad():
            hidden, _, _ = model.initial_inference(obs)
        
        # Add MCTS rollouts (loose threshold)
        macro_cache.add_mcts_rollouts_batch(mcts, hidden)
        
        # Select action
        action, _ = mcts.get_action_probs(root, temperature=temperature)
        actions.append(action)
        
        # Get entropy for selected action
        if action in root.children:
            entropies.append(root.children[action].entropy)
        else:
            entropies.append(0.5)
        
        state, reward, done = game.step(state, action)
        rewards.append(reward)
    
    # Add the ACTUALLY PLAYED trajectory (tight threshold)
    if len(actions) >= 2 and observations:
        first_obs = observations[0].to(device)
        with torch.no_grad():
            first_hidden, _, _ = model.initial_inference(first_obs)
        macro_cache.add_played_trajectory(actions, entropies, rewards, first_hidden)
    
    return observations, actions, rewards, entropies, state.score


def play_episode_simple(game: AtariGame, model: MuZeroNetwork, 
                        macro_cache: MacroCache, device: torch.device,
                        max_steps: int = 1000, epsilon: float = 0.1,
                        temperature: float = 1.0) -> Tuple[List, List, List, List, int]:
    """Simple policy episode (no MCTS)."""
    model.eval()
    state = game.reset()
    
    observations = []
    actions = []
    rewards = []
    entropies = []
    
    for step in range(max_steps):
        if state.done:
            break
        
        obs = game.encode(state).to(device).unsqueeze(0)
        observations.append(obs.cpu())
        
        with torch.no_grad():
            hidden, policy_logits, _ = model.initial_inference(obs)
        
        legal = game.legal_actions(state)
        if not legal:
            break
        
        if np.random.random() < epsilon:
            action = np.random.choice(legal)
        else:
            probs = F.softmax(policy_logits[0] / temperature, dim=0).cpu().numpy()
            mask = np.zeros(game.action_space_size)
            mask[legal] = 1
            probs = probs * mask
            if probs.sum() > 0:
                probs = probs / probs.sum()
                action = np.random.choice(game.action_space_size, p=probs)
            else:
                action = np.random.choice(legal)
        
        with torch.no_grad():
            action_tensor = torch.tensor([action], device=device)
            _, _, _, _, entropy = model.recurrent_inference(hidden, action_tensor)
            entropies.append(entropy.item())
        
        actions.append(action)
        state, reward, done = game.step(state, action)
        rewards.append(reward)
    
    # Add played trajectory
    if len(actions) >= 2 and observations:
        first_obs = observations[0].to(device)
        with torch.no_grad():
            first_hidden, _, _ = model.initial_inference(first_obs)
        macro_cache.add_played_trajectory(actions, entropies, rewards, first_hidden)
    
    return observations, actions, rewards, entropies, state.score


print("Self-play ready with dual-threshold macro discovery!")


## 6. Training Loop

In [None]:
import torch.optim as optim
from tqdm.auto import tqdm
import matplotlib.pyplot as plt


def train_muzero_mcts(game: AtariGame, model: MuZeroNetwork, 
                      macro_cache: MacroCache, device: torch.device,
                      num_iterations: int = 100,
                      episodes_per_iter: int = 2,
                      max_steps: int = 500,
                      num_simulations: int = 15,
                      batch_size: int = 64,
                      lr: float = 1e-4,
                      use_mcts: bool = True):
    """
    Train MuZero with MCTS-based macro discovery.
    
    Key improvements:
    1. MCTS explores many paths -> more macro candidates
    2. State-conditioned -> know which macros are general
    3. Longer training -> accumulate confident patterns
    """
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    replay_obs = []
    replay_actions = []
    replay_rewards = []
    max_buffer = 10000
    
    history = {
        "scores": [], "lengths": [], "losses": [], 
        "macros": [], "avg_entropy": [], "deterministic_rate": [],
        "mcts_rollouts": [], "state_generality": []
    }
    
    mode = "MCTS" if use_mcts else "Policy"
    print(f"Training with {mode} for {num_iterations} iterations")
    print(f"  Episodes/iter: {episodes_per_iter}, Steps/episode: {max_steps}")
    if use_mcts:
        print(f"  MCTS simulations: {num_simulations}")
    print(f"  Macro threshold: {macro_cache.entropy_threshold}")
    print()
    
    for iteration in tqdm(range(num_iterations), desc="Training"):
        # Temperature schedule
        temp = max(0.5, 1.5 - iteration / (num_iterations / 2))
        epsilon = max(0.05, 1.0 - iteration / num_iterations)
        
        iter_scores = []
        iter_lengths = []
        iter_entropies = []
        
        for _ in range(episodes_per_iter):
            if use_mcts:
                obs, acts, rews, ents, score = play_episode_mcts(
                    game, model, macro_cache, device,
                    max_steps=max_steps,
                    num_simulations=num_simulations,
                    temperature=temp
                )
            else:
                obs, acts, rews, ents, score = play_episode_simple(
                    game, model, device,
                    max_steps=max_steps,
                    epsilon=epsilon,
                    temperature=temp
                )
                # Also add to macro cache
                if obs:
                    macro_cache.add_rollout(acts, ents, rews, obs[0].to(device))
            
            replay_obs.extend(obs)
            replay_actions.extend(acts)
            replay_rewards.extend(rews)
            
            iter_scores.append(score)
            iter_lengths.append(len(acts))
            iter_entropies.extend(ents)
        
        # Trim buffer
        if len(replay_obs) > max_buffer:
            replay_obs = replay_obs[-max_buffer:]
            replay_actions = replay_actions[-max_buffer:]
            replay_rewards = replay_rewards[-max_buffer:]
        
        # Training
        if len(replay_obs) >= batch_size:
            model.train()
            
            indices = np.random.choice(len(replay_obs), batch_size, replace=False)
            batch_obs = torch.cat([replay_obs[i] for i in indices]).to(device)
            batch_actions = torch.tensor([replay_actions[i] for i in indices], device=device)
            batch_rewards = torch.tensor([replay_rewards[i] for i in indices], 
                                         device=device, dtype=torch.float32)
            
            hidden, policy_logits, values = model.initial_inference(batch_obs)
            
            policy_loss = F.cross_entropy(policy_logits, batch_actions)
            value_loss = F.mse_loss(values, batch_rewards)
            
            next_hidden, pred_reward, _, _, pred_entropy = model.recurrent_inference(
                hidden, batch_actions
            )
            reward_loss = F.mse_loss(pred_reward, batch_rewards)
            
            # Entropy regularization
            entropy_target = 0.1
            entropy_loss = F.mse_loss(pred_entropy, torch.full_like(pred_entropy, entropy_target))
            
            total_loss = policy_loss + 0.5 * value_loss + 0.5 * reward_loss + 0.1 * entropy_loss
            
            optimizer.zero_grad()
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            history["losses"].append(total_loss.item())
        
        # Record
        history["scores"].append(np.mean(iter_scores))
        history["lengths"].append(np.mean(iter_lengths))
        history["macros"].append(len(macro_cache.macros))
        history["avg_entropy"].append(np.mean(iter_entropies) if iter_entropies else 0)
        
        stats = macro_cache.get_statistics()
        history["deterministic_rate"].append(stats["deterministic_rate"])
        history["mcts_rollouts"].append(stats["mcts_rollouts"])
        history["state_generality"].append(stats["avg_state_generality"])
        
        if iteration % 20 == 0:
            print(f"\nIter {iteration}: Score={np.mean(iter_scores):.1f}, "
                  f"Entropy={np.mean(iter_entropies):.3f}, "
                  f"Macros={stats['num_macros']}, "
                  f"MCTS rollouts={stats['mcts_rollouts']}, "
                  f"Generality={stats['avg_state_generality']:.1f}")
    
    return history


## 7. Train on Breakout

In [None]:
# Initialize
breakout = AtariGame("Breakout")
model = MuZeroNetwork(breakout.observation_dim, breakout.action_space_size).to(DEVICE)

# Dual-threshold macro cache
macro_cache = MacroCache(
    play_entropy_threshold=0.2,   # Tight for played trajectories
    mcts_entropy_threshold=0.4,   # Looser for MCTS exploration
    min_occurrences=3,
    min_length=2,
    max_length=6,
)

print(f"Training: {breakout.game_name}")
print(f"Dual-threshold macro discovery:")
print(f"  Played trajectory: {macro_cache.play_threshold}")
print(f"  MCTS rollouts: {macro_cache.mcts_threshold}")
print()

# Train - fewer iterations to avoid collapse
history = train_muzero_mcts(
    game=breakout,
    model=model,
    macro_cache=macro_cache,
    device=DEVICE,
    num_iterations=80,
    episodes_per_iter=3,
    max_steps=500,
    num_simulations=10,
    use_mcts=True,
)


In [None]:
# Plot training results
fig, axes = plt.subplots(2, 4, figsize=(18, 8))

axes[0, 0].plot(history["scores"])
axes[0, 0].set_xlabel("Iteration"); axes[0, 0].set_ylabel("Score")
axes[0, 0].set_title("Breakout Score")

axes[0, 1].plot(history["lengths"])
axes[0, 1].set_xlabel("Iteration"); axes[0, 1].set_ylabel("Length")
axes[0, 1].set_title("Episode Length")

axes[0, 2].plot(history["losses"])
axes[0, 2].set_xlabel("Step"); axes[0, 2].set_ylabel("Loss")
axes[0, 2].set_title("Training Loss")

axes[0, 3].plot(history["macros"])
axes[0, 3].set_xlabel("Iteration"); axes[0, 3].set_ylabel("Macros")
axes[0, 3].set_title("Macros Discovered")

axes[1, 0].plot(history["avg_entropy"])
axes[1, 0].axhline(y=macro_cache.entropy_threshold, color='r', linestyle='--', label='Threshold')
axes[1, 0].set_xlabel("Iteration"); axes[1, 0].set_ylabel("Entropy")
axes[1, 0].set_title("Model Entropy"); axes[1, 0].legend()

axes[1, 1].plot(history["deterministic_rate"])
axes[1, 1].set_xlabel("Iteration"); axes[1, 1].set_ylabel("Rate")
axes[1, 1].set_title("Deterministic Rate"); axes[1, 1].set_ylim(0, 1)

axes[1, 2].plot(history["mcts_rollouts"])
axes[1, 2].set_xlabel("Iteration"); axes[1, 2].set_ylabel("Rollouts")
axes[1, 2].set_title("MCTS Rollouts Processed")

axes[1, 3].plot(history["state_generality"])
axes[1, 3].set_xlabel("Iteration"); axes[1, 3].set_ylabel("Avg Buckets")
axes[1, 3].set_title("Macro State Generality")

plt.tight_layout()
plt.show()


## 8. Macro Analysis

In [None]:
print("=" * 70)
print("MACRO ANALYSIS: Dual-Threshold Discovery")
print("=" * 70)

stats = macro_cache.get_statistics()
print(f"\nTransitions: {stats['total_transitions']}")
print(f"Low-entropy (play): {stats['low_entropy_play']}")
print(f"Low-entropy (MCTS): {stats['low_entropy_mcts']}")
print(f"Deterministic rate: {stats['deterministic_rate']:.1%}")
print(f"\nRollouts - Play: {stats['play_rollouts']}, MCTS: {stats['mcts_rollouts']}")
print(f"\nMacros discovered: {stats['num_macros']}")
print(f"  From play: {stats['play_macros']}")
print(f"  From MCTS: {stats['mcts_macros']}")
print(f"  Candidates: {stats['num_candidates']}")

if stats['num_macros'] > 0:
    print("\n" + "=" * 70)
    print("MOST GENERAL MACROS (work in many states)")
    print("=" * 70)
    for i, macro in enumerate(macro_cache.get_top_macros(10, sort_by="generality")):
        decoded = macro_cache.decode_macro(macro, breakout)
        print(f"{i+1:2d}. {decoded}")
        print(f"    States: {macro.state_generality}, Count: {macro.count}, "
              f"Entropy: {macro.avg_entropy:.3f}, Source: {macro.source}")
    
    print("\n" + "=" * 70)
    print("MOST DETERMINISTIC (lowest entropy)")
    print("=" * 70)
    for i, macro in enumerate(macro_cache.get_top_macros(10, sort_by="entropy")):
        decoded = macro_cache.decode_macro(macro, breakout)
        det = "DET" if macro.is_deterministic else "learned"
        print(f"{i+1:2d}. {decoded}")
        print(f"    Entropy: {macro.avg_entropy:.4f} [{det}], Count: {macro.count}")
    
    print("\n" + "=" * 70)
    print("MOST FREQUENT")
    print("=" * 70)
    for i, macro in enumerate(macro_cache.get_top_macros(10, sort_by="count")):
        decoded = macro_cache.decode_macro(macro, breakout)
        print(f"{i+1:2d}. {decoded} (count={macro.count}, entropy={macro.avg_entropy:.3f})")
else:
    print("\nNo macros yet. Check:")
    print("  1. Is entropy dropping? (model learning)")
    print("  2. Are thresholds appropriate?")
    print("  3. Need more training iterations?")


In [None]:
# Entropy distribution of discovered macros
import matplotlib.pyplot as plt

if macro_cache.macros:
    entropies = [m.avg_entropy for m in macro_cache.macros.values()]
    
    plt.figure(figsize=(10, 4))
    
    plt.subplot(1, 2, 1)
    plt.hist(entropies, bins=30, edgecolor='black')
    plt.axvline(x=0.1, color='r', linestyle='--', label='Deterministic threshold')
    plt.axvline(x=macro_cache.entropy_threshold, color='orange', linestyle='--', label='Discovery threshold')
    plt.xlabel("Average Entropy")
    plt.ylabel("Count")
    plt.title("Entropy Distribution of Discovered Macros")
    plt.legend()
    
    plt.subplot(1, 2, 2)
    lengths = [m.length for m in macro_cache.macros.values()]
    plt.hist(lengths, bins=range(2, 8), edgecolor='black', align='left')
    plt.xlabel("Macro Length")
    plt.ylabel("Count")
    plt.title("Length Distribution of Macros")
    
    plt.tight_layout()
    plt.show()
    
    # Stats
    det_macros = sum(1 for e in entropies if e < 0.1)
    print(f"Truly deterministic macros (entropy < 0.1): {det_macros} / {len(entropies)}")


## 9. Train on Other Games (Optional)

In [None]:
# Uncomment to train on Pong

# pong = AtariGame("Pong")
# pong_model = MuZeroNetwork(pong.observation_dim, pong.action_space_size).to(DEVICE)
# pong_cache = MacroCache(min_occurrences=5)

# pong_history = train_muzero(
#     game=pong,
#     model=pong_model,
#     macro_cache=pong_cache,
#     device=DEVICE,
#     num_iterations=50,
#     episodes_per_iter=3,
# )

# print("\nPong Macros:")
# for macro in pong_cache.get_top_macros(10):
#     print(f"  {pong_cache.decode_macro(macro, pong)} (count={macro.count})")

In [None]:
# Uncomment to train on Space Invaders

# invaders = AtariGame("SpaceInvaders")
# invaders_model = MuZeroNetwork(invaders.observation_dim, invaders.action_space_size).to(DEVICE)
# invaders_cache = MacroCache(min_occurrences=5)

# invaders_history = train_muzero(
#     game=invaders,
#     model=invaders_model,
#     macro_cache=invaders_cache,
#     device=DEVICE,
#     num_iterations=50,
#     episodes_per_iter=3,
# )

# print("\nSpace Invaders Macros:")
# for macro in invaders_cache.get_top_macros(10):
#     print(f"  {invaders_cache.decode_macro(macro, invaders)} (count={macro.count})")

## 10. Summary

### Key Insights

1. **Deterministic Atari** games are ideal for macro discovery because:
   - All transitions have entropy ≈ 0 (no sticky actions)
   - Physics are predictable (ball bounces, gravity)
   - Agent controls all actions (no opponent uncertainty)

2. **Macros as Compressible Causal Structure**:
   - Frequent action patterns represent reusable temporal abstractions
   - These are "rules" discovered from dynamics, not symbolic definitions
   - Longer training reveals more meaningful patterns

3. **Next Steps**:
   - Use macros to accelerate MCTS planning
   - Learn macro preconditions (when to apply)
   - Transfer macros across similar games