# Stochastic MuZero with Learned Temporal Abstractions
## V2: Atari Games Edition

This notebook implements **Stochastic MuZero** with macro discovery on **deterministic Atari games**.

### Why Atari?
- **Single-player**: Agent controls all actions (no opponent uncertainty)
- **Deterministic physics**: Ball bounces, gravity, collisions are predictable
- **Rich temporal structure**: Positioning sequences, attack patterns, navigation

### Core Idea
Rules are **compressible causal structure** in the transition dynamics:
- If a sequence of transitions is deterministic and repeatable, collapse it into a **macro-operator**
- Macros enable faster planning and capture reusable temporal abstractions

### Games
- **Breakout**: Ball trajectory, paddle positioning
- **Pong**: Ball interception patterns
- **Space Invaders**: Enemy patterns, shooting sequences

## 1. Setup & Installation

In [None]:
# Install dependencies
!pip install torch numpy gymnasium ale-py tqdm matplotlib -q

# Install Atari ROMs (required for ALE games)
!pip install autorom[accept-rom-license] -q
!python -m atari_py.import_roms . 2>/dev/null || true

# Alternative ROM installation if above fails
import subprocess
try:
    subprocess.run(['AutoROM', '--accept-license'], capture_output=True, timeout=60)
    print("ROMs installed via AutoROM")
except:
    print("AutoROM not available, trying alternative...")
    !pip install ale-py[roms] -q 2>/dev/null || true

import torch
import numpy as np

print(f"\nPyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using: {DEVICE}")

In [None]:
# Test Atari environment
import gymnasium as gym
import ale_py

# Register ALE environments
gym.register_envs(ale_py)

try:
    env = gym.make("ALE/Breakout-v5")
    obs, _ = env.reset()
    print(f"Breakout works!")
    print(f"  Observation: {obs.shape}")
    print(f"  Actions: {env.action_space.n} - {env.unwrapped.get_action_meanings()}")
    env.close()
except Exception as e:
    print(f"Error: {e}")
    print("\nTrying to fix...")
    !pip install 'gymnasium[atari,accept-rom-license]' -q
    print("\nPlease restart runtime and run this cell again.")

## 2. Atari Game Wrapper

In [None]:
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple, Any
import gymnasium as gym
from gymnasium.wrappers import AtariPreprocessing, FrameStackObservation

# Register ALE environments
import ale_py
gym.register_envs(ale_py)


@dataclass
class AtariState:
    """Wrapper for Atari game state."""
    env: gym.Env
    observation: np.ndarray
    done: bool = False
    lives: int = 0
    score: int = 0


class AtariGame:
    """
    Atari game wrapper with deterministic mode for macro discovery.
    
    Key settings:
    - repeat_action_probability=0: No sticky actions (deterministic)
    - noop_max=0: No random initial noops
    - 84x84 grayscale, 4 stacked frames
    """
    
    ACTIONS = {
        0: "NOOP", 1: "FIRE", 2: "UP", 3: "RIGHT", 4: "LEFT", 5: "DOWN",
        6: "UPRIGHT", 7: "UPLEFT", 8: "DOWNRIGHT", 9: "DOWNLEFT",
        10: "UPFIRE", 11: "RIGHTFIRE", 12: "LEFTFIRE", 13: "DOWNFIRE",
    }

    def __init__(self, game_name: str = "Breakout", frame_stack: int = 4, 
                 frame_skip: int = 4, deterministic: bool = True):
        self.game_name = game_name
        self.frame_stack = frame_stack
        self.frame_skip = frame_skip
        self.deterministic = deterministic
        
        self._env = self._make_env()
        self.action_space_size = self._env.action_space.n
        self.action_meanings = self._env.unwrapped.get_action_meanings()
        self.observation_dim = 84 * 84 * frame_stack  # Flattened

    def _make_env(self) -> gym.Env:
        env = gym.make(
            f"ALE/{self.game_name}-v5",
            repeat_action_probability=0.0 if self.deterministic else 0.25,
            frameskip=1,
        )
        env = AtariPreprocessing(
            env,
            noop_max=0 if self.deterministic else 30,
            frame_skip=self.frame_skip,
            screen_size=84,
            grayscale_obs=True,
            grayscale_newaxis=True,
            scale_obs=True,
        )
        env = FrameStackObservation(env, self.frame_stack)
        return env

    def reset(self) -> AtariState:
        env = self._make_env()
        obs, info = env.reset()
        return AtariState(env=env, observation=obs, done=False, 
                         lives=info.get('lives', 0), score=0)

    def step(self, state: AtariState, action: int) -> Tuple[AtariState, float, bool]:
        if state.done:
            return state, 0.0, True
        
        obs, reward, terminated, truncated, info = state.env.step(action)
        done = terminated or truncated
        
        new_state = AtariState(
            env=state.env, observation=obs, done=done,
            lives=info.get('lives', state.lives),
            score=state.score + int(reward),
        )
        return new_state, float(reward), done

    def encode(self, state: AtariState) -> torch.Tensor:
        """Encode observation as flattened tensor."""
        obs = state.observation
        if obs.ndim == 4 and obs.shape[-1] == 1:
            obs = obs.squeeze(-1)  # (4, 84, 84)
        obs = np.asarray(obs, dtype=np.float32).reshape(-1)
        return torch.tensor(obs, dtype=torch.float32)

    def legal_actions(self, state: AtariState) -> List[int]:
        return [] if state.done else list(range(self.action_space_size))

    def action_name(self, action: int) -> str:
        if action < len(self.action_meanings):
            return self.action_meanings[action]
        return f"Action_{action}"


