# Experiment: AlphaSnake 10x10 Vast.ai Paper-Fidelity

## Objetivo
Entrenar un agente Snake estilo AlphaZero (MCTS + Policy/Value Net) reproduciendo condiciones del paper para 10x10, y exportar ONNX para integración directa con `games/snake/ai.html`.

## Criterios de éxito
- Contrato de entorno paper: 10x10, rewards `+1/-1/0`, sin shaping.
- Champion update por `head-to-head MCTS > 55%`.
- Estrategia faseada en Vast.ai (warm-up -> paper) con resume por checkpoint.
- Export ONNX verificado para uso en navegador.

**Takeaway:** este notebook es el artefacto reproducible end-to-end para entrenar y desplegar AlphaSnake 10x10.


In [None]:

# ============================================================
# Cell 1: Setup — Vast.ai / Colab / local
# ============================================================
import os
import torch

ROOT_DIR = os.getcwd()
if os.path.exists('/workspace'):
    # Vast.ai (persistente)
    SAVE_DIR = '/workspace/alphasnake'
elif os.path.exists('/content'):
    try:
        from google.colab import drive
        drive.mount('/content/drive')
        SAVE_DIR = '/content/drive/MyDrive/alphasnake'
    except ImportError:
        SAVE_DIR = '/content/alphasnake'
else:
    SAVE_DIR = os.path.join(ROOT_DIR, 'alphasnake')

os.makedirs(SAVE_DIR, exist_ok=True)
print(f'PyTorch {torch.__version__} | CUDA: {torch.cuda.is_available()} | Checkpoints: {SAVE_DIR}')


## Setup completado
Se detecta entorno (`/workspace`, Colab o local) y se define `SAVE_DIR` persistente para checkpoints.

**Takeaway:** todos los artefactos de entrenamiento quedan centralizados y resumibles.


In [None]:

# ============================================================
# Cell 2: Imports y Configuracion (paper-fidelity + faseado)
# ============================================================
import numpy as np
import random
import math
import time
import pickle
import subprocess
import threading
import multiprocessing
import dataclasses
from dataclasses import dataclass
from collections import deque
from queue import Empty
from typing import Optional

from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

# -----------------------------------------------------------
# Perfil de corrida (Vast.ai RTX 4060 Ti + 4.8 vCPU)
# -----------------------------------------------------------
PRESET = 'vastai_phased'  # 'paper_only' | 'vastai_phased' | 'smoke'

@dataclass
class Config:
    # Entorno
    board_size: int = 10
    max_steps: int = 1000

    # Red
    net_channels: int = 64
    net_blocks: int = 6

    # MCTS (paper)
    num_simulations: int = 400
    c_puct: float = 1.0
    dir_alpha: float = 0.03
    dir_eps: float = 0.25
    temp_init: float = 1.0
    temp_final: float = 0.0
    temp_decay_move: int = 30
    food_samples: int = 8

    # Training
    lr: float = 1e-3
    weight_decay: float = 1e-4
    batch_size: int = 128
    buffer_size: int = 200_000
    epochs_per_iter: int = 10

    # Self-play / evaluación efectiva por iteración
    games_per_iter: int = 1000
    eval_games: int = 200

    # Iteraciones
    max_iterations: int = 37
    warmup_iterations: int = 12

    # Fase A (warm-up)
    warmup_num_simulations: int = 220
    warmup_games_per_iter: int = 360
    warmup_eval_games: int = 80
    warmup_food_samples: int = 6

    # Fase B (paper)
    paper_num_simulations: int = 400
    paper_games_per_iter: int = 1000
    paper_eval_games: int = 200
    paper_food_samples: int = 8

    # Champion / evaluación
    accept_threshold: float = 0.55  # new model must win >55% head-to-head
    target_win_rate: float = 0.94
    target_win_rate_patience: int = 3

    # Infra
    save_dir: str = SAVE_DIR
    checkpoint_interval: int = 1
    selfplay_workers: int = 4
    selfplay_backend: str = 'process'  # 'process' recomendado para evitar GIL
    inference_batch_size: int = 128
    inference_timeout_ms: int = 8
    use_amp: bool = True

    # Logging
    verbose_game_updates: bool = False
    game_log_interval: int = 1
    game_tick_moves: int = 100
    heartbeat_seconds: int = 30


def build_config(preset: str) -> Config:
    if preset == 'paper_only':
        return Config(
            max_iterations=37,
            warmup_iterations=0,
            warmup_num_simulations=400,
            warmup_games_per_iter=1000,
            warmup_eval_games=200,
            warmup_food_samples=8,
            paper_num_simulations=400,
            paper_games_per_iter=1000,
            paper_eval_games=200,
            paper_food_samples=8,
        )

    if preset == 'smoke':
        return Config(
            max_iterations=2,
            warmup_iterations=2,
            warmup_num_simulations=50,
            warmup_games_per_iter=20,
            warmup_eval_games=10,
            warmup_food_samples=2,
            paper_num_simulations=50,
            paper_games_per_iter=20,
            paper_eval_games=10,
            paper_food_samples=2,
            buffer_size=25_000,
            inference_batch_size=32,
        )

    # vastai_phased (recomendado)
    return Config()


config = build_config(PRESET)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.set_float32_matmul_precision('high')


def get_iteration_config(cfg: Config, iteration: int) -> Config:
    # Devuelve config efectiva para la iteración (1-based).
    if iteration <= cfg.warmup_iterations:
        return dataclasses.replace(
            cfg,
            num_simulations=cfg.warmup_num_simulations,
            games_per_iter=cfg.warmup_games_per_iter,
            eval_games=cfg.warmup_eval_games,
            food_samples=cfg.warmup_food_samples,
        )

    return dataclasses.replace(
        cfg,
        num_simulations=cfg.paper_num_simulations,
        games_per_iter=cfg.paper_games_per_iter,
        eval_games=cfg.paper_eval_games,
        food_samples=cfg.paper_food_samples,
    )

print(f'Device: {device}')
print(f'Preset: {PRESET}')
print('Config base:')
print(dataclasses.asdict(config))
print('Iter 1 cfg:', dataclasses.asdict(get_iteration_config(config, 1)))
print('Iter final cfg:', dataclasses.asdict(get_iteration_config(config, config.max_iterations)))


## Entorno Snake 10x10
Contrato exacto: acciones discretas 4, reversa prohibida, observación `(4,10,10)` y rewards `+1/-1/0`.

**Takeaway:** mantener este contrato es clave para reproducir comportamiento tipo paper.


In [None]:
# ============================================================
# Cell 3: Entorno Snake
# ============================================================

