# Minesweeper LLM Competition - Custom GRPO Training

## Goal
Finetune an LLM with LoRA using GRPO to play Minesweeper by:
- **Input**: JSON game state (board configuration)
- **Output**: JSON action (reveal or flag a cell)

Teams will compete to train the best Minesweeper-playing LLM!

## Training Approach
- **Model**: Qwen2.5-14B-Instruct (from /root/.cache/huggingface/hub)
- **Method**: GRPO (Group Relative Policy Optimization)
- **Framework**: Unsloth (2-6x faster, 70% less VRAM)
- **Hardware**: AMD MI300X GPU (192GB HBM3, ROCm)

# Load Model with Unsloth

Load Qwen3-4B with LoRA configuration:

In [1]:
import os

os.environ["HF_HOME"] = "./workspace/hf_cache"
os.environ["HUGGINGFACE_HUB_CACHE"] = "./workspace/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "./workspace/hf_cache"
os.environ["HF_DATASETS_CACHE"] = "./workspace/hf_cache"


In [2]:
from huggingface_hub import snapshot_download

snapshot_download(
    repo_id="unsloth/Qwen2.5-14B-Instruct",
    local_dir="./workspace/Qwen2.5-14B-Instruct",
    local_dir_use_symlinks=False,
)


ModuleNotFoundError: No module named 'huggingface_hub'

In [3]:
from unsloth import FastLanguageModel
import torch
import yaml

# Load config
with open("minesweeper_config_me.yaml", "r") as f:
    config = yaml.safe_load(f)

lora_rank = config.get("lora_rank", 32)

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="/workspace/workspace/Qwen2.5-14B-Instruct",
    load_in_4bit=False,   # AMD → 4bit disabled
    max_seq_length=2048,  # Increased: larger boards need longer prompts
    dtype=torch.bfloat16,
)

print("Model loaded successfully!")
print(f"Device: {model.device}")
print(f"LoRA rank: {lora_rank} (from config)")

# ── Add newline as EOS token so generation stops after first JSON line ──
# (replaces stop_strings which isn't supported by this GRPOConfig version)
newline_token_id = tokenizer.encode("\n", add_special_tokens=False)[-1]
original_eos = tokenizer.eos_token_id
if original_eos != newline_token_id:
    model.generation_config.eos_token_id = [original_eos, newline_token_id]
    model.config.eos_token_id = [original_eos, newline_token_id]
    print(f"  ✅ Added newline (token {newline_token_id}) as additional EOS → stops after JSON line")
    print(f"     EOS tokens: {model.generation_config.eos_token_id}")

ModuleNotFoundError: No module named 'unsloth'

# Add LoRA Adapters

Add LoRA layers for efficient finetuning:

In [4]:
model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank,
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha = lora_rank,           # alpha = rank → scaling factor = 1.0 (stable training)
    lora_dropout = 0.05,              # Small dropout for regularization
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
)
print(f"LoRA config: rank={lora_rank}, alpha={lora_rank}, dropout=0.05")
model.print_trainable_parameters()

NameError: name 'FastLanguageModel' is not defined

# Minesweeper Game Implementation

Custom Minesweeper environment supporting:
- Customizable board size and mine count
- Actions: reveal or flag cells
- Win: reveal all safe cells
- Lose: reveal a mine

In [2]:
from dataclasses import dataclass, field
from typing import List, Tuple, Optional, Set
import random
import math

# ──────────────────────────────────────────────────────────────────────
# Board size configuration — competition spec: n,m ∈ [1,50], mines 0-20%
# ──────────────────────────────────────────────────────────────────────

# Constants
MIN_ROWS, MAX_ROWS = 1, 50
MIN_COLS, MAX_COLS = 1, 50
MIN_MINE_DENSITY = 0.0    # 0% mines allowed (trivial board)
MAX_MINE_DENSITY = 0.20   # 20% of total cells


def sample_board_config(rng=None):
    """Sample a random (rows, cols, num_mines) from the full competition range.

    - rows ∈ [1, 50], cols ∈ [1, 50]
    - mines ∈ [0, floor(0.20 * rows * cols)]
    - Uses a weighted distribution favoring smaller boards during training
      (large boards are rare but included for coverage).
    """
    rng = rng or random.Random()

    # Weighted size distribution: favor small/medium, still cover large
    size_band = rng.random()
    if size_band < 0.30:
        # Small: 1-8
        rows = rng.randint(1, 8)
        cols = rng.randint(1, 8)
    elif size_band < 0.55:
        # Medium: 5-15
        rows = rng.randint(5, 15)
        cols = rng.randint(5, 15)
    elif size_band < 0.75:
        # Large: 10-30
        rows = rng.randint(10, 30)
        cols = rng.randint(10, 30)
    elif size_band < 0.90:
        # XL: 20-40
        rows = rng.randint(20, 40)
        cols = rng.randint(20, 40)
    else:
        # Full range: 1-50 (including extreme cases)
        rows = rng.randint(1, 50)
        cols = rng.randint(1, 50)

    total = rows * cols
    max_mines = int(total * MAX_MINE_DENSITY)  # floor(0.20 * total)

    if max_mines == 0:
        num_mines = 0  # Boards too small for any mines at ≤20%
    else:
        num_mines = rng.randint(0, max_mines)

    return rows, cols, num_mines


def mine_density(rows, cols, num_mines):
    """Compute mine density as a fraction."""
    total = rows * cols
    return num_mines / total if total > 0 else 0.0


@dataclass
class MinesweeperGame:
    rows: int
    cols: int
    num_mines: int
    seed: Optional[int] = None
    _rng: random.Random = field(init=False, repr=False)
    _board: List[List[int]] = field(init=False, repr=False)  # -1 = mine, 0-8 = count
    _revealed: Set[Tuple[int, int]] = field(init=False, repr=False, default_factory=set)
    _flagged: Set[Tuple[int, int]] = field(init=False, repr=False, default_factory=set)
    _state: str = field(default="ongoing", init=False, repr=False)
    _move_count: int = field(default=0, init=False, repr=False)

    def __post_init__(self):
        # ── Input validation — competition spec: n,m ∈ [1,50] ──
        if self.rows < MIN_ROWS or self.cols < MIN_COLS:
            raise ValueError(f"Board too small: {self.rows}x{self.cols} (min {MIN_ROWS}x{MIN_COLS})")
        if self.rows > MAX_ROWS or self.cols > MAX_COLS:
            raise ValueError(f"Board too large: {self.rows}x{self.cols} (max {MAX_ROWS}x{MAX_COLS})")
        if self.num_mines < 0:
            raise ValueError(f"num_mines cannot be negative, got {self.num_mines}")
        if self.num_mines >= self.rows * self.cols:
            raise ValueError(f"Too many mines ({self.num_mines}) for {self.rows}x{self.cols} board")
        # 0 mines is allowed — trivial board, instant win on first reveal cascade

        self._rng = random.Random(self.seed)
        self._board = [[0 for _ in range(self.cols)] for _ in range(self.rows)]
        self._place_mines()
        self._calculate_numbers()

        # ── Edge case: 0 mines → all cells are safe, auto-win on init check ──
        self._check_win()

    def _place_mines(self):
        """Place mines randomly on the board."""
        if self.num_mines == 0:
            return  # No mines to place
        positions = [(r, c) for r in range(self.rows) for c in range(self.cols)]
        mine_positions = self._rng.sample(positions, self.num_mines)
        for r, c in mine_positions:
            self._board[r][c] = -1

    def _calculate_numbers(self):
        """Calculate numbers for each cell based on adjacent mines."""
        for r in range(self.rows):
            for c in range(self.cols):
                if self._board[r][c] == -1:
                    continue
                count = 0
                for dr in [-1, 0, 1]:
                    for dc in [-1, 0, 1]:
                        if dr == 0 and dc == 0:
                            continue
                        nr, nc = r + dr, c + dc
                        if 0 <= nr < self.rows and 0 <= nc < self.cols:
                            if self._board[nr][nc] == -1:
                                count += 1
                self._board[r][c] = count

    def _reveal_cell(self, row: int, col: int) -> bool:
        """Reveal a cell. Returns True if valid move, False if invalid.
        Uses iterative flood-fill to avoid recursion limit on large boards.
        """
        if not (0 <= row < self.rows and 0 <= col < self.cols):
            return False
        if (row, col) in self._revealed or (row, col) in self._flagged:
            return False

        stack = [(row, col)]
        while stack:
            r, c = stack.pop()
            if (r, c) in self._revealed:
                continue

            self._revealed.add((r, c))

            # Hit a mine!
            if self._board[r][c] == -1:
                self._state = "failed"
                return True

            # Auto-reveal neighbors if cell is 0
            if self._board[r][c] == 0:
                for dr in [-1, 0, 1]:
                    for dc in [-1, 0, 1]:
                        if dr == 0 and dc == 0:
                            continue
                        nr, nc = r + dr, c + dc
                        if (0 <= nr < self.rows and 0 <= nc < self.cols
                                and (nr, nc) not in self._revealed
                                and (nr, nc) not in self._flagged):
                            stack.append((nr, nc))

        return True

    def _flag_cell(self, row: int, col: int) -> bool:
        """Flag/unflag a cell. Returns True if valid, False if invalid."""
        if not (0 <= row < self.rows and 0 <= col < self.cols):
            return False
        if (row, col) in self._revealed:
            return False

        if (row, col) in self._flagged:
            self._flagged.remove((row, col))
        else:
            self._flagged.add((row, col))
        return True

    def do_action(self, action: dict) -> str:
        """Execute an action and return a status string.

        Returns one of:
          'ok'               - valid move executed
          'mine'             - revealed a mine (game over → state='failed')
          'win'              - game won after this move (all safe cells revealed)
          'invalid_format'   - bad action dict / missing keys / bad types
          'out_of_bounds'    - coordinates outside the board
          'already_revealed' - cell was already revealed
          'flagged_cell'     - tried to reveal a flagged cell
          'invalid_flag'     - tried to flag a revealed cell
          'game_over'        - game was already over before this call

        Only 'mine' sets state='failed'. All other invalid moves
        return an error string but keep the game 'ongoing'.
        NO move limit — game continues until success or failure.
        """
        if self._state != "ongoing":
            return "game_over"

        if not isinstance(action, dict):
            return "invalid_format"

        action_type = action.get("type")
        row = action.get("row")
        col = action.get("col")

        if action_type not in ("reveal", "flag") or row is None or col is None:
            return "invalid_format"

        try:
            row, col = int(row), int(col)
        except (ValueError, TypeError):
            return "invalid_format"

        if not (0 <= row < self.rows and 0 <= col < self.cols):
            return "out_of_bounds"

        if action_type == "reveal":
            if (row, col) in self._revealed:
                return "already_revealed"
            if (row, col) in self._flagged:
                return "flagged_cell"
            self._reveal_cell(row, col)
            self._move_count += 1
        else:  # flag
            if (row, col) in self._revealed:
                return "invalid_flag"
            self._flag_cell(row, col)
            self._move_count += 1

        self._check_win()

        if self._state == "failed":
            return "mine"
        if self._state == "success":
            return "win"
        return "ok"

    def _check_win(self):
        """Check if player has won.

        Win condition: ALL safe (non-mine) cells are revealed.
        With 0 mines, ALL cells are safe → need to reveal every cell.
        """
        if self._state != "ongoing":
            return
        total_cells = self.rows * self.cols
        safe_cells = total_cells - self.num_mines
        if safe_cells == 0:
            self._state = "success"
        elif len(self._revealed) >= safe_cells:
            self._state = "success"

    def get_visible_board(self) -> List[List[str]]:
        """Get board state as player sees it.
        Uses '?' for unrevealed cells (competition standard).
        """
        visible = []
        for r in range(self.rows):
            row = []
            for c in range(self.cols):
                if (r, c) in self._flagged:
                    row.append('F')
                elif (r, c) in self._revealed:
                    val = self._board[r][c]
                    row.append('*' if val == -1 else str(val))
                else:
                    row.append('?')
            visible.append(row)
        return visible

    def state(self) -> str:
        return self._state

    @property
    def move_count(self) -> int:
        return self._move_count

    def get_mine_positions(self) -> Set[Tuple[int, int]]:
        """Return set of all mine positions (for reward computation)."""
        return {(r, c) for r in range(self.rows) for c in range(self.cols)
                if self._board[r][c] == -1}

    def progress(self) -> float:
        """Fraction of safe cells revealed (0.0 to 1.0)."""
        safe_cells = self.rows * self.cols - self.num_mines
        return len(self._revealed) / safe_cells if safe_cells > 0 else 1.0

    def game_phase(self) -> str:
        """Determine the current game phase for prompt selection."""
        if self._move_count == 0 and len(self._revealed) == 0:
            return "opening"
        prog = self.progress()
        if prog >= 0.80:
            return "endgame"
        return "midgame"

    def pretty_print(self) -> str:
        """Pretty print the board."""
        visible = self.get_visible_board()
        lines = []

        # Header — handle up to 2-digit column numbers
        col_width = 3 if self.cols > 10 else 2
        header = "   " + " ".join(f"{i:>{col_width-1}d}" for i in range(self.cols))
        lines.append(header)
        lines.append("  " + "─" * (self.cols * col_width + 1))

        # Board
        for r, row in enumerate(visible):
            sep = " " * (col_width - 1)
            line = f"{r:2d}│ " + sep.join(row)
            lines.append(line)

        return "\n".join(lines)


# ──────────────────────────────────────────────────────────────────────
# Sanity tests
# ──────────────────────────────────────────────────────────────────────
print("Testing MinesweeperGame (competition spec: 1-50 boards, 0-20% mines)...")

# Basic gameplay
g = MinesweeperGame(5, 5, 3, seed=0)
assert g.state() == "ongoing"
assert g.do_action({"type": "reveal", "row": -1, "col": 0}) == "out_of_bounds"
assert g.state() == "ongoing", "BUG: out_of_bounds should NOT end the game"
assert g.do_action({"type": "reveal", "row": 99, "col": 0}) == "out_of_bounds"
assert g.state() == "ongoing"
assert g.do_action({"type": "flag", "row": 0, "col": 0}) == "ok"
assert g.do_action({"type": "reveal", "row": 0, "col": 0}) == "flagged_cell"
assert g.state() == "ongoing", "BUG: flagged_cell should NOT end the game"
assert g.do_action({}) == "invalid_format"
assert g.state() == "ongoing", "BUG: invalid_format should NOT end the game"
print("  ✅ do_action keeps game ongoing on invalid moves")

# Verify '?' is used for unrevealed cells
board = g.get_visible_board()
has_question = any('?' in row for row in board)
assert has_question, "Board should use '?' for unrevealed cells"
print("  ✅ Board uses '?' for unrevealed cells")

# Game phase tracking
assert g.game_phase() == "opening" or g.move_count > 0
print("  ✅ Game phase tracking works")

# ── Edge case: 0 mines board ──
g0 = MinesweeperGame(3, 3, 0, seed=42)
assert g0.state() == "ongoing"
result = g0.do_action({"type": "reveal", "row": 0, "col": 0})
assert result == "win", f"0-mine board should win on first reveal, got {result}"
assert g0.state() == "success"
print("  ✅ 0-mine board → instant win on first reveal")

# ── Edge case: 1x1 board with 0 mines ──
g1x1 = MinesweeperGame(1, 1, 0, seed=42)
assert g1x1.state() == "ongoing"
result = g1x1.do_action({"type": "reveal", "row": 0, "col": 0})
assert result == "win"
print("  ✅ 1x1 board with 0 mines works")

# ── Edge case: 1x1 board — cannot have mines ──
try:
    MinesweeperGame(1, 1, 1, seed=42)
    assert False, "Should have raised ValueError"
except ValueError:
    pass
print("  ✅ 1x1 board with 1 mine correctly rejected")

# ── Edge case: 50x50 board ──
g50 = MinesweeperGame(50, 50, 500, seed=42)
assert g50.state() == "ongoing"
assert g50.rows == 50 and g50.cols == 50
assert mine_density(50, 50, 500) == 0.20
print("  ✅ 50x50 board with 20% mines works")

# ── Edge cases: rectangular ──
g1x50 = MinesweeperGame(1, 50, 10, seed=42)
assert g1x50.rows == 1 and g1x50.cols == 50
g50x1 = MinesweeperGame(50, 1, 10, seed=42)
assert g50x1.rows == 50 and g50x1.cols == 1
print("  ✅ 1x50 and 50x1 boards work")

# ── Edge case: 2x2 with 0 mines ──
g2x2 = MinesweeperGame(2, 2, 0, seed=42)
result = g2x2.do_action({"type": "reveal", "row": 0, "col": 0})
assert result == "win"
print("  ✅ 2x2 board with 0 mines → instant cascade win")

# Variable sizes across the full range
for r, c, m in [(1,1,0), (2,2,0), (3,3,1), (5,5,3), (10,10,20),
                (20,20,80), (30,30,180), (50,50,500), (1,50,10), (50,1,10)]:
    g = MinesweeperGame(r, c, m, seed=42)
    assert g.rows == r and g.cols == c
    board = g.get_visible_board()
    assert len(board) == r and len(board[0]) == c
