# Minesweeper LLM Competition - Custom GRPO Training

## Goal
Finetune an LLM with LoRA using GRPO to play Minesweeper by:
- **Input**: JSON game state (board configuration)
- **Output**: JSON action (reveal or flag a cell)

Teams will compete to train the best Minesweeper-playing LLM!

## Training Approach
- **Model**: Qwen2.5-14B-Instruct (from /root/.cache/huggingface/hub)
- **Method**: GRPO (Group Relative Policy Optimization)
- **Framework**: Unsloth (2-6x faster, 70% less VRAM)
- **Hardware**: AMD MI300X GPU (192GB HBM3, ROCm)

# Load Model with Unsloth

Load Qwen3-4B with LoRA configuration:

In [None]:
import os

os.environ["HF_HOME"] = "./workspace/hf_cache"
os.environ["HUGGINGFACE_HUB_CACHE"] = "./workspace/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "./workspace/hf_cache"
os.environ["HF_DATASETS_CACHE"] = "./workspace/hf_cache"


In [None]:
from huggingface_hub import snapshot_download

snapshot_download(
    repo_id="unsloth/Qwen2.5-14B-Instruct",
    local_dir="./workspace/Qwen2.5-14B-Instruct",
    local_dir_use_symlinks=False,
)


In [None]:
from unsloth import FastLanguageModel
import torch
import yaml

# Load config
with open("minesweeper_config_me.yaml", "r") as f:
    config = yaml.safe_load(f)

lora_rank = config.get("lora_rank", 32)

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="/workspace/workspace/Qwen2.5-14B-Instruct",
    load_in_4bit=False,   # AMD → 4bit disabled
    max_seq_length=2048,  # Increased: larger boards need longer prompts
    dtype=torch.bfloat16,
)

print("Model loaded successfully!")
print(f"Device: {model.device}")
print(f"LoRA rank: {lora_rank} (from config)")

# ── Add newline as EOS token so generation stops after first JSON line ──
# (replaces stop_strings which isn't supported by this GRPOConfig version)
newline_token_id = tokenizer.encode("\n", add_special_tokens=False)[-1]
original_eos = tokenizer.eos_token_id
if original_eos != newline_token_id:
    model.generation_config.eos_token_id = [original_eos, newline_token_id]
    model.config.eos_token_id = [original_eos, newline_token_id]
    print(f"  ✅ Added newline (token {newline_token_id}) as additional EOS → stops after JSON line")
    print(f"     EOS tokens: {model.generation_config.eos_token_id}")

# Add LoRA Adapters

Add LoRA layers for efficient finetuning:

In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank,
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha = lora_rank,           # alpha = rank → scaling factor = 1.0 (stable training)
    lora_dropout = 0.05,              # Small dropout for regularization
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
)
print(f"LoRA config: rank={lora_rank}, alpha={lora_rank}, dropout=0.05")
model.print_trainable_parameters()

# Minesweeper Game Implementation

Custom Minesweeper environment supporting:
- Customizable board size and mine count
- Actions: reveal or flag cells
- Win: reveal all safe cells
- Lose: reveal a mine

In [None]:
from dataclasses import dataclass, field
from typing import List, Tuple, Optional, Set
import random
import math

# ──────────────────────────────────────────────────────────────────────
# Board size configuration — competition spec: n,m ∈ [1,50], mines 0-20%
# ──────────────────────────────────────────────────────────────────────

# Constants
MIN_ROWS, MAX_ROWS = 1, 50
MIN_COLS, MAX_COLS = 1, 50
MIN_MINE_DENSITY = 0.0    # 0% mines allowed (trivial board)
MAX_MINE_DENSITY = 0.20   # 20% of total cells


def sample_board_config(rng=None):
    """Sample a random (rows, cols, num_mines) from the full competition range.

    - rows ∈ [1, 50], cols ∈ [1, 50]
    - mines ∈ [0, floor(0.20 * rows * cols)]
    - Uses a weighted distribution favoring smaller boards during training
      (large boards are rare but included for coverage).
    """
    rng = rng or random.Random()

    # Weighted size distribution: favor small/medium, still cover large
    size_band = rng.random()
    if size_band < 0.30:
        # Small: 1-8
        rows = rng.randint(1, 8)
        cols = rng.randint(1, 8)
    elif size_band < 0.55:
        # Medium: 5-15
        rows = rng.randint(5, 15)
        cols = rng.randint(5, 15)
    elif size_band < 0.75:
        # Large: 10-30
        rows = rng.randint(10, 30)
        cols = rng.randint(10, 30)
    elif size_band < 0.90:
        # XL: 20-40
        rows = rng.randint(20, 40)
        cols = rng.randint(20, 40)
    else:
        # Full range: 1-50 (including extreme cases)
        rows = rng.randint(1, 50)
        cols = rng.randint(1, 50)

    total = rows * cols
    max_mines = int(total * MAX_MINE_DENSITY)  # floor(0.20 * total)

    if max_mines == 0:
        num_mines = 0  # Boards too small for any mines at ≤20%
    else:
        num_mines = rng.randint(0, max_mines)

    return rows, cols, num_mines


def mine_density(rows, cols, num_mines):
    """Compute mine density as a fraction."""
    total = rows * cols
    return num_mines / total if total > 0 else 0.0


@dataclass
class MinesweeperGame:
    rows: int
    cols: int
    num_mines: int
    seed: Optional[int] = None
    _rng: random.Random = field(init=False, repr=False)
    _board: List[List[int]] = field(init=False, repr=False)  # -1 = mine, 0-8 = count
    _revealed: Set[Tuple[int, int]] = field(init=False, repr=False, default_factory=set)
    _flagged: Set[Tuple[int, int]] = field(init=False, repr=False, default_factory=set)
    _state: str = field(default="ongoing", init=False, repr=False)
    _move_count: int = field(default=0, init=False, repr=False)

    def __post_init__(self):
        # ── Input validation — competition spec: n,m ∈ [1,50] ──
        if self.rows < MIN_ROWS or self.cols < MIN_COLS:
            raise ValueError(f"Board too small: {self.rows}x{self.cols} (min {MIN_ROWS}x{MIN_COLS})")
        if self.rows > MAX_ROWS or self.cols > MAX_COLS:
            raise ValueError(f"Board too large: {self.rows}x{self.cols} (max {MAX_ROWS}x{MAX_COLS})")
        if self.num_mines < 0:
            raise ValueError(f"num_mines cannot be negative, got {self.num_mines}")
        if self.num_mines >= self.rows * self.cols:
            raise ValueError(f"Too many mines ({self.num_mines}) for {self.rows}x{self.cols} board")
        # 0 mines is allowed — trivial board, instant win on first reveal cascade

        self._rng = random.Random(self.seed)
        self._board = [[0 for _ in range(self.cols)] for _ in range(self.rows)]
        self._place_mines()
        self._calculate_numbers()

        # ── Edge case: 0 mines → all cells are safe, auto-win on init check ──
        self._check_win()

    def _place_mines(self):
        """Place mines randomly on the board."""
        if self.num_mines == 0:
            return  # No mines to place
        positions = [(r, c) for r in range(self.rows) for c in range(self.cols)]
        mine_positions = self._rng.sample(positions, self.num_mines)
        for r, c in mine_positions:
            self._board[r][c] = -1

    def _calculate_numbers(self):
        """Calculate numbers for each cell based on adjacent mines."""
        for r in range(self.rows):
            for c in range(self.cols):
                if self._board[r][c] == -1:
                    continue
                count = 0
                for dr in [-1, 0, 1]:
                    for dc in [-1, 0, 1]:
                        if dr == 0 and dc == 0:
                            continue
                        nr, nc = r + dr, c + dc
                        if 0 <= nr < self.rows and 0 <= nc < self.cols:
                            if self._board[nr][nc] == -1:
                                count += 1
                self._board[r][c] = count

    def _reveal_cell(self, row: int, col: int) -> bool:
        """Reveal a cell. Returns True if valid move, False if invalid.
        Uses iterative flood-fill to avoid recursion limit on large boards.
        """
        if not (0 <= row < self.rows and 0 <= col < self.cols):
            return False
        if (row, col) in self._revealed or (row, col) in self._flagged:
            return False

        stack = [(row, col)]
        while stack:
            r, c = stack.pop()
            if (r, c) in self._revealed:
                continue
40
            self._revealed.add((r, c))

            # Hit a mine!
            if self._board[r][c] == -1:
                self._state = "failed"
                return True

            # Auto-reveal neighbors if cell is 0
            if self._board[r][c] == 0:
                for dr in [-1, 0, 1]:
                    for dc in [-1, 0, 1]:
                        if dr == 0 and dc == 0:
                            continue
                        nr, nc = r + dr, c + dc
                        if (0 <= nr < self.rows and 0 <= nc < self.cols
                                and (nr, nc) not in self._revealed
                                and (nr, nc) not in self._flagged):
                            stack.append((nr, nc))

        return True

    def _flag_cell(self, row: int, col: int) -> bool:
        """Flag/unflag a cell. Returns True if valid, False if invalid."""
        if not (0 <= row < self.rows and 0 <= col < self.cols):
            return False
        if (row, col) in self._revealed:
            return False

        if (row, col) in self._flagged:
            self._flagged.remove((row, col))
        else:
            self._flagged.add((row, col))
        return True

    def do_action(self, action: dict) -> str:
        """Execute an action and return a status string.

        Returns one of:
          'ok'               - valid move executed
          'mine'             - revealed a mine (game over → state='failed')
          'win'              - game won after this move (all safe cells revealed)
          'invalid_format'   - bad action dict / missing keys / bad types
          'out_of_bounds'    - coordinates outside the board
          'already_revealed' - cell was already revealed
          'flagged_cell'     - tried to reveal a flagged cell
          'invalid_flag'     - tried to flag a revealed cell
          'game_over'        - game was already over before this call

        Only 'mine' sets state='failed'. All other invalid moves
        return an error string but keep the game 'ongoing'.
        NO move limit — game continues until success or failure.
        """
        if self._state != "ongoing":
            return "game_over"

        if not isinstance(action, dict):
            return "invalid_format"

        action_type = action.get("type")
        row = action.get("row")
        col = action.get("col")

        if action_type not in ("reveal", "flag") or row is None or col is None:
            return "invalid_format"

        try:
            row, col = int(row), int(col)
        except (ValueError, TypeError):
            return "invalid_format"

        if not (0 <= row < self.rows and 0 <= col < self.cols):
            return "out_of_bounds"

        if action_type == "reveal":
            if (row, col) in self._revealed:
                return "already_revealed"
            if (row, col) in self._flagged:
                return "flagged_cell"
            self._reveal_cell(row, col)
            self._move_count += 1
        else:  # flag
            if (row, col) in self._revealed:
                return "invalid_flag"
            self._flag_cell(row, col)
            self._move_count += 1

        self._check_win()

        if self._state == "failed":
            return "mine"
        if self._state == "success":
            return "win"
        return "ok"

    def _check_win(self):
        """Check if player has won.

        Win condition: ALL safe (non-mine) cells are revealed.
        With 0 mines, ALL cells are safe → need to reveal every cell.
        """
        if self._state != "ongoing":
            return
        total_cells = self.rows * self.cols
        safe_cells = total_cells - self.num_mines
        if safe_cells == 0:
            self._state = "success"
        elif len(self._revealed) >= safe_cells:
            self._state = "success"

    def get_visible_board(self) -> List[List[str]]:
        """Get board state as player sees it.
        Uses '?' for unrevealed cells (competition standard).
        """
        visible = []
        for r in range(self.rows):
            row = []
            for c in range(self.cols):
                if (r, c) in self._flagged:
                    row.append('F')
                elif (r, c) in self._revealed:
                    val = self._board[r][c]
                    row.append('*' if val == -1 else str(val))
                else:
                    row.append('?')
            visible.append(row)
        return visible

    def state(self) -> str:
        return self._state

    @property
    def move_count(self) -> int:
        return self._move_count

    def get_mine_positions(self) -> Set[Tuple[int, int]]:
        """Return set of all mine positions (for reward computation)."""
        return {(r, c) for r in range(self.rows) for c in range(self.cols)
                if self._board[r][c] == -1}

    def progress(self) -> float:
        """Fraction of safe cells revealed (0.0 to 1.0)."""
        safe_cells = self.rows * self.cols - self.num_mines
        return len(self._revealed) / safe_cells if safe_cells > 0 else 1.0

    def game_phase(self) -> str:
        """Determine the current game phase for prompt selection."""
        if self._move_count == 0 and len(self._revealed) == 0:
            return "opening"
        prog = self.progress()
        if prog >= 0.80:
            return "endgame"
        return "midgame"

    def pretty_print(self) -> str:
        """Pretty print the board."""
        visible = self.get_visible_board()
        lines = []

        # Header — handle up to 2-digit column numbers
        col_width = 3 if self.cols > 10 else 2
        header = "   " + " ".join(f"{i:>{col_width-1}d}" for i in range(self.cols))
        lines.append(header)
        lines.append("  " + "─" * (self.cols * col_width + 1))

        # Board
        for r, row in enumerate(visible):
            sep = " " * (col_width - 1)
            line = f"{r:2d}│ " + sep.join(row)
            lines.append(line)

        return "\n".join(lines)


# ──────────────────────────────────────────────────────────────────────
# Sanity tests
# ──────────────────────────────────────────────────────────────────────
print("Testing MinesweeperGame (competition spec: 1-50 boards, 0-20% mines)...")

# Basic gameplay
g = MinesweeperGame(5, 5, 3, seed=0)
assert g.state() == "ongoing"
assert g.do_action({"type": "reveal", "row": -1, "col": 0}) == "out_of_bounds"
assert g.state() == "ongoing", "BUG: out_of_bounds should NOT end the game"
assert g.do_action({"type": "reveal", "row": 99, "col": 0}) == "out_of_bounds"
assert g.state() == "ongoing"
assert g.do_action({"type": "flag", "row": 0, "col": 0}) == "ok"
assert g.do_action({"type": "reveal", "row": 0, "col": 0}) == "flagged_cell"
assert g.state() == "ongoing", "BUG: flagged_cell should NOT end the game"
assert g.do_action({}) == "invalid_format"
assert g.state() == "ongoing", "BUG: invalid_format should NOT end the game"
print("  ✅ do_action keeps game ongoing on invalid moves")

# Verify '?' is used for unrevealed cells
board = g.get_visible_board()
has_question = any('?' in row for row in board)
assert has_question, "Board should use '?' for unrevealed cells"
print("  ✅ Board uses '?' for unrevealed cells")

# Game phase tracking
assert g.game_phase() == "opening" or g.move_count > 0
print("  ✅ Game phase tracking works")

# ── Edge case: 0 mines board ──
g0 = MinesweeperGame(3, 3, 0, seed=42)
assert g0.state() == "ongoing"
result = g0.do_action({"type": "reveal", "row": 0, "col": 0})
assert result == "win", f"0-mine board should win on first reveal, got {result}"
assert g0.state() == "success"
print("  ✅ 0-mine board → instant win on first reveal")

# ── Edge case: 1x1 board with 0 mines ──
g1x1 = MinesweeperGame(1, 1, 0, seed=42)
assert g1x1.state() == "ongoing"
result = g1x1.do_action({"type": "reveal", "row": 0, "col": 0})
assert result == "win"
print("  ✅ 1x1 board with 0 mines works")

# ── Edge case: 1x1 board — cannot have mines ──
try:
    MinesweeperGame(1, 1, 1, seed=42)
    assert False, "Should have raised ValueError"
except ValueError:
    pass
print("  ✅ 1x1 board with 1 mine correctly rejected")

# ── Edge case: 50x50 board ──
g50 = MinesweeperGame(50, 50, 500, seed=42)
assert g50.state() == "ongoing"
assert g50.rows == 50 and g50.cols == 50
assert mine_density(50, 50, 500) == 0.20
print("  ✅ 50x50 board with 20% mines works")

# ── Edge cases: rectangular ──
g1x50 = MinesweeperGame(1, 50, 10, seed=42)
assert g1x50.rows == 1 and g1x50.cols == 50
g50x1 = MinesweeperGame(50, 1, 10, seed=42)
assert g50x1.rows == 50 and g50x1.cols == 1
print("  ✅ 1x50 and 50x1 boards work")

# ── Edge case: 2x2 with 0 mines ──
g2x2 = MinesweeperGame(2, 2, 0, seed=42)
result = g2x2.do_action({"type": "reveal", "row": 0, "col": 0})
assert result == "win"
print("  ✅ 2x2 board with 0 mines → instant cascade win")

# Variable sizes across the full range
for r, c, m in [(1,1,0), (2,2,0), (3,3,1), (5,5,3), (10,10,20),
                (20,20,80), (30,30,180), (50,50,500), (1,50,10), (50,1,10)]:
    g = MinesweeperGame(r, c, m, seed=42)
    assert g.rows == r and g.cols == c
    board = g.get_visible_board()
    assert len(board) == r and len(board[0]) == c
print(f"  ✅ Variable board sizes (1x1 to 50x50) work")

# Test sample_board_config produces valid configs
rng = random.Random(42)
for _ in range(200):
    r, c, m = sample_board_config(rng)
    assert MIN_ROWS <= r <= MAX_ROWS
    assert MIN_COLS <= c <= MAX_COLS
    assert 0 <= m <= int(r * c * MAX_MINE_DENSITY)
    g = MinesweeperGame(r, c, m, seed=42)
print(f"  ✅ sample_board_config produces valid configs (200 tested)")

# Test progress
g = MinesweeperGame(5, 5, 3, seed=42)
assert g.progress() == 0.0
print(f"  ✅ All game engine tests passed")
print(f"  Board range: {MIN_ROWS}-{MAX_ROWS} rows × {MIN_COLS}-{MAX_COLS} cols")
print(f"  Mine density: {MIN_MINE_DENSITY*100:.0f}%-{MAX_MINE_DENSITY*100:.0f}%")

