# Minesweeper LLM Competition - Custom GRPO Training

## Goal
Finetune an LLM with LoRA using GRPO to play Minesweeper by:
- **Input**: JSON game state (board configuration)
- **Output**: JSON action (reveal or flag a cell)

Teams will compete to train the best Minesweeper-playing LLM!

## Training Approach
- **Model**: Qwen2.5-14B-Instruct (from /root/.cache/huggingface/hub)
- **Method**: GRPO (Group Relative Policy Optimization)
- **Framework**: Unsloth (2-6x faster, 70% less VRAM)
- **Hardware**: AMD MI300X GPU (192GB HBM3, ROCm)

# Load Model with Unsloth

Load Qwen3-4B with LoRA configuration:

In [None]:
import os

os.environ["HF_HOME"] = "./workspace/hf_cache"
os.environ["HUGGINGFACE_HUB_CACHE"] = "./workspace/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "./workspace/hf_cache"
os.environ["HF_DATASETS_CACHE"] = "./workspace/hf_cache"


In [None]:
from huggingface_hub import snapshot_download

snapshot_download(
    repo_id="unsloth/Qwen2.5-14B-Instruct",
    local_dir="./workspace/Qwen2.5-14B-Instruct",
    local_dir_use_symlinks=False,
)


In [None]:
from unsloth import FastLanguageModel
import torch
import yaml

# Load config
with open("minesweeper_config_me.yaml", "r") as f:
    config = yaml.safe_load(f)

lora_rank = config.get("lora_rank", 32)

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="/workspace/workspace/Qwen2.5-14B-Instruct",
    load_in_4bit=False,   # AMD → 4bit disabled
    max_seq_length=2048,  # Increased: larger boards need longer prompts
    dtype=torch.bfloat16,
)

print("Model loaded successfully!")
print(f"Device: {model.device}")
print(f"LoRA rank: {lora_rank} (from config)")

# ── Add newline as EOS token so generation stops after first JSON line ──
# (replaces stop_strings which isn't supported by this GRPOConfig version)
newline_token_id = tokenizer.encode("\n", add_special_tokens=False)[-1]
original_eos = tokenizer.eos_token_id
if original_eos != newline_token_id:
    model.generation_config.eos_token_id = [original_eos, newline_token_id]
    model.config.eos_token_id = [original_eos, newline_token_id]
    print(f"  ✅ Added newline (token {newline_token_id}) as additional EOS → stops after JSON line")
    print(f"     EOS tokens: {model.generation_config.eos_token_id}")

# Add LoRA Adapters

Add LoRA layers for efficient finetuning:

In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank,
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha = lora_rank,           # alpha = rank → scaling factor = 1.0 (stable training)
    lora_dropout = 0.05,              # Small dropout for regularization
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
)
print(f"LoRA config: rank={lora_rank}, alpha={lora_rank}, dropout=0.05")
model.print_trainable_parameters()

# Minesweeper Game Implementation

Custom Minesweeper environment supporting:
- Customizable board size and mine count
- Actions: reveal or flag cells
- Win: reveal all safe cells
- Lose: reveal a mine

In [None]:
from dataclasses import dataclass, field
from typing import List, Tuple, Optional, Set
import random
import math

# ──────────────────────────────────────────────────────────────────────
# Board size configuration — competition spec: n,m ∈ [1,50], mines 0-20%
# ──────────────────────────────────────────────────────────────────────

# Constants
MIN_ROWS, MAX_ROWS = 1, 50
MIN_COLS, MAX_COLS = 1, 50
MIN_MINE_DENSITY = 0.0    # 0% mines allowed (trivial board)
MAX_MINE_DENSITY = 0.20   # 20% of total cells


def sample_board_config(rng=None):
    """Sample a random (rows, cols, num_mines) from the full competition range.

    - rows ∈ [1, 50], cols ∈ [1, 50]
    - mines ∈ [0, floor(0.20 * rows * cols)]
    - Uses a weighted distribution favoring smaller boards during training
      (large boards are rare but included for coverage).
    """
    rng = rng or random.Random()

    # Weighted size distribution: favor small/medium, still cover large
    size_band = rng.random()
    if size_band < 0.30:
        # Small: 1-8
        rows = rng.randint(1, 8)
        cols = rng.randint(1, 8)
    elif size_band < 0.55:
        # Medium: 5-15
        rows = rng.randint(5, 15)
        cols = rng.randint(5, 15)
    elif size_band < 0.75:
        # Large: 10-30
        rows = rng.randint(10, 30)
        cols = rng.randint(10, 30)
    elif size_band < 0.90:
        # XL: 20-40
        rows = rng.randint(20, 40)
        cols = rng.randint(20, 40)
    else:
        # Full range: 1-50 (including extreme cases)
        rows = rng.randint(1, 50)
        cols = rng.randint(1, 50)

    total = rows * cols
    max_mines = int(total * MAX_MINE_DENSITY)  # floor(0.20 * total)

    if max_mines == 0:
        num_mines = 0  # Boards too small for any mines at ≤20%
    else:
        num_mines = rng.randint(0, max_mines)

    return rows, cols, num_mines


def mine_density(rows, cols, num_mines):
    """Compute mine density as a fraction."""
    total = rows * cols
    return num_mines / total if total > 0 else 0.0


@dataclass
class MinesweeperGame:
    rows: int
    cols: int
    num_mines: int
    seed: Optional[int] = None
    _rng: random.Random = field(init=False, repr=False)
    _board: List[List[int]] = field(init=False, repr=False)  # -1 = mine, 0-8 = count
    _revealed: Set[Tuple[int, int]] = field(init=False, repr=False, default_factory=set)
    _flagged: Set[Tuple[int, int]] = field(init=False, repr=False, default_factory=set)
    _state: str = field(default="ongoing", init=False, repr=False)
    _move_count: int = field(default=0, init=False, repr=False)

    def __post_init__(self):
        # ── Input validation — competition spec: n,m ∈ [1,50] ──
        if self.rows < MIN_ROWS or self.cols < MIN_COLS:
            raise ValueError(f"Board too small: {self.rows}x{self.cols} (min {MIN_ROWS}x{MIN_COLS})")
        if self.rows > MAX_ROWS or self.cols > MAX_COLS:
            raise ValueError(f"Board too large: {self.rows}x{self.cols} (max {MAX_ROWS}x{MAX_COLS})")
        if self.num_mines < 0:
            raise ValueError(f"num_mines cannot be negative, got {self.num_mines}")
        if self.num_mines >= self.rows * self.cols:
            raise ValueError(f"Too many mines ({self.num_mines}) for {self.rows}x{self.cols} board")
        # 0 mines is allowed — trivial board, instant win on first reveal cascade

        self._rng = random.Random(self.seed)
        self._board = [[0 for _ in range(self.cols)] for _ in range(self.rows)]
        self._place_mines()
        self._calculate_numbers()

        # ── Edge case: 0 mines → all cells are safe, auto-win on init check ──
        self._check_win()

    def _place_mines(self):
        """Place mines randomly on the board."""
        if self.num_mines == 0:
            return  # No mines to place
        positions = [(r, c) for r in range(self.rows) for c in range(self.cols)]
        mine_positions = self._rng.sample(positions, self.num_mines)
        for r, c in mine_positions:
            self._board[r][c] = -1

    def _calculate_numbers(self):
        """Calculate numbers for each cell based on adjacent mines."""
        for r in range(self.rows):
            for c in range(self.cols):
                if self._board[r][c] == -1:
                    continue
                count = 0
                for dr in [-1, 0, 1]:
                    for dc in [-1, 0, 1]:
                        if dr == 0 and dc == 0:
                            continue
                        nr, nc = r + dr, c + dc
                        if 0 <= nr < self.rows and 0 <= nc < self.cols:
                            if self._board[nr][nc] == -1:
                                count += 1
                self._board[r][c] = count

    def _reveal_cell(self, row: int, col: int) -> bool:
        """Reveal a cell. Returns True if valid move, False if invalid.
        Uses iterative flood-fill to avoid recursion limit on large boards.
        """
        if not (0 <= row < self.rows and 0 <= col < self.cols):
            return False
        if (row, col) in self._revealed or (row, col) in self._flagged:
            return False

        stack = [(row, col)]
        while stack:
            r, c = stack.pop()
            if (r, c) in self._revealed:
                continue

            self._revealed.add((r, c))

            # Hit a mine!
            if self._board[r][c] == -1:
                self._state = "failed"
                return True

            # Auto-reveal neighbors if cell is 0
            if self._board[r][c] == 0:
                for dr in [-1, 0, 1]:
                    for dc in [-1, 0, 1]:
                        if dr == 0 and dc == 0:
                            continue
                        nr, nc = r + dr, c + dc
                        if (0 <= nr < self.rows and 0 <= nc < self.cols
                                and (nr, nc) not in self._revealed
                                and (nr, nc) not in self._flagged):
                            stack.append((nr, nc))

        return True

    def _flag_cell(self, row: int, col: int) -> bool:
        """Flag/unflag a cell. Returns True if valid, False if invalid."""
        if not (0 <= row < self.rows and 0 <= col < self.cols):
            return False
        if (row, col) in self._revealed:
            return False

        if (row, col) in self._flagged:
            self._flagged.remove((row, col))
        else:
            self._flagged.add((row, col))
        return True

    def do_action(self, action: dict) -> str:
        """Execute an action and return a status string.

        Returns one of:
          'ok'               - valid move executed
          'mine'             - revealed a mine (game over → state='failed')
          'win'              - game won after this move (all safe cells revealed)
          'invalid_format'   - bad action dict / missing keys / bad types
          'out_of_bounds'    - coordinates outside the board
          'already_revealed' - cell was already revealed
          'flagged_cell'     - tried to reveal a flagged cell
          'invalid_flag'     - tried to flag a revealed cell
          'game_over'        - game was already over before this call

        Only 'mine' sets state='failed'. All other invalid moves
        return an error string but keep the game 'ongoing'.
        NO move limit — game continues until success or failure.
        """
        if self._state != "ongoing":
            return "game_over"

        if not isinstance(action, dict):
            return "invalid_format"

        action_type = action.get("type")
        row = action.get("row")
        col = action.get("col")

        if action_type not in ("reveal", "flag") or row is None or col is None:
            return "invalid_format"

        try:
            row, col = int(row), int(col)
        except (ValueError, TypeError):
            return "invalid_format"

        if not (0 <= row < self.rows and 0 <= col < self.cols):
            return "out_of_bounds"

        if action_type == "reveal":
            if (row, col) in self._revealed:
                return "already_revealed"
            if (row, col) in self._flagged:
                return "flagged_cell"
            self._reveal_cell(row, col)
            self._move_count += 1
        else:  # flag
            if (row, col) in self._revealed:
                return "invalid_flag"
            self._flag_cell(row, col)
            self._move_count += 1

        self._check_win()

        if self._state == "failed":
            return "mine"
        if self._state == "success":
            return "win"
        return "ok"

    def _check_win(self):
        """Check if player has won.

        Win condition: ALL safe (non-mine) cells are revealed.
        With 0 mines, ALL cells are safe → need to reveal every cell.
        """
        if self._state != "ongoing":
            return
        total_cells = self.rows * self.cols
        safe_cells = total_cells - self.num_mines
        if safe_cells == 0:
            self._state = "success"
        elif len(self._revealed) >= safe_cells:
            self._state = "success"

    def get_visible_board(self) -> List[List[str]]:
        """Get board state as player sees it.
        Uses '?' for unrevealed cells (competition standard).
        """
        visible = []
        for r in range(self.rows):
            row = []
            for c in range(self.cols):
                if (r, c) in self._flagged:
                    row.append('F')
                elif (r, c) in self._revealed:
                    val = self._board[r][c]
                    row.append('*' if val == -1 else str(val))
                else:
                    row.append('?')
            visible.append(row)
        return visible

    def state(self) -> str:
        return self._state

    @property
    def move_count(self) -> int:
        return self._move_count

    def get_mine_positions(self) -> Set[Tuple[int, int]]:
        """Return set of all mine positions (for reward computation)."""
        return {(r, c) for r in range(self.rows) for c in range(self.cols)
                if self._board[r][c] == -1}

    def progress(self) -> float:
        """Fraction of safe cells revealed (0.0 to 1.0)."""
        safe_cells = self.rows * self.cols - self.num_mines
        return len(self._revealed) / safe_cells if safe_cells > 0 else 1.0

    def game_phase(self) -> str:
        """Determine the current game phase for prompt selection."""
        if self._move_count == 0 and len(self._revealed) == 0:
            return "opening"
        prog = self.progress()
        if prog >= 0.80:
            return "endgame"
        return "midgame"

    def pretty_print(self) -> str:
        """Pretty print the board."""
        visible = self.get_visible_board()
        lines = []

        # Header — handle up to 2-digit column numbers
        col_width = 3 if self.cols > 10 else 2
        header = "   " + " ".join(f"{i:>{col_width-1}d}" for i in range(self.cols))
        lines.append(header)
        lines.append("  " + "─" * (self.cols * col_width + 1))

        # Board
        for r, row in enumerate(visible):
            sep = " " * (col_width - 1)
            line = f"{r:2d}│ " + sep.join(row)
            lines.append(line)

        return "\n".join(lines)


# ──────────────────────────────────────────────────────────────────────
# Sanity tests
# ──────────────────────────────────────────────────────────────────────
print("Testing MinesweeperGame (competition spec: 1-50 boards, 0-20% mines)...")

# Basic gameplay
g = MinesweeperGame(5, 5, 3, seed=0)
assert g.state() == "ongoing"
assert g.do_action({"type": "reveal", "row": -1, "col": 0}) == "out_of_bounds"
assert g.state() == "ongoing", "BUG: out_of_bounds should NOT end the game"
assert g.do_action({"type": "reveal", "row": 99, "col": 0}) == "out_of_bounds"
assert g.state() == "ongoing"
assert g.do_action({"type": "flag", "row": 0, "col": 0}) == "ok"
assert g.do_action({"type": "reveal", "row": 0, "col": 0}) == "flagged_cell"
assert g.state() == "ongoing", "BUG: flagged_cell should NOT end the game"
assert g.do_action({}) == "invalid_format"
assert g.state() == "ongoing", "BUG: invalid_format should NOT end the game"
print("  ✅ do_action keeps game ongoing on invalid moves")

# Verify '?' is used for unrevealed cells
board = g.get_visible_board()
has_question = any('?' in row for row in board)
assert has_question, "Board should use '?' for unrevealed cells"
print("  ✅ Board uses '?' for unrevealed cells")

# Game phase tracking
assert g.game_phase() == "opening" or g.move_count > 0
print("  ✅ Game phase tracking works")

# ── Edge case: 0 mines board ──
g0 = MinesweeperGame(3, 3, 0, seed=42)
assert g0.state() == "ongoing"
result = g0.do_action({"type": "reveal", "row": 0, "col": 0})
assert result == "win", f"0-mine board should win on first reveal, got {result}"
assert g0.state() == "success"
print("  ✅ 0-mine board → instant win on first reveal")

# ── Edge case: 1x1 board with 0 mines ──
g1x1 = MinesweeperGame(1, 1, 0, seed=42)
assert g1x1.state() == "ongoing"
result = g1x1.do_action({"type": "reveal", "row": 0, "col": 0})
assert result == "win"
print("  ✅ 1x1 board with 0 mines works")

# ── Edge case: 1x1 board — cannot have mines ──
try:
    MinesweeperGame(1, 1, 1, seed=42)
    assert False, "Should have raised ValueError"
except ValueError:
    pass
print("  ✅ 1x1 board with 1 mine correctly rejected")

# ── Edge case: 50x50 board ──
g50 = MinesweeperGame(50, 50, 500, seed=42)
assert g50.state() == "ongoing"
assert g50.rows == 50 and g50.cols == 50
assert mine_density(50, 50, 500) == 0.20
print("  ✅ 50x50 board with 20% mines works")