# Test
game = AtariGame("Breakout")
print(f"Game: {game.game_name}")
print(f"Observation dim: {game.observation_dim}")
print(f"Actions: {game.action_space_size} - {game.action_meanings}")


## 3. MuZero Network

In [None]:
import torch.nn as nn
import torch.nn.functional as F


class MuZeroNetwork(nn.Module):
    """
    MuZero network with entropy estimation for macro discovery.
    
    Key addition: The dynamics network outputs both next_state AND
    a confidence/entropy estimate. Low entropy = deterministic transition.
    """
    
    def __init__(self, obs_dim: int, action_dim: int, 
                 state_dim: int = 256, hidden_dim: int = 256):
        super().__init__()
        self.action_dim = action_dim
        self.state_dim = state_dim
        
        # Representation network: obs -> state
        self.representation = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, state_dim),
        )
        
        # Dynamics network: (state, action) -> next_state
        self.dynamics = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, state_dim),
        )
        
        # Entropy head: predicts uncertainty of transition
        # Low entropy = deterministic (macro-able)
        # High entropy = stochastic (can't compress)
        self.entropy_head = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),
            nn.Softplus(),  # Ensure positive entropy
        )
        
        # Reward prediction
        self.reward_head = nn.Sequential(
            nn.Linear(state_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),
        )
        
        # Policy head
        self.policy_head = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
        )
        
        # Value head
        self.value_head = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
        )
    
    def initial_inference(self, obs: torch.Tensor):
        """Initial inference: obs -> (state, policy, value)"""
        state = self.representation(obs)
        policy = self.policy_head(state)
        value = self.value_head(state).squeeze(-1)
        return state, policy, value
    
    def recurrent_inference(self, state: torch.Tensor, action: torch.Tensor):
        """
        Recurrent inference with entropy estimation.
        
        Returns: (next_state, reward, policy, value, entropy)
        
        entropy: Model's uncertainty about this transition.
                 Low = deterministic, good for macros.
                 High = stochastic, can't compress.
        """
        # One-hot encode action
        action_onehot = F.one_hot(action, self.action_dim).float()
        x = torch.cat([state, action_onehot], dim=-1)
        
        next_state = self.dynamics(x)
        entropy = self.entropy_head(x).squeeze(-1)  # Predicted transition entropy
        
        reward = self.reward_head(next_state).squeeze(-1)
        policy = self.policy_head(next_state)
        value = self.value_head(next_state).squeeze(-1)
        
        return next_state, reward, policy, value, entropy


# Test
model = MuZeroNetwork(game.observation_dim, game.action_space_size).to(DEVICE)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Test forward pass
state = game.reset()
obs = game.encode(state).unsqueeze(0).to(DEVICE)
hidden, policy, value = model.initial_inference(obs)

# Test dynamics with entropy
action = torch.tensor([1], device=DEVICE)
next_hidden, reward, next_policy, next_value, entropy = model.recurrent_inference(hidden, action)
print(f"Hidden state: {hidden.shape}")
print(f"Transition entropy: {entropy.item():.4f}")
print(f"  (Low entropy = deterministic transition = macro candidate)")


## 4. Macro Discovery

Macros are **reusable action sequences** discovered from gameplay:
- Track action patterns that occur frequently
- In deterministic Atari, all transitions have entropy ≈ 0
- Patterns that repeat become macro candidates

