# Minesweeper LLM Competition - Custom GRPO Training

## Goal
Finetune an LLM with LoRA using GRPO to play Minesweeper by:
- **Input**: JSON game state (board configuration)
- **Output**: JSON action (reveal or flag a cell)

Teams will compete to train the best Minesweeper-playing LLM!

## Training Approach
- **Model**: Qwen2.5-14B-Instruct (from /root/.cache/huggingface/hub)
- **Method**: GRPO (Group Relative Policy Optimization)
- **Framework**: Unsloth (2-6x faster, 70% less VRAM)
- **Hardware**: AMD MI300X GPU (192GB HBM3, ROCm)

# Load Model with Unsloth

Load Qwen3-4B with LoRA configuration:

In [9]:
import os

os.environ["HF_HOME"] = "./workspace/hf_cache"
os.environ["HUGGINGFACE_HUB_CACHE"] = "./workspace/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "./workspace/hf_cache"
os.environ["HF_DATASETS_CACHE"] = "./workspace/hf_cache"


In [10]:
from huggingface_hub import snapshot_download

snapshot_download(
    repo_id="unsloth/Qwen2.5-14B-Instruct",
    local_dir="./workspace/Qwen2.5-14B-Instruct",
    local_dir_use_symlinks=False,
)


Ignored error while writing commit hash to /root/.cache/huggingface/models--unsloth--Qwen2.5-14B-Instruct/refs/main: [Errno 30] Read-only file system: '/root/.cache/huggingface/models--unsloth--Qwen2.5-14B-Instruct'.


'/workspace/workspace/Qwen2.5-14B-Instruct'

In [None]:
from unsloth import FastLanguageModel
import torch
import yaml

# Load config
with open("minesweeper_config_me.yaml", "r") as f:
    config = yaml.safe_load(f)

lora_rank = config.get("lora_rank", 32)

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="/workspace/workspace/Qwen2.5-14B-Instruct",
    load_in_4bit=False,   # AMD → 4bit disabled
    max_seq_length=2048,  # Increased: larger boards need longer prompts
    dtype=torch.bfloat16,
)

print("Model loaded successfully!")
print(f"Device: {model.device}")
print(f"LoRA rank: {lora_rank} (from config)")

Unsloth: AMD currently is not stable with 4bit bitsandbytes. Disabling for now.
==((====))==  Unsloth 2025.10.6: Fast Qwen2 patching. Transformers: 4.56.2. vLLM: 0.11.1rc2.dev161+g8a297115e.rocm700.
   \\   /|    . Num GPUs = 1. Max memory: 255.688 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+gitb2fb688. ROCm Toolkit: 7.0.51831-a3e329ad8. Triton: 3.4.0
\        /    Bfloat16 = TRUE. FA [Xformers = None. FA2 = True]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

Model loaded successfully!
Device: cuda:0


# Add LoRA Adapters

Add LoRA layers for efficient finetuning:

In [12]:
model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank,
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha = lora_rank,           # alpha = rank → scaling factor = 1.0 (stable training)
    lora_dropout = 0.05,              # Small dropout for regularization
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
)
print(f"LoRA config: rank={lora_rank}, alpha={lora_rank}, dropout=0.05")
model.print_trainable_parameters()

Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = 0.05.
Unsloth will patch all other layers, except LoRA matrices, causing a performance hit.
Unsloth 2025.10.6 patched 48 layers with 0 QKV layers, 0 O layers and 0 MLP layers.


LoRA config: rank=32, alpha=32, dropout=0.05
trainable params: 137,625,600 || all params: 14,907,659,264 || trainable%: 0.9232


# Minesweeper Game Implementation

Custom Minesweeper environment supporting:
- Customizable board size and mine count
- Actions: reveal or flag cells
- Win: reveal all safe cells
- Lose: reveal a mine

In [None]:
from dataclasses import dataclass, field
from typing import List, Tuple, Optional, Set
import random
import math

# ──────────────────────────────────────────────────────────────────────
# Board size configuration — competition spec: n,m ∈ [1,50], mines 0-20%
# ──────────────────────────────────────────────────────────────────────

# Constants
MIN_ROWS, MAX_ROWS = 1, 50
MIN_COLS, MAX_COLS = 1, 50
MIN_MINE_DENSITY = 0.0    # 0% mines allowed (trivial board)
MAX_MINE_DENSITY = 0.20   # 20% of total cells


def sample_board_config(rng=None):
    """Sample a random (rows, cols, num_mines) from the full competition range.

    - rows ∈ [1, 50], cols ∈ [1, 50]
    - mines ∈ [0, floor(0.20 * rows * cols)]
    - Uses a weighted distribution favoring smaller boards during training
      (large boards are rare but included for coverage).
    """
    rng = rng or random.Random()

    # Weighted size distribution: favor small/medium, still cover large
    size_band = rng.random()
    if size_band < 0.30:
        # Small: 1-8
        rows = rng.randint(1, 8)
        cols = rng.randint(1, 8)
    elif size_band < 0.55:
        # Medium: 5-15
        rows = rng.randint(5, 15)
        cols = rng.randint(5, 15)
    elif size_band < 0.75:
        # Large: 10-30
        rows = rng.randint(10, 30)
        cols = rng.randint(10, 30)
    elif size_band < 0.90:
        # XL: 20-40
        rows = rng.randint(20, 40)
        cols = rng.randint(20, 40)
    else:
        # Full range: 1-50 (including extreme cases)
        rows = rng.randint(1, 50)
        cols = rng.randint(1, 50)

    total = rows * cols
    max_mines = int(total * MAX_MINE_DENSITY)  # floor(0.20 * total)

    if max_mines == 0:
        num_mines = 0  # Boards too small for any mines at ≤20%
    else:
        num_mines = rng.randint(0, max_mines)

    return rows, cols, num_mines


def mine_density(rows, cols, num_mines):
    """Compute mine density as a fraction."""
    total = rows * cols
    return num_mines / total if total > 0 else 0.0


@dataclass
class MinesweeperGame:
    rows: int
    cols: int
    num_mines: int
    seed: Optional[int] = None
    _rng: random.Random = field(init=False, repr=False)
    _board: List[List[int]] = field(init=False, repr=False)  # -1 = mine, 0-8 = count
    _revealed: Set[Tuple[int, int]] = field(init=False, repr=False, default_factory=set)
    _flagged: Set[Tuple[int, int]] = field(init=False, repr=False, default_factory=set)
    _state: str = field(default="ongoing", init=False, repr=False)
    _move_count: int = field(default=0, init=False, repr=False)

    def __post_init__(self):
        # ── Input validation — competition spec: n,m ∈ [1,50] ──
        if self.rows < MIN_ROWS or self.cols < MIN_COLS:
            raise ValueError(f"Board too small: {self.rows}x{self.cols} (min {MIN_ROWS}x{MIN_COLS})")
        if self.rows > MAX_ROWS or self.cols > MAX_COLS:
            raise ValueError(f"Board too large: {self.rows}x{self.cols} (max {MAX_ROWS}x{MAX_COLS})")
        if self.num_mines < 0:
            raise ValueError(f"num_mines cannot be negative, got {self.num_mines}")
        if self.num_mines >= self.rows * self.cols:
            raise ValueError(f"Too many mines ({self.num_mines}) for {self.rows}x{self.cols} board")
        # 0 mines is allowed — trivial board, instant win on first reveal cascade

        self._rng = random.Random(self.seed)
        self._board = [[0 for _ in range(self.cols)] for _ in range(self.rows)]
        self._place_mines()
        self._calculate_numbers()

        # ── Edge case: 0 mines → all cells are safe, auto-win on init check ──
        self._check_win()

    def _place_mines(self):
        """Place mines randomly on the board."""
        if self.num_mines == 0:
            return  # No mines to place
        positions = [(r, c) for r in range(self.rows) for c in range(self.cols)]
        mine_positions = self._rng.sample(positions, self.num_mines)
        for r, c in mine_positions:
            self._board[r][c] = -1

    def _calculate_numbers(self):
        """Calculate numbers for each cell based on adjacent mines."""
        for r in range(self.rows):
            for c in range(self.cols):
                if self._board[r][c] == -1:
                    continue
                count = 0
                for dr in [-1, 0, 1]:
                    for dc in [-1, 0, 1]:
                        if dr == 0 and dc == 0:
                            continue
                        nr, nc = r + dr, c + dc
                        if 0 <= nr < self.rows and 0 <= nc < self.cols:
                            if self._board[nr][nc] == -1:
                                count += 1
                self._board[r][c] = count

    def _reveal_cell(self, row: int, col: int) -> bool:
        """Reveal a cell. Returns True if valid move, False if invalid.
        Uses iterative flood-fill to avoid recursion limit on large boards.
        """
        if not (0 <= row < self.rows and 0 <= col < self.cols):
            return False
        if (row, col) in self._revealed or (row, col) in self._flagged:
            return False

        stack = [(row, col)]
        while stack:
            r, c = stack.pop()
            if (r, c) in self._revealed:
                continue

            self._revealed.add((r, c))

            # Hit a mine!
            if self._board[r][c] == -1:
                self._state = "failed"
                return True

            # Auto-reveal neighbors if cell is 0
            if self._board[r][c] == 0:
                for dr in [-1, 0, 1]:
                    for dc in [-1, 0, 1]:
                        if dr == 0 and dc == 0:
                            continue
                        nr, nc = r + dr, c + dc
                        if (0 <= nr < self.rows and 0 <= nc < self.cols
                                and (nr, nc) not in self._revealed
                                and (nr, nc) not in self._flagged):
                            stack.append((nr, nc))

        return True

    def _flag_cell(self, row: int, col: int) -> bool:
        """Flag/unflag a cell. Returns True if valid, False if invalid."""
        if not (0 <= row < self.rows and 0 <= col < self.cols):
            return False
        if (row, col) in self._revealed:
            return False

        if (row, col) in self._flagged:
            self._flagged.remove((row, col))
        else:
            self._flagged.add((row, col))
        return True

    def do_action(self, action: dict) -> str:
        """Execute an action and return a status string.

        Returns one of:
          'ok'               - valid move executed
          'mine'             - revealed a mine (game over → state='failed')
          'win'              - game won after this move (all safe cells revealed)
          'invalid_format'   - bad action dict / missing keys / bad types
          'out_of_bounds'    - coordinates outside the board
          'already_revealed' - cell was already revealed
          'flagged_cell'     - tried to reveal a flagged cell
          'invalid_flag'     - tried to flag a revealed cell
          'game_over'        - game was already over before this call

        Only 'mine' sets state='failed'. All other invalid moves
        return an error string but keep the game 'ongoing'.
        NO move limit — game continues until success or failure.
        """
        if self._state != "ongoing":
            return "game_over"

        if not isinstance(action, dict):
            return "invalid_format"

        action_type = action.get("type")
        row = action.get("row")
        col = action.get("col")

        if action_type not in ("reveal", "flag") or row is None or col is None:
            return "invalid_format"

        try:
            row, col = int(row), int(col)
        except (ValueError, TypeError):
            return "invalid_format"

        if not (0 <= row < self.rows and 0 <= col < self.cols):
            return "out_of_bounds"

        if action_type == "reveal":
            if (row, col) in self._revealed:
                return "already_revealed"
            if (row, col) in self._flagged:
                return "flagged_cell"
            self._reveal_cell(row, col)
            self._move_count += 1
        else:  # flag
            if (row, col) in self._revealed:
                return "invalid_flag"
            self._flag_cell(row, col)
            self._move_count += 1

        self._check_win()

        if self._state == "failed":
            return "mine"
        if self._state == "success":
            return "win"
        return "ok"

    def _check_win(self):
        """Check if player has won.

        Win condition: ALL safe (non-mine) cells are revealed.
        With 0 mines, ALL cells are safe → need to reveal every cell.
        """
        if self._state != "ongoing":
            return
        total_cells = self.rows * self.cols
        safe_cells = total_cells - self.num_mines
        if safe_cells == 0:
            self._state = "success"
        elif len(self._revealed) >= safe_cells:
            self._state = "success"

    def get_visible_board(self) -> List[List[str]]:
        """Get board state as player sees it.
        Uses '?' for unrevealed cells (competition standard).
        """
        visible = []
        for r in range(self.rows):
            row = []
            for c in range(self.cols):
                if (r, c) in self._flagged:
                    row.append('F')
                elif (r, c) in self._revealed:
                    val = self._board[r][c]
                    row.append('*' if val == -1 else str(val))
                else:
                    row.append('?')
            visible.append(row)
        return visible

    def state(self) -> str:
        return self._state

    @property
    def move_count(self) -> int:
        return self._move_count

    def get_mine_positions(self) -> Set[Tuple[int, int]]:
        """Return set of all mine positions (for reward computation)."""
        return {(r, c) for r in range(self.rows) for c in range(self.cols)
                if self._board[r][c] == -1}

    def progress(self) -> float:
        """Fraction of safe cells revealed (0.0 to 1.0)."""
        safe_cells = self.rows * self.cols - self.num_mines
        return len(self._revealed) / safe_cells if safe_cells > 0 else 1.0

    def game_phase(self) -> str:
        """Determine the current game phase for prompt selection."""
        if self._move_count == 0 and len(self._revealed) == 0:
            return "opening"
        prog = self.progress()
        if prog >= 0.80:
            return "endgame"
        return "midgame"

    def pretty_print(self) -> str:
        """Pretty print the board."""
        visible = self.get_visible_board()
        lines = []

        # Header — handle up to 2-digit column numbers
        col_width = 3 if self.cols > 10 else 2
        header = "   " + " ".join(f"{i:>{col_width-1}d}" for i in range(self.cols))
        lines.append(header)
        lines.append("  " + "─" * (self.cols * col_width + 1))

        # Board
        for r, row in enumerate(visible):
            sep = " " * (col_width - 1)
            line = f"{r:2d}│ " + sep.join(row)
            lines.append(line)

        return "\n".join(lines)


# ──────────────────────────────────────────────────────────────────────
# Sanity tests
# ──────────────────────────────────────────────────────────────────────
print("Testing MinesweeperGame (competition spec: 1-50 boards, 0-20% mines)...")

# Basic gameplay
g = MinesweeperGame(5, 5, 3, seed=0)
assert g.state() == "ongoing"
assert g.do_action({"type": "reveal", "row": -1, "col": 0}) == "out_of_bounds"
assert g.state() == "ongoing", "BUG: out_of_bounds should NOT end the game"
assert g.do_action({"type": "reveal", "row": 99, "col": 0}) == "out_of_bounds"
assert g.state() == "ongoing"
assert g.do_action({"type": "flag", "row": 0, "col": 0}) == "ok"
assert g.do_action({"type": "reveal", "row": 0, "col": 0}) == "flagged_cell"
assert g.state() == "ongoing", "BUG: flagged_cell should NOT end the game"
assert g.do_action({}) == "invalid_format"
assert g.state() == "ongoing", "BUG: invalid_format should NOT end the game"
print("  ✅ do_action keeps game ongoing on invalid moves")

# Verify '?' is used for unrevealed cells
board = g.get_visible_board()
has_question = any('?' in row for row in board)
assert has_question, "Board should use '?' for unrevealed cells"
print("  ✅ Board uses '?' for unrevealed cells")

# Game phase tracking
assert g.game_phase() == "opening" or g.move_count > 0
print("  ✅ Game phase tracking works")

# ── Edge case: 0 mines board ──
g0 = MinesweeperGame(3, 3, 0, seed=42)
assert g0.state() == "ongoing"
result = g0.do_action({"type": "reveal", "row": 0, "col": 0})
assert result == "win", f"0-mine board should win on first reveal, got {result}"
assert g0.state() == "success"
print("  ✅ 0-mine board → instant win on first reveal")