print(f"  ✅ Variable board sizes (1x1 to 50x50) work")

# Test sample_board_config produces valid configs
rng = random.Random(42)
for _ in range(200):
    r, c, m = sample_board_config(rng)
    assert MIN_ROWS <= r <= MAX_ROWS
    assert MIN_COLS <= c <= MAX_COLS
    assert 0 <= m <= int(r * c * MAX_MINE_DENSITY)
    g = MinesweeperGame(r, c, m, seed=42)
print(f"  ✅ sample_board_config produces valid configs (200 tested)")

# Test progress
g = MinesweeperGame(5, 5, 3, seed=42)
assert g.progress() == 0.0
print(f"  ✅ All game engine tests passed")
print(f"  Board range: {MIN_ROWS}-{MAX_ROWS} rows × {MIN_COLS}-{MAX_COLS} cols")
print(f"  Mine density: {MIN_MINE_DENSITY*100:.0f}%-{MAX_MINE_DENSITY*100:.0f}%")

Testing MinesweeperGame (competition spec: 1-50 boards, 0-20% mines)...
  ✅ do_action keeps game ongoing on invalid moves
  ✅ Board uses '?' for unrevealed cells
  ✅ Game phase tracking works
  ✅ 0-mine board → instant win on first reveal
  ✅ 1x1 board with 0 mines works
  ✅ 1x1 board with 1 mine correctly rejected
  ✅ 50x50 board with 20% mines works
  ✅ 1x50 and 50x1 boards work
  ✅ 2x2 board with 0 mines → instant cascade win
  ✅ Variable board sizes (1x1 to 50x50) work
  ✅ sample_board_config produces valid configs (200 tested)
  ✅ All game engine tests passed
  Board range: 1-50 rows × 1-50 cols
  Mine density: 0%-20%


# Prompt System & Game Logic Helpers

## Simplified Unified Prompt System
Single prompt format for all board sizes (1×1 to 50×50) — no training/inference split.

**Key Design Principles:**
- **Prompt minimalism** — Model learns from examples (GRPO fine-tuning), not verbose instructions
- **JSON-only output** — `max_completion_length=128` enforces direct action output
- **Pre-computed hints** — Safe/mine cell lists provide the logical analysis

## Prompt Structure (~300-400 tokens)
```
Board info → Grid → Legend → [Critical hint] → [Safe/Mine hints] → JSON format spec
```

## Board Representation
| Size | Display |
|------|---------|
| 1–30 | Full grid with column headers |
| 31–50 | Frontier zone only (revealed numbers + neighbors) |

## Critical Edge Cases (only 3)
1. **Zero mines** — "All cells safe"
2. **All mines flagged** — "Reveal any '?' to win"
3. **Last cell** — "Cell (r,c) is safe/mine"

## Output Format
```json
{"type": "reveal", "row": 2, "col": 3}
```

In [3]:
import json
import re

# ══════════════════════════════════════════════════════════════════════
# SIMPLIFIED PROMPTING SYSTEM
# ══════════════════════════════════════════════════════════════════════
# Replaces the over-engineered 3-tier board format (A/B/C) and dual-mode
# (training/inference) system with a single, unified prompt format.
#
# Key simplifications:
#   1. Single board representation for all sizes (1-50)
#   2. No training/inference mode split
#   3. No ReAct scaffolding (STEP 1-4) — model outputs JSON directly
#   4. Only 3 critical edge cases kept (zero-mines, all-flagged, last-cell)
#   5. ~350 tokens target (down from 800-1200)
# ══════════════════════════════════════════════════════════════════════


def _compute_safe_and_mine_cells(game: MinesweeperGame):
    """Compute both safe and mine cells in a SINGLE pass (O(n²) not O(n⁴)).

    Returns (safe_set, mine_set) where each is a set of (row, col) tuples.

    A cell is logically SAFE if any adjacent revealed number has all its
    mines accounted for by flags (remaining_mines == 0).
    A cell is a logically CERTAIN mine if an adjacent number has
    remaining_mines == remaining unrevealed neighbors.
    """
    safe = set()
    mines = set()

    for r in range(game.rows):
        for c in range(game.cols):
            if (r, c) not in game._revealed:
                continue
            val = game._board[r][c]
            if val <= 0:
                continue

            flags = 0
            unrevealed = []
            for dr in [-1, 0, 1]:
                for dc in [-1, 0, 1]:
                    if dr == 0 and dc == 0:
                        continue
                    nr, nc = r + dr, c + dc
                    if 0 <= nr < game.rows and 0 <= nc < game.cols:
                        if (nr, nc) in game._flagged:
                            flags += 1
                        elif (nr, nc) not in game._revealed:
                            unrevealed.append((nr, nc))

            remaining = val - flags

            if remaining == 0 and unrevealed:
                for cell in unrevealed:
                    safe.add(cell)
            elif remaining > 0 and remaining == len(unrevealed):
                for cell in unrevealed:
                    mines.add(cell)

    return safe, mines


def _compute_safe_cells(game: MinesweeperGame) -> list:
    """Return list of [row, col] for logically safe cells."""
    safe, _ = _compute_safe_and_mine_cells(game)
    return [list(c) for c in safe]


def _compute_mine_cells(game: MinesweeperGame) -> list:
    """Return list of [row, col] for logically certain mine cells."""
    _, mines = _compute_safe_and_mine_cells(game)
    return [list(c) for c in mines]


def _is_logically_safe(game: MinesweeperGame, row: int, col: int) -> bool:
    safe, _ = _compute_safe_and_mine_cells(game)
    return (row, col) in safe


def _is_logically_mine(game: MinesweeperGame, row: int, col: int) -> bool:
    _, mines = _compute_safe_and_mine_cells(game)
    return (row, col) in mines


# ──────────────────────────────────────────────────────────────────────
# UNIFIED BOARD REPRESENTATION
# Single format for all board sizes (1×1 to 50×50)
# ──────────────────────────────────────────────────────────────────────

def _format_board_unified(game: MinesweeperGame) -> str:
    """Single board format for all sizes.
    
    - Boards ≤30×30: Full grid with simple spacing
    - Boards 31-50: Frontier zone only (revealed numbers + unrevealed neighbors)
    """
    board = game.get_visible_board()
    
    # For large boards (31+), show only the frontier zone
    if game.rows > 30 or game.cols > 30:
        return _format_frontier_zone(game, board)
    
    # For small/medium boards, show full grid
    lines = []
    
    # Column header
    if game.cols <= 10:
        col_header = "   " + " ".join(f"{c}" for c in range(game.cols))
    else:
        col_header = "   " + "".join(f"{c:2d}" for c in range(game.cols))
    lines.append(col_header)
    
    # Grid rows
    for r, row in enumerate(board):
        if game.cols <= 10:
            lines.append(f"{r:2d} " + " ".join(row))
        else:
            lines.append(f"{r:2d} " + "".join(f"{v:>2s}" for v in row))
    
    return "\n".join(lines)


def _format_frontier_zone(game: MinesweeperGame, board: list) -> str:
    """For large boards (31+), show only frontier cells and their context."""
    # Find frontier cells (unrevealed adjacent to revealed numbers)
    frontier = set()
    number_cells = {}
    
    for r in range(game.rows):
        for c in range(game.cols):
            if (r, c) in game._revealed and game._board[r][c] > 0:
                number_cells[(r, c)] = game._board[r][c]
                for dr in [-1, 0, 1]:
                    for dc in [-1, 0, 1]:
                        nr, nc = r + dr, c + dc
                        if (0 <= nr < game.rows and 0 <= nc < game.cols
                                and (nr, nc) not in game._revealed):
                            frontier.add((nr, nc))
    
    if not frontier:
        # No frontier yet — show center region
        cr, cc = game.rows // 2, game.cols // 2
        r_min, r_max = max(0, cr - 5), min(game.rows, cr + 6)
        c_min, c_max = max(0, cc - 5), min(game.cols, cc + 6)
    else:
        # Find bounding box of frontier with padding
        rs = [r for r, c in frontier]
        cs = [c for r, c in frontier]
        r_min = max(0, min(rs) - 2)
        r_max = min(game.rows, max(rs) + 3)
        c_min = max(0, min(cs) - 2)
        c_max = min(game.cols, max(cs) + 3)
        
        # Limit to ~20 rows/cols max for readability
        if r_max - r_min > 20:
            mid_r = (r_min + r_max) // 2
            r_min, r_max = mid_r - 10, mid_r + 10
        if c_max - c_min > 20:
            mid_c = (c_min + c_max) // 2
            c_min, c_max = mid_c - 10, mid_c + 10
    
    lines = [f"[Showing rows {r_min}-{r_max-1}, cols {c_min}-{c_max-1}]"]
    col_header = "   " + "".join(f"{c:2d}" for c in range(c_min, c_max))
    lines.append(col_header)
    
    for r in range(r_min, r_max):
        cells = "".join(f"{board[r][c]:>2s}" for c in range(c_min, c_max))
        lines.append(f"{r:2d} {cells}")
    
    return "\n".join(lines)


# ──────────────────────────────────────────────────────────────────────
# SIMPLIFIED EDGE CASE GUIDANCE (only 3 critical cases)
# ──────────────────────────────────────────────────────────────────────

def _get_critical_hints(game: MinesweeperGame) -> str:
    """Return hints only for the 3 critical edge cases."""
    remaining = game.rows * game.cols - len(game._revealed) - len(game._flagged)
    remaining_mines = game.num_mines - len(game._flagged)
    
    # Case 1: Zero mines board
    if game.num_mines == 0:
        return "All cells safe (0 mines) — reveal any '?'"
    
    # Case 2: All mines flagged
    if remaining_mines == 0 and remaining > 0:
        return f"All {game.num_mines} mines flagged — reveal any '?' to win"
    
    # Case 3: Last cell
    if remaining == 1:
        for r in range(game.rows):
            for c in range(game.cols):
                if (r, c) not in game._revealed and (r, c) not in game._flagged:
                    if remaining_mines == 0:
                        return f"Last cell ({r},{c}) is safe — reveal to win"
                    else:
                        return f"Last cell ({r},{c}) is a mine — flag it"
    
    # Case 4: All remaining cells are mines
    if remaining > 0 and remaining == remaining_mines:
        return f"All {remaining} remaining cells are mines — flag any"
    
    return ""


# ──────────────────────────────────────────────────────────────────────
# UNIFIED PROMPT FORMAT
# Single format for training and inference (~350 tokens)
# ──────────────────────────────────────────────────────────────────────

SYSTEM_PROMPT = "You are a Minesweeper AI. Output ONLY valid JSON."


def format_state_for_llm(game: MinesweeperGame, mode=None) -> str:
    """Generate a simplified, unified prompt for any game state.
    
    Args:
        game: The MinesweeperGame instance
        mode: Ignored (kept for backward compatibility). Same prompt always.
    
    Returns:
        A prompt string (~250-400 tokens) with:
        - Board dimensions and mine count
        - Grid representation
        - Pre-computed safe/mine hints
        - Critical edge case hint (if applicable)
        - JSON output format specification
    """
    if game.state() == "success":
        return "Game won. No action needed."
    
    rows, cols = game.rows, game.cols
    mines = game.num_mines
    revealed = len(game._revealed)
    flagged = len(game._flagged)
    safe_total = rows * cols - mines
    
    # ── Board representation ──
    board_repr = _format_board_unified(game)
    
    # ── Pre-computed logical hints ──
    safe_cells = _compute_safe_cells(game)[:8]  # Limit to 8 for brevity
    mine_cells = _compute_mine_cells(game)[:8]
    
    hint_lines = []
    if safe_cells:
        hint_lines.append(f"Safe (100%): {safe_cells}")
    if mine_cells:
        hint_lines.append(f"Mines (100%): {mine_cells}")
    if not safe_cells and not mine_cells and revealed > 0:
        hint_lines.append("No certain moves — guess needed")
    
    hints = "\n".join(hint_lines) if hint_lines else ""
    
    # ── Critical edge case hint ──
    critical = _get_critical_hints(game)
    
    # ── Build prompt ──
    prompt = f"""Minesweeper {rows}×{cols}, {mines} mines
Revealed: {revealed}/{safe_total} | Flags: {flagged}/{mines}

{board_repr}

Legend: ?=unrevealed F=flagged 0-8=safe (number=adjacent mines)
"""
    
    if critical:
        prompt += f"\n{critical}\n"
    
    if hints:
        prompt += f"\n{hints}\n"
    
    prompt += f"""
Output JSON only: {{"type":"reveal","row":N,"col":N}} or {{"type":"flag","row":N,"col":N}}
Row: 0-{rows-1}, Col: 0-{cols-1}"""
    
    return prompt


def parse_llm_action(response: str) -> dict:
    """Extract JSON action from LLM response.

    Finds all JSON-like objects and returns the LAST one matching the
    expected schema (type + row + col). LLMs typically place their
    final answer at the end.
    """
    best = None
    for match in re.finditer(r'\{[^{}]*\}', response):
        try:
            action = json.loads(match.group())
            if ("type" in action and "row" in action and "col" in action
                    and action["type"] in ("reveal", "flag")):
                action["row"] = int(action["row"])
                action["col"] = int(action["col"])
                best = action
        except (json.JSONDecodeError, ValueError, TypeError):
            continue
    return best


# ──────────────────────────────────────────────────────────────────────
# TESTS
# ──────────────────────────────────────────────────────────────────────

print("Testing simplified unified prompt system...")
print("=" * 60)

# Test prompt length on various board sizes
print("\nPrompt lengths by board size:")
for rows, cols, mines in [(1,1,0), (3,3,1), (5,5,3), (6,6,5), (8,8,10),
                           (10,10,20), (15,15,45), (20,20,80), (1,10,2)]:
    if mines >= rows * cols:
        continue
    game = MinesweeperGame(rows=rows, cols=cols, num_mines=mines, seed=42)
    if game.state() == "ongoing":
        prompt = format_state_for_llm(game)
        print(f"  {rows:2d}×{cols:2d} m={mines:2d}: {len(prompt):4d} chars")

# Test large board (frontier zone)
print("\nLarge board (35×35) frontier zone format:")
game_large = MinesweeperGame(35, 35, 100, seed=42)
game_large.do_action({"type": "reveal", "row": 17, "col": 17})
prompt_large = format_state_for_llm(game_large)
print(f"  35×35: {len(prompt_large)} chars")

# Test edge cases
print("\nEdge case detection:")

# Zero mines
game_0 = MinesweeperGame(5, 5, 0, seed=42)
p = format_state_for_llm(game_0)
assert "0 mines" in p.lower() or "safe" in p.lower()
print("  ✅ Zero mines edge case")

# All mines flagged scenario
game_flagged = MinesweeperGame(3, 3, 1, seed=42)
game_flagged.do_action({"type": "reveal", "row": 0, "col": 0})
# Find and flag the mine
for r in range(3):
    for c in range(3):
        if game_flagged._board[r][c] == -1:
            game_flagged.do_action({"type": "flag", "row": r, "col": c})
            break
p = format_state_for_llm(game_flagged)
if "flagged" in p.lower():
    print("  ✅ All mines flagged edge case")
else:
    print("  ⚠️ All mines flagged case (may not trigger if game state changed)")

# Test parse_llm_action
assert parse_llm_action('{"type":"reveal","row":2,"col":3}') == {"type": "reveal", "row": 2, "col": 3}
assert parse_llm_action('blah {"type":"flag","row":"1","col":"2"} done') == {"type": "flag", "row": 1, "col": 2}
assert parse_llm_action('no json here') is None
assert parse_llm_action('{"type":"invalid","row":0,"col":0}') is None
print("  ✅ parse_llm_action handles all cases")

# Test backward compatibility (mode parameter ignored)
game = MinesweeperGame(6, 6, 5, seed=42)
p1 = format_state_for_llm(game)
p2 = format_state_for_llm(game, mode="training")
p3 = format_state_for_llm(game, mode="inference")
assert p1 == p2 == p3
print("  ✅ mode parameter backward compatible (ignored)")

# Show example prompt
print(f"\n{'=' * 60}")
print("EXAMPLE PROMPT (6×6 board after opening move):")
print("=" * 60)
game = MinesweeperGame(6, 6, 5, seed=42)
game.do_action({"type": "reveal", "row": 0, "col": 0})
prompt = format_state_for_llm(game)
print(prompt)
print(f"\n[Total: {len(prompt)} characters]")

# Verify we hit the target token range
print(f"\n{'=' * 60}")
print("SIMPLIFICATION SUMMARY:")
print("=" * 60)
print(f"  ✅ Single board format (was 3-tier A/B/C)")
print(f"  ✅ No training/inference mode split")
print(f"  ✅ No ReAct scaffolding (STEP 1-4)")
print(f"  ✅ Only 3 critical edge cases (was 9)")
print(f"  ✅ ~{len(prompt)} chars for 6×6 board (target: 250-400 tokens ≈ 1000-1600 chars)")
print(f"  ✅ Code reduced from ~350 lines to ~100 lines")