In [None]:
from collections import defaultdict


@dataclass
class MacroOperator:
    """A discovered macro-operator (low-entropy action sequence)."""
    actions: tuple
    avg_entropy: float  # Average entropy across transitions
    count: int = 0
    total_reward: float = 0.0
    
    @property
    def avg_reward(self) -> float:
        return self.total_reward / max(self.count, 1)
    
    @property
    def length(self) -> int:
        return len(self.actions)
    
    @property
    def is_deterministic(self) -> bool:
        """True if this macro represents a deterministic sequence."""
        return self.avg_entropy < 0.1


class MacroCache:
    """
    Entropy-based macro discovery from MCTS rollouts.
    
    Key insight: Macros are sequences where the model predicts
    transitions with LOW ENTROPY (high confidence). These represent
    compressible causal structure - the "rules" of the game.
    
    Discovery process:
    1. During MCTS, track (state, action, entropy) for each transition
    2. Find sequences where ALL transitions have entropy < threshold
    3. These are macro candidates (deterministic chains)
    4. Promote to macros when seen multiple times
    """
    
    def __init__(self, entropy_threshold: float = 0.5, 
                 min_occurrences: int = 3, 
                 min_length: int = 2, 
                 max_length: int = 6):
        self.entropy_threshold = entropy_threshold
        self.min_occurrences = min_occurrences
        self.min_length = min_length
        self.max_length = max_length
        
        # Candidates: pattern -> {count, total_entropy, total_reward}
        self.candidates = defaultdict(lambda: {
            "count": 0, "total_entropy": 0.0, "total_reward": 0.0
        })
        self.macros: Dict[tuple, MacroOperator] = {}
        
        # Statistics
        self.total_transitions = 0
        self.low_entropy_transitions = 0
    
    def add_rollout(self, actions: List[int], entropies: List[float], 
                    rewards: List[float]):
        """
        Extract macro candidates from an MCTS rollout.
        
        Only considers sequences where ALL transitions have
        entropy below threshold (deterministic chains).
        """
        n = len(actions)
        if n < self.min_length:
            return
        
        self.total_transitions += n
        self.low_entropy_transitions += sum(1 for e in entropies if e < self.entropy_threshold)
        
        # Find low-entropy windows
        for length in range(self.min_length, min(self.max_length + 1, n + 1)):
            for i in range(n - length + 1):
                window_entropies = entropies[i:i + length]
                
                # ALL transitions must be low-entropy (deterministic)
                if all(e < self.entropy_threshold for e in window_entropies):
                    pattern = tuple(actions[i:i + length])
                    
                    # Skip uniform patterns
                    if len(set(pattern)) <= 1:
                        continue
                    
                    avg_entropy = sum(window_entropies) / length
                    pattern_reward = sum(rewards[i:i + length])
                    
                    self.candidates[pattern]["count"] += 1
                    self.candidates[pattern]["total_entropy"] += avg_entropy
                    self.candidates[pattern]["total_reward"] += pattern_reward
                    
                    # Promote to macro
                    if (self.candidates[pattern]["count"] >= self.min_occurrences 
                        and pattern not in self.macros):
                        c = self.candidates[pattern]
                        self.macros[pattern] = MacroOperator(
                            actions=pattern,
                            avg_entropy=c["total_entropy"] / c["count"],
                            count=c["count"],
                            total_reward=c["total_reward"],
                        )
    
    def get_statistics(self) -> Dict[str, Any]:
        deterministic_rate = (self.low_entropy_transitions / max(self.total_transitions, 1))
        return {
            "num_macros": len(self.macros),
            "num_candidates": len(self.candidates),
            "total_transitions": self.total_transitions,
            "low_entropy_transitions": self.low_entropy_transitions,
            "deterministic_rate": deterministic_rate,
        }
    
    def get_top_macros(self, n: int = 10) -> List[MacroOperator]:
        """Get top macros by count (most frequent deterministic patterns)."""
        sorted_macros = sorted(self.macros.values(), key=lambda m: m.count, reverse=True)
        return sorted_macros[:n]
    
    def decode_macro(self, macro: MacroOperator, game: AtariGame) -> str:
        names = [game.action_name(a) for a in macro.actions]
        return " -> ".join(names)


# Test
cache = MacroCache(entropy_threshold=0.5, min_occurrences=3)
print(f"MacroCache initialized")
print(f"  Entropy threshold: {cache.entropy_threshold}")
print(f"  Min occurrences: {cache.min_occurrences}")
print(f"  Pattern length: {cache.min_length}-{cache.max_length}")
print()
print("Macro discovery now based on MODEL ENTROPY, not just action frequency!")
print("Low entropy = model confidently predicts transition = deterministic = macro")