# Prompt System & Game Logic Helpers

## ReAct Prompt System (based on LAMER paper — 74% win on MineSweeper)
Two prompt modes:
- **Training prompt** — Full ReAct (Reason + Act) with step-by-step reasoning, constraint examples, format enforcement
- **Inference prompt** — ~60% fewer tokens for fast evaluation

## LAMER Paper Key Findings Applied
| Paper Finding | Implementation |
|--------------|----------------|
| ReAct > zero-shot (6.3% vs 4.5% base) | "Think step-by-step → scan → estimate → act" |
| Pre-computed CoT hints boost performance | SAFE/MINE cell lists in every prompt |
| RL training reaches 52% win rate | GRPO with `temperature=1.0`, `num_iterations=2` |
| Center-opening for dense boards | Reward center, penalize edges on density >10% |
| `temperature=0.7` for eval | Optimal eval temperature from paper |

## 3-Tier Board Representation
| Format | Board Size | Display |
|--------|-----------|---------|
| **A (Small)** | 1–20 | Full grid with borders and headers |
| **B (Medium)** | 21–35 | Frontier cells + revealed number regions |
| **C (Large)** | 36–50 | Quadrant summary + critical area snippets |

## Prompt Structure (Training)
```
BOARD STATE → REASONING (STEP 1: scan constraints → STEP 2: probability) → OUTPUT FORMAT (valid + invalid examples)
```

## Symbol Legend
`?`=unrevealed  `F`=flagged  `0`-`8`=revealed safe

## Output Format
```json
{"type": "reveal", "row": 2, "col": 3}
```

In [None]:
import json
import re

# ══════════════════════════════════════════════════════════════════════
# ULTRA-MINIMAL PROMPTING SYSTEM — v2 Complete Rewrite
# ══════════════════════════════════════════════════════════════════════
# PROBLEM: Model outputs verbose explanations + ```json blocks that get 
# truncated by max_new_tokens=128, causing parse failures.
#
# SOLUTION:
#   1. ULTRA-SHORT prompts (~150-250 chars)
#   2. Few-shot examples showing EXACT expected output  
#   3. Explicit "DO NOT" instructions
#   4. Robust parser that handles ```json blocks
# ══════════════════════════════════════════════════════════════════════


def _compute_safe_and_mine_cells(game: MinesweeperGame):
    """Compute both safe and mine cells in a SINGLE pass."""
    safe = set()
    mines = set()

    for r in range(game.rows):
        for c in range(game.cols):
            if (r, c) not in game._revealed:
                continue
            val = game._board[r][c]
            if val <= 0:
                continue

            flags = 0
            unrevealed = []
            for dr in [-1, 0, 1]:
                for dc in [-1, 0, 1]:
                    if dr == 0 and dc == 0:
                        continue
                    nr, nc = r + dr, c + dc
                    if 0 <= nr < game.rows and 0 <= nc < game.cols:
                        if (nr, nc) in game._flagged:
                            flags += 1
                        elif (nr, nc) not in game._revealed:
                            unrevealed.append((nr, nc))

            remaining = val - flags

            if remaining == 0 and unrevealed:
                for cell in unrevealed:
                    safe.add(cell)
            elif remaining > 0 and remaining == len(unrevealed):
                for cell in unrevealed:
                    mines.add(cell)

    return safe, mines


def _compute_safe_cells(game: MinesweeperGame) -> list:
    safe, _ = _compute_safe_and_mine_cells(game)
    return [list(c) for c in safe]


def _compute_mine_cells(game: MinesweeperGame) -> list:
    _, mines = _compute_safe_and_mine_cells(game)
    return [list(c) for c in mines]


def _is_logically_safe(game: MinesweeperGame, row: int, col: int) -> bool:
    safe, _ = _compute_safe_and_mine_cells(game)
    return (row, col) in safe


def _is_logically_mine(game: MinesweeperGame, row: int, col: int) -> bool:
    _, mines = _compute_safe_and_mine_cells(game)
    return (row, col) in mines


# ──────────────────────────────────────────────────────────────────────
# COMPACT BOARD — Minimal representation
# ──────────────────────────────────────────────────────────────────────

def _format_board_compact(game: MinesweeperGame) -> str:
    """Ultra-compact board: just the grid, no headers for small boards."""
    board = game.get_visible_board()
    
    if game.rows <= 10 and game.cols <= 10:
        # Tiny format: no headers
        return "\n".join("".join(row) for row in board)
    else:
        # Larger boards: minimal headers
        lines = []
        for r, row in enumerate(board):
            lines.append(f"{r:2d}|" + "".join(row))
        return "\n".join(lines)


# ──────────────────────────────────────────────────────────────────────
# ULTRA-MINIMAL PROMPT — Forces JSON-only output
# ──────────────────────────────────────────────────────────────────────

def format_state_for_llm(game: MinesweeperGame) -> str:
    """Generate ULTRA-SHORT prompt with few-shot example.
    
    Key insight: The model keeps outputting explanations because it's 
    instruction-tuned. We MUST show it the exact format with an example.
    """
    if game.state() == "success":
        return '{"type":"reveal","row":0,"col":0}'  # Dummy, game over
    
    rows, cols = game.rows, game.cols
    mines = game.num_mines
    
    # === CRITICAL: Edge case ultra-explicit prompts ===
    if rows == 1 and cols == 1:
        return f"""1x1 0mines
?
ONLY ONE CELL EXISTS at (0,0). Reveal it to win.
{{"type":"reveal","row":0,"col":0}}"""
    
    if mines == 0:
        # Find first unrevealed cell
        for r in range(rows):
            for c in range(cols):
                if (r, c) not in game._revealed:
                    return f"""{rows}x{cols} 0mines ALL SAFE
{_format_board_compact(game)}
NO MINES - reveal ANY cell to win!
{{"type":"reveal","row":{r},"col":{c}}}"""
    
    # Normal boards - existing logic
    board = _format_board_compact(game)
    
    safe_cells = _compute_safe_cells(game)
    mine_cells = _compute_mine_cells(game)
    
    if safe_cells:
        hint = f"Safe:{safe_cells[0]}"
    elif mine_cells:
        hint = f"Mine:{mine_cells[0]}"
    elif len(game._revealed) == 0:
        hint = f"Start:center"
    else:
        hint = "Guess:any ?"
    
    prompt = f"""{rows}x{cols} {mines}mines {hint}
{board}
Reply ONLY: {{"type":"reveal","row":N,"col":N}} or {{"type":"flag","row":N,"col":N}}
Example: {{"type":"reveal","row":2,"col":3}}
DO NOT explain. DO NOT use ```json. Just the JSON."""
    
    return prompt


def parse_llm_action(response: str) -> dict:
    """Robust parser that handles various output formats."""
    if not response:
        return None
    
    response = response.strip()
    
    # Strategy 1: Try pure JSON first
    try:
        action = json.loads(response)
        if _validate_action(action):
            return action
    except json.JSONDecodeError:
        pass
    
    # Strategy 2: Extract from ```json blocks
    code_block_match = re.search(r'```(?:json)?\s*(\{[^`]*?\})\s*```', response, re.DOTALL)
    if code_block_match:
        try:
            action = json.loads(code_block_match.group(1))
            if _validate_action(action):
                return action
        except json.JSONDecodeError:
            pass
    
    # Strategy 3: Extract JSON after removing code block markers
    cleaned = re.sub(r'```(?:json)?', '', response)
    cleaned = cleaned.strip()
    try:
        action = json.loads(cleaned)
        if _validate_action(action):
            return action
    except json.JSONDecodeError:
        pass
    
    # Strategy 4: Find first valid JSON object
    for match in re.finditer(r'\{[^{}]*?"type"[^{}]*?"row"[^{}]*?"col"[^{}]*?\}', response):
        try:
            action = json.loads(match.group())
            if _validate_action(action):
                return action
        except json.JSONDecodeError:
            continue
    
    # Strategy 5: Field extraction fallback
    type_match = re.search(r'"type"\s*:\s*"(reveal|flag)"', response)
    row_match = re.search(r'"row"\s*:\s*(\d+)', response)
    col_match = re.search(r'"col"\s*:\s*(\d+)', response)
    
    if type_match and row_match and col_match:
        return {
            "type": type_match.group(1),
            "row": int(row_match.group(1)),
            "col": int(col_match.group(1))
        }
    
    return None


def _validate_action(action: dict) -> bool:
    """Check if action dict has all required fields."""
    if not isinstance(action, dict):
        return False
    if "type" not in action or "row" not in action or "col" not in action:
        return False
    if action["type"] not in ("reveal", "flag"):
        return False
    try:
        action["row"] = int(action["row"])
        action["col"] = int(action["col"])
        return True
    except (ValueError, TypeError):
        return False

# ──────────────────────────────────────────────────────────────────────
# Tests
# ──────────────────────────────────────────────────────────────────────

print("=" * 70)
print("ULTRA-MINIMAL PROMPT SYSTEM v2")
print("=" * 70)

# Test prompt lengths (should be ~150-300 chars now)
print("\nPrompt lengths (target: 150-300 chars):")
for rows, cols, mines in [(1,1,0), (3,3,1), (5,5,0), (6,6,5), (10,10,20)]:
    game = MinesweeperGame(rows=rows, cols=cols, num_mines=mines, seed=42)
    if game.state() == "ongoing":
        prompt = format_state_for_llm(game)
        print(f"  {rows:2d}x{cols:2d} m={mines:2d}: {len(prompt):4d} chars")

# Show examples
print("\n" + "=" * 70)
print("EXAMPLE: 1x1 board (0 mines) — was failing before")
print("=" * 70)
game = MinesweeperGame(1, 1, 0, seed=42)
print(format_state_for_llm(game))

print("\n" + "=" * 70)
print("EXAMPLE: 5x5 board (0 mines) — was failing before")
print("=" * 70)
game = MinesweeperGame(5, 5, 0, seed=42)
print(format_state_for_llm(game))

print("\n" + "=" * 70)
print("EXAMPLE: 6x6 board (5 mines)")
print("=" * 70)
game = MinesweeperGame(6, 6, 5, seed=42)
print(format_state_for_llm(game))

# Test parser robustness
print("\n" + "=" * 70)
print("PARSER ROBUSTNESS TESTS")
print("=" * 70)

test_cases = [
    ('Pure JSON', '{"type":"reveal","row":3,"col":3}'),
    ('With newline', '{"type":"reveal","row":3,"col":3}\n'),
    ('Code block', '```json\n{"type":"reveal","row":3,"col":3}\n```'),
    ('Verbose + JSON', 'I will reveal cell at row 3 col 3: {"type":"reveal","row":3,"col":3}'),
    ('Truncated block', '```json\n{"type":"reveal","row":3,"col":3}'),
    ('Partial JSON', 'reveal row 3 col 3 {"type":"reveal","row":3,"col":3'),
    ('Scattered fields', 'type is "reveal" and "row": 3 and "col": 3'),
]

for name, test_input in test_cases:
    result = parse_llm_action(test_input)
    status = "✅" if result else "❌"
    print(f"  {status} {name}: {result}")

print("\n✅ Ultra-minimal prompt system ready")
print("   - Prompts now ~150-300 chars (was 500+)")
print("   - Few-shot example included")
print("   - Explicit 'DO NOT explain' instruction")
print("   - Robust parser handles code blocks & truncation")

# Test Model Before Training

See how the base model performs without finetuning:

In [None]:
from transformers import TextStreamer

game = MinesweeperGame(rows=6, cols=6, num_mines=5, seed=42)
prompt = format_state_for_llm(game)

text = tokenizer.apply_chat_template(
    [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": prompt}
    ],
    tokenize = False,
    add_generation_prompt = True,
)

print("=== Base Model Response ===")
output = model.generate(
    **tokenizer(text, return_tensors = "pt").to("cuda"),
    temperature = 0.7,
    top_p = 0.9,
    max_new_tokens = 128,
    do_sample = True,
    streamer = TextStreamer(tokenizer, skip_prompt = True),
)

# GRPO Reward Functions

Define reward functions to guide the model's learning:

In [None]:
import numpy as np
import math

# ══════════════════════════════════════════════════════════════════════
# REWARD FUNCTIONS v2 — Aggressive penalties for invalid output
# ══════════════════════════════════════════════════════════════════════
# 
# KEY INSIGHT: Previous rewards weren't working because:
#   1. Verbose output with valid JSON still got positive scores
#   2. Invalid JSON penalty (-20) was offset by other rewards
#   3. Model learned verbosity was OK as long as JSON was somewhere
#
# NEW APPROACH:
#   1. PURE JSON gets massive bonus (+15)
#   2. Verbose output with JSON gets PENALTY (not just reduced bonus)
#   3. Invalid JSON gets EXTREME penalty (-30)
#   4. Total invalid penalty across all rewards = -50+
# ══════════════════════════════════════════════════════════════════════


def _reconstruct_game(idx, kwargs):
    """Reconstruct a MinesweeperGame from dataset columns."""
    seeds = kwargs.get("seed", [])
    move_histories = kwargs.get("move_history", [])
    rows_list = kwargs.get("board_rows", [])
    cols_list = kwargs.get("board_cols", [])
    mines_list = kwargs.get("board_mines", [])

    if idx >= len(seeds) or idx >= len(move_histories):
        return None, None

    seed = seeds[idx]
    rows = rows_list[idx] if idx < len(rows_list) else 6
    cols = cols_list[idx] if idx < len(cols_list) else 6
    num_mines = mines_list[idx] if idx < len(mines_list) else 5

    mh_raw = move_histories[idx]
    if isinstance(mh_raw, str):
        move_history = json.loads(mh_raw)
    else:
        move_history = list(mh_raw)

    game = MinesweeperGame(rows=int(rows), cols=int(cols),
                           num_mines=int(num_mines), seed=int(seed))
    for prev in move_history:
        result = game.do_action(prev)
        if result == "mine":
            return None, None

    return game, move_history


def _difficulty_multiplier(game) -> float:
    """Difficulty-based reward multiplier (1.0 to 1.5)."""
    board_size = game.rows * game.cols
    if board_size == 0:
        return 1.0
    density = game.num_mines / board_size if board_size > 0 else 0
    return 1.0 + 0.5 * min(1.0, density / 0.20)


# ──────────────────────────────────────────────────────────────────────
# Reward 1: FORMAT REWARD — Maximum weight on pure JSON output
# ──────────────────────────────────────────────────────────────────────

def valid_json_reward(prompts, completions, **kwargs):
    """Format reward with EXTREME focus on pure JSON.
    
    Scoring (designed so verbose output is ALWAYS worse):
      Pure JSON only:      +15  (best case)
      JSON + whitespace:   +10  (acceptable)
      JSON + <20 chars:    +2   (marginal)
      JSON + 20-50 chars:  -5   (penalty for verbosity)
      JSON + 50-100 chars: -15  (strong penalty)
      JSON + >100 chars:   -25  (severe penalty for explanation)
      No valid JSON:       -30  (EXTREME PENALTY)
    
    This ensures verbosity is ALWAYS punished, even if JSON is valid.
    """
    scores = []
    
    for completion in completions:
        response = completion[0]["content"].strip() if completion else ""
        action = parse_llm_action(response)
        
        # ═══ NO VALID JSON: EXTREME PENALTY ═══
        if action is None:
            scores.append(-30.0)
            continue
        
        # ═══ VALID JSON: Score based on purity ═══
        response_len = len(response)
        
        # Check if response is pure JSON
        try:
            parsed = json.loads(response)
            if isinstance(parsed, dict) and "type" in parsed:
                scores.append(15.0)  # Pure JSON - maximum reward
                continue
        except json.JSONDecodeError:
            pass
        
        # JSON found but with extra content - PENALIZE verbosity
        json_match = re.search(r'\{[^{}]*\}', response)
        if json_match:
            json_len = len(json_match.group())
            extra_chars = response_len - json_len
            
            if extra_chars <= 5:
                scores.append(10.0)   # Just whitespace/newline
            elif extra_chars <= 20:
                scores.append(2.0)    # Minor extra (maybe closing ```)
            elif extra_chars <= 50:
                scores.append(-5.0)   # PENALTY: some explanation
            elif extra_chars <= 100:
                scores.append(-15.0)  # STRONG PENALTY: verbose
            else:
                scores.append(-25.0)  # SEVERE: full explanation
        else:
            scores.append(-30.0)  # Shouldn't happen
    
    return scores


# ──────────────────────────────────────────────────────────────────────
# Reward 2: GAMEPLAY REWARD — Valid game actions
# ──────────────────────────────────────────────────────────────────────