# ── Edge case: 1x1 board with 0 mines ──
g1x1 = MinesweeperGame(1, 1, 0, seed=42)
assert g1x1.state() == "ongoing"
result = g1x1.do_action({"type": "reveal", "row": 0, "col": 0})
assert result == "win"
print("  ✅ 1x1 board with 0 mines works")

# ── Edge case: 1x1 board — cannot have mines ──
try:
    MinesweeperGame(1, 1, 1, seed=42)
    assert False, "Should have raised ValueError"
except ValueError:
    pass
print("  ✅ 1x1 board with 1 mine correctly rejected")

# ── Edge case: 50x50 board ──
g50 = MinesweeperGame(50, 50, 500, seed=42)
assert g50.state() == "ongoing"
assert g50.rows == 50 and g50.cols == 50
assert mine_density(50, 50, 500) == 0.20
print("  ✅ 50x50 board with 20% mines works")

# ── Edge cases: rectangular ──
g1x50 = MinesweeperGame(1, 50, 10, seed=42)
assert g1x50.rows == 1 and g1x50.cols == 50
g50x1 = MinesweeperGame(50, 1, 10, seed=42)
assert g50x1.rows == 50 and g50x1.cols == 1
print("  ✅ 1x50 and 50x1 boards work")

# ── Edge case: 2x2 with 0 mines ──
g2x2 = MinesweeperGame(2, 2, 0, seed=42)
result = g2x2.do_action({"type": "reveal", "row": 0, "col": 0})
assert result == "win"
print("  ✅ 2x2 board with 0 mines → instant cascade win")

# Variable sizes across the full range
for r, c, m in [(1,1,0), (2,2,0), (3,3,1), (5,5,3), (10,10,20),
                (20,20,80), (30,30,180), (50,50,500), (1,50,10), (50,1,10)]:
    g = MinesweeperGame(r, c, m, seed=42)
    assert g.rows == r and g.cols == c
    board = g.get_visible_board()
    assert len(board) == r and len(board[0]) == c
print(f"  ✅ Variable board sizes (1x1 to 50x50) work")

# Test sample_board_config produces valid configs
rng = random.Random(42)
for _ in range(200):
    r, c, m = sample_board_config(rng)
    assert MIN_ROWS <= r <= MAX_ROWS
    assert MIN_COLS <= c <= MAX_COLS
    assert 0 <= m <= int(r * c * MAX_MINE_DENSITY)
    g = MinesweeperGame(r, c, m, seed=42)
print(f"  ✅ sample_board_config produces valid configs (200 tested)")

# Test progress
g = MinesweeperGame(5, 5, 3, seed=42)
assert g.progress() == 0.0
print(f"  ✅ All game engine tests passed")
print(f"  Board range: {MIN_ROWS}-{MAX_ROWS} rows × {MIN_COLS}-{MAX_COLS} cols")
print(f"  Mine density: {MIN_MINE_DENSITY*100:.0f}%-{MAX_MINE_DENSITY*100:.0f}%")

# Prompt System & Game Logic Helpers

## Master Prompt System
Two prompt modes:
- **Training prompt** — Full strategic reasoning guidance, phase-aware, edge-case-aware
- **Inference prompt** — ~60% fewer tokens for fast evaluation

## 3-Tier Board Representation
| Format | Board Size | Display |
|--------|-----------|---------|
| **A (Small)** | 1–20 rows/cols | Full grid with borders and row/col headers |
| **B (Medium)** | 21–35 rows/cols | Frontier cells + revealed number regions |
| **C (Large)** | 36–50 rows/cols | Quadrant summary + critical area snippets |

## Phase-Aware Prompts
| Phase | Trigger | Strategy |
|-------|---------|----------|
| **Opening** | No cells revealed | Density-based center preference |
| **Mid-game** | Progress < 80% | Deduction checklist (satisfied/constrained numbers) |
| **Endgame** | Progress ≥ 80% | Flag accounting, completion logic |

## Edge Case Handling
0 mines, linear boards, tiny boards, large boards, very high density, one cell left

## Symbol Legend
- `?` = Unrevealed cell
- `F` = Flagged cell (suspected mine)
- `0`–`8` = Revealed safe cell (adjacent mine count)

## Output Format
```json
{"type": "reveal", "row": 2, "col": 3}
```
or
```json
{"type": "flag", "row": 1, "col": 4}
```

## Dataset Columns (for reward function reconstruction)
| Column | Type | Purpose |
|--------|------|---------|
| `prompt` | list[dict] | Chat-formatted prompt |
| `seed` | int | RNG seed to reconstruct game |
| `move_history` | str (JSON) | Previous moves as JSON list |
| `board_rows` | int | Number of rows (1–50) |
| `board_cols` | int | Number of columns (1–50) |
| `board_mines` | int | Number of mines (0–20% of cells) |

In [None]:
import json
import re

# ──────────────────────────────────────────────────────────────────────
# Minesweeper Logic Helpers — cached per game state
# ──────────────────────────────────────────────────────────────────────

def _compute_safe_and_mine_cells(game: MinesweeperGame):
    """Compute both safe and mine cells in a SINGLE pass (O(n²) not O(n⁴)).

    Returns (safe_set, mine_set) where each is a set of (row, col) tuples.

    A cell is logically SAFE if any adjacent revealed number has all its
    mines accounted for by flags (remaining_mines == 0).
    A cell is a logically CERTAIN mine if an adjacent number has
    remaining_mines == remaining unrevealed neighbors.
    """
    safe = set()
    mines = set()

    for r in range(game.rows):
        for c in range(game.cols):
            if (r, c) not in game._revealed:
                continue
            val = game._board[r][c]
            if val <= 0:
                continue

            flags = 0
            unrevealed = []
            for dr in [-1, 0, 1]:
                for dc in [-1, 0, 1]:
                    if dr == 0 and dc == 0:
                        continue
                    nr, nc = r + dr, c + dc
                    if 0 <= nr < game.rows and 0 <= nc < game.cols:
                        if (nr, nc) in game._flagged:
                            flags += 1
                        elif (nr, nc) not in game._revealed:
                            unrevealed.append((nr, nc))

            remaining = val - flags

            if remaining == 0 and unrevealed:
                for cell in unrevealed:
                    safe.add(cell)
            elif remaining > 0 and remaining == len(unrevealed):
                for cell in unrevealed:
                    mines.add(cell)

    return safe, mines


def _compute_safe_cells(game: MinesweeperGame) -> list:
    """Return list of [row, col] for logically safe cells."""
    safe, _ = _compute_safe_and_mine_cells(game)
    return [list(c) for c in safe]


def _compute_mine_cells(game: MinesweeperGame) -> list:
    """Return list of [row, col] for logically certain mine cells."""
    _, mines = _compute_safe_and_mine_cells(game)
    return [list(c) for c in mines]


def _is_logically_safe(game: MinesweeperGame, row: int, col: int) -> bool:
    safe, _ = _compute_safe_and_mine_cells(game)
    return (row, col) in safe


def _is_logically_mine(game: MinesweeperGame, row: int, col: int) -> bool:
    _, mines = _compute_safe_and_mine_cells(game)
    return (row, col) in mines


# ──────────────────────────────────────────────────────────────────────
# Board Representation — 3-tier format (A/B/C) based on size
# A: Small (1-20 rows/cols) — full grid
# B: Medium (21-35 rows/cols) — revealed regions + frontier
# C: Large (36-50 rows/cols) — critical areas + frontier summary
# ──────────────────────────────────────────────────────────────────────

def _format_board_small(game: MinesweeperGame) -> str:
    """Format A: Full grid display for boards ≤20 rows/cols."""
    board = game.get_visible_board()
    # Build column header
    if game.cols <= 10:
        col_header = "    " + " ".join(f"{c}" for c in range(game.cols))
        separator = "  +" + "-" * (game.cols * 2 + 1) + "+"
        rows_str = []
        for r, row in enumerate(board):
            rows_str.append(f"{r:2d}| " + " ".join(row) + " |")
    else:
        col_header = "    " + " ".join(f"{c:2d}" for c in range(game.cols))
        separator = "  +" + "-" * (game.cols * 3) + "+"
        rows_str = []
        for r, row in enumerate(board):
            rows_str.append(f"{r:2d}| " + " ".join(f"{v:>2s}" for v in row) + " |")

    return col_header + "\n" + separator + "\n" + "\n".join(rows_str) + "\n" + separator


def _get_frontier_and_numbers(game: MinesweeperGame):
    """Get frontier cells (unrevealed cells adjacent to revealed numbers)
    and the revealed number cells near the frontier."""
    frontier = set()
    number_cells = {}

    for r in range(game.rows):
        for c in range(game.cols):
            if (r, c) in game._revealed and game._board[r][c] > 0:
                number_cells[(r, c)] = game._board[r][c]
                for dr in [-1, 0, 1]:
                    for dc in [-1, 0, 1]:
                        nr, nc = r + dr, c + dc
                        if (0 <= nr < game.rows and 0 <= nc < game.cols
                                and (nr, nc) not in game._revealed):
                            frontier.add((nr, nc))

    return frontier, number_cells


def _extract_critical_areas(game: MinesweeperGame, frontier, number_cells, max_areas=3):
    """Extract small rectangular regions around frontier clusters for display."""
    if not frontier:
        return []

    # Cluster frontier cells by proximity
    frontier_list = sorted(frontier)
    areas = []
    used = set()

    for fr, fc in frontier_list:
        if (fr, fc) in used:
            continue

        # Find a small rectangle around this frontier cell
        r_min = max(0, fr - 1)
        r_max = min(game.rows - 1, fr + 1)
        c_min = max(0, fc - 1)
        c_max = min(game.cols - 1, fc + 1)

        # Expand to include nearby frontier cells
        for fr2, fc2 in frontier_list:
            if abs(fr2 - fr) <= 2 and abs(fc2 - fc) <= 2:
                r_min = min(r_min, max(0, fr2 - 1))
                r_max = max(r_max, min(game.rows - 1, fr2 + 1))
                c_min = min(c_min, max(0, fc2 - 1))
                c_max = max(c_max, min(game.cols - 1, fc2 + 1))
                used.add((fr2, fc2))

        # Format region
        board = game.get_visible_board()
        col_hdr = "    " + " ".join(f"{c:2d}" for c in range(c_min, c_max + 1))
        sep = "  +" + "-" * ((c_max - c_min + 1) * 3) + "+"
        rows_str = []
        for r in range(r_min, r_max + 1):
            cells = [f"{board[r][c]:>2s}" for c in range(c_min, c_max + 1)]
            rows_str.append(f"{r:2d}| " + " ".join(cells) + " |")

        area_str = f"Critical area near ({fr},{fc}):\n{col_hdr}\n{sep}\n"
        area_str += "\n".join(rows_str) + "\n" + sep
        areas.append(area_str)

        if len(areas) >= max_areas:
            break

    return areas


def _format_board_medium(game: MinesweeperGame) -> str:
    """Format B: Revealed regions + frontier for boards 21-35 rows/cols."""
    frontier, number_cells = _get_frontier_and_numbers(game)

    # Show critical areas
    areas = _extract_critical_areas(game, frontier, number_cells, max_areas=4)

    frontier_list = sorted(frontier)[:30]
    flagged_list = sorted(game._flagged)[:20]
    number_summary = [f"({r},{c})={v}" for (r, c), v in
                      sorted(number_cells.items())[:30]]

    lines = []
    lines.append(f"Board: {game.rows}×{game.cols} with {game.num_mines} mines "
                 f"({mine_density(game.rows, game.cols, game.num_mines)*100:.1f}% density)")
    lines.append("")

    if areas:
        for area in areas:
            lines.append(area)
            lines.append("")

    lines.append(f"Revealed numbers: {number_summary}")
    lines.append(f"Unrevealed frontier cells: {frontier_list}")
    if flagged_list:
        lines.append(f"Flagged cells: {flagged_list}")

    return "\n".join(lines)


def _format_board_large(game: MinesweeperGame) -> str:
    """Format C: Critical areas + frontier summary for boards 36-50 rows/cols."""
    frontier, number_cells = _get_frontier_and_numbers(game)

    # Quadrant summary
    mid_r, mid_c = game.rows // 2, game.cols // 2
    quadrants = {
        "Top-left": (0, mid_r, 0, mid_c),
        "Top-right": (0, mid_r, mid_c, game.cols),
        "Bottom-left": (mid_r, game.rows, 0, mid_c),
        "Bottom-right": (mid_r, game.rows, mid_c, game.cols),
    }

    lines = []
    lines.append(f"Board: {game.rows}×{game.cols} with {game.num_mines} mines "
                 f"({mine_density(game.rows, game.cols, game.num_mines)*100:.1f}% density)")
    lines.append("")

    safe_total = game.rows * game.cols - game.num_mines
    lines.append(f"Revealed safe cells: {len(game._revealed)}/{safe_total} "
                 f"({game.progress()*100:.1f}% complete)")
    lines.append("")

    lines.append("Revealed regions summary:")
    for qname, (r0, r1, c0, c1) in quadrants.items():
        q_total = (r1 - r0) * (c1 - c0)
        q_revealed = sum(1 for r in range(r0, r1) for c in range(c0, c1)
                        if (r, c) in game._revealed)
        status = "Fully explored" if q_revealed >= q_total * 0.9 else \
                 "Mostly complete" if q_revealed >= q_total * 0.5 else \
                 "Partially explored" if q_revealed > 0 else "Unexplored"
        lines.append(f"- {qname} ({r0}-{r1-1}, {c0}-{c1-1}): "
                     f"{status}, {q_revealed}/{q_total} cells revealed")
    lines.append("")

    # Critical areas
    areas = _extract_critical_areas(game, frontier, number_cells, max_areas=3)
    if areas:
        lines.append("Current frontier (unrevealed cells adjacent to revealed numbers):")
        for area in areas:
            lines.append(area)
            lines.append("")

    # Flag accounting
    if len(game._flagged) > 0:
        lines.append(f"Flags placed: {len(game._flagged)}/{game.num_mines} mines flagged")
        if len(game._flagged) == game.num_mines:
            lines.append("→ If you are certain all mines are flagged, reveal any remaining '?' to win")

    frontier_list = sorted(frontier)[:20]
    lines.append(f"\nTotal unrevealed non-flagged cells: "
                 f"{game.rows * game.cols - len(game._revealed) - len(game._flagged)}")
    if frontier_list:
        lines.append(f"Frontier cells (choose from these): {frontier_list}")

    return "\n".join(lines)


def _format_board(game: MinesweeperGame) -> str:
    """Select appropriate board format based on size."""
    if game.rows <= 20 and game.cols <= 20:
        return _format_board_small(game)
    elif game.rows <= 35 and game.cols <= 35:
        return _format_board_medium(game)
    else:
        return _format_board_large(game)


# ──────────────────────────────────────────────────────────────────────
# Master Prompt System — phase-aware, density-aware, edge-case-aware
# ──────────────────────────────────────────────────────────────────────

SYSTEM_PROMPT = (
    "You are an expert Minesweeper player with perfect logical reasoning abilities. "
    "Your goal is to reveal all non-mine cells while avoiding all mines. "
    "You must win by revealing every safe cell without hitting a single mine."
)