## 5. Self-Play with Macro Discovery

In [None]:
def play_episode(game: AtariGame, model: MuZeroNetwork, device: torch.device,
                 max_steps: int = 1000, epsilon: float = 0.1,
                 temperature: float = 1.0) -> Tuple[List, List, List, List, int]:
    """
    Play one episode, tracking model entropy for each transition.
    
    Returns: (observations, actions, rewards, entropies, final_score)
    
    The entropies list contains the model's uncertainty for each transition.
    Low entropy = deterministic = good for macro discovery.
    """
    model.eval()
    state = game.reset()
    
    observations = []
    actions = []
    rewards = []
    entropies = []  # Model's predicted entropy for each transition
    
    hidden_state = None
    
    for step in range(max_steps):
        if state.done:
            break
        
        obs = game.encode(state).to(device).unsqueeze(0)
        observations.append(obs.cpu())
        
        with torch.no_grad():
            hidden_state, policy_logits, _ = model.initial_inference(obs)
        
        legal = game.legal_actions(state)
        if not legal:
            break
        
        # Epsilon-greedy with softmax
        if np.random.random() < epsilon:
            action = np.random.choice(legal)
        else:
            probs = F.softmax(policy_logits[0] / temperature, dim=0).cpu().numpy()
            mask = np.zeros(game.action_space_size)
            mask[legal] = 1
            probs = probs * mask
            if probs.sum() > 0:
                probs = probs / probs.sum()
                action = np.random.choice(game.action_space_size, p=probs)
            else:
                action = np.random.choice(legal)
        
        # Get entropy for this transition from the model
        with torch.no_grad():
            action_tensor = torch.tensor([action], device=device)
            _, _, _, _, entropy = model.recurrent_inference(hidden_state, action_tensor)
            entropies.append(entropy.item())
        
        actions.append(action)
        state, reward, done = game.step(state, action)
        rewards.append(reward)
    
    return observations, actions, rewards, entropies, state.score


# Test episode with entropy tracking
obs, acts, rews, ents, score = play_episode(game, model, DEVICE, max_steps=100, epsilon=1.0)
print(f"Test episode: {len(acts)} steps, score={score}")
print(f"Entropy stats: min={min(ents):.3f}, max={max(ents):.3f}, mean={np.mean(ents):.3f}")
print(f"  (Untrained model has high entropy - will decrease as it learns)")


## 6. Training Loop

In [None]:
import torch.optim as optim
from tqdm.auto import tqdm
import matplotlib.pyplot as plt