class SnakeEnv:
    """
    Snake environment estilo AlphaSnake.
    Grid board_size x board_size, 4 acciones, observacion (4, H, W).
    """
    # 0=UP, 1=DOWN, 2=LEFT, 3=RIGHT
    ACTIONS = {0: (0, -1), 1: (0, 1), 2: (-1, 0), 3: (1, 0)}
    OPPOSITES = {0: 1, 1: 0, 2: 3, 3: 2}
    ACTION_NAMES = {0: 'UP', 1: 'DOWN', 2: 'LEFT', 3: 'RIGHT'}

    def __init__(self, board_size=10, max_steps=1000):
        self.board_size = board_size
        self.max_steps = max_steps
        self.max_score = board_size * board_size - 3  # 97 para 10x10
        self.reset()

    def reset(self):
        mid = self.board_size // 2
        # Snake empieza en el centro, apuntando a la derecha
        self.snake = deque([(mid, mid), (mid - 1, mid), (mid - 2, mid)])
        self.direction = 3  # RIGHT
        self.food = None
        self.done = False
        self.steps = 0
        self.score = 0
        self._place_food()
        return self.get_state()

    def _get_free_cells(self):
        snake_set = set(self.snake)
        return [(x, y) for y in range(self.board_size)
                for x in range(self.board_size)
                if (x, y) not in snake_set]

    def _place_food(self):
        free = self._get_free_cells()
        if free:
            self.food = random.choice(free)
        else:
            # Tablero lleno = victoria
            self.food = None
            self.done = True

    def get_state(self):
        """Retorna observacion (4, board_size, board_size) float32."""
        s = np.zeros((4, self.board_size, self.board_size), dtype=np.float32)
        # Canal 0: cuerpo
        for x, y in self.snake:
            s[0, y, x] = 1.0
        # Canal 1: cabeza
        hx, hy = self.snake[0]
        s[1, hy, hx] = 1.0
        # Canal 2: comida
        if self.food:
            fx, fy = self.food
            s[2, fy, fx] = 1.0
        # Canal 3: direccion (valor constante normalizado)
        # UP=0.25, DOWN=0.5, LEFT=0.75, RIGHT=1.0
        s[3, :, :] = (self.direction + 1) / 4.0
        return s

    def valid_actions(self):
        """Acciones validas (excluye reversa directa)."""
        rev = self.OPPOSITES[self.direction]
        return [a for a in range(4) if a != rev]

    def step(self, action):
        """
        Ejecuta accion. Retorna (reward, done).
        reward: +1.0 comida, -1.0 muerte, 0.0 otro.
        """
        if self.done:
            return 0.0, True

        # Prohibir reversa directa: si intenta reversa, mantener direccion
        if action == self.OPPOSITES[self.direction]:
            action = self.direction

        self.direction = action
        dx, dy = self.ACTIONS[action]
        hx, hy = self.snake[0]
        nx, ny = hx + dx, hy + dy

        # Colision con pared
        if nx < 0 or nx >= self.board_size or ny < 0 or ny >= self.board_size:
            self.done = True
            return -1.0, True

        # Verificar si comera comida (antes de mover)
        ate_food = self.food is not None and (nx, ny) == self.food

        # Colision con cuerpo
        # Si NO come comida, la cola se movera, asi que es valido ir a la pos actual de la cola
        body_set = set(self.snake)
        if not ate_food:
            body_set.discard(self.snake[-1])
        if (nx, ny) in body_set:
            self.done = True
            return -1.0, True

        # Mover
        self.snake.appendleft((nx, ny))
        if ate_food:
            self.score += 1
            self._place_food()
            if self.done:  # _place_food puso done=True si el tablero esta lleno
                return 1.0, True
        else:
            self.snake.pop()

        self.steps += 1
        if self.steps >= self.max_steps:
            self.done = True
            return 0.0, True

        return (1.0 if ate_food else 0.0), False

    def is_win(self):
        """True si la serpiente lleno todo el tablero."""
        return len(self.snake) >= self.board_size ** 2

    def clone(self):
        """Copia eficiente del entorno."""
        env = SnakeEnv.__new__(SnakeEnv)
        env.board_size = self.board_size
        env.max_steps = self.max_steps
        env.max_score = self.max_score
        env.snake = deque(self.snake)
        env.direction = self.direction
        env.food = self.food
        env.done = self.done
        env.steps = self.steps
        env.score = self.score
        return env

# --- Test rapido ---
env = SnakeEnv(10, 1000)
s = env.reset()
print(f"State shape: {s.shape}")
print(f"Snake length: {len(env.snake)}, food: {env.food}")
print(f"Valid actions: {[env.ACTION_NAMES[a] for a in env.valid_actions()]}")

# Jugar unos pasos random
for _ in range(20):
    a = random.choice(env.valid_actions())
    r, done = env.step(a)
    if done:
        print(f"Game over! Score: {env.score}, Win: {env.is_win()}")
        break
else:
    print(f"After 20 steps: score={env.score}, length={len(env.snake)}")


## Red Policy/Value
ResNet 6 bloques (64 canales) con cabezas policy (4 acciones) y value `tanh`.

**Takeaway:** arquitectura compacta y estable para 10x10, consistente con configuración paper reducida.


In [None]:
# ============================================================
# Cell 4: Red Neuronal — AlphaSnakeNet
# ============================================================

class ResBlock(nn.Module):
    """Bloque residual: Conv→BN→ReLU→Conv→BN→skip→ReLU"""
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = F.relu(out + residual)
        return out