def gameplay_scores(prompts, completions, **kwargs):
    """Gameplay reward - penalizes invalid moves heavily.
    
    Invalid JSON:          -20 (also penalized in format reward)
    Out of bounds:         -25
    Already revealed/flag: -25
    Hit mine:              -20
    Valid move:            +10 to +20
    Win bonus:             +50
    """
    scores = []

    for idx, completion in enumerate(completions):
        response = completion[0]["content"] if completion else ""
        action = parse_llm_action(response)

        # Invalid JSON - compound penalty
        if action is None:
            scores.append(-20.0)
            continue

        game, move_history = _reconstruct_game(idx, kwargs)
        if game is None:
            scores.append(0.0)
            continue

        if game.state() != "ongoing":
            scores.append(0.0)
            continue

        row, col = action["row"], action["col"]
        action_type = action["type"]
        diff_mult = _difficulty_multiplier(game)

        # ═══ OUT OF BOUNDS: SEVERE PENALTY ═══
        if not (0 <= row < game.rows and 0 <= col < game.cols):
            scores.append(-25.0 * diff_mult)
            continue

        # ═══ 0-MINE BOARD: Simple handling ═══
        if game.num_mines == 0:
            if action_type == "reveal":
                if (row, col) in game._revealed:
                    scores.append(-25.0)  # Already revealed
                else:
                    # Check for win
                    game_copy = MinesweeperGame(
                        rows=game.rows, cols=game.cols,
                        num_mines=0, seed=kwargs.get("seed", [0])[idx]
                    )
                    for prev in move_history:
                        game_copy.do_action(prev)
                    result = game_copy.do_action(action)
                    if result == "win":
                        scores.append(50.0)  # Win bonus
                    else:
                        scores.append(15.0)  # Valid reveal
            else:
                scores.append(-15.0)  # Don't flag 0-mine boards
            continue

        # ═══ STANDARD BOARD ═══
        safe_set, mine_set = _compute_safe_and_mine_cells(game)

        if action_type == "reveal":
            # Already revealed
            if (row, col) in game._revealed:
                scores.append(-25.0 * diff_mult)
                continue
            
            # Flagged cell
            if (row, col) in game._flagged:
                scores.append(-20.0 * diff_mult)
                continue
            
            # Hit mine
            if game._board[row][col] == -1:
                scores.append(-20.0 * diff_mult)
                continue
            
            # Valid reveal
            score = 15.0 if (row, col) in safe_set else 10.0
            
            # Check for win
            game_copy = MinesweeperGame(
                rows=game.rows, cols=game.cols,
                num_mines=game.num_mines, seed=kwargs.get("seed", [0])[idx]
            )
            for prev in move_history:
                game_copy.do_action(prev)
            result = game_copy.do_action(action)
            if result == "win":
                score += 50.0
            
            scores.append(score * diff_mult)

        elif action_type == "flag":
            # Already flagged
            if (row, col) in game._flagged:
                scores.append(-25.0 * diff_mult)
                continue
            
            # Flag revealed cell
            if (row, col) in game._revealed:
                scores.append(-25.0 * diff_mult)
                continue
            
            # Flag mine (good) vs flag safe cell (bad)
            if game._board[row][col] == -1:
                score = 20.0 if (row, col) in mine_set else 12.0
            else:
                score = -10.0  # Flagged safe cell
            
            scores.append(score * diff_mult)

    return scores


# ──────────────────────────────────────────────────────────────────────
# Reward 3: STRATEGIC REWARD — Bonus for smart play
# ──────────────────────────────────────────────────────────────────────

def strategic_reward(prompts, completions, **kwargs):
    """Strategic reward - smaller bonuses for good strategy."""
    scores = []

    for idx, completion in enumerate(completions):
        response = completion[0]["content"] if completion else ""
        action = parse_llm_action(response)

        # Invalid gets penalty here too
        if action is None:
            scores.append(-10.0)
            continue

        game, move_history = _reconstruct_game(idx, kwargs)
        if game is None or game.state() != "ongoing":
            scores.append(0.0)
            continue

        row, col = action["row"], action["col"]
        action_type = action["type"]

        # Out of bounds
        if not (0 <= row < game.rows and 0 <= col < game.cols):
            scores.append(-10.0)
            continue

        # Simple 0-mine handling
        if game.num_mines == 0:
            scores.append(5.0 if action_type == "reveal" else -5.0)
            continue

        safe_set, mine_set = _compute_safe_and_mine_cells(game)
        score = 0.0

        # Opening move bonus
        if len(game._revealed) == 0 and action_type == "reveal":
            center_r, center_c = game.rows // 2, game.cols // 2
            dist = abs(row - center_r) + abs(col - center_c)
            score += max(0, 5 - dist)

        # Logical move bonus
        if action_type == "reveal" and (row, col) in safe_set:
            score += 5.0
        elif action_type == "flag" and (row, col) in mine_set:
            score += 5.0

        scores.append(score)

    return scores


# ──────────────────────────────────────────────────────────────────────
# Tests
# ──────────────────────────────────────────────────────────────────────

print("=" * 70)
print("REWARD FUNCTIONS v2 — Aggressive penalties for verbosity")
print("=" * 70)

test_prompts = ["test"]
test_kwargs = {
    "seed": [42],
    "move_history": ["[]"],
    "board_rows": [6],
    "board_cols": [6],
    "board_mines": [5],
}

# Test cases
print("\n1. PURE JSON (best case):")
pure_json = [[{"role": "assistant", "content": '{"type":"reveal","row":3,"col":3}'}]]
r1 = valid_json_reward(test_prompts, pure_json, **test_kwargs)[0]
r2 = gameplay_scores(test_prompts, pure_json, **test_kwargs)[0]
r3 = strategic_reward(test_prompts, pure_json, **test_kwargs)[0]
total = r1 * 0.35 + r2 * 0.50 + r3 * 0.15
print(f"   format={r1:+.1f} gameplay={r2:+.1f} strategic={r3:+.1f} → TOTAL={total:+.2f}")

print("\n2. JSON with code block (common model output):")
code_block = [[{"role": "assistant", "content": '```json\n{"type":"reveal","row":3,"col":3}\n```'}]]
r1 = valid_json_reward(test_prompts, code_block, **test_kwargs)[0]
r2 = gameplay_scores(test_prompts, code_block, **test_kwargs)[0]
r3 = strategic_reward(test_prompts, code_block, **test_kwargs)[0]
total = r1 * 0.35 + r2 * 0.50 + r3 * 0.15
print(f"   format={r1:+.1f} gameplay={r2:+.1f} strategic={r3:+.1f} → TOTAL={total:+.2f}")

print("\n3. VERBOSE with JSON (what model currently does):")
verbose = [[{"role": "assistant", "content": 'Given the scenario, I will reveal row 3 col 3: {"type":"reveal","row":3,"col":3}'}]]
r1 = valid_json_reward(test_prompts, verbose, **test_kwargs)[0]
r2 = gameplay_scores(test_prompts, verbose, **test_kwargs)[0]
r3 = strategic_reward(test_prompts, verbose, **test_kwargs)[0]
total = r1 * 0.35 + r2 * 0.50 + r3 * 0.15
print(f"   format={r1:+.1f} gameplay={r2:+.1f} strategic={r3:+.1f} → TOTAL={total:+.2f}")

print("\n4. INVALID JSON (worst case):")
invalid = [[{"role": "assistant", "content": "I think row 3 col 3 is safe because..."}]]
r1 = valid_json_reward(test_prompts, invalid, **test_kwargs)[0]
r2 = gameplay_scores(test_prompts, invalid, **test_kwargs)[0]
r3 = strategic_reward(test_prompts, invalid, **test_kwargs)[0]
total = r1 * 0.35 + r2 * 0.50 + r3 * 0.15
print(f"   format={r1:+.1f} gameplay={r2:+.1f} strategic={r3:+.1f} → TOTAL={total:+.2f}")

print("\n5. OUT OF BOUNDS:")
oob = [[{"role": "assistant", "content": '{"type":"reveal","row":99,"col":99}'}]]
r1 = valid_json_reward(test_prompts, oob, **test_kwargs)[0]
r2 = gameplay_scores(test_prompts, oob, **test_kwargs)[0]
r3 = strategic_reward(test_prompts, oob, **test_kwargs)[0]
total = r1 * 0.35 + r2 * 0.50 + r3 * 0.15
print(f"   format={r1:+.1f} gameplay={r2:+.1f} strategic={r3:+.1f} → TOTAL={total:+.2f}")

print("\n" + "=" * 70)
print("REWARD SUMMARY (with weights [0.35, 0.50, 0.15]):")
print("=" * 70)
print("  Pure JSON + valid move:  ~+13.5 total")
print("  Code block + valid:      ~+8 total")
print("  Verbose + valid:         ~+3 total (still positive but much lower)")
print("  Invalid JSON:            ~-22 total (SEVERE)")
print("  Out of bounds:           ~-13 total")
print()
print("✅ Verbosity is now ALWAYS worse than concise output")

# Exhaustive Training Dataset Generation

6-phase composition targeting **4000 samples** with density-stratified sampling:

| Phase | Budget | Description |
|-------|--------|-------------|
| 1. Edge Cases | 10% | 50+ explicit configs (trivial, linear, rectangular, density extremes) |
| 2. Opening | 25% | 75% fresh + 25% single-move — fix 75% early death rate |
| 3. Pattern-Specific | 15% | Satisfied numbers (60%) + multi-region boards (40%) |
| 4. Mid-Game | 25% | 3-15 moves, progressive flagging 10%→30%→50% |
| 5. Endgame | 15% | 80-98% revealed, flag accounting, completion strategy |
| 6. Forced Guess | 10% | No logical deductions available — probability reasoning |

In [None]:
import os
from datasets import Dataset

# ══════════════════════════════════════════════════════════════════════
#  EXHAUSTIVE DATASET GENERATION
#  Fixes: 75% early deaths, no pattern training, 0.7% endgame,
#         no forced-guess scenarios, insufficient flag training
# ══════════════════════════════════════════════════════════════════════


# ──────────────────────────────────────────────────────────────────────
# Helper: Smart move selection with progressive flagging
# ──────────────────────────────────────────────────────────────────────

def _smart_reveal(game, rng):
    """Pick a smart cell to reveal during history generation.
    Prefers safe cells (if known) to avoid hitting mines and losing the game.
    Falls back to random unrevealed cell.
    """
    safe_set, _ = _compute_safe_and_mine_cells(game)
    if safe_set:
        return rng.choice(list(safe_set))

    # Fallback: random unrevealed, unflagged cell
    unrevealed = [(r, c) for r in range(game.rows) for c in range(game.cols)
                  if (r, c) not in game._revealed and (r, c) not in game._flagged]
    if not unrevealed:
        return None
    return rng.choice(unrevealed)


def _smart_flag(game, rng):
    """Pick a logically certain mine to flag, or None if none available."""
    _, mine_set = _compute_safe_and_mine_cells(game)
    mine_candidates = [c for c in mine_set if c not in game._flagged]
    if mine_candidates:
        return rng.choice(mine_candidates)
    return None  # Don't flag randomly — creates bad training data


def _progressive_flag_probability(game):
    """Adaptive flagging based on game progress (LAMER-inspired).
    - Early game (0-30%):  10% flagging (explore first)
    - Mid game (30-70%):   30% flagging (deduce mines)
    - Late game (70-100%): 50% flagging (lock in certain mines)
    """
    progress = game.progress()
    if progress < 0.30:
        return 0.10
    elif progress < 0.70:
        return 0.30
    else:
        return 0.50


def _play_smart_moves(game, rng, num_moves, use_progressive_flags=True):
    """Play num_moves smart moves on a game. Returns move_history list.
    Uses progressive flagging strategy when enabled.
    Returns early if game ends or gets stuck.

    Fix #1: Added stuck_count to prevent infinite loops when no valid
    moves can be found (e.g., all cells revealed/flagged but game ongoing).
    """
    move_history = []
    stuck_count = 0
    MAX_STUCK = 10  # Abort after 10 consecutive failed attempts

    for _ in range(num_moves):
        if game.state() != "ongoing":
            break

        action_dict = None
        flag_prob = _progressive_flag_probability(game) if use_progressive_flags else 0.15

        # Try flagging with adaptive probability
        if rng.random() < flag_prob:
            flag_target = _smart_flag(game, rng)
            if flag_target:
                action_dict = {"type": "flag", "row": flag_target[0], "col": flag_target[1]}

        # Fall back to reveal
        if action_dict is None:
            reveal_target = _smart_reveal(game, rng)
            if reveal_target is None:
                stuck_count += 1
                if stuck_count >= MAX_STUCK:
                    break  # Prevent infinite loop
                continue
            action_dict = {"type": "reveal", "row": reveal_target[0], "col": reveal_target[1]}

        result = game.do_action(action_dict)
        if result == "mine":
            break  # Hit a mine — stop

        move_history.append(action_dict)
        stuck_count = 0  # Reset on successful move

    return move_history


# ──────────────────────────────────────────────────────────────────────
# Scenario generators: endgame, forced-guess, multi-region
# ──────────────────────────────────────────────────────────────────────

def _generate_endgame_state(rows, cols, num_mines, rng, completion_target=0.85):
    """Generate boards that are 80-98% complete.
    Critical for teaching finishing strategy and flag accounting.
    Returns (game, move_history) or (None, None) on failure.
    """
    seed = rng.randint(0, 999999)
    game = MinesweeperGame(rows=rows, cols=cols, num_mines=num_mines, seed=seed)

    if game.state() != "ongoing":
        return None, None, seed

    safe_total = rows * cols - num_mines
    if safe_total <= 1:
        return None, None, seed

    target_revealed = int(safe_total * rng.uniform(completion_target, 0.98))
    target_revealed = max(1, min(target_revealed, safe_total - 1))

    move_history = []
    stuck_count = 0
    max_stuck = 20

    while len(game._revealed) < target_revealed and game.state() == "ongoing":
        safe_set, mine_set = _compute_safe_and_mine_cells(game)

        if safe_set:
            target = rng.choice(list(safe_set))
            action_dict = {"type": "reveal", "row": target[0], "col": target[1]}
        else:
            # No logical moves — reveal a random safe cell (we know the board)
            all_safe_unrevealed = [
                (r, c) for r in range(rows) for c in range(cols)
                if game._board[r][c] != -1
                and (r, c) not in game._revealed
                and (r, c) not in game._flagged
            ]
            if not all_safe_unrevealed:
                break
            target = rng.choice(all_safe_unrevealed)
            action_dict = {"type": "reveal", "row": target[0], "col": target[1]}
            stuck_count += 1
            if stuck_count >= max_stuck:
                break

        result = game.do_action(action_dict)
        if result == "mine":
            return None, None, seed
        move_history.append(action_dict)

    # Optionally add some flags in endgame
    if game.state() == "ongoing":
        _, mine_set = _compute_safe_and_mine_cells(game)
        flaggable = [c for c in mine_set if c not in game._flagged]
        if flaggable and rng.random() < 0.7:
            num_flags = rng.randint(1, min(len(flaggable), 5))
            for target in rng.sample(flaggable, num_flags):
                action_dict = {"type": "flag", "row": target[0], "col": target[1]}
                game.do_action(action_dict)
                move_history.append(action_dict)

    if game.state() != "ongoing":
        return None, None, seed

    return game, move_history, seed


def _generate_forced_guess_state(rows, cols, num_mines, rng):
    """Generate a board state where no 100% certain logical moves exist.
    These teach the model probability-based guessing.
    Returns (game, move_history, seed) or (None, None, seed).
    """
    seed = rng.randint(0, 999999)
    game = MinesweeperGame(rows=rows, cols=cols, num_mines=num_mines, seed=seed)

    if game.state() != "ongoing":
        return None, None, seed

    move_history = []
    max_reveal_attempts = rows * cols

    for _ in range(max_reveal_attempts):
        if game.state() != "ongoing":
            return None, None, seed

        safe_set, mine_set = _compute_safe_and_mine_cells(game)

        if not safe_set and not mine_set and len(game._revealed) > 0:
            # No logical moves available — this is what we want!
            unrevealed_count = sum(
                1 for r in range(rows) for c in range(cols)
                if (r, c) not in game._revealed and (r, c) not in game._flagged
            )
            if unrevealed_count >= 2:
                return game, move_history, seed

        if safe_set:
            target = rng.choice(list(safe_set))
            action_dict = {"type": "reveal", "row": target[0], "col": target[1]}
        else:
            # Must guess — reveal random safe cell (using board knowledge)
            all_safe = [
                (r, c) for r in range(rows) for c in range(cols)
                if game._board[r][c] != -1
                and (r, c) not in game._revealed
                and (r, c) not in game._flagged
            ]
            if not all_safe:
                break
            target = rng.choice(all_safe)
            action_dict = {"type": "reveal", "row": target[0], "col": target[1]}

        result = game.do_action(action_dict)
        if result == "mine":
            return None, None, seed
        move_history.append(action_dict)

    # Fallback: return whatever state we reached (might still have logical moves)
    if game.state() == "ongoing" and len(game._revealed) > 0:
        return game, move_history, seed
    return None, None, seed