def train_muzero(game: AtariGame, model: MuZeroNetwork, 
                 macro_cache: MacroCache, device: torch.device,
                 num_iterations: int = 50, episodes_per_iter: int = 3,
                 max_steps: int = 1000, batch_size: int = 64,
                 lr: float = 1e-4):
    """
    Train MuZero with entropy-based macro discovery.
    
    Key insight: As the model learns the game dynamics, its
    prediction entropy decreases for deterministic transitions.
    These low-entropy sequences become macro candidates.
    """
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # Replay buffer
    replay_obs = []
    replay_actions = []
    replay_rewards = []
    replay_entropies = []
    max_buffer = 10000
    
    # Tracking
    history = {
        "scores": [], "lengths": [], "losses": [], 
        "macros": [], "avg_entropy": [], "deterministic_rate": []
    }
    
    print(f"Training for {num_iterations} iterations")
    print(f"Macro discovery based on MODEL ENTROPY (not action frequency)")
    print(f"  Entropy threshold: {macro_cache.entropy_threshold}")
    print()
    
    for iteration in tqdm(range(num_iterations), desc="Training"):
        epsilon = max(0.05, 1.0 - iteration / num_iterations)
        temperature = max(0.5, 2.0 - iteration / (num_iterations / 2))
        
        iter_scores = []
        iter_lengths = []
        iter_entropies = []
        
        # Self-play with entropy tracking
        for _ in range(episodes_per_iter):
            obs, acts, rews, ents, score = play_episode(
                game, model, device, max_steps=max_steps,
                epsilon=epsilon, temperature=temperature
            )
            
            replay_obs.extend(obs)
            replay_actions.extend(acts)
            replay_rewards.extend(rews)
            replay_entropies.extend(ents)
            
            # Discover macros from LOW-ENTROPY sequences
            macro_cache.add_rollout(acts, ents, rews)
            
            iter_scores.append(score)
            iter_lengths.append(len(acts))
            iter_entropies.extend(ents)
        
        # Trim buffer
        if len(replay_obs) > max_buffer:
            replay_obs = replay_obs[-max_buffer:]
            replay_actions = replay_actions[-max_buffer:]
            replay_rewards = replay_rewards[-max_buffer:]
            replay_entropies = replay_entropies[-max_buffer:]
        
        # Training step
        if len(replay_obs) >= batch_size:
            model.train()
            
            indices = np.random.choice(len(replay_obs), batch_size, replace=False)
            batch_obs = torch.cat([replay_obs[i] for i in indices]).to(device)
            batch_actions = torch.tensor([replay_actions[i] for i in indices], device=device)
            batch_rewards = torch.tensor([replay_rewards[i] for i in indices], 
                                         device=device, dtype=torch.float32)
            
            # Forward
            hidden, policy_logits, values = model.initial_inference(batch_obs)
            
            # Policy loss
            policy_loss = F.cross_entropy(policy_logits, batch_actions)
            
            # Value loss
            value_loss = F.mse_loss(values, batch_rewards)
            
            # Dynamics + entropy loss
            next_hidden, pred_reward, _, _, pred_entropy = model.recurrent_inference(
                hidden, batch_actions
            )
            reward_loss = F.mse_loss(pred_reward, batch_rewards)
            
            # Entropy regularization: encourage low entropy for consistent dynamics
            # But not too low (prevent collapse)
            entropy_target = 0.1  # Target entropy for deterministic Atari
            entropy_loss = F.mse_loss(pred_entropy, 
                                      torch.full_like(pred_entropy, entropy_target))
            
            total_loss = policy_loss + 0.5 * value_loss + 0.5 * reward_loss + 0.1 * entropy_loss
            
            optimizer.zero_grad()
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            history["losses"].append(total_loss.item())
        
        # Record
        history["scores"].append(np.mean(iter_scores))
        history["lengths"].append(np.mean(iter_lengths))
        history["macros"].append(len(macro_cache.macros))
        history["avg_entropy"].append(np.mean(iter_entropies) if iter_entropies else 0)
        
        stats = macro_cache.get_statistics()
        history["deterministic_rate"].append(stats["deterministic_rate"])
        
        if iteration % 10 == 0:
            print(f"\nIter {iteration}: Score={np.mean(iter_scores):.1f}, "
                  f"Entropy={np.mean(iter_entropies):.3f}, "
                  f"DetRate={stats['deterministic_rate']:.1%}, "
                  f"Macros={stats['num_macros']}")
    
    return history


## 7. Train on Breakout

In [None]:
# Initialize
breakout = AtariGame("Breakout")
model = MuZeroNetwork(breakout.observation_dim, breakout.action_space_size).to(DEVICE)
macro_cache = MacroCache(min_occurrences=5, min_length=2, max_length=6)

print(f"Training on: {breakout.game_name}")
print(f"Observation dim: {breakout.observation_dim}")
print(f"Actions: {breakout.action_space_size}")
print(f"Deterministic: {breakout.deterministic}")
print()

# Train
history = train_muzero(
    game=breakout,
    model=model,
    macro_cache=macro_cache,
    device=DEVICE,
    num_iterations=50,
    episodes_per_iter=3,
    max_steps=1000,
)

In [None]:
# Plot training results
fig, axes = plt.subplots(2, 3, figsize=(15, 8))

axes[0, 0].plot(history["scores"])
axes[0, 0].set_xlabel("Iteration")
axes[0, 0].set_ylabel("Score")
axes[0, 0].set_title("Breakout Score")

axes[0, 1].plot(history["lengths"])
axes[0, 1].set_xlabel("Iteration")
axes[0, 1].set_ylabel("Episode Length")
axes[0, 1].set_title("Episode Length")

axes[0, 2].plot(history["losses"])
axes[0, 2].set_xlabel("Training Step")
axes[0, 2].set_ylabel("Loss")
axes[0, 2].set_title("Training Loss")

axes[1, 0].plot(history["macros"])
axes[1, 0].set_xlabel("Iteration")
axes[1, 0].set_ylabel("Macros")
axes[1, 0].set_title("Macro Discovery")

axes[1, 1].plot(history["avg_entropy"])
axes[1, 1].set_xlabel("Iteration")
axes[1, 1].set_ylabel("Entropy")
axes[1, 1].set_title("Model Entropy (lower = more deterministic)")
axes[1, 1].axhline(y=macro_cache.entropy_threshold, color='r', linestyle='--', label='Threshold')
axes[1, 1].legend()