Testing simplified unified prompt system...

Prompt lengths by board size:
   1× 1 m= 0:  277 chars
   3× 3 m= 1:  260 chars
   5× 5 m= 3:  303 chars
   6× 6 m= 5:  330 chars
   8× 8 m=10:  398 chars
  10×10 m=20:  482 chars
  15×15 m=45:  776 chars
  20×20 m=80: 1156 chars
   1×10 m= 2:  271 chars

Large board (35×35) frontier zone format:
  35×35: 445 chars

Edge case detection:
  ✅ Zero mines edge case
  ✅ All mines flagged edge case
  ✅ parse_llm_action handles all cases
  ✅ mode parameter backward compatible (ignored)

EXAMPLE PROMPT (6×6 board after opening move):
Minesweeper 6×6, 5 mines
Revealed: 1/31 | Flags: 0/5

   0 1 2 3 4 5
 0 2 ? ? ? ? ?
 1 ? ? ? ? ? ?
 2 ? ? ? ? ? ?
 3 ? ? ? ? ? ?
 4 ? ? ? ? ? ?
 5 ? ? ? ? ? ?

Legend: ?=unrevealed F=flagged 0-8=safe (number=adjacent mines)

No certain moves — guess needed

Output JSON only: {"type":"reveal","row":N,"col":N} or {"type":"flag","row":N,"col":N}
Row: 0-5, Col: 0-5

[Total: 363 characters]

SIMPLIFICATION SUMMARY:
  ✅ Singl

# Test Model Before Training

See how the base model performs without finetuning:

In [None]:
from transformers import TextStreamer

game = MinesweeperGame(rows=6, cols=6, num_mines=5, seed=42)
prompt = format_state_for_llm(game)

text = tokenizer.apply_chat_template(
    [{"role": "user", "content": prompt}],
    tokenize = False,
    add_generation_prompt = True,
)

print("=== Base Model Response ===")
output = model.generate(
    **tokenizer(text, return_tensors = "pt").to("cuda"),
    temperature = 0.7,
    top_p = 0.9,
    max_new_tokens = 128,
    do_sample = True,
    streamer = TextStreamer(tokenizer, skip_prompt = True),
)

# GRPO Reward Functions

Define reward functions to guide the model's learning:

In [4]:
import numpy as np
import math

# ──────────────────────────────────────────────────────────────────────
# Helper: Reconstruct game from dataset columns
# ──────────────────────────────────────────────────────────────────────

def _reconstruct_game(idx, kwargs):
    """Reconstruct a MinesweeperGame from dataset columns passed via GRPO kwargs.

    The dataset stores: seed, move_history, board_rows, board_cols, board_mines
    GRPOTrainer passes all non-'prompt' columns as lists in **kwargs.

    Returns (game, move_history_list) or (None, None) if data is missing.
    Handles 0-mine boards correctly.
    """
    seeds = kwargs.get("seed", [])
    move_histories = kwargs.get("move_history", [])
    rows_list = kwargs.get("board_rows", [])
    cols_list = kwargs.get("board_cols", [])
    mines_list = kwargs.get("board_mines", [])

    if idx >= len(seeds) or idx >= len(move_histories):
        return None, None

    seed = seeds[idx]
    rows = rows_list[idx] if idx < len(rows_list) else 6
    cols = cols_list[idx] if idx < len(cols_list) else 6
    num_mines = mines_list[idx] if idx < len(mines_list) else 5

    mh_raw = move_histories[idx]
    if isinstance(mh_raw, str):
        move_history = json.loads(mh_raw)
    else:
        move_history = list(mh_raw)

    game = MinesweeperGame(rows=int(rows), cols=int(cols),
                           num_mines=int(num_mines), seed=int(seed))
    for prev in move_history:
        result = game.do_action(prev)
        if result == "mine":
            return None, None  # History hit a mine — bad data

    return game, move_history


# ──────────────────────────────────────────────────────────────────────
# Length penalty helper (GRPO-LEAD: α=0.05)
#
# Shorter correct responses → bonus, longer → penalty.
# Applied to Reward 1 (valid_json_reward) since that's the format reward.
# ──────────────────────────────────────────────────────────────────────
LENGTH_PENALTY_ALPHA = 0.15  # Stronger penalty (was 0.05) for JSON-only constraint


def _length_penalty(response: str) -> float:
    """Compute length-based reward modifier (GRPO-LEAD paper).

    HACKATHON VERSION: JSON-only output, 128-token constraint.
    Heavily rewards pure JSON (≤60 chars), severely penalizes extra text.

    Returns a value in [-2.0, +2.0]:
      Pure JSON (~30 chars)    → +2.0
      Near-pure (<60c)         → +1.5
      Some extra text (<100c)  → +0.5
      Moderate (100-200 chars) → -0.5
      Verbose (>200 chars)     → -1.5 to -2.0
    """
    n = len(response)
    if n <= 40:       # Pure JSON action
        return 2.0
    elif n <= 60:     # Near-pure JSON
        return 1.5
    elif n <= 100:    # Minor extra text
        return 0.5 * math.exp(-LENGTH_PENALTY_ALPHA * (n - 60) / 10)
    elif n <= 200:    # Significant extra text — penalty
        return -0.5 - 0.5 * ((n - 100) / 100)
    else:             # Way too verbose for 128-token JSON-only
        return -1.5 - min(0.5, (n - 200) / 200)


# ──────────────────────────────────────────────────────────────────────
# Difficulty reweighting helper (XRPO paper)
#
# Harder boards get amplified reward signal so the model doesn't
# ignore difficult scenarios. Uses board difficulty proxy:
#   difficulty = density × (1 + move_depth / total_cells)
#
# Multiplier: 0.4 + (1.1 / (1 + exp(10 * (success_rate - 0.75))))
# Since we don't have per-board success rates in a single pass,
# we use density + board size as difficulty proxy.
# ──────────────────────────────────────────────────────────────────────

def _difficulty_multiplier(game) -> float:
    """Compute difficulty-based reward multiplier (XRPO paper).

    Returns a value in [0.7, 1.5]:
      Easy boards (low density, small) → ~0.7-0.9  (slight damping)
      Medium boards                    → ~1.0
      Hard boards (high density, large) → ~1.2-1.5 (amplified signal)
    """
    board_size = game.rows * game.cols

    # Safety check: degenerate boards get neutral multiplier
    if board_size == 0 or game.num_mines == 0:
        return 1.0

    density = mine_density(game.rows, game.cols, game.num_mines)

    # Difficulty proxy: combination of density and board complexity
    # density ∈ [0, 0.2], size ∈ [1, 2500]
    # Normalize size: log scale so 1→0, 100→0.5, 2500→1.0
    size_factor = min(1.0, math.log(max(1, board_size)) / math.log(2500))

    # Combined difficulty ∈ [0, 1]
    difficulty = 0.6 * (density / 0.20) + 0.4 * size_factor

    # Sigmoid-based multiplier: harder → higher
    # At difficulty=0 → ~0.7, at difficulty=0.5 → ~1.0, at difficulty=1.0 → ~1.5
    multiplier = 0.7 + 0.8 / (1.0 + math.exp(-6.0 * (difficulty - 0.5)))

    return multiplier


# ──────────────────────────────────────────────────────────────────────
# Reward 1: Valid JSON format + conciseness + length penalty (GRPO-LEAD)
# ──────────────────────────────────────────────────────────────────────

def valid_json_reward(prompts, completions, **kwargs):
    """GRPO-LEAD: Length-regularized reward - HACKATHON VERSION (JSON-only, 128 tokens).

    Heavily rewards pure JSON (≤60 chars), penalizes ANY extra text.
    Constraint: max_new_tokens=128 for competition.

    Two-pass approach:
      Pass 1: Score format correctness & collect correct response lengths
      Pass 2: Apply z-score normalized length penalty to correct responses
    """
    # ── Pass 1: Evaluate format correctness & collect lengths ──
    results = []  # List of (action, response, base_score, is_correct)
    correct_lengths = []

    for completion in completions:
        response = completion[0]["content"].strip() if completion else ""
        action = parse_llm_action(response)

        if action is None:
            results.append((None, response, -3.0, False))
            continue

        # Check if it's PURE JSON (no extra text)
        try:
            parsed = json.loads(response)
            if "type" in parsed and "row" in parsed and "col" in parsed:
                # Pure JSON - MAXIMUM REWARD
                if len(response) <= 60:
                    correct_lengths.append(len(response))
                    results.append((action, response, 5.0, True))  # Ultra-short pure JSON
                elif len(response) <= 100:
                    correct_lengths.append(len(response))
                    results.append((action, response, 3.0, True))  # Standard pure JSON
                else:
                    correct_lengths.append(len(response))
                    results.append((action, response, 2.0, True))  # Pure but verbose
                continue
        except json.JSONDecodeError:
            pass

        # JSON found but with extra text — PENALTY
        json_match = re.search(r'\{[^{}]*\}', response)
        extra_chars = len(response) - len(json_match.group()) if json_match else len(response)

        if extra_chars <= 10:       # Minor formatting chars
            base = 1.5
        elif extra_chars <= 30:     # Some extra text
            base = 0.5
        elif extra_chars <= 100:    # Significant reasoning
            base = -0.5
        else:                       # Way too verbose
            base = -2.0             # HEAVY PENALTY

        correct_lengths.append(len(response))
        results.append((action, response, base, True))

    # ── Pass 2: Apply group-normalized length penalty (GRPO-LEAD) ──
    scores = []

    if len(correct_lengths) > 1:
        mean_len = np.mean(correct_lengths)
        std_len = np.std(correct_lengths) + 1e-8  # Avoid div by zero

        for action, response, base_score, is_correct in results:
            if not is_correct:
                scores.append(base_score)  # Invalid JSON — no length adjustment
            else:
                z_score = (len(response) - mean_len) / std_len
                length_multiplier = math.exp(-LENGTH_PENALTY_ALPHA * z_score)
                # Clamp multiplier to [0.5, 1.5] for stability
                length_multiplier = max(0.5, min(1.5, length_multiplier))
                scores.append(base_score * length_multiplier)
    else:
        # Not enough correct responses for group normalization
        # Fall back to absolute length penalty
        for action, response, base_score, is_correct in results:
            if not is_correct:
                scores.append(base_score)
            else:
                lp = _length_penalty(response)
                scores.append(base_score + lp)

    return scores


# ──────────────────────────────────────────────────────────────────────
# Reward 2: Gameplay — complete 12-criterion scoring
# Handles: 0-mine boards, 1x1 boards, large boards, all edge cases
# Now with difficulty reweighting (XRPO)
# ──────────────────────────────────────────────────────────────────────

def gameplay_scores(prompts, completions, **kwargs):
    """
    Complete gameplay reward implementing all 12 scoring criteria.
    Uses variable board sizes from dataset columns.
    No move limit — only success or failure.
    XRPO: Difficulty reweighting — harder boards get amplified signal.

    1.  Flag cell that IS a mine        → +15 / +20 (logical)
    2.  Flag cell that is NOT a mine    → -10
    3.  Reveal cell that IS a mine      → -25
    4.  Reveal cell that is safe        → +10 (guess) / +15 (logical)
    5.  Flag already flagged cell       → -12
    6.  Reveal already revealed cell    → -12
    7.  Out of bounds                   → -15
    8.  Total flags > total mines       → -10 (additional)
    9.  Invalid JSON                    → -10
    10. Win the game                    → +100 × size_scale
    11. Reveal a flagged cell           → -8
    12. Flag a revealed cell            → -8
    """
    scores = []

    for idx, completion in enumerate(completions):
        response = completion[0]["content"] if completion else ""
        action = parse_llm_action(response)

        # ── Criterion 9: Invalid JSON ──
        if action is None:
            scores.append(-10.0)
            continue

        # ── Reconstruct game state from dataset columns ──
        game, move_history = _reconstruct_game(idx, kwargs)
        if game is None:
            scores.append(0.0)
            continue

        # ── Difficulty multiplier (XRPO) ──
        diff_mult = _difficulty_multiplier(game)

        # ── Edge case: game already won (0-mine board or all safe revealed) ──
        if game.state() != "ongoing":
            scores.append(0.0)
            continue

        row, col = action["row"], action["col"]
        action_type = action["type"]

        # ── Criterion 7: Out of bounds ──
        if not (0 <= row < game.rows and 0 <= col < game.cols):
            scores.append(-15.0 * diff_mult)
            continue

        # ── Edge case: 0-mine board — any reveal is safe, any flag is wrong ──
        if game.num_mines == 0:
            if action_type == "reveal":
                if (row, col) in game._revealed:
                    scores.append(-12.0)
                else:
                    game_copy = MinesweeperGame(
                        rows=game.rows, cols=game.cols,
                        num_mines=0, seed=kwargs.get("seed", [0])[idx]
                    )
                    for prev in move_history:
                        game_copy.do_action(prev)
                    result = game_copy.do_action(action)
                    # Scale win bonus with board size
                    board_size = game.rows * game.cols
                    size_scale = 1.0 + min(1.0, board_size / 1000)
                    win_bonus = (100.0 * size_scale) if result == "win" else 0.0
                    scores.append(15.0 + win_bonus)
            else:
                scores.append(-10.0)
            continue

        # ── Compute logical deductions ONCE ──
        safe_set, mine_set = _compute_safe_and_mine_cells(game)

        score = 0.0

        if action_type == "reveal":
            # ── Criterion 6: Reveal already revealed cell ──
            if (row, col) in game._revealed:
                scores.append(-12.0 * diff_mult)
                continue

            # ── Criterion 11: Reveal a flagged cell ──
            if (row, col) in game._flagged:
                scores.append(-8.0 * diff_mult)
                continue

            # ── Criterion 3: Reveal a mine ──
            if game._board[row][col] == -1:
                # GRPO-LEAD: explicit wrong penalty (-1.0 additional)
                scores.append((-25.0 - 1.0) * diff_mult)
                continue

            # ── Criterion 4: Reveal safe cell ──
            if (row, col) in safe_set:
                score += 15.0   # Logically deduced safe cell
            else:
                score += 10.0   # Guessed safe cell

            # Small bonus for revealing near numbers (information-rich)
            board = game.get_visible_board()
            has_adjacent_number = False
            for dr in [-1, 0, 1]:
                for dc in [-1, 0, 1]:
                    nr, nc = row + dr, col + dc
                    if 0 <= nr < game.rows and 0 <= nc < game.cols:
                        if board[nr][nc] in ('1','2','3','4','5','6','7','8'):
                            has_adjacent_number = True
                            break
                if has_adjacent_number:
                    break
            if has_adjacent_number:
                score += 1.0

            # ── Criterion 10: Check for win ──
            game_copy = MinesweeperGame(
                rows=game.rows, cols=game.cols,
                num_mines=game.num_mines, seed=kwargs.get("seed", [0])[idx]
            )
            for prev in move_history:
                game_copy.do_action(prev)
            result = game_copy.do_action(action)
            if result == "win":
                # Scale win bonus with board size — 50×50 boards worth 2× a 6×6
                board_size = game.rows * game.cols
                size_scale = 1.0 + min(1.0, board_size / 1000)
                score += 100.0 * size_scale

        elif action_type == "flag":
            # ── Criterion 5: Flag already flagged cell ──
            if (row, col) in game._flagged:
                scores.append(-12.0 * diff_mult)
                continue

            # ── Criterion 12: Flag a revealed cell ──
            if (row, col) in game._revealed:
                scores.append(-8.0 * diff_mult)
                continue

            # ── Criterion 1: Flag a mine (correct) ──
            if game._board[row][col] == -1:
                if (row, col) in mine_set:
                    score += 20.0   # Logically deduced mine
                else:
                    score += 15.0   # Correct but guessed

            # ── Criterion 2: Flag a non-mine (wrong) ──
            else:
                # GRPO-LEAD: explicit wrong penalty
                score -= 10.0 + 1.0

            # ── Criterion 8: Total flags > total mines ──
            new_flag_count = len(game._flagged) + 1
            if new_flag_count > game.num_mines:
                score -= 10.0

        # Apply difficulty multiplier (XRPO)
        scores.append(score * diff_mult)

    return scores


# ──────────────────────────────────────────────────────────────────────
# Reward 3: Strategic play — rewards logical deduction over guessing
# Now with difficulty reweighting (XRPO)
# ──────────────────────────────────────────────────────────────────────