def _get_edge_case_guidance(game: MinesweeperGame) -> str:
    """Return edge-case-specific guidance if applicable, else empty string."""
    density = mine_density(game.rows, game.cols, game.num_mines) * 100

    if game.num_mines == 0:
        return (
            "\n=== EDGE CASE: ZERO MINES ===\n"
            "No mines exist on this board. Every cell is safe.\n"
            "Simply reveal any unrevealed cell to progress toward victory.\n"
            "No risk, move quickly.\n"
        )

    if game.rows == 1 or game.cols == 1:
        orientation = "1×N" if game.rows == 1 else "N×1"
        return (
            f"\n=== EDGE CASE: LINEAR BOARD ({orientation}) ===\n"
            "Only 2-3 neighbors per cell (ends have 1-2).\n"
            "Strong constraints, easier deduction.\n"
            "Play conservatively on ends.\n"
        )

    if game.rows <= 3 and game.cols <= 3:
        return (
            "\n=== EDGE CASE: TINY BOARD ===\n"
            f"≤3×3 board ({game.rows}×{game.cols}). Very limited cells, each move is critical.\n"
            "First move should be center cell if available.\n"
            "Fewer neighbors means faster deduction.\n"
        )

    if game.rows >= 30 or game.cols >= 30:
        return (
            "\n=== EDGE CASE: LARGE BOARD ===\n"
            f"{game.rows}×{game.cols} board. Focus on expanding safe regions systematically.\n"
            "Don't jump randomly across board.\n"
            "Build outward from revealed areas.\n"
        )

    if density >= 18:
        return (
            "\n=== EDGE CASE: VERY HIGH DENSITY ===\n"
            f"Mine density is {density:.1f}% (≥18%). Play extremely conservatively.\n"
            "Only make 100% certain moves. Accept slower progress to avoid losses.\n"
        )

    remaining = game.rows * game.cols - len(game._revealed) - len(game._flagged)
    remaining_mines = game.num_mines - len(game._flagged)
    if remaining <= 1 and remaining_mines == 0:
        return (
            "\n=== EDGE CASE: ONE CELL LEFT ===\n"
            "All mines are flagged. Reveal any remaining '?' cell to win!\n"
        )

    return ""