# ── Edge cases: rectangular ──
g1x50 = MinesweeperGame(1, 50, 10, seed=42)
assert g1x50.rows == 1 and g1x50.cols == 50
g50x1 = MinesweeperGame(50, 1, 10, seed=42)
assert g50x1.rows == 50 and g50x1.cols == 1
print("  ✅ 1x50 and 50x1 boards work")

# ── Edge case: 2x2 with 0 mines ──
g2x2 = MinesweeperGame(2, 2, 0, seed=42)
result = g2x2.do_action({"type": "reveal", "row": 0, "col": 0})
assert result == "win"
print("  ✅ 2x2 board with 0 mines → instant cascade win")

# Variable sizes across the full range
for r, c, m in [(1,1,0), (2,2,0), (3,3,1), (5,5,3), (10,10,20),
                (20,20,80), (30,30,180), (50,50,500), (1,50,10), (50,1,10)]:
    g = MinesweeperGame(r, c, m, seed=42)
    assert g.rows == r and g.cols == c
    board = g.get_visible_board()
    assert len(board) == r and len(board[0]) == c
print(f"  ✅ Variable board sizes (1x1 to 50x50) work")

# Test sample_board_config produces valid configs
rng = random.Random(42)
for _ in range(200):
    r, c, m = sample_board_config(rng)
    assert MIN_ROWS <= r <= MAX_ROWS
    assert MIN_COLS <= c <= MAX_COLS
    assert 0 <= m <= int(r * c * MAX_MINE_DENSITY)
    g = MinesweeperGame(r, c, m, seed=42)
print(f"  ✅ sample_board_config produces valid configs (200 tested)")

# Test progress
g = MinesweeperGame(5, 5, 3, seed=42)
assert g.progress() == 0.0
print(f"  ✅ All game engine tests passed")
print(f"  Board range: {MIN_ROWS}-{MAX_ROWS} rows × {MIN_COLS}-{MAX_COLS} cols")
print(f"  Mine density: {MIN_MINE_DENSITY*100:.0f}%-{MAX_MINE_DENSITY*100:.0f}%")

# Prompt System & Game Logic Helpers

## Unified Prompt (single template for all board sizes)
Same compact prompt for both training and inference. Pre-computed hints (safe/mine cells) do the heavy lifting — no verbose STEP 1-4 reasoning scaffolding.

## 3-Tier Board Representation
| Format | Board Size | Display |
|--------|-----------|---------|
| **A (Small)** | 1–20 | Full grid with borders and headers |
| **B (Medium)** | 21–35 | Frontier cells + revealed number regions |
| **C (Large)** | 36–50 | Quadrant summary + critical area snippets |

## Symbol Legend
`?`=unrevealed  `F`=flagged  `0`-`8`=revealed safe

## Output Format
```json
{"type": "reveal", "row": 2, "col": 3}
```

In [None]:
import json
import re

# ──────────────────────────────────────────────────────────────────────
# Minesweeper Logic Helpers — cached per game state
# ──────────────────────────────────────────────────────────────────────

def _compute_safe_and_mine_cells(game: MinesweeperGame):
    """Compute both safe and mine cells in a SINGLE pass (O(n²) not O(n⁴)).

    Returns (safe_set, mine_set) where each is a set of (row, col) tuples.

    A cell is logically SAFE if any adjacent revealed number has all its
    mines accounted for by flags (remaining_mines == 0).
    A cell is a logically CERTAIN mine if an adjacent number has
    remaining_mines == remaining unrevealed neighbors.
    """
    safe = set()
    mines = set()

    for r in range(game.rows):
        for c in range(game.cols):
            if (r, c) not in game._revealed:
                continue
            val = game._board[r][c]
            if val <= 0:
                continue

            flags = 0
            unrevealed = []
            for dr in [-1, 0, 1]:
                for dc in [-1, 0, 1]:
                    if dr == 0 and dc == 0:
                        continue
                    nr, nc = r + dr, c + dc
                    if 0 <= nr < game.rows and 0 <= nc < game.cols:
                        if (nr, nc) in game._flagged:
                            flags += 1
                        elif (nr, nc) not in game._revealed:
                            unrevealed.append((nr, nc))

            remaining = val - flags

            if remaining == 0 and unrevealed:
                for cell in unrevealed:
                    safe.add(cell)
            elif remaining > 0 and remaining == len(unrevealed):
                for cell in unrevealed:
                    mines.add(cell)

    return safe, mines


def _compute_safe_cells(game: MinesweeperGame) -> list:
    """Return list of [row, col] for logically safe cells."""
    safe, _ = _compute_safe_and_mine_cells(game)
    return [list(c) for c in safe]


def _compute_mine_cells(game: MinesweeperGame) -> list:
    """Return list of [row, col] for logically certain mine cells."""
    _, mines = _compute_safe_and_mine_cells(game)
    return [list(c) for c in mines]


# ──────────────────────────────────────────────────────────────────────
# Board Representation — 3-tier format (A/B/C) based on size
# A: Small (1-20 rows/cols) — full grid
# B: Medium (21-35 rows/cols) — revealed regions + frontier
# C: Large (36-50 rows/cols) — critical areas + frontier summary
# ──────────────────────────────────────────────────────────────────────

def _format_board_small(game: MinesweeperGame) -> str:
    """Format A: Full grid display for boards ≤20 rows/cols."""
    board = game.get_visible_board()
    if game.cols <= 10:
        col_header = "    " + " ".join(f"{c}" for c in range(game.cols))
        separator = "  +" + "-" * (game.cols * 2 + 1) + "+"
        rows_str = []
        for r, row in enumerate(board):
            rows_str.append(f"{r:2d}| " + " ".join(row) + " |")
    else:
        col_header = "    " + " ".join(f"{c:2d}" for c in range(game.cols))
        separator = "  +" + "-" * (game.cols * 3) + "+"
        rows_str = []
        for r, row in enumerate(board):
            rows_str.append(f"{r:2d}| " + " ".join(f"{v:>2s}" for v in row) + " |")
    return col_header + "\n" + separator + "\n" + "\n".join(rows_str) + "\n" + separator


def _get_frontier_and_numbers(game: MinesweeperGame):
    """Get frontier cells (unrevealed cells adjacent to revealed numbers)
    and the revealed number cells near the frontier."""
    frontier = set()
    number_cells = {}
    for r in range(game.rows):
        for c in range(game.cols):
            if (r, c) in game._revealed and game._board[r][c] > 0:
                number_cells[(r, c)] = game._board[r][c]
                for dr in [-1, 0, 1]:
                    for dc in [-1, 0, 1]:
                        nr, nc = r + dr, c + dc
                        if (0 <= nr < game.rows and 0 <= nc < game.cols
                                and (nr, nc) not in game._revealed):
                            frontier.add((nr, nc))
    return frontier, number_cells


def _extract_critical_areas(game: MinesweeperGame, frontier, number_cells, max_areas=3):
    """Extract small rectangular regions around frontier clusters for display."""
    if not frontier:
        return []
    frontier_list = sorted(frontier)
    areas = []
    used = set()
    for fr, fc in frontier_list:
        if (fr, fc) in used:
            continue
        r_min = max(0, fr - 1)
        r_max = min(game.rows - 1, fr + 1)
        c_min = max(0, fc - 1)
        c_max = min(game.cols - 1, fc + 1)
        for fr2, fc2 in frontier_list:
            if abs(fr2 - fr) <= 2 and abs(fc2 - fc) <= 2:
                r_min = min(r_min, max(0, fr2 - 1))
                r_max = max(r_max, min(game.rows - 1, fr2 + 1))
                c_min = min(c_min, max(0, fc2 - 1))
                c_max = max(c_max, min(game.cols - 1, fc2 + 1))
                used.add((fr2, fc2))
        board = game.get_visible_board()
        col_hdr = "    " + " ".join(f"{c:2d}" for c in range(c_min, c_max + 1))
        sep = "  +" + "-" * ((c_max - c_min + 1) * 3) + "+"
        rows_str = []
        for r in range(r_min, r_max + 1):
            cells = [f"{board[r][c]:>2s}" for c in range(c_min, c_max + 1)]
            rows_str.append(f"{r:2d}| " + " ".join(cells) + " |")
        area_str = f"Region near ({fr},{fc}):\n{col_hdr}\n{sep}\n"
        area_str += "\n".join(rows_str) + "\n" + sep
        areas.append(area_str)
        if len(areas) >= max_areas:
            break
    return areas


def _format_board_medium(game: MinesweeperGame) -> str:
    """Format B: Revealed regions + frontier for boards 21-35 rows/cols."""
    frontier, number_cells = _get_frontier_and_numbers(game)
    areas = _extract_critical_areas(game, frontier, number_cells, max_areas=4)
    frontier_list = sorted(frontier)[:30]
    flagged_list = sorted(game._flagged)[:20]
    number_summary = [f"({r},{c})={v}" for (r, c), v in sorted(number_cells.items())[:30]]
    lines = []
    lines.append(f"Board: {game.rows}×{game.cols} with {game.num_mines} mines "
                 f"({mine_density(game.rows, game.cols, game.num_mines)*100:.1f}% density)")
    lines.append("")
    if areas:
        for area in areas:
            lines.append(area)
            lines.append("")
    lines.append(f"Revealed numbers: {number_summary}")
    lines.append(f"Unrevealed frontier cells: {frontier_list}")
    if flagged_list:
        lines.append(f"Flagged cells: {flagged_list}")
    return "\n".join(lines)


def _format_board_large(game: MinesweeperGame) -> str:
    """Format C: Critical areas + frontier summary for boards 36-50 rows/cols."""
    frontier, number_cells = _get_frontier_and_numbers(game)
    mid_r, mid_c = game.rows // 2, game.cols // 2
    quadrants = {
        "Top-left": (0, mid_r, 0, mid_c),
        "Top-right": (0, mid_r, mid_c, game.cols),
        "Bottom-left": (mid_r, game.rows, 0, mid_c),
        "Bottom-right": (mid_r, game.rows, mid_c, game.cols),
    }
    lines = []
    lines.append(f"Board: {game.rows}×{game.cols} with {game.num_mines} mines "
                 f"({mine_density(game.rows, game.cols, game.num_mines)*100:.1f}% density)")
    lines.append("")
    safe_total = game.rows * game.cols - game.num_mines
    lines.append(f"Revealed safe cells: {len(game._revealed)}/{safe_total} "
                 f"({game.progress()*100:.1f}% complete)")
    lines.append("")
    lines.append("Quadrant summary:")
    for qname, (r0, r1, c0, c1) in quadrants.items():
        q_total = (r1 - r0) * (c1 - c0)
        q_revealed = sum(1 for r in range(r0, r1) for c in range(c0, c1)
                        if (r, c) in game._revealed)
        status = "Fully explored" if q_revealed >= q_total * 0.9 else \
                 "Mostly complete" if q_revealed >= q_total * 0.5 else \
                 "Partially explored" if q_revealed > 0 else "Unexplored"
        lines.append(f"  {qname} ({r0}-{r1-1}, {c0}-{c1-1}): "
                     f"{status}, {q_revealed}/{q_total} revealed")
    lines.append("")
    areas = _extract_critical_areas(game, frontier, number_cells, max_areas=3)
    if areas:
        lines.append("Frontier regions:")
        for area in areas:
            lines.append(area)
            lines.append("")
    if len(game._flagged) > 0:
        lines.append(f"Flags: {len(game._flagged)}/{game.num_mines} mines flagged")
        if len(game._flagged) == game.num_mines:
            lines.append("→ All mines flagged — reveal any remaining '?' to win")
    frontier_list = sorted(frontier)[:20]
    lines.append(f"\nUnrevealed non-flagged: "
                 f"{game.rows * game.cols - len(game._revealed) - len(game._flagged)}")
    if frontier_list:
        lines.append(f"Frontier cells: {frontier_list}")
    return "\n".join(lines)


def _format_board(game: MinesweeperGame) -> str:
    """Select appropriate board format based on size."""
    if game.rows <= 20 and game.cols <= 20:
        return _format_board_small(game)
    elif game.rows <= 35 and game.cols <= 35:
        return _format_board_medium(game)
    else:
        return _format_board_large(game)



SYSTEM_PROMPT = (
    "You are a Minesweeper AI. You MUST reply with EXACTLY one JSON object on a single line. "
    "Do NOT write any text, explanation, reasoning, markdown, or code blocks. "
    "Do NOT use ```json or ``` wrappers. "
    "ONLY output raw JSON in this exact format: "
    '{"type":"reveal","row":0,"col":0} '
    "Nothing before it. Nothing after it. Just the JSON object."
)


def format_state_for_llm(game: MinesweeperGame, mode="training") -> str:
    """Convert game state to a compact prompt for the LLM.

    Single unified template for both training and inference.
    Pre-computed hints (safe/mine cells) do the heavy lifting.
    No verbose STEP 1-4 reasoning instructions.

    mode="training" — includes edge case guidance
    mode="inference" — same format (consistency helps at eval time)
    """
    if game.state() == "success":
        return "Game already won. No action needed."

    rows, cols = game.rows, game.cols
    total_cells = rows * cols
    safe_total = total_cells - game.num_mines
    density = mine_density(rows, cols, game.num_mines) * 100
    revealed = len(game._revealed)
    flags = len(game._flagged)
    remaining_mines = game.num_mines - flags

    # ── 0-mine shortcut ──
    if game.num_mines == 0:
        return (
            f"Minesweeper {rows}×{cols} board with 0 mines.\n"
            "All cells are safe. Reveal any unrevealed cell.\n\n"
            f"Row: 0-{rows - 1}  Col: 0-{cols - 1}\n"
            'Reply with ONLY this JSON, nothing else: {"type":"reveal","row":<int>,"col":<int>}'
        )

    board_repr = _format_board(game)

    # ── Pre-computed logical hints (the real value) ──
    safe_cells = _compute_safe_cells(game)
    mine_cells = _compute_mine_cells(game)

    hints = ""
    if safe_cells:
        hints += f"SAFE cells (100% certain — reveal one): {safe_cells[:8]}\n"
    if mine_cells:
        hints += f"MINE cells (100% certain — flag one): {mine_cells[:8]}\n"
    if not safe_cells and not mine_cells:
        if revealed > 0:
            hints += "No certain moves. Guess: prefer cells near low numbers, avoid edges.\n"
        else:
            hints += "No cells revealed yet. Make an opening move.\n"

    # ── Edge case one-liners ──
    edge = ""
    remaining = total_cells - revealed - flags

    if game.num_mines == 0:
        edge = "ZERO MINES: Every cell is safe. Reveal any.\n"
    elif rows == 1 or cols == 1:
        edge = f"LINEAR BOARD: Start from ends, work inward.\n"
    elif rows <= 3 and cols <= 3:
        edge = f"TINY BOARD: Prefer center if unrevealed.\n"

    if remaining == 1 and remaining_mines == 0:
        for r in range(rows):
            for c in range(cols):
                if (r, c) not in game._revealed and (r, c) not in game._flagged:
                    edge += f"LAST CELL ({r},{c}) — all mines flagged → REVEAL to WIN!\n"
                    break
            else:
                continue
            break
    elif remaining == 1 and remaining_mines == 1:
        for r in range(rows):
            for c in range(cols):
                if (r, c) not in game._revealed and (r, c) not in game._flagged:
                    edge += f"LAST CELL ({r},{c}) — it IS a mine → FLAG it!\n"
                    break
            else:
                continue
            break
    elif remaining_mines == 0 and remaining > 0:
        edge += f"ALL MINES FLAGGED: All {remaining} '?' cells are SAFE → reveal any.\n"
    elif remaining > 0 and remaining == remaining_mines:
        edge += f"ALL {remaining} REMAINING CELLS ARE MINES → flag any.\n"

    # ── Opening hint ──
    if revealed == 0:
        center_r, center_c = rows // 2, cols // 2
        edge += f"Opening: prefer center ({center_r},{center_c}).\n"

    # ── Build prompt ──
    prompt = (
        f"Minesweeper {rows}×{cols}, {game.num_mines} mines ({density:.1f}%)\n"
        f"Revealed: {revealed}/{safe_total} | Flags: {flags}/{game.num_mines}\n\n"
        f"{board_repr}\n\n"
        f"?=unrevealed F=flagged 0-8=safe\n\n"
        f"{edge}{hints}\n"
        f"Row: 0-{rows - 1}  Col: 0-{cols - 1}\n"
        'Reply with ONLY raw JSON, nothing else: {"type":"reveal","row":<int>,"col":<int>}'
    )

    return prompt