def strategic_reward(prompts, completions, **kwargs):
    """Reward strategic play patterns:
    - Choosing logically deducible moves when available
    - Opening in center (LAMER paper: center-opening for dense boards)
    - Penalize ignoring available deductions
    - Handles 0-mine boards (any reveal is correct)
    - Difficulty reweighting (XRPO)
    """
    scores = []

    for idx, completion in enumerate(completions):
        response = completion[0]["content"] if completion else ""
        action = parse_llm_action(response)

        if action is None:
            scores.append(0.0)
            continue

        game, move_history = _reconstruct_game(idx, kwargs)
        if game is None:
            scores.append(0.0)
            continue

        # Game already over — no strategic value
        if game.state() != "ongoing":
            scores.append(0.0)
            continue

        row, col = action["row"], action["col"]
        action_type = action["type"]
        score = 0.0

        if not (0 <= row < game.rows and 0 <= col < game.cols):
            scores.append(0.0)
            continue

        # ── Difficulty multiplier (XRPO) ──
        diff_mult = _difficulty_multiplier(game)

        # ── 0-mine board: any reveal is trivially correct ──
        if game.num_mines == 0:
            if action_type == "reveal":
                scores.append(2.0)
            else:
                scores.append(-2.0)
            continue

        # ── Compute logical deductions ONCE ──
        safe_set, mine_set = _compute_safe_and_mine_cells(game)

        # ── Fresh game opening strategy (LAMER paper: center-opening) ──
        if len(game._revealed) == 0 and action_type == "reveal":
            density_pct = mine_density(game.rows, game.cols, game.num_mines) * 100
            center_r, center_c = game.rows // 2, game.cols // 2
            dist_to_center = abs(row - center_r) + abs(col - center_c)
            max_dist = center_r + center_c

            if density_pct > 10:
                if dist_to_center == 0:
                    score += 5.0
                elif dist_to_center <= max(1, max_dist // 4):
                    score += 3.0
                elif dist_to_center <= max_dist // 2:
                    score += 1.0
                else:
                    score -= 2.0
            else:
                if dist_to_center <= max(1, max_dist // 3):
                    score += 2.0

        # ── Reward choosing logically deducible moves ──
        if action_type == "reveal" and (row, col) in safe_set:
            score += 5.0   # Chose a provably safe cell
        elif action_type == "flag" and (row, col) in mine_set:
            score += 5.0   # Chose a provably mine cell
        elif safe_set or mine_set:
            # Deducible moves existed but agent didn't pick one
            score -= 3.0

        # ── Penalize flagging on a fresh board (no info to deduce) ──
        if len(game._revealed) == 0 and action_type == "flag":
            score -= 2.0

        # Apply difficulty multiplier (XRPO)
        scores.append(score * diff_mult)

    return scores


# ── Verify reward function signatures ──
print("✅ All reward functions defined with correct TRL signature:")
print("   1. valid_json_reward — format + length penalty (GRPO-LEAD)")
print("   2. gameplay_scores   — 12 criteria + difficulty reweight (XRPO)")
print("   3. strategic_reward  — deduction + center-opening + difficulty (XRPO)")
print()

# ── Test length penalty ──
print("Length penalty tests (GRPO-LEAD):")
print(f"  Pure JSON (30c):  {_length_penalty('x' * 30):+.2f}")
print(f"  Brief (100c):     {_length_penalty('x' * 100):+.2f}")
print(f"  Moderate (300c):  {_length_penalty('x' * 300):+.2f}")
print(f"  Verbose (600c):   {_length_penalty('x' * 600):+.2f}")
print(f"  Very long (1000c):{_length_penalty('x' * 1000):+.2f}")
print()

# ── Test difficulty multiplier ──
print("Difficulty multiplier tests (XRPO):")
for rows, cols, mines in [(3,3,0), (5,5,3), (10,10,10), (10,10,20), (30,30,180), (50,50,500)]:
    g = MinesweeperGame(rows=rows, cols=cols, num_mines=mines, seed=42)
    dm = _difficulty_multiplier(g)
    d = mine_density(rows, cols, mines) * 100
    print(f"  {rows}x{cols} m={mines} ({d:.1f}%): ×{dm:.2f}")
print()

# ── Smoke test: simulate what GRPOTrainer passes ──
test_completions_pure = [[{"role": "assistant", "content": '{"type":"reveal","row":0,"col":0}'}]]
test_completions_verbose = [[{"role": "assistant", "content": 'Let me analyze this board step by step. The cell at row 0, col 0 looks promising because it has several revealed neighbors. ' + '{"type":"reveal","row":0,"col":0}'}]]
test_prompts = ["test"]

# Normal board
test_kwargs = {
    "seed": [42],
    "move_history": ["[]"],
    "board_rows": [6],
    "board_cols": [6],
    "board_mines": [5],
}

print("Smoke test (6x6 m=5):")
r1p = valid_json_reward(test_prompts, test_completions_pure, **test_kwargs)
r1v = valid_json_reward(test_prompts, test_completions_verbose, **test_kwargs)
r2 = gameplay_scores(test_prompts, test_completions_pure, **test_kwargs)
r3 = strategic_reward(test_prompts, test_completions_pure, **test_kwargs)
print(f"  Pure JSON:    format={r1p[0]:.2f}, gameplay={r2[0]:.1f}, strategic={r3[0]:.1f}")
print(f"  Verbose JSON: format={r1v[0]:.2f} (length penalty working)")

# 0-mine board
test_kwargs_zero = {
    "seed": [42],
    "move_history": ["[]"],
    "board_rows": [5],
    "board_cols": [5],
    "board_mines": [0],
}

r2z = gameplay_scores(test_prompts, test_completions_pure, **test_kwargs_zero)
r3z = strategic_reward(test_prompts, test_completions_pure, **test_kwargs_zero)
print(f"  0-mine board: gameplay={r2z[0]:.1f}, strategic={r3z[0]:.1f}")

# 1x1 board with 0 mines
test_kwargs_1x1 = {
    "seed": [42],
    "move_history": ["[]"],
    "board_rows": [1],
    "board_cols": [1],
    "board_mines": [0],
}

r2_1x1 = gameplay_scores(test_prompts, test_completions_pure, **test_kwargs_1x1)
print(f"  1x1 m=0:     gameplay={r2_1x1[0]:.1f}")

# Hard board — difficulty multiplier should amplify
test_kwargs_hard = {
    "seed": [42],
    "move_history": ["[]"],
    "board_rows": [20],
    "board_cols": [20],
    "board_mines": [80],
}

r2h = gameplay_scores(test_prompts, test_completions_pure, **test_kwargs_hard)
r3h = strategic_reward(test_prompts, test_completions_pure, **test_kwargs_hard)
print(f"  Hard 20x20:   gameplay={r2h[0]:.1f}, strategic={r3h[0]:.1f} (XRPO amplified)")

print(f"\n  ✅ Edge cases: 0-mine boards, 1x1 boards, hard boards handled")
print(f"  ✅ All reward functions work with kwargs (seed, move_history, board_rows/cols/mines)")
print(f"  ✅ Length penalty (GRPO-LEAD): pure JSON ≤60c → +5.0, verbose >100c → -2.0")
print(f"  ✅ Explicit wrong penalty (GRPO-LEAD): mine reveal -26, wrong flag -11")
print(f"  ✅ Difficulty reweighting (XRPO): harder boards → amplified signal")
print(f"  ✅ Win bonus scales with board size: 6×6→+104, 20×20→+140, 50×50→+250")
print(f"  ✅ HACKATHON: JSON-only output, 128-token constraint, α=0.15 length penalty")

✅ All reward functions defined with correct TRL signature:
   1. valid_json_reward — format + length penalty (GRPO-LEAD)
   2. gameplay_scores   — 12 criteria + difficulty reweight (XRPO)
   3. strategic_reward  — deduction + center-opening + difficulty (XRPO)

Length penalty tests (GRPO-LEAD):
  Pure JSON (30c):  +2.00
  Brief (100c):     +0.27
  Moderate (300c):  -2.00
  Verbose (600c):   -2.00
  Very long (1000c):-2.00

Difficulty multiplier tests (XRPO):
  3x3 m=0 (0.0%): ×1.00
  5x5 m=3 (12.0%): ×1.13
  10x10 m=10 (10.0%): ×1.14
  10x10 m=20 (20.0%): ×1.41
  30x30 m=180 (20.0%): ×1.45
  50x50 m=500 (20.0%): ×1.46

Smoke test (6x6 m=5):
  Pure JSON:    format=7.00, gameplay=12.2, strategic=-2.4
  Verbose JSON: format=-2.79 (length penalty working)
  0-mine board: gameplay=117.5, strategic=2.0
  1x1 m=0:     gameplay=115.1
  Hard 20x20:   gameplay=14.4, strategic=-2.9 (XRPO amplified)

  ✅ Edge cases: 0-mine boards, 1x1 boards, hard boards handled
  ✅ All reward functions work with 

# Exhaustive Training Dataset Generation

6-phase composition targeting **4000 samples** with density-stratified sampling:

| Phase | Budget | Description |
|-------|--------|-------------|
| 1. Edge Cases | 10% | 50+ explicit configs (trivial, linear, rectangular, density extremes) |
| 2. Opening | 25% | 75% fresh + 25% single-move — fix 75% early death rate |
| 3. Pattern-Specific | 15% | Satisfied numbers (60%) + multi-region boards (40%) |
| 4. Mid-Game | 25% | 3-15 moves, progressive flagging 10%→30%→50% |
| 5. Endgame | 15% | 80-98% revealed, flag accounting, completion strategy |
| 6. Forced Guess | 10% | No logical deductions available — probability reasoning |

In [5]:
import os
from datasets import Dataset

# ══════════════════════════════════════════════════════════════════════
#  EXHAUSTIVE DATASET GENERATION
#  Fixes: 75% early deaths, no pattern training, 0.7% endgame,
#         no forced-guess scenarios, insufficient flag training
# ══════════════════════════════════════════════════════════════════════


# ──────────────────────────────────────────────────────────────────────
# Helper: Smart move selection with progressive flagging
# ──────────────────────────────────────────────────────────────────────

def _smart_reveal(game, rng):
    """Pick a smart cell to reveal during history generation.
    Prefers safe cells (if known) to avoid hitting mines and losing the game.
    Falls back to random unrevealed cell.
    """
    safe_set, _ = _compute_safe_and_mine_cells(game)
    if safe_set:
        return rng.choice(list(safe_set))

    # Fallback: random unrevealed, unflagged cell
    unrevealed = [(r, c) for r in range(game.rows) for c in range(game.cols)
                  if (r, c) not in game._revealed and (r, c) not in game._flagged]
    if not unrevealed:
        return None
    return rng.choice(unrevealed)


def _smart_flag(game, rng):
    """Pick a logically certain mine to flag, or None if none available."""
    _, mine_set = _compute_safe_and_mine_cells(game)
    mine_candidates = [c for c in mine_set if c not in game._flagged]
    if mine_candidates:
        return rng.choice(mine_candidates)
    return None  # Don't flag randomly — creates bad training data


def _progressive_flag_probability(game):
    """Adaptive flagging based on game progress (LAMER-inspired).
    - Early game (0-30%):  10% flagging (explore first)
    - Mid game (30-70%):   30% flagging (deduce mines)
    - Late game (70-100%): 50% flagging (lock in certain mines)
    """
    progress = game.progress()
    if progress < 0.30:
        return 0.10
    elif progress < 0.70:
        return 0.30
    else:
        return 0.50


def _play_smart_moves(game, rng, num_moves, use_progressive_flags=True):
    """Play num_moves smart moves on a game. Returns move_history list.
    Uses progressive flagging strategy when enabled.
    Returns early if game ends or gets stuck.

    Fix #1: Added stuck_count to prevent infinite loops when no valid
    moves can be found (e.g., all cells revealed/flagged but game ongoing).
    """
    move_history = []
    stuck_count = 0
    MAX_STUCK = 10  # Abort after 10 consecutive failed attempts

    for _ in range(num_moves):
        if game.state() != "ongoing":
            break

        action_dict = None
        flag_prob = _progressive_flag_probability(game) if use_progressive_flags else 0.15

        # Try flagging with adaptive probability
        if rng.random() < flag_prob:
            flag_target = _smart_flag(game, rng)
            if flag_target:
                action_dict = {"type": "flag", "row": flag_target[0], "col": flag_target[1]}

        # Fall back to reveal
        if action_dict is None:
            reveal_target = _smart_reveal(game, rng)
            if reveal_target is None:
                stuck_count += 1
                if stuck_count >= MAX_STUCK:
                    break  # Prevent infinite loop
                continue
            action_dict = {"type": "reveal", "row": reveal_target[0], "col": reveal_target[1]}

        result = game.do_action(action_dict)
        if result == "mine":
            break  # Hit a mine — stop

        move_history.append(action_dict)
        stuck_count = 0  # Reset on successful move

    return move_history


# ──────────────────────────────────────────────────────────────────────
# Scenario generators: endgame, forced-guess, multi-region
# ──────────────────────────────────────────────────────────────────────

def _generate_endgame_state(rows, cols, num_mines, rng, completion_target=0.85):
    """Generate boards that are 80-98% complete.
    Critical for teaching finishing strategy and flag accounting.
    Returns (game, move_history) or (None, None) on failure.
    """
    seed = rng.randint(0, 999999)
    game = MinesweeperGame(rows=rows, cols=cols, num_mines=num_mines, seed=seed)

    if game.state() != "ongoing":
        return None, None, seed

    safe_total = rows * cols - num_mines
    if safe_total <= 1:
        return None, None, seed

    target_revealed = int(safe_total * rng.uniform(completion_target, 0.98))
    target_revealed = max(1, min(target_revealed, safe_total - 1))

    move_history = []
    stuck_count = 0
    max_stuck = 20

    while len(game._revealed) < target_revealed and game.state() == "ongoing":
        safe_set, mine_set = _compute_safe_and_mine_cells(game)

        if safe_set:
            target = rng.choice(list(safe_set))
            action_dict = {"type": "reveal", "row": target[0], "col": target[1]}
        else:
            # No logical moves — reveal a random safe cell (we know the board)
            all_safe_unrevealed = [
                (r, c) for r in range(rows) for c in range(cols)
                if game._board[r][c] != -1
                and (r, c) not in game._revealed
                and (r, c) not in game._flagged
            ]
            if not all_safe_unrevealed:
                break
            target = rng.choice(all_safe_unrevealed)
            action_dict = {"type": "reveal", "row": target[0], "col": target[1]}
            stuck_count += 1
            if stuck_count >= max_stuck:
                break

        result = game.do_action(action_dict)
        if result == "mine":
            return None, None, seed
        move_history.append(action_dict)

    # Optionally add some flags in endgame
    if game.state() == "ongoing":
        _, mine_set = _compute_safe_and_mine_cells(game)
        flaggable = [c for c in mine_set if c not in game._flagged]
        if flaggable and rng.random() < 0.7:
            num_flags = rng.randint(1, min(len(flaggable), 5))
            for target in rng.sample(flaggable, num_flags):
                action_dict = {"type": "flag", "row": target[0], "col": target[1]}
                game.do_action(action_dict)
                move_history.append(action_dict)

    if game.state() != "ongoing":
        return None, None, seed

    return game, move_history, seed


def _generate_forced_guess_state(rows, cols, num_mines, rng):
    """Generate a board state where no 100% certain logical moves exist.
    These teach the model probability-based guessing.
    Returns (game, move_history, seed) or (None, None, seed).
    """
    seed = rng.randint(0, 999999)
    game = MinesweeperGame(rows=rows, cols=cols, num_mines=num_mines, seed=seed)

    if game.state() != "ongoing":
        return None, None, seed

    move_history = []
    max_reveal_attempts = rows * cols

    for _ in range(max_reveal_attempts):
        if game.state() != "ongoing":
            return None, None, seed

        safe_set, mine_set = _compute_safe_and_mine_cells(game)

        if not safe_set and not mine_set and len(game._revealed) > 0:
            # No logical moves available — this is what we want!
            unrevealed_count = sum(
                1 for r in range(rows) for c in range(cols)
                if (r, c) not in game._revealed and (r, c) not in game._flagged
            )
            if unrevealed_count >= 2:
                return game, move_history, seed

        if safe_set:
            target = rng.choice(list(safe_set))
            action_dict = {"type": "reveal", "row": target[0], "col": target[1]}
        else:
            # Must guess — reveal random safe cell (using board knowledge)
            all_safe = [
                (r, c) for r in range(rows) for c in range(cols)
                if game._board[r][c] != -1
                and (r, c) not in game._revealed
                and (r, c) not in game._flagged
            ]
            if not all_safe:
                break
            target = rng.choice(all_safe)
            action_dict = {"type": "reveal", "row": target[0], "col": target[1]}

        result = game.do_action(action_dict)
        if result == "mine":
            return None, None, seed
        move_history.append(action_dict)

    # Fallback: return whatever state we reached (might still have logical moves)
    if game.state() == "ongoing" and len(game._revealed) > 0:
        return game, move_history, seed
    return None, None, seed


def _generate_multi_region_state(rows, cols, num_mines, rng):
    """Create board with multiple separated revealed regions.
    Teaches model to reason across disconnected information.
    Returns (game, move_history, seed) or (None, None, seed).
    """
    if rows < 6 or cols < 6:
        return None, None, 0  # Need space for multiple regions

    seed = rng.randint(0, 999999)
    game = MinesweeperGame(rows=rows, cols=cols, num_mines=num_mines, seed=seed)

    if game.state() != "ongoing":
        return None, None, seed

    # Pick 2-4 starting points in different quadrants
    quadrant_centers = [
        (rows // 4, cols // 4),
        (rows // 4, 3 * cols // 4),
        (3 * rows // 4, cols // 4),
        (3 * rows // 4, 3 * cols // 4),
    ]
    rng.shuffle(quadrant_centers)
    num_regions = rng.randint(2, min(4, len(quadrant_centers)))
    selected = quadrant_centers[:num_regions]

    move_history = []
    for r, c in selected:
        if game.state() != "ongoing":
            break
        r = max(0, min(r, rows - 1))
        c = max(0, min(c, cols - 1))

        if (r, c) not in game._revealed and game._board[r][c] != -1:
            action_dict = {"type": "reveal", "row": r, "col": c}
            result = game.do_action(action_dict)
            if result == "mine":
                return None, None, seed
            move_history.append(action_dict)

        # Reveal a few logical neighbors around this region
        for _ in range(rng.randint(1, 3)):
            if game.state() != "ongoing":
                break
            safe_set, _ = _compute_safe_and_mine_cells(game)
            if safe_set:
                # Prefer safe cells near this region
                nearby = [
                    (sr, sc) for sr, sc in safe_set
                    if abs(sr - r) <= rows // 3 and abs(sc - c) <= cols // 3
                ]
                target = rng.choice(nearby) if nearby else rng.choice(list(safe_set))
                action_dict = {"type": "reveal", "row": target[0], "col": target[1]}
                result = game.do_action(action_dict)
                if result == "mine":
                    return None, None, seed
                move_history.append(action_dict)

    if game.state() == "ongoing" and len(game._revealed) > 0:
        return game, move_history, seed
    return None, None, seed


def _generate_satisfied_numbers_state(rows, cols, num_mines, rng):
    """Generate board with multiple satisfied numbers (easy deductions available).
    Satisfied number = number whose count of adjacent flags equals its value.
    All remaining unrevealed neighbors of a satisfied number are safe.
    Teaches basic constraint satisfaction.
    Returns (game, move_history, seed).
    """
    seed = rng.randint(0, 999999)
    game = MinesweeperGame(rows=rows, cols=cols, num_mines=num_mines, seed=seed)

    if game.state() != "ongoing":
        return None, None, seed

    move_history = []

    # Reveal some cells first
    num_initial = rng.randint(3, max(3, min(15, rows * cols // 4)))
    for _ in range(num_initial):
        if game.state() != "ongoing":
            break
        target = _smart_reveal(game, rng)
        if target is None:
            break
        action_dict = {"type": "reveal", "row": target[0], "col": target[1]}
        result = game.do_action(action_dict)
        if result == "mine":
            return None, None, seed
        move_history.append(action_dict)

    # Now flag logically certain mines to create satisfied numbers
    if game.state() == "ongoing":
        for _ in range(5):
            _, mine_set = _compute_safe_and_mine_cells(game)
            flaggable = [c for c in mine_set if c not in game._flagged]
            if not flaggable:
                break
            target = rng.choice(flaggable)
            action_dict = {"type": "flag", "row": target[0], "col": target[1]}
            game.do_action(action_dict)
            move_history.append(action_dict)

    if game.state() == "ongoing" and len(game._revealed) > 0:
        return game, move_history, seed
    return None, None, seed


# ──────────────────────────────────────────────────────────────────────
# Density-stratified board sampling
# ──────────────────────────────────────────────────────────────────────

DENSITY_TARGETS = [
    # (density_range, weight, label)
    ((0.00, 0.00), 0.08, "Zero mines"),       # 8%  - trivial edge case
    ((0.01, 0.05), 0.17, "Very sparse"),       # 17%
    ((0.05, 0.10), 0.25, "Sparse"),            # 25%
    ((0.10, 0.15), 0.25, "Medium"),            # 25%
    ((0.15, 0.20), 0.25, "Dense/Max"),         # 25%
]


def _sample_board_with_density(rng, target_density_range=None):
    """Sample board config with explicit density control.
    If target_density_range is None, picks one from DENSITY_TARGETS.
    """
    if target_density_range is None:
        # Weighted random density band
        weights = [w for _, w, _ in DENSITY_TARGETS]
        ranges = [r for r, _, _ in DENSITY_TARGETS]
        idx = rng.choices(range(len(DENSITY_TARGETS)), weights=weights, k=1)[0]
        target_density_range = ranges[idx]

    # Sample board size — boosted large-board representation for 50×50 support
    size_band = rng.random()
    if size_band < 0.25:
        rows, cols = rng.randint(1, 8), rng.randint(1, 8)      # 25% tiny
    elif size_band < 0.45:
        rows, cols = rng.randint(5, 15), rng.randint(5, 15)    # 20% small
    elif size_band < 0.65:
        rows, cols = rng.randint(10, 30), rng.randint(10, 30)  # 20% medium
    elif size_band < 0.85:
        rows, cols = rng.randint(20, 40), rng.randint(20, 40)  # 20% large
    else:
        rows, cols = rng.randint(30, 50), rng.randint(30, 50)  # 15% XL (30-50)

    total = rows * cols
    min_d, max_d = target_density_range

    if min_d == 0.0 and max_d == 0.0:
        return rows, cols, 0

    min_mines = max(0, int(math.ceil(total * min_d)))
    max_mines = min(int(total * max_d), int(total * MAX_MINE_DENSITY), total - 1)

    if max_mines < min_mines:
        num_mines = max(0, min(min_mines, total - 1))
    else:
        num_mines = rng.randint(min_mines, max_mines)

    return rows, cols, num_mines


# ──────────────────────────────────────────────────────────────────────
# Exhaustive edge-case configs (50+ scenarios)
# ──────────────────────────────────────────────────────────────────────

EDGE_CASE_CONFIGS = [
    # (rows, cols, mines, label)

    # === TRIVIAL BOARDS ===
    (1, 1, 0,   "1x1 trivial"),
    (2, 2, 0,   "2x2 no mines"),
    (3, 3, 0,   "3x3 no mines"),
    (4, 4, 0,   "4x4 no mines"),
    (5, 5, 0,   "5x5 no mines - cascade practice"),

    # === LINEAR BOARDS (1D Minesweeper) ===
    (1, 5, 1,   "1x5 single mine"),
    (1, 10, 2,  "1x10 two mines"),
    (1, 20, 4,  "1x20 row"),
    (1, 50, 10, "1x50 maximum row"),
    (5, 1, 1,   "5x1 column"),
    (10, 1, 2,  "10x1 column"),
    (20, 1, 4,  "20x1 column"),
    (50, 1, 10, "50x1 maximum column"),

    # === TINY BOARDS ===
    (1, 2, 0,   "1x2 trivial"),
    (2, 1, 0,   "2x1 trivial"),
    (2, 2, 1,   "2x2 single mine"),
    (2, 3, 1,   "2x3 mini"),
    (3, 2, 1,   "3x2 mini"),
    (3, 3, 1,   "3x3 single mine"),
    (3, 3, 2,   "3x3 two mines"),
    (4, 4, 3,   "4x4 medium density"),

    # === EXTREME RECTANGULAR ===
    (2, 50, 20, "2x50 ultra-wide"),
    (50, 2, 20, "50x2 ultra-tall"),
    (3, 40, 24, "3x40 extreme ratio"),
    (40, 3, 24, "40x3 extreme ratio"),
    (5, 50, 50, "5x50 max density wide"),
    (50, 5, 50, "50x5 max density tall"),

    # === CLASSIC MINESWEEPER SIZES ===
    (8, 8, 1,   "8x8 very sparse"),
    (8, 8, 5,   "8x8 sparse"),
    (8, 8, 10,  "8x8 medium - classic beginner"),
    (8, 8, 13,  "8x8 max density"),
    (16, 16, 10, "16x16 sparse"),
    (16, 16, 40, "16x16 medium - classic intermediate"),
    (16, 16, 51, "16x16 max density"),
    (16, 30, 20, "16x30 rectangular sparse"),
    (16, 30, 96, "16x30 rectangular dense - classic expert"),

    # === LARGE BOARDS ===
    (20, 20, 10, "20x20 very sparse"),
    (20, 20, 80, "20x20 max density"),
    (25, 25, 30, "25x25 sparse"),
    (25, 25, 125, "25x25 max density"),
    (30, 30, 50, "30x30 sparse"),
    (30, 30, 180, "30x30 max density"),
    (40, 40, 100, "40x40 sparse"),
    (40, 40, 320, "40x40 max density"),

    # === MAXIMUM SIZE ===
    (50, 50, 10,  "50x50 ultra-sparse"),
    (50, 50, 100, "50x50 sparse"),
    (50, 50, 250, "50x50 medium"),
    (50, 50, 500, "50x50 maximum density"),

    # === DENSITY EDGE CASES ===
    (10, 10, 0,  "10x10 no mines"),
    (15, 15, 0,  "15x15 no mines - large trivial"),
    (20, 20, 0,  "20x20 no mines - huge trivial"),
    (10, 10, 1,  "10x10 single mine"),
    (20, 20, 1,  "20x20 single mine in large board"),

    # === MORE RECTANGULAR VARIETY ===
    (3, 50, 30,  "3x50 wide"),
    (50, 3, 30,  "50x3 tall"),
    (5, 15, 15,  "5x15 wide rect"),
    (15, 5, 15,  "15x5 tall rect"),
    (6, 8, 6,    "6x8 rect"),
    (8, 6, 6,    "8x6 rect"),
    (7, 13, 18,  "7x13 odd rect"),
    (13, 7, 18,  "13x7 odd rect"),
    (15, 30, 90, "15x30 wide large"),
]


# ──────────────────────────────────────────────────────────────────────
# Dataset item builder (DRY helper)
# ──────────────────────────────────────────────────────────────────────

def _build_dataset_item(game, seed, move_history):
    """Build a single dataset item dict from a game state."""
    prompt_text = format_state_for_llm(game)
    return {
        "prompt": [{"role": "user", "content": prompt_text}],
        "seed": seed,
        "move_history": json.dumps(move_history),
        "board_rows": game.rows,
        "board_cols": game.cols,
        "board_mines": game.num_mines,
    }


# ══════════════════════════════════════════════════════════════════════
#  MAIN GENERATOR — 6-Phase Exhaustive Composition
# ══════════════════════════════════════════════════════════════════════

def generate_exhaustive_dataset(num_samples=1000, rng_seed=42):
    """
    Comprehensive Minesweeper training dataset covering ALL scenarios.

    6-Phase composition:
      Phase 1: Edge cases           — 10%  (50+ explicit configs)
      Phase 2: Opening-heavy        — 25%  (fresh + single-move boards)
      Phase 3: Pattern-specific     — 15%  (satisfied numbers, multi-region)
      Phase 4: Mid-game deduction   — 25%  (core logical reasoning)
      Phase 5: Endgame completion   — 15%  (80-98% revealed, flag accounting)
      Phase 6: Forced guess         — 10%  (no logical moves available)

    Improvements over previous version:
      ✅ 50+ edge case configs (was 25)
      ✅ Progressive flagging 10%→30%→50% by game phase (was flat 15%)
      ✅ Density-stratified board sampling
      ✅ Dedicated endgame generator (was 0.7% late-game)
      ✅ Forced-guess scenario training (was 0%)
      ✅ Multi-region disconnected boards (was 0%)
      ✅ Satisfied-number pattern training (was 0%)
      ✅ 4000 samples (was 3000)
    """
    rng = random.Random(rng_seed)
    np.random.seed(rng_seed)

    dataset_items = []
    phase_counts = {
        "edge_case": 0, "opening": 0, "pattern": 0,
        "midgame": 0, "endgame": 0, "forced_guess": 0,
    }
    config_counts = {}
    density_counts = {"zero": 0, "very_sparse": 0, "sparse": 0, "medium": 0, "dense": 0}

    def _track_config(game):
        key = f"{game.rows}x{game.cols}m{game.num_mines}"
        config_counts[key] = config_counts.get(key, 0) + 1
        d = game.num_mines / (game.rows * game.cols) if game.rows * game.cols > 0 else 0
        if d == 0:
            density_counts["zero"] += 1
        elif d <= 0.05:
            density_counts["very_sparse"] += 1
        elif d <= 0.10:
            density_counts["sparse"] += 1
        elif d <= 0.15:
            density_counts["medium"] += 1
        else:
            density_counts["dense"] += 1

    # Budget per phase
    n_edge     = int(num_samples * 0.10)
    n_opening  = int(num_samples * 0.25)
    n_pattern  = int(num_samples * 0.15)
    n_midgame  = int(num_samples * 0.25)
    n_endgame  = int(num_samples * 0.15)
    n_forced   = num_samples - n_edge - n_opening - n_pattern - n_midgame - n_endgame

    print(f"  Phase budgets: edge={n_edge}, opening={n_opening}, pattern={n_pattern}, "
          f"midgame={n_midgame}, endgame={n_endgame}, forced={n_forced}")

    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    # PHASE 1: Edge Cases (10%) — 50+ explicit configs
    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    print("  Phase 1: Edge cases...")
    edge_generated = 0
    edge_idx = 0
    while edge_generated < n_edge:
        ec_rows, ec_cols, ec_mines, ec_label = EDGE_CASE_CONFIGS[edge_idx % len(EDGE_CASE_CONFIGS)]
        edge_idx += 1
        seed = rng.randint(0, 999999)

        game = MinesweeperGame(rows=ec_rows, cols=ec_cols, num_mines=ec_mines, seed=seed)
        if game.state() != "ongoing":
            continue

        # 0-3 moves for variety
        num_moves = rng.randint(0, min(3, max(0, ec_rows * ec_cols - ec_mines - 1)))
        move_history = _play_smart_moves(game, rng, num_moves, use_progressive_flags=True)

        if game.state() == "ongoing":
            dataset_items.append(_build_dataset_item(game, seed, move_history))
            _track_config(game)
            phase_counts["edge_case"] += 1
            edge_generated += 1

    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    # PHASE 2: Opening-Heavy Training (25%) — Fix 75% early death rate
    # 75% fresh (0 moves), 25% single-move
    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    print("  Phase 2: Opening training...")
    opening_generated = 0
    opening_attempts = 0
    while opening_generated < n_opening and opening_attempts < n_opening * 5:
        opening_attempts += 1
        rows, cols, num_mines = _sample_board_with_density(rng)

        seed = rng.randint(0, 999999)
        game = MinesweeperGame(rows=rows, cols=cols, num_mines=num_mines, seed=seed)

        if num_mines == 0:
            if game.state() == "ongoing":
                dataset_items.append(_build_dataset_item(game, seed, []))
                _track_config(game)
                phase_counts["opening"] += 1
                opening_generated += 1
            continue

        if game.state() != "ongoing":
            continue

        # 75% fresh boards, 25% single-move
        num_moves = 0 if rng.random() < 0.75 else 1
        move_history = _play_smart_moves(game, rng, num_moves, use_progressive_flags=False)

        if game.state() == "ongoing":
            dataset_items.append(_build_dataset_item(game, seed, move_history))
            _track_config(game)
            phase_counts["opening"] += 1
            opening_generated += 1

    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    # PHASE 3: Pattern-Specific Scenarios (15%)
    # Mix of: satisfied-number boards (60%), multi-region boards (40%)
    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    print("  Phase 3: Pattern-specific scenarios...")
    pattern_generated = 0
    pattern_attempts = 0
    while pattern_generated < n_pattern and pattern_attempts < n_pattern * 10:
        pattern_attempts += 1

        # 60% satisfied numbers, 40% multi-region
        if rng.random() < 0.60:
            # Satisfied numbers — need boards with mines
            rows, cols, num_mines = _sample_board_with_density(rng, (0.05, 0.20))
            if rows < 3 or cols < 3:
                continue
            game, move_history, seed = _generate_satisfied_numbers_state(
                rows, cols, num_mines, rng
            )
        else:
            # Multi-region — need larger boards
            rows, cols, num_mines = _sample_board_with_density(rng, (0.05, 0.15))
            rows = max(rows, 8)
            cols = max(cols, 8)
            num_mines = min(num_mines, int(rows * cols * 0.15))
            if num_mines < 1:
                num_mines = max(1, int(rows * cols * 0.05))
            game, move_history, seed = _generate_multi_region_state(
                rows, cols, num_mines, rng
            )

        if game is not None and game.state() == "ongoing":
            dataset_items.append(_build_dataset_item(game, seed, move_history))
            _track_config(game)
            phase_counts["pattern"] += 1
            pattern_generated += 1

    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    # PHASE 4: Mid-Game Logical Deduction (25%) — Core gameplay
    # 3-15 moves played, progressive flagging, density-stratified
    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    print("  Phase 4: Mid-game deduction...")
    midgame_generated = 0
    midgame_attempts = 0
    while midgame_generated < n_midgame and midgame_attempts < n_midgame * 5:
        midgame_attempts += 1
        rows, cols, num_mines = _sample_board_with_density(rng)

        if num_mines == 0:
            continue  # Skip zero-mine for midgame

        seed = rng.randint(0, 999999)
        game = MinesweeperGame(rows=rows, cols=cols, num_mines=num_mines, seed=seed)

        if game.state() != "ongoing":
            continue

        total_safe = rows * cols - num_mines
        max_moves = min(15, max(3, total_safe - 1))
        num_moves = rng.randint(3, max_moves)

        move_history = _play_smart_moves(game, rng, num_moves, use_progressive_flags=True)

        if game.state() == "ongoing" and len(game._revealed) > 0:
            dataset_items.append(_build_dataset_item(game, seed, move_history))
            _track_config(game)
            phase_counts["midgame"] += 1
            midgame_generated += 1

    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    # PHASE 5: Endgame Completion (15%) — 80-98% revealed
    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    print("  Phase 5: Endgame completion...")
    endgame_generated = 0
    endgame_attempts = 0
    while endgame_generated < n_endgame and endgame_attempts < n_endgame * 10:
        endgame_attempts += 1
        rows, cols, num_mines = _sample_board_with_density(rng, (0.05, 0.20))

        if num_mines < 1 or rows * cols < 8:
            continue

        completion = rng.uniform(0.80, 0.95)
        game, move_history, seed = _generate_endgame_state(
            rows, cols, num_mines, rng, completion_target=completion
        )

        if game is not None and game.state() == "ongoing":
            dataset_items.append(_build_dataset_item(game, seed, move_history))
            _track_config(game)
            phase_counts["endgame"] += 1
            endgame_generated += 1

    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    # PHASE 6: Forced Guess Scenarios (10%) — No logical deductions
    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    print("  Phase 6: Forced guess scenarios...")
    forced_generated = 0
    forced_attempts = 0
    while forced_generated < n_forced and forced_attempts < n_forced * 15:
        forced_attempts += 1
        # Dense boards more likely to produce forced-guess states
        rows, cols, num_mines = _sample_board_with_density(rng, (0.10, 0.20))

        if num_mines < 2 or rows * cols < 6:
            continue

        game, move_history, seed = _generate_forced_guess_state(
            rows, cols, num_mines, rng
        )

        if game is not None and game.state() == "ongoing":
            dataset_items.append(_build_dataset_item(game, seed, move_history))
            _track_config(game)
            phase_counts["forced_guess"] += 1
            forced_generated += 1

    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    # Shuffle and trim
    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    rng.shuffle(dataset_items)
    dataset_items = dataset_items[:num_samples]
    ds = Dataset.from_list(dataset_items)

    return ds, config_counts, phase_counts, density_counts


# ══════════════════════════════════════════════════════════════════════
#  Generate & Analyze
# ══════════════════════════════════════════════════════════════════════

print("=" * 70)
print("  EXHAUSTIVE DATASET GENERATION")
print("  4000 samples | 6 phases | 50+ edge cases | density-stratified")
print("=" * 70)
print()

dataset, config_counts, phase_counts, density_counts = generate_exhaustive_dataset(
    num_samples=1000, rng_seed=42
)

print(f"\n{'─'*70}")
print(f"Created {len(dataset)} training examples\n")

# Phase distribution
print("Phase distribution:")
for phase, count in phase_counts.items():
    pct = count / len(dataset) * 100 if len(dataset) > 0 else 0
    print(f"  {phase:14s}: {count:4d} ({pct:5.1f}%)")

# Density distribution
print(f"\nDensity distribution:")
for band, count in density_counts.items():
    pct = count / len(dataset) * 100 if len(dataset) > 0 else 0
    print(f"  {band:14s}: {count:4d} ({pct:5.1f}%)")

# Board size distribution (top 20)
print(f"\nBoard size distribution (top 20):")
sorted_configs = sorted(config_counts.items(), key=lambda x: -x[1])
for config, count in sorted_configs[:20]:
    pct = count / len(dataset) * 100 if len(dataset) > 0 else 0
    print(f"  {config:16s}: {count:4d} ({pct:.1f}%)")
if len(sorted_configs) > 20:
    print(f"  ... and {len(sorted_configs) - 20} more unique configs")
print(f"\nTotal unique board configs: {len(config_counts)}")

# Board size statistics
all_rows = [item["board_rows"] for item in dataset]
all_cols = [item["board_cols"] for item in dataset]
all_mines = [item["board_mines"] for item in dataset]
print(f"\nBoard size statistics:")
print(f"  Rows:  min={min(all_rows)}, max={max(all_rows)}, mean={np.mean(all_rows):.1f}")
print(f"  Cols:  min={min(all_cols)}, max={max(all_cols)}, mean={np.mean(all_cols):.1f}")
print(f"  Mines: min={min(all_mines)}, max={max(all_mines)}, mean={np.mean(all_mines):.1f}")
densities = [m / (r * c) * 100 for r, c, m in zip(all_rows, all_cols, all_mines) if r * c > 0]
print(f"  Density: min={min(densities):.1f}%, max={max(densities):.1f}%, mean={np.mean(densities):.1f}%")

zero_mine_count = sum(1 for m in all_mines if m == 0)
print(f"  Zero-mine boards: {zero_mine_count} ({zero_mine_count/len(dataset)*100:.1f}%)")

move_counts = [len(json.loads(item["move_history"])) for item in dataset]
print(f"\nMove statistics:")
print(f"  Min: {min(move_counts)}, Max: {max(move_counts)}, "
      f"Mean: {np.mean(move_counts):.1f}, Median: {np.median(move_counts):.1f}")

# Phase-specific move stats
print(f"\nLogical deduction coverage:")
has_safe = 0
has_mine = 0
has_both = 0
for item in dataset:
    mh = json.loads(item["move_history"])
    has_flag = any(m.get("type") == "flag" for m in mh)
    has_rev = any(m.get("type") == "reveal" for m in mh)
    if has_flag:
        has_mine += 1
    if has_rev:
        has_safe += 1
    if has_flag and has_rev:
        has_both += 1
print(f"  Samples with reveals: {has_safe} ({has_safe/len(dataset)*100:.1f}%)")
print(f"  Samples with flags:   {has_mine} ({has_mine/len(dataset)*100:.1f}%)")
print(f"  Samples with both:    {has_both} ({has_both/len(dataset)*100:.1f}%)")

# Verify dataset columns
print(f"\nDataset columns: {dataset.column_names}")
print(f"  ✅ board_rows, board_cols, board_mines present for reward functions")

# ── Save dataset to JSON ──
dataset_json_path = "minesweeper_dataset.json"

json_records = []
for item in dataset:
    record = {
        "seed": item["seed"],
        "move_history": item["move_history"],
        "board_rows": item["board_rows"],
        "board_cols": item["board_cols"],
        "board_mines": item["board_mines"],
        "prompt_text": item["prompt"][0]["content"],
    }
    json_records.append(record)

with open(dataset_json_path, "w") as f:
    json.dump(json_records, f, indent=2)

print(f"\n✅ Dataset saved to {dataset_json_path} ({os.path.getsize(dataset_json_path) / 1024:.1f} KB)")
print(f"   {len(json_records)} records with fields: seed, move_history, board_rows/cols/mines, prompt_text")

print(f"\nSample prompt ({dataset[0]['board_rows']}x{dataset[0]['board_cols']}, "
      f"{dataset[0]['board_mines']} mines):")
print(dataset[0]["prompt"][0]["content"][:300] + "...")

ModuleNotFoundError: No module named 'datasets'

# Configure GRPO Training

Set up GRPO trainer with all hyperparameters:

In [None]:
from trl import GRPOConfig, GRPOTrainer

# ── Lengths ──
# Simplified unified prompt (~300-400 tokens for most boards)
# - Small boards (1-30): Full grid, ~200-600 tokens
# - Large boards (31-50): Frontier zone only, ~400-800 tokens
# 1900 + 128 = 2028 < 2048 = max_seq_length (fits comfortably)
max_prompt_length = 1900
max_completion_length = 128  # HACKATHON CONSTRAINT: JSON-only output
                             # Pure JSON action is ~10-25 tokens

# ── GRPO Configuration (Hybrid: LAMER + XRPO + GRPO-LEAD + S-GRPO) ──
training_args = GRPOConfig(
    # === Generation (LAMER: temperature=1.0 for training exploration) ===
    temperature = 1.0,           # LAMER paper: full exploration during training
    top_p = 0.95,

    # === Optimization (XRPO: lower LR for stability with difficulty reweighting) ===
    learning_rate = 5e-6,        # Reduced from 2e-5: XRPO reweighting + GRPO-LEAD
    weight_decay = 0.01,
    warmup_ratio = 0.05,
    lr_scheduler_type = "cosine",
    optim = "adamw_8bit",
    max_grad_norm = 0.5,

    # === Batch sizes ===
    logging_steps = 1,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 4,
    num_generations = 8,

    # === Lengths ===
    max_prompt_length = max_prompt_length,
    max_completion_length = max_completion_length,

    # === Training duration ===
    max_steps = 500,
    save_steps = 100,

    # === GRPO specific (LAMER: num_iterations=2 for MineSweeper) ===
    beta = 0.04,                 # Mild KL penalty to prevent reward hacking
    num_iterations = 2,          # LAMER paper: 2 GRPO iterations

    # === Reward weighting ===
    # [valid_json + length_penalty, gameplay + difficulty_reweight, strategic + difficulty]
    # Increased format weight slightly since it now includes length penalty (GRPO-LEAD)
    reward_weights = [0.20, 0.65, 0.15],

    # === Ensure extra dataset columns are NOT removed ===
    remove_unused_columns = False,

    # === Output ===
    report_to = "none",
    output_dir = "minesweeper_grpo_v2",
    seed = 42,
    bf16 = True,
)

print("Training configuration:")
print(f"  Max steps:           {training_args.max_steps}")
print(f"  Generations/state:   {training_args.num_generations}")
print(f"  Learning rate:       {training_args.learning_rate}")
print(f"  LR scheduler:       {training_args.lr_scheduler_type}")
print(f"  Max grad norm:       {training_args.max_grad_norm}")
print(f"  Beta (KL penalty):   {training_args.beta}")
print(f"  Num iterations:      {training_args.num_iterations}")
print(f"  Reward weights:      {training_args.reward_weights}")
print(f"  Prompt/Completion:   {max_prompt_length}/{max_completion_length}")
print(f"  Temperature:         {training_args.temperature}")
print(f"  remove_unused_cols:  {training_args.remove_unused_columns}")
print(f"  LoRA rank:           {lora_rank}")
print(f"  Board range:         1-50 rows × 1-50 cols, 0-20% mines")
print()
print("Prompt system: Simplified unified format (~300-400 tokens)")

In [None]:
from transformers import TrainerCallback

# Board configs to evaluate on during training (mix of sizes from 1-50 range)
EVAL_CONFIGS = [
    (1, 1, 0),     # Trivial — 0 mines
    (3, 3, 1),     # Tiny
    (5, 5, 3),     # Small
    (6, 6, 5),     # Standard
    (8, 8, 10),    # Medium
    (10, 10, 20),  # Large
    (15, 15, 45),  # XL
    (6, 8, 6),     # Rectangular
    (1, 10, 2),    # Row board
    (20, 20, 80),  # XX-Large
]


class MinesweeperEvalCallback(TrainerCallback):
    """Periodically play games during training with variable board sizes.

    NO move limit — games end only on success (all safe revealed)
    or failure (mine hit). Max iterations capped to prevent infinite loops
    from repeated invalid actions.
    """

    def __init__(self, eval_every_steps=50, num_games=10):
        self.eval_every_steps = eval_every_steps
        self.num_games = min(num_games, len(EVAL_CONFIGS))

    def on_step_end(self, args, state, control, model=None, processing_class=None, **kwargs):
        if state.global_step % self.eval_every_steps != 0:
            return

        tokenizer = processing_class
        if tokenizer is None or model is None:
            return

        was_training = model.training
        model.eval()

        wins = 0
        total_moves = 0
        invalid_count = 0

        for i in range(self.num_games):
            rows, cols, mines = EVAL_CONFIGS[i % len(EVAL_CONFIGS)]
            game = MinesweeperGame(rows=rows, cols=cols, num_mines=mines,
                                   seed=10000 + i)
            moves = 0
            invalids = 0
            consecutive_invalids = 0
            seen_actions = set()   # Fix #4: detect repeated actions
            repeat_count = 0       # Fix #4: count consecutive repeats
            # Safety cap: prevent infinite loops (not a move limit — just loop protection)
            # Capped at 500 to prevent runaway 50×50 evals (was rows*cols*3+20 = 7520 for 50×50)
            max_iterations = min(500, rows * cols + 100)

            iteration = 0
            while game.state() == "ongoing" and iteration < max_iterations:
                iteration += 1
                prompt = format_state_for_llm(game, mode="inference")
                text = tokenizer.apply_chat_template(
                    [{"role": "user", "content": prompt}],
                    tokenize=False,
                    add_generation_prompt=True,
                )
                inputs = tokenizer(text, return_tensors="pt", truncation=True,
                                   max_length=max_prompt_length + 100)
                inputs = {k: v.to(model.device) for k, v in inputs.items()}

                with torch.no_grad():
                    output = model.generate(
                        **inputs,
                        temperature=0.7,  # LAMER paper: 0.7 for eval
                        max_new_tokens=128,  # HACKATHON CONSTRAINT: JSON-only
                        do_sample=True,
                        top_p=0.9,
                    )

                # Decode ONLY the generated tokens (not the prompt)
                gen_tokens = output[0][inputs["input_ids"].shape[1]:]
                response = tokenizer.decode(gen_tokens, skip_special_tokens=True)
                action = parse_llm_action(response)

                if action is None:
                    invalids += 1
                    consecutive_invalids += 1
                    if consecutive_invalids >= 5:
                        break  # Too many consecutive invalid actions
                    continue

                consecutive_invalids = 0

                # Fix #4: Check for repeated actions (stuck detection)
                action_key = (action['type'], action['row'], action['col'])
                if action_key in seen_actions:
                    repeat_count += 1
                    if repeat_count >= 3:  # Same move 3 times = stuck
                        break
                else:
                    repeat_count = 0
                    seen_actions.add(action_key)

                result = game.do_action(action)
                if result in ("mine", "win"):
                    moves += 1
                    break
                elif result == "ok":
                    moves += 1
                else:
                    # Invalid move (out_of_bounds, already_revealed, etc.)
                    invalids += 1
                    consecutive_invalids += 1
                    if consecutive_invalids >= 5:
                        break

            if game.state() == "success":
                wins += 1
            total_moves += moves
            invalid_count += invalids

        win_rate = wins / self.num_games
        avg_moves = total_moves / self.num_games
        print(f"\n[Eval @ step {state.global_step}] "
              f"Win: {wins}/{self.num_games} ({win_rate*100:.0f}%) | "
              f"Avg moves: {avg_moves:.1f} | "
              f"Invalid: {invalid_count}\n")

        if was_training:
            model.train()

eval_callback = MinesweeperEvalCallback(eval_every_steps=50, num_games=10)

print(f"Eval callback: {eval_callback.num_games} games every "
      f"{eval_callback.eval_every_steps} steps")
print(f"  No move limit — only success or failure")
print(f"  Uses inference-mode prompt (60% fewer tokens)")
print(f"  Configs: {len(EVAL_CONFIGS)} sizes from 1x1 to 20x20")
print(f"  Max iterations capped at min(500, rows*cols+100)")
print(f"  max_new_tokens=128 (HACKATHON CONSTRAINT: JSON-only output)")

# Train the Model

Start GRPO training with reward functions:

In [None]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        valid_json_reward,   # Format + length penalty (GRPO-LEAD)
        gameplay_scores,     # 12 criteria + difficulty reweight (XRPO)
        strategic_reward,    # Deduction + center-opening + difficulty (XRPO)
    ],
    args = training_args,
    train_dataset = dataset,
    callbacks = [eval_callback],  # Periodic gameplay evaluation
)

print("Starting GRPO training with 3 hybrid reward functions...")
print("  [1] valid_json_reward  (weight: 0.20) — format + length penalty (GRPO-LEAD)")
print("  [2] gameplay_scores    (weight: 0.65) — 12 criteria + difficulty (XRPO)")
print("  [3] strategic_reward   (weight: 0.15) — deduction + center (LAMER)")
trainer.train()

# Test Trained Model

Evaluate the finetuned model:

In [None]:
# Test on variable board sizes across the full competition range
FastLanguageModel.for_inference(model)

test_configs = [
    (1, 1, 0, 200, "1x1 Trivial"),
    (3, 3, 1, 201, "Tiny"),
    (5, 5, 3, 202, "Small/Easy"),
    (6, 6, 5, 99,  "Standard"),
    (8, 8, 10, 101, "Medium"),
    (10, 10, 20, 203, "Large 20%"),
    (15, 15, 45, 204, "XL 20%"),
    (1, 20, 4, 205, "Row Board"),
    (20, 1, 4, 206, "Column Board"),
    (5, 5, 0, 207, "Zero Mines"),
]

for rows, cols, mines, seed, label in test_configs:
    print(f"\n{'='*50}")
    print(f"=== {label} ({rows}x{cols}, {mines} mines) ===")
    print(f"{'='*50}")

    test_game = MinesweeperGame(rows=rows, cols=cols, num_mines=mines, seed=seed)

    # Handle already-won games (0-mine boards that auto-cascade)
    if test_game.state() != "ongoing":
        print(f"  Game auto-resolved: {test_game.state()}")
        continue

    test_prompt = format_state_for_llm(test_game, mode="inference")

    test_text = tokenizer.apply_chat_template(
        [{"role": "user", "content": test_prompt}],
        tokenize=False,
        add_generation_prompt=True,
    )

    inputs = tokenizer(test_text, return_tensors="pt", truncation=True,
                       max_length=max_prompt_length + 100)
    inputs = {k: v.to("cuda") for k, v in inputs.items()}

    output = model.generate(
        **inputs,
        temperature=0.7,  # LAMER paper: 0.7 for eval
        max_new_tokens=128,  # HACKATHON CONSTRAINT: JSON-only
        do_sample=True,
        top_p=0.9,
        repetition_penalty=1.2,
    )

    # Decode only generated tokens
    gen_tokens = output[0][inputs["input_ids"].shape[1]:]
    response_text = tokenizer.decode(gen_tokens, skip_special_tokens=True)
    print(f"Response: {response_text.strip()}")

    action = parse_llm_action(response_text)
    print(f"Parsed action: {action}")

    if action:
        result = test_game.do_action(action)
        print(f"Result: {result} | Game state: {test_game.state()}")
    else:
        print("⚠️ Failed to parse a valid action")

# Exhaustive Evaluation: Full Competition Range

Play complete games across 37 board configurations (1×1 to 50×50, 0-20% mines).
**No move limit** — only two outcomes: SUCCESS (all safe cells revealed) or FAILURE (mine revealed).

In [None]:
# def play_full_game(model, tokenizer, rows=6, cols=6, num_mines=5, seed=None,
#                    verbose=False):
#     """Play a complete Minesweeper game, tracking detailed metrics.

#     Supports any board size from 1x1 to 50x50.
#     NO move limit — game ends ONLY on:
#       - SUCCESS: all non-mine cells are revealed
#       - FAILURE: any mine cell is revealed
#     A safety iteration cap prevents infinite loops from repeated invalid actions.
#     """
#     game = MinesweeperGame(rows=rows, cols=cols, num_mines=num_mines, seed=seed)

#     # Edge case: game already won (e.g., 0-mine board)
#     if game.state() != "ongoing":
#         return {
#             "game": game,
#             "moves": 0,
#             "logical_moves": 0,
#             "flags_correct": 0,
#             "flags_wrong": 0,
#             "total_invalids": 0,
#             "result": game.state(),
#             "progress": game.progress(),
#             "config": f"{rows}x{cols}m{num_mines}",
#         }

#     moves = 0
#     consecutive_invalids = 0
#     total_invalids = 0
#     logical_moves = 0
#     flags_correct = 0
#     flags_wrong = 0
#     seen_actions = set()   # Fix #4: detect repeated actions (stuck loop)
#     repeat_count = 0       # Fix #4: consecutive repeat counter
#     # Safety cap to prevent infinite loops — NOT a game move limit
#     # Capped at 500 to prevent runaway 50×50 evals (was rows*cols*3+50 = 7550 for 50×50)
#     max_iterations = min(500, rows * cols + 100)

#     iteration = 0
#     while game.state() == "ongoing" and iteration < max_iterations:
#         iteration += 1
#         prompt = format_state_for_llm(game, mode="inference")
#         text = tokenizer.apply_chat_template(
#             [{"role": "user", "content": prompt}],
#             tokenize=False,
#             add_generation_prompt=True,
#         )

#         inputs = tokenizer(text, return_tensors="pt", truncation=True,
#                            max_length=max_prompt_length + 100)
#         inputs = {k: v.to("cuda") for k, v in inputs.items()}

#         with torch.no_grad():
#             output = model.generate(
#                 **inputs,
#                 temperature=0.7,  # LAMER paper: 0.7 for eval
#                 max_new_tokens=128,  # HACKATHON CONSTRAINT: JSON-only
#                 do_sample=True,
#                 top_p=0.9,
#                 repetition_penalty=1.2,
#             )

#         # Decode ONLY generated tokens
#         gen_tokens = output[0][inputs["input_ids"].shape[1]:]
#         response = tokenizer.decode(gen_tokens, skip_special_tokens=True)
#         action = parse_llm_action(response)

#         if action is None:
#             consecutive_invalids += 1
#             total_invalids += 1
#             if consecutive_invalids >= 5:
#                 break  # Agent is stuck — abort
#             continue

#         # Fix #4: Check for repeated actions (stuck detection)
#         action_key = (action['type'], action['row'], action['col'])
#         if action_key in seen_actions:
#             repeat_count += 1
#             if repeat_count >= 3:  # Same move 3 times = stuck
#                 break
#         else:
#             repeat_count = 0
#             seen_actions.add(action_key)

#         # Track logical moves (compute BEFORE applying action)
#         safe_set, mine_set = _compute_safe_and_mine_cells(game)
#         r, c = action["row"], action["col"]
#         if action["type"] == "reveal" and (r, c) in safe_set:
#             logical_moves += 1
#         elif action["type"] == "flag" and (r, c) in mine_set:
#             logical_moves += 1

#         # Track flag accuracy
#         if action["type"] == "flag":
#             if 0 <= r < game.rows and 0 <= c < game.cols:
#                 if game._board[r][c] == -1:
#                     flags_correct += 1
#                 else:
#                     flags_wrong += 1

#         if verbose:
#             print(f"  Move {moves}: {action}")

#         result = game.do_action(action)
#         if result in ("mine", "win", "ok"):
#             moves += 1
#         elif result in ("out_of_bounds", "already_revealed", "flagged_cell",
#                          "invalid_flag", "invalid_format"):
#             # Invalid moves don't count but game stays ongoing
#             total_invalids += 1
#             consecutive_invalids += 1
#             if consecutive_invalids >= 5:
#                 break

#         if result in ("mine", "win"):
#             break

#     return {
#         "game": game,
#         "moves": moves,
#         "logical_moves": logical_moves,
#         "flags_correct": flags_correct,
#         "flags_wrong": flags_wrong,
#         "total_invalids": total_invalids,
#         "result": game.state(),
#         "progress": game.progress(),
#         "config": f"{rows}x{cols}m{num_mines}",
#     }


# # ──────────────────────────────────────────────────────────────────────
# # EXHAUSTIVE Multi-Size Evaluation — Competition Spec
# # n, m ∈ [1, 50], mines 0-20% of total cells
# # Only two outcomes: SUCCESS or FAILURE (no timeouts)
# # ──────────────────────────────────────────────────────────────────────

# EVAL_SUITE = [
#     # (rows, cols, mines, num_games, label)

#     # === Trivial / Edge Cases ===
#     (1, 1, 0,   5,  "1x1 trivial"),
#     (1, 2, 0,   5,  "1x2 trivial"),
#     (2, 1, 0,   5,  "2x1 trivial"),
#     (2, 2, 0,   5,  "2x2 no mines"),
#     (3, 3, 0,   5,  "3x3 no mines"),
#     (5, 5, 0,   5,  "5x5 no mines"),

#     # === Tiny Boards ===
#     (3, 3, 1,  10,  "3x3 1 mine"),
#     (4, 4, 3,  10,  "4x4 3 mines"),

#     # === Small Boards ===
#     (5, 5, 3,  15,  "5x5 easy"),
#     (5, 5, 5,  10,  "5x5 max density"),

#     # === Standard ===
#     (6, 6, 5,  20,  "6x6 standard"),
#     (6, 6, 7,  10,  "6x6 hard"),

#     # === Medium ===
#     (7, 7, 7,  10,  "7x7 medium"),
#     (8, 8, 10, 10,  "8x8 medium"),
#     (8, 8, 12, 10,  "8x8 max density"),

#     # === Large ===
#     (10, 10, 10,  5, "10x10 10%"),
#     (10, 10, 20, 10, "10x10 20%"),
#     (15, 15, 45,  5, "15x15 20%"),

#     # === XL ===
#     (20, 20, 80,  3, "20x20 20%"),
#     (25, 25, 125, 3, "25x25 20%"),
#     (30, 30, 180, 2, "30x30 20%"),

#     # === XXL ===
#     (40, 40, 320, 2, "40x40 20%"),
#     (50, 50, 500, 2, "50x50 20%"),

#     # === Rectangular ===
#     (1, 10, 2,   5, "1x10 row"),
#     (10, 1, 2,   5, "10x1 column"),
#     (1, 50, 10,  3, "1x50 row"),
#     (50, 1, 10,  3, "50x1 column"),
#     (6, 8, 6,    5, "6x8 rect"),
#     (8, 6, 6,    5, "8x6 rect"),
#     (5, 15, 15,  3, "5x15 wide"),
#     (15, 5, 15,  3, "15x5 tall"),
#     (3, 50, 30,  2, "3x50 extreme wide"),
#     (50, 3, 30,  2, "50x3 extreme tall"),

#     # === Sparse (low density) ===
#     (10, 10, 1,  5, "10x10 sparse"),
#     (20, 20, 4,  3, "20x20 sparse"),
#     (50, 50, 10, 2, "50x50 sparse"),

#     # === Progressive Difficulty (Fix #10: LAMER generalization test) ===
#     # Same size, increasing mines — tests density scaling
#     (10, 10, 5,  5, "10x10 5% progressive"),
#     (10, 10, 15, 5, "10x10 15% progressive"),

#     # Increasing size, same ~10% density — tests size scaling
#     (15, 15, 23, 3, "15x15 10% density"),
#     (20, 20, 40, 3, "20x20 10% density"),

#     # Generalization to harder unseen configs
#     (25, 25, 62, 2, "25x25 10% generalization"),
#     (30, 30, 90, 2, "30x30 10% generalization"),
# ]

# FastLanguageModel.for_inference(model)
# total_games = sum(n for _,_,_,n,_ in EVAL_SUITE)
# print(f"{'='*80}")
# print(f"  EXHAUSTIVE EVALUATION — {total_games} games across {len(EVAL_SUITE)} configs")
# print(f"  Competition spec: n,m ∈ [1,50], mines 0-20%, no move limit")
# print(f"  Includes progressive difficulty test (LAMER generalization)")
# print(f"  Only two outcomes: SUCCESS or FAILURE")
# print(f"{'='*80}\n")

# all_results = []
# per_config_stats = {}

# for rows, cols, mines, num_games, label in EVAL_SUITE:
#     config_key = f"{rows}x{cols}m{mines}"
#     wins = 0
#     fails = 0
#     config_results = []

#     for i in range(num_games):
#         info = play_full_game(model, tokenizer, rows=rows, cols=cols,
#                               num_mines=mines, seed=5000 + i + hash(config_key) % 10000)
#         config_results.append(info)
#         all_results.append(info)

#         if info["result"] == "success":
#             wins += 1
#         elif info["result"] == "failed":
#             fails += 1

#     # Per-config summary
#     avg_moves = np.mean([r["moves"] for r in config_results])
#     avg_logical = np.mean([r["logical_moves"] for r in config_results])
#     avg_progress = np.mean([r["progress"] for r in config_results])
#     avg_invalids = np.mean([r["total_invalids"] for r in config_results])
#     stuck = sum(1 for r in config_results if r["result"] == "ongoing")
#     wr = wins / num_games * 100

#     per_config_stats[config_key] = {
#         "label": label, "wins": wins, "fails": fails, "stuck": stuck,
#         "total": num_games, "win_rate": wr,
#         "avg_moves": avg_moves, "avg_progress": avg_progress,
#     }

#     status_icon = "✅" if wr >= 50 else "⚠️" if wr >= 20 else "❌"
#     print(f"  {status_icon} {label:22s} ({config_key:12s}): "
#           f"{wins:2d}/{num_games:2d} wins ({wr:5.1f}%) | "
#           f"fails={fails} stuck={stuck} | "
#           f"moves={avg_moves:5.1f} | logical={avg_logical:4.1f} | "
#           f"progress={avg_progress:.0%} | invalids={avg_invalids:.1f}")

# # ── Overall Summary ──
# total_games_actual = len(all_results)
# total_wins = sum(1 for r in all_results if r["result"] == "success")
# total_fails = sum(1 for r in all_results if r["result"] == "failed")
# total_stuck = sum(1 for r in all_results if r["result"] == "ongoing")
# fc = sum(r["flags_correct"] for r in all_results)
# fw = sum(r["flags_wrong"] for r in all_results)

# print(f"\n{'='*80}")
# print(f"  OVERALL: {total_wins}/{total_games_actual} wins ({total_wins/total_games_actual*100:.1f}%)")
# print(f"{'='*80}")
# print(f"  Wins (success):    {total_wins:4d} ({total_wins/total_games_actual*100:.1f}%)")
# print(f"  Losses (failure):  {total_fails:4d} ({total_fails/total_games_actual*100:.1f}%)")
# print(f"  Stuck (loop cap):  {total_stuck:4d} ({total_stuck/total_games_actual*100:.1f}%)")
# print(f"  Avg moves:         {np.mean([r['moves'] for r in all_results]):.1f}")
# print(f"  Avg progress:      {np.mean([r['progress'] for r in all_results]):.0%}")
# print(f"  Avg logical moves: {np.mean([r['logical_moves'] for r in all_results]):.1f}")
# if fc + fw > 0:
#     print(f"  Flag accuracy:     {fc}/{fc+fw} ({fc/(fc+fw)*100:.1f}%)")
# else:
#     print(f"  Flags: none placed")

# # ── Category Breakdown ──
# categories = {
#     "Trivial (0 mines)": [k for k, v in per_config_stats.items() if "m0" in k],
#     "Tiny (≤4x4)":       [k for k, v in per_config_stats.items()
#                            if v["label"].startswith(("3x3", "4x4")) and "m0" not in k],
#     "Small (5x5)":       [k for k, v in per_config_stats.items() if k.startswith("5x5")],
#     "Standard (6x6)":    [k for k, v in per_config_stats.items() if k.startswith("6x6")],
#     "Medium (7-8)":      [k for k, v in per_config_stats.items()
#                            if k.startswith(("7x7", "8x8"))],
#     "Large (10-15)":     [k for k, v in per_config_stats.items()
#                            if k.startswith(("10x10", "15x15"))],
#     "XL (20-50)":        [k for k, v in per_config_stats.items()
#                            if k.startswith(("20x20", "25x25", "30x30", "40x40", "50x50"))],
#     "Rectangular":       [k for k, v in per_config_stats.items()
#                            if "rect" in v["label"] or "row" in v["label"]
#                            or "column" in v["label"] or "wide" in v["label"]
#                            or "tall" in v["label"]],
#     "Sparse":            [k for k, v in per_config_stats.items() if "sparse" in v["label"]],
#     "Progressive":       [k for k, v in per_config_stats.items()
#                            if "progressive" in v["label"] or "generalization" in v["label"]
#                            or "density" in v["label"]],
# }

# print(f"\n{'='*80}")
# print(f"  CATEGORY BREAKDOWN")
# print(f"{'='*80}")
# for cat_name, keys in categories.items():
#     if not keys:
#         continue
#     cat_wins = sum(per_config_stats[k]["wins"] for k in keys)
#     cat_total = sum(per_config_stats[k]["total"] for k in keys)
#     if cat_total > 0:
#         cat_wr = cat_wins / cat_total * 100
#         icon = "✅" if cat_wr >= 50 else "⚠️" if cat_wr >= 20 else "❌"
#         print(f"  {icon} {cat_name:22s}: {cat_wins:3d}/{cat_total:3d} ({cat_wr:5.1f}%)")
# print(f"{'='*80}")

# Save the Model

Save your trained model for competition submission:

In [None]:
# Save LoRA adapters
model.save_pretrained("my_minesweeper_model")
tokenizer.save_pretrained("my_minesweeper_model")
print("✅ LoRA adapters saved to: my_minesweeper_model/")

# ──────────────────────────────────────────────────────────────────────
# Save merged model in 16bit
# Workaround for Unsloth bug: UnboundLocalError on 'copied_tokenizer_model_from_cache'
# when using local model paths. We manually merge LoRA weights and save with HF API.
# ──────────────────────────────────────────────────────────────────────
import os, shutil, gc

merged_dir = "my_minesweeper_model_merged"
os.makedirs(merged_dir, exist_ok=True)

try:
    # Try Unsloth's native method first (works on some versions)
    model.save_pretrained_merged(
        merged_dir,
        tokenizer,
        save_method="merged_16bit",
    )
    print("✅ Merged 16-bit model saved via Unsloth")
except (UnboundLocalError, Exception) as e:
    print(f"⚠️ Unsloth merge failed ({type(e).__name__}: {e})")
    print("   Falling back to manual LoRA merge...")

    try:
        # Manual merge: get the PEFT model, merge LoRA into base weights, save
        from peft import PeftModel
        from transformers import AutoModelForCausalLM, AutoTokenizer

        # Re-load base model in float16 for merge
        print("   Loading base model for merge...")
        base_model = AutoModelForCausalLM.from_pretrained(
            "/workspace/workspace/Qwen2.5-14B-Instruct",
            torch_dtype=torch.float16,
            device_map="auto",
        )

        # Load LoRA adapters on top
        print("   Applying LoRA adapters...")
        merged_model = PeftModel.from_pretrained(base_model, "my_minesweeper_model")
        merged_model = merged_model.merge_and_unload()

        # Save the fully merged model
        print("   Saving merged model...")
        merged_model.save_pretrained(merged_dir, safe_serialization=True)
        tokenizer.save_pretrained(merged_dir)

        # Cleanup
        del base_model, merged_model
        gc.collect()
        torch.cuda.empty_cache()

        print(f"✅ Merged 16-bit model saved to: {merged_dir}/")

    except Exception as e2:
        print(f"❌ BOTH merge methods failed!")
        print(f"   Unsloth error: {e}")
        print(f"   Manual merge error: {e2}")
        print(f"   LoRA adapters are still saved at: my_minesweeper_model/")
        print(f"   You can merge manually later with:")
        print(f"     from peft import PeftModel")
        print(f"     base = AutoModelForCausalLM.from_pretrained('<base_model_path>')")
        print(f"     merged = PeftModel.from_pretrained(base, 'my_minesweeper_model')")
        print(f"     merged = merged.merge_and_unload()")
        print(f"     merged.save_pretrained('{merged_dir}')")

# Verify saved files
saved_files = os.listdir(merged_dir)
safetensors = [f for f in saved_files if f.endswith(".safetensors")]
print(f"   Files: {len(saved_files)} total, {len(safetensors)} safetensors shards")
print(f"   Config: {'config.json' in saved_files}  Tokenizer: {'tokenizer.json' in saved_files}")

# Inference from Merged Model

Load the saved merged model from disk (no LoRA, no Unsloth) and verify it works for Minesweeper inference.
This tests that the model was saved correctly and can be loaded independently.

In [None]:
# ──────────────────────────────────────────────────────────────────────
# Load merged model from disk for inference testing
# No Unsloth, no LoRA — pure HuggingFace transformers
# ──────────────────────────────────────────────────────────────────────
import torch, gc
from transformers import AutoModelForCausalLM, AutoTokenizer

merged_dir = "my_minesweeper_model_merged"

print(f"Loading merged model from: {merged_dir}/")
print("  (This is a standalone model — no LoRA adapters needed)")

# Load tokenizer
merged_tokenizer = AutoTokenizer.from_pretrained(merged_dir)
print(f"  ✅ Tokenizer loaded ({merged_tokenizer.vocab_size} vocab)")

# Load model
merged_model = AutoModelForCausalLM.from_pretrained(
    merged_dir,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
merged_model.eval()
print(f"  ✅ Model loaded on {merged_model.device}")
print(f"  Parameters: {sum(p.numel() for p in merged_model.parameters()):,}")

# ──────────────────────────────────────────────────────────────────────
# Run inference on a diverse set of Minesweeper boards
# ──────────────────────────────────────────────────────────────────────

test_configs = [
    (1, 1, 0, 300, "1×1 Trivial"),
    (3, 3, 1, 301, "Tiny 3×3"),
    (5, 5, 3, 302, "Small 5×5"),
    (6, 6, 5, 99,  "Standard 6×6"),
    (8, 8, 10, 303, "Medium 8×8"),
    (10, 10, 20, 304, "Large 10×10"),
    (15, 15, 45, 305, "XL 15×15"),
    (1, 20, 4, 306, "Linear 1×20"),
    (5, 5, 0, 307, "Zero Mines 5×5"),
]

print(f"\n{'='*60}")
print(f"  MERGED MODEL INFERENCE TEST — {len(test_configs)} boards")
print(f"{'='*60}")

results = {"pass": 0, "fail": 0, "skip": 0}

for rows, cols, mines, seed, label in test_configs:
    print(f"\n--- {label} ({rows}×{cols}, {mines} mines, seed={seed}) ---")

    game = MinesweeperGame(rows=rows, cols=cols, num_mines=mines, seed=seed)

    if game.state() != "ongoing":
        print(f"  Game auto-resolved: {game.state()}")
        results["skip"] += 1
        continue

    # Build prompt using the same inference prompt system
    prompt = format_state_for_llm(game, mode="inference")

    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": prompt},
    ]
    text = merged_tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )

    inputs = merged_tokenizer(text, return_tensors="pt", truncation=True, max_length=2048)
    inputs = {k: v.to(merged_model.device) for k, v in inputs.items()}

    with torch.no_grad():
        output = merged_model.generate(
            **inputs,
            temperature=0.7,       # LAMER paper: 0.7 for eval
            max_new_tokens=128,    # HACKATHON CONSTRAINT: JSON-only
            do_sample=True,
            top_p=0.9,
            repetition_penalty=1.2,
        )

    # Decode only generated tokens (not the prompt)
    gen_tokens = output[0][inputs["input_ids"].shape[1]:]
    response = merged_tokenizer.decode(gen_tokens, skip_special_tokens=True).strip()

    print(f"  Response: {response[:200]}{'...' if len(response) > 200 else ''}")

    action = parse_llm_action(response)
    print(f"  Parsed:   {action}")

    if action:
        result = game.do_action(action)
        state = game.state()
        print(f"  Result:   {result} → game {state}")
        if result != "mine":
            results["pass"] += 1
        else:
            results["fail"] += 1
    else:
        print(f"  ⚠️ Failed to parse valid action")
        results["fail"] += 1

# ── Summary ──
print(f"\n{'='*60}")
print(f"  INFERENCE TEST SUMMARY")
print(f"{'='*60}")
print(f"  Pass (valid move): {results['pass']}/{results['pass'] + results['fail']}")
print(f"  Fail (bad parse/mine): {results['fail']}")
print(f"  Skipped (auto-win): {results['skip']}")
total = results['pass'] + results['fail']
if total > 0:
    print(f"  Success rate: {results['pass']/total*100:.1f}%")
print(f"\n✅ Merged model inference test complete!")
print(f"   Model path: {merged_dir}/")
print(f"   Ready for competition submission.")

# Fixes & Improvements Applied

## 4-Paper Hybrid Integration (LAMER + XRPO + GRPO-LEAD + S-GRPO)

### Paper 1: LAMER — Meta-RL for Language Agents (74% win rate on MineSweeper)
| LAMER Finding | Implementation |
|---------------|----------------|
| ReAct prompting (Reason + Act) | Explicit "think step-by-step before acting" in system prompt |
| Pre-computed logical hints (CoT) | SAFE/MINE cell lists injected into every prompt |
| Center-opening strategy | Reward center on dense boards (>10%), penalize edges/corners |
| `temperature=1.0` for training | Full exploration during GRPO generation |
| `temperature=0.7` for eval | LAMER paper eval temperature |
| `num_iterations=2` | LAMER paper: 2 GRPO iterations for MineSweeper |
| STEP 1/STEP 2 reasoning | Satisfied/constrained number scanning |

### Paper 2: XRPO — Adaptive Exploration with Difficulty Reweighting
| XRPO Finding | Implementation |
|--------------|----------------|
| Difficulty reweighting | `_difficulty_multiplier()` → harder boards get amplified reward signal (×0.7–×1.5) |
| Exploration heuristic | STEP 4 in prompt: prefer high-info cells, low numbers, avoid edges |
| Novelty bonus concept | Implicit: difficulty multiplier serves similar purpose (hard=novel→stronger signal) |
| Adaptive difficulty proxy | `0.6 × density/0.20 + 0.4 × log(size)/log(2500)` → sigmoid multiplier |
| Safety check | Returns 1.0 for 0-mine boards and degenerate cases (Fix #2) |

### Paper 3: GRPO-LEAD — Length Penalty & Explicit Wrong Penalty
| GRPO-LEAD Finding | Implementation |
|-------------------|----------------|
| Group-normalized length penalty | 2-pass z-score normalization across correct responses (Fix #6) |
| ≤200 token response cap | Instruction in training prompt: "Total response ≤ 200 tokens" |
| Explicit wrong penalty | Mine reveal: -26 (was -25), wrong flag: -11 (was -10) |
| Lower learning rate | `5e-6` (was `2e-5`) for stability with penalty terms |
| Concise reasoning | "2-3 sentences max" instruction in prompt output format |

### Paper 4: S-GRPO — Early Exit When Solution Found
| S-GRPO Finding | Implementation |
|----------------|----------------|
| Early exit instruction | "Once you find a logical move in Steps 1-3, STOP reasoning and output it immediately" |
| Reduced wasted tokens | Combined with GRPO-LEAD length penalty → shorter, focused responses |
| Step-ordered deduction | Steps 1→2→3→4 with explicit "stop when found" at each logical step |

## Robustness Fixes (Latest Session)
| # | Issue | Fix | Priority |
|---|-------|-----|----------|
| 1 | Infinite loops in `_play_smart_moves()` | `stuck_count` with `MAX_STUCK=10` — break after 10 failed attempts | P0 |
| 2 | Div-by-zero in `_difficulty_multiplier()` | Safety check: return 1.0 for `board_size==0` or `num_mines==0` | P3 |
| 3 | Missing one-cell-left endgame hint | Compute exact last cell coords, inject `CRITICAL` hint with specific action | P2 |
| 4 | Model repeating same valid move forever | `seen_actions` set + `repeat_count` in eval callback & `play_full_game` — break after 3 repeats | P1 |
| 6 | Wrong length penalty (absolute not group) | 2-pass GRPO-LEAD: z-score normalize lengths across correct responses, clamp multiplier [0.5, 1.5] | P1 |
| 9 | Manual merge fallback had no error handling | Nested try/except — if both Unsloth & PEFT merge fail, print recovery instructions | P3 |
| 10 | No progressive difficulty in eval suite | Added 6 configs: same-size density scaling + same-density size scaling + generalization tests | P3 |

## 50×50 Board Scalability Fixes
| # | Issue | Fix | Impact |
|---|-------|-----|--------|
| S1 | Token budget mismatch (1900+128=2028, no room for reasoning) | Rebalanced to 1792+256=2048 exact fit | 50×50 Format C (~1335 tokens) fits with 457 token margin for reasoning |
| S2 | Large boards only 10% of training data (sizes 1-50) | Boosted to 15% at sizes 30-50; bands now 25/20/20/20/15 | 3× more large board training samples |
| S3 | Flat +100 win bonus regardless of board size | Scaled: `100 × (1 + min(1, board_size/1000))` — 50×50→+250, 6×6→+104 | Model motivated to complete large boards |
| S4 | Eval `max_iterations` too high (rows*cols*3 = 7500 for 50×50) | Capped to `min(500, rows*cols+100)` | Prevents runaway 50×50 eval loops |
| S5 | All eval/inference used `max_new_tokens=128` | Reverted to 128 everywhere — HACKATHON CONSTRAINT: JSON-only output, no reasoning | Pure JSON action is ~10-25 tokens, well under 128 limit |

## Exhaustive Dataset Generation
| Feature | Details |
|---------|---------|
| Total samples | 4000 (was 3000) |
| Edge cases | 50+ configs (was 25) — trivial, linear, rectangular, density extremes |
| Opening training | 25% fresh/single-move boards — train safe opening strategies |
| Pattern-specific | 15% — satisfied-number boards + multi-region disconnected boards |
| Mid-game deduction | 25% — 3-15 moves, progressive flagging 10%→30%→50% |
| Endgame completion | 15% — 80-98% revealed — flag accounting, finish strategy |
| Forced guess | 10% — no logical deductions available |
| Progressive flagging | 10% early → 30% mid → 50% late |
| Density stratification | 8% zero, 17% very sparse, 25% sparse, 25% medium, 25% dense |
| Stuck prevention | `MAX_STUCK=10` in `_play_smart_moves()` (Fix #1) |

## Prompt System (Hybrid: 4-Paper Master Template)
| Feature | Paper Source | Implementation |
|---------|-------------|----------------|
| ReAct reasoning | LAMER | "Think step-by-step → STEP 1-4 → Act" |
| Pattern recognition | XRPO | STEP 3: 1-2-1 line, 1-1 corner, zero cascade |
| Forced-guess heuristic | XRPO | STEP 4: prefer high-info cells, low numbers, avoid edges |
| Early exit | S-GRPO | "Once you find a logical move in Steps 1-3, STOP" |
| Length control | GRPO-LEAD | "Total response ≤ 200 tokens", "2-3 sentences max" |
| 3-tier board format | Custom | A (≤20): full grid, B (21-35): frontier, C (36-50): summary |
| Phase-aware prompts | LAMER | Opening (center), Mid-game (deduction), Endgame (flag accounting) |
| Edge case guidance | Custom | 0-mine, linear, tiny, large, high-density, all-flagged, last-cell |
| One-cell-left hint | Fix #3 | Computes exact cell coords → "Reveal (r,c) to WIN!" or "Flag (r,c)!" |
| All-remaining-mines | Fix #3 | "ALL REMAINING N CELLS ARE MINES: Flag any!" |
| Inference variant | LAMER | ~60% shorter for eval/test |

## Reward System (3 Functions, Hybrid)
| Reward | Weight | Paper Enhancements |
|--------|--------|-------------------|
| `valid_json_reward` | 0.20 | GRPO-LEAD: group-normalized z-score length penalty (Fix #6) |
| `gameplay_scores` | 0.65 | XRPO: difficulty reweighting (×0.7–×1.5), GRPO-LEAD: explicit wrong penalty |
| `strategic_reward` | 0.15 | XRPO: difficulty reweighting, LAMER: center-opening |

## Evaluation System
| Feature | Details |
|---------|---------|
| Eval callback | 10 configs every 50 steps, temp=0.7, stuck detection (Fix #4) |
| Exhaustive eval | 43+ configs (was 37), 1×1 to 50×50 |
| Progressive difficulty | Same-size density scaling + same-density size scaling (Fix #10) |
| Generalization test | 25×25 10%, 30×30 10% — harder unseen configs |
| Stuck detection | Repeated action tracking — break after 3 same-move repeats (Fix #4) |
| Category breakdown | Trivial, Tiny, Small, Standard, Medium, Large, XL, Rectangular, Sparse, Progressive |

## Training Config (Hybrid Hyperparameters)
| Parameter | Value | Source |
|-----------|-------|--------|
| temperature | 1.0 (train), 0.7 (eval) | LAMER |
| learning_rate | 5e-6 | GRPO-LEAD (reduced for stability) |
| num_iterations | 2 | LAMER |
| beta (KL) | 0.04 | LAMER |
| reward_weights | [0.20, 0.65, 0.15] | Hybrid (format↑ for length penalty) |
| max_completion_length | 128 | GRPO-LEAD (length control) |
| max_grad_norm | 0.5 | Gradient clipping for stability |
| warmup_ratio | 0.05 | Stable early training |

## Model Save & Inference
| Feature | Details |
|---------|---------|
| LoRA save | `my_minesweeper_model/` — always succeeds |
| Merged save | Try Unsloth → fallback PEFT merge → error recovery instructions (Fix #9) |
| Merged inference | Standalone HF transformers load — no Unsloth/LoRA needed |
| Inference test | 9 diverse boards, pass/fail summary |

## Competition Spec Compliance
| Requirement | Implementation |
|-------------|---------------|
| Board rows: 1–50 | `MIN_ROWS=1, MAX_ROWS=50` in game engine |
| Board cols: 1–50 | `MIN_COLS=1, MAX_COLS=50` in game engine |
| Mines: 0–20% of cells | `MAX_MINE_DENSITY=0.20`, 0 mines allowed |
| No move limit | Only success/failure |
| Success = all safe revealed | `_check_win()` checks `len(revealed) >= safe_cells` |
| Failure = mine revealed | Only `_reveal_cell` hitting mine sets `_state="failed"` |
| max_new_tokens: 128 | `max_completion_length=128` in GRPOConfig |

## Critical Bugs Fixed (Previous Sessions)
| # | Bug | Fix |
|---|-----|-----|

| 1 | `do_action()` set `_state="failed"` for ALL invalid moves | Only `mine` sets state to "failed" |
| 8 | Random flags in training data | `_smart_flag()` only flags logically certain mines || 9 | Reward scale imbalance | Rebalanced: JSON=-10, mine=-26, win=+100 |

| 2 | Reward functions signature mismatch with TRL GRPOTrainer | `(prompts, completions, **kwargs)` |
| 9 | Reward scale imbalance | Rebalanced: JSON=-10, mine=-26, win=+100 || 8 | Random flags in training data | `_smart_flag()` only flags logically certain mines |

| 3 | Hardcoded board size in rewards | `_reconstruct_game()` reads from kwargs || 7 | `remove_unused_columns` not set | Explicitly `False` |

| 4 | `max_prompt_length=700` truncated prompts | Increased to 1900 || 6 | Eval decoded full output including prompt | Decodes only generated tokens |
| 5 | Separate O(n²) passes for safe/mine cells | Combined single pass |