def _get_phase_guidance(game: MinesweeperGame) -> str:
    """Return phase-specific reasoning guidance."""
    phase = game.game_phase()
    density = mine_density(game.rows, game.cols, game.num_mines) * 100
    safe_total = game.rows * game.cols - game.num_mines
    remaining = safe_total - len(game._revealed)

    if phase == "opening":
        center_r, center_c = game.rows // 2, game.cols // 2
        safe_r_min = max(game.rows // 4, 0)
        safe_r_max = min(3 * game.rows // 4, game.rows - 1)
        safe_c_min = max(game.cols // 4, 0)
        safe_c_max = min(3 * game.cols // 4, game.cols - 1)

        return (
            "=== OPENING MOVE SCENARIO ===\n"
            "This is an opening move on a fresh board. No cells have been revealed yet.\n\n"
            "OPENING STRATEGY (density-based):\n"
            f"- Mine density: {density:.1f}%\n"
            + (f"- Low density (≤5%): any cell is relatively safe, prefer center for cascade potential\n"
               if density <= 5 else
               f"- Moderate density (≤10%): strongly prefer center region\n"
               if density <= 10 else
               f"- Elevated density (≤15%): avoid edges/corners, pick near-center\n"
               if density <= 15 else
               f"- High density (>{15}%): pick cell closest to exact center\n")
            + f"- Center cell: row={center_r}, col={center_c}\n"
            f"- Safe zone: rows {safe_r_min}-{safe_r_max}, cols {safe_c_min}-{safe_c_max}\n\n"
            "Make your first move (reveal only - no flagging on move 1):\n"
        )

    elif phase == "endgame":
        return (
            "=== ENDGAME COMPLETION SCENARIO ===\n"
            f"Board is {game.progress()*100:.1f}% complete. {remaining} safe cells remain.\n"
            f"Flags placed: {len(game._flagged)} (Total mines: {game.num_mines})\n"
            f"Mines unflagged: {game.num_mines - len(game._flagged)}\n\n"
            "ENDGAME LOGIC:\n"
            f"1. If flags_placed ({len(game._flagged)}) = num_mines ({game.num_mines}) "
            "AND unrevealed cells exist:\n"
            "   → ALL remaining unrevealed cells are SAFE - reveal any immediately\n"
            "2. If flags_placed < num_mines:\n"
            "   → Check for cells you can logically deduce as mines → flag them\n"
            "   → Then recheck if flags_placed = num_mines\n"
            "3. Carefully verify no simple deductions remain before guessing\n"
        )

    else:  # midgame
        return (
            "=== MID-GAME LOGICAL DEDUCTION SCENARIO ===\n"
            "Multiple regions are revealed. Focus on finding 100% certain logical moves.\n\n"
            "DEDUCTION CHECKLIST:\n"
            "□ Check all revealed numbers for satisfied constraint (flags_near = number)\n"
            "□ Check all revealed numbers for full constraint (unrevealed_near = number - flags_near)\n"
            "□ Check for chain deductions (revealing may create new logical moves)\n"
            f"□ Verify flag count: {len(game._flagged)} flags placed, {game.num_mines} total mines\n\n"
            "If logical moves exist → Make the safest logical move (prefer revealing over flagging)\n"
            "If no logical moves exist → Choose lowest-risk unrevealed cell\n"
        )


def format_state_for_llm(game: MinesweeperGame, mode="training") -> str:
    """Convert game state to a comprehensive prompt for the LLM.

    mode="training" — Full master prompt with strategic reasoning guidance
    mode="inference" — Minimal, fast prompt (60% fewer tokens) for evaluation

    Supports any board size from 1×1 to 50×50.
    Uses 3-tier board representation:
      A: Small (≤20) — full grid
      B: Medium (21-35) — regions + frontier
      C: Large (36-50) — critical areas + summary
    """
    # ── Edge case: game already won ──
    if game.state() == "success":
        return "Game already won. No action needed."

    total_cells = game.rows * game.cols
    safe_total = total_cells - game.num_mines
    density = mine_density(game.rows, game.cols, game.num_mines) * 100
    remaining_mines = game.num_mines - len(game._flagged)

    # ── Edge case: 0 mines (training or inference) ──
    if game.num_mines == 0:
        return (
            f"You are playing Minesweeper on a {game.rows}×{game.cols} board with 0 mines.\n"
            "All cells are safe. Reveal any unrevealed cell to win.\n\n"
            'Output ONLY valid JSON:\n'
            '{"type":"reveal","row":<int>,"col":<int>}'
        )

    # ── Board representation ──
    board_repr = _format_board(game)

    # ── Compute logical hints ──
    safe_cells = _compute_safe_cells(game)
    mine_cells = _compute_mine_cells(game)

    hint_lines = []
    if safe_cells:
        hint_lines.append(f"SAFE cells (reveal one): {safe_cells[:6]}")
    if mine_cells:
        hint_lines.append(f"MINE cells (flag one): {mine_cells[:6]}")
    if not safe_cells and not mine_cells:
        hint_lines.append("No logical deductions available — choose lowest-risk unrevealed cell.")
    hint_section = "\n".join(hint_lines)

    # ═══════════════════════════════════════════════════════════════════
    #  INFERENCE PROMPT — minimal, fast (~60% fewer tokens)
    # ═══════════════════════════════════════════════════════════════════
    if mode == "inference":
        prompt = (
            f"You are playing Minesweeper. Win by revealing all safe cells without hitting mines.\n\n"
            f"Board: {game.rows}×{game.cols}, {game.num_mines} mines ({density:.1f}%)\n"
            f"Revealed: {len(game._revealed)}/{safe_total} cells | "
            f"Flags: {len(game._flagged)} | Moves: {game.move_count}\n\n"
            f"{board_repr}\n\n"
            "Legend:\n"
            "- '?' = unrevealed, 'F' = flagged, '0'-'8' = safe revealed\n\n"
            "Rules:\n"
            "- Numbers show adjacent mine count (8 neighbors)\n"
            "- Deduce logically first, guess only if needed\n"
            + (f"- Opening: prefer center on dense boards (density={density:.1f}%)\n"
               if game.game_phase() == "opening" else "")
            + f"\n{hint_section}\n\n"
            'Output ONLY valid JSON:\n'
            '{"type":"reveal","row":<int>,"col":<int>} or {"type":"flag","row":<int>,"col":<int>}'
        )
        return prompt

    # ═══════════════════════════════════════════════════════════════════
    #  TRAINING PROMPT — full master prompt with strategic reasoning
    # ═══════════════════════════════════════════════════════════════════
    phase_guidance = _get_phase_guidance(game)
    edge_case_guidance = _get_edge_case_guidance(game)

    prompt = f"""You are an expert Minesweeper player with perfect logical reasoning abilities. Your goal is to reveal all non-mine cells while avoiding all mines.

=== GAME RULES ===

- Each cell is either a MINE or SAFE
- Revealed safe cells show a NUMBER (0-8) indicating adjacent mines
- A cell with 0 means all 8 neighbors are safe
- You can REVEAL a cell (if it's a mine, you LOSE immediately)
- You can FLAG a cell you deduce contains a mine
- WIN: Reveal all safe cells without hitting any mine
- LOSE: Reveal any cell containing a mine

=== CURRENT BOARD ===

Board Size: {game.rows} rows × {game.cols} columns
Total Cells: {total_cells}
Number of Mines: {game.num_mines}
Mine Density: {density:.1f}%
Cells Revealed: {len(game._revealed)}/{safe_total} safe cells
Cells Flagged: {len(game._flagged)} ({remaining_mines} mines remain unflagged)
Moves Played: {game.move_count}

{board_repr}

Legend:
- '?' = Unrevealed cell (unknown - could be mine or safe)
- '0'-'8' = Revealed safe cell showing count of adjacent mines
- 'F' = Flagged cell (you marked as suspected mine)

=== STRATEGIC REASONING PROCESS ===

{phase_guidance}

STEP 1: IDENTIFY LOGICAL DEDUCTIONS
Scan the board for 100% certain deductions:
a) SATISFIED NUMBERS — If number N has exactly N flagged neighbors:
   → All other unrevealed neighbors are SAFE → reveal one
b) CONSTRAINED NUMBERS — If number N has exactly N unrevealed/unflagged neighbors:
   → All those neighbors are MINES → flag one
c) ZERO CELLS — Revealed 0 means all 8 neighbors are safe

STEP 2: IF NO LOGICAL MOVES EXIST
Use probability estimation:
- Prefer cells with most revealed neighbors (more information)
- Prefer cells adjacent to LOW numbers (1s better than 3s)
- Avoid cells adjacent to HIGH numbers
- If flags_placed = num_mines → all remaining unrevealed are SAFE
{edge_case_guidance}
ANALYSIS:
{hint_section}

=== OUTPUT FORMAT ===

Output EXACTLY ONE action as valid JSON with NO additional text:

REVEAL: {{"type":"reveal","row":<int>,"col":<int>}}
FLAG: {{"type":"flag","row":<int>,"col":<int>}}

REQUIREMENTS:
- "row" must be integer from 0 to {game.rows - 1}
- "col" must be integer from 0 to {game.cols - 1}
- Do not reveal already-revealed or flagged cells
- Do not flag already-flagged or revealed cells
- Output ONLY the JSON object — no explanations after

Your action:"""

    return prompt


def parse_llm_action(response: str) -> dict:
    """Extract JSON action from LLM response.

    Finds all JSON-like objects and returns the LAST one matching the
    expected schema (type + row + col). LLMs typically place their
    final answer at the end.
    """
    best = None
    for match in re.finditer(r'\{[^{}]*\}', response):
        try:
            action = json.loads(match.group())
            if ("type" in action and "row" in action and "col" in action
                    and action["type"] in ("reveal", "flag")):
                # Ensure row/col are ints
                action["row"] = int(action["row"])
                action["col"] = int(action["col"])
                best = action
        except (json.JSONDecodeError, ValueError, TypeError):
            continue
    return best


# ──────────────────────────────────────────────────────────────────────
# Tests
# ──────────────────────────────────────────────────────────────────────

# Test training prompts on various sizes
print("Testing prompt system (training mode)...")
for rows, cols, mines in [(1,1,0), (3,3,1), (5,5,3), (6,6,5), (8,8,10),
                           (10,10,20), (1,10,2), (10,1,2)]:
    if mines >= rows * cols:
        continue
    game = MinesweeperGame(rows=rows, cols=cols, num_mines=mines, seed=42)
    if game.state() == "ongoing":
        prompt = format_state_for_llm(game, mode="training")
        assert f"{rows}" in prompt and f"{cols}" in prompt
        assert "'?'" in prompt or "0 mines" in prompt
        print(f"  {rows}x{cols} m={mines} training: {len(prompt)} chars")

# Test inference prompts
print("\nTesting prompt system (inference mode)...")
for rows, cols, mines in [(5,5,3), (6,6,5), (20,20,80)]:
    game = MinesweeperGame(rows=rows, cols=cols, num_mines=mines, seed=42)
    t_prompt = format_state_for_llm(game, mode="training")
    i_prompt = format_state_for_llm(game, mode="inference")
    reduction = (1 - len(i_prompt) / len(t_prompt)) * 100
    print(f"  {rows}x{cols}: training={len(t_prompt)} inference={len(i_prompt)} "
          f"({reduction:.0f}% reduction)")

# Test phase-aware prompts
print("\nTesting phase-aware prompts...")
game = MinesweeperGame(8, 8, 10, seed=42)
assert "OPENING" in format_state_for_llm(game, "training")
game.do_action({"type": "reveal", "row": 4, "col": 4})
p = format_state_for_llm(game, "training")
assert "MID-GAME" in p or "ENDGAME" in p
print("  ✅ Phase-aware prompts work")

# Test edge case prompts
game_0mine = MinesweeperGame(5, 5, 0, seed=42)
p0 = format_state_for_llm(game_0mine, "training")
assert "0 mines" in p0
print("  ✅ 0-mine edge case prompt works")

game_1xn = MinesweeperGame(1, 10, 2, seed=42)
p1xn = format_state_for_llm(game_1xn, "training")
assert "LINEAR" in p1xn
print("  ✅ Linear board edge case prompt works")

game_tiny = MinesweeperGame(3, 3, 1, seed=42)
ptiny = format_state_for_llm(game_tiny, "training")
assert "TINY" in ptiny
print("  ✅ Tiny board edge case prompt works")

# Test large board format
game_large = MinesweeperGame(30, 30, 180, seed=42)
game_large.do_action({"type": "reveal", "row": 15, "col": 15})
p_large = format_state_for_llm(game_large, "training")
print(f"  30x30 prompt: {len(p_large)} chars")

game_xl = MinesweeperGame(50, 50, 500, seed=42)
game_xl.do_action({"type": "reveal", "row": 25, "col": 25})
p_xl = format_state_for_llm(game_xl, "training")
assert "LARGE BOARD" in p_xl
print(f"  50x50 prompt: {len(p_xl)} chars")

# Test parse_llm_action edge cases
assert parse_llm_action('{"type":"reveal","row":2,"col":3}') == {"type": "reveal", "row": 2, "col": 3}
assert parse_llm_action('blah {"type":"flag","row":"1","col":"2"} done') == {"type": "flag", "row": 1, "col": 2}
assert parse_llm_action('no json here') is None
assert parse_llm_action('{"type":"invalid","row":0,"col":0}') is None
print("  ✅ parse_llm_action edge cases pass")

# Show example training prompt
print(f"\n{'='*60}")
game = MinesweeperGame(rows=6, cols=6, num_mines=5, seed=42)
game.do_action({"type": "reveal", "row": 0, "col": 0})
prompt = format_state_for_llm(game, mode="training")
print(f"=== Example 6x6 TRAINING prompt ({len(prompt)} chars) ===")
print(prompt[:1500])
if len(prompt) > 1500:
    print(f"... [{len(prompt) - 1500} more chars]")

print(f"\n{'='*60}")
prompt_inf = format_state_for_llm(game, mode="inference")
print(f"=== Example 6x6 INFERENCE prompt ({len(prompt_inf)} chars) ===")
print(prompt_inf)

You are an expert Minesweeper solver. Output ONE action as JSON only.

RULES:
- Numbers show how many of their 8 neighbors are mines.
- Subtract flagged neighbors from the number to find remaining mines.
- If remaining mines == remaining unrevealed neighbors → all are mines → FLAG.
- If remaining mines == 0 → all unrevealed neighbors are safe → REVEAL.
- Never reveal a flagged cell or flag a revealed cell.
- Prefer logically deducible moves over guessing.

{
  "board": [
    [
      ".",
      ".",
      ".",
      ".",
      ".",
      "."
    ],
    [
      ".",
      ".",
      ".",
      ".",
      ".",
      "."
    ],
    [
      ".",
      ".",
      ".",
      ".",
      ".",
      "."
    ],
    [
      ".",
      ".",
      ".",
      ".",
      ".",
      "."
    ],
    [
      ".",
      ".",
      ".",
      ".",
      ".",
      "."
    ],
    [
      ".",
      ".",
      ".",
      ".",
      ".",
      "."
    ]
  ],
  "rows": 6,
  "cols": 6,
  "mines": 5,
  "flags_pla

# Test Model Before Training

See how the base model performs without finetuning:

In [15]:
from transformers import TextStreamer

game = MinesweeperGame(rows=6, cols=6, num_mines=5, seed=42)
prompt = format_state_for_llm(game)

text = tokenizer.apply_chat_template(
    [{"role": "user", "content": prompt}],
    tokenize = False,
    add_generation_prompt = True,
)

print("=== Base Model Response ===")
output = model.generate(
    **tokenizer(text, return_tensors = "pt").to("cuda"),
    temperature = 0.7,
    top_p = 0.9,
    max_new_tokens = 128,
    do_sample = True,
    streamer = TextStreamer(tokenizer, skip_prompt = True),
)

=== Base Model Response ===
{"type":"reveal","row":0,"col":0}<|im_end|>


# GRPO Reward Functions

Define reward functions to guide the model's learning:

In [None]:
import numpy as np

# ──────────────────────────────────────────────────────────────────────
# Helper: Reconstruct game from dataset columns
# ──────────────────────────────────────────────────────────────────────

def _reconstruct_game(idx, kwargs):
    """Reconstruct a MinesweeperGame from dataset columns passed via GRPO kwargs.

    The dataset stores: seed, move_history, board_rows, board_cols, board_mines
    GRPOTrainer passes all non-'prompt' columns as lists in **kwargs.

    Returns (game, move_history_list) or (None, None) if data is missing.
    Handles 0-mine boards correctly.
    """
    seeds = kwargs.get("seed", [])
    move_histories = kwargs.get("move_history", [])
    rows_list = kwargs.get("board_rows", [])
    cols_list = kwargs.get("board_cols", [])
    mines_list = kwargs.get("board_mines", [])

    if idx >= len(seeds) or idx >= len(move_histories):
        return None, None

    seed = seeds[idx]
    rows = rows_list[idx] if idx < len(rows_list) else 6
    cols = cols_list[idx] if idx < len(cols_list) else 6
    num_mines = mines_list[idx] if idx < len(mines_list) else 5

    mh_raw = move_histories[idx]
    if isinstance(mh_raw, str):
        move_history = json.loads(mh_raw)
    else:
        move_history = list(mh_raw)

    game = MinesweeperGame(rows=int(rows), cols=int(cols),
                           num_mines=int(num_mines), seed=int(seed))
    for prev in move_history:
        result = game.do_action(prev)
        if result == "mine":
            return None, None  # History hit a mine — bad data

    return game, move_history


# ──────────────────────────────────────────────────────────────────────
# Reward 1: Valid JSON format + conciseness
# ──────────────────────────────────────────────────────────────────────

def valid_json_reward(prompts, completions, **kwargs):
    """Reward valid JSON action format. Also rewards conciseness.

    TRL GRPOTrainer calls: reward_func(prompts=..., completions=..., **kwargs)
    completions is a list of list-of-dicts (chat format).
    """
    scores = []
    for completion in completions:
        response = completion[0]["content"].strip() if completion else ""
        action = parse_llm_action(response)

        if action is None:
            scores.append(-3.0)
            continue

        # Bonus for pure JSON (no extra text)
        try:
            parsed = json.loads(response)
            if "type" in parsed and "row" in parsed and "col" in parsed:
                scores.append(3.0)  # Perfect — pure JSON only
                continue
        except json.JSONDecodeError:
            pass

        # Valid JSON but with extra surrounding text
        json_match = re.search(r'\{[^{}]*\}', response)
        extra_chars = len(response) - len(json_match.group()) if json_match else len(response)
        if extra_chars < 10:
            scores.append(2.5)
        elif extra_chars < 50:
            scores.append(1.0)
        elif extra_chars < 200:
            scores.append(0.0)
        else:
            scores.append(-1.0)  # Too verbose

    return scores


# ──────────────────────────────────────────────────────────────────────
# Reward 2: Gameplay — complete 12-criterion scoring
# Handles: 0-mine boards, 1x1 boards, large boards, all edge cases
# ──────────────────────────────────────────────────────────────────────

def gameplay_scores(prompts, completions, **kwargs):
    """
    Complete gameplay reward implementing all 12 scoring criteria.
    Uses variable board sizes from dataset columns.
    No move limit — only success or failure.

    1.  Flag cell that IS a mine        → +15 / +20 (logical)
    2.  Flag cell that is NOT a mine    → -10
    3.  Reveal cell that IS a mine      → -25
    4.  Reveal cell that is safe        → +10 (guess) / +15 (logical)
    5.  Flag already flagged cell       → -12
    6.  Reveal already revealed cell    → -12
    7.  Out of bounds                   → -15
    8.  Total flags > total mines       → -10 (additional)
    9.  Invalid JSON                    → -10
    10. Win the game                    → +100
    11. Reveal a flagged cell           → -8
    12. Flag a revealed cell            → -8
    """
    scores = []

    for idx, completion in enumerate(completions):
        response = completion[0]["content"] if completion else ""
        action = parse_llm_action(response)

        # ── Criterion 9: Invalid JSON ──
        if action is None:
            scores.append(-10.0)
            continue

        # ── Reconstruct game state from dataset columns ──
        game, move_history = _reconstruct_game(idx, kwargs)
        if game is None:
            scores.append(0.0)
            continue

        # ── Edge case: game already won (0-mine board or all safe revealed) ──
        if game.state() != "ongoing":
            # Game is already over — any action is irrelevant
            scores.append(0.0)
            continue

        row, col = action["row"], action["col"]
        action_type = action["type"]

        # ── Criterion 7: Out of bounds ──
        if not (0 <= row < game.rows and 0 <= col < game.cols):
            scores.append(-15.0)
            continue

        # ── Edge case: 0-mine board — any reveal is safe, any flag is wrong ──
        if game.num_mines == 0:
            if action_type == "reveal":
                if (row, col) in game._revealed:
                    scores.append(-12.0)  # Already revealed
                else:
                    # Check if this wins
                    game_copy = MinesweeperGame(
                        rows=game.rows, cols=game.cols,
                        num_mines=0, seed=kwargs.get("seed", [0])[idx]
                    )
                    for prev in move_history:
                        game_copy.do_action(prev)
                    result = game_copy.do_action(action)
                    win_bonus = 100.0 if result == "win" else 0.0
                    scores.append(15.0 + win_bonus)  # Always safe + possible win
            else:  # flag on 0-mine board
                scores.append(-10.0)  # No mines to flag
            continue

        # ── Compute logical deductions ONCE ──
        safe_set, mine_set = _compute_safe_and_mine_cells(game)

        score = 0.0

        if action_type == "reveal":
            # ── Criterion 6: Reveal already revealed cell ──
            if (row, col) in game._revealed:
                scores.append(-12.0)
                continue

            # ── Criterion 11: Reveal a flagged cell ──
            if (row, col) in game._flagged:
                scores.append(-8.0)
                continue

            # ── Criterion 3: Reveal a mine ──
            if game._board[row][col] == -1:
                scores.append(-25.0)
                continue

            # ── Criterion 4: Reveal safe cell ──
            if (row, col) in safe_set:
                score += 15.0   # Logically deduced safe cell
            else:
                score += 10.0   # Guessed safe cell

            # Small bonus for revealing near numbers (information-rich)
            board = game.get_visible_board()
            has_adjacent_number = False
            for dr in [-1, 0, 1]:
                for dc in [-1, 0, 1]:
                    nr, nc = row + dr, col + dc
                    if 0 <= nr < game.rows and 0 <= nc < game.cols:
                        if board[nr][nc] in ('1','2','3','4','5','6','7','8'):
                            has_adjacent_number = True
                            break
                if has_adjacent_number:
                    break
            if has_adjacent_number:
                score += 1.0

            # ── Criterion 10: Check for win ──
            game_copy = MinesweeperGame(
                rows=game.rows, cols=game.cols,
                num_mines=game.num_mines, seed=kwargs.get("seed", [0])[idx]
            )
            for prev in move_history:
                game_copy.do_action(prev)
            result = game_copy.do_action(action)
            if result == "win":
                score += 100.0

        elif action_type == "flag":
            # ── Criterion 5: Flag already flagged cell ──
            if (row, col) in game._flagged:
                scores.append(-12.0)
                continue

            # ── Criterion 12: Flag a revealed cell ──
            if (row, col) in game._revealed:
                scores.append(-8.0)
                continue

            # ── Criterion 1: Flag a mine (correct) ──
            if game._board[row][col] == -1:
                if (row, col) in mine_set:
                    score += 20.0   # Logically deduced mine
                else:
                    score += 15.0   # Correct but guessed

            # ── Criterion 2: Flag a non-mine (wrong) ──
            else:
                score -= 10.0

            # ── Criterion 8: Total flags > total mines ──
            new_flag_count = len(game._flagged) + 1
            if new_flag_count > game.num_mines:
                score -= 10.0

        scores.append(score)

    return scores


# ──────────────────────────────────────────────────────────────────────
# Reward 3: Strategic play — rewards logical deduction over guessing
# ──────────────────────────────────────────────────────────────────────

def strategic_reward(prompts, completions, **kwargs):
    """Reward strategic play patterns:
    - Choosing logically deducible moves when available
    - Opening in corners/edges (lower mine probability on fresh boards)
    - Penalize ignoring available deductions
    - Handles 0-mine boards (any reveal is correct)
    """
    scores = []

    for idx, completion in enumerate(completions):
        response = completion[0]["content"] if completion else ""
        action = parse_llm_action(response)

        if action is None:
            scores.append(0.0)
            continue

        game, move_history = _reconstruct_game(idx, kwargs)
        if game is None:
            scores.append(0.0)
            continue

        # Game already over — no strategic value
        if game.state() != "ongoing":
            scores.append(0.0)
            continue

        row, col = action["row"], action["col"]
        action_type = action["type"]
        score = 0.0

        if not (0 <= row < game.rows and 0 <= col < game.cols):
            scores.append(0.0)
            continue

        # ── 0-mine board: any reveal is trivially correct ──
        if game.num_mines == 0:
            if action_type == "reveal":
                scores.append(2.0)  # Small reward — too easy to need strategy
            else:
                scores.append(-2.0)  # Flagging with 0 mines is wrong
            continue

        # ── Compute logical deductions ONCE ──
        safe_set, mine_set = _compute_safe_and_mine_cells(game)

        # ── Fresh game opening strategy ──
        if len(game._revealed) == 0 and action_type == "reveal":
            corners = [(0, 0), (0, game.cols - 1),
                       (game.rows - 1, 0), (game.rows - 1, game.cols - 1)]
            edges = ([(0, c) for c in range(game.cols)] +
                     [(game.rows-1, c) for c in range(game.cols)] +
                     [(r, 0) for r in range(game.rows)] +
                     [(r, game.cols-1) for r in range(game.rows)])
            if (row, col) in corners:
                score += 3.0   # Corners have only 3 neighbors → lowest risk
            elif (row, col) in edges:
                score += 1.0   # Edges have 5 neighbors → moderate risk

        # ── Reward choosing logically deducible moves ──
        if action_type == "reveal" and (row, col) in safe_set:
            score += 5.0   # Chose a provably safe cell
        elif action_type == "flag" and (row, col) in mine_set:
            score += 5.0   # Chose a provably mine cell
        elif safe_set or mine_set:
            # Deducible moves existed but agent didn't pick one
            score -= 3.0

        # ── Penalize flagging on a fresh board (no info to deduce) ──
        if len(game._revealed) == 0 and action_type == "flag":
            score -= 2.0

        scores.append(score)

    return scores


# ── Verify reward function signatures ──
print("✅ All reward functions defined with correct TRL signature:")
print("   1. valid_json_reward(prompts, completions, **kwargs)")
print("   2. gameplay_scores(prompts, completions, **kwargs)")
print("   3. strategic_reward(prompts, completions, **kwargs)")
print()

# ── Smoke test: simulate what GRPOTrainer passes ──
test_completions = [[{"role": "assistant", "content": '{"type":"reveal","row":0,"col":0}'}]]
test_prompts = ["test"]

# Normal board
test_kwargs = {
    "seed": [42],
    "move_history": ["[]"],
    "board_rows": [6],
    "board_cols": [6],
    "board_mines": [5],
}
r1 = valid_json_reward(test_prompts, test_completions, **test_kwargs)
r2 = gameplay_scores(test_prompts, test_completions, **test_kwargs)
r3 = strategic_reward(test_prompts, test_completions, **test_kwargs)
print(f"  Smoke test (6x6 m=5): format={r1[0]:.1f}, gameplay={r2[0]:.1f}, strategic={r3[0]:.1f}")

# 0-mine board
test_kwargs_zero = {
    "seed": [42],
    "move_history": ["[]"],
    "board_rows": [5],
    "board_cols": [5],
    "board_mines": [0],
}
r2z = gameplay_scores(test_prompts, test_completions, **test_kwargs_zero)
r3z = strategic_reward(test_prompts, test_completions, **test_kwargs_zero)
print(f"  Smoke test (5x5 m=0): gameplay={r2z[0]:.1f}, strategic={r3z[0]:.1f}")

# 1x1 board with 0 mines
test_kwargs_1x1 = {
    "seed": [42],
    "move_history": ["[]"],
    "board_rows": [1],
    "board_cols": [1],
    "board_mines": [0],
}
# 1x1 with 0 mines: game auto-wins on first reveal which happens in do_action
# But actually the game should still be ongoing before first reveal
r2_1x1 = gameplay_scores(test_prompts, test_completions, **test_kwargs_1x1)
print(f"  Smoke test (1x1 m=0): gameplay={r2_1x1[0]:.1f}")

print(f"\n  ✅ All reward functions work with kwargs (seed, move_history, board_rows/cols/mines)")
print(f"  ✅ Edge cases: 0-mine boards, 1x1 boards handled")

✅ All reward functions defined:
   1. valid_json_reward   — format + conciseness
   2. gameplay_scores     — all 12 criteria
   3. strategic_reward    — logical deduction bonuses


# Create Training Dataset — Exhaustive 6-Phase Generation

Generate diverse game states with density-stratified sampling and bell-curve board sizes:

In [None]:
import os
from datasets import Dataset

# ══════════════════════════════════════════════════════════════════════
#  EXHAUSTIVE DATASET GENERATION
#  Fixes: 75% early deaths, no pattern training, 0.7% endgame,
#         no forced-guess scenarios, insufficient flag training
# ══════════════════════════════════════════════════════════════════════


# ──────────────────────────────────────────────────────────────────────
# Helper: Smart move selection with progressive flagging
# ──────────────────────────────────────────────────────────────────────

def _smart_reveal(game, rng):
    """Pick a smart cell to reveal during history generation.
    Prefers safe cells (if known) to avoid hitting mines and losing the game.
    Falls back to random unrevealed cell.
    """
    safe_set, _ = _compute_safe_and_mine_cells(game)
    if safe_set:
        return rng.choice(list(safe_set))

    # Fallback: random unrevealed, unflagged cell
    unrevealed = [(r, c) for r in range(game.rows) for c in range(game.cols)
                  if (r, c) not in game._revealed and (r, c) not in game._flagged]
    if not unrevealed:
        return None
    return rng.choice(unrevealed)


def _smart_flag(game, rng):
    """Pick a logically certain mine to flag, or None if none available."""
    _, mine_set = _compute_safe_and_mine_cells(game)
    mine_candidates = [c for c in mine_set if c not in game._flagged]
    if mine_candidates:
        return rng.choice(mine_candidates)
    return None  # Don't flag randomly — creates bad training data


def _progressive_flag_probability(game):
    """Adaptive flagging based on game progress (LAMER-inspired).
    - Early game (0-30%):  10% flagging (explore first)
    - Mid game (30-70%):   30% flagging (deduce mines)
    - Late game (70-100%): 50% flagging (lock in certain mines)
    """
    progress = game.progress()
    if progress < 0.30:
        return 0.10
    elif progress < 0.70:
        return 0.30
    else:
        return 0.50


def _play_smart_moves(game, rng, num_moves, use_progressive_flags=True):
    """Play num_moves smart moves on a game. Returns move_history list.
    Uses progressive flagging strategy when enabled.
    Returns early if game ends or gets stuck.

    Fix #1: Added stuck_count to prevent infinite loops when no valid
    moves can be found (e.g., all cells revealed/flagged but game ongoing).
    """
    move_history = []
    stuck_count = 0
    MAX_STUCK = 10  # Abort after 10 consecutive failed attempts

    for _ in range(num_moves):
        if game.state() != "ongoing":
            break

        action_dict = None
        flag_prob = _progressive_flag_probability(game) if use_progressive_flags else 0.15

        # Try flagging with adaptive probability
        if rng.random() < flag_prob:
            flag_target = _smart_flag(game, rng)
            if flag_target:
                action_dict = {"type": "flag", "row": flag_target[0], "col": flag_target[1]}

        # Fall back to reveal
        if action_dict is None:
            reveal_target = _smart_reveal(game, rng)
            if reveal_target is None:
                stuck_count += 1
                if stuck_count >= MAX_STUCK:
                    break  # Prevent infinite loop
                continue
            action_dict = {"type": "reveal", "row": reveal_target[0], "col": reveal_target[1]}

        result = game.do_action(action_dict)
        if result == "mine":
            break  # Hit a mine — stop

        move_history.append(action_dict)
        stuck_count = 0  # Reset on successful move

    return move_history


# ──────────────────────────────────────────────────────────────────────
# Scenario generators: endgame, forced-guess, multi-region
# ──────────────────────────────────────────────────────────────────────

def _generate_endgame_state(rows, cols, num_mines, rng, completion_target=0.85):
    """Generate boards that are 80-98% complete.
    Critical for teaching finishing strategy and flag accounting.
    Returns (game, move_history) or (None, None) on failure.
    """
    seed = rng.randint(0, 999999)
    game = MinesweeperGame(rows=rows, cols=cols, num_mines=num_mines, seed=seed)

    if game.state() != "ongoing":
        return None, None, seed

    safe_total = rows * cols - num_mines
    if safe_total <= 1:
        return None, None, seed

    target_revealed = int(safe_total * rng.uniform(completion_target, 0.98))
    target_revealed = max(1, min(target_revealed, safe_total - 1))

    move_history = []
    stuck_count = 0
    max_stuck = 20

    while len(game._revealed) < target_revealed and game.state() == "ongoing":
        safe_set, mine_set = _compute_safe_and_mine_cells(game)

        if safe_set:
            target = rng.choice(list(safe_set))
            action_dict = {"type": "reveal", "row": target[0], "col": target[1]}
        else:
            # No logical moves — reveal a random safe cell (we know the board)
            all_safe_unrevealed = [
                (r, c) for r in range(rows) for c in range(cols)
                if game._board[r][c] != -1
                and (r, c) not in game._revealed
                and (r, c) not in game._flagged
            ]
            if not all_safe_unrevealed:
                break
            target = rng.choice(all_safe_unrevealed)
            action_dict = {"type": "reveal", "row": target[0], "col": target[1]}
            stuck_count += 1
            if stuck_count >= max_stuck:
                break

        result = game.do_action(action_dict)
        if result == "mine":
            return None, None, seed
        move_history.append(action_dict)

    # Optionally add some flags in endgame
    if game.state() == "ongoing":
        _, mine_set = _compute_safe_and_mine_cells(game)
        flaggable = [c for c in mine_set if c not in game._flagged]
        if flaggable and rng.random() < 0.7:
            num_flags = rng.randint(1, min(len(flaggable), 5))
            for target in rng.sample(flaggable, num_flags):
                action_dict = {"type": "flag", "row": target[0], "col": target[1]}
                game.do_action(action_dict)
                move_history.append(action_dict)

    if game.state() != "ongoing":
        return None, None, seed

    return game, move_history, seed


def _generate_forced_guess_state(rows, cols, num_mines, rng):
    """Generate a board state where no 100% certain logical moves exist.
    These teach the model probability-based guessing.
    Returns (game, move_history, seed) or (None, None, seed).
    """
    seed = rng.randint(0, 999999)
    game = MinesweeperGame(rows=rows, cols=cols, num_mines=num_mines, seed=seed)

    if game.state() != "ongoing":
        return None, None, seed

    move_history = []
    max_reveal_attempts = rows * cols

    for _ in range(max_reveal_attempts):
        if game.state() != "ongoing":
            return None, None, seed

        safe_set, mine_set = _compute_safe_and_mine_cells(game)

        if not safe_set and not mine_set and len(game._revealed) > 0:
            # No logical moves available — this is what we want!
            unrevealed_count = sum(
                1 for r in range(rows) for c in range(cols)
                if (r, c) not in game._revealed and (r, c) not in game._flagged
            )
            if unrevealed_count >= 2:
                return game, move_history, seed

        if safe_set:
            target = rng.choice(list(safe_set))
            action_dict = {"type": "reveal", "row": target[0], "col": target[1]}
        else:
            # Must guess — reveal random safe cell (using board knowledge)
            all_safe = [
                (r, c) for r in range(rows) for c in range(cols)
                if game._board[r][c] != -1
                and (r, c) not in game._revealed
                and (r, c) not in game._flagged
            ]
            if not all_safe:
                break
            target = rng.choice(all_safe)
            action_dict = {"type": "reveal", "row": target[0], "col": target[1]}

        result = game.do_action(action_dict)
        if result == "mine":
            return None, None, seed
        move_history.append(action_dict)

    # Fallback: return whatever state we reached (might still have logical moves)
    if game.state() == "ongoing" and len(game._revealed) > 0:
        return game, move_history, seed
    return None, None, seed


def _generate_multi_region_state(rows, cols, num_mines, rng):
    """Create board with multiple separated revealed regions.
    Teaches model to reason across disconnected information.
    Returns (game, move_history, seed) or (None, None, seed).
    """
    if rows < 6 or cols < 6:
        return None, None, 0  # Need space for multiple regions

    seed = rng.randint(0, 999999)
    game = MinesweeperGame(rows=rows, cols=cols, num_mines=num_mines, seed=seed)

    if game.state() != "ongoing":
        return None, None, seed

    # Pick 2-4 starting points in different quadrants
    quadrant_centers = [
        (rows // 4, cols // 4),
        (rows // 4, 3 * cols // 4),
        (3 * rows // 4, cols // 4),
        (3 * rows // 4, 3 * cols // 4),
    ]
    rng.shuffle(quadrant_centers)
    num_regions = rng.randint(2, min(4, len(quadrant_centers)))
    selected = quadrant_centers[:num_regions]

    move_history = []
    for r, c in selected:
        if game.state() != "ongoing":
            break
        r = max(0, min(r, rows - 1))
        c = max(0, min(c, cols - 1))

        if (r, c) not in game._revealed and game._board[r][c] != -1:
            action_dict = {"type": "reveal", "row": r, "col": c}
            result = game.do_action(action_dict)
            if result == "mine":
                return None, None, seed
            move_history.append(action_dict)

        # Reveal a few logical neighbors around this region
        for _ in range(rng.randint(1, 3)):
            if game.state() != "ongoing":
                break
            safe_set, _ = _compute_safe_and_mine_cells(game)
            if safe_set:
                # Prefer safe cells near this region
                nearby = [
                    (sr, sc) for sr, sc in safe_set
                    if abs(sr - r) <= rows // 3 and abs(sc - c) <= cols // 3
                ]
                target = rng.choice(nearby) if nearby else rng.choice(list(safe_set))
                action_dict = {"type": "reveal", "row": target[0], "col": target[1]}
                result = game.do_action(action_dict)
                if result == "mine":
                    return None, None, seed
                move_history.append(action_dict)

    if game.state() == "ongoing" and len(game._revealed) > 0:
        return game, move_history, seed
    return None, None, seed


def _generate_satisfied_numbers_state(rows, cols, num_mines, rng):
    """Generate board with multiple satisfied numbers (easy deductions available).
    Satisfied number = number whose count of adjacent flags equals its value.
    All remaining unrevealed neighbors of a satisfied number are safe.
    Teaches basic constraint satisfaction.
    Returns (game, move_history, seed).
    """
    seed = rng.randint(0, 999999)
    game = MinesweeperGame(rows=rows, cols=cols, num_mines=num_mines, seed=seed)

    if game.state() != "ongoing":
        return None, None, seed

    move_history = []

    # Reveal some cells first
    num_initial = rng.randint(3, max(3, min(15, rows * cols // 4)))
    for _ in range(num_initial):
        if game.state() != "ongoing":
            break
        target = _smart_reveal(game, rng)
        if target is None:
            break
        action_dict = {"type": "reveal", "row": target[0], "col": target[1]}
        result = game.do_action(action_dict)
        if result == "mine":
            return None, None, seed
        move_history.append(action_dict)

    # Now flag logically certain mines to create satisfied numbers
    if game.state() == "ongoing":
        for _ in range(5):
            _, mine_set = _compute_safe_and_mine_cells(game)
            flaggable = [c for c in mine_set if c not in game._flagged]
            if not flaggable:
                break
            target = rng.choice(flaggable)
            action_dict = {"type": "flag", "row": target[0], "col": target[1]}
            game.do_action(action_dict)
            move_history.append(action_dict)

    if game.state() == "ongoing" and len(game._revealed) > 0:
        return game, move_history, seed
    return None, None, seed


# ──────────────────────────────────────────────────────────────────────
# Density-stratified board sampling
# ──────────────────────────────────────────────────────────────────────

DENSITY_TARGETS = [
    # (density_range, weight, label)
    ((0.00, 0.00), 0.08, "Zero mines"),       # 8%  - trivial edge case
    ((0.01, 0.05), 0.17, "Very sparse"),       # 17%
    ((0.05, 0.10), 0.25, "Sparse"),            # 25%
    ((0.10, 0.15), 0.25, "Medium"),            # 25%
    ((0.15, 0.20), 0.25, "Dense/Max"),         # 25%
]


# ──────────────────────────────────────────────────────────────────────
# Board Size Distribution — Bell-curve centered on mid-range (8-20)
# ──────────────────────────────────────────────────────────────────────

SIZE_BANDS = [
    # (min_dim, max_dim, weight, label)
    (1,   5,  0.10, "Tiny"),       # 10% — basics + edge cases
    (5,  10,  0.20, "Small"),      # 20% — simple patterns
    (8,  20,  0.35, "Mid"),        # 35% — PEAK: core curriculum ★
    (15, 30,  0.20, "Large"),      # 20% — generalization
    (25, 40,  0.10, "XL"),         # 10% — harder boards
    (35, 50,  0.05, "XXL"),        #  5% — edge exposure only
]


def _sample_board_size(rng):
    """Sample board dimensions using bell-curve distribution."""
    weights = [w for _, _, w, _ in SIZE_BANDS]
    idx = rng.choices(range(len(SIZE_BANDS)), weights=weights, k=1)[0]
    lo, hi, _, _ = SIZE_BANDS[idx]
    rows = rng.randint(lo, hi)
    cols = rng.randint(lo, hi)
    return rows, cols


def _sample_board_with_density(rng, target_density_range=None):
    """Sample board config with explicit density control + bell-curve sizes."""
    if target_density_range is None:
        weights = [w for _, w, _ in DENSITY_TARGETS]
        ranges = [r for r, _, _ in DENSITY_TARGETS]
        idx = rng.choices(range(len(DENSITY_TARGETS)), weights=weights, k=1)[0]
        target_density_range = ranges[idx]

    rows, cols = _sample_board_size(rng)
    total = rows * cols
    min_d, max_d = target_density_range

    if min_d == 0.0 and max_d == 0.0:
        return rows, cols, 0

    min_mines = max(0, int(math.ceil(total * min_d)))
    max_mines = min(int(total * max_d), int(total * MAX_MINE_DENSITY), total - 1)

    if max_mines < min_mines:
        num_mines = max(0, min(min_mines, total - 1))
    else:
        num_mines = rng.randint(min_mines, max_mines)

    return rows, cols, num_mines


# ──────────────────────────────────────────────────────────────────────
# Exhaustive edge-case configs (50+ scenarios)
# ──────────────────────────────────────────────────────────────────────

EDGE_CASE_CONFIGS = [
    # (rows, cols, mines, label)

    # === TRIVIAL BOARDS ===
    (1, 1, 0,   "1x1 trivial"),
    (2, 2, 0,   "2x2 no mines"),
    (3, 3, 0,   "3x3 no mines"),
    (4, 4, 0,   "4x4 no mines"),
    (5, 5, 0,   "5x5 no mines - cascade practice"),

    # === LINEAR BOARDS (1D Minesweeper) ===
    (1, 5, 1,   "1x5 single mine"),
    (1, 10, 2,  "1x10 two mines"),
    (1, 20, 4,  "1x20 row"),
    (1, 50, 10, "1x50 maximum row"),
    (5, 1, 1,   "5x1 column"),
    (10, 1, 2,  "10x1 column"),
    (20, 1, 4,  "20x1 column"),
    (50, 1, 10, "50x1 maximum column"),

    # === TINY BOARDS ===
    (1, 2, 0,   "1x2 trivial"),
    (2, 1, 0,   "2x1 trivial"),
    (2, 2, 1,   "2x2 single mine"),
    (2, 3, 1,   "2x3 mini"),
    (3, 2, 1,   "3x2 mini"),
    (3, 3, 1,   "3x3 single mine"),
    (3, 3, 2,   "3x3 two mines"),
    (4, 4, 3,   "4x4 medium density"),

    # === EXTREME RECTANGULAR ===
    (2, 50, 20, "2x50 ultra-wide"),
    (50, 2, 20, "50x2 ultra-tall"),
    (3, 40, 24, "3x40 extreme ratio"),
    (40, 3, 24, "40x3 extreme ratio"),
    (5, 50, 50, "5x50 max density wide"),
    (50, 5, 50, "50x5 max density tall"),

    # === CLASSIC MINESWEEPER SIZES ===
    (8, 8, 1,   "8x8 very sparse"),
    (8, 8, 5,   "8x8 sparse"),
    (8, 8, 10,  "8x8 medium - classic beginner"),
    (8, 8, 13,  "8x8 max density"),
    (16, 16, 10, "16x16 sparse"),
    (16, 16, 40, "16x16 medium - classic intermediate"),
    (16, 16, 51, "16x16 max density"),
    (16, 30, 20, "16x30 rectangular sparse"),
    (16, 30, 96, "16x30 rectangular dense - classic expert"),

    # === LARGE BOARDS ===
    (20, 20, 10, "20x20 very sparse"),
    (20, 20, 80, "20x20 max density"),
    (25, 25, 30, "25x25 sparse"),
    (25, 25, 125, "25x25 max density"),
    (30, 30, 50, "30x30 sparse"),
    (30, 30, 180, "30x30 max density"),
    (40, 40, 100, "40x40 sparse"),
    (40, 40, 320, "40x40 max density"),

    # === MAXIMUM SIZE ===
    (50, 50, 10,  "50x50 ultra-sparse"),
    (50, 50, 100, "50x50 sparse"),
    (50, 50, 250, "50x50 medium"),
    (50, 50, 500, "50x50 maximum density"),

    # === DENSITY EDGE CASES ===
    (10, 10, 0,  "10x10 no mines"),
    (15, 15, 0,  "15x15 no mines - large trivial"),
    (20, 20, 0,  "20x20 no mines - huge trivial"),
    (10, 10, 1,  "10x10 single mine"),
    (20, 20, 1,  "20x20 single mine in large board"),

    # === MORE RECTANGULAR VARIETY ===
    (3, 50, 30,  "3x50 wide"),
    (50, 3, 30,  "50x3 tall"),
    (5, 15, 15,  "5x15 wide rect"),
    (15, 5, 15,  "15x5 tall rect"),
    (6, 8, 6,    "6x8 rect"),
    (8, 6, 6,    "8x6 rect"),
    (7, 13, 18,  "7x13 odd rect"),
    (13, 7, 18,  "13x7 odd rect"),
    (15, 30, 90, "15x30 wide large"),
]


# ──────────────────────────────────────────────────────────────────────
# Dataset item builder (DRY helper)
# ──────────────────────────────────────────────────────────────────────

def _build_dataset_item(game, seed, move_history):
    """Build a single dataset item dict from a game state."""
    prompt_text = format_state_for_llm(game)
    return {
        "prompt": [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": prompt_text},
        ],
        "seed": seed,
        "move_history": json.dumps(move_history),
        "board_rows": game.rows,
        "board_cols": game.cols,
        "board_mines": game.num_mines,
    }


# ══════════════════════════════════════════════════════════════════════
#  MAIN GENERATOR — 6-Phase Exhaustive Composition
# ══════════════════════════════════════════════════════════════════════

def generate_exhaustive_dataset(num_samples=1000, rng_seed=42):
    """
    Comprehensive Minesweeper training dataset covering ALL scenarios.

    6-Phase composition:
      Phase 1: Edge cases           — 10%  (50+ explicit configs)
      Phase 2: Opening-heavy        — 25%  (fresh + single-move boards)
      Phase 3: Pattern-specific     — 15%  (satisfied numbers, multi-region)
      Phase 4: Mid-game deduction   — 25%  (core logical reasoning)
      Phase 5: Endgame completion   — 15%  (80-98% revealed, flag accounting)
      Phase 6: Forced guess         — 10%  (no logical moves available)
    """
    rng = random.Random(rng_seed)
    np.random.seed(rng_seed)

    dataset_items = []
    phase_counts = {
        "edge_case": 0, "opening": 0, "pattern": 0,
        "midgame": 0, "endgame": 0, "forced_guess": 0,
    }
    config_counts = {}
    density_counts = {"zero": 0, "very_sparse": 0, "sparse": 0, "medium": 0, "dense": 0}

    def _track_config(game):
        key = f"{game.rows}x{game.cols}m{game.num_mines}"
        config_counts[key] = config_counts.get(key, 0) + 1
        d = game.num_mines / (game.rows * game.cols) if game.rows * game.cols > 0 else 0
        if d == 0:
            density_counts["zero"] += 1
        elif d <= 0.05:
            density_counts["very_sparse"] += 1
        elif d <= 0.10:
            density_counts["sparse"] += 1
        elif d <= 0.15:
            density_counts["medium"] += 1
        else:
            density_counts["dense"] += 1

    # Budget per phase
    n_edge     = int(num_samples * 0.10)
    n_opening  = int(num_samples * 0.25)
    n_pattern  = int(num_samples * 0.15)
    n_midgame  = int(num_samples * 0.25)
    n_endgame  = int(num_samples * 0.15)
    n_forced   = num_samples - n_edge - n_opening - n_pattern - n_midgame - n_endgame

    print(f"  Phase budgets: edge={n_edge}, opening={n_opening}, pattern={n_pattern}, "
          f"midgame={n_midgame}, endgame={n_endgame}, forced={n_forced}")

    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    # PHASE 1: Edge Cases (10%)
    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    print("  Phase 1: Edge cases...")
    edge_generated = 0
    edge_idx = 0
    while edge_generated < n_edge:
        ec_rows, ec_cols, ec_mines, ec_label = EDGE_CASE_CONFIGS[edge_idx % len(EDGE_CASE_CONFIGS)]
        edge_idx += 1
        seed = rng.randint(0, 999999)

        game = MinesweeperGame(rows=ec_rows, cols=ec_cols, num_mines=ec_mines, seed=seed)
        if game.state() != "ongoing":
            continue

        num_moves = rng.randint(0, min(3, max(0, ec_rows * ec_cols - ec_mines - 1)))
        move_history = _play_smart_moves(game, rng, num_moves, use_progressive_flags=True)

        if game.state() == "ongoing":
            dataset_items.append(_build_dataset_item(game, seed, move_history))
            _track_config(game)
            phase_counts["edge_case"] += 1
            edge_generated += 1

    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    # PHASE 2: Opening-Heavy Training (25%)
    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    print("  Phase 2: Opening training...")
    opening_generated = 0
    opening_attempts = 0
    while opening_generated < n_opening and opening_attempts < n_opening * 5:
        opening_attempts += 1
        rows, cols, num_mines = _sample_board_with_density(rng)

        seed = rng.randint(0, 999999)
        game = MinesweeperGame(rows=rows, cols=cols, num_mines=num_mines, seed=seed)

        if num_mines == 0:
            if game.state() == "ongoing":
                dataset_items.append(_build_dataset_item(game, seed, []))
                _track_config(game)
                phase_counts["opening"] += 1
                opening_generated += 1
            continue

        if game.state() != "ongoing":
            continue

        num_moves = 0 if rng.random() < 0.75 else 1
        move_history = _play_smart_moves(game, rng, num_moves, use_progressive_flags=False)

        if game.state() == "ongoing":
            dataset_items.append(_build_dataset_item(game, seed, move_history))
            _track_config(game)
            phase_counts["opening"] += 1
            opening_generated += 1

    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    # PHASE 3: Pattern-Specific Scenarios (15%)
    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    print("  Phase 3: Pattern-specific scenarios...")
    pattern_generated = 0
    pattern_attempts = 0
    while pattern_generated < n_pattern and pattern_attempts < n_pattern * 10:
        pattern_attempts += 1

        if rng.random() < 0.60:
            rows, cols, num_mines = _sample_board_with_density(rng, (0.05, 0.20))
            if rows < 3 or cols < 3:
                continue
            game, move_history, seed = _generate_satisfied_numbers_state(
                rows, cols, num_mines, rng
            )
        else:
            rows, cols, num_mines = _sample_board_with_density(rng, (0.05, 0.15))
            rows = max(rows, 8)
            cols = max(cols, 8)
            num_mines = min(num_mines, int(rows * cols * 0.15))
            if num_mines < 1:
                num_mines = max(1, int(rows * cols * 0.05))
            game, move_history, seed = _generate_multi_region_state(
                rows, cols, num_mines, rng
            )

        if game is not None and game.state() == "ongoing":
            dataset_items.append(_build_dataset_item(game, seed, move_history))
            _track_config(game)
            phase_counts["pattern"] += 1
            pattern_generated += 1

    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    # PHASE 4: Mid-Game Logical Deduction (25%)
    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    print("  Phase 4: Mid-game deduction...")
    midgame_generated = 0
    midgame_attempts = 0
    while midgame_generated < n_midgame and midgame_attempts < n_midgame * 5:
        midgame_attempts += 1
        rows, cols, num_mines = _sample_board_with_density(rng)

        if num_mines == 0:
            continue

        seed = rng.randint(0, 999999)
        game = MinesweeperGame(rows=rows, cols=cols, num_mines=num_mines, seed=seed)

        if game.state() != "ongoing":
            continue

        total_safe = rows * cols - num_mines
        max_moves = min(15, max(3, total_safe - 1))
        num_moves = rng.randint(3, max_moves)

        move_history = _play_smart_moves(game, rng, num_moves, use_progressive_flags=True)

        if game.state() == "ongoing" and len(game._revealed) > 0:
            dataset_items.append(_build_dataset_item(game, seed, move_history))
            _track_config(game)
            phase_counts["midgame"] += 1
            midgame_generated += 1

    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    # PHASE 5: Endgame Completion (15%)
    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    print("  Phase 5: Endgame completion...")
    endgame_generated = 0
    endgame_attempts = 0
    while endgame_generated < n_endgame and endgame_attempts < n_endgame * 10:
        endgame_attempts += 1
        rows, cols, num_mines = _sample_board_with_density(rng, (0.05, 0.20))

        if num_mines < 1 or rows * cols < 8:
            continue

        completion = rng.uniform(0.80, 0.95)
        game, move_history, seed = _generate_endgame_state(
            rows, cols, num_mines, rng, completion_target=completion
        )

        if game is not None and game.state() == "ongoing":
            dataset_items.append(_build_dataset_item(game, seed, move_history))
            _track_config(game)
            phase_counts["endgame"] += 1
            endgame_generated += 1

    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    # PHASE 6: Forced Guess Scenarios (10%)
    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    print("  Phase 6: Forced guess scenarios...")
    forced_generated = 0
    forced_attempts = 0
    while forced_generated < n_forced and forced_attempts < n_forced * 15:
        forced_attempts += 1
        rows, cols, num_mines = _sample_board_with_density(rng, (0.10, 0.20))

        if num_mines < 2 or rows * cols < 6:
            continue

        game, move_history, seed = _generate_forced_guess_state(
            rows, cols, num_mines, rng
        )

        if game is not None and game.state() == "ongoing":
            dataset_items.append(_build_dataset_item(game, seed, move_history))
            _track_config(game)
            phase_counts["forced_guess"] += 1
            forced_generated += 1

    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    # Shuffle and trim
    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    rng.shuffle(dataset_items)
    dataset_items = dataset_items[:num_samples]
    ds = Dataset.from_list(dataset_items)

    return ds, config_counts, phase_counts, density_counts


# ══════════════════════════════════════════════════════════════════════
#  Generate & Analyze
# ══════════════════════════════════════════════════════════════════════

print("=" * 70)
print("  EXHAUSTIVE DATASET GENERATION")
print("  1000 samples | 6 phases | 50+ edge cases | density-stratified")
print("=" * 70)
print()

dataset, config_counts, phase_counts, density_counts = generate_exhaustive_dataset(
    num_samples=1000, rng_seed=42
)

print(f"\n{'─'*70}")
print(f"Created {len(dataset)} training examples\n")

# Extract arrays for analysis
all_rows = [item["board_rows"] for item in dataset]
all_cols = [item["board_cols"] for item in dataset]
all_mines = [item["board_mines"] for item in dataset]
densities = [m / (r * c) * 100 if r * c > 0 else 0 for r, c, m in zip(all_rows, all_cols, all_mines)]

# Phase distribution
print("Phase distribution:")
for phase, count in phase_counts.items():
    pct = count / len(dataset) * 100 if len(dataset) > 0 else 0
    bar = '█' * int(pct / 2)
    print(f"  {phase:14s}: {count:4d} ({pct:5.1f}%) {bar}")

# Size band distribution
print(f"\nSize band distribution (bell-curve curriculum):")
for lo, hi, target_w, label in SIZE_BANDS:
    count = sum(1 for r, c in zip(all_rows, all_cols) if max(r, c) >= lo and max(r, c) <= hi)
    pct = count / len(dataset) * 100 if len(dataset) > 0 else 0
    bar = '█' * int(pct / 2)
    print(f"  {label:6s} ({lo:2d}-{hi:2d}): {count:4d} ({pct:5.1f}%) target={target_w*100:.0f}% {bar}")

# Density distribution
print(f"\nDensity distribution:")
for band, count in density_counts.items():
    pct = count / len(dataset) * 100 if len(dataset) > 0 else 0
    bar = '█' * int(pct / 2)
    print(f"  {band:12s}: {count:4d} ({pct:5.1f}%) {bar}")

# Prompt tier distribution
print(f"\nPrompt tier distribution:")
tier_counts = {1: 0, 2: 0, 3: 0, 4: 0}
for r, c in zip(all_rows, all_cols):
    mx = max(r, c)
    if mx <= 5: tier_counts[1] += 1
    elif mx <= 12: tier_counts[2] += 1
    elif mx <= 25: tier_counts[3] += 1
    else: tier_counts[4] += 1
tier_labels = {1: 'Tiny (≤5)', 2: 'Small (6-12)', 3: 'Medium (13-25)', 4: 'Large (26+)'}
for tier, count in tier_counts.items():
    pct = count / len(dataset) * 100 if len(dataset) > 0 else 0
    bar = '█' * int(pct / 2)
    print(f"  Tier {tier} {tier_labels[tier]:16s}: {count:4d} ({pct:5.1f}%) {bar}")

zero_mine_count = sum(1 for m in all_mines if m == 0)
print(f"  Zero-mine boards: {zero_mine_count} ({zero_mine_count/len(dataset)*100:.1f}%)")

move_counts = [len(json.loads(item["move_history"])) for item in dataset]
print(f"\nMove statistics:")
print(f"  Min: {min(move_counts)}, Max: {max(move_counts)}, "
      f"Mean: {np.mean(move_counts):.1f}, Median: {np.median(move_counts):.1f}")
print(f"  Density: min={min(densities):.1f}%, max={max(densities):.1f}%, mean={np.mean(densities):.1f}%")

# Logical deduction coverage
print(f"\nLogical deduction coverage:")
has_safe = 0
has_mine = 0
has_both = 0
for item in dataset:
    mh = json.loads(item["move_history"])
    has_flag = any(m.get("type") == "flag" for m in mh)
    has_rev = any(m.get("type") == "reveal" for m in mh)
    if has_flag:
        has_mine += 1
    if has_rev:
        has_safe += 1
    if has_flag and has_rev:
        has_both += 1
print(f"  Samples with reveals: {has_safe} ({has_safe/len(dataset)*100:.1f}%)")
print(f"  Samples with flags:   {has_mine} ({has_mine/len(dataset)*100:.1f}%)")
print(f"  Samples with both:    {has_both} ({has_both/len(dataset)*100:.1f}%)")

# Verify dataset columns
print(f"\nDataset columns: {dataset.column_names}")
print(f"  ✅ board_rows, board_cols, board_mines present for reward functions")

# Sample prompt
print(f"\nSample prompt ({dataset[0]['board_rows']}x{dataset[0]['board_cols']}, "
      f"{dataset[0]['board_mines']} mines):")
print(dataset[0]["prompt"][0]["content"][:300] + "...")

# ── Save dataset to JSON ──
dataset_json_path = "minesweeper_dataset.json"
json_records = []
for item in dataset:
    record = {
        "seed": item["seed"],
        "move_history": item["move_history"],
        "board_rows": item["board_rows"],
        "board_cols": item["board_cols"],
        "board_mines": item["board_mines"],
        "prompt_text": item["prompt"][0]["content"],
    }
    json_records.append(record)

with open(dataset_json_path, "w") as f:
    json.dump(json_records, f, indent=2)

print(f"   {len(json_records)} records with fields: seed, move_history, board_rows/cols/mines, prompt_text")
print(f"\n✅ Dataset saved to {dataset_json_path} ({os.path.getsize(dataset_json_path) / 1024:.1f} KB)")

Generating training dataset with curriculum learning...
Created 2000 training examples (all ongoing games)

  Fresh games (0 moves): 637 (31.9%)
  Early game (1-2):      833 (41.6%)
  Mid game (3-8):        516 (25.8%)
  Late game (9+):        14 (0.7%)
  Avg moves per state:   1.9

Example training prompt (first 300 chars):
You are an expert Minesweeper solver. Output ONE action as JSON only.

RULES:
- Numbers show how many of their 8 neighbors are mines.
- Subtract flagged neighbors from the number to find remaining mines.
- If remaining mines == remaining unrevealed neighbors → all are mines → FLAG.
- If remaining mi...


# Configure GRPO Training

Set up GRPO trainer with all hyperparameters:

In [None]:
from trl import GRPOConfig, GRPOTrainer

# ── Lengths ──
# Training prompts include full strategic reasoning guidance.
# Small boards (≤20): ~800-1600 tokens (full grid + master prompt)
# Medium boards (21-35): ~700-1400 tokens (frontier + master prompt)
# Large boards (36-50): ~600-1200 tokens (summary + master prompt)
# max_prompt_length=1900 gives safe margin for all sizes within max_seq_length=2048
# Inference prompts are ~60% shorter — used during eval/test.
max_prompt_length = 1900
max_completion_length = 128  # Pure JSON output is ~40 tokens; 128 is generous

# ── GRPO Configuration ──
training_args = GRPOConfig(
    # === Generation ===
    temperature = 0.9,           # Exploration during training
    top_p = 0.95,

    # === Optimization ===
    learning_rate = 2e-5,
    weight_decay = 0.01,
    warmup_ratio = 0.05,
    lr_scheduler_type = "cosine",
    optim = "adamw_8bit",
    max_grad_norm = 0.5,

    # === Batch sizes ===
    logging_steps = 1,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 4,
    num_generations = 8,

    # === Lengths ===
    max_prompt_length = max_prompt_length,
    max_completion_length = max_completion_length,

    # === Training duration ===
    max_steps = 500,
    save_steps = 100,

    # === GRPO specific ===
    beta = 0.04,                 # Mild KL penalty to prevent reward hacking
    num_iterations = 1,

    # === Reward weighting (gameplay >> format >> strategy) ===
    reward_weights = [0.15, 0.70, 0.15],

    # === Ensure extra dataset columns are NOT removed ===
    remove_unused_columns = False,

    # === Output ===
    report_to = "none",
    output_dir = "minesweeper_grpo_v2",
    seed = 42,
    bf16 = True,
)

print("Training configuration:")
print(f"  Max steps:           {training_args.max_steps}")
print(f"  Generations/state:   {training_args.num_generations}")
print(f"  Learning rate:       {training_args.learning_rate}")
print(f"  LR scheduler:       {training_args.lr_scheduler_type}")
print(f"  Max grad norm:       {training_args.max_grad_norm}")
print(f"  Beta (KL penalty):   {training_args.beta}")
print(f"  Reward weights:      {training_args.reward_weights}")
print(f"  Prompt/Completion:   {max_prompt_length}/{max_completion_length}")
print(f"  Temperature:         {training_args.temperature}")
print(f"  remove_unused_cols:  {training_args.remove_unused_columns}")

print(f"  LoRA rank:           {lora_rank}")print(f"  Board range:         1-50 rows × 1-50 cols, 0-20% mines")

Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.
We will change the batch size of 1 to the `num_generations` of 8
Training configuration:
  Model:               Qwen2.5-14B-Instruct
  Max steps:           500
  Generations/state:   8
  Learning rate:       2e-05
  LR scheduler:       SchedulerType.COSINE
  Max grad norm:       0.5
  Loss type:           bnpo
  Beta (KL penalty):   0.001
  Reward weights:      [0.15, 0.7, 0.15]
  Prompt/Completion:   700/200
  Temperature:         0.9
  Top-p:               0.95
  LoRA rank:           32


In [None]:
from transformers import TrainerCallback

# Board configs to evaluate on during training (mix of sizes from 1-50 range)
EVAL_CONFIGS = [
    (1, 1, 0),     # Trivial — 0 mines
    (3, 3, 1),     # Tiny
    (5, 5, 3),     # Small
    (6, 6, 5),     # Standard
    (8, 8, 10),    # Medium
    (10, 10, 20),  # Large
    (15, 15, 45),  # XL
    (6, 8, 6),     # Rectangular
    (1, 10, 2),    # Row board
    (20, 20, 80),  # XX-Large
]


class MinesweeperEvalCallback(TrainerCallback):
    """Periodically play games during training with variable board sizes.

    NO move limit — games end only on success (all safe revealed)
    or failure (mine hit). Max iterations capped to prevent infinite loops
    from repeated invalid actions.
    """

    def __init__(self, eval_every_steps=50, num_games=10):
        self.eval_every_steps = eval_every_steps
        self.num_games = min(num_games, len(EVAL_CONFIGS))

    def on_step_end(self, args, state, control, model=None, processing_class=None, **kwargs):
        if state.global_step % self.eval_every_steps != 0:
            return

        tokenizer = processing_class
        if tokenizer is None or model is None:
            return

        was_training = model.training
        model.eval()

        wins = 0
        total_moves = 0
        invalid_count = 0

        for i in range(self.num_games):
            rows, cols, mines = EVAL_CONFIGS[i % len(EVAL_CONFIGS)]
            game = MinesweeperGame(rows=rows, cols=cols, num_mines=mines,
                                   seed=10000 + i)
            moves = 0
            invalids = 0
            consecutive_invalids = 0
            # Safety cap: prevent infinite loops (not a move limit — just loop protection)
            max_iterations = rows * cols * 3 + 20

            iteration = 0
            while game.state() == "ongoing" and iteration < max_iterations:
                iteration += 1
                prompt = format_state_for_llm(game, mode="inference")
                text = tokenizer.apply_chat_template(
                    [{"role": "user", "content": prompt}],
                    tokenize=False,
                    add_generation_prompt=True,
                )
                inputs = tokenizer(text, return_tensors="pt", truncation=True,
                                   max_length=max_prompt_length + 100)
                inputs = {k: v.to(model.device) for k, v in inputs.items()}

                with torch.no_grad():
                    output = model.generate(
                        **inputs,
                        temperature=0.3,
                        max_new_tokens=128,
                        do_sample=True,
                        top_p=0.8,
                    )

                # Decode ONLY the generated tokens (not the prompt)
                gen_tokens = output[0][inputs["input_ids"].shape[1]:]
                response = tokenizer.decode(gen_tokens, skip_special_tokens=True)
                action = parse_llm_action(response)

                if action is None:
                    invalids += 1
                    consecutive_invalids += 1
                    if consecutive_invalids >= 5:
                        break  # Too many consecutive invalid actions
                    continue

                consecutive_invalids = 0
                result = game.do_action(action)
                if result in ("mine", "win"):
                    moves += 1
                    break
                elif result == "ok":
                    moves += 1
                else:
                    # Invalid move (out_of_bounds, already_revealed, etc.)
                    invalids += 1
                    consecutive_invalids += 1
                    if consecutive_invalids >= 5:
                        break

            if game.state() == "success":
                wins += 1
            total_moves += moves
            invalid_count += invalids

        win_rate = wins / self.num_games
        avg_moves = total_moves / self.num_games
        print(f"\n[Eval @ step {state.global_step}] "
              f"Win: {wins}/{self.num_games} ({win_rate*100:.0f}%) | "
              f"Avg moves: {avg_moves:.1f} | "
              f"Invalid: {invalid_count}\n")

        if was_training:
            model.train()


eval_callback = MinesweeperEvalCallback(eval_every_steps=50, num_games=10)
print(f"Eval callback: {eval_callback.num_games} games every "
      f"{eval_callback.eval_every_steps} steps")
print(f"  Configs: {len(EVAL_CONFIGS)} sizes from 1x1 to 20x20")
print(f"  Uses inference-mode prompt (60% fewer tokens)")
print(f"  No move limit — only success or failure")

Eval callback: 10 games every 50 steps (temp=0.3 for deterministic eval)


# Train the Model

Start GRPO training with reward functions:

In [None]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        valid_json_reward,   # Reward valid JSON format + conciseness
        gameplay_scores,     # Core gameplay (all 12 criteria)
        strategic_reward,    # Logical deduction bonuses
    ],
    args = training_args,
    train_dataset = dataset,
    callbacks = [eval_callback],  # Periodic gameplay evaluation
)

print("Starting GRPO training with 3 reward functions...")
print("  [1] valid_json_reward  (weight: 0.15)")
print("  [2] gameplay_scores    (weight: 0.70)")
print("  [3] strategic_reward   (weight: 0.15)")
trainer.train()

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None}.


Starting GRPO training with 3 reward functions...
  [1] valid_json_reward  (weight: 0.15)
  [2] gameplay_scores    (weight: 0.70)
  [3] strategic_reward   (weight: 0.15)


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 2,000 | Num Epochs = 1 | Total steps = 500
O^O/ \_/ \    Batch size per device = 8 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (8 x 4 x 1) = 32
 "-____-"     Trainable parameters = 137,625,600 of 14,907,659,264 (0.92% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss,reward,reward_std,completions / mean_length,completions / min_length,completions / max_length,completions / clipped_ratio,completions / mean_terminated_length,completions / min_terminated_length,completions / max_terminated_length,sampling / sampling_logp_difference / mean,sampling / sampling_logp_difference / max,sampling / importance_sampling_ratio / min,sampling / importance_sampling_ratio / mean,sampling / importance_sampling_ratio / max,kl,rewards / valid_json_reward / mean,rewards / valid_json_reward / std,rewards / gameplay_scores / mean,rewards / gameplay_scores / std,rewards / strategic_reward / mean,rewards / strategic_reward / std
1,0.0,10.775,0.0,14.0,14.0,14.0,0.0,14.0,14.0,14.0,0,0,0,0,0,9e-06,3.0,0.0,14.0,4.310527,3.5,1.524001
2,0.0,-0.4375,0.0,14.0,14.0,14.0,0.0,14.0,14.0,14.0,No Log,No Log,No Log,No Log,No Log,5.2e-05,3.0,0.0,-1.75,18.008959,2.25,1.813925
3,0.0001,-3.4125,0.0,14.625,14.0,24.0,0.0,14.625,14.0,24.0,No Log,No Log,No Log,No Log,No Log,0.023024,3.0,0.0,-6.0,19.423962,2.25,1.813925
4,0.0,5.428125,1.177424,14.0,14.0,14.0,0.0,14.0,14.0,14.0,No Log,No Log,No Log,No Log,No Log,0.000511,3.0,0.0,6.46875,16.329241,3.0,2.155264
5,0.0001,0.621875,2.184623,14.625,14.0,24.0,0.0,14.625,14.0,24.0,No Log,No Log,No Log,No Log,No Log,0.028749,3.0,0.0,0.03125,11.830754,1.0,1.016001
6,0.0,4.565625,2.72226,14.0,14.0,14.0,0.0,14.0,14.0,14.0,No Log,No Log,No Log,No Log,No Log,0.000225,3.0,0.0,5.34375,16.736303,2.5,2.540002
7,0.0,12.6375,0.0,14.0,14.0,14.0,0.0,14.0,14.0,14.0,No Log,No Log,No Log,No Log,No Log,1e-05,3.0,0.0,16.5,4.158163,4.25,1.319824


# Test Trained Model

Evaluate the finetuned model:

In [None]:
# Test on variable board sizes across the full competition range
FastLanguageModel.for_inference(model)

test_configs = [
    (1, 1, 0, 200, "1x1 Trivial"),
    (3, 3, 1, 201, "Tiny"),
    (5, 5, 3, 202, "Small/Easy"),
    (6, 6, 5, 99,  "Standard"),
    (8, 8, 10, 101, "Medium"),
    (10, 10, 20, 203, "Large 20%"),
    (15, 15, 45, 204, "XL 20%"),
    (1, 20, 4, 205, "Row Board"),
    (20, 1, 4, 206, "Column Board"),
    (5, 5, 0, 207, "Zero Mines"),
]

for rows, cols, mines, seed, label in test_configs:
    print(f"\n{'='*50}")
    print(f"=== {label} ({rows}x{cols}, {mines} mines) ===")
    print(f"{'='*50}")

    test_game = MinesweeperGame(rows=rows, cols=cols, num_mines=mines, seed=seed)

    # Handle already-won games (0-mine boards that auto-cascade)
    if test_game.state() != "ongoing":
        print(f"  Game auto-resolved: {test_game.state()}")
        continue

    test_prompt = format_state_for_llm(test_game, mode="inference")

    test_text = tokenizer.apply_chat_template(
        [{"role": "user", "content": test_prompt}],
        tokenize=False,
        add_generation_prompt=True,
    )

    inputs = tokenizer(test_text, return_tensors="pt", truncation=True,
                       max_length=max_prompt_length + 100)
    inputs = {k: v.to("cuda") for k, v in inputs.items()}

    output = model.generate(
        **inputs,
        temperature=0.3,
        max_new_tokens=128,
        do_sample=True,
        top_p=0.8,
        repetition_penalty=1.2,
    )

    # Decode only generated tokens
    gen_tokens = output[0][inputs["input_ids"].shape[1]:]
    response_text = tokenizer.decode(gen_tokens, skip_special_tokens=True)
    print(f"Response: {response_text.strip()}")

    action = parse_llm_action(response_text)
    print(f"Parsed action: {action}")

    if action:
        result = test_game.do_action(action)
        print(f"Result: {result} | Game state: {test_game.state()}")
    else:
        print("⚠️ Failed to parse a valid action")

# Exhaustive Evaluation: Full Competition Range

Play complete games across 37 board configurations (1×1 to 50×50, 0-20% mines).
**No move limit** — only two outcomes: SUCCESS (all safe cells revealed) or FAILURE (mine revealed).

In [None]:
def play_full_game(model, tokenizer, rows=6, cols=6, num_mines=5, seed=None,
                   verbose=False):
    """Play a complete Minesweeper game, tracking detailed metrics.

    Supports any board size from 1x1 to 50x50.
    NO move limit — game ends ONLY on:
      - SUCCESS: all non-mine cells are revealed
      - FAILURE: any mine cell is revealed
    A safety iteration cap prevents infinite loops from repeated invalid actions.
    """
    game = MinesweeperGame(rows=rows, cols=cols, num_mines=num_mines, seed=seed)

    # Edge case: game already won (e.g., 0-mine board)
    if game.state() != "ongoing":
        return {
            "game": game,
            "moves": 0,
            "logical_moves": 0,
            "flags_correct": 0,
            "flags_wrong": 0,
            "total_invalids": 0,
            "result": game.state(),
            "progress": game.progress(),
            "config": f"{rows}x{cols}m{num_mines}",
        }

    moves = 0
    consecutive_invalids = 0
    total_invalids = 0
    logical_moves = 0
    flags_correct = 0
    flags_wrong = 0
    # Safety cap to prevent infinite loops — NOT a game move limit
    max_iterations = rows * cols * 3 + 50

    iteration = 0
    while game.state() == "ongoing" and iteration < max_iterations:
        iteration += 1
        prompt = format_state_for_llm(game, mode="inference")
        text = tokenizer.apply_chat_template(
            [{"role": "user", "content": prompt}],
            tokenize=False,
            add_generation_prompt=True,
        )

        inputs = tokenizer(text, return_tensors="pt", truncation=True,
                           max_length=max_prompt_length + 100)
        inputs = {k: v.to("cuda") for k, v in inputs.items()}

        with torch.no_grad():
            output = model.generate(
                **inputs,
                temperature=0.3,
                max_new_tokens=128,
                do_sample=True,
                top_p=0.8,
                repetition_penalty=1.2,
            )

        # Decode ONLY generated tokens
        gen_tokens = output[0][inputs["input_ids"].shape[1]:]
        response = tokenizer.decode(gen_tokens, skip_special_tokens=True)
        action = parse_llm_action(response)

        if action is None:
            consecutive_invalids += 1
            total_invalids += 1
            if consecutive_invalids >= 5:
                break  # Agent is stuck — abort
            continue

        consecutive_invalids = 0

        # Track logical moves (compute BEFORE applying action)
        safe_set, mine_set = _compute_safe_and_mine_cells(game)
        r, c = action["row"], action["col"]
        if action["type"] == "reveal" and (r, c) in safe_set:
            logical_moves += 1
        elif action["type"] == "flag" and (r, c) in mine_set:
            logical_moves += 1

        # Track flag accuracy
        if action["type"] == "flag":
            if 0 <= r < game.rows and 0 <= c < game.cols:
                if game._board[r][c] == -1:
                    flags_correct += 1
                else:
                    flags_wrong += 1

        if verbose:
            print(f"  Move {moves}: {action}")

        result = game.do_action(action)
        if result in ("mine", "win", "ok"):
            moves += 1
        elif result in ("out_of_bounds", "already_revealed", "flagged_cell",
                         "invalid_flag", "invalid_format"):
            # Invalid moves don't count but game stays ongoing
            total_invalids += 1
            consecutive_invalids += 1
            if consecutive_invalids >= 5:
                break

        if result in ("mine", "win"):
            break

    return {
        "game": game,
        "moves": moves,
        "logical_moves": logical_moves,
        "flags_correct": flags_correct,
        "flags_wrong": flags_wrong,
        "total_invalids": total_invalids,
        "result": game.state(),
        "progress": game.progress(),
        "config": f"{rows}x{cols}m{num_mines}",
    }


# ──────────────────────────────────────────────────────────────────────
# EXHAUSTIVE Multi-Size Evaluation — Competition Spec
# n, m ∈ [1, 50], mines 0-20% of total cells
# Only two outcomes: SUCCESS or FAILURE (no timeouts)
# ──────────────────────────────────────────────────────────────────────

EVAL_SUITE = [
    # (rows, cols, mines, num_games, label)

    # === Trivial / Edge Cases ===
    (1, 1, 0,   5,  "1x1 trivial"),
    (1, 2, 0,   5,  "1x2 trivial"),
    (2, 1, 0,   5,  "2x1 trivial"),
    (2, 2, 0,   5,  "2x2 no mines"),
    (3, 3, 0,   5,  "3x3 no mines"),
    (5, 5, 0,   5,  "5x5 no mines"),

    # === Tiny Boards ===
    (3, 3, 1,  10,  "3x3 1 mine"),
    (4, 4, 3,  10,  "4x4 3 mines"),

    # === Small Boards ===
    (5, 5, 3,  15,  "5x5 easy"),
    (5, 5, 5,  10,  "5x5 max density"),

    # === Standard ===
    (6, 6, 5,  20,  "6x6 standard"),
    (6, 6, 7,  10,  "6x6 hard"),

    # === Medium ===
    (7, 7, 7,  10,  "7x7 medium"),
    (8, 8, 10, 10,  "8x8 medium"),
    (8, 8, 12, 10,  "8x8 max density"),

    # === Large ===
    (10, 10, 20, 10, "10x10 20%"),
    (10, 10, 10,  5, "10x10 10%"),
    (15, 15, 45,  5, "15x15 20%"),

    # === XL ===
    (20, 20, 80,  3, "20x20 20%"),
    (25, 25, 125, 3, "25x25 20%"),
    (30, 30, 180, 2, "30x30 20%"),

    # === XXL ===
    (40, 40, 320, 2, "40x40 20%"),
    (50, 50, 500, 2, "50x50 20%"),

    # === Rectangular ===
    (1, 10, 2,   5, "1x10 row"),
    (10, 1, 2,   5, "10x1 column"),
    (1, 50, 10,  3, "1x50 row"),
    (50, 1, 10,  3, "50x1 column"),
    (6, 8, 6,    5, "6x8 rect"),
    (8, 6, 6,    5, "8x6 rect"),
    (5, 15, 15,  3, "5x15 wide"),
    (15, 5, 15,  3, "15x5 tall"),
    (3, 50, 30,  2, "3x50 extreme wide"),
    (50, 3, 30,  2, "50x3 extreme tall"),

    # === Sparse (low density) ===
    (10, 10, 1,  5, "10x10 sparse"),
    (20, 20, 4,  3, "20x20 sparse"),
    (50, 50, 10, 2, "50x50 sparse"),
]

FastLanguageModel.for_inference(model)
total_games = sum(n for _,_,_,n,_ in EVAL_SUITE)
print(f"{'='*80}")
print(f"  EXHAUSTIVE EVALUATION — {total_games} games across {len(EVAL_SUITE)} configs")
print(f"  Competition spec: n,m ∈ [1,50], mines 0-20%, no move limit")
print(f"  Only two outcomes: SUCCESS or FAILURE")
print(f"{'='*80}\n")

all_results = []
per_config_stats = {}

for rows, cols, mines, num_games, label in EVAL_SUITE:
    config_key = f"{rows}x{cols}m{mines}"
    wins = 0
    fails = 0
    config_results = []

    for i in range(num_games):
        info = play_full_game(model, tokenizer, rows=rows, cols=cols,
                              num_mines=mines, seed=5000 + i + hash(config_key) % 10000)
        config_results.append(info)
        all_results.append(info)

        if info["result"] == "success":
            wins += 1
        elif info["result"] == "failed":
            fails += 1

    # Per-config summary
    avg_moves = np.mean([r["moves"] for r in config_results])
    avg_logical = np.mean([r["logical_moves"] for r in config_results])
    avg_progress = np.mean([r["progress"] for r in config_results])
    avg_invalids = np.mean([r["total_invalids"] for r in config_results])
    stuck = sum(1 for r in config_results if r["result"] == "ongoing")
    wr = wins / num_games * 100

    per_config_stats[config_key] = {
        "label": label, "wins": wins, "fails": fails, "stuck": stuck,
        "total": num_games, "win_rate": wr,
        "avg_moves": avg_moves, "avg_progress": avg_progress,
    }

    status_icon = "✅" if wr >= 50 else "⚠️" if wr >= 20 else "❌"
    print(f"  {status_icon} {label:22s} ({config_key:12s}): "
          f"{wins:2d}/{num_games:2d} wins ({wr:5.1f}%) | "
          f"fails={fails} stuck={stuck} | "
          f"moves={avg_moves:5.1f} | logical={avg_logical:4.1f} | "
          f"progress={avg_progress:.0%} | invalids={avg_invalids:.1f}")

# ── Overall Summary ──
total_games_actual = len(all_results)
total_wins = sum(1 for r in all_results if r["result"] == "success")
total_fails = sum(1 for r in all_results if r["result"] == "failed")
total_stuck = sum(1 for r in all_results if r["result"] == "ongoing")
fc = sum(r["flags_correct"] for r in all_results)
fw = sum(r["flags_wrong"] for r in all_results)

print(f"\n{'='*80}")
print(f"  OVERALL: {total_wins}/{total_games_actual} wins ({total_wins/total_games_actual*100:.1f}%)")
print(f"{'='*80}")
print(f"  Wins (success):    {total_wins:4d} ({total_wins/total_games_actual*100:.1f}%)")
print(f"  Losses (failure):  {total_fails:4d} ({total_fails/total_games_actual*100:.1f}%)")
print(f"  Stuck (loop cap):  {total_stuck:4d} ({total_stuck/total_games_actual*100:.1f}%)")
print(f"  Avg moves:         {np.mean([r['moves'] for r in all_results]):.1f}")
print(f"  Avg progress:      {np.mean([r['progress'] for r in all_results]):.0%}")
print(f"  Avg logical moves: {np.mean([r['logical_moves'] for r in all_results]):.1f}")
if fc + fw > 0:
    print(f"  Flag accuracy:     {fc}/{fc+fw} ({fc/(fc+fw)*100:.1f}%)")
else:
    print(f"  Flags: none placed")

# ── Category Breakdown ──
categories = {
    "Trivial (0 mines)": [k for k, v in per_config_stats.items() if "m0" in k],
    "Tiny (≤4x4)":       [k for k, v in per_config_stats.items()
                           if v["label"].startswith(("3x3", "4x4")) and "m0" not in k],
    "Small (5x5)":       [k for k, v in per_config_stats.items() if k.startswith("5x5")],
    "Standard (6x6)":    [k for k, v in per_config_stats.items() if k.startswith("6x6")],
    "Medium (7-8)":      [k for k, v in per_config_stats.items()
                           if k.startswith(("7x7", "8x8"))],
    "Large (10-15)":     [k for k, v in per_config_stats.items()
                           if k.startswith(("10x10", "15x15"))],
    "XL (20-50)":        [k for k, v in per_config_stats.items()
                           if k.startswith(("20x20", "25x25", "30x30", "40x40", "50x50"))],
    "Rectangular":       [k for k, v in per_config_stats.items()
                           if "rect" in v["label"] or "row" in v["label"]
                           or "column" in v["label"] or "wide" in v["label"]
                           or "tall" in v["label"]],
    "Sparse":            [k for k, v in per_config_stats.items() if "sparse" in v["label"]],
}

print(f"\n{'='*80}")
print(f"  CATEGORY BREAKDOWN")
print(f"{'='*80}")
for cat_name, keys in categories.items():
    if not keys:
        continue
    cat_wins = sum(per_config_stats[k]["wins"] for k in keys)
    cat_total = sum(per_config_stats[k]["total"] for k in keys)
    if cat_total > 0:
        cat_wr = cat_wins / cat_total * 100
        icon = "✅" if cat_wr >= 50 else "⚠️" if cat_wr >= 20 else "❌"
        print(f"  {icon} {cat_name:22s}: {cat_wins:3d}/{cat_total:3d} ({cat_wr:5.1f}%)")

print(f"{'='*80}")

# Save the Model

Save your trained model for competition submission:

In [None]:
# Save LoRA adapters
model.save_pretrained("my_minesweeper_model")
tokenizer.save_pretrained("my_minesweeper_model")
print("✅ LoRA adapters saved to: my_minesweeper_model/")

# Save merged model in 16bit (local file name which will be used for eval)
if True:
    model.save_pretrained_merged(
        "my_minesweeper_model_merged",
        tokenizer,
        save_method = "merged_16bit"
    )
    print("✅ Merged 16-bit model saved to: my_minesweeper_model_merged/")

# Fixes & Improvements Applied

## Competition Spec Compliance
| Requirement | Implementation |
|-------------|---------------|
| Board rows: 1–50 | `MIN_ROWS=1, MAX_ROWS=50` in game engine |
| Board cols: 1–50 | `MIN_COLS=1, MAX_COLS=50` in game engine |
| Mines: 0–20% of cells | `MAX_MINE_DENSITY=0.20`, 0 mines allowed |
| No move limit | Removed `max_moves` from eval/play — only success/failure |
| Success = all safe revealed | `_check_win()` checks `len(revealed) >= safe_cells` |
| Failure = mine revealed | Only `_reveal_cell` hitting mine sets `_state="failed"` |
| max_new_tokens: 128 | `max_completion_length=128` in GRPOConfig |
| Dataset as JSON | Saved to `minesweeper_dataset.json` |

## Prompt System (Master Template)
| Feature | Implementation |
|---------|---------------|
| 3-tier board representation | Format A (≤20): full grid, B (21-35): frontier + regions, C (36-50): summary + critical areas |
| Phase-aware prompts | Opening (density-based center strategy), Mid-game (deduction checklist), Endgame (completion/flag accounting) |
| Edge case prompts | 0 mines, linear boards, tiny boards, large boards, very high density, one cell left |
| Strategic reasoning | STEP 1: logical deductions (satisfied/constrained numbers), STEP 2: probability estimation |
| Inference prompt | 60% fewer tokens for eval/test — minimal rules, no reasoning guidance |
| Symbol standardization | `?` for unrevealed (not `.`), `F` for flagged, `0`-`8` for revealed |
| Logical hints | `_compute_safe_and_mine_cells()` provides SAFE/MINE hints in prompt |
| Game phase tracking | `game_phase()` returns "opening"/"midgame"/"endgame" based on move count and progress |

## Critical Bugs Fixed
| # | Bug | Fix |
|---|-----|-----|
| 1 | `do_action()` set `_state="failed"` for ALL invalid moves | Only `mine` sets state to "failed"; others keep game "ongoing" |
| 2 | Reward functions signature mismatch with TRL GRPOTrainer | Changed to `(prompts, completions, **kwargs)` |
| 3 | Hardcoded `MinesweeperGame(rows=6, cols=6, num_mines=5)` in rewards | `_reconstruct_game()` reads board_rows/cols/mines from kwargs |
| 4 | `max_prompt_length=700` truncated prompts | Increased to 1900 tokens for boards up to 50×50 |
| 5 | Separate O(n²) passes for safe/mine cells | Combined `_compute_safe_and_mine_cells()` — single pass |
| 6 | Eval decoded full output including prompt | Decodes only generated tokens |
| 7 | `remove_unused_columns` not set | Explicitly `False` in GRPOConfig |
| 8 | Random flags in training data | `_smart_flag()` only flags logically certain mines |
| 9 | Reward scale imbalance | Rebalanced: invalid JSON=-10, mine=-25, win=+100 |

## Design Improvements
| # | Change | Impact |
|---|--------|--------|
| 10 | Board sizes 1×1 to 50×50 with weighted random sampling | Full competition range coverage |
| 11 | 0-mine boards supported (trivial instant-win on first reveal) | Edge case coverage |
| 12 | 25+ explicit edge-case configs in dataset | 1×1, 1×50, 50×1, 0-mine, max-density, rectangular |
| 13 | 3000 training samples with board diversity | Rows/cols from 1–50, densities 0–20% |
| 14 | Dataset saved to JSON file | `minesweeper_dataset.json` for inspection/reuse |
| 15 | Master prompt with 3-tier board representation | Small/medium/large boards get appropriate display format |
| 16 | Phase-aware prompts (opening/midgame/endgame) | Contextual strategic guidance matched to game phase |
| 17 | Inference prompt mode (60% fewer tokens) | Used in eval callback and post-training evaluation |
| 18 | Exhaustive eval: 37 configs, ~200 games (1×1 to 50×50) | Trivial, tiny, small, standard, medium, large, XL, XXL, rectangular, sparse |
| 19 | No move limit in eval — only success/failure | Matches competition rules exactly |
| 20 | Category breakdown in eval (trivial/small/medium/large/XL/rectangular/sparse) | Clear performance visibility |
| 21 | `max_prompt_length=1900` for large board support | Master prompt + board fits within this even for 50×50 |
| 22 | `beta=0.04` mild KL penalty | Prevents reward hacking / policy drift |
| 23 | `?` symbol for unrevealed cells (not `.`) | Clearer visual distinction, matches edge case guidance |
| 24 | Move count tracking + `game_phase()` method | Enables phase-aware prompt selection |