def _generate_multi_region_state(rows, cols, num_mines, rng):
    """Create board with multiple separated revealed regions.
    Teaches model to reason across disconnected information.
    Returns (game, move_history, seed) or (None, None, seed).
    """
    if rows < 6 or cols < 6:
        return None, None, 0  # Need space for multiple regions

    seed = rng.randint(0, 999999)
    game = MinesweeperGame(rows=rows, cols=cols, num_mines=num_mines, seed=seed)

    if game.state() != "ongoing":
        return None, None, seed

    # Pick 2-4 starting points in different quadrants
    quadrant_centers = [
        (rows // 4, cols // 4),
        (rows // 4, 3 * cols // 4),
        (3 * rows // 4, cols // 4),
        (3 * rows // 4, 3 * cols // 4),
    ]
    rng.shuffle(quadrant_centers)
    num_regions = rng.randint(2, min(4, len(quadrant_centers)))
    selected = quadrant_centers[:num_regions]

    move_history = []
    for r, c in selected:
        if game.state() != "ongoing":
            break
        r = max(0, min(r, rows - 1))
        c = max(0, min(c, cols - 1))

        if (r, c) not in game._revealed and game._board[r][c] != -1:
            action_dict = {"type": "reveal", "row": r, "col": c}
            result = game.do_action(action_dict)
            if result == "mine":
                return None, None, seed
            move_history.append(action_dict)

        # Reveal a few logical neighbors around this region
        for _ in range(rng.randint(1, 3)):
            if game.state() != "ongoing":
                break
            safe_set, _ = _compute_safe_and_mine_cells(game)
            if safe_set:
                # Prefer safe cells near this region
                nearby = [
                    (sr, sc) for sr, sc in safe_set
                    if abs(sr - r) <= rows // 3 and abs(sc - c) <= cols // 3
                ]
                target = rng.choice(nearby) if nearby else rng.choice(list(safe_set))
                action_dict = {"type": "reveal", "row": target[0], "col": target[1]}
                result = game.do_action(action_dict)
                if result == "mine":
                    return None, None, seed
                move_history.append(action_dict)

    if game.state() == "ongoing" and len(game._revealed) > 0:
        return game, move_history, seed
    return None, None, seed


def _generate_satisfied_numbers_state(rows, cols, num_mines, rng):
    """Generate board with multiple satisfied numbers (easy deductions available).
    Satisfied number = number whose count of adjacent flags equals its value.
    All remaining unrevealed neighbors of a satisfied number are safe.
    Teaches basic constraint satisfaction.
    Returns (game, move_history, seed).
    """
    seed = rng.randint(0, 999999)
    game = MinesweeperGame(rows=rows, cols=cols, num_mines=num_mines, seed=seed)

    if game.state() != "ongoing":
        return None, None, seed

    move_history = []

    # Reveal some cells first
    num_initial = rng.randint(3, max(3, min(15, rows * cols // 4)))
    for _ in range(num_initial):
        if game.state() != "ongoing":
            break
        target = _smart_reveal(game, rng)
        if target is None:
            break
        action_dict = {"type": "reveal", "row": target[0], "col": target[1]}
        result = game.do_action(action_dict)
        if result == "mine":
            return None, None, seed
        move_history.append(action_dict)

    # Now flag logically certain mines to create satisfied numbers
    if game.state() == "ongoing":
        for _ in range(5):
            _, mine_set = _compute_safe_and_mine_cells(game)
            flaggable = [c for c in mine_set if c not in game._flagged]
            if not flaggable:
                break
            target = rng.choice(flaggable)
            action_dict = {"type": "flag", "row": target[0], "col": target[1]}
            game.do_action(action_dict)
            move_history.append(action_dict)

    if game.state() == "ongoing" and len(game._revealed) > 0:
        return game, move_history, seed
    return None, None, seed


# ──────────────────────────────────────────────────────────────────────
# Density-stratified board sampling
# ──────────────────────────────────────────────────────────────────────

DENSITY_TARGETS = [
    # (density_range, weight, label)
    ((0.00, 0.00), 0.08, "Zero mines"),       # 8%  - trivial edge case
    ((0.01, 0.05), 0.17, "Very sparse"),       # 17%
    ((0.05, 0.10), 0.25, "Sparse"),            # 25%
    ((0.10, 0.15), 0.25, "Medium"),            # 25%
    ((0.15, 0.20), 0.25, "Dense/Max"),         # 25%
]


def _sample_board_with_density(rng, target_density_range=None):
    """Sample board config with explicit density control.
    If target_density_range is None, picks one from DENSITY_TARGETS.
    """
    if target_density_range is None:
        # Weighted random density band
        weights = [w for _, w, _ in DENSITY_TARGETS]
        ranges = [r for r, _, _ in DENSITY_TARGETS]
        idx = rng.choices(range(len(DENSITY_TARGETS)), weights=weights, k=1)[0]
        target_density_range = ranges[idx]

    # Sample board size — boosted large-board representation for 50×50 support
    size_band = rng.random()
    if size_band < 0.25:
        rows, cols = rng.randint(1, 8), rng.randint(1, 8)      # 25% tiny
    elif size_band < 0.45:
        rows, cols = rng.randint(5, 15), rng.randint(5, 15)    # 20% small
    elif size_band < 0.65:
        rows, cols = rng.randint(10, 30), rng.randint(10, 30)  # 20% medium
    elif size_band < 0.85:
        rows, cols = rng.randint(20, 40), rng.randint(20, 40)  # 20% large
    else:
        rows, cols = rng.randint(30, 50), rng.randint(30, 50)  # 15% XL (30-50)

    total = rows * cols
    min_d, max_d = target_density_range

    if min_d == 0.0 and max_d == 0.0:
        return rows, cols, 0

    min_mines = max(0, int(math.ceil(total * min_d)))
    max_mines = min(int(total * max_d), int(total * MAX_MINE_DENSITY), total - 1)

    if max_mines < min_mines:
        num_mines = max(0, min(min_mines, total - 1))
    else:
        num_mines = rng.randint(min_mines, max_mines)

    return rows, cols, num_mines


# ──────────────────────────────────────────────────────────────────────
# Exhaustive edge-case configs (50+ scenarios)
# ──────────────────────────────────────────────────────────────────────

EDGE_CASE_CONFIGS = [
    # (rows, cols, mines, label)

    # === TRIVIAL BOARDS ===
    (1, 1, 0,   "1x1 trivial"),
    (2, 2, 0,   "2x2 no mines"),
    (3, 3, 0,   "3x3 no mines"),
    (4, 4, 0,   "4x4 no mines"),
    (5, 5, 0,   "5x5 no mines - cascade practice"),

    # === LINEAR BOARDS (1D Minesweeper) ===
    (1, 5, 1,   "1x5 single mine"),
    (1, 10, 2,  "1x10 two mines"),
    (1, 20, 4,  "1x20 row"),
    (1, 50, 10, "1x50 maximum row"),
    (5, 1, 1,   "5x1 column"),
    (10, 1, 2,  "10x1 column"),
    (20, 1, 4,  "20x1 column"),
    (50, 1, 10, "50x1 maximum column"),

    # === TINY BOARDS ===
    (1, 2, 0,   "1x2 trivial"),
    (2, 1, 0,   "2x1 trivial"),
    (2, 2, 1,   "2x2 single mine"),
    (2, 3, 1,   "2x3 mini"),
    (3, 2, 1,   "3x2 mini"),
    (3, 3, 1,   "3x3 single mine"),
    (3, 3, 2,   "3x3 two mines"),
    (4, 4, 3,   "4x4 medium density"),

    # === EXTREME RECTANGULAR ===
    (2, 50, 20, "2x50 ultra-wide"),
    (50, 2, 20, "50x2 ultra-tall"),
    (3, 40, 24, "3x40 extreme ratio"),
    (40, 3, 24, "40x3 extreme ratio"),
    (5, 50, 50, "5x50 max density wide"),
    (50, 5, 50, "50x5 max density tall"),

    # === CLASSIC MINESWEEPER SIZES ===
    (8, 8, 1,   "8x8 very sparse"),
    (8, 8, 5,   "8x8 sparse"),
    (8, 8, 10,  "8x8 medium - classic beginner"),
    (8, 8, 13,  "8x8 max density"),
    (16, 16, 10, "16x16 sparse"),
    (16, 16, 40, "16x16 medium - classic intermediate"),
    (16, 16, 51, "16x16 max density"),
    (16, 30, 20, "16x30 rectangular sparse"),
    (16, 30, 96, "16x30 rectangular dense - classic expert"),

    # === LARGE BOARDS ===
    (20, 20, 10, "20x20 very sparse"),
    (20, 20, 80, "20x20 max density"),
    (25, 25, 30, "25x25 sparse"),
    (25, 25, 125, "25x25 max density"),
    (30, 30, 50, "30x30 sparse"),
    (30, 30, 180, "30x30 max density"),
    (40, 40, 100, "40x40 sparse"),
    (40, 40, 320, "40x40 max density"),

    # === MAXIMUM SIZE ===
    (50, 50, 10,  "50x50 ultra-sparse"),
    (50, 50, 100, "50x50 sparse"),
    (50, 50, 250, "50x50 medium"),
    (50, 50, 500, "50x50 maximum density"),

    # === DENSITY EDGE CASES ===
    (10, 10, 0,  "10x10 no mines"),
    (15, 15, 0,  "15x15 no mines - large trivial"),
    (20, 20, 0,  "20x20 no mines - huge trivial"),
    (10, 10, 1,  "10x10 single mine"),
    (20, 20, 1,  "20x20 single mine in large board"),

    # === MORE RECTANGULAR VARIETY ===
    (3, 50, 30,  "3x50 wide"),
    (50, 3, 30,  "50x3 tall"),
    (5, 15, 15,  "5x15 wide rect"),
    (15, 5, 15,  "15x5 tall rect"),
    (6, 8, 6,    "6x8 rect"),
    (8, 6, 6,    "8x6 rect"),
    (7, 13, 18,  "7x13 odd rect"),
    (13, 7, 18,  "13x7 odd rect"),
    (15, 30, 90, "15x30 wide large"),
]


# ──────────────────────────────────────────────────────────────────────
# Dataset item builder (DRY helper)
# ──────────────────────────────────────────────────────────────────────

SYSTEM_PROMPT = "You are a Minesweeper AI. Output ONLY valid JSON. No explanations, no reasoning, just {\"type\":\"reveal\",\"row\":N,\"col\":N}."

def _build_dataset_item(game, seed, move_history):
    """Build a single dataset item dict from a game state."""
    prompt_text = format_state_for_llm(game)
    return {
        "prompt": [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": prompt_text}
        ],
        "seed": seed,
        "move_history": json.dumps(move_history),
        "board_rows": game.rows,
        "board_cols": game.cols,
        "board_mines": game.num_mines,
    }


# ══════════════════════════════════════════════════════════════════════
#  MAIN GENERATOR — 6-Phase Exhaustive Composition
# ══════════════════════════════════════════════════════════════════════

def generate_exhaustive_dataset(num_samples=4000, rng_seed=42):
    """
    Comprehensive Minesweeper training dataset covering ALL scenarios.

    6-Phase composition:
      Phase 1: Edge cases           — 10%  (50+ explicit configs)
      Phase 2: Opening-heavy        — 25%  (fresh + single-move boards)
      Phase 3: Pattern-specific     — 15%  (satisfied numbers, multi-region)
      Phase 4: Mid-game deduction   — 25%  (core logical reasoning)
      Phase 5: Endgame completion   — 15%  (80-98% revealed, flag accounting)
      Phase 6: Forced guess         — 10%  (no logical moves available)

    Improvements over previous version:
      ✅ 50+ edge case configs (was 25)
      ✅ Progressive flagging 10%→30%→50% by game phase (was flat 15%)
      ✅ Density-stratified board sampling
      ✅ Dedicated endgame generator (was 0.7% late-game)
      ✅ Forced-guess scenario training (was 0%)
      ✅ Multi-region disconnected boards (was 0%)
      ✅ Satisfied-number pattern training (was 0%)
      ✅ 4000 samples (was 3000)
    """
    rng = random.Random(rng_seed)
    np.random.seed(rng_seed)

    dataset_items = []
    phase_counts = {
        "edge_case": 0, "opening": 0, "pattern": 0,
        "midgame": 0, "endgame": 0, "forced_guess": 0,
    }
    config_counts = {}
    density_counts = {"zero": 0, "very_sparse": 0, "sparse": 0, "medium": 0, "dense": 0}

    def _track_config(game):
        key = f"{game.rows}x{game.cols}m{game.num_mines}"
        config_counts[key] = config_counts.get(key, 0) + 1
        d = game.num_mines / (game.rows * game.cols) if game.rows * game.cols > 0 else 0
        if d == 0:
            density_counts["zero"] += 1
        elif d <= 0.05:
            density_counts["very_sparse"] += 1
        elif d <= 0.10:
            density_counts["sparse"] += 1
        elif d <= 0.15:
            density_counts["medium"] += 1
        else:
            density_counts["dense"] += 1

    # Budget per phase
    n_edge     = int(num_samples * 0.10)
    n_opening  = int(num_samples * 0.25)
    n_pattern  = int(num_samples * 0.15)
    n_midgame  = int(num_samples * 0.25)
    n_endgame  = int(num_samples * 0.15)
    n_forced   = num_samples - n_edge - n_opening - n_pattern - n_midgame - n_endgame

    print(f"  Phase budgets: edge={n_edge}, opening={n_opening}, pattern={n_pattern}, "
          f"midgame={n_midgame}, endgame={n_endgame}, forced={n_forced}")

    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    # PHASE 1: Edge Cases (10%) — 50+ explicit configs
    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    print("  Phase 1: Edge cases...")
    edge_generated = 0
    edge_idx = 0
    while edge_generated < n_edge:
        ec_rows, ec_cols, ec_mines, ec_label = EDGE_CASE_CONFIGS[edge_idx % len(EDGE_CASE_CONFIGS)]
        edge_idx += 1
        seed = rng.randint(0, 999999)

        game = MinesweeperGame(rows=ec_rows, cols=ec_cols, num_mines=ec_mines, seed=seed)
        if game.state() != "ongoing":
            continue

        # 0-3 moves for variety
        num_moves = rng.randint(0, min(3, max(0, ec_rows * ec_cols - ec_mines - 1)))
        move_history = _play_smart_moves(game, rng, num_moves, use_progressive_flags=True)

        if game.state() == "ongoing":
            dataset_items.append(_build_dataset_item(game, seed, move_history))
            _track_config(game)
            phase_counts["edge_case"] += 1
            edge_generated += 1

    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    # PHASE 2: Opening-Heavy Training (25%) — Fix 75% early death rate
    # 75% fresh (0 moves), 25% single-move
    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    print("  Phase 2: Opening training...")
    opening_generated = 0
    opening_attempts = 0
    while opening_generated < n_opening and opening_attempts < n_opening * 5:
        opening_attempts += 1
        rows, cols, num_mines = _sample_board_with_density(rng)

        seed = rng.randint(0, 999999)
        game = MinesweeperGame(rows=rows, cols=cols, num_mines=num_mines, seed=seed)

        if num_mines == 0:
            if game.state() == "ongoing":
                dataset_items.append(_build_dataset_item(game, seed, []))
                _track_config(game)
                phase_counts["opening"] += 1
                opening_generated += 1
            continue

        if game.state() != "ongoing":
            continue

        # 75% fresh boards, 25% single-move
        num_moves = 0 if rng.random() < 0.75 else 1
        move_history = _play_smart_moves(game, rng, num_moves, use_progressive_flags=False)

        if game.state() == "ongoing":
            dataset_items.append(_build_dataset_item(game, seed, move_history))
            _track_config(game)
            phase_counts["opening"] += 1
            opening_generated += 1

    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    # PHASE 3: Pattern-Specific Scenarios (15%)
    # Mix of: satisfied-number boards (60%), multi-region boards (40%)
    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    print("  Phase 3: Pattern-specific scenarios...")
    pattern_generated = 0
    pattern_attempts = 0
    while pattern_generated < n_pattern and pattern_attempts < n_pattern * 10:
        pattern_attempts += 1

        # 60% satisfied numbers, 40% multi-region
        if rng.random() < 0.60:
            # Satisfied numbers — need boards with mines
            rows, cols, num_mines = _sample_board_with_density(rng, (0.05, 0.20))
            if rows < 3 or cols < 3:
                continue
            game, move_history, seed = _generate_satisfied_numbers_state(
                rows, cols, num_mines, rng
            )
        else:
            # Multi-region — need larger boards
            rows, cols, num_mines = _sample_board_with_density(rng, (0.05, 0.15))
            rows = max(rows, 8)
            cols = max(cols, 8)
            num_mines = min(num_mines, int(rows * cols * 0.15))
            if num_mines < 1:
                num_mines = max(1, int(rows * cols * 0.05))
            game, move_history, seed = _generate_multi_region_state(
                rows, cols, num_mines, rng
            )

        if game is not None and game.state() == "ongoing":
            dataset_items.append(_build_dataset_item(game, seed, move_history))
            _track_config(game)
            phase_counts["pattern"] += 1
            pattern_generated += 1

    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    # PHASE 4: Mid-Game Logical Deduction (25%) — Core gameplay
    # 3-15 moves played, progressive flagging, density-stratified
    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    print("  Phase 4: Mid-game deduction...")
    midgame_generated = 0
    midgame_attempts = 0
    while midgame_generated < n_midgame and midgame_attempts < n_midgame * 5:
        midgame_attempts += 1
        rows, cols, num_mines = _sample_board_with_density(rng)

        if num_mines == 0:
            continue  # Skip zero-mine for midgame

        seed = rng.randint(0, 999999)
        game = MinesweeperGame(rows=rows, cols=cols, num_mines=num_mines, seed=seed)

        if game.state() != "ongoing":
            continue

        total_safe = rows * cols - num_mines
        max_moves = min(15, max(3, total_safe - 1))
        num_moves = rng.randint(3, max_moves)

        move_history = _play_smart_moves(game, rng, num_moves, use_progressive_flags=True)

        if game.state() == "ongoing" and len(game._revealed) > 0:
            dataset_items.append(_build_dataset_item(game, seed, move_history))
            _track_config(game)
            phase_counts["midgame"] += 1
            midgame_generated += 1

    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    # PHASE 5: Endgame Completion (15%) — 80-98% revealed
    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    print("  Phase 5: Endgame completion...")
    endgame_generated = 0
    endgame_attempts = 0
    while endgame_generated < n_endgame and endgame_attempts < n_endgame * 10:
        endgame_attempts += 1
        rows, cols, num_mines = _sample_board_with_density(rng, (0.05, 0.20))

        if num_mines < 1 or rows * cols < 8:
            continue

        completion = rng.uniform(0.80, 0.95)
        game, move_history, seed = _generate_endgame_state(
            rows, cols, num_mines, rng, completion_target=completion
        )

        if game is not None and game.state() == "ongoing":
            dataset_items.append(_build_dataset_item(game, seed, move_history))
            _track_config(game)
            phase_counts["endgame"] += 1
            endgame_generated += 1

    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    # PHASE 6: Forced Guess Scenarios (10%) — No logical deductions
    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    print("  Phase 6: Forced guess scenarios...")
    forced_generated = 0
    forced_attempts = 0
    while forced_generated < n_forced and forced_attempts < n_forced * 15:
        forced_attempts += 1
        # Dense boards more likely to produce forced-guess states
        rows, cols, num_mines = _sample_board_with_density(rng, (0.10, 0.20))

        if num_mines < 2 or rows * cols < 6:
            continue

        game, move_history, seed = _generate_forced_guess_state(
            rows, cols, num_mines, rng
        )

        if game is not None and game.state() == "ongoing":
            dataset_items.append(_build_dataset_item(game, seed, move_history))
            _track_config(game)
            phase_counts["forced_guess"] += 1
            forced_generated += 1

    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    # Shuffle and trim
    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    rng.shuffle(dataset_items)
    dataset_items = dataset_items[:num_samples]
    ds = Dataset.from_list(dataset_items)

    return ds, config_counts, phase_counts, density_counts


# ══════════════════════════════════════════════════════════════════════
#  Generate & Analyze
# ══════════════════════════════════════════════════════════════════════

print("=" * 70)
print("  EXHAUSTIVE DATASET GENERATION")
print("  4000 samples | 6 phases | 50+ edge cases | density-stratified")
print("=" * 70)
print()

dataset, config_counts, phase_counts, density_counts = generate_exhaustive_dataset(
    num_samples=4000, rng_seed=42
)

print(f"\n{'─'*70}")
print(f"Created {len(dataset)} training examples\n")

# Phase distribution
print("Phase distribution:")
for phase, count in phase_counts.items():
    pct = count / len(dataset) * 100 if len(dataset) > 0 else 0
    print(f"  {phase:14s}: {count:4d} ({pct:5.1f}%)")

# Density distribution
print(f"\nDensity distribution:")
for band, count in density_counts.items():
    pct = count / len(dataset) * 100 if len(dataset) > 0 else 0
    print(f"  {band:14s}: {count:4d} ({pct:5.1f}%)")

# Board size distribution (top 20)
print(f"\nBoard size distribution (top 20):")
sorted_configs = sorted(config_counts.items(), key=lambda x: -x[1])
for config, count in sorted_configs[:20]:
    pct = count / len(dataset) * 100 if len(dataset) > 0 else 0
    print(f"  {config:16s}: {count:4d} ({pct:.1f}%)")
if len(sorted_configs) > 20:
    print(f"  ... and {len(sorted_configs) - 20} more unique configs")
print(f"\nTotal unique board configs: {len(config_counts)}")

# Board size statistics
all_rows = [item["board_rows"] for item in dataset]
all_cols = [item["board_cols"] for item in dataset]
all_mines = [item["board_mines"] for item in dataset]
print(f"\nBoard size statistics:")
print(f"  Rows:  min={min(all_rows)}, max={max(all_rows)}, mean={np.mean(all_rows):.1f}")
print(f"  Cols:  min={min(all_cols)}, max={max(all_cols)}, mean={np.mean(all_cols):.1f}")
print(f"  Mines: min={min(all_mines)}, max={max(all_mines)}, mean={np.mean(all_mines):.1f}")
densities = [m / (r * c) * 100 for r, c, m in zip(all_rows, all_cols, all_mines) if r * c > 0]
print(f"  Density: min={min(densities):.1f}%, max={max(densities):.1f}%, mean={np.mean(densities):.1f}%")

zero_mine_count = sum(1 for m in all_mines if m == 0)
print(f"  Zero-mine boards: {zero_mine_count} ({zero_mine_count/len(dataset)*100:.1f}%)")

move_counts = [len(json.loads(item["move_history"])) for item in dataset]
print(f"\nMove statistics:")
print(f"  Min: {min(move_counts)}, Max: {max(move_counts)}, "
      f"Mean: {np.mean(move_counts):.1f}, Median: {np.median(move_counts):.1f}")

# Phase-specific move stats
print(f"\nLogical deduction coverage:")
has_safe = 0
has_mine = 0
has_both = 0
for item in dataset:
    mh = json.loads(item["move_history"])
    has_flag = any(m.get("type") == "flag" for m in mh)
    has_rev = any(m.get("type") == "reveal" for m in mh)
    if has_flag:
        has_mine += 1
    if has_rev:
        has_safe += 1
    if has_flag and has_rev:
        has_both += 1
print(f"  Samples with reveals: {has_safe} ({has_safe/len(dataset)*100:.1f}%)")
print(f"  Samples with flags:   {has_mine} ({has_mine/len(dataset)*100:.1f}%)")
print(f"  Samples with both:    {has_both} ({has_both/len(dataset)*100:.1f}%)")

# Verify dataset columns
print(f"\nDataset columns: {dataset.column_names}")
print(f"  ✅ board_rows, board_cols, board_mines present for reward functions")

# ── Save dataset to JSON ──
dataset_json_path = "minesweeper_dataset.json"

json_records = []
for item in dataset:
    record = {
        "prompt_text": item["prompt"][-1]["content"],  # Last message is user content
        "move_history": item["move_history"],
        "board_rows": item["board_rows"],
        "board_cols": item["board_cols"],
        "board_mines": item["board_mines"],
        "prompt_text": item["prompt"][0]["content"],
    }
    json_records.append(record)

with open(dataset_json_path, "w") as f:
    json.dump(json_records, f, indent=2)

print(dataset[0]["prompt"][-1]["content"][:300] + "...")

print(f"   {len(json_records)} records with fields: seed, move_history, board_rows/cols/mines, prompt_text")print(dataset[0]["prompt"][0]["content"][:300] + "...")

      f"{dataset[0]['board_mines']} mines):")
print(f"\nSample prompt ({dataset[0]['board_rows']}x{dataset[0]['board_cols']}, "

# Configure GRPO Training

Set up GRPO trainer with all hyperparameters:

In [None]:
from trl import GRPOConfig, GRPOTrainer
import os
import json
from datetime import datetime
from glob import glob

# ══════════════════════════════════════════════════════════════════════
# CHECKPOINT MANAGEMENT SYSTEM
# ══════════════════════════════════════════════════════════════════════
# - Saves checkpoints every 100 steps
# - Tracks latest checkpoint in a JSON file
# - Automatically resumes from last checkpoint
# - Preserves all checkpoints for hackathon submission
# ══════════════════════════════════════════════════════════════════════

CHECKPOINT_DIR = "minesweeper_checkpoints"
CHECKPOINT_TRACKER = os.path.join(CHECKPOINT_DIR, "checkpoint_tracker.json")

def get_checkpoint_tracker():
    """Load or create checkpoint tracker."""
    if os.path.exists(CHECKPOINT_TRACKER):
        with open(CHECKPOINT_TRACKER, 'r') as f:
            return json.load(f)
    return {
        "latest_checkpoint": None,
        "latest_step": 0,
        "total_steps_trained": 0,
        "checkpoints": [],
        "training_sessions": []
    }

def save_checkpoint_tracker(tracker):
    """Save checkpoint tracker."""
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)
    with open(CHECKPOINT_TRACKER, 'w') as f:
        json.dump(tracker, f, indent=2)

def find_latest_checkpoint():
    """Find the latest checkpoint directory."""
    tracker = get_checkpoint_tracker()
    
    # First check tracker
    if tracker["latest_checkpoint"] and os.path.exists(tracker["latest_checkpoint"]):
        return tracker["latest_checkpoint"], tracker["latest_step"]
    
    # Fallback: scan checkpoint directories
    checkpoint_dirs = glob(os.path.join(CHECKPOINT_DIR, "checkpoint-*"))
    if not checkpoint_dirs:
        return None, 0
    
    # Sort by step number
    def get_step(path):
        try:
            return int(os.path.basename(path).split("-")[1])
        except:
            return 0
    
    checkpoint_dirs.sort(key=get_step, reverse=True)
    latest = checkpoint_dirs[0]
    step = get_step(latest)
    
    return latest, step

# Check for existing checkpoint
latest_checkpoint, resume_step = find_latest_checkpoint()

print("=" * 70)
print("CHECKPOINT MANAGEMENT SYSTEM")
print("=" * 70)

if latest_checkpoint:
    print(f"✅ Found existing checkpoint: {latest_checkpoint}")
    print(f"   Will resume from step: {resume_step}")
    RESUME_FROM_CHECKPOINT = latest_checkpoint
else:
    print("📝 No existing checkpoint found. Starting fresh training.")
    RESUME_FROM_CHECKPOINT = None

# ══════════════════════════════════════════════════════════════════════
# TRAINING CONFIG — With checkpoint saving every 100 steps
# ══════════════════════════════════════════════════════════════════════

max_prompt_length = 800
max_completion_length = 96  # Increased from 64 — model needs buffer

# Target total steps (can increase this and resume to train more)
TARGET_TOTAL_STEPS = 2000  # Increase this as needed before hackathon

training_args = GRPOConfig(
    # === Generation ===
    temperature = 0.7,  # FIXED: match eval temperature exactly
    top_p = 0.9,

    # === Optimization ===
    learning_rate = 5e-6,  # Reduced from 1e-5 for stability
    weight_decay = 0.01,
    warmup_ratio = 0.03,
    lr_scheduler_type = "cosine",
    optim = "adamw_8bit",
    max_grad_norm = 1.0,

    # === Batch sizes ===
    logging_steps = 10,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 4,
    num_generations = 8,

    # === Lengths ===
    max_prompt_length = max_prompt_length,
    max_completion_length = max_completion_length,  # Now 96

    # === Training duration ===
    max_steps = TARGET_TOTAL_STEPS,
    
    # === CHECKPOINT SAVING ===
    save_steps = 100,                    # Save every 100 steps
    save_total_limit = None,             # Keep ALL checkpoints (no limit)
    output_dir = CHECKPOINT_DIR,         # Save to checkpoint directory
    save_strategy = "steps",
    
    # === GRPO specific ===
    beta = 0.02,
    num_iterations = 2,

    # === Reward weighting ===
    # Rebalanced: gameplay dominates, format important, strategic bonus
    reward_weights = [0.35, 0.50, 0.15],  # [format, gameplay, strategic]

    # === Other ===
    remove_unused_columns = False,
    report_to = "none",
    seed = 42,
    bf16 = True,
)

print()
print("TRAINING CONFIGURATION:")
print(f"  Target total steps:    {TARGET_TOTAL_STEPS}")
print(f"  Resume from step:      {resume_step if RESUME_FROM_CHECKPOINT else 0}")
print(f"  Remaining steps:       {TARGET_TOTAL_STEPS - resume_step}")
print(f"  Save checkpoint every: {training_args.save_steps} steps")
print(f"  Checkpoint directory:  {CHECKPOINT_DIR}/")
print(f"  Keep all checkpoints:  Yes (save_total_limit=None)")
print()
print("To train more steps before hackathon:")
print("  1. Increase TARGET_TOTAL_STEPS above")
print("  2. Re-run this cell and the training cell")
print("  3. Training will resume from the latest checkpoint")

In [None]:
# ═══════════════════════════════════════════════════════════════════
#  VERIFICATION: System prompt is in EVERY dataset item
# ═══════════════════════════════════════════════════════════════════
print(f"Dataset size: {len(dataset)} items")
print(f"SYSTEM_PROMPT: {SYSTEM_PROMPT!r}\n")

# Check first 3 items
for i in range(min(3, len(dataset))):
    item = dataset[i]
    prompt_messages = item["prompt"]
    has_system = any(
        msg.get("role") == "system" and SYSTEM_PROMPT in msg.get("content", "")
        for msg in prompt_messages
    )
    print(f"  Item {i}: {len(prompt_messages)} messages | "
          f"system_prompt={'✅' if has_system else '❌ MISSING'} | "
          f"roles={[m['role'] for m in prompt_messages]}")

# Spot-check ALL items
missing = [i for i in range(len(dataset))
           if not any(m.get("role") == "system" for m in dataset[i]["prompt"])]
if missing:
    print(f"\n❌ CRITICAL: {len(missing)} items MISSING system prompt: {missing[:10]}...")
else:
    print(f"\n✅ All {len(dataset)} items have system prompt")

# Verify tokenized form includes system prompt text
sample_text = tokenizer.apply_chat_template(
    dataset[0]["prompt"], tokenize=False, add_generation_prompt=True
)
assert SYSTEM_PROMPT in sample_text, "❌ System prompt NOT in tokenized text!"
print(f"✅ System prompt present in tokenized chat template")
print(f"   Tokenized length (chars): {len(sample_text)}")

In [None]:
from transformers import TrainerCallback

# Board configs to evaluate on during training
EVAL_CONFIGS = [
    (1, 1, 0),     # Trivial — 0 mines  (was failing!)
    (3, 3, 1),     # Tiny
    (5, 5, 0),     # Zero mines (was failing!)
    (5, 5, 3),     # Small
    (6, 6, 5),     # Standard
    (8, 8, 10),    # Medium
    (10, 10, 20),  # Large
    (15, 15, 45),  # XL
    (1, 10, 2),    # Row board
    (20, 20, 80),  # XX-Large
]


class MinesweeperEvalCallback(TrainerCallback):
    """Periodically play games during training.
    
    Tracks:
    - Win rate (games won / total)
    - Invalid count (parse failures + invalid moves)
    - Output length (to verify brevity is learned)
    """

    def __init__(self, eval_every_steps=50, num_games=10):
        self.eval_every_steps = eval_every_steps
        self.num_games = min(num_games, len(EVAL_CONFIGS))

    def on_step_end(self, args, state, control, model=None, processing_class=None, **kwargs):
        if state.global_step % self.eval_every_steps != 0:
            return

        tokenizer = processing_class
        if tokenizer is None or model is None:
            return

        was_training = model.training
        model.eval()

        wins = 0
        total_moves = 0
        invalid_count = 0
        total_response_len = 0
        response_count = 0

        for i in range(self.num_games):
            rows, cols, mines = EVAL_CONFIGS[i % len(EVAL_CONFIGS)]
            game = MinesweeperGame(rows=rows, cols=cols, num_mines=mines,
                                   seed=10000 + i)
            moves = 0
            invalids = 0
            consecutive_invalids = 0
            seen_actions = set()
            repeat_count = 0
            max_iterations = min(200, rows * cols + 50)

            iteration = 0
            while game.state() == "ongoing" and iteration < max_iterations:
                iteration += 1
                prompt = format_state_for_llm(game)
                text = tokenizer.apply_chat_template(
                    [
                        {"role": "system", "content": SYSTEM_PROMPT},
                        {"role": "user", "content": prompt}
                    ],
                    tokenize=False,
                    add_generation_prompt=True,
                )
                inputs = tokenizer(text, return_tensors="pt", truncation=True,
                                   max_length=max_prompt_length + 100)
                inputs = {k: v.to(model.device) for k, v in inputs.items()}

                with torch.no_grad():
                    output = model.generate(
                        **inputs,
                        temperature=0.7,
                        max_new_tokens=96,  # Increased from 64 — match training
                        do_sample=True,
                        top_p=0.9,
                    )

                gen_tokens = output[0][inputs["input_ids"].shape[1]:]
                response = tokenizer.decode(gen_tokens, skip_special_tokens=True)
                
                # Track response length
                total_response_len += len(response)
                response_count += 1
                
                action = parse_llm_action(response)

                if action is None:
                    invalids += 1
                    consecutive_invalids += 1
                    if consecutive_invalids >= 5:
                        break
                    continue

                consecutive_invalids = 0

                action_key = (action['type'], action['row'], action['col'])
                if action_key in seen_actions:
                    repeat_count += 1
                    if repeat_count >= 3:
                        break
                else:
                    repeat_count = 0
                    seen_actions.add(action_key)

                result = game.do_action(action)
                if result in ("mine", "win"):
                    moves += 1
                    break
                elif result == "ok":
                    moves += 1
                else:
                    invalids += 1
                    consecutive_invalids += 1
                    if consecutive_invalids >= 5:
                        break

            if game.state() == "success":
                wins += 1
            total_moves += moves
            invalid_count += invalids

        win_rate = wins / self.num_games
        avg_moves = total_moves / self.num_games
        avg_resp_len = total_response_len / max(1, response_count)
        
        # Color-coded output
        inv_color = "🔴" if invalid_count > 10 else ("🟡" if invalid_count > 5 else "🟢")
        len_color = "🔴" if avg_resp_len > 100 else ("🟡" if avg_resp_len > 50 else "🟢")
        
        print(f"\n[Eval @ step {state.global_step}] "
              f"Win: {wins}/{self.num_games} ({win_rate*100:.0f}%) | "
              f"Moves: {avg_moves:.1f} | "
              f"{inv_color} Invalid: {invalid_count} | "
              f"{len_color} AvgLen: {avg_resp_len:.0f}\n")

        if was_training:
            model.train()

eval_callback = MinesweeperEvalCallback(eval_every_steps=50, num_games=10)

print("=" * 70)
print("EVAL CALLBACK v2")
print("=" * 70)
print(f"  Games per eval:      {eval_callback.num_games}")
print(f"  Eval every:          {eval_callback.eval_every_steps} steps")
print(f"  max_new_tokens:      96 (increased from 64)")
print(f"  Configs:             {len(EVAL_CONFIGS)} boards including 0-mine cases")
print()
print("Now tracking:")
print("  - Invalid count (🟢<5, 🟡5-10, 🔴>10)")
print("  - Response length (🟢<50, 🟡50-100, 🔴>100) - shorter is better!")

# 🚨 Emergency Merge — Run ANYTIME

These utility functions let you **merge the latest checkpoint into a standalone model** at any point:
- **During** training (in a separate cell)
- **After** training finishes or crashes
- **Before** deadline — stop training, merge, and submit

Just call `merge_latest_checkpoint()` whenever you need a submittable model.

In [None]:
# ═══════════════════════════════════════════════════════════════════════
# 🚨 EMERGENCY MERGE - Run this ANYTIME to get merged model from latest checkpoint
# ═══════════════════════════════════════════════════════════════════════

import os
from glob import glob
import shutil

def merge_latest_checkpoint(force=False):
    """
    Merge the latest checkpoint into a standalone model.
    Can be run DURING training (in another cell) or AFTER training stops.
    
    Args:
        force: If True, merge even if already merged
    
    Returns:
        Path to merged model or None
    """
    print("=" * 70)
    print("🔄 EMERGENCY CHECKPOINT MERGE")
    print("=" * 70)
    
    # Find all checkpoints
    checkpoint_dirs = glob(os.path.join(CHECKPOINT_DIR, "checkpoint-*"))
    
    if not checkpoint_dirs:
        print("❌ No checkpoints found!")
        print(f"   Searched in: {CHECKPOINT_DIR}")
        return None
    
    # Sort by step number
    def get_step(path):
        try:
            return int(os.path.basename(path).split("-")[1])
        except:
            return 0
    
    checkpoint_dirs.sort(key=get_step, reverse=True)
    latest_checkpoint = checkpoint_dirs[0]
    latest_step = get_step(latest_checkpoint)
    
    print(f"📂 Found {len(checkpoint_dirs)} checkpoint(s)")
    print(f"   Latest: {os.path.basename(latest_checkpoint)} (step {latest_step})")
    
    # Output directory for merged model
    merged_output = os.path.join(CHECKPOINT_DIR, f"merged_step_{latest_step}")
    
    # Check if already merged
    if os.path.exists(merged_output) and not force:
        print(f"✅ Already merged: {merged_output}")
        print("   Use force=True to re-merge")
        return merged_output
    
    print(f"\n🔧 Merging checkpoint to: {merged_output}")
    print("   This takes ~5-10 minutes...")
    
    try:
        # Load the checkpoint
        from peft import PeftModel
        
        # Load base model + LoRA adapter
        print("   Step 1/3: Loading base model...")
        checkpoint_model, checkpoint_tokenizer = FastLanguageModel.from_pretrained(
            model_name=latest_checkpoint,
            max_seq_length=2048,
            dtype=torch.bfloat16,
            load_in_4bit=False,
        )
        
        # Merge LoRA weights into base model
        print("   Step 2/3: Merging LoRA weights...")
        checkpoint_model = FastLanguageModel.for_inference(checkpoint_model)
        merged_model = checkpoint_model.merge_and_unload()
        
        # Save merged model
        print("   Step 3/3: Saving merged model...")
        os.makedirs(merged_output, exist_ok=True)
        merged_model.save_pretrained(merged_output)
        checkpoint_tokenizer.save_pretrained(merged_output)
        
        print(f"\n✅ SUCCESS! Merged model saved to:")
        print(f"   {merged_output}")
        print(f"\n📊 Model stats:")
        print(f"   - Trained steps: {latest_step}")
        print(f"   - Size: ~{sum(os.path.getsize(os.path.join(merged_output, f)) for f in os.listdir(merged_output)) / (1024**3):.1f} GB")
        
        # Also create a "latest" symlink/copy for convenience
        latest_link = os.path.join(CHECKPOINT_DIR, "merged_latest")
        if os.path.exists(latest_link):
            if os.path.islink(latest_link) or os.path.isdir(latest_link):
                shutil.rmtree(latest_link, ignore_errors=True)
        try:
            os.symlink(merged_output, latest_link, target_is_directory=True)
            print(f"   - Also available at: {latest_link}")
        except:
            pass  # Symlinks might not work on all systems
        
        return merged_output
        
    except Exception as e:
        print(f"\n❌ ERROR during merge: {e}")
        import traceback
        traceback.print_exc()
        return None


def list_checkpoints():
    """Show all available checkpoints"""
    checkpoint_dirs = glob(os.path.join(CHECKPOINT_DIR, "checkpoint-*"))
    
    if not checkpoint_dirs:
        print("No checkpoints found yet. Training might not have started.")
        return
    
    checkpoint_dirs.sort(key=lambda p: int(os.path.basename(p).split("-")[1]))
    
    print("=" * 70)
    print("📁 AVAILABLE CHECKPOINTS")
    print("=" * 70)
    for cp in checkpoint_dirs:
        step = os.path.basename(cp).split("-")[1]
        size = sum(os.path.getsize(os.path.join(cp, f)) for f in os.listdir(cp)) / (1024**2)
        print(f"  checkpoint-{step:>4s}  ({size:>6.1f} MB)")
    print("=" * 70)


# Quick status check
list_checkpoints()

In [None]:
# ═══════════════════════════════════════════════════════════════════════
# 🧪 QUICK TEST - Verify merged model works
# ═══════════════════════════════════════════════════════════════════════

def quick_test_merged_model(model_path):
    """Test merged model on 3 boards to verify it produces valid actions."""
    print("=" * 70)
    print("🧪 TESTING MERGED MODEL")
    print("=" * 70)
    
    # Load merged model
    print(f"Loading from: {model_path}")
    test_model, test_tokenizer = FastLanguageModel.from_pretrained(
        model_name=model_path,
        max_seq_length=2048,
        dtype=torch.bfloat16,
        load_in_4bit=False,
    )
    test_model = FastLanguageModel.for_inference(test_model)
    
    test_boards = [
        (1, 1, 0, "1×1 trivial"),
        (5, 5, 0, "5×5 zero mines"),
        (6, 6, 5, "6×6 standard"),
    ]
    
    wins = 0
    for rows, cols, mines, label in test_boards:
        game = MinesweeperGame(rows, cols, mines, seed=42)
        prompt = format_state_for_llm(game)
        
        text = test_tokenizer.apply_chat_template(
            [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": prompt}
            ],
            tokenize=False,
            add_generation_prompt=True,
        )
        
        output = test_model.generate(
            **test_tokenizer(text, return_tensors="pt").to("cuda"),
            temperature=0.7,
            max_new_tokens=96,
            do_sample=True,
            top_p=0.9,
        )
        
        response = test_tokenizer.decode(output[0], skip_special_tokens=True)
        response_only = response.split("assistant")[-1].strip() if "assistant" in response else response
        
        action = parse_llm_action(response_only)
        
        print(f"\n{label}:")
        print(f"  Response: {response_only[:100]}")
        print(f"  Parsed: {action}")
        
        if action:
            result = game.do_action(action)
            print(f"  Result: {result}")
            if result in ["win", "ok"]:
                wins += 1
                print("  ✅ VALID")
            else:
                print(f"  ❌ INVALID ({result})")
        else:
            print("  ❌ PARSE FAILED")
    
    print("\n" + "=" * 70)
    print(f"QUICK TEST RESULT: {wins}/3 valid actions")
    print("=" * 70)
    
    # Cleanup
    del test_model
    del test_tokenizer
    import gc
    gc.collect()
    torch.cuda.empty_cache()
    
    return wins >= 2  # At least 2/3 should work


# Test the latest merged model
merged_path = os.path.join(CHECKPOINT_DIR, "merged_latest")
if os.path.exists(merged_path):
    success = quick_test_merged_model(merged_path)
    if success:
        print("✅ Model is ready for submission!")
    else:
        print("⚠️ Model might need more training")
else:
    print("No merged model found. Run merge_latest_checkpoint() first.")

# Train the Model

Start GRPO training with reward functions:

In [None]:
from transformers import TrainerCallback

# ══════════════════════════════════════════════════════════════════════
# CHECKPOINT TRACKING CALLBACK
# ══════════════════════════════════════════════════════════════════════

class CheckpointTrackerCallback(TrainerCallback):
    """Track checkpoints and update the tracker file after each save."""
    
    def on_save(self, args, state, control, **kwargs):
        """Called after a checkpoint is saved."""
        checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
        
        tracker = get_checkpoint_tracker()
        tracker["latest_checkpoint"] = checkpoint_path
        tracker["latest_step"] = state.global_step
        tracker["total_steps_trained"] = state.global_step
        
        if checkpoint_path not in tracker["checkpoints"]:
            tracker["checkpoints"].append(checkpoint_path)
        
        save_checkpoint_tracker(tracker)
        
        print(f"\n💾 Checkpoint saved: {checkpoint_path}")
        print(f"   Total steps trained: {state.global_step}")


# Create trainer with checkpoint tracking
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        valid_json_reward,   # FORMAT (50%): +15 pure JSON, -30 invalid
        gameplay_scores,     # GAMEPLAY (40%): +15 valid, -25 invalid
        strategic_reward,    # STRATEGY (10%): +5 logical
    ],
    args = training_args,
    train_dataset = dataset,
    callbacks = [eval_callback, CheckpointTrackerCallback()],
)

print("=" * 70)
print("GRPO TRAINER WITH CHECKPOINT RESUMPTION")
print("=" * 70)
print()
print("Reward Functions (with weights):")
print("  [1] valid_json_reward  (50%) — Format compliance")
print("  [2] gameplay_scores    (40%) — Game validity")
print("  [3] strategic_reward   (10%) — Smart play")
print()

# Record training session
tracker = get_checkpoint_tracker()
session_info = {
    "start_time": datetime.now().isoformat(),
    "resume_from": RESUME_FROM_CHECKPOINT,
    "target_steps": TARGET_TOTAL_STEPS,
}
tracker["training_sessions"].append(session_info)
save_checkpoint_tracker(tracker)

# Start or resume training
if RESUME_FROM_CHECKPOINT:
    print(f"🔄 RESUMING from checkpoint: {RESUME_FROM_CHECKPOINT}")
    print(f"   Starting at step: {resume_step}")
    print(f"   Training until step: {TARGET_TOTAL_STEPS}")
    print()
    trainer.train(resume_from_checkpoint=RESUME_FROM_CHECKPOINT)
else:
    print("🚀 STARTING fresh training")
    print(f"   Training until step: {TARGET_TOTAL_STEPS}")
    print()
    trainer.train()

# Update tracker after training completes
tracker = get_checkpoint_tracker()
tracker["training_sessions"][-1]["end_time"] = datetime.now().isoformat()
tracker["training_sessions"][-1]["final_step"] = trainer.state.global_step
save_checkpoint_tracker(tracker)

print()
print("=" * 70)
print("TRAINING SESSION COMPLETE")
print("=" * 70)
print(f"  Final step: {trainer.state.global_step}")
print(f"  Checkpoints saved in: {CHECKPOINT_DIR}/")
print()
print("To continue training:")
print("  1. Increase TARGET_TOTAL_STEPS in the config cell")
print("  2. Re-run config cell and this training cell")
print("  3. Training will automatically resume from latest checkpoint")

# Test Trained Model

Evaluate the finetuned model:

In [None]:
# Test on variable board sizes across the full competition range
FastLanguageModel.for_inference(model)

test_configs = [
    (1, 1, 0, 200, "1x1 Trivial"),
    (3, 3, 1, 201, "Tiny"),
    (5, 5, 3, 202, "Small/Easy"),
    (6, 6, 5, 99,  "Standard"),
    (8, 8, 10, 101, "Medium"),
    (10, 10, 20, 203, "Large 20%"),
    (15, 15, 45, 204, "XL 20%"),
    (1, 20, 4, 205, "Row Board"),
    (20, 1, 4, 206, "Column Board"),
    (5, 5, 0, 207, "Zero Mines"),
]

for rows, cols, mines, seed, label in test_configs:
    print(f"\n{'='*50}")
    print(f"=== {label} ({rows}x{cols}, {mines} mines) ===")
    print(f"{'='*50}")

    test_game = MinesweeperGame(rows=rows, cols=cols, num_mines=mines, seed=seed)

    # Handle already-won games (0-mine boards that auto-cascade)
    if test_game.state() != "ongoing":
        print(f"  Game auto-resolved: {test_game.state()}")
        continue

    test_prompt = format_state_for_llm(test_game)

    test_text = tokenizer.apply_chat_template(
        [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": test_prompt},
        ],
        tokenize=False,
        add_generation_prompt=True,
    )

    inputs = tokenizer(test_text, return_tensors="pt", truncation=True,
                       max_length=max_prompt_length + 100)
    inputs = {k: v.to("cuda") for k, v in inputs.items()}

    output = model.generate(
        **inputs,
        temperature=0.7,  # LAMER paper: 0.7 for eval
        max_new_tokens=128,  # HACKATHON CONSTRAINT: JSON-only
        do_sample=True,
        top_p=0.9,
        repetition_penalty=1.2,
    )

    # Decode only generated tokens
    gen_tokens = output[0][inputs["input_ids"].shape[1]:]
    response_text = tokenizer.decode(gen_tokens, skip_special_tokens=True)
    print(f"Response: {response_text.strip()}")

    action = parse_llm_action(response_text)
    print(f"Parsed action: {action}")

    if action:
        result = test_game.do_action(action)
        print(f"Result: {result} | Game state: {test_game.state()}")
    else:
        print("⚠️ Failed to parse a valid action")

# Exhaustive Evaluation: Full Competition Range

Play complete games across 37 board configurations (1×1 to 50×50, 0-20% mines).
**No move limit** — only two outcomes: SUCCESS (all safe cells revealed) or FAILURE (mine revealed).

In [None]:
# def play_full_game(model, tokenizer, rows=6, cols=6, num_mines=5, seed=None,
#                    verbose=False):
#     """Play a complete Minesweeper game, tracking detailed metrics.

#     Supports any board size from 1x1 to 50x50.
#     NO move limit — game ends ONLY on:
#       - SUCCESS: all non-mine cells are revealed
#       - FAILURE: any mine cell is revealed
#     A safety iteration cap prevents infinite loops from repeated invalid actions.
#     """
#     game = MinesweeperGame(rows=rows, cols=cols, num_mines=num_mines, seed=seed)

#     # Edge case: game already won (e.g., 0-mine board)
#     if game.state() != "ongoing":
#         return {
#             "game": game,
#             "moves": 0,
#             "logical_moves": 0,
#             "flags_correct": 0,
#             "flags_wrong": 0,
#             "total_invalids": 0,
#             "result": game.state(),
#             "progress": game.progress(),
#             "config": f"{rows}x{cols}m{num_mines}",
#         }

#     moves = 0
#     consecutive_invalids = 0
#     total_invalids = 0
#     logical_moves = 0
#     flags_correct = 0
#     flags_wrong = 0
#     seen_actions = set()   # Fix #4: detect repeated actions (stuck loop)
#     repeat_count = 0       # Fix #4: consecutive repeat counter
#     # Safety cap to prevent infinite loops — NOT a game move limit
#     # Capped at 500 to prevent runaway 50×50 evals (was rows*cols*3+50 = 7550 for 50×50)
#     max_iterations = min(500, rows * cols + 100)

#     iteration = 0
#     while game.state() == "ongoing" and iteration < max_iterations:
#         iteration += 1
#         prompt = format_state_for_llm(game)
#         text = tokenizer.apply_chat_template(
#             [
#                 {"role": "system", "content": SYSTEM_PROMPT},
#                 {"role": "user", "content": prompt},
#             ],
#             tokenize=False,
#             add_generation_prompt=True,
#         )

#         inputs = tokenizer(text, return_tensors="pt", truncation=True,
#                            max_length=max_prompt_length + 100)
#         inputs = {k: v.to("cuda") for k, v in inputs.items()}

#         with torch.no_grad():
#             output = model.generate(
#                 **inputs,
#                 temperature=0.7,  # LAMER paper: 0.7 for eval
#                 max_new_tokens=128,  # HACKATHON CONSTRAINT: JSON-only
#                 do_sample=True,
#                 top_p=0.9,
#                 repetition_penalty=1.2,
#             )

#         # Decode ONLY generated tokens
#         gen_tokens = output[0][inputs["input_ids"].shape[1]:]
#         response = tokenizer.decode(gen_tokens, skip_special_tokens=True)
#         action = parse_llm_action(response)

#         if action is None:
#             consecutive_invalids += 1
#             total_invalids += 1
#             if consecutive_invalids >= 5:
#                 break  # Agent is stuck — abort
#             continue

#         # Fix #4: Check for repeated actions (stuck detection)
#         action_key = (action['type'], action['row'], action['col'])
#         if action_key in seen_actions:
#             repeat_count += 1
#             if repeat_count >= 3:  # Same move 3 times = stuck
#                 break
#         else:
#             repeat_count = 0
#             seen_actions.add(action_key)

#         # Track logical moves (compute BEFORE applying action)
#         safe_set, mine_set = _compute_safe_and_mine_cells(game)
#         r, c = action["row"], action["col"]
#         if action["type"] == "reveal" and (r, c) in safe_set:
#             logical_moves += 1
#         elif action["type"] == "flag" and (r, c) in mine_set:
#             logical_moves += 1

#         # Track flag accuracy
#         if action["type"] == "flag":
#             if 0 <= r < game.rows and 0 <= c < game.cols:
#                 if game._board[r][c] == -1:
#                     flags_correct += 1
#                 else:
#                     flags_wrong += 1

#         if verbose:
#             print(f"  Move {moves}: {action}")

#         result = game.do_action(action)
#         if result in ("mine", "win", "ok"):
#             moves += 1
#         elif result in ("out_of_bounds", "already_revealed", "flagged_cell",
#                          "invalid_flag", "invalid_format"):
#             # Invalid moves don't count but game stays ongoing
#             total_invalids += 1
#             consecutive_invalids += 1
#             if consecutive_invalids >= 5:
#                 break

#         if result in ("mine", "win"):
#             break

#     return {
#         "game": game,
#         "moves": moves,
#         "logical_moves": logical_moves,
#         "flags_correct": flags_correct,
#         "flags_wrong": flags_wrong,
#         "total_invalids": total_invalids,
#         "result": game.state(),
#         "progress": game.progress(),
#         "config": f"{rows}x{cols}m{num_mines}",
#     }


# # ──────────────────────────────────────────────────────────────────────
# # EXHAUSTIVE Multi-Size Evaluation — Competition Spec
# # n, m ∈ [1, 50], mines 0-20% of total cells
# # Only two outcomes: SUCCESS or FAILURE (no timeouts)
# # ──────────────────────────────────────────────────────────────────────

# EVAL_SUITE = [
#     # (rows, cols, mines, num_games, label)

#     # === Trivial / Edge Cases ===
#     (1, 1, 0,   5,  "1x1 trivial"),
#     (1, 2, 0,   5,  "1x2 trivial"),
#     (2, 1, 0,   5,  "2x1 trivial"),
#     (2, 2, 0,   5,  "2x2 no mines"),
#     (3, 3, 0,   5,  "3x3 no mines"),
#     (5, 5, 0,   5,  "5x5 no mines"),

#     # === Tiny Boards ===
#     (3, 3, 1,  10,  "3x3 1 mine"),
#     (4, 4, 3,  10,  "4x4 3 mines"),

#     # === Small Boards ===
#     (5, 5, 3,  15,  "5x5 easy"),
#     (5, 5, 5,  10,  "5x5 max density"),

#     # === Standard ===
#     (6, 6, 5,  20,  "6x6 standard"),
#     (6, 6, 7,  10,  "6x6 hard"),

#     # === Medium ===
#     (7, 7, 7,  10,  "7x7 medium"),
#     (8, 8, 10, 10,  "8x8 medium"),
#     (8, 8, 12, 10,  "8x8 max density"),

#     # === Large ===
#     (10, 10, 10,  5, "10x10 10%"),
#     (10, 10, 20, 10, "10x10 20%"),
#     (15, 15, 45,  5, "15x15 20%"),

#     # === XL ===
#     (20, 20, 80,  3, "20x20 20%"),
#     (25, 25, 125, 3, "25x25 20%"),
#     (30, 30, 180, 2, "30x30 20%"),

#     # === XXL ===
#     (40, 40, 320, 2, "40x40 20%"),
#     (50, 50, 500, 2, "50x50 20%"),

#     # === Rectangular ===
#     (1, 10, 2,   5, "1x10 row"),
#     (10, 1, 2,   5, "10x1 column"),
#     (1, 50, 10,  3, "1x50 row"),
#     (50, 1, 10,  3, "50x1 column"),
#     (6, 8, 6,    5, "6x8 rect"),
#     (8, 6, 6,    5, "8x6 rect"),
#     (5, 15, 15,  3, "5x15 wide"),
#     (15, 5, 15,  3, "15x5 tall"),
#     (3, 50, 30,  2, "3x50 extreme wide"),
#     (50, 3, 30,  2, "50x3 extreme tall"),

#     # === Sparse (low density) ===
#     (10, 10, 1,  5, "10x10 sparse"),
#     (20, 20, 4,  3, "20x20 sparse"),
#     (50, 50, 10, 2, "50x50 sparse"),

#     # === Progressive Difficulty (Fix #10: LAMER generalization test) ===
#     # Same size, increasing mines — tests density scaling
#     (10, 10, 5,  5, "10x10 5% progressive"),
#     (10, 10, 15, 5, "10x10 15% progressive"),

#     # Increasing size, same ~10% density — tests size scaling
#     (15, 15, 23, 3, "15x15 10% density"),
#     (20, 20, 40, 3, "20x20 10% density"),

#     # Generalization to harder unseen configs
#     (25, 25, 62, 2, "25x25 10% generalization"),
#     (30, 30, 90, 2, "30x30 10% generalization"),
# ]

# FastLanguageModel.for_inference(model)
# total_games = sum(n for _,_,_,n,_ in EVAL_SUITE)
# print(f"{'='*80}")
# print(f"  EXHAUSTIVE EVALUATION — {total_games} games across {len(EVAL_SUITE)} configs")
# print(f"  Competition spec: n,m ∈ [1,50], mines 0-20%, no move limit")
# print(f"  Includes progressive difficulty test (LAMER generalization)")
# print(f"  Only two outcomes: SUCCESS or FAILURE")
# print(f"{'='*80}\n")

# all_results = []
# per_config_stats = {}

# for rows, cols, mines, num_games, label in EVAL_SUITE:
#     config_key = f"{rows}x{cols}m{mines}"
#     wins = 0
#     fails = 0
#     config_results = []

#     for i in range(num_games):
#         info = play_full_game(model, tokenizer, rows=rows, cols=cols,
#                               num_mines=mines, seed=5000 + i + hash(config_key) % 10000)
#         config_results.append(info)
#         all_results.append(info)

#         if info["result"] == "success":
#             wins += 1
#         elif info["result"] == "failed":
#             fails += 1

#     # Per-config summary
#     avg_moves = np.mean([r["moves"] for r in config_results])
#     avg_logical = np.mean([r["logical_moves"] for r in config_results])
#     avg_progress = np.mean([r["progress"] for r in config_results])
#     avg_invalids = np.mean([r["total_invalids"] for r in config_results])
#     stuck = sum(1 for r in config_results if r["result"] == "ongoing")
#     wr = wins / num_games * 100

#     per_config_stats[config_key] = {
#         "label": label, "wins": wins, "fails": fails, "stuck": stuck,
#         "total": num_games, "win_rate": wr,
#         "avg_moves": avg_moves, "avg_progress": avg_progress,
#     }

#     status_icon = "✅" if wr >= 50 else "⚠️" if wr >= 20 else "❌"
#     print(f"  {status_icon} {label:22s} ({config_key:12s}): "
#           f"{wins:2d}/{num_games:2d} wins ({wr:5.1f}%) | "
#           f"fails={fails} stuck={stuck} | "
#           f"moves={avg_moves:5.1f} | logical={avg_logical:4.1f} | "
#           f"progress={avg_progress:.0%} | invalids={avg_invalids:.1f}")

# # ── Overall Summary ──
# total_games_actual = len(all_results)
# total_wins = sum(1 for r in all_results if r["result"] == "success")
# total_fails = sum(1 for r in all_results if r["result"] == "failed")
# total_stuck = sum(1 for r in all_results if r["result"] == "ongoing")
# fc = sum(r["flags_correct"] for r in all_results)
# fw = sum(r["flags_wrong"] for r in all_results)

# print(f"\n{'='*80}")
# print(f"  OVERALL: {total_wins}/{total_games_actual} wins ({total_wins/total_games_actual*100:.1f}%)")
# print(f"{'='*80}")
# print(f"  Wins (success):    {total_wins:4d} ({total_wins/total_games_actual*100:.1f}%)")
# print(f"  Losses (failure):  {total_fails:4d} ({total_fails/total_games_actual*100:.1f}%)")
# print(f"  Stuck (loop cap):  {total_stuck:4d} ({total_stuck/total_games_actual*100:.1f}%)")
# print(f"  Avg moves:         {np.mean([r['moves'] for r in all_results]):.1f}")
# print(f"  Avg progress:      {np.mean([r['progress'] for r in all_results]):.0%}")
# print(f"  Avg logical moves: {np.mean([r['logical_moves'] for r in all_results]):.1f}")
# if fc + fw > 0:
#     print(f"  Flag accuracy:     {fc}/{fc+fw} ({fc/(fc+fw)*100:.1f}%)")
# else:
#     print(f"  Flags: none placed")

# # ── Category Breakdown ──
# categories = {
#     "Trivial (0 mines)": [k for k, v in per_config_stats.items() if "m0" in k],
#     "Tiny (≤4x4)":       [k for k, v in per_config_stats.items()
#                            if v["label"].startswith(("3x3", "4x4")) and "m0" not in k],
#     "Small (5x5)":       [k for k, v in per_config_stats.items() if k.startswith("5x5")],
#     "Standard (6x6)":    [k for k, v in per_config_stats.items() if k.startswith("6x6")],
#     "Medium (7-8)":      [k for k, v in per_config_stats.items()
#                            if k.startswith(("7x7", "8x8"))],
#     "Large (10-15)":     [k for k, v in per_config_stats.items()
#                            if k.startswith(("10x10", "15x15"))],
#     "XL (20-50)":        [k for k, v in per_config_stats.items()
#                            if k.startswith(("20x20", "25x25", "30x30", "40x40", "50x50"))],
#     "Rectangular":       [k for k, v in per_config_stats.items()
#                            if "rect" in v["label"] or "row" in v["label"]
#                            or "column" in v["label"] or "wide" in v["label"]
#                            or "tall" in v["label"]],
#     "Sparse":            [k for k, v in per_config_stats.items() if "sparse" in v["label"]],
#     "Progressive":       [k for k, v in per_config_stats.items()
#                            if "progressive" in v["label"] or "generalization" in v["label"]
#                            or "density" in v["label"]],
# }

# print(f"\n{'='*80}")
# print(f"  CATEGORY BREAKDOWN")
# print(f"{'='*80}")
# for cat_name, keys in categories.items():
#     if not keys:
#         continue
#     cat_wins = sum(per_config_stats[k]["wins"] for k in keys)
#     cat_total = sum(per_config_stats[k]["total"] for k in keys)
#     if cat_total > 0:
#         cat_wr = cat_wins / cat_total * 100
#         icon = "✅" if cat_wr >= 50 else "⚠️" if cat_wr >= 20 else "❌"
#         print(f"  {icon} {cat_name:22s}: {cat_wins:3d}/{cat_total:3d} ({cat_wr:5.1f}%)")
# print(f"{'='*80}")

# Checkpoint Status & Final Model Save

Check training progress and save the final model for hackathon submission.

In [None]:
# ══════════════════════════════════════════════════════════════════════
# SAVE FINAL MODEL FOR HACKATHON SUBMISSION
# ══════════════════════════════════════════════════════════════════════
# This cell saves the final merged model ready for submission.
# It uses the latest checkpoint and creates a properly merged model.
# ══════════════════════════════════════════════════════════════════════

import os, shutil, gc
from datetime import datetime
from glob import glob

# Get training info from checkpoint tracker
tracker = get_checkpoint_tracker()
total_steps = tracker.get("total_steps_trained", 0)
latest_ckpt = tracker.get("latest_checkpoint", None)

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

print("=" * 70)
print("SAVING FINAL MODEL FOR HACKATHON SUBMISSION")
print("=" * 70)
print(f"  Total steps trained: {total_steps}")
print(f"  Latest checkpoint:   {latest_ckpt}")
print(f"  Timestamp:           {timestamp}")
print()

# Directory names with step count for clarity
lora_dir = f"my_minesweeper_model_step{total_steps}_{timestamp}"
merged_dir = f"my_minesweeper_model_merged_step{total_steps}_{timestamp}"
merged_dir_latest = "my_minesweeper_model_merged_latest"
lora_dir_latest = "my_minesweeper_model_latest"

# Save LoRA adapters
model.save_pretrained(lora_dir)
tokenizer.save_pretrained(lora_dir)
print(f"✅ LoRA adapters saved to: {lora_dir}/")

# Also save to "latest" directories
model.save_pretrained(lora_dir_latest)
tokenizer.save_pretrained(lora_dir_latest)
print(f"✅ LoRA adapters also saved to: {lora_dir_latest}/")

# ──────────────────────────────────────────────────────────────────────
# Save merged model in 16bit
# ──────────────────────────────────────────────────────────────────────
os.makedirs(merged_dir, exist_ok=True)

try:
    # Try Unsloth's native method first
    model.save_pretrained_merged(
        merged_dir,
        tokenizer,
        save_method="merged_16bit",
    )
    print(f"✅ Merged 16-bit model saved via Unsloth to: {merged_dir}/")
except (UnboundLocalError, Exception) as e:
    print(f"⚠️ Unsloth merge failed ({type(e).__name__}: {e})")
    print("   Falling back to manual LoRA merge...")

    try:
        from peft import PeftModel
        from transformers import AutoModelForCausalLM, AutoTokenizer

        print("   Loading base model for merge...")
        base_model = AutoModelForCausalLM.from_pretrained(
            "/workspace/workspace/Qwen2.5-14B-Instruct",
            torch_dtype=torch.float16,
            device_map="auto",
        )

        print("   Applying LoRA adapters...")
        merged_model = PeftModel.from_pretrained(base_model, lora_dir)
        merged_model = merged_model.merge_and_unload()

        print(f"   Saving merged model to: {merged_dir}/")
        merged_model.save_pretrained(merged_dir, safe_serialization=True)
        tokenizer.save_pretrained(merged_dir)

        del base_model, merged_model
        gc.collect()
        torch.cuda.empty_cache()

        print(f"✅ Merged 16-bit model saved to: {merged_dir}/")

    except Exception as e2:
        print(f"❌ BOTH merge methods failed!")
        print(f"   Unsloth error: {e}")
        print(f"   Manual merge error: {e2}")
        print(f"   LoRA adapters are still saved at: {lora_dir}/")

# Copy to "latest" directory
if os.path.exists(merged_dir) and os.listdir(merged_dir):
    if os.path.exists(merged_dir_latest):
        shutil.rmtree(merged_dir_latest)
    shutil.copytree(merged_dir, merged_dir_latest)
    print(f"✅ Also copied to: {merged_dir_latest}/")

# Update checkpoint tracker with final model info
tracker["final_model"] = {
    "lora_dir": lora_dir,
    "merged_dir": merged_dir,
    "total_steps": total_steps,
    "timestamp": timestamp,
}
save_checkpoint_tracker(tracker)

# Verify saved files
if os.path.exists(merged_dir):
    saved_files = os.listdir(merged_dir)
    safetensors = [f for f in saved_files if f.endswith(".safetensors")]
    print(f"\n📁 Saved files in {merged_dir}:")
    print(f"   Total files: {len(saved_files)}")
    print(f"   Safetensors shards: {len(safetensors)}")
    print(f"   Config: {'✅' if 'config.json' in saved_files else '❌'}")
    print(f"   Tokenizer: {'✅' if 'tokenizer.json' in saved_files else '❌'}")

# List all saved checkpoints
print(f"\n{'='*70}")
print("ALL SAVED CHECKPOINTS:")
print(f"{'='*70}")
checkpoints = sorted(glob(os.path.join(CHECKPOINT_DIR, "checkpoint-*")))
for ckpt in checkpoints:
    step = os.path.basename(ckpt).split("-")[1]
    print(f"  📁 {ckpt} (step {step})")

print(f"\n{'='*70}")
print("HACKATHON SUBMISSION READY")
print(f"{'='*70}")
print(f"  🏆 SUBMIT THIS: {merged_dir_latest}/")
print(f"     (or the timestamped version: {merged_dir}/)")
print()
print(f"  Total training: {total_steps} steps")
print(f"  All checkpoints preserved in: {CHECKPOINT_DIR}/")
print()
print("To train more before submission:")
print("  1. Increase TARGET_TOTAL_STEPS in config cell")
print("  2. Re-run config and training cells")
print("  3. Re-run this save cell")
print("  4. Submit the new merged_latest model")

In [None]:
# ══════════════════════════════════════════════════════════════════════
# CHECK CHECKPOINT STATUS
# ══════════════════════════════════════════════════════════════════════
# Run this cell anytime to see your training progress

from glob import glob
import os

print("=" * 70)
print("CHECKPOINT STATUS")
print("=" * 70)

tracker = get_checkpoint_tracker()

print(f"\n📊 Training Progress:")
print(f"   Total steps trained:  {tracker.get('total_steps_trained', 0)}")
print(f"   Latest checkpoint:    {tracker.get('latest_checkpoint', 'None')}")
print(f"   Training sessions:    {len(tracker.get('training_sessions', []))}")

# List all checkpoints
checkpoints = sorted(glob(os.path.join(CHECKPOINT_DIR, "checkpoint-*")))
print(f"\n📁 Saved Checkpoints ({len(checkpoints)} total):")
for ckpt in checkpoints:
    step = os.path.basename(ckpt).split("-")[1]
    size_mb = sum(os.path.getsize(os.path.join(ckpt, f)) 
                  for f in os.listdir(ckpt) if os.path.isfile(os.path.join(ckpt, f))) / 1e6
    print(f"   checkpoint-{step:>4s} ({size_mb:.0f} MB)")

# Show training sessions
sessions = tracker.get("training_sessions", [])
if sessions:
    print(f"\n📅 Training Sessions:")
    for i, session in enumerate(sessions[-5:], 1):  # Show last 5 sessions
        start = session.get("start_time", "?")[:19]
        end = session.get("end_time", "ongoing")[:19] if session.get("end_time") else "ongoing"
        final = session.get("final_step", "?")
        print(f"   Session {i}: {start} → {end} (step {final})")

print(f"\n{'='*70}")
print("NEXT STEPS:")
print(f"{'='*70}")
print(f"  To continue training:  Increase TARGET_TOTAL_STEPS and re-run training")
print(f"  To save final model:   Run the next cell")
print(f"  Current target:        {TARGET_TOTAL_STEPS} steps")

# Inference from Merged Model

Load the saved merged model from disk (no LoRA, no Unsloth) and verify it works for Minesweeper inference.

⚠️ **Important:** Unsloth patches transformers globally when imported. The cell below includes a workaround to reset these patches.

**If you still get errors**, the cleanest solution is:
1. **Restart Kernel** (Kernel → Restart)
2. Run **only** the cells needed for inference (skip model loading/training cells)
3. Or run the inference in a separate Python script

In [None]:
# ──────────────────────────────────────────────────────────────────────
# INFERENCE FROM MERGED MODEL (Subprocess Approach)
# 
# Unsloth patches transformers globally and cannot be easily undone.
# This cell writes a standalone inference script and runs it in a 
# fresh Python subprocess where Unsloth was never imported.
# ──────────────────────────────────────────────────────────────────────

import subprocess
import sys

# Write standalone inference script
inference_script = '''
import torch
import json
import re
import random

# ── Load model (pure transformers, no Unsloth) ──
from transformers import AutoModelForCausalLM, AutoTokenizer

# Use the _latest directory (always points to most recent model)
merged_dir = "my_minesweeper_model_merged_latest"

print(f"Loading merged model from: {merged_dir}/")
tokenizer = AutoTokenizer.from_pretrained(merged_dir)
print(f"  ✅ Tokenizer loaded ({tokenizer.vocab_size} vocab)")

model = AutoModelForCausalLM.from_pretrained(
    merged_dir,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
model.eval()
print(f"  ✅ Model loaded on {model.device}")
print(f"  Parameters: {sum(p.numel() for p in model.parameters()):,}")

# ── Minimal MinesweeperGame for testing ──
class MinesweeperGame:
    def __init__(self, rows, cols, num_mines, seed=None):
        self.rows, self.cols, self.num_mines = rows, cols, num_mines
        self._revealed, self._flagged = set(), set()
        rng = random.Random(seed)
        all_cells = [(r, c) for r in range(rows) for c in range(cols)]
        mines = set(rng.sample(all_cells, min(num_mines, len(all_cells))))
        self._board = [[0]*cols for _ in range(rows)]
        for r, c in mines:
            self._board[r][c] = -1
        for r in range(rows):
            for c in range(cols):
                if self._board[r][c] != -1:
                    count = sum(1 for dr in [-1,0,1] for dc in [-1,0,1]
                               if 0 <= r+dr < rows and 0 <= c+dc < cols
                               and self._board[r+dr][c+dc] == -1)
                    self._board[r][c] = count
        self._state = "ongoing"
        if num_mines == 0:
            self._state = "success"
    
    def state(self): return self._state
    
    def get_visible_board(self):
        board = []
        for r in range(self.rows):
            row = []
            for c in range(self.cols):
                if (r,c) in self._flagged: row.append("F")
                elif (r,c) in self._revealed: row.append(str(self._board[r][c]))
                else: row.append("?")
            board.append(row)
        return board
    
    def do_action(self, action):
        if self._state != "ongoing": return self._state
        r, c = action.get("row"), action.get("col")
        if not (0 <= r < self.rows and 0 <= c < self.cols): return "invalid"
        if action["type"] == "flag":
            if (r,c) in self._revealed: return "invalid"
            self._flagged.add((r,c))
            return "ok"
        if (r,c) in self._revealed or (r,c) in self._flagged: return "invalid"
        if self._board[r][c] == -1:
            self._state = "failed"
            return "mine"
        self._reveal(r, c)
        safe = self.rows * self.cols - self.num_mines
        if len(self._revealed) >= safe:
            self._state = "success"
            return "win"
        return "ok"
    
    def _reveal(self, r, c):
        if (r,c) in self._revealed or (r,c) in self._flagged: return
        if not (0 <= r < self.rows and 0 <= c < self.cols): return
        if self._board[r][c] == -1: return
        self._revealed.add((r,c))
        if self._board[r][c] == 0:
            for dr in [-1,0,1]:
                for dc in [-1,0,1]:
                    self._reveal(r+dr, c+dc)

# ── Prompt formatting ──
SYSTEM_PROMPT = "You are a Minesweeper AI. Output ONLY valid JSON."

def format_prompt(game):
    board = game.get_visible_board()
    grid = "   " + " ".join(str(c) for c in range(game.cols)) + "\\n"
    for r, row in enumerate(board):
        grid += f"{r:2d} " + " ".join(row) + "\\n"
    return f"""Minesweeper {game.rows}×{game.cols}, {game.num_mines} mines

{grid}
Legend: ?=unrevealed F=flagged 0-8=safe

Output JSON only: {{"type":"reveal","row":N,"col":N}} or {{"type":"flag","row":N,"col":N}}
Row: 0-{game.rows-1}, Col: 0-{game.cols-1}"""

def parse_action(response):
    for match in re.finditer(r"\\{[^{}]*\\}", response):
        try:
            action = json.loads(match.group())
            if "type" in action and "row" in action and "col" in action:
                if action["type"] in ("reveal", "flag"):
                    action["row"] = int(action["row"])
                    action["col"] = int(action["col"])
                    return action
        except: pass
    return None

# ── Run tests ──
test_configs = [
    (1, 1, 0, 300, "1×1 Trivial"),
    (3, 3, 1, 301, "Tiny 3×3"),
    (5, 5, 3, 302, "Small 5×5"),
    (6, 6, 5, 99,  "Standard 6×6"),
    (8, 8, 10, 303, "Medium 8×8"),
    (10, 10, 20, 304, "Large 10×10"),
    (15, 15, 45, 305, "XL 15×15"),
    (1, 20, 4, 306, "Linear 1×20"),
    (5, 5, 0, 307, "Zero Mines 5×5"),
]

print(f"\\n{'='*60}")
print(f"  MERGED MODEL INFERENCE TEST — {len(test_configs)} boards")
print(f"{'='*60}")

results = {"pass": 0, "fail": 0, "skip": 0}

for rows, cols, mines, seed, label in test_configs:
    print(f"\\n--- {label} ({rows}×{cols}, {mines} mines, seed={seed}) ---")
    
    game = MinesweeperGame(rows, cols, mines, seed)
    if game.state() != "ongoing":
        print(f"  Game auto-resolved: {game.state()}")
        results["skip"] += 1
        continue
    
    prompt = format_prompt(game)
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": prompt},
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=2048)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    with torch.no_grad():
        output = model.generate(
            **inputs,
            temperature=0.7,
            max_new_tokens=128,
            do_sample=True,
            top_p=0.9,
            repetition_penalty=1.2,
            pad_token_id=tokenizer.pad_token_id,
        )
    
    gen_tokens = output[0][inputs["input_ids"].shape[1]:]
    response = tokenizer.decode(gen_tokens, skip_special_tokens=True).strip()
    print(f"  Response: {response[:150]}{'...' if len(response) > 150 else ''}")
    
    action = parse_action(response)
    print(f"  Parsed:   {action}")
    
    if action:
        result = game.do_action(action)
        print(f"  Result:   {result} → game {game.state()}")
        if result != "mine":
            results["pass"] += 1
        else:
            results["fail"] += 1
    else:
        print(f"  ⚠️ Failed to parse valid action")
        results["fail"] += 1

print(f"\\n{'='*60}")
print(f"  INFERENCE TEST SUMMARY")
print(f"{'='*60}")
print(f"  Pass (valid move): {results['pass']}/{results['pass'] + results['fail']}")
print(f"  Fail (bad parse/mine): {results['fail']}")
print(f"  Skipped (auto-win): {results['skip']}")
total = results['pass'] + results['fail']
if total > 0:
    print(f"  Success rate: {results['pass']/total*100:.1f}%")
print(f"\\n✅ Merged model inference test complete!")
print(f"   Model path: {merged_dir}/")
'''

# Write script to file
script_path = "run_inference_clean.py"
with open(script_path, "w") as f:
    f.write(inference_script)
print(f"✅ Wrote inference script to: {script_path}")

# Run in subprocess (fresh Python, no Unsloth contamination)
print(f"\n{'='*60}")
print("Running inference in fresh Python subprocess...")
print("(This avoids Unsloth's transformers patches)")
print(f"{'='*60}\n")

result = subprocess.run(
    [sys.executable, script_path],
    capture_output=False,  # Stream output directly
    text=True,
)

print(f"\n{'='*60}")
if result.returncode == 0:
    print("✅ Subprocess completed successfully!")
else:
    print(f"❌ Subprocess failed with return code: {result.returncode}")

# Fixes & Improvements Applied

## 4-Paper Hybrid Integration (LAMER + XRPO + GRPO-LEAD + S-GRPO)

### Paper 1: LAMER — Meta-RL for Language Agents (74% win rate on MineSweeper)
| LAMER Finding | Implementation |
|---------------|----------------|
| ReAct prompting (Reason + Act) | Explicit "think step-by-step before acting" in system prompt |
| Pre-computed logical hints (CoT) | SAFE/MINE cell lists injected into every prompt |
| Center-opening strategy | Reward center on dense boards (>10%), penalize edges/corners |
| `temperature=1.0` for training | Full exploration during GRPO generation |
| `temperature=0.7` for eval | LAMER paper eval temperature |
| `num_iterations=2` | LAMER paper: 2 GRPO iterations for MineSweeper |
| STEP 1/STEP 2 reasoning | Satisfied/constrained number scanning |

### Paper 2: XRPO — Adaptive Exploration with Difficulty Reweighting
| XRPO Finding | Implementation |
|--------------|----------------|
| Difficulty reweighting | `_difficulty_multiplier()` → harder boards get amplified reward signal (×0.7–×1.5) |
| Exploration heuristic | STEP 4 in prompt: prefer high-info cells, low numbers, avoid edges |
| Novelty bonus concept | Implicit: difficulty multiplier serves similar purpose (hard=novel→stronger signal) |
| Adaptive difficulty proxy | `0.6 × density/0.20 + 0.4 × log(size)/log(2500)` → sigmoid multiplier |
| Safety check | Returns 1.0 for 0-mine boards and degenerate cases (Fix #2) |

### Paper 3: GRPO-LEAD — Length Penalty & Explicit Wrong Penalty
| GRPO-LEAD Finding | Implementation |
|-------------------|----------------|
| Group-normalized length penalty | 2-pass z-score normalization across correct responses (Fix #6) |
| ≤200 token response cap | Instruction in training prompt: "Total response ≤ 200 tokens" |
| Explicit wrong penalty | Mine reveal: -26 (was -25), wrong flag: -11 (was -10) |
| Lower learning rate | `5e-6` (was `2e-5`) for stability with penalty terms |
| Concise reasoning | "2-3 sentences max" instruction in prompt output format |

### Paper 4: S-GRPO — Early Exit When Solution Found
| S-GRPO Finding | Implementation |
|----------------|----------------|
| Early exit instruction | "Once you find a logical move in Steps 1-3, STOP reasoning and output it immediately" |
| Reduced wasted tokens | Combined with GRPO-LEAD length penalty → shorter, focused responses |
| Step-ordered deduction | Steps 1→2→3→4 with explicit "stop when found" at each logical step |

## Robustness Fixes (Latest Session)
| # | Issue | Fix | Priority |
|---|-------|-----|----------|
| 1 | Infinite loops in `_play_smart_moves()` | `stuck_count` with `MAX_STUCK=10` — break after 10 failed attempts | P0 |
| 2 | Div-by-zero in `_difficulty_multiplier()` | Safety check: return 1.0 for `board_size==0` or `num_mines==0` | P3 |
| 3 | Missing one-cell-left endgame hint | Compute exact last cell coords, inject `CRITICAL` hint with specific action | P2 |
| 4 | Model repeating same valid move forever | `seen_actions` set + `repeat_count` in eval callback & `play_full_game` — break after 3 repeats | P1 |
| 6 | Wrong length penalty (absolute not group) | 2-pass GRPO-LEAD: z-score normalize lengths across correct responses, clamp multiplier [0.5, 1.5] | P1 |
| 9 | Manual merge fallback had no error handling | Nested try/except — if both Unsloth & PEFT merge fail, print recovery instructions | P3 |
| 10 | No progressive difficulty in eval suite | Added 6 configs: same-size density scaling + same-density size scaling + generalization tests | P3 |

## 50×50 Board Scalability Fixes
| # | Issue | Fix | Impact |
|---|-------|-----|--------|
| S1 | Token budget mismatch (1900+128=2028, no room for reasoning) | Rebalanced to 1792+256=2048 exact fit | 50×50 Format C (~1335 tokens) fits with 457 token margin for reasoning |
| S2 | Large boards only 10% of training data (sizes 1-50) | Boosted to 15% at sizes 30-50; bands now 25/20/20/20/15 | 3× more large board training samples |
| S3 | Flat +100 win bonus regardless of board size | Scaled: `100 × (1 + min(1, board_size/1000))` — 50×50→+250, 6×6→+104 | Model motivated to complete large boards |
| S4 | Eval `max_iterations` too high (rows*cols*3 = 7500 for 50×50) | Capped to `min(500, rows*cols+100)` | Prevents runaway 50×50 eval loops |
| S5 | All eval/inference used `max_new_tokens=128` | Reverted to 128 everywhere — HACKATHON CONSTRAINT: JSON-only output, no reasoning | Pure JSON action is ~10-25 tokens, well under 128 limit |

## Exhaustive Dataset Generation
| Feature | Details |
|---------|---------|
| Total samples | 4000 (was 3000) |
| Edge cases | 50+ configs (was 25) — trivial, linear, rectangular, density extremes |
| Opening training | 25% fresh/single-move boards — train safe opening strategies |
| Pattern-specific | 15% — satisfied-number boards + multi-region disconnected boards |
| Mid-game deduction | 25% — 3-15 moves, progressive flagging 10%→30%→50% |
| Endgame completion | 15% — 80-98% revealed — flag accounting, finish strategy |
| Forced guess | 10% — no logical deductions available |
| Progressive flagging | 10% early → 30% mid → 50% late |
| Density stratification | 8% zero, 17% very sparse, 25% sparse, 25% medium, 25% dense |
| Stuck prevention | `MAX_STUCK=10` in `_play_smart_moves()` (Fix #1) |

## Prompt System (Hybrid: 4-Paper Master Template)
| Feature | Paper Source | Implementation |
|---------|-------------|----------------|
| ReAct reasoning | LAMER | "Think step-by-step → STEP 1-4 → Act" |
| Pattern recognition | XRPO | STEP 3: 1-2-1 line, 1-1 corner, zero cascade |
| Forced-guess heuristic | XRPO | STEP 4: prefer high-info cells, low numbers, avoid edges |
| Early exit | S-GRPO | "Once you find a logical move in Steps 1-3, STOP" |
| Length control | GRPO-LEAD | "Total response ≤ 200 tokens", "2-3 sentences max" |
| 3-tier board format | Custom | A (≤20): full grid, B (21-35): frontier, C (36-50): summary |
| Phase-aware prompts | LAMER | Opening (center), Mid-game (deduction), Endgame (flag accounting) |
| Edge case guidance | Custom | 0-mine, linear, tiny, large, high-density, all-flagged, last-cell |
| One-cell-left hint | Fix #3 | Computes exact cell coords → "Reveal (r,c) to WIN!" or "Flag (r,c)!" |
| All-remaining-mines | Fix #3 | "ALL REMAINING N CELLS ARE MINES: Flag any!" |
| Inference variant | LAMER | ~60% shorter for eval/test |

## Reward System (3 Functions, Hybrid)
| Reward | Weight | Paper Enhancements |
|--------|--------|-------------------|
| `valid_json_reward` | 0.20 | GRPO-LEAD: group-normalized z-score length penalty (Fix #6) |
| `gameplay_scores` | 0.65 | XRPO: difficulty reweighting (×0.7–×1.5), GRPO-LEAD: explicit wrong penalty |
| `strategic_reward` | 0.15 | XRPO: difficulty reweighting, LAMER: center-opening |

## Evaluation System
| Feature | Details |
|---------|---------|
| Eval callback | 10 configs every 50 steps, temp=0.7, stuck detection (Fix #4) |
| Exhaustive eval | 43+ configs (was 37), 1×1 to 50×50 |
| Progressive difficulty | Same-size density scaling + same-density size scaling (Fix #10) |
| Generalization test | 25×25 10%, 30×30 10% — harder unseen configs |
| Stuck detection | Repeated action tracking — break after 3 same-move repeats (Fix #4) |
| Category breakdown | Trivial, Tiny, Small, Standard, Medium, Large, XL, Rectangular, Sparse, Progressive |

## Training Config (Hybrid Hyperparameters)
| Parameter | Value | Source |
|-----------|-------|--------|
| temperature | 1.0 (train), 0.7 (eval) | LAMER |
| learning_rate | 5e-6 | GRPO-LEAD (reduced for stability) |
| num_iterations | 2 | LAMER |
| beta (KL) | 0.04 | LAMER |
| reward_weights | [0.20, 0.65, 0.15] | Hybrid (format↑ for length penalty) |
| max_completion_length | 128 | GRPO-LEAD (length control) |
| max_grad_norm | 0.5 | Gradient clipping for stability |
| warmup_ratio | 0.05 | Stable early training |

## Model Save & Inference
| Feature | Details |
|---------|---------|
| LoRA save | `my_minesweeper_model/` — always succeeds |
| Merged save | Try Unsloth → fallback PEFT merge → error recovery instructions (Fix #9) |
| Merged inference | Standalone HF transformers load — no Unsloth/LoRA needed |
| Inference test | 9 diverse boards, pass/fail summary |

## Competition Spec Compliance
| Requirement | Implementation |
|-------------|---------------|
| Board rows: 1–50 | `MIN_ROWS=1, MAX_ROWS=50` in game engine |
| Board cols: 1–50 | `MIN_COLS=1, MAX_COLS=50` in game engine |
| Mines: 0–20% of cells | `MAX_MINE_DENSITY=0.20`, 0 mines allowed |
| No move limit | Only success/failure |
| Success = all safe revealed | `_check_win()` checks `len(revealed) >= safe_cells` |
| Failure = mine revealed | Only `_reveal_cell` hitting mine sets `_state="failed"` |
| max_new_tokens: 128 | `max_completion_length=128` in GRPOConfig |

## Critical Bugs Fixed (Previous Sessions)
| # | Bug | Fix |
|---|-----|-----|

| 1 | `do_action()` set `_state="failed"` for ALL invalid moves | Only `mine` sets state to "failed" |
| 8 | Random flags in training data | `_smart_flag()` only flags logically certain mines || 9 | Reward scale imbalance | Rebalanced: JSON=-10, mine=-26, win=+100 |

| 2 | Reward functions signature mismatch with TRL GRPOTrainer | `(prompts, completions, **kwargs)` |
| 9 | Reward scale imbalance | Rebalanced: JSON=-10, mine=-26, win=+100 || 8 | Random flags in training data | `_smart_flag()` only flags logically certain mines |

| 3 | Hardcoded board size in rewards | `_reconstruct_game()` reads from kwargs || 7 | `remove_unused_columns` not set | Explicitly `False` |

| 4 | `max_prompt_length=700` truncated prompts | Increased to 1900 || 6 | Eval decoded full output including prompt | Decodes only generated tokens |
| 5 | Separate O(n²) passes for safe/mine cells | Combined single pass |