axes[1, 2].plot(history["deterministic_rate"])
axes[1, 2].set_xlabel("Iteration")
axes[1, 2].set_ylabel("Rate")
axes[1, 2].set_title("Deterministic Transition Rate")
axes[1, 2].set_ylim(0, 1)

plt.tight_layout()
plt.show()


## 8. Macro Analysis

In [None]:
print("=" * 60)
print("ENTROPY-BASED MACRO ANALYSIS: Breakout")
print("=" * 60)

stats = macro_cache.get_statistics()
print(f"\nTotal transitions analyzed: {stats['total_transitions']}")
print(f"Low-entropy transitions: {stats['low_entropy_transitions']} ({stats['deterministic_rate']:.1%})")
print(f"\nMacros discovered: {stats['num_macros']}")
print(f"Candidate patterns: {stats['num_candidates']}")

if stats['num_macros'] > 0:
    print("\nTop 15 Macros (by frequency of low-entropy occurrence):")
    print("-" * 60)
    
    for i, macro in enumerate(macro_cache.get_top_macros(15)):
        decoded = macro_cache.decode_macro(macro, breakout)
        det_str = "DETERMINISTIC" if macro.is_deterministic else "partial"
        print(f"{i+1:2d}. {decoded}")
        print(f"    Count: {macro.count}, Avg Entropy: {macro.avg_entropy:.4f} [{det_str}]")
        print(f"    Avg Reward: {macro.avg_reward:.2f}")
else:
    print("\nNo macros discovered yet.")
    print("This could mean:")
    print("  1. Model hasn't learned deterministic dynamics yet (train longer)")
    print("  2. Entropy threshold too low (try 1.0 or higher)")

print("\n" + "=" * 60)
print("KEY INSIGHT")
print("=" * 60)
print("""
Macros are now discovered based on MODEL ENTROPY, not just action frequency.

When the model learns to predict transitions with HIGH CONFIDENCE (low entropy),
those sequences become macro candidates. This captures:

- Deterministic physics (ball trajectory after paddle hit)
- Predictable game rules (FIRE always launches ball)
- Learned control patterns (paddle positioning sequences)

The "deterministic rate" shows what fraction of transitions the model
considers predictable. As training progresses, this should increase
for deterministic games like Breakout.
""")


## 9. Train on Other Games (Optional)

In [None]:
# Uncomment to train on Pong

# pong = AtariGame("Pong")
# pong_model = MuZeroNetwork(pong.observation_dim, pong.action_space_size).to(DEVICE)
# pong_cache = MacroCache(min_occurrences=5)

# pong_history = train_muzero(
#     game=pong,
#     model=pong_model,
#     macro_cache=pong_cache,
#     device=DEVICE,
#     num_iterations=50,
#     episodes_per_iter=3,
# )

# print("\nPong Macros:")
# for macro in pong_cache.get_top_macros(10):
#     print(f"  {pong_cache.decode_macro(macro, pong)} (count={macro.count})")

In [None]:
# Uncomment to train on Space Invaders

# invaders = AtariGame("SpaceInvaders")
# invaders_model = MuZeroNetwork(invaders.observation_dim, invaders.action_space_size).to(DEVICE)
# invaders_cache = MacroCache(min_occurrences=5)

# invaders_history = train_muzero(
#     game=invaders,
#     model=invaders_model,
#     macro_cache=invaders_cache,
#     device=DEVICE,
#     num_iterations=50,
#     episodes_per_iter=3,
# )

# print("\nSpace Invaders Macros:")
# for macro in invaders_cache.get_top_macros(10):
#     print(f"  {invaders_cache.decode_macro(macro, invaders)} (count={macro.count})")

## 10. Summary

### Key Insights

1. **Deterministic Atari** games are ideal for macro discovery because:
   - All transitions have entropy ≈ 0 (no sticky actions)
   - Physics are predictable (ball bounces, gravity)
   - Agent controls all actions (no opponent uncertainty)

2. **Macros as Compressible Causal Structure**:
   - Frequent action patterns represent reusable temporal abstractions
   - These are "rules" discovered from dynamics, not symbolic definitions
   - Longer training reveals more meaningful patterns

3. **Next Steps**:
   - Use macros to accelerate MCTS planning
   - Learn macro preconditions (when to apply)
   - Transfer macros across similar games