class AlphaSnakeNet(nn.Module):
    """
    Policy + Value network estilo AlphaZero.
    Input:  (batch, 4, board_size, board_size)
    Output: policy (batch, 4), value (batch, 1)
    """
    def __init__(self, in_channels=4, num_blocks=6, channels=64, board_size=10):
        super().__init__()
        self.board_size = board_size

        # Stem
        self.stem = nn.Sequential(
            nn.Conv2d(in_channels, channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(channels),
            nn.ReLU()
        )

        # Residual tower
        self.res_tower = nn.Sequential(*[ResBlock(channels) for _ in range(num_blocks)])

        # Policy head
        self.policy_conv = nn.Conv2d(channels, 2, 1, bias=False)
        self.policy_bn = nn.BatchNorm2d(2)
        self.policy_fc = nn.Linear(2 * board_size * board_size, 4)

        # Value head
        self.value_conv = nn.Conv2d(channels, 1, 1, bias=False)
        self.value_bn = nn.BatchNorm2d(1)
        self.value_fc1 = nn.Linear(board_size * board_size, 64)
        self.value_fc2 = nn.Linear(64, 1)

    def forward(self, x):
        # Shared trunk
        x = self.stem(x)
        x = self.res_tower(x)

        # Policy
        p = F.relu(self.policy_bn(self.policy_conv(x)))
        p = p.view(p.size(0), -1)
        p = F.softmax(self.policy_fc(p), dim=1)

        # Value
        v = F.relu(self.value_bn(self.value_conv(x)))
        v = v.view(v.size(0), -1)
        v = F.relu(self.value_fc1(v))
        v = torch.tanh(self.value_fc2(v))

        return p, v

# --- Test ---
net = AlphaSnakeNet(
    in_channels=4,
    num_blocks=config.net_blocks,
    channels=config.net_channels,
    board_size=config.board_size
).to(device)

total_params = sum(p.numel() for p in net.parameters())
print(f"AlphaSnakeNet: {total_params:,} parametros")

dummy = torch.randn(1, 4, config.board_size, config.board_size).to(device)
pol, val = net(dummy)
print(f"Policy shape: {pol.shape}, sum={pol.sum().item():.4f}")
print(f"Value shape: {val.shape}, val={val.item():.4f}")
del net, dummy  # liberar memoria


## MCTS con estocasticidad de comida
UCB clásico AlphaZero, Dirichlet en raíz y `food_samples` para robustez ante spawn aleatorio.

**Takeaway:** esta parte determina gran parte del rendimiento final (>90%).


In [None]:
# ============================================================
# Cell 5: MCTS — Monte Carlo Tree Search
# ============================================================

class MCTSNode:
    """Nodo del arbol MCTS."""
    __slots__ = ['prior', 'visit_count', 'value_sum', 'children',
                 'is_expanded', 'env', 'food_eaten']

    def __init__(self, prior=0.0):
        self.prior = prior
        self.visit_count = 0
        self.value_sum = 0.0
        self.children = {}        # action (int) -> MCTSNode
        self.is_expanded = False
        self.env = None           # SnakeEnv snapshot
        self.food_eaten = False   # True si la transicion a este nodo comio comida

    def value(self):
        if self.visit_count == 0:
            return 0.0
        return self.value_sum / self.visit_count


class MCTS:
    """
    Monte Carlo Tree Search estilo AlphaZero con soporte
    para estocasticidad de comida (food sampling).
    Acepta predict_fn y predict_batch_fn para permitir
    batching de inferencia desde multiples hilos.
    """
    def __init__(self, predict_fn, cfg, predict_batch_fn=None):
        self.predict_fn = predict_fn
        self.predict_batch_fn = predict_batch_fn or (lambda states: [predict_fn(s) for s in states])
        self.c_puct = cfg.c_puct
        self.num_simulations = cfg.num_simulations
        self.dir_alpha = cfg.dir_alpha
        self.dir_eps = cfg.dir_eps
        self.food_samples = cfg.food_samples

    def _predict(self, state_np):
        """Inferir policy y value."""
        return self.predict_fn(state_np)

    def _ucb_score(self, parent, child):
        """UCB score clasico AlphaZero."""
        q = child.value()
        u = self.c_puct * child.prior * math.sqrt(parent.visit_count) / (1 + child.visit_count)
        return q + u

    def _select_child(self, node):
        """Seleccionar hijo con mayor UCB score."""
        best_score = -float('inf')
        best_action = -1
        best_child = None
        for action, child in node.children.items():
            score = self._ucb_score(node, child)
            if score > best_score:
                best_score = score
                best_action = action
                best_child = child
        return best_action, best_child

    def _expand(self, node):
        """
        Expandir nodo hoja: evaluar con la red, crear hijos.
        Maneja estocasticidad de comida promediando valores.
        Retorna el valor del nodo.
        """
        env = node.env
        if env.done:
            return 1.0 if env.is_win() else -1.0

        policy, value = self._predict(env.get_state())

        # --- Food stochasticity (batched via predict_batch_fn) ---
        if node.food_eaten and self.food_samples > 1:
            snake_set = set(env.snake)
            free = [(x, y) for y in range(env.board_size)
                    for x in range(env.board_size)
                    if (x, y) not in snake_set and (x, y) != env.food]
            k = min(self.food_samples - 1, len(free))
            if k > 0:
                sampled_positions = random.sample(free, k)
                food_states = []
                for food_pos in sampled_positions:
                    env_copy = env.clone()
                    env_copy.food = food_pos
                    food_states.append(env_copy.get_state())
                results = self.predict_batch_fn(food_states)
                food_values = [r[1] for r in results]
                value = (value + sum(food_values)) / (1 + len(food_values))

        # Mascara de acciones validas y renormalizacion
        valid = env.valid_actions()
        mask = np.zeros(4, dtype=np.float32)
        for a in valid:
            mask[a] = 1.0
        policy = policy * mask
        total = policy.sum()
        if total > 0:
            policy /= total
        else:
            policy = mask / mask.sum()

        # Crear hijos
        for a in valid:
            node.children[a] = MCTSNode(prior=policy[a])

        node.is_expanded = True
        return value

    def _add_dirichlet_noise(self, node):
        """Agregar Dirichlet noise al root para exploracion."""
        actions = list(node.children.keys())
        if not actions:
            return
        noise = np.random.dirichlet([self.dir_alpha] * len(actions))
        for i, a in enumerate(actions):
            node.children[a].prior = (
                (1 - self.dir_eps) * node.children[a].prior +
                self.dir_eps * noise[i]
            )

    def search(self, root_env, temperature=1.0):
        """
        Ejecutar MCTS completo desde un estado raiz.
        Retorna distribucion de probabilidad sobre acciones (4,).
        """
        root = MCTSNode()
        root.env = root_env.clone()

        # Expandir raiz
        self._expand(root)
        self._add_dirichlet_noise(root)

        # Simulaciones
        for _ in range(self.num_simulations):
            node = root
            path = [node]

            # --- SELECT ---
            while node.is_expanded and node.children:
                action, child = self._select_child(node)

                # Computacion lazy del estado del hijo
                if child.env is None:
                    env_copy = node.env.clone()
                    old_score = env_copy.score
                    env_copy.step(action)
                    child.env = env_copy
                    child.food_eaten = (env_copy.score > old_score and not env_copy.done)

                node = child
                path.append(node)

            # --- EXPAND & EVALUATE ---
            if not node.is_expanded:
                value = self._expand(node)
            else:
                # Nodo terminal (expanded pero sin hijos)
                value = 1.0 if node.env.is_win() else -1.0

            # --- BACKUP ---
            for n in reversed(path):
                n.visit_count += 1
                n.value_sum += value

        # --- Construir distribucion de acciones ---
        visits = np.zeros(4, dtype=np.float32)
        for a, child in root.children.items():
            visits[a] = child.visit_count

        if temperature == 0:
            # Greedy
            best = np.argmax(visits)
            probs = np.zeros(4, dtype=np.float32)
            probs[best] = 1.0
        else:
            # Con temperatura
            visits_temp = np.power(visits, 1.0 / temperature)
            total = visits_temp.sum()
            if total > 0:
                probs = visits_temp / total
            else:
                probs = np.ones(4, dtype=np.float32) / 4.0

        return probs

print("MCTS implementado correctamente.")


## Batching de inferencia para self-play paralelo
Se agrupan requests de múltiples workers para aprovechar mejor GPU.

**Takeaway:** batching mejora throughput en RTX 4060 Ti y reduce cuello de botella en inferencia.


In [None]:
# ============================================================
# Cell 5b: InferenceBatcher — Batching de GPU para self-play paralelo
# ============================================================
import queue as queue_module
import concurrent.futures

class InferenceBatcher:
    """
    Recolecta requests de inferencia de multiples hilos y los
    procesa en batches en la GPU. Esto maximiza la utilizacion
    de la GPU durante self-play paralelo.
    """
    def __init__(self, net, device, max_batch=64, timeout_ms=3):
        self.net = net
        self.device = device
        self.max_batch = max_batch
        self.timeout = timeout_ms / 1000.0
        self._queue = queue_module.Queue()
        self._running = False
        self._thread = None
        self._total_calls = 0
        self._total_batches = 0

    def start(self):
        self._running = True
        self._total_calls = 0
        self._total_batches = 0
        self._thread = threading.Thread(target=self._worker, daemon=True)
        self._thread.start()
        return self

    def stop(self):
        self._running = False
        self._queue.put(None)  # sentinel
        if self._thread:
            self._thread.join(timeout=10)

    def submit(self, state_np):
        """Non-blocking: enviar estado, retorna Future."""
        f = concurrent.futures.Future()
        self._queue.put((state_np, f))
        return f

    def predict(self, state_np):
        """Blocking: enviar estado, esperar resultado (policy, value)."""
        return self.submit(state_np).result()

    def predict_batch(self, states_list):
        """Enviar multiples estados, esperar todos los resultados."""
        if not states_list:
            return []
        futures = [self.submit(s) for s in states_list]
        return [f.result() for f in futures]

    def stats(self):
        avg = self._total_calls / max(self._total_batches, 1)
        return self._total_calls, self._total_batches, avg

    def _worker(self):
        while self._running:
            batch = []
            # Esperar primer request
            try:
                item = self._queue.get(timeout=1.0)
            except queue_module.Empty:
                continue
            if item is None:
                break
            batch.append(item)

            # Recolectar mas requests (hasta max_batch o timeout)
            deadline = time.time() + self.timeout
            while len(batch) < self.max_batch:
                remaining = deadline - time.time()
                if remaining <= 0:
                    break
                try:
                    item = self._queue.get(timeout=max(0.0001, remaining))
                except queue_module.Empty:
                    break
                if item is None:
                    self._running = False
                    break
                batch.append(item)

            if not batch:
                continue

            states = [b[0] for b in batch]
            futures_list = [b[1] for b in batch]
            self._total_calls += len(batch)
            self._total_batches += 1

            try:
                arr = np.stack(states)
                t = torch.as_tensor(arr, device=self.device)
                with torch.no_grad():
                    policies, values = self.net(t)
                p_np = policies.cpu().numpy()
                v_np = values.cpu().numpy()
                for i, f in enumerate(futures_list):
                    f.set_result((p_np[i], float(v_np[i, 0])))
            except Exception as e:
                for f in futures_list:
                    if not f.done():
                        f.set_exception(e)


def make_predict_fn(net, device):
    """Crear funcion de prediccion directa (para uso single-thread, eval)."""
    @torch.no_grad()
    def predict(state_np):
        t = torch.as_tensor(state_np, device=device).unsqueeze(0)
        p, v = net(t)
        return p[0].cpu().numpy(), v.item()
    return predict

print("InferenceBatcher implementado.")


## Self-play y Replay Buffer
Se generan ejemplos `(state, pi, z)` por partida y se acumulan en buffer circular.

**Takeaway:** datos frescos + buffer suficiente son críticos para estabilidad de entrenamiento.


In [None]:
# ============================================================
# Cell 6: Self-Play
# ============================================================

def self_play_game(predict_fn, cfg, predict_batch_fn=None, progress_cb=None):
    """
    Jugar una partida completa usando MCTS.
    predict_fn: callable(state_np) -> (policy_np, value_float)
    predict_batch_fn: callable(list[state_np]) -> list[(policy, value)]
    progress_cb: callable(move_count, score, length) opcional para telemetria.
    Retorna: (examples, is_win, score, num_moves)
    """
    env = SnakeEnv(cfg.board_size, cfg.max_steps)
    mcts = MCTS(predict_fn, cfg, predict_batch_fn)
    env.reset()

    states = []
    policies = []
    move_count = 0

    while not env.done:
        # Temperatura: 1.0 hasta move temp_decay_move, luego 0.0
        temp = cfg.temp_init if move_count < cfg.temp_decay_move else cfg.temp_final

        # Busqueda MCTS
        pi = mcts.search(env, temperature=temp)

        # Guardar datos de entrenamiento
        states.append(env.get_state().copy())
        policies.append(pi.copy())

        # Seleccionar accion
        if temp == 0:
            action = int(np.argmax(pi))
        else:
            action = int(np.random.choice(4, p=pi))

        # Ejecutar
        env.step(action)
        move_count += 1
        if progress_cb is not None:
            tick_every = max(25, int(getattr(cfg, "game_tick_moves", 100)))
            if move_count == 1 or move_count % tick_every == 0:
                progress_cb(move_count, env.score, len(env.snake))

    # Determinar resultado: +1 si gano (lleno tablero), -1 si no
    z = 1.0 if env.is_win() else -1.0

    # Crear ejemplos de entrenamiento
    examples = [(s, p, z) for s, p in zip(states, policies)]
    return examples, env.is_win(), env.score, move_count, len(env.snake)


def run_n_games(n, request_queue, response_queue, progress_queue, worker_id, cfg):
    """
    Ejecutar n partidas de self-play en un proceso worker.
    Usa request_queue/response_queue para obtener (policy, value) del proceso principal.
    Necesario para multiprocessing: cada worker usa su propio nucleo (sin GIL).
    """
    def predict(state_np):
        request_queue.put((worker_id, state_np))
        return response_queue.get()

    def predict_batch(states_list):
        return [predict(s) for s in states_list]

    all_examples = []
    wins = 0
    scores_list = []
    moves_list = []
    lengths_list = []

    if progress_queue is not None:
        progress_queue.put({"type": "worker_start", "worker": worker_id, "games": n})

    for game_idx in range(n):
        if progress_queue is not None:
            progress_queue.put({"type": "game_start", "worker": worker_id, "game_local": game_idx + 1})

        progress_cb = None
        if progress_queue is not None and bool(getattr(cfg, "verbose_game_updates", False)):
            def _progress_cb(move_count, score_now, len_now):
                progress_queue.put({
                    "type": "game_tick",
                    "worker": worker_id,
                    "game_local": game_idx + 1,
                    "moves": move_count,
                    "score": score_now,
                    "length": len_now,
                })
            progress_cb = _progress_cb

        examples, won, score, moves, length = self_play_game(predict, cfg, predict_batch, progress_cb=progress_cb)
        all_examples.extend(examples)
        if won:
            wins += 1
        scores_list.append(score)
        moves_list.append(moves)
        lengths_list.append(length)

        if progress_queue is not None:
            progress_queue.put({
                "type": "game_end",
                "worker": worker_id,
                "game_local": game_idx + 1,
                "won": won,
                "score": score,
                "moves": moves,
                "length": length,
                "examples": len(examples),
            })

    return all_examples, wins, scores_list, moves_list, lengths_list


print("Self-play implementado.")


In [None]:
# ============================================================
# Cell 7: Replay Buffer y Training
# ============================================================

class ReplayBuffer:
    """Buffer circular para almacenar ejemplos de self-play."""
    def __init__(self, max_size):
        self.buffer = deque(maxlen=max_size)

    def add(self, examples):
        """Agregar lista de (state, policy, value) al buffer."""
        self.buffer.extend(examples)

    def sample(self, batch_size):
        """Samplear un batch aleatorio."""
        n = min(batch_size, len(self.buffer))
        indices = np.random.choice(len(self.buffer), size=n, replace=False)
        batch = [self.buffer[i] for i in indices]
        states = np.array([b[0] for b in batch])
        policies = np.array([b[1] for b in batch])
        values = np.array([b[2] for b in batch], dtype=np.float32)
        return states, policies, values

    def __len__(self):
        return len(self.buffer)

    def save(self, path):
        with open(path, 'wb') as f:
            pickle.dump(list(self.buffer), f)

    def load(self, path):
        if os.path.exists(path):
            with open(path, 'rb') as f:
                data = pickle.load(f)
            self.buffer = deque(data, maxlen=self.buffer.maxlen)
            return True
        return False


def train_epoch(net, optimizer, buffer, cfg, device, scaler=None):
    """
    Entrenar una epoca sobre el replay buffer.
    Loss = (z - v)^2 - pi * log(p)
    (weight_decay ya esta en el optimizer como L2)
    """
    net.train()
    total_loss = 0.0
    total_p_loss = 0.0
    total_v_loss = 0.0
    n_batches = 0

    num_batches = max(1, len(buffer) // cfg.batch_size)
    use_amp = bool(getattr(cfg, "use_amp", True) and device.type == "cuda")

    for _ in range(num_batches):
        states, policies, values = buffer.sample(cfg.batch_size)

        states_t = torch.from_numpy(states)
        policies_t = torch.from_numpy(policies)
        values_t = torch.from_numpy(values).unsqueeze(1)

        if device.type == "cuda":
            states_t = states_t.pin_memory().to(device, non_blocking=True)
            policies_t = policies_t.pin_memory().to(device, non_blocking=True)
            values_t = values_t.pin_memory().to(device, non_blocking=True)
        else:
            states_t = states_t.to(device)
            policies_t = policies_t.to(device)
            values_t = values_t.to(device)

        with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=use_amp):
            pred_p, pred_v = net(states_t)

            # Value loss: MSE
            v_loss = F.mse_loss(pred_v, values_t)

            # Policy loss: cross-entropy (pi * log(p))
            p_loss = -torch.mean(torch.sum(policies_t * torch.log(pred_p + 1e-8), dim=1))

            loss = v_loss + p_loss

        optimizer.zero_grad(set_to_none=True)
        if use_amp and scaler is not None:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        total_loss += loss.item()
        total_p_loss += p_loss.item()
        total_v_loss += v_loss.item()
        n_batches += 1

    if n_batches == 0:
        return 0, 0, 0
    return total_loss / n_batches, total_p_loss / n_batches, total_v_loss / n_batches

print("ReplayBuffer y train_epoch implementados.")


## Evaluación y Champion Gate
Evaluación MCTS sin ruido para métrica real y comparación `new vs best` por head-to-head pareado por seed.

**Takeaway:** solo se acepta champion cuando supera umbral `>55%` en confrontación directa.


In [None]:

# ============================================================
# Cell 8: Evaluacion (MCTS, policy y head-to-head champion gate)
# ============================================================

def evaluate_mcts(net, cfg, device, n_games=None):
    # Evaluar modelo usando MCTS (greedy, temp=0).
    if n_games is None:
        n_games = cfg.eval_games
    net.eval()

    wins = 0
    total_score = 0
    scores = []

    eval_cfg = dataclasses.replace(cfg, dir_eps=0.0)

    for _ in tqdm(range(n_games), desc='Eval MCTS', leave=False):
        env = SnakeEnv(eval_cfg.board_size, eval_cfg.max_steps)
        mcts_eval = MCTS(make_predict_fn(net, device), eval_cfg)
        env.reset()

        while not env.done:
            pi = mcts_eval.search(env, temperature=0)
            action = int(np.argmax(pi))
            env.step(action)

        if env.is_win():
            wins += 1
        total_score += env.score
        scores.append(env.score)

    wr = wins / max(n_games, 1)
    avg = total_score / max(n_games, 1)
    return wr, avg, scores


def evaluate_policy_only(net, cfg, device, n_games=100):
    # Evaluacion rapida sin MCTS (solo monitoreo auxiliar).
    net.eval()
    wins = 0
    total_score = 0

    for _ in tqdm(range(n_games), desc='Eval Policy', leave=False):
        env = SnakeEnv(cfg.board_size, cfg.max_steps)
        env.reset()

        while not env.done:
            state_t = torch.from_numpy(env.get_state()).unsqueeze(0).to(device)
            with torch.no_grad():
                policy, _ = net(state_t)

            valid = env.valid_actions()
            mask = torch.zeros(4, device=device)
            for a in valid:
                mask[a] = 1.0
            policy = policy.squeeze() * mask
            total = policy.sum()
            policy = (policy / total) if total > 0 else (mask / mask.sum())

            action = int(torch.argmax(policy).item())
            env.step(action)

        if env.is_win():
            wins += 1
        total_score += env.score

    wr = wins / max(n_games, 1)
    avg = total_score / max(n_games, 1)
    return wr, avg


def _play_single_game_mcts(net, cfg, device, seed):
    # Partida determinística por seed para evaluación pareada.
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    env = SnakeEnv(cfg.board_size, cfg.max_steps)
    mcts_eval = MCTS(make_predict_fn(net, device), cfg)
    env.reset()

    while not env.done:
        pi = mcts_eval.search(env, temperature=0)
        action = int(np.argmax(pi))
        env.step(action)

    return env.is_win(), env.score


def evaluate_head_to_head_mcts(new_net, best_net, cfg, device, n_games=None, seed_offset=10000):
    """
    Champion gate fiel al paper: nuevo vs mejor modelo con MCTS, pareado por seed.
    Aceptar si win_rate del nuevo > accept_threshold.
    """
    if n_games is None:
        n_games = cfg.eval_games

    eval_cfg = dataclasses.replace(cfg, dir_eps=0.0)

    new_wins = 0
    best_wins = 0
    draws = 0

    for i in tqdm(range(n_games), desc='Eval H2H', leave=False):
        seed = seed_offset + i
        new_win, new_score = _play_single_game_mcts(new_net, eval_cfg, device, seed)
        best_win, best_score = _play_single_game_mcts(best_net, eval_cfg, device, seed)

        # Comparacion primaria por score final (mas granular)
        if new_score > best_score:
            new_wins += 1
        elif best_score > new_score:
            best_wins += 1
        else:
            # Desempate por victoria binaria
            if new_win and not best_win:
                new_wins += 1
            elif best_win and not new_win:
                best_wins += 1
            else:
                draws += 1

    compared = max(1, new_wins + best_wins)
    new_wr = new_wins / compared
    return new_wr, new_wins, best_wins, draws


print('Funciones de evaluacion implementadas (incluye head-to-head MCTS).')


## Loop principal de entrenamiento (faseado)
Fase A: warm-up (12 iter) y Fase B: paper (25 iter), con early-stop por `wr>=0.94` en 3 checkpoints consecutivos.

**Takeaway:** faseado reduce costo inicial sin perder objetivo de fidelidad final.


In [None]:

# ============================================================
# Cell 9: Loop Principal (self-play -> train -> H2H -> champion)
# ============================================================

def save_checkpoint(net, best_net, optimizer, iteration, best_win_rate, cfg, buffer, consecutive_target_hits):
    ckpt = {
        'iteration': iteration,
        'net': net.state_dict(),
        'best_net': best_net.state_dict(),
        'optimizer': optimizer.state_dict(),
        'best_win_rate': best_win_rate,
        'consecutive_target_hits': consecutive_target_hits,
        'config': dataclasses.asdict(cfg),
    }
    path = os.path.join(cfg.save_dir, 'latest_checkpoint.pt')
    torch.save(ckpt, path)

    best_path = os.path.join(cfg.save_dir, 'best_model.pt')
    torch.save(best_net.state_dict(), best_path)

    buffer_path = os.path.join(cfg.save_dir, 'buffer.pkl')
    buffer.save(buffer_path)


def load_checkpoint(net, best_net, optimizer, cfg, buffer):
    path = os.path.join(cfg.save_dir, 'latest_checkpoint.pt')
    if not os.path.exists(path):
        return 0, 0.0, 0

    ckpt = torch.load(path, map_location=device)
    net.load_state_dict(ckpt['net'])
    best_net.load_state_dict(ckpt['best_net'])
    optimizer.load_state_dict(ckpt['optimizer'])

    buffer_path = os.path.join(cfg.save_dir, 'buffer.pkl')
    buffer.load(buffer_path)

    iteration = int(ckpt.get('iteration', 0))
    best_wr = float(ckpt.get('best_win_rate', 0.0))
    consecutive = int(ckpt.get('consecutive_target_hits', 0))

    print(f'Checkpoint cargado: iter={iteration}, best_wr={best_wr:.3f}, target_hits={consecutive}')
    print(f'Buffer restaurado: {len(buffer)} ejemplos')
    return iteration, best_wr, consecutive


def train_alphasnake(cfg):
    print('=' * 70)
    print(' AlphaSnake 10x10 Training (Vast.ai phased + paper-fidelity)')
    print('=' * 70)

    net = AlphaSnakeNet(4, cfg.net_blocks, cfg.net_channels, cfg.board_size).to(device)
    best_net = AlphaSnakeNet(4, cfg.net_blocks, cfg.net_channels, cfg.board_size).to(device)
    best_net.load_state_dict(net.state_dict())

    optimizer = torch.optim.Adam(net.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    buffer = ReplayBuffer(cfg.buffer_size)
    scaler = torch.amp.GradScaler(device='cuda', enabled=bool(cfg.use_amp and device.type == 'cuda'))

    start_iter, best_win_rate, consecutive_target_hits = load_checkpoint(net, best_net, optimizer, cfg, buffer)

    training_start = time.time()
    for iteration in range(start_iter, cfg.max_iterations):
        iter_no = iteration + 1
        iter_cfg = get_iteration_config(cfg, iter_no)

        iter_start = time.time()
        phase = 'warm-up' if iter_no <= cfg.warmup_iterations else 'paper'

        print('\n' + '=' * 70)
        print(f'ITER {iter_no}/{cfg.max_iterations} | phase={phase}')
        print(
            f"sims={iter_cfg.num_simulations} | games={iter_cfg.games_per_iter} | "
            f"eval={iter_cfg.eval_games} | food_samples={iter_cfg.food_samples}"
        )
        print('=' * 70)

        # ------------------------------------------------
        # 1) Self-play paralelo
        # ------------------------------------------------
        # ------------------------------------------------
        # 1) Self-play paralelo
        # ------------------------------------------------
        cpu_total = os.cpu_count() or 8
        n_workers = max(1, min(int(iter_cfg.selfplay_workers), cpu_total))
        backend = getattr(iter_cfg, 'selfplay_backend', 'process')

        all_examples = []
        sp_wins = 0
        sp_scores = []
        sp_moves = []
        sp_lengths = []

        best_net.eval()
        sp_start = time.time()

        if backend == 'process':
            mp_manager = multiprocessing.Manager()
            mp_request_queue = mp_manager.Queue()
            response_queues = [mp_manager.Queue() for _ in range(n_workers)]
            progress_queue = mp_manager.Queue()

            def inference_server():
                batch_timeout = max(0.001, float(iter_cfg.inference_timeout_ms) / 1000.0)
                max_batch = max(8, int(iter_cfg.inference_batch_size))
                use_amp = bool(iter_cfg.use_amp and device.type == 'cuda')
                while True:
                    batch = []
                    deadline = time.time() + batch_timeout
                    while len(batch) < max_batch:
                        try:
                            item = mp_request_queue.get(timeout=max(0.001, deadline - time.time()))
                            if item is None:
                                return
                            batch.append(item)
                        except Empty:
                            break
                    if not batch:
                        continue

                    wids = [b[0] for b in batch]
                    states = [b[1] for b in batch]
                    arr = np.stack(states)
                    t = torch.from_numpy(arr)
                    if device.type == 'cuda':
                        t = t.pin_memory().to(device, non_blocking=True)
                    else:
                        t = t.to(device)

                    with torch.no_grad():
                        with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=use_amp):
                            policies, values = best_net(t)

                    p_np = policies.cpu().numpy()
                    v_np = values.cpu().numpy()
                    for i in range(len(batch)):
                        response_queues[wids[i]].put((p_np[i], float(v_np[i, 0])))

            server_thread = threading.Thread(target=inference_server, daemon=True)
            server_thread.start()

            games_per_worker = [
                iter_cfg.games_per_iter // n_workers + (1 if i < iter_cfg.games_per_iter % n_workers else 0)
                for i in range(n_workers)
            ]

            with concurrent.futures.ProcessPoolExecutor(max_workers=n_workers) as pool:
                futures = [
                    pool.submit(run_n_games, games_per_worker[i], mp_request_queue, response_queues[i], progress_queue, i, iter_cfg)
                    for i in range(n_workers)
                    if games_per_worker[i] > 0
                ]
                pending = set(futures)
                games_reported = 0
                heartbeat_s = max(10, int(getattr(iter_cfg, 'heartbeat_seconds', 30)))
                last_heartbeat = time.time()

                with tqdm(total=iter_cfg.games_per_iter, desc=f'Self-play iter {iter_no}', dynamic_ncols=True) as pbar:
                    while pending:
                        drained = 0
                        while True:
                            try:
                                msg = progress_queue.get_nowait()
                            except Empty:
                                break

                            drained += 1
                            if isinstance(msg, dict) and msg.get('type') == 'game_end':
                                games_reported += 1
                                pbar.update(1)

                        done_now, pending = concurrent.futures.wait(
                            pending,
                            timeout=1.0,
                            return_when=concurrent.futures.FIRST_COMPLETED,
                        )

                        if not done_now and drained == 0 and (time.time() - last_heartbeat) >= heartbeat_s:
                            print(
                                f"[Self-play heartbeat] juegos={games_reported}/{iter_cfg.games_per_iter} | workers_activos={len(pending)}",
                                flush=True,
                            )
                            last_heartbeat = time.time()

                        for future in done_now:
                            ex, w, sc, mv, ln = future.result()
                            all_examples.extend(ex)
                            sp_wins += w
                            sp_scores.extend(sc)
                            sp_moves.extend(mv)
                            sp_lengths.extend(ln)

                    # Drenar eventos restantes
                    while True:
                        try:
                            msg = progress_queue.get_nowait()
                        except Empty:
                            break
                        if isinstance(msg, dict) and msg.get('type') == 'game_end':
                            games_reported += 1

                    if pbar.n < iter_cfg.games_per_iter:
                        pbar.update(iter_cfg.games_per_iter - pbar.n)

            for _ in range(n_workers):
                mp_request_queue.put(None)
            server_thread.join(timeout=10)

        else:
            # Fallback: thread backend (más estable, pero limitado por GIL)
            progress_queue = queue_module.Queue()
            batcher = InferenceBatcher(
                best_net,
                device,
                max_batch=max(8, int(iter_cfg.inference_batch_size)),
                timeout_ms=max(1, int(iter_cfg.inference_timeout_ms)),
            ).start()

            def run_n_games_thread(n_games, worker_id):
                examples_all = []
                wins_local = 0
                scores_local = []
                moves_local = []
                lengths_local = []

                for _ in range(n_games):
                    examples, won, score, moves, length = self_play_game(
                        batcher.predict,
                        iter_cfg,
                        predict_batch_fn=batcher.predict_batch,
                        progress_cb=None,
                    )
                    examples_all.extend(examples)
                    if won:
                        wins_local += 1
                    scores_local.append(score)
                    moves_local.append(moves)
                    lengths_local.append(length)
                    progress_queue.put(1)

                return examples_all, wins_local, scores_local, moves_local, lengths_local

            games_per_worker = [
                iter_cfg.games_per_iter // n_workers + (1 if i < iter_cfg.games_per_iter % n_workers else 0)
                for i in range(n_workers)
            ]

            with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as pool:
                futures = [
                    pool.submit(run_n_games_thread, games_per_worker[i], i)
                    for i in range(n_workers)
                    if games_per_worker[i] > 0
                ]
                pending = set(futures)
                games_reported = 0
                heartbeat_s = max(10, int(getattr(iter_cfg, 'heartbeat_seconds', 30)))
                last_heartbeat = time.time()

                with tqdm(total=iter_cfg.games_per_iter, desc=f'Self-play iter {iter_no}', dynamic_ncols=True) as pbar:
                    while pending:
                        drained = 0
                        while True:
                            try:
                                progress_queue.get_nowait()
                                drained += 1
                            except queue_module.Empty:
                                break

                        if drained:
                            games_reported += drained
                            pbar.update(drained)

                        done_now, pending = concurrent.futures.wait(
                            pending,
                            timeout=1.0,
                            return_when=concurrent.futures.FIRST_COMPLETED,
                        )

                        if not done_now and drained == 0 and (time.time() - last_heartbeat) >= heartbeat_s:
                            print(
                                f"[Self-play heartbeat] juegos={games_reported}/{iter_cfg.games_per_iter} | workers_activos={len(pending)}",
                                flush=True,
                            )
                            last_heartbeat = time.time()

                        for future in done_now:
                            ex, w, sc, mv, ln = future.result()
                            all_examples.extend(ex)
                            sp_wins += w
                            sp_scores.extend(sc)
                            sp_moves.extend(mv)
                            sp_lengths.extend(ln)

                    if pbar.n < iter_cfg.games_per_iter:
                        pbar.update(iter_cfg.games_per_iter - pbar.n)

            total_calls, total_batches, avg_batch = batcher.stats()
            batcher.stop()
            print(f"[Inference batcher] calls={total_calls} | batches={total_batches} | avg_batch={avg_batch:.1f}")

        sp_time = time.time() - sp_start
        sp_wr = sp_wins / max(iter_cfg.games_per_iter, 1)
        sp_scores_arr = np.array(sp_scores if sp_scores else [0])
        sp_moves_arr = np.array(sp_moves if sp_moves else [0])
        sp_lengths_arr = np.array(sp_lengths if sp_lengths else [0])

        buffer.add(all_examples)

        print(f'[Self-play] {sp_time:.0f}s | wins={sp_wins}/{iter_cfg.games_per_iter} ({100*sp_wr:.1f}%) | examples={len(all_examples)}')
        print(f"           score avg={sp_scores_arr.mean():.1f} | moves med={np.median(sp_moves_arr):.0f} | len med={np.median(sp_lengths_arr):.0f}")
        print(f'           buffer={len(buffer)}')

        # ------------------------------------------------
        # 2) Training
        # ------------------------------------------------
        train_start = time.time()
        last_loss, last_p, last_v = 0.0, 0.0, 0.0

        for epoch in tqdm(range(iter_cfg.epochs_per_iter), desc=f'Train iter {iter_no}', leave=False):
            last_loss, last_p, last_v = train_epoch(net, optimizer, buffer, iter_cfg, device, scaler=scaler)
            if (epoch + 1) in {1, iter_cfg.epochs_per_iter}:
                tqdm.write(f'  epoch {epoch+1}/{iter_cfg.epochs_per_iter}: loss={last_loss:.4f} (p={last_p:.4f}, v={last_v:.4f})')

        train_time = time.time() - train_start

        # ------------------------------------------------
        # 3) Champion gate (head-to-head MCTS > 55%)
        # ------------------------------------------------
        eval_start = time.time()
        h2h_wr, new_wins, best_wins, draws = evaluate_head_to_head_mcts(
            net, best_net, iter_cfg, device, n_games=iter_cfg.eval_games
        )

        if h2h_wr > iter_cfg.accept_threshold:
            print(f'[Champion] NUEVO modelo aceptado: h2h_wr={h2h_wr:.3f} ({new_wins}-{best_wins}, draws={draws})')
            best_net.load_state_dict(net.state_dict())
            best_win_rate = max(best_win_rate, h2h_wr)
        else:
            print(f'[Champion] Se mantiene best: h2h_wr={h2h_wr:.3f} ({new_wins}-{best_wins}, draws={draws})')
            net.load_state_dict(best_net.state_dict())

        # Métrica de éxito final (MCTS sobre champion)
        wr_eval, avg_eval, _ = evaluate_mcts(best_net, iter_cfg, device, n_games=iter_cfg.eval_games)
        if wr_eval >= iter_cfg.target_win_rate:
            consecutive_target_hits += 1
        else:
            consecutive_target_hits = 0

        eval_time = time.time() - eval_start

        # ------------------------------------------------
        # 4) Summary + checkpoint + early-stop
        # ------------------------------------------------
        iter_time = time.time() - iter_start
        elapsed = (time.time() - training_start) / 60.0

        print(f'[Summary] train={train_time:.0f}s | eval={eval_time:.0f}s | iter={iter_time:.0f}s | elapsed={elapsed:.1f} min')
        print(f'          champion_eval_wr={wr_eval:.3f}, champion_eval_avg={avg_eval:.1f}, target_hits={consecutive_target_hits}/{iter_cfg.target_win_rate_patience}')

        if iter_no % iter_cfg.checkpoint_interval == 0:
            save_checkpoint(net, best_net, optimizer, iter_no, best_win_rate, cfg, buffer, consecutive_target_hits)
            print(f'[Checkpoint] guardado en {cfg.save_dir} (iter={iter_no})')

        if consecutive_target_hits >= iter_cfg.target_win_rate_patience:
            print(f"[Early stop] Se alcanzo wr>={iter_cfg.target_win_rate:.2f} por {iter_cfg.target_win_rate_patience} checkpoints consecutivos.")
            break

    print('\nEntrenamiento completado.')
    return best_net


print('Loop de entrenamiento definido. Ejecutar la siguiente celda para entrenar.')


In [None]:

# ============================================================
# Cell X: Smoke checks de contrato entorno/modelo
# ============================================================

# 1) shape
env = SnakeEnv(board_size=10, max_steps=1000)
obs = env.reset()
assert obs.shape == (4, 10, 10), obs.shape

# 2) reversa prohibida
start_dir = env.direction
rev = env.OPPOSITES[start_dir]
_, _ = env.step(rev)
assert env.direction == start_dir, 'La reversa directa debe ignorarse'

# 3) rewards validos
env = SnakeEnv(board_size=10, max_steps=1000)
env.reset()
for _ in range(30):
    r, _ = env.step(random.choice(env.valid_actions()))
    assert r in (-1.0, 0.0, 1.0), r

# 4) win condition controlada
env = SnakeEnv(board_size=2, max_steps=20)
env.snake = deque([(1, 1), (0, 1), (0, 0)])
env.direction = 3
env.food = (1, 0)
env.done = False
env.steps = 0
env.score = 0
r, d = env.step(0)  # UP
assert d is True and r == 1.0 and env.is_win(), (r, d, env.is_win())

print('Smoke checks OK.')


In [None]:

# ============================================================
# Cell 10: EJECUTAR ENTRENAMIENTO
# ============================================================
# Ejecuta entrenamiento completo (resume automático si existe checkpoint).
best_model = train_alphasnake(config)


## Evaluación final y exportación
Validación final MCTS y export ONNX con chequeo de equivalencia numérica.

**Takeaway:** el modelo queda listo para deploy en navegador (`onnxruntime-web`).


In [None]:

# ============================================================
# Cell 11: Evaluacion Final con MCTS (champion)
# ============================================================
best_model_path = os.path.join(config.save_dir, 'best_model.pt')
if os.path.exists(best_model_path):
    best_model = AlphaSnakeNet(4, config.net_blocks, config.net_channels, config.board_size).to(device)
    best_model.load_state_dict(torch.load(best_model_path, map_location=device))
    print('Mejor modelo cargado desde save_dir.')

eval_cfg = get_iteration_config(config, config.max_iterations)
eval_cfg = dataclasses.replace(eval_cfg, dir_eps=0.0)

print('\nEvaluacion final MCTS (200 juegos recomendados en fase paper)...')
wr, avg_score, scores = evaluate_mcts(best_model, eval_cfg, device, n_games=eval_cfg.eval_games)
print(f'Win rate: {wr:.3f} ({int(wr*eval_cfg.eval_games)}/{eval_cfg.eval_games})')
print(f'Avg score: {avg_score:.1f} / {config.board_size**2 - 3}')
print(f'Min: {min(scores)}, Median: {np.median(scores):.0f}, Max: {max(scores)}')


In [None]:

# ============================================================
# Cell 12: Exportar ONNX + verificacion PyTorch/ONNX
# ============================================================
import onnx
import onnxruntime as ort

best_model.eval()
best_model.cpu()

dummy_input = torch.randn(1, 4, config.board_size, config.board_size)
onnx_path = os.path.join(config.save_dir, 'alphasnake.onnx')

torch.onnx.export(
    best_model,
    dummy_input,
    onnx_path,
    export_params=True,
    opset_version=13,
    do_constant_folding=True,
    input_names=['state'],
    output_names=['policy', 'value'],
    dynamic_axes={
        'state': {0: 'batch_size'},
        'policy': {0: 'batch_size'},
        'value': {0: 'batch_size'},
    }
)

model_onnx = onnx.load(onnx_path)
onnx.checker.check_model(model_onnx)
print(f'Modelo ONNX exportado y verificado: {onnx_path}')

ort_session = ort.InferenceSession(onnx_path)
test_input = np.random.randn(1, 4, config.board_size, config.board_size).astype(np.float32)

with torch.no_grad():
    pt_policy, pt_value = best_model(torch.from_numpy(test_input))
pt_policy = pt_policy.numpy()
pt_value = pt_value.numpy()

ort_policy, ort_value = ort_session.run(None, {'state': test_input})
print(f'Policy diff max: {np.abs(pt_policy - ort_policy).max():.8f}')
print(f'Value diff max:  {np.abs(pt_value - ort_value).max():.8f}')
print(f"Tamaño ONNX: {os.path.getsize(onnx_path) / (1024*1024):.2f} MB")


## Integración en el juego
Copia `alphasnake.onnx` al directorio del juego y ejecuta `games/snake/ai.html`.

**Checklist rápido:**
- [ ] `ai.html` arranca en MCTS
- [ ] stats de win usan `state.won`
- [ ] score AI representa comidas (máx. 97)

**Takeaway:** la integración debe adaptar el juego al modelo entrenado, no al revés.


In [None]:

# ============================================================
# Cell 13: Artefactos e integración con el juego
# ============================================================
print(f'best_model.pt : {os.path.join(config.save_dir, "best_model.pt")}')
print(f'alphasnake.onnx: {os.path.join(config.save_dir, "alphasnake.onnx")}')
print('Copiar ONNX a: games/snake/ai/alphasnake.onnx')
print('Abrir demo: games/snake/ai.html (modo MCTS por defecto)')