def parse_llm_action(response: str) -> dict:
    """Extract a valid Minesweeper action from LLM response.

    Searches for JSON objects matching the expected format.
    Handles: extra text around JSON, string-typed integers, multiple JSON objects.
    Returns: {"type": "reveal"/"flag", "row": int, "col": int} or None
    """
    best = None
    for match in re.finditer(r'\{[^{}]*\}', response):
        try:
            action = json.loads(match.group())
            if ("type" in action and "row" in action and "col" in action
                    and action["type"] in ("reveal", "flag")):
                action["row"] = int(action["row"])
                action["col"] = int(action["col"])
                best = action
        except (json.JSONDecodeError, ValueError, TypeError):
            continue
    return best

# ── Quick verification ──
print("✅ Unified prompt system loaded (single template)")
print("   Kept: board formatters (3-tier A/B/C), constraint solver, parse_llm_action")
print("   Removed: 4-tier STEP 1-4 reasoning, 200 lines of tests")

# Verify basic functionality
game_test = MinesweeperGame(6, 6, 5, seed=42)
p = format_state_for_llm(game_test, mode="training")
print(f"   Example 6×6 prompt: {len(p)} chars")
assert "Minesweeper" in p
assert "JSON" in p
assert parse_llm_action('{"type":"reveal","row":2,"col":3}') == {"type": "reveal", "row": 2, "col": 3}
assert parse_llm_action('no json') is None
print("   ✅ parse_llm_action works")

print(f"\n--- Example prompt (6×6) ---\n{p}\n---")

# Test Model Before Training

See how the base model performs without finetuning:

In [None]:
from transformers import TextStreamer

game = MinesweeperGame(rows=6, cols=6, num_mines=5, seed=42)
prompt = format_state_for_llm(game)

text = tokenizer.apply_chat_template(
    [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": prompt},
    ],
    tokenize = False,
    add_generation_prompt = True,
)

streamer = TextStreamer(tokenizer, skip_prompt = True)

print("=== Base Model Response ===")
output = model.generate(
    **tokenizer(text, return_tensors = "pt").to("cuda"),
    streamer = streamer,
    temperature = 0.7,
    top_p = 0.9,
    max_new_tokens = 128,
    do_sample = True,
)

# GRPO Reward Functions

Define reward functions to guide the model's learning:

In [None]:
import numpy as np
import math

# ──────────────────────────────────────────────────────────────────────
# Helper: Reconstruct game from dataset columns
# ──────────────────────────────────────────────────────────────────────

def _reconstruct_game(idx, kwargs):
    """Reconstruct a MinesweeperGame from dataset columns passed via GRPO kwargs.

    The dataset stores: seed, move_history, board_rows, board_cols, board_mines
    GRPOTrainer passes all non-'prompt' columns as lists in **kwargs.

    Returns (game, move_history_list) or (None, None) if data is missing.
    Handles 0-mine boards correctly.
    """
    seeds = kwargs.get("seed", [])
    move_histories = kwargs.get("move_history", [])
    rows_list = kwargs.get("board_rows", [])
    cols_list = kwargs.get("board_cols", [])
    mines_list = kwargs.get("board_mines", [])

    if idx >= len(seeds) or idx >= len(move_histories):
        return None, None

    seed = seeds[idx]
    rows = rows_list[idx] if idx < len(rows_list) else 6
    cols = cols_list[idx] if idx < len(cols_list) else 6
    num_mines = mines_list[idx] if idx < len(mines_list) else 5

    mh_raw = move_histories[idx]
    if isinstance(mh_raw, str):
        move_history = json.loads(mh_raw)
    else:
        move_history = list(mh_raw)

    game = MinesweeperGame(rows=int(rows), cols=int(cols),
                           num_mines=int(num_mines), seed=int(seed))
    for prev in move_history:
        result = game.do_action(prev)
        if result == "mine":
            return None, None  # History hit a mine — bad data

    return game, move_history


# ──────────────────────────────────────────────────────────────────────
# Length penalty helper (GRPO-LEAD: α=0.15)
# ──────────────────────────────────────────────────────────────────────
LENGTH_PENALTY_ALPHA = 0.15


def _length_penalty(response: str) -> float:
    """Compute length-based reward modifier (GRPO-LEAD paper).

    Returns a value in [-2.0, +2.0]:
      Pure JSON (~30 chars)    → +2.0
      Near-pure (<60c)         → +1.5
      Some extra text (<100c)  → +0.5
      Moderate (100-200 chars) → -0.5
      Verbose (>200 chars)     → -1.5 to -2.0
    """
    n = len(response)
    if n <= 40:
        return 2.0
    elif n <= 60:
        return 1.5
    elif n <= 100:
        return 0.5 * math.exp(-LENGTH_PENALTY_ALPHA * (n - 60) / 10)
    elif n <= 200:
        return -0.5 - 0.5 * ((n - 100) / 100)
    else:
        return -1.5 - min(0.5, (n - 200) / 200)


# ──────────────────────────────────────────────────────────────────────
# Difficulty reweighting helper (XRPO paper)
# ──────────────────────────────────────────────────────────────────────

def _difficulty_multiplier(game) -> float:
    """Compute difficulty-based reward multiplier (XRPO paper).

    Returns a value in [0.7, 1.5]:
      Easy boards (low density, small) → ~0.7-0.9
      Medium boards                    → ~1.0
      Hard boards (high density, large) → ~1.2-1.5
    """
    board_size = game.rows * game.cols

    if board_size == 0 or game.num_mines == 0:
        return 1.0

    density = mine_density(game.rows, game.cols, game.num_mines)
    size_factor = min(1.0, math.log(max(1, board_size)) / math.log(2500))
    difficulty = 0.6 * (density / 0.20) + 0.4 * size_factor
    multiplier = 0.7 + 0.8 / (1.0 + math.exp(-6.0 * (difficulty - 0.5)))

    return multiplier


# ──────────────────────────────────────────────────────────────────────
# Reward 1: Valid JSON format + conciseness + length penalty (GRPO-LEAD)
# ──────────────────────────────────────────────────────────────────────

def valid_json_reward(prompts, completions, **kwargs):
    """GRPO-LEAD: Ultra-strict JSON-only reward with EXPONENTIAL penalties.

    Exponential penalty curve: penalty = -1.0 × 1.5^((extra_chars - 3) / 5)
      5 extra chars  → -1.0
      10 extra chars → -2.8
      20 extra chars → -10.1
      30 extra chars → -28.3
      50+ chars      → -50.0 (cap)

    Two-pass approach:
      Pass 1: Score format correctness & collect correct response lengths
      Pass 2: Apply z-score normalized length penalty to correct responses
    """
    results = []
    correct_lengths = []

    for completion in completions:
        response = completion[0]["content"].strip() if completion else ""
        action = parse_llm_action(response)

        if action is None:
            results.append((None, response, -25.0, False))
            continue

        # Check if it's PURE JSON (no extra text)
        try:
            parsed = json.loads(response)
            if "type" in parsed and "row" in parsed and "col" in parsed:
                if len(response) <= 50:
                    correct_lengths.append(len(response))
                    results.append((action, response, 10.0, True))
                elif len(response) <= 80:
                    correct_lengths.append(len(response))
                    results.append((action, response, 8.0, True))
                elif len(response) <= 120:
                    correct_lengths.append(len(response))
                    results.append((action, response, 5.0, True))
                else:
                    correct_lengths.append(len(response))
                    results.append((action, response, 2.0, True))
                continue
        except json.JSONDecodeError:
            pass

        # JSON found but with extra text — EXPONENTIAL PENALTY
        json_match = re.search(r'\{[^{}]*\}', response)
        extra_chars = len(response) - len(json_match.group()) if json_match else len(response)

        if extra_chars <= 3:
            base = 0.5
        else:
            base = -1.0 * (1.5 ** ((extra_chars - 3) / 5))
            base = max(-50.0, base)

        correct_lengths.append(len(response))
        results.append((action, response, base, True))

    # ── Pass 2: Apply group-normalized length penalty (GRPO-LEAD) ──
    scores = []

    if len(correct_lengths) > 1:
        mean_len = np.mean(correct_lengths)
        std_len = np.std(correct_lengths) + 1e-8

        for action, response, base_score, is_correct in results:
            if not is_correct:
                scores.append(base_score)
            else:
                z_score = (len(response) - mean_len) / std_len
                length_multiplier = math.exp(-LENGTH_PENALTY_ALPHA * z_score)
                length_multiplier = max(0.5, min(1.5, length_multiplier))
                scores.append(base_score * length_multiplier)
    else:
        for action, response, base_score, is_correct in results:
            if not is_correct:
                scores.append(base_score)
            else:
                lp = _length_penalty(response)
                scores.append(base_score + lp)

    return scores


# ──────────────────────────────────────────────────────────────────────
# Reward 2: Gameplay — 12-criterion scoring + center-opening (XRPO)
# ──────────────────────────────────────────────────────────────────────

def gameplay_scores(prompts, completions, **kwargs):
    """
    Complete gameplay reward implementing all 12 scoring criteria
    + center-opening bonus (folded in from strategic_reward).

    1.  Flag cell that IS a mine        → +15 / +20 (logical)
    2.  Flag cell that is NOT a mine    → -11
    3.  Reveal cell that IS a mine      → -26
    4.  Reveal cell that is safe        → +10 (guess) / +15 (logical)
    5.  Flag already flagged cell       → -12
    6.  Reveal already revealed cell    → -12
    7.  Out of bounds                   → -15
    8.  Total flags > total mines       → -10 (additional)
    9.  Invalid JSON                    → -10
    10. Win the game                    → +100 × size_scale
    11. Reveal a flagged cell           → -8
    12. Flag a revealed cell            → -8
    13. Center-opening bonus (LAMER)    → +5 center, +3 near-center
    14. Penalize flagging on fresh board → -2
    """
    scores = []

    for idx, completion in enumerate(completions):
        response = completion[0]["content"] if completion else ""
        action = parse_llm_action(response)

        # Criterion 9: Invalid JSON
        if action is None:
            scores.append(-20.0)
            continue

        game, move_history = _reconstruct_game(idx, kwargs)
        if game is None:
            scores.append(0.0)
            continue

        diff_mult = _difficulty_multiplier(game)

        if game.state() != "ongoing":
            scores.append(0.0)
            continue

        row, col = action["row"], action["col"]
        action_type = action["type"]

        # Criterion 7: Out of bounds
        if not (0 <= row < game.rows and 0 <= col < game.cols):
            scores.append(-15.0 * diff_mult)
            continue

        # Edge case: 0-mine board
        if game.num_mines == 0:
            if action_type == "reveal":
                if (row, col) in game._revealed:
                    scores.append(-12.0)
                else:
                    game_copy = MinesweeperGame(
                        rows=game.rows, cols=game.cols,
                        num_mines=0, seed=kwargs.get("seed", [0])[idx]
                    )
                    for prev in move_history:
                        game_copy.do_action(prev)
                    result = game_copy.do_action(action)
                    board_size = game.rows * game.cols
                    size_scale = 1.0 + min(1.0, board_size / 1000)
                    win_bonus = (100.0 * size_scale) if result == "win" else 0.0
                    scores.append(15.0 + win_bonus)
            else:
                scores.append(-10.0)
            continue

        # Compute logical deductions ONCE
        safe_set, mine_set = _compute_safe_and_mine_cells(game)

        score = 0.0

        # ── Center-opening bonus (from LAMER paper, was in strategic_reward) ──
        if len(game._revealed) == 0 and action_type == "reveal":
            density_pct = mine_density(game.rows, game.cols, game.num_mines) * 100
            center_r, center_c = game.rows // 2, game.cols // 2
            dist_to_center = abs(row - center_r) + abs(col - center_c)
            max_dist = center_r + center_c

            if density_pct > 10:
                if dist_to_center == 0:
                    score += 5.0
                elif dist_to_center <= max(1, max_dist // 4):
                    score += 3.0
                elif dist_to_center <= max_dist // 2:
                    score += 1.0
                else:
                    score -= 2.0
            else:
                if dist_to_center <= max(1, max_dist // 3):
                    score += 2.0

        # Penalize flagging on a fresh board
        if len(game._revealed) == 0 and action_type == "flag":
            score -= 2.0

        if action_type == "reveal":
            # Criterion 6: Reveal already revealed cell
            if (row, col) in game._revealed:
                scores.append(-12.0 * diff_mult)
                continue

            # Criterion 11: Reveal a flagged cell
            if (row, col) in game._flagged:
                scores.append(-8.0 * diff_mult)
                continue

            # Criterion 3: Reveal a mine
            if game._board[row][col] == -1:
                scores.append((-25.0 - 1.0) * diff_mult)
                continue

            # Criterion 4: Reveal safe cell
            if (row, col) in safe_set:
                score += 15.0   # Logically deduced safe cell
            else:
                score += 10.0   # Guessed safe cell

            # Small bonus for revealing near numbers (information-rich)
            board = game.get_visible_board()
            has_adjacent_number = False
            for dr in [-1, 0, 1]:
                for dc in [-1, 0, 1]:
                    nr, nc = row + dr, col + dc
                    if 0 <= nr < game.rows and 0 <= nc < game.cols:
                        if board[nr][nc] in ('1','2','3','4','5','6','7','8'):
                            has_adjacent_number = True
                            break
                if has_adjacent_number:
                    break
            if has_adjacent_number:
                score += 1.0

            # Criterion 10: Check for win
            game_copy = MinesweeperGame(
                rows=game.rows, cols=game.cols,
                num_mines=game.num_mines, seed=kwargs.get("seed", [0])[idx]
            )
            for prev in move_history:
                game_copy.do_action(prev)
            result = game_copy.do_action(action)
            if result == "win":
                board_size = game.rows * game.cols
                size_scale = 1.0 + min(1.0, board_size / 1000)
                score += 100.0 * size_scale

        elif action_type == "flag":
            # Criterion 5: Flag already flagged cell
            if (row, col) in game._flagged:
                scores.append(-12.0 * diff_mult)
                continue

            # Criterion 12: Flag a revealed cell
            if (row, col) in game._revealed:
                scores.append(-8.0 * diff_mult)
                continue

            # Criterion 1: Flag a mine (correct)
            if game._board[row][col] == -1:
                if (row, col) in mine_set:
                    score += 20.0   # Logically deduced mine
                else:
                    score += 15.0   # Correct but guessed

            # Criterion 2: Flag a non-mine (wrong)
            else:
                score -= 10.0 + 1.0

            # Criterion 8: Total flags > total mines
            new_flag_count = len(game._flagged) + 1
            if new_flag_count > game.num_mines:
                score -= 10.0

        # Apply difficulty multiplier (XRPO)
        scores.append(score * diff_mult)

    return scores


# ── Verify reward function signatures ──
print("✅ Reward functions defined (2 total):")
print("   1. valid_json_reward — format + exponential penalty (GRPO-LEAD)")
print("   2. gameplay_scores   — 12 criteria + center-opening + difficulty (XRPO)")
print("   Removed: strategic_reward (was 5% weight = noise, folded center-opening into gameplay_scores)")

# ── Smoke test ──
test_completions_pure = [[{"role": "assistant", "content": '{"type":"reveal","row":0,"col":0}'}]]
test_completions_verbose = [[{"role": "assistant", "content": 'Let me analyze this board step by step. The cell at row 0, col 0 looks promising because it has several revealed neighbors. {"type":"reveal","row":0,"col":0}'}]]
test_prompts = ["test"]

test_kwargs = {
    "seed": [42],
    "move_history": ["[]"],
    "board_rows": [6],
    "board_cols": [6],
    "board_mines": [5],
}

r1p = valid_json_reward(test_prompts, test_completions_pure, **test_kwargs)
r1v = valid_json_reward(test_prompts, test_completions_verbose, **test_kwargs)
r2p = gameplay_scores(test_prompts, test_completions_pure, **test_kwargs)

print(f"\n  Pure JSON format reward:    {r1p[0]:+.2f}")
print(f"  Verbose JSON format reward: {r1v[0]:+.2f}")
print(f"  Gameplay reward (pure):     {r2p[0]:+.2f}")
print(f"\n  Reward gap analysis (weights [0.60, 0.40]):")
print(f"    Pure JSON weighted:    {r1p[0]*0.60:+.2f} + {r2p[0]*0.40:+.2f} = {r1p[0]*0.60 + r2p[0]*0.40:+.2f}")
print(f"    Gap: {(r1p[0] - r1v[0]) * 0.60:+.2f} — pure JSON DOMINATES")

# 0-mine board
test_kwargs_zero = {"seed": [42], "move_history": ["[]"], "board_rows": [5], "board_cols": [5], "board_mines": [0]}
r2z = gameplay_scores(test_prompts, test_completions_pure, **test_kwargs_zero)
print(f"\n  0-mine board: gameplay={r2z[0]:.1f}")

# 1x1 board
test_kwargs_1x1 = {"seed": [42], "move_history": ["[]"], "board_rows": [1], "board_cols": [1], "board_mines": [0]}
r2_1x1 = gameplay_scores(test_prompts, test_completions_pure, **test_kwargs_1x1)
print(f"  1x1 m=0:     gameplay={r2_1x1[0]:.1f}")

# Hard board
test_kwargs_hard = {"seed": [42], "move_history": ["[]"], "board_rows": [20], "board_cols": [20], "board_mines": [80]}
r2h = gameplay_scores(test_prompts, test_completions_pure, **test_kwargs_hard)
print(f"  Hard 20x20:   gameplay={r2h[0]:.1f} (XRPO amplified)")

print(f"\n  ✅ Exponential JSON penalties (pure +10.0, verbose → -50.0 cap)")
print(f"  ✅ Edge cases: 0-mine, 1x1, hard boards handled")
print(f"  ✅ Center-opening bonus folded into gameplay_scores")
print(f"  ✅ Difficulty reweighting (XRPO): harder boards → amplified signal")

# Exhaustive Training Dataset Generation

6-phase composition targeting **4000 samples** with density-stratified sampling:

| Phase | Budget | Description |
|-------|--------|-------------|
| 1. Edge Cases | 10% | 50+ explicit configs (trivial, linear, rectangular, density extremes) |
| 2. Opening | 25% | 75% fresh + 25% single-move — fix 75% early death rate |
| 3. Pattern-Specific | 15% | Satisfied numbers (60%) + multi-region boards (40%) |
| 4. Mid-Game | 25% | 3-15 moves, progressive flagging 10%→30%→50% |
| 5. Endgame | 15% | 80-98% revealed, flag accounting, completion strategy |
| 6. Forced Guess | 10% | No logical deductions available — probability reasoning |

## Board Size Distribution (Bell-Curve Curriculum)
| Band | Size Range | Weight | Purpose |
|------|-----------|--------|---------|
| Tiny | 1–5 | 10% | Edge cases, basics |
| Small | 5–10 | 20% | Simple pattern learning |
| **Mid** | **8–20** | **35%** | **Core curriculum ★ peak** |
| Large | 15–30 | 20% | Generalization |
| XL | 25–40 | 10% | Harder boards |
| XXL | 35–50 | 5% | Edge exposure only |

In [None]:
import os
from datasets import Dataset

# ══════════════════════════════════════════════════════════════════════
#  EXHAUSTIVE DATASET GENERATION
#  Fixes: 75% early deaths, no pattern training, 0.7% endgame,
#         no forced-guess scenarios, insufficient flag training
# ══════════════════════════════════════════════════════════════════════


# ──────────────────────────────────────────────────────────────────────
# Helper: Smart move selection with progressive flagging
# ──────────────────────────────────────────────────────────────────────

def _smart_reveal(game, rng):
    """Pick a smart cell to reveal during history generation.
    Prefers safe cells (if known) to avoid hitting mines and losing the game.
    Falls back to random unrevealed cell.
    """
    safe_set, _ = _compute_safe_and_mine_cells(game)
    if safe_set:
        return rng.choice(list(safe_set))

    # Fallback: random unrevealed, unflagged cell
    unrevealed = [(r, c) for r in range(game.rows) for c in range(game.cols)
                  if (r, c) not in game._revealed and (r, c) not in game._flagged]
    if not unrevealed:
        return None
    return rng.choice(unrevealed)


def _smart_flag(game, rng):
    """Pick a logically certain mine to flag, or None if none available."""
    _, mine_set = _compute_safe_and_mine_cells(game)
    mine_candidates = [c for c in mine_set if c not in game._flagged]
    if mine_candidates:
        return rng.choice(mine_candidates)
    return None  # Don't flag randomly — creates bad training data


def _progressive_flag_probability(game):
    """Adaptive flagging based on game progress (LAMER-inspired).
    - Early game (0-30%):  10% flagging (explore first)
    - Mid game (30-70%):   30% flagging (deduce mines)
    - Late game (70-100%): 50% flagging (lock in certain mines)
    """
    progress = game.progress()
    if progress < 0.30:
        return 0.10
    elif progress < 0.70:
        return 0.30
    else:
        return 0.50


def _play_smart_moves(game, rng, num_moves, use_progressive_flags=True):
    """Play num_moves smart moves on a game. Returns move_history list.
    Uses progressive flagging strategy when enabled.
    Returns early if game ends or gets stuck.

    Fix #1: Added stuck_count to prevent infinite loops when no valid
    moves can be found (e.g., all cells revealed/flagged but game ongoing).
    """
    move_history = []
    stuck_count = 0
    MAX_STUCK = 10  # Abort after 10 consecutive failed attempts

    for _ in range(num_moves):
        if game.state() != "ongoing":
            break

        action_dict = None
        flag_prob = _progressive_flag_probability(game) if use_progressive_flags else 0.15

        # Try flagging with adaptive probability
        if rng.random() < flag_prob:
            flag_target = _smart_flag(game, rng)
            if flag_target:
                action_dict = {"type": "flag", "row": flag_target[0], "col": flag_target[1]}

        # Fall back to reveal
        if action_dict is None:
            reveal_target = _smart_reveal(game, rng)
            if reveal_target is None:
                stuck_count += 1
                if stuck_count >= MAX_STUCK:
                    break  # Prevent infinite loop
                continue
            action_dict = {"type": "reveal", "row": reveal_target[0], "col": reveal_target[1]}

        result = game.do_action(action_dict)
        if result == "mine":
            break  # Hit a mine — stop

        move_history.append(action_dict)
        stuck_count = 0  # Reset on successful move

    return move_history


# ──────────────────────────────────────────────────────────────────────
# Scenario generators: endgame, forced-guess, multi-region
# ──────────────────────────────────────────────────────────────────────

def _generate_endgame_state(rows, cols, num_mines, rng, completion_target=0.85):
    """Generate boards that are 80-98% complete.
    Critical for teaching finishing strategy and flag accounting.
    Returns (game, move_history) or (None, None) on failure.
    """
    seed = rng.randint(0, 999999)
    game = MinesweeperGame(rows=rows, cols=cols, num_mines=num_mines, seed=seed)

    if game.state() != "ongoing":
        return None, None, seed

    safe_total = rows * cols - num_mines
    if safe_total <= 1:
        return None, None, seed

    target_revealed = int(safe_total * rng.uniform(completion_target, 0.98))
    target_revealed = max(1, min(target_revealed, safe_total - 1))

    move_history = []
    stuck_count = 0
    max_stuck = 20

    while len(game._revealed) < target_revealed and game.state() == "ongoing":
        safe_set, mine_set = _compute_safe_and_mine_cells(game)

        if safe_set:
            target = rng.choice(list(safe_set))
            action_dict = {"type": "reveal", "row": target[0], "col": target[1]}
        else:
            # No logical moves — reveal a random safe cell (we know the board)
            all_safe_unrevealed = [
                (r, c) for r in range(rows) for c in range(cols)
                if game._board[r][c] != -1
                and (r, c) not in game._revealed
                and (r, c) not in game._flagged
            ]
            if not all_safe_unrevealed:
                break
            target = rng.choice(all_safe_unrevealed)
            action_dict = {"type": "reveal", "row": target[0], "col": target[1]}
            stuck_count += 1
            if stuck_count >= max_stuck:
                break

        result = game.do_action(action_dict)
        if result == "mine":
            return None, None, seed
        move_history.append(action_dict)

    # Optionally add some flags in endgame
    if game.state() == "ongoing":
        _, mine_set = _compute_safe_and_mine_cells(game)
        flaggable = [c for c in mine_set if c not in game._flagged]
        if flaggable and rng.random() < 0.7:
            num_flags = rng.randint(1, min(len(flaggable), 5))
            for target in rng.sample(flaggable, num_flags):
                action_dict = {"type": "flag", "row": target[0], "col": target[1]}
                game.do_action(action_dict)
                move_history.append(action_dict)

    if game.state() != "ongoing":
        return None, None, seed

    return game, move_history, seed


def _generate_forced_guess_state(rows, cols, num_mines, rng):
    """Generate a board state where no 100% certain logical moves exist.
    These teach the model probability-based guessing.
    Returns (game, move_history, seed) or (None, None, seed).
    """
    seed = rng.randint(0, 999999)
    game = MinesweeperGame(rows=rows, cols=cols, num_mines=num_mines, seed=seed)

    if game.state() != "ongoing":
        return None, None, seed

    move_history = []
    max_reveal_attempts = rows * cols

    for _ in range(max_reveal_attempts):
        if game.state() != "ongoing":
            return None, None, seed

        safe_set, mine_set = _compute_safe_and_mine_cells(game)

        if not safe_set and not mine_set and len(game._revealed) > 0:
            # No logical moves available — this is what we want!
            unrevealed_count = sum(
                1 for r in range(rows) for c in range(cols)
                if (r, c) not in game._revealed and (r, c) not in game._flagged
            )
            if unrevealed_count >= 2:
                return game, move_history, seed

        if safe_set:
            target = rng.choice(list(safe_set))
            action_dict = {"type": "reveal", "row": target[0], "col": target[1]}
        else:
            # Must guess — reveal random safe cell (using board knowledge)
            all_safe = [
                (r, c) for r in range(rows) for c in range(cols)
                if game._board[r][c] != -1
                and (r, c) not in game._revealed
                and (r, c) not in game._flagged
            ]
            if not all_safe:
                break
            target = rng.choice(all_safe)
            action_dict = {"type": "reveal", "row": target[0], "col": target[1]}

        result = game.do_action(action_dict)
        if result == "mine":
            return None, None, seed
        move_history.append(action_dict)

    # Fallback: return whatever state we reached (might still have logical moves)
    if game.state() == "ongoing" and len(game._revealed) > 0:
        return game, move_history, seed
    return None, None, seed


def _generate_multi_region_state(rows, cols, num_mines, rng):
    """Create board with multiple separated revealed regions.
    Teaches model to reason across disconnected information.
    Returns (game, move_history, seed) or (None, None, seed).
    """
    if rows < 6 or cols < 6:
        return None, None, 0  # Need space for multiple regions

    seed = rng.randint(0, 999999)
    game = MinesweeperGame(rows=rows, cols=cols, num_mines=num_mines, seed=seed)

    if game.state() != "ongoing":
        return None, None, seed

    # Pick 2-4 starting points in different quadrants
    quadrant_centers = [
        (rows // 4, cols // 4),
        (rows // 4, 3 * cols // 4),
        (3 * rows // 4, cols // 4),
        (3 * rows // 4, 3 * cols // 4),
    ]
    rng.shuffle(quadrant_centers)
    num_regions = rng.randint(2, min(4, len(quadrant_centers)))
    selected = quadrant_centers[:num_regions]

    move_history = []
    for r, c in selected:
        if game.state() != "ongoing":
            break
        r = max(0, min(r, rows - 1))
        c = max(0, min(c, cols - 1))

        if (r, c) not in game._revealed and game._board[r][c] != -1:
            action_dict = {"type": "reveal", "row": r, "col": c}
            result = game.do_action(action_dict)
            if result == "mine":
                return None, None, seed
            move_history.append(action_dict)

        # Reveal a few logical neighbors around this region
        for _ in range(rng.randint(1, 3)):
            if game.state() != "ongoing":
                break
            safe_set, _ = _compute_safe_and_mine_cells(game)
            if safe_set:
                # Prefer safe cells near this region
                nearby = [
                    (sr, sc) for sr, sc in safe_set
                    if abs(sr - r) <= rows // 3 and abs(sc - c) <= cols // 3
                ]
                target = rng.choice(nearby) if nearby else rng.choice(list(safe_set))
                action_dict = {"type": "reveal", "row": target[0], "col": target[1]}
                result = game.do_action(action_dict)
                if result == "mine":
                    return None, None, seed
                move_history.append(action_dict)

    if game.state() == "ongoing" and len(game._revealed) > 0:
        return game, move_history, seed
    return None, None, seed


def _generate_satisfied_numbers_state(rows, cols, num_mines, rng):
    """Generate board with multiple satisfied numbers (easy deductions available).
    Satisfied number = number whose count of adjacent flags equals its value.
    All remaining unrevealed neighbors of a satisfied number are safe.
    Teaches basic constraint satisfaction.
    Returns (game, move_history, seed).
    """
    seed = rng.randint(0, 999999)
    game = MinesweeperGame(rows=rows, cols=cols, num_mines=num_mines, seed=seed)

    if game.state() != "ongoing":
        return None, None, seed

    move_history = []

    # Reveal some cells first
    num_initial = rng.randint(3, max(3, min(15, rows * cols // 4)))
    for _ in range(num_initial):
        if game.state() != "ongoing":
            break
        target = _smart_reveal(game, rng)
        if target is None:
            break
        action_dict = {"type": "reveal", "row": target[0], "col": target[1]}
        result = game.do_action(action_dict)
        if result == "mine":
            return None, None, seed
        move_history.append(action_dict)

    # Now flag logically certain mines to create satisfied numbers
    if game.state() == "ongoing":
        for _ in range(5):
            _, mine_set = _compute_safe_and_mine_cells(game)
            flaggable = [c for c in mine_set if c not in game._flagged]
            if not flaggable:
                break
            target = rng.choice(flaggable)
            action_dict = {"type": "flag", "row": target[0], "col": target[1]}
            game.do_action(action_dict)
            move_history.append(action_dict)

    if game.state() == "ongoing" and len(game._revealed) > 0:
        return game, move_history, seed
    return None, None, seed


# ──────────────────────────────────────────────────────────────────────
# Density-stratified board sampling
# ──────────────────────────────────────────────────────────────────────

DENSITY_TARGETS = [
    # (density_range, weight, label)
    ((0.00, 0.00), 0.08, "Zero mines"),       # 8%  - trivial edge case
    ((0.01, 0.05), 0.17, "Very sparse"),       # 17%
    ((0.05, 0.10), 0.25, "Sparse"),            # 25%
    ((0.10, 0.15), 0.25, "Medium"),            # 25%
    ((0.15, 0.20), 0.25, "Dense/Max"),         # 25%
]


# ──────────────────────────────────────────────────────────────────────
# Board Size Distribution — Bell-curve centered on mid-range (8-20)
#
# Curriculum learning insight: mid-range boards (8-20) are the sweet
# spot for learning patterns. Small boards teach basics, but mid-range
# boards have enough complexity to learn deduction chains, constraint
# satisfaction, and multi-step reasoning without the noise/sparsity
# of huge boards. Large boards (30+) exist for generalization but
# shouldn't dominate training.
#
# Distribution (bell-curve, peak at mid-range):
#   Tiny   (1-5):    10%  — edge cases, basics
#   Small  (5-10):   20%  — simple pattern learning
#   Mid    (8-20):   35%  — PEAK: core pattern training ★
#   Large  (15-30):  20%  — generalization, scaling
#   XL     (25-40):  10%  — harder generalization
#   XXL    (35-50):   5%  — edge exposure only
# ──────────────────────────────────────────────────────────────────────

SIZE_BANDS = [
    # (min_dim, max_dim, weight, label)
    (1,   5,  0.10, "Tiny"),       # 10% — basics + edge cases
    (5,  10,  0.20, "Small"),      # 20% — simple patterns
    (8,  20,  0.35, "Mid"),        # 35% — PEAK: core curriculum ★
    (15, 30,  0.20, "Large"),      # 20% — generalization
    (25, 40,  0.10, "XL"),         # 10% — harder boards
    (35, 50,  0.05, "XXL"),        #  5% — edge exposure only
]


def _sample_board_size(rng):
    """Sample board dimensions using bell-curve distribution.
    Peak at mid-range (8-20) for optimal curriculum learning.
    """
    weights = [w for _, _, w, _ in SIZE_BANDS]
    idx = rng.choices(range(len(SIZE_BANDS)), weights=weights, k=1)[0]
    lo, hi, _, _ = SIZE_BANDS[idx]
    rows = rng.randint(lo, hi)
    cols = rng.randint(lo, hi)
    return rows, cols


def _sample_board_with_density(rng, target_density_range=None):
    """Sample board config with explicit density control + bell-curve sizes.
    If target_density_range is None, picks one from DENSITY_TARGETS.
    """
    if target_density_range is None:
        # Weighted random density band
        weights = [w for _, w, _ in DENSITY_TARGETS]
        ranges = [r for r, _, _ in DENSITY_TARGETS]
        idx = rng.choices(range(len(DENSITY_TARGETS)), weights=weights, k=1)[0]
        target_density_range = ranges[idx]

    # Sample board size using bell-curve distribution
    rows, cols = _sample_board_size(rng)

    total = rows * cols
    min_d, max_d = target_density_range

    if min_d == 0.0 and max_d == 0.0:
        return rows, cols, 0

    min_mines = max(0, int(math.ceil(total * min_d)))
    max_mines = min(int(total * max_d), int(total * MAX_MINE_DENSITY), total - 1)

    if max_mines < min_mines:
        num_mines = max(0, min(min_mines, total - 1))
    else:
        num_mines = rng.randint(min_mines, max_mines)

    return rows, cols, num_mines


# ──────────────────────────────────────────────────────────────────────
# Exhaustive edge-case configs (50+ scenarios)
# ──────────────────────────────────────────────────────────────────────

EDGE_CASE_CONFIGS = [
    # (rows, cols, mines, label)

    # === TRIVIAL BOARDS ===
    (1, 1, 0,   "1x1 trivial"),
    (2, 2, 0,   "2x2 no mines"),
    (3, 3, 0,   "3x3 no mines"),
    (4, 4, 0,   "4x4 no mines"),
    (5, 5, 0,   "5x5 no mines - cascade practice"),

    # === LINEAR BOARDS (1D Minesweeper) ===
    (1, 5, 1,   "1x5 single mine"),
    (1, 10, 2,  "1x10 two mines"),
    (1, 20, 4,  "1x20 row"),
    (1, 50, 10, "1x50 maximum row"),
    (5, 1, 1,   "5x1 column"),
    (10, 1, 2,  "10x1 column"),
    (20, 1, 4,  "20x1 column"),
    (50, 1, 10, "50x1 maximum column"),

    # === TINY BOARDS ===
    (1, 2, 0,   "1x2 trivial"),
    (2, 1, 0,   "2x1 trivial"),
    (2, 2, 1,   "2x2 single mine"),
    (2, 3, 1,   "2x3 mini"),
    (3, 2, 1,   "3x2 mini"),
    (3, 3, 1,   "3x3 single mine"),
    (3, 3, 2,   "3x3 two mines"),
    (4, 4, 3,   "4x4 medium density"),

    # === EXTREME RECTANGULAR ===
    (2, 50, 20, "2x50 ultra-wide"),
    (50, 2, 20, "50x2 ultra-tall"),
    (3, 40, 24, "3x40 extreme ratio"),
    (40, 3, 24, "40x3 extreme ratio"),
    (5, 50, 50, "5x50 max density wide"),
    (50, 5, 50, "50x5 max density tall"),

    # === CLASSIC MINESWEEPER SIZES ===
    (8, 8, 1,   "8x8 very sparse"),
    (8, 8, 5,   "8x8 sparse"),
    (8, 8, 10,  "8x8 medium - classic beginner"),
    (8, 8, 13,  "8x8 max density"),
    (16, 16, 10, "16x16 sparse"),
    (16, 16, 40, "16x16 medium - classic intermediate"),
    (16, 16, 51, "16x16 max density"),
    (16, 30, 20, "16x30 rectangular sparse"),
    (16, 30, 96, "16x30 rectangular dense - classic expert"),

    # === LARGE BOARDS ===
    (20, 20, 10, "20x20 very sparse"),
    (20, 20, 80, "20x20 max density"),
    (25, 25, 30, "25x25 sparse"),
    (25, 25, 125, "25x25 max density"),
    (30, 30, 50, "30x30 sparse"),
    (30, 30, 180, "30x30 max density"),
    (40, 40, 100, "40x40 sparse"),
    (40, 40, 320, "40x40 max density"),

    # === MAXIMUM SIZE ===
    (50, 50, 10,  "50x50 ultra-sparse"),
    (50, 50, 100, "50x50 sparse"),
    (50, 50, 250, "50x50 medium"),
    (50, 50, 500, "50x50 maximum density"),

    # === DENSITY EDGE CASES ===
    (10, 10, 0,  "10x10 no mines"),
    (15, 15, 0,  "15x15 no mines - large trivial"),
    (20, 20, 0,  "20x20 no mines - huge trivial"),
    (10, 10, 1,  "10x10 single mine"),
    (20, 20, 1,  "20x20 single mine in large board"),

    # === MORE RECTANGULAR VARIETY ===
    (3, 50, 30,  "3x50 wide"),
    (50, 3, 30,  "50x3 tall"),
    (5, 15, 15,  "5x15 wide rect"),
    (15, 5, 15,  "15x5 tall rect"),
    (6, 8, 6,    "6x8 rect"),
    (8, 6, 6,    "8x6 rect"),
    (7, 13, 18,  "7x13 odd rect"),
    (13, 7, 18,  "13x7 odd rect"),
    (15, 30, 90, "15x30 wide large"),
]


# ──────────────────────────────────────────────────────────────────────
# Dataset item builder (DRY helper)
# ──────────────────────────────────────────────────────────────────────

def _build_dataset_item(game, seed, move_history):
    """Build a single dataset item dict from a game state."""
    prompt_text = format_state_for_llm(game)
    return {
        "prompt": [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": prompt_text},
        ],
        "seed": seed,
        "move_history": json.dumps(move_history),
        "board_rows": game.rows,
        "board_cols": game.cols,
        "board_mines": game.num_mines,
    }


# ══════════════════════════════════════════════════════════════════════
#  MAIN GENERATOR — 6-Phase Exhaustive Composition
# ══════════════════════════════════════════════════════════════════════

def generate_exhaustive_dataset(num_samples=1000, rng_seed=42):
    """
    Comprehensive Minesweeper training dataset covering ALL scenarios.

    6-Phase composition:
      Phase 1: Edge cases           — 10%  (50+ explicit configs)
      Phase 2: Opening-heavy        — 25%  (fresh + single-move boards)
      Phase 3: Pattern-specific     — 15%  (satisfied numbers, multi-region)
      Phase 4: Mid-game deduction   — 25%  (core logical reasoning)
      Phase 5: Endgame completion   — 15%  (80-98% revealed, flag accounting)
      Phase 6: Forced guess         — 10%  (no logical moves available)

    Improvements over previous version:
      ✅ 50+ edge case configs (was 25)
      ✅ Progressive flagging 10%→30%→50% by game phase (was flat 15%)
      ✅ Density-stratified board sampling
      ✅ Dedicated endgame generator (was 0.7% late-game)
      ✅ Forced-guess scenario training (was 0%)
      ✅ Multi-region disconnected boards (was 0%)
      ✅ Satisfied-number pattern training (was 0%)
      ✅ 4000 samples (was 3000)
    """
    rng = random.Random(rng_seed)
    np.random.seed(rng_seed)

    dataset_items = []
    phase_counts = {
        "edge_case": 0, "opening": 0, "pattern": 0,
        "midgame": 0, "endgame": 0, "forced_guess": 0,
    }
    config_counts = {}
    density_counts = {"zero": 0, "very_sparse": 0, "sparse": 0, "medium": 0, "dense": 0}

    def _track_config(game):
        key = f"{game.rows}x{game.cols}m{game.num_mines}"
        config_counts[key] = config_counts.get(key, 0) + 1
        d = game.num_mines / (game.rows * game.cols) if game.rows * game.cols > 0 else 0
        if d == 0:
            density_counts["zero"] += 1
        elif d <= 0.05:
            density_counts["very_sparse"] += 1
        elif d <= 0.10:
            density_counts["sparse"] += 1
        elif d <= 0.15:
            density_counts["medium"] += 1
        else:
            density_counts["dense"] += 1

    # Budget per phase
    n_edge     = int(num_samples * 0.10)
    n_opening  = int(num_samples * 0.25)
    n_pattern  = int(num_samples * 0.15)
    n_midgame  = int(num_samples * 0.25)
    n_endgame  = int(num_samples * 0.15)
    n_forced   = num_samples - n_edge - n_opening - n_pattern - n_midgame - n_endgame

    print(f"  Phase budgets: edge={n_edge}, opening={n_opening}, pattern={n_pattern}, "
          f"midgame={n_midgame}, endgame={n_endgame}, forced={n_forced}")

    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    # PHASE 1: Edge Cases (10%) — 50+ explicit configs
    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    print("  Phase 1: Edge cases...")
    edge_generated = 0
    edge_idx = 0
    while edge_generated < n_edge:
        ec_rows, ec_cols, ec_mines, ec_label = EDGE_CASE_CONFIGS[edge_idx % len(EDGE_CASE_CONFIGS)]
        edge_idx += 1
        seed = rng.randint(0, 999999)

        game = MinesweeperGame(rows=ec_rows, cols=ec_cols, num_mines=ec_mines, seed=seed)
        if game.state() != "ongoing":
            continue

        # 0-3 moves for variety
        num_moves = rng.randint(0, min(3, max(0, ec_rows * ec_cols - ec_mines - 1)))
        move_history = _play_smart_moves(game, rng, num_moves, use_progressive_flags=True)

        if game.state() == "ongoing":
            dataset_items.append(_build_dataset_item(game, seed, move_history))
            _track_config(game)
            phase_counts["edge_case"] += 1
            edge_generated += 1

    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    # PHASE 2: Opening-Heavy Training (25%) — Fix 75% early death rate
    # 75% fresh (0 moves), 25% single-move
    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    print("  Phase 2: Opening training...")
    opening_generated = 0
    opening_attempts = 0
    while opening_generated < n_opening and opening_attempts < n_opening * 5:
        opening_attempts += 1
        rows, cols, num_mines = _sample_board_with_density(rng)

        seed = rng.randint(0, 999999)
        game = MinesweeperGame(rows=rows, cols=cols, num_mines=num_mines, seed=seed)

        if num_mines == 0:
            if game.state() == "ongoing":
                dataset_items.append(_build_dataset_item(game, seed, []))
                _track_config(game)
                phase_counts["opening"] += 1
                opening_generated += 1
            continue

        if game.state() != "ongoing":
            continue

        # 75% fresh boards, 25% single-move
        num_moves = 0 if rng.random() < 0.75 else 1
        move_history = _play_smart_moves(game, rng, num_moves, use_progressive_flags=False)

        if game.state() == "ongoing":
            dataset_items.append(_build_dataset_item(game, seed, move_history))
            _track_config(game)
            phase_counts["opening"] += 1
            opening_generated += 1

    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    # PHASE 3: Pattern-Specific Scenarios (15%)
    # Mix of: satisfied-number boards (60%), multi-region boards (40%)
    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    print("  Phase 3: Pattern-specific scenarios...")
    pattern_generated = 0
    pattern_attempts = 0
    while pattern_generated < n_pattern and pattern_attempts < n_pattern * 10:
        pattern_attempts += 1

        # 60% satisfied numbers, 40% multi-region
        if rng.random() < 0.60:
            # Satisfied numbers — need boards with mines
            rows, cols, num_mines = _sample_board_with_density(rng, (0.05, 0.20))
            if rows < 3 or cols < 3:
                continue
            game, move_history, seed = _generate_satisfied_numbers_state(
                rows, cols, num_mines, rng
            )
        else:
            # Multi-region — need larger boards
            rows, cols, num_mines = _sample_board_with_density(rng, (0.05, 0.15))
            rows = max(rows, 8)
            cols = max(cols, 8)
            num_mines = min(num_mines, int(rows * cols * 0.15))
            if num_mines < 1:
                num_mines = max(1, int(rows * cols * 0.05))
            game, move_history, seed = _generate_multi_region_state(
                rows, cols, num_mines, rng
            )

        if game is not None and game.state() == "ongoing":
            dataset_items.append(_build_dataset_item(game, seed, move_history))
            _track_config(game)
            phase_counts["pattern"] += 1
            pattern_generated += 1

    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    # PHASE 4: Mid-Game Logical Deduction (25%) — Core gameplay
    # 3-15 moves played, progressive flagging, density-stratified
    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    print("  Phase 4: Mid-game deduction...")
    midgame_generated = 0
    midgame_attempts = 0
    while midgame_generated < n_midgame and midgame_attempts < n_midgame * 5:
        midgame_attempts += 1
        rows, cols, num_mines = _sample_board_with_density(rng)

        if num_mines == 0:
            continue  # Skip zero-mine for midgame

        seed = rng.randint(0, 999999)
        game = MinesweeperGame(rows=rows, cols=cols, num_mines=num_mines, seed=seed)

        if game.state() != "ongoing":
            continue

        total_safe = rows * cols - num_mines
        max_moves = min(15, max(3, total_safe - 1))
        num_moves = rng.randint(3, max_moves)

        move_history = _play_smart_moves(game, rng, num_moves, use_progressive_flags=True)

        if game.state() == "ongoing" and len(game._revealed) > 0:
            dataset_items.append(_build_dataset_item(game, seed, move_history))
            _track_config(game)
            phase_counts["midgame"] += 1
            midgame_generated += 1

    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    # PHASE 5: Endgame Completion (15%) — 80-98% revealed
    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    print("  Phase 5: Endgame completion...")
    endgame_generated = 0
    endgame_attempts = 0
    while endgame_generated < n_endgame and endgame_attempts < n_endgame * 10:
        endgame_attempts += 1
        rows, cols, num_mines = _sample_board_with_density(rng, (0.05, 0.20))

        if num_mines < 1 or rows * cols < 8:
            continue

        completion = rng.uniform(0.80, 0.95)
        game, move_history, seed = _generate_endgame_state(
            rows, cols, num_mines, rng, completion_target=completion
        )

        if game is not None and game.state() == "ongoing":
            dataset_items.append(_build_dataset_item(game, seed, move_history))
            _track_config(game)
            phase_counts["endgame"] += 1
            endgame_generated += 1

    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    # PHASE 6: Forced Guess Scenarios (10%) — No logical deductions
    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    print("  Phase 6: Forced guess scenarios...")
    forced_generated = 0
    forced_attempts = 0
    while forced_generated < n_forced and forced_attempts < n_forced * 15:
        forced_attempts += 1
        # Dense boards more likely to produce forced-guess states
        rows, cols, num_mines = _sample_board_with_density(rng, (0.10, 0.20))

        if num_mines < 2 or rows * cols < 6:
            continue

        game, move_history, seed = _generate_forced_guess_state(
            rows, cols, num_mines, rng
        )

        if game is not None and game.state() == "ongoing":
            dataset_items.append(_build_dataset_item(game, seed, move_history))
            _track_config(game)
            phase_counts["forced_guess"] += 1
            forced_generated += 1

    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    # Shuffle and trim
    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    rng.shuffle(dataset_items)
    dataset_items = dataset_items[:num_samples]
    ds = Dataset.from_list(dataset_items)

    return ds, config_counts, phase_counts, density_counts


# ══════════════════════════════════════════════════════════════════════
#  Generate & Analyze
# ══════════════════════════════════════════════════════════════════════

print("=" * 70)
print("  EXHAUSTIVE DATASET GENERATION")
print("  4000 samples | 6 phases | 50+ edge cases | density-stratified")
print("=" * 70)
print()

dataset, config_counts, phase_counts, density_counts = generate_exhaustive_dataset(
    num_samples=1000, rng_seed=42
)

print(f"\n{'─'*70}")
print(f"Created {len(dataset)} training examples\n")

# Extract arrays for analysis
all_rows = [item["board_rows"] for item in dataset]
all_cols = [item["board_cols"] for item in dataset]
all_mines = [item["board_mines"] for item in dataset]
densities = [m / (r * c) * 100 if r * c > 0 else 0 for r, c, m in zip(all_rows, all_cols, all_mines)]

# Phase distribution
print("Phase distribution:")
for phase, count in phase_counts.items():
    pct = count / len(dataset) * 100 if len(dataset) > 0 else 0
    bar = '█' * int(pct / 2)
    print(f"  {phase:14s}: {count:4d} ({pct:5.1f}%) {bar}")

# Size band distribution (curriculum analysis)
print(f"\nSize band distribution (bell-curve curriculum):")
for lo, hi, target_w, label in SIZE_BANDS:
    count = sum(1 for r, c in zip(all_rows, all_cols) if max(r, c) >= lo and max(r, c) <= hi)
    pct = count / len(dataset) * 100 if len(dataset) > 0 else 0
    bar = '█' * int(pct / 2)
    print(f"  {label:6s} ({lo:2d}-{hi:2d}): {count:4d} ({pct:5.1f}%) target={target_w*100:.0f}% {bar}")

# Density distribution
print(f"\nDensity distribution:")
for band, count in density_counts.items():
    pct = count / len(dataset) * 100 if len(dataset) > 0 else 0
    bar = '█' * int(pct / 2)
    print(f"  {band:12s}: {count:4d} ({pct:5.1f}%) {bar}")

# Prompt tier distribution
print(f"\nPrompt tier distribution:")
tier_counts = {1: 0, 2: 0, 3: 0, 4: 0}
for r, c in zip(all_rows, all_cols):
    mx = max(r, c)
    if mx <= 5: tier_counts[1] += 1
    elif mx <= 12: tier_counts[2] += 1
    elif mx <= 25: tier_counts[3] += 1
    else: tier_counts[4] += 1
tier_labels = {1: 'Tiny (≤5)', 2: 'Small (6-12)', 3: 'Medium (13-25)', 4: 'Large (26+)'}
for tier, count in tier_counts.items():
    pct = count / len(dataset) * 100 if len(dataset) > 0 else 0
    bar = '█' * int(pct / 2)
    print(f"  Tier {tier} {tier_labels[tier]:16s}: {count:4d} ({pct:5.1f}%) {bar}")

zero_mine_count = sum(1 for m in all_mines if m == 0)
print(f"  Zero-mine boards: {zero_mine_count} ({zero_mine_count/len(dataset)*100:.1f}%)")

move_counts = [len(json.loads(item["move_history"])) for item in dataset]
print(f"\nMove statistics:")
print(f"  Min: {min(move_counts)}, Max: {max(move_counts)}, "
      f"Mean: {np.mean(move_counts):.1f}, Median: {np.median(move_counts):.1f}")
print(f"  Density: min={min(densities):.1f}%, max={max(densities):.1f}%, mean={np.mean(densities):.1f}%")

# Logical deduction coverage
print(f"\nLogical deduction coverage:")
has_safe = 0
has_mine = 0
has_both = 0
for item in dataset:
    mh = json.loads(item["move_history"])
    has_flag = any(m.get("type") == "flag" for m in mh)
    has_rev = any(m.get("type") == "reveal" for m in mh)
    if has_flag:
        has_mine += 1
    if has_rev:
        has_safe += 1
    if has_flag and has_rev:
        has_both += 1
print(f"  Samples with reveals: {has_safe} ({has_safe/len(dataset)*100:.1f}%)")
print(f"  Samples with flags:   {has_mine} ({has_mine/len(dataset)*100:.1f}%)")
print(f"  Samples with both:    {has_both} ({has_both/len(dataset)*100:.1f}%)")

# Verify dataset columns
print(f"\nDataset columns: {dataset.column_names}")
print(f"  ✅ board_rows, board_cols, board_mines present for reward functions")

# Sample prompt
print(f"\nSample prompt ({dataset[0]['board_rows']}x{dataset[0]['board_cols']}, "
      f"{dataset[0]['board_mines']} mines):")
print(dataset[0]["prompt"][0]["content"][:300] + "...")

# ── Save dataset to JSON ──
dataset_json_path = "minesweeper_dataset.json"
json_records = []
for item in dataset:
    record = {
        "seed": item["seed"],
        "move_history": item["move_history"],
        "board_rows": item["board_rows"],
        "board_cols": item["board_cols"],
        "board_mines": item["board_mines"],
        "prompt_text": item["prompt"][0]["content"],
    }
    json_records.append(record)

with open(dataset_json_path, "w") as f:
    json.dump(json_records, f, indent=2)

print(f"   {len(json_records)} records with fields: seed, move_history, board_rows/cols/mines, prompt_text")
print(f"\n✅ Dataset saved to {dataset_json_path} ({os.path.getsize(dataset_json_path) / 1024:.1f} KB)")

# Configure GRPO Training

Set up GRPO trainer with all hyperparameters:

In [None]:
from trl import GRPOConfig, GRPOTrainer

# ── Fix torch.compile recompilation limit ──
# GRPO generates varying sequence lengths per step, each triggering a graph
# recompile. Default cache_size_limit (8) is too low → FailOnRecompileLimitHit.
# Raising it allows torch._dynamo to cache more compiled graphs.
import torch._dynamo
torch._dynamo.config.cache_size_limit = 128
torch._dynamo.config.optimize_ddp = False

# ── Lengths ──
# Training prompts include pre-computed hints (safe/mine cells).
# Small boards (≤20): ~400-800 tokens (full grid + unified prompt)
# Medium boards (21-35): ~500-900 tokens (frontier + unified prompt)
# Large boards (36-50): ~400-800 tokens (summary + unified prompt)
# 1900 + 128 = 2028 < 2048 = max_seq_length (fits comfortably)
# 128 completion tokens: JSON-only output (hackathon constraint)
max_prompt_length = 1900
max_completion_length = 128  # HACKATHON CONSTRAINT: JSON-only, no reasoning
                             # Pure JSON action is ~10-25 tokens, well under 128

# ── GRPO Configuration (Simplified & Fast) ──
training_args = GRPOConfig(
    # === Generation ===
    temperature = 1.0,           # Fixed at 1.0 (no annealing)
    top_p = 0.95,

    # === Optimization ===
    learning_rate = 5e-5,        # 10× higher than before — loss was barely moving
    weight_decay = 0.01,
    warmup_ratio = 0.10,         # Longer warmup for stability with higher LR
    lr_scheduler_type = "cosine",
    optim = "adamw_8bit",
    max_grad_norm = 0.5,

    # === Batch sizes ===
    logging_steps = 1,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 2,   # Was 4 → 2× faster effective steps
    num_generations = 16,        # Was 24 → faster per step, still enough variance

    # === Lengths ===
    max_prompt_length = max_prompt_length,
    max_completion_length = max_completion_length,

    # === Training duration ===
    max_steps = 500,
    save_steps = 100,

    # === GRPO specific ===
    beta = 0.01,                 # Was 0.03 → less KL restriction, faster learning
    num_iterations = 1,          # Was 2 → halves compute per step

    # === Ensure extra dataset columns are NOT removed ===
    remove_unused_columns = False,

    # === Output ===
    report_to = "none",
    output_dir = "minesweeper_grpo_v2",
    seed = 42,
    bf16 = True,
)

print("Training configuration:")
print(f"  torch._dynamo.cache_size_limit = 128 (fix recompile crash)")
print(f"  Temperature:         {training_args.temperature}  (fixed, no annealing)")
print(f"  LR:                  {training_args.learning_rate}")
print(f"  Beta (KL penalty):   {training_args.beta}")
print(f"  Num iterations:      {training_args.num_iterations}")
print(f"  Num generations:     {training_args.num_generations}")
print(f"  Grad accum steps:    {training_args.gradient_accumulation_steps}")
print(f"  Warmup ratio:        {training_args.warmup_ratio}")
print(f"  LoRA rank:           {lora_rank}")
print()
print("Speed improvements vs previous:")
print("  • num_generations 24→16, num_iterations 2→1, grad_accum 4→2")
print("  • stop_strings=['\\n'] → outputs stop after JSON line (~14 tokens)")
print("  • Newline token added as EOS → outputs stop after JSON line (~14 tokens)")
print("  • LR 2e-5→5e-5, beta 0.03→0.01 → faster convergence")

In [None]:
from transformers import TrainerCallback

# Board configs to evaluate on during training (mix of sizes from 1-50 range)
EVAL_CONFIGS = [
    (1, 1, 0),     # Trivial — 0 mines
    (3, 3, 1),     # Tiny
    (5, 5, 3),     # Small
    (6, 6, 5),     # Standard
    (8, 8, 10),    # Medium
    (10, 10, 20),  # Large
    (15, 15, 45),  # XL
    (6, 8, 6),     # Rectangular
    (1, 10, 2),    # Row board
    (20, 20, 80),  # XX-Large
]


# ──────────────────────────────────────────────────────────────────────
# FIX #3 (v2): Temperature Annealing Callback
# Prevents policy collapse by starting with high exploration (1.2)
# and annealing to focused generation (0.7) over first 70% of training.
#
# Schedule:  Steps 0-100: 1.2→1.0 | 100-300: 1.0→0.8 | 300-350: 0.8→0.7 | 350+: 0.7
# ──────────────────────────────────────────────────────────────────────

class TemperatureAnnealingCallback(TrainerCallback):
    """Anneal generation temperature during GRPO training.

    Why: Fixed temp=1.0 caused policy collapse after ~20 steps.
    - Early training: High temp (1.2) → diverse outputs → healthy reward variance
    - Late training:  Low temp (0.7)  → converged policy → still enough variance

    Linear annealing over first 70% of training, then fixed at final_temp.
    """

    def __init__(self, max_steps, initial_temp=1.2, final_temp=0.7):
        self.max_steps = max_steps
        self.initial_temp = initial_temp
        self.final_temp = final_temp

    def get_temperature(self, step):
        """Linear annealing over first 70% of training."""
        anneal_end = self.max_steps * 0.7
        if step >= anneal_end:
            return self.final_temp
        progress = step / anneal_end
        return self.initial_temp - (self.initial_temp - self.final_temp) * progress

    def on_step_begin(self, args, state, control, **kwargs):
        temp = self.get_temperature(state.global_step)

        # Update the GRPOConfig temperature for next generation batch
        args.temperature = temp

        # Log every 25 steps
        if state.global_step % 25 == 0:
            print(f"  [Step {state.global_step}] Temperature: {temp:.3f}")

        return control


class MinesweeperEvalCallback(TrainerCallback):
    """Periodically play games during training with variable board sizes.

    NO move limit — games end only on success (all safe revealed)
    or failure (mine hit). Max iterations capped to prevent infinite loops
    from repeated invalid actions.

    FIX #5 (v2): Enhanced early stopping with:
      - JSON reward degradation tracking (3 consecutive 30%+ drops)
      - Mean length monitoring (warns >25, stops >40)
      - Clipping ratio monitoring (warns >5%, stops >15%)
      - reward_std collapse detection (warns <0.1, counts consecutive)
      - Consecutive collapse counter (stops after 5 consecutive collapsed steps)
    """

    def __init__(self, eval_every_steps=50, num_games=10):
        self.eval_every_steps = eval_every_steps
        self.num_games = min(num_games, len(EVAL_CONFIGS))
        self.best_json_reward = 0.0
        self.best_json_step = 0
        self.degradation_warnings = 0
        self.consecutive_collapse = 0         # v2: track consecutive reward_std < 0.1

    def on_log(self, args, state, control, logs=None, **kwargs):
        """Monitor training health and stop if degenerating."""
        if logs is None:
            return

        step = state.global_step

        # ── MONITOR 1: Policy Collapse (reward_std) ──
        reward_std = logs.get('reward_std')
        if reward_std is not None:
            if reward_std < 0.1:
                self.consecutive_collapse += 1
                mean_len = logs.get('completions/mean_length', 'N/A')
                max_len = logs.get('completions/max_length', 'N/A')
                print(f"\n⚠️  [Step {step}] reward_std={reward_std:.4f} — "
                      f"NEAR-COLLAPSED POLICY ({self.consecutive_collapse} consecutive)")
                print(f"    All {args.num_generations} generations producing nearly "
                      f"identical outputs.")
                print(f"    mean_length={mean_len}, max_length={max_len}")

                if self.consecutive_collapse >= 5:
                    print(f"\n🛑 STOPPING TRAINING: Policy collapsed for 5 consecutive "
                          f"logged steps!")
                    print(f"    reward_std has been < 0.1 since step "
                          f"~{step - self.consecutive_collapse}")
                    print(f"    GRPO has zero gradient signal — training is wasted.")
                    control.should_training_stop = True
            else:
                self.consecutive_collapse = 0  # Reset on healthy variance

        # ── MONITOR 2: JSON Reward Degradation ──
        json_reward = logs.get('rewards/valid_json_reward/mean')
        if json_reward is not None:
            if json_reward > self.best_json_reward:
                self.best_json_reward = json_reward
                self.best_json_step = step
                self.degradation_warnings = 0

            elif self.best_json_reward > 0 and json_reward < self.best_json_reward * 0.7:
                self.degradation_warnings += 1
                mean_len = logs.get('completions/mean_length', 'N/A')
                max_len = logs.get('completions/max_length', 'N/A')
                clipped = logs.get('completions/clipped_ratio', 'N/A')
                print(f"\n⚠️  [Step {step}] JSON reward DEGRADED: "
                      f"{json_reward:.2f} < {self.best_json_reward:.2f} "
                      f"(best @ step {self.best_json_step}) "
                      f"[{self.degradation_warnings}/8 warnings]")
                print(f"    mean_length={mean_len}, max_length={max_len}, "
                      f"clipped={clipped}")

                if self.degradation_warnings >= 8:
                    print(f"\n🛑 STOPPING TRAINING: JSON reward degraded 8 "
                          f"consecutive times!")
                    print(f"    Best: {self.best_json_reward:.2f} @ step "
                          f"{self.best_json_step}")
                    print(f"    Current: {json_reward:.2f} @ step {step}")
                    control.should_training_stop = True

        # ── MONITOR 3: Mean Completion Length (verbosity) ──
        mean_len = logs.get('completions/mean_length')
        if mean_len is not None and mean_len > 25:
            print(f"\n⚠️  [Step {step}] mean_length={mean_len:.1f} — "
                  f"MODEL GETTING VERBOSE!")
            print(f"    Expected: 15-20 tokens (pure JSON). Got: {mean_len:.1f}")
            if mean_len > 40:
                self.degradation_warnings += 1
                print(f"    🛑 mean_length > 40 — model is adding reasoning text. "
                      f"[{self.degradation_warnings}/8 warnings]")
                if self.degradation_warnings >= 8:
                    print(f"\n🛑 STOPPING TRAINING: Model diverged into verbose "
                          f"outputs!")
                    control.should_training_stop = True

        # ── MONITOR 4: Clipping (hitting 128-token limit) ──
        clipped = logs.get('completions/clipped_ratio')
        if clipped is not None and clipped > 0.05:
            print(f"\n⚠️  [Step {step}] clipped_ratio={clipped:.1%} — "
                  f"responses hitting 128-token limit!")
            if clipped > 0.15:
                self.degradation_warnings += 1
                print(f"    🛑 {clipped:.1%} truncated — model diverging! "
                      f"[{self.degradation_warnings}/3 warnings]")
                if self.degradation_warnings >= 3:
                    print(f"\n🛑 STOPPING TRAINING: Too many truncated responses!")
                    control.should_training_stop = True

        return control

    def on_step_end(self, args, state, control, model=None, processing_class=None, **kwargs):
        if state.global_step % self.eval_every_steps != 0:
            return

        tokenizer = processing_class
        if tokenizer is None or model is None:
            return

        was_training = model.training
        model.eval()

        wins = 0
        total_moves = 0
        invalid_count = 0

        for i in range(self.num_games):
            rows, cols, mines = EVAL_CONFIGS[i % len(EVAL_CONFIGS)]
            game = MinesweeperGame(rows=rows, cols=cols, num_mines=mines,
                                   seed=10000 + i)
            moves = 0
            invalids = 0
            consecutive_invalids = 0
            seen_actions = set()
            repeat_count = 0
            max_iterations = min(500, rows * cols + 100)

            iteration = 0
            while game.state() == "ongoing" and iteration < max_iterations:
                iteration += 1
                prompt = format_state_for_llm(game, mode="inference")
                text = tokenizer.apply_chat_template(
                    [
                        {"role": "system", "content": SYSTEM_PROMPT},
                        {"role": "user", "content": prompt},
                    ],
                    tokenize=False,
                    add_generation_prompt=True,
                )
                inputs = tokenizer(text, return_tensors="pt", truncation=True,
                                   max_length=max_prompt_length + 100)
                inputs = {k: v.to(model.device) for k, v in inputs.items()}

                with torch.no_grad():
                    output = model.generate(
                        **inputs,
                        temperature=0.7,
                        max_new_tokens=128,
                        do_sample=True,
                        top_p=0.9,
                    )

                gen_tokens = output[0][inputs["input_ids"].shape[1]:]
                response = tokenizer.decode(gen_tokens, skip_special_tokens=True)
                action = parse_llm_action(response)

                if action is None:
                    invalids += 1
                    consecutive_invalids += 1
                    if consecutive_invalids >= 5:
                        break
                    continue

                consecutive_invalids = 0

                action_key = (action['type'], action['row'], action['col'])
                if action_key in seen_actions:
                    repeat_count += 1
                    if repeat_count >= 3:
                        break
                else:
                    repeat_count = 0
                    seen_actions.add(action_key)

                result = game.do_action(action)
                if result in ("mine", "win"):
                    moves += 1
                    break
                elif result == "ok":
                    moves += 1
                else:
                    invalids += 1
                    consecutive_invalids += 1
                    if consecutive_invalids >= 5:
                        break

            if game.state() == "success":
                wins += 1
            total_moves += moves
            invalid_count += invalids

        win_rate = wins / self.num_games
        avg_moves = total_moves / self.num_games
        print(f"\n[Eval @ step {state.global_step}] "
              f"Win: {wins}/{self.num_games} ({win_rate*100:.0f}%) | "
              f"Avg moves: {avg_moves:.1f} | "
              f"Invalid: {invalid_count}\n")

        if was_training:
            model.train()


class CheckpointCallback(TrainerCallback):
    """Save LoRA adapters every N steps as named checkpoints."""

    def __init__(self, save_every_steps=100):
        self.save_every_steps = save_every_steps

    def on_step_end(self, args, state, control, model=None, processing_class=None, **kwargs):
        if state.global_step % self.save_every_steps != 0 or state.global_step == 0:
            return

        ckpt_dir = f"minesweeper_ckpt_step{state.global_step}"
        model.save_pretrained(ckpt_dir)
        if processing_class is not None:
            processing_class.save_pretrained(ckpt_dir)
        print(f"\n💾 [Step {state.global_step}] Checkpoint saved → {ckpt_dir}/\n")


# ── Instantiate callbacks ──
eval_callback = MinesweeperEvalCallback(eval_every_steps=50, num_games=10)
ckpt_callback = CheckpointCallback(save_every_steps=100)
temp_callback = TemperatureAnnealingCallback(max_steps=500, initial_temp=1.2, final_temp=0.7)

print(f"Callbacks configured:")
print(f"  1. MinesweeperEvalCallback: {eval_callback.num_games} games every "
      f"{eval_callback.eval_every_steps} steps")
print(f"     No move limit — only success or failure")
print(f"     Configs: {len(EVAL_CONFIGS)} sizes from 1x1 to 20x20")
print(f"  2. CheckpointCallback: LoRA adapters saved every {ckpt_callback.save_every_steps} steps")
print(f"     Saves to: minesweeper_ckpt_step{{100,200,300,400,500}}/")
print(f"  3. TemperatureAnnealingCallback: {temp_callback.initial_temp} → "
      f"{temp_callback.final_temp} over first 70% of {temp_callback.max_steps} steps")
print(f"     Schedule: 1.2→1.0 (step 0-100) | 1.0→0.8 (100-300) | 0.8→0.7 (300-350) | 0.7 fixed")
print()
print("Early stopping monitors:")
print("  ✅ MONITOR 1: reward_std < 0.1 → stops after 5 consecutive collapsed steps")
print("  ✅ MONITOR 2: JSON reward 30%+ drop → stops after 8 consecutive warnings")
print("  ✅ MONITOR 3: mean_length > 25 warns, >40 adds degradation warning")
print("  ✅ MONITOR 4: clipped_ratio > 5% warns, >15% adds degradation warning")

# Train the Model

Start GRPO training with reward functions:

In [None]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        valid_json_reward,   # Format + exponential penalty
        gameplay_scores,     # 12 criteria + center-opening + difficulty
    ],
    reward_weights = [0.60, 0.40],
    args = training_args,
    train_dataset = dataset,
    callbacks = [eval_callback, ckpt_callback, temp_callback],
)

print("Starting GRPO training...")
print("  [1] valid_json_reward  (weight: 0.60)")
print("  [2] gameplay_scores    (weight: 0.40)")
print("  Temperature: 1.2 → 0.7 (annealing)")
trainer.train()

# Test Trained Model

Evaluate the finetuned model:

In [None]:
# Test on variable board sizes across the full competition range
FastLanguageModel.for_inference(model)

test_configs = [
    (1, 1, 0, 200, "1x1 Trivial"),
    (3, 3, 1, 201, "Tiny"),
    (5, 5, 3, 202, "Small/Easy"),
    (6, 6, 5, 99,  "Standard"),
    (8, 8, 10, 101, "Medium"),
    (10, 10, 20, 203, "Large 20%"),
    (15, 15, 45, 204, "XL 20%"),
    (1, 20, 4, 205, "Row Board"),
    (20, 1, 4, 206, "Column Board"),
    (5, 5, 0, 207, "Zero Mines"),
]

for rows, cols, mines, seed, label in test_configs:
    print(f"\n{'='*50}")
    print(f"=== {label} ({rows}x{cols}, {mines} mines) ===")
    print(f"{'='*50}")

    test_game = MinesweeperGame(rows=rows, cols=cols, num_mines=mines, seed=seed)

    # Handle already-won games (0-mine boards that auto-cascade)
    if test_game.state() != "ongoing":
        print(f"  Game auto-resolved: {test_game.state()}")
        continue

    test_prompt = format_state_for_llm(test_game, mode="inference")

    test_text = tokenizer.apply_chat_template(
        [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": test_prompt},
        ],
        tokenize=False,
        add_generation_prompt=True,
    )

    inputs = tokenizer(test_text, return_tensors="pt", truncation=True,
                       max_length=max_prompt_length + 100)
    inputs = {k: v.to("cuda") for k, v in inputs.items()}

    output = model.generate(
        **inputs,
        temperature=0.7,  # LAMER paper: 0.7 for eval
        max_new_tokens=128,  # HACKATHON CONSTRAINT: JSON-only
        do_sample=True,
        top_p=0.9,
        repetition_penalty=1.2,
    )

    # Decode only generated tokens
    gen_tokens = output[0][inputs["input_ids"].shape[1]:]
    response_text = tokenizer.decode(gen_tokens, skip_special_tokens=True)
    print(f"Response: {response_text.strip()}")

    action = parse_llm_action(response_text)
    print(f"Parsed action: {action}")

    if action:
        result = test_game.do_action(action)
        print(f"Result: {result} | Game state: {test_game.state()}")
    else:
        print("⚠️ Failed to parse a valid action")

# Save the Model

Save your trained model for competition submission:

In [None]:
# Save LoRA adapters
model.save_pretrained("my_minesweeper_model")
tokenizer.save_pretrained("my_minesweeper_model")
print("✅ LoRA adapters saved to: my_minesweeper_model/")

# ──────────────────────────────────────────────────────────────────────
# Save merged model in 16bit
# Workaround for Unsloth bug: UnboundLocalError on 'copied_tokenizer_model_from_cache'
# when using local model paths. We manually merge LoRA weights and save with HF API.
# ──────────────────────────────────────────────────────────────────────
import os, shutil, gc

merged_dir = "my_minesweeper_model_merged"
os.makedirs(merged_dir, exist_ok=True)

try:
    # Try Unsloth's native method first (works on some versions)
    model.save_pretrained_merged(
        merged_dir,
        tokenizer,
        save_method="merged_16bit",
    )
    print("✅ Merged 16-bit model saved via Unsloth")
except (UnboundLocalError, Exception) as e:
    print(f"⚠️ Unsloth merge failed ({type(e).__name__}: {e})")
    print("   Falling back to manual LoRA merge...")

    try:
        # Manual merge: get the PEFT model, merge LoRA into base weights, save
        from peft import PeftModel
        from transformers import AutoModelForCausalLM, AutoTokenizer

        # Re-load base model in float16 for merge
        print("   Loading base model for merge...")
        base_model = AutoModelForCausalLM.from_pretrained(
            "/workspace/workspace/Qwen2.5-14B-Instruct",
            torch_dtype=torch.float16,
            device_map="auto",
        )

        # Load LoRA adapters on top
        print("   Applying LoRA adapters...")
        merged_model = PeftModel.from_pretrained(base_model, "my_minesweeper_model")
        merged_model = merged_model.merge_and_unload()

        # Save the fully merged model
        print("   Saving merged model...")
        merged_model.save_pretrained(merged_dir, safe_serialization=True)
        tokenizer.save_pretrained(merged_dir)

        # Cleanup
        del base_model, merged_model
        gc.collect()
        torch.cuda.empty_cache()

        print(f"✅ Merged 16-bit model saved to: {merged_dir}/")

    except Exception as e2:
        print(f"❌ BOTH merge methods failed!")
        print(f"   Unsloth error: {e}")
        print(f"   Manual merge error: {e2}")
        print(f"   LoRA adapters are still saved at: my_minesweeper_model/")
        print(f"   You can merge manually later with:")
        print(f"     from peft import PeftModel")
        print(f"     base = AutoModelForCausalLM.from_pretrained('<base_model_path>')")
        print(f"     merged = PeftModel.from_pretrained(base, 'my_minesweeper_model')")
        print(f"     merged = merged.merge_and_unload()")
        print(f"     merged.save_pretrained('{merged_dir}')")

# Verify saved files
saved_files = os.listdir(merged_dir)
safetensors = [f for f in saved_files if f.endswith(".safetensors")]
print(f"   Files: {len(saved_files)} total, {len(safetensors)} safetensors shards")
print(f"   Config: {'config.json' in saved_files}  Tokenizer: {'tokenizer.json' in saved_files}")

# Exhaustive Evaluation: Full Competition Range

Play complete games across 37 board configurations (1×1 to 50×50, 0-20% mines).
**No move limit** — only two outcomes: SUCCESS (all safe cells revealed) or FAILURE (mine revealed).

In [None]:
def play_full_game(model, tokenizer, rows=6, cols=6, num_mines=5, seed=None,
                   verbose=False):
    """Play a complete Minesweeper game, tracking detailed metrics.

    Supports any board size from 1x1 to 50x50.
    NO move limit — game ends ONLY on:
      - SUCCESS: all non-mine cells are revealed
      - FAILURE: any mine cell is revealed
    A safety iteration cap prevents infinite loops from repeated invalid actions.
    """
    game = MinesweeperGame(rows=rows, cols=cols, num_mines=num_mines, seed=seed)

    # Edge case: game already won (e.g., 0-mine board)
    if game.state() != "ongoing":
        return {
            "game": game,
            "moves": 0,
            "logical_moves": 0,
            "flags_correct": 0,
            "flags_wrong": 0,
            "total_invalids": 0,
            "result": game.state(),
            "progress": game.progress(),
            "config": f"{rows}x{cols}m{num_mines}",
        }

    moves = 0
    consecutive_invalids = 0
    total_invalids = 0
    logical_moves = 0
    flags_correct = 0
    flags_wrong = 0
    seen_actions = set()   # Fix #4: detect repeated actions (stuck loop)
    repeat_count = 0       # Fix #4: consecutive repeat counter
    # Safety cap to prevent infinite loops — NOT a game move limit
    # Capped at 500 to prevent runaway 50×50 evals (was rows*cols*3+50 = 7550 for 50×50)
    max_iterations = min(500, rows * cols + 100)

    iteration = 0
    while game.state() == "ongoing" and iteration < max_iterations:
        iteration += 1
        prompt = format_state_for_llm(game, mode="inference")
        text = tokenizer.apply_chat_template(
            [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": prompt},
            ],
            tokenize=False,
            add_generation_prompt=True,
        )

        inputs = tokenizer(text, return_tensors="pt", truncation=True,
                           max_length=max_prompt_length + 100)
        inputs = {k: v.to("cuda") for k, v in inputs.items()}

        with torch.no_grad():
            output = model.generate(
                **inputs,
                temperature=0.7,  # LAMER paper: 0.7 for eval
                max_new_tokens=128,  # HACKATHON CONSTRAINT: JSON-only
                do_sample=True,
                top_p=0.9,
                repetition_penalty=1.2,
            )

        # Decode ONLY generated tokens
        gen_tokens = output[0][inputs["input_ids"].shape[1]:]
        response = tokenizer.decode(gen_tokens, skip_special_tokens=True)
        action = parse_llm_action(response)

        if action is None:
            consecutive_invalids += 1
            total_invalids += 1
            if consecutive_invalids >= 5:
                break  # Agent is stuck — abort
            continue

        # Fix #4: Check for repeated actions (stuck detection)
        action_key = (action['type'], action['row'], action['col'])
        if action_key in seen_actions:
            repeat_count += 1
            if repeat_count >= 3:  # Same move 3 times = stuck
                break
        else:
            repeat_count = 0
            seen_actions.add(action_key)

        # Track logical moves (compute BEFORE applying action)
        safe_set, mine_set = _compute_safe_and_mine_cells(game)
        r, c = action["row"], action["col"]
        if action["type"] == "reveal" and (r, c) in safe_set:
            logical_moves += 1
        elif action["type"] == "flag" and (r, c) in mine_set:
            logical_moves += 1

        # Track flag accuracy
        if action["type"] == "flag":
            if 0 <= r < game.rows and 0 <= c < game.cols:
                if game._board[r][c] == -1:
                    flags_correct += 1
                else:
                    flags_wrong += 1

        if verbose:
            print(f"  Move {moves}: {action}")

        result = game.do_action(action)
        if result in ("mine", "win", "ok"):
            moves += 1
        elif result in ("out_of_bounds", "already_revealed", "flagged_cell",
                         "invalid_flag", "invalid_format"):
            # Invalid moves don't count but game stays ongoing
            total_invalids += 1
            consecutive_invalids += 1
            if consecutive_invalids >= 5:
                break

        if result in ("mine", "win"):
            break

    return {
        "game": game,
        "moves": moves,
        "logical_moves": logical_moves,
        "flags_correct": flags_correct,
        "flags_wrong": flags_wrong,
        "total_invalids": total_invalids,
        "result": game.state(),
        "progress": game.progress(),
        "config": f"{rows}x{cols}m{num_mines}",
    }


# ──────────────────────────────────────────────────────────────────────
# EXHAUSTIVE Multi-Size Evaluation — Competition Spec
# n, m ∈ [1, 50], mines 0-20% of total cells
# Only two outcomes: SUCCESS or FAILURE (no timeouts)
# ──────────────────────────────────────────────────────────────────────

EVAL_SUITE = [
    # (rows, cols, mines, num_games, label)

    # === Trivial / Edge Cases ===
    (1, 1, 0,   5,  "1x1 trivial"),
    (1, 2, 0,   5,  "1x2 trivial"),
    (2, 1, 0,   5,  "2x1 trivial"),
    (2, 2, 0,   5,  "2x2 no mines"),
    (3, 3, 0,   5,  "3x3 no mines"),
    (5, 5, 0,   5,  "5x5 no mines"),

    # === Tiny Boards ===
    (3, 3, 1,  10,  "3x3 1 mine"),
    (4, 4, 3,  10,  "4x4 3 mines"),

    # === Small Boards ===
    (5, 5, 3,  15,  "5x5 easy"),
    (5, 5, 5,  10,  "5x5 max density"),

    # === Standard ===
    (6, 6, 5,  20,  "6x6 standard"),
    (6, 6, 7,  10,  "6x6 hard"),

    # === Medium ===
    (7, 7, 7,  10,  "7x7 medium"),
    (8, 8, 10, 10,  "8x8 medium"),
    (8, 8, 12, 10,  "8x8 max density"),

    # === Large ===
    (10, 10, 10,  5, "10x10 10%"),
    (10, 10, 20, 10, "10x10 20%"),
    (15, 15, 45,  5, "15x15 20%"),

    # === XL ===
    (20, 20, 80,  3, "20x20 20%"),
    (25, 25, 125, 3, "25x25 20%"),
    (30, 30, 180, 2, "30x30 20%"),

    # === XXL ===
    (40, 40, 320, 2, "40x40 20%"),
    (50, 50, 500, 2, "50x50 20%"),

    # === Rectangular ===
    (1, 10, 2,   5, "1x10 row"),
    (10, 1, 2,   5, "10x1 column"),
    (1, 50, 10,  3, "1x50 row"),
    (50, 1, 10,  3, "50x1 column"),
    (6, 8, 6,    5, "6x8 rect"),
    (8, 6, 6,    5, "8x6 rect"),
    (5, 15, 15,  3, "5x15 wide"),
    (15, 5, 15,  3, "15x5 tall"),
    (3, 50, 30,  2, "3x50 extreme wide"),
    (50, 3, 30,  2, "50x3 extreme tall"),

    # === Sparse (low density) ===
    (10, 10, 1,  5, "10x10 sparse"),
    (20, 20, 4,  3, "20x20 sparse"),
    (50, 50, 10, 2, "50x50 sparse"),

    # === Progressive Difficulty (Fix #10: LAMER generalization test) ===
    # Same size, increasing mines — tests density scaling
    (10, 10, 5,  5, "10x10 5% progressive"),
    (10, 10, 15, 5, "10x10 15% progressive"),

    # Increasing size, same ~10% density — tests size scaling
    (15, 15, 23, 3, "15x15 10% density"),
    (20, 20, 40, 3, "20x20 10% density"),

    # Generalization to harder unseen configs
    (25, 25, 62, 2, "25x25 10% generalization"),
    (30, 30, 90, 2, "30x30 10% generalization"),
]

FastLanguageModel.for_inference(model)
total_games = sum(n for _,_,_,n,_ in EVAL_SUITE)
print(f"{'='*80}")
print(f"  EXHAUSTIVE EVALUATION — {total_games} games across {len(EVAL_SUITE)} configs")
print(f"  Competition spec: n,m ∈ [1,50], mines 0-20%, no move limit")
print(f"  Includes progressive difficulty test (LAMER generalization)")
print(f"  Only two outcomes: SUCCESS or FAILURE")
print(f"{'='*80}\n")

all_results = []
per_config_stats = {}

for rows, cols, mines, num_games, label in EVAL_SUITE:
    config_key = f"{rows}x{cols}m{mines}"
    wins = 0
    fails = 0
    config_results = []

    for i in range(num_games):
        info = play_full_game(model, tokenizer, rows=rows, cols=cols,
                              num_mines=mines, seed=5000 + i + hash(config_key) % 10000)
        config_results.append(info)
        all_results.append(info)

        if info["result"] == "success":
            wins += 1
        elif info["result"] == "failed":
            fails += 1

    # Per-config summary
    avg_moves = np.mean([r["moves"] for r in config_results])
    avg_logical = np.mean([r["logical_moves"] for r in config_results])
    avg_progress = np.mean([r["progress"] for r in config_results])
    avg_invalids = np.mean([r["total_invalids"] for r in config_results])
    stuck = sum(1 for r in config_results if r["result"] == "ongoing")
    wr = wins / num_games * 100

    per_config_stats[config_key] = {
        "label": label, "wins": wins, "fails": fails, "stuck": stuck,
        "total": num_games, "win_rate": wr,
        "avg_moves": avg_moves, "avg_progress": avg_progress,
    }

    status_icon = "✅" if wr >= 50 else "⚠️" if wr >= 20 else "❌"
    print(f"  {status_icon} {label:22s} ({config_key:12s}): "
          f"{wins:2d}/{num_games:2d} wins ({wr:5.1f}%) | "
          f"fails={fails} stuck={stuck} | "
          f"moves={avg_moves:5.1f} | logical={avg_logical:4.1f} | "
          f"progress={avg_progress:.0%} | invalids={avg_invalids:.1f}")

# ── Overall Summary ──
total_games_actual = len(all_results)
total_wins = sum(1 for r in all_results if r["result"] == "success")
total_fails = sum(1 for r in all_results if r["result"] == "failed")
total_stuck = sum(1 for r in all_results if r["result"] == "ongoing")
fc = sum(r["flags_correct"] for r in all_results)
fw = sum(r["flags_wrong"] for r in all_results)

print(f"\n{'='*80}")
print(f"  OVERALL: {total_wins}/{total_games_actual} wins ({total_wins/total_games_actual*100:.1f}%)")
print(f"{'='*80}")
print(f"  Wins (success):    {total_wins:4d} ({total_wins/total_games_actual*100:.1f}%)")
print(f"  Losses (failure):  {total_fails:4d} ({total_fails/total_games_actual*100:.1f}%)")
print(f"  Stuck (loop cap):  {total_stuck:4d} ({total_stuck/total_games_actual*100:.1f}%)")
print(f"  Avg moves:         {np.mean([r['moves'] for r in all_results]):.1f}")
print(f"  Avg progress:      {np.mean([r['progress'] for r in all_results]):.0%}")
print(f"  Avg logical moves: {np.mean([r['logical_moves'] for r in all_results]):.1f}")
if fc + fw > 0:
    print(f"  Flag accuracy:     {fc}/{fc+fw} ({fc/(fc+fw)*100:.1f}%)")
else:
    print(f"  Flags: none placed")

# ── Category Breakdown ──
categories = {
    "Trivial (0 mines)": [k for k, v in per_config_stats.items() if "m0" in k],
    "Tiny (≤4x4)":       [k for k, v in per_config_stats.items()
                           if v["label"].startswith(("3x3", "4x4")) and "m0" not in k],
    "Small (5x5)":       [k for k, v in per_config_stats.items() if k.startswith("5x5")],
    "Standard (6x6)":    [k for k, v in per_config_stats.items() if k.startswith("6x6")],
    "Medium (7-8)":      [k for k, v in per_config_stats.items()
                           if k.startswith(("7x7", "8x8"))],
    "Large (10-15)":     [k for k, v in per_config_stats.items()
                           if k.startswith(("10x10", "15x15"))],
    "XL (20-50)":        [k for k, v in per_config_stats.items()
                           if k.startswith(("20x20", "25x25", "30x30", "40x40", "50x50"))],
    "Rectangular":       [k for k, v in per_config_stats.items()
                           if "rect" in v["label"] or "row" in v["label"]
                           or "column" in v["label"] or "wide" in v["label"]
                           or "tall" in v["label"]],
    "Sparse":            [k for k, v in per_config_stats.items() if "sparse" in v["label"]],
    "Progressive":       [k for k, v in per_config_stats.items()
                           if "progressive" in v["label"] or "generalization" in v["label"]
                           or "density" in v["label"]],
}

print(f"\n{'='*80}")
print(f"  CATEGORY BREAKDOWN")
print(f"{'='*80}")
for cat_name, keys in categories.items():
    if not keys:
        continue
    cat_wins = sum(per_config_stats[k]["wins"] for k in keys)
    cat_total = sum(per_config_stats[k]["total"] for k in keys)
    if cat_total > 0:
        cat_wr = cat_wins / cat_total * 100
        icon = "✅" if cat_wr >= 50 else "⚠️" if cat_wr >= 20 else "❌"
        print(f"  {icon} {cat_name:22s}: {cat_wins:3d}/{cat_total:3d} ({cat_wr:5.1f}%)")
print(f"{'='*80}")

# Inference from Merged Model

Load the saved merged model from disk (no LoRA, no Unsloth) and verify it works for Minesweeper inference.
This tests that the model was saved correctly and can be loaded independently.

In [None]:
# ──────────────────────────────────────────────────────────────────────
# Load merged model from disk for inference testing
# No Unsloth, no LoRA — pure HuggingFace transformers
# ──────────────────────────────────────────────────────────────────────
import torch, gc
from transformers import AutoModelForCausalLM, AutoTokenizer

merged_dir = "my_minesweeper_model_merged"

print(f"Loading merged model from: {merged_dir}/")
print("  (This is a standalone model — no LoRA adapters needed)")

# Load tokenizer
merged_tokenizer = AutoTokenizer.from_pretrained(merged_dir)
print(f"  ✅ Tokenizer loaded ({merged_tokenizer.vocab_size} vocab)")

# Load model
merged_model = AutoModelForCausalLM.from_pretrained(
    merged_dir,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
merged_model.eval()
print(f"  ✅ Model loaded on {merged_model.device}")
print(f"  Parameters: {sum(p.numel() for p in merged_model.parameters()):,}")

# ──────────────────────────────────────────────────────────────────────
# Run inference on a diverse set of Minesweeper boards
# ──────────────────────────────────────────────────────────────────────

test_configs = [
    (1, 1, 0, 300, "1×1 Trivial"),
    (3, 3, 1, 301, "Tiny 3×3"),
    (5, 5, 3, 302, "Small 5×5"),
    (6, 6, 5, 99,  "Standard 6×6"),
    (8, 8, 10, 303, "Medium 8×8"),
    (10, 10, 20, 304, "Large 10×10"),
    (15, 15, 45, 305, "XL 15×15"),
    (1, 20, 4, 306, "Linear 1×20"),
    (5, 5, 0, 307, "Zero Mines 5×5"),
]

print(f"\n{'='*60}")
print(f"  MERGED MODEL INFERENCE TEST — {len(test_configs)} boards")
print(f"{'='*60}")

results = {"pass": 0, "fail": 0, "skip": 0}

for rows, cols, mines, seed, label in test_configs:
    print(f"\n--- {label} ({rows}×{cols}, {mines} mines, seed={seed}) ---")

    game = MinesweeperGame(rows=rows, cols=cols, num_mines=mines, seed=seed)

    if game.state() != "ongoing":
        print(f"  Game auto-resolved: {game.state()}")
        results["skip"] += 1
        continue

    # Build prompt using the same inference prompt system
    prompt = format_state_for_llm(game, mode="inference")

    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": prompt},
    ]
    text = merged_tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )

    inputs = merged_tokenizer(text, return_tensors="pt", truncation=True, max_length=2048)
    inputs = {k: v.to(merged_model.device) for k, v in inputs.items()}

    with torch.no_grad():
        output = merged_model.generate(
            **inputs,
            temperature=0.7,       # LAMER paper: 0.7 for eval
            max_new_tokens=128,    # HACKATHON CONSTRAINT: JSON-only
            do_sample=True,
            top_p=0.9,
            repetition_penalty=1.2,
        )

    # Decode only generated tokens (not the prompt)
    gen_tokens = output[0][inputs["input_ids"].shape[1]:]
    response = merged_tokenizer.decode(gen_tokens, skip_special_tokens=True).strip()

    print(f"  Response: {response[:200]}{'...' if len(response) > 200 else ''}")

    action = parse_llm_action(response)
    print(f"  Parsed:   {action}")

    if action:
        result = game.do_action(action)
        state = game.state()
        print(f"  Result:   {result} → game {state}")
        if result != "mine":
            results["pass"] += 1
        else:
            results["fail"] += 1
    else:
        print(f"  ⚠️ Failed to parse valid action")
        results["fail"] += 1

# ── Summary ──
print(f"\n{'='*60}")
print(f"  INFERENCE TEST SUMMARY")
print(f"{'='*60}")
print(f"  Pass (valid move): {results['pass']}/{results['pass'] + results['fail']}")
print(f"  Fail (bad parse/mine): {results['fail']}")
print(f"  Skipped (auto-win): {results['skip']}")
total = results['pass'] + results['fail']
if total > 0:
    print(f"  Success rate: {results['pass']/total*100:.1f}%")
print(f"\n✅ Merged model inference test complete!")
print(f"   Model path: {merged_dir}/")
print(f"   Ready for competition submission.")

# Training Notes

## Reward Functions (2 total)

| # | Function | Weight | Source |
|---|----------|--------|--------|
| 1 | `valid_json_reward` | 0.60 | GRPO-LEAD: exponential length penalty |
| 2 | `gameplay_scores` | 0.40 | XRPO: 12 criteria + center-opening + difficulty |

**Removed:** `strategic_reward` (was 5% weight = noise, ran constraint solver redundantly). Center-opening bonus folded into `gameplay_scores`.

## Key Config

| Parameter | Value | Source |
|-----------|-------|--------|
| temperature | 1.2→0.7 (annealed) | Prevents collapse |
| learning_rate | 5e-5 | Higher for faster convergence |
| num_iterations | 1 | Halved for speed |
| beta (KL) | 0.01 | Less KL restriction |
| num_generations | 16 | Faster per step |
| grad_accum_steps | 2 | 2× faster effective steps |
| reward_weights | [0.60, 0.40] | Format dominates |
| max_completion_length | 128 | Hackathon constraint |
| Format penalties | Exponential | -1.0 × 1.5^((extra-3)/5) |

## Early Stopping (4 monitors)

| Monitor | Warn | Stop |
|---------|------|------|
| reward_std | < 0.1 | 5 consecutive collapsed steps |
| JSON reward | 30% drop | 8 consecutive warnings |
| mean_length | > 25 tokens | > 40 tokens |
| clipped_ratio | > 5% | > 15% |