# Minesweeper LLM Competition - Custom GRPO Training

## Goal
Finetune an LLM with LoRA using GRPO to play Minesweeper by:
- **Input**: JSON game state (board configuration)
- **Output**: JSON action (reveal or flag a cell)

Teams will compete to train the best Minesweeper-playing LLM!

## Training Approach
- **Model**: Qwen2.5-14B-Instruct (from /root/.cache/huggingface/hub)
- **Method**: GRPO (Group Relative Policy Optimization)
- **Framework**: Unsloth (2-6x faster, 70% less VRAM)
- **Hardware**: AMD MI300X GPU (192GB HBM3, ROCm)

# Load Model with Unsloth

Load Qwen3-4B with LoRA configuration:

In [9]:
import os

os.environ["HF_HOME"] = "./workspace/hf_cache"
os.environ["HUGGINGFACE_HUB_CACHE"] = "./workspace/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "./workspace/hf_cache"
os.environ["HF_DATASETS_CACHE"] = "./workspace/hf_cache"


In [10]:
from huggingface_hub import snapshot_download

snapshot_download(
    repo_id="unsloth/Qwen2.5-14B-Instruct",
    local_dir="./workspace/Qwen2.5-14B-Instruct",
    local_dir_use_symlinks=False,
)


Ignored error while writing commit hash to /root/.cache/huggingface/models--unsloth--Qwen2.5-14B-Instruct/refs/main: [Errno 30] Read-only file system: '/root/.cache/huggingface/models--unsloth--Qwen2.5-14B-Instruct'.


'/workspace/workspace/Qwen2.5-14B-Instruct'

In [None]:
from unsloth import FastLanguageModel
import torch
import yaml

# Load config
with open("minesweeper_config_me.yaml", "r") as f:
    config = yaml.safe_load(f)

lora_rank = config.get("lora_rank", 32)

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="/workspace/workspace/Qwen2.5-14B-Instruct",
    load_in_4bit=False,   # AMD → 4bit disabled
    max_seq_length=2048,  # Increased: larger boards need longer prompts
    dtype=torch.bfloat16,
)

print("Model loaded successfully!")
print(f"Device: {model.device}")
print(f"LoRA rank: {lora_rank} (from config)")

Unsloth: AMD currently is not stable with 4bit bitsandbytes. Disabling for now.
==((====))==  Unsloth 2025.10.6: Fast Qwen2 patching. Transformers: 4.56.2. vLLM: 0.11.1rc2.dev161+g8a297115e.rocm700.
   \\   /|    . Num GPUs = 1. Max memory: 255.688 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+gitb2fb688. ROCm Toolkit: 7.0.51831-a3e329ad8. Triton: 3.4.0
\        /    Bfloat16 = TRUE. FA [Xformers = None. FA2 = True]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

Model loaded successfully!
Device: cuda:0


# Add LoRA Adapters

Add LoRA layers for efficient finetuning:

In [12]:
model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank,
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha = lora_rank,           # alpha = rank → scaling factor = 1.0 (stable training)
    lora_dropout = 0.05,              # Small dropout for regularization
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
)
print(f"LoRA config: rank={lora_rank}, alpha={lora_rank}, dropout=0.05")
model.print_trainable_parameters()

Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = 0.05.
Unsloth will patch all other layers, except LoRA matrices, causing a performance hit.
Unsloth 2025.10.6 patched 48 layers with 0 QKV layers, 0 O layers and 0 MLP layers.


LoRA config: rank=32, alpha=32, dropout=0.05
trainable params: 137,625,600 || all params: 14,907,659,264 || trainable%: 0.9232


# Minesweeper Game Implementation

Custom Minesweeper environment supporting:
- Customizable board size and mine count
- Actions: reveal or flag cells
- Win: reveal all safe cells
- Lose: reveal a mine

In [None]:
from dataclasses import dataclass, field
from typing import List, Tuple, Optional, Set
import random

# ──────────────────────────────────────────────────────────────────────
# Board size configurations for training and evaluation
# ──────────────────────────────────────────────────────────────────────

BOARD_CONFIGS = [
    # (rows, cols, num_mines, weight) — weight controls sampling probability
    (5, 5,  3, 0.10),   # Small / easy
    (5, 5,  5, 0.05),   # Small / hard
    (6, 6,  5, 0.20),   # Default (competition likely uses this)
    (6, 6,  7, 0.10),   # Default / hard
    (7, 7,  7, 0.10),   # Medium
    (7, 7, 10, 0.05),   # Medium / hard
    (8, 8,  8, 0.10),   # Larger
    (8, 8, 10, 0.10),   # Larger / hard
    (9, 9, 10, 0.10),   # Large
    (10,10,12, 0.05),   # Extra large
    (6, 8,  6, 0.03),   # Rectangular
    (8, 6,  6, 0.02),   # Rectangular (tall)
]

def sample_board_config(rng=None):
    """Sample a (rows, cols, num_mines) tuple from BOARD_CONFIGS."""
    rng = rng or random.Random()
    weights = [w for _, _, _, w in BOARD_CONFIGS]
    chosen = rng.choices(BOARD_CONFIGS, weights=weights, k=1)[0]
    return chosen[0], chosen[1], chosen[2]


def mine_density(rows, cols, num_mines):
    """Compute mine density as a fraction."""
    return num_mines / (rows * cols)


@dataclass
class MinesweeperGame:
    rows: int
    cols: int
    num_mines: int
    seed: Optional[int] = None
    _rng: random.Random = field(init=False, repr=False)
    _board: List[List[int]] = field(init=False, repr=False)  # -1 = mine, 0-8 = count
    _revealed: Set[Tuple[int, int]] = field(init=False, repr=False, default_factory=set)
    _flagged: Set[Tuple[int, int]] = field(init=False, repr=False, default_factory=set)
    _state: str = field(default="ongoing", init=False, repr=False)

    def __post_init__(self):
        # ── Input validation ──
        if self.rows < 2 or self.cols < 2:
            raise ValueError(f"Board too small: {self.rows}x{self.cols} (min 2x2)")
        if self.rows > 30 or self.cols > 30:
            raise ValueError(f"Board too large: {self.rows}x{self.cols} (max 30x30)")
        if self.num_mines < 1:
            raise ValueError(f"Need at least 1 mine, got {self.num_mines}")
        if self.num_mines >= self.rows * self.cols:
            raise ValueError(f"Too many mines ({self.num_mines}) for {self.rows}x{self.cols} board")
        if mine_density(self.rows, self.cols, self.num_mines) > 0.50:
            pass  # Allow but warn — very hard boards

        self._rng = random.Random(self.seed)
        self._board = [[0 for _ in range(self.cols)] for _ in range(self.rows)]
        self._place_mines()
        self._calculate_numbers()

    def _place_mines(self):
        """Place mines randomly on the board."""
        positions = [(r, c) for r in range(self.rows) for c in range(self.cols)]
        mine_positions = self._rng.sample(positions, self.num_mines)
        for r, c in mine_positions:
            self._board[r][c] = -1

    def _calculate_numbers(self):
        """Calculate numbers for each cell based on adjacent mines."""
        for r in range(self.rows):
            for c in range(self.cols):
                if self._board[r][c] == -1:
                    continue
                count = 0
                for dr in [-1, 0, 1]:
                    for dc in [-1, 0, 1]:
                        if dr == 0 and dc == 0:
                            continue
                        nr, nc = r + dr, c + dc
                        if 0 <= nr < self.rows and 0 <= nc < self.cols:
                            if self._board[nr][nc] == -1:
                                count += 1
                self._board[r][c] = count

    def _reveal_cell(self, row: int, col: int) -> bool:
        """Reveal a cell. Returns True if valid move, False if invalid.
        Uses iterative flood-fill to avoid recursion limit on large boards.
        """
        if not (0 <= row < self.rows and 0 <= col < self.cols):
            return False
        if (row, col) in self._revealed or (row, col) in self._flagged:
            return False

        stack = [(row, col)]
        while stack:
            r, c = stack.pop()
            if (r, c) in self._revealed:
                continue

            self._revealed.add((r, c))

            # Hit a mine!
            if self._board[r][c] == -1:
                self._state = "failed"
                return True

            # Auto-reveal neighbors if cell is 0
            if self._board[r][c] == 0:
                for dr in [-1, 0, 1]:
                    for dc in [-1, 0, 1]:
                        if dr == 0 and dc == 0:
                            continue
                        nr, nc = r + dr, c + dc
                        if (0 <= nr < self.rows and 0 <= nc < self.cols
                                and (nr, nc) not in self._revealed
                                and (nr, nc) not in self._flagged):
                            stack.append((nr, nc))

        return True

    def _flag_cell(self, row: int, col: int) -> bool:
        """Flag/unflag a cell. Returns True if valid, False if invalid."""
        if not (0 <= row < self.rows and 0 <= col < self.cols):
            return False
        if (row, col) in self._revealed:
            return False

        if (row, col) in self._flagged:
            self._flagged.remove((row, col))
        else:
            self._flagged.add((row, col))
        return True

    def do_action(self, action: dict) -> str:
        """Execute an action and return a status string.

        Returns one of:
          'ok'               - valid move executed
          'mine'             - revealed a mine (game over → state='failed')
          'win'              - game won after this move
          'invalid_format'   - bad action dict / missing keys / bad types
          'out_of_bounds'    - coordinates outside the board
          'already_revealed' - cell was already revealed
          'flagged_cell'     - tried to reveal a flagged cell
          'invalid_flag'     - tried to flag a revealed cell
          'game_over'        - game was already over before this call

        FIX: Only 'mine' sets state='failed'. All other invalid moves
        return an error string but keep the game 'ongoing' so the agent
        can retry or the reward function can distinguish error types.
        """
        if self._state != "ongoing":
            return "game_over"

        if not isinstance(action, dict):
            return "invalid_format"

        action_type = action.get("type")
        row = action.get("row")
        col = action.get("col")

        if action_type not in ("reveal", "flag") or row is None or col is None:
            return "invalid_format"

        try:
            row, col = int(row), int(col)
        except (ValueError, TypeError):
            return "invalid_format"

        if not (0 <= row < self.rows and 0 <= col < self.cols):
            return "out_of_bounds"

        if action_type == "reveal":
            if (row, col) in self._revealed:
                return "already_revealed"
            if (row, col) in self._flagged:
                return "flagged_cell"
            self._reveal_cell(row, col)
        else:  # flag
            if (row, col) in self._revealed:
                return "invalid_flag"
            self._flag_cell(row, col)

        self._check_win()

        if self._state == "failed":
            return "mine"
        if self._state == "success":
            return "win"
        return "ok"

    def _check_win(self):
        """Check if player has won.

        Win condition: ALL safe cells are revealed.
        Alternative win: all mines are correctly flagged AND all safe
        cells are revealed. (In standard Minesweeper, revealing all
        safe cells is sufficient — flags are optional.)
        """
        if self._state != "ongoing":
            return
        total_cells = self.rows * self.cols
        safe_cells = total_cells - self.num_mines
        if len(self._revealed) >= safe_cells:
            # Verify no mines were revealed (shouldn't happen if state is ongoing)
            self._state = "success"

    def get_visible_board(self) -> List[List[str]]:
        """Get board state as player sees it."""
        visible = []
        for r in range(self.rows):
            row = []
            for c in range(self.cols):
                if (r, c) in self._flagged:
                    row.append('F')
                elif (r, c) in self._revealed:
                    val = self._board[r][c]
                    row.append('*' if val == -1 else str(val))
                else:
                    row.append('.')
            visible.append(row)
        return visible

    def state(self) -> str:
        return self._state

    def get_mine_positions(self) -> Set[Tuple[int, int]]:
        """Return set of all mine positions (for reward computation)."""
        return {(r, c) for r in range(self.rows) for c in range(self.cols)
                if self._board[r][c] == -1}

    def progress(self) -> float:
        """Fraction of safe cells revealed (0.0 to 1.0)."""
        safe_cells = self.rows * self.cols - self.num_mines
        return len(self._revealed) / safe_cells if safe_cells > 0 else 1.0

    def pretty_print(self) -> str:
        """Pretty print the board."""
        visible = self.get_visible_board()
        lines = []

        # Header
        header = "   " + " ".join(f"{i:2d}" for i in range(self.cols))
        lines.append(header)
        lines.append("  " + "─" * (self.cols * 3 + 1))

        # Board
        for r, row in enumerate(visible):
            line = f"{r:2d}│ " + "  ".join(row)
            lines.append(line)

        return "\n".join(lines)


# ── Quick sanity tests ──
print("Testing MinesweeperGame...")
g = MinesweeperGame(5, 5, 3, seed=0)
assert g.state() == "ongoing"
assert g.do_action({"type": "reveal", "row": -1, "col": 0}) == "out_of_bounds"
assert g.state() == "ongoing", "BUG: out_of_bounds should NOT end the game"
assert g.do_action({"type": "reveal", "row": 99, "col": 0}) == "out_of_bounds"
assert g.state() == "ongoing"
assert g.do_action({"type": "flag", "row": 0, "col": 0}) == "ok"
assert g.do_action({"type": "reveal", "row": 0, "col": 0}) == "flagged_cell"
assert g.state() == "ongoing", "BUG: flagged_cell should NOT end the game"
assert g.do_action({}) == "invalid_format"
assert g.state() == "ongoing", "BUG: invalid_format should NOT end the game"
print("  ✅ do_action keeps game ongoing on invalid moves (fixed)")

# Test variable sizes
for r, c, m in [(5,5,3), (6,6,5), (8,8,10), (10,10,12), (6,8,6)]:
    g = MinesweeperGame(r, c, m, seed=42)
    assert g.rows == r and g.cols == c
    board = g.get_visible_board()
    assert len(board) == r and len(board[0]) == c
print(f"  ✅ Variable board sizes work")

# Test progress
g = MinesweeperGame(5, 5, 3, seed=42)
assert g.progress() == 0.0
print(f"  ✅ All game engine tests passed")
print(f"  Board configs available: {len(BOARD_CONFIGS)} different sizes")

# JSON Input/Output Format

## Input Format (Game State)
The prompt includes board state for **variable board sizes** (5×5 to 10×10):
- Compact row-per-line format with row indices
- Board dimensions, mine count, remaining mines
- Pre-computed logical hints (safe cells, mine cells)

## Output Format (Action)
```json
{"type": "reveal", "row": 2, "col": 3}
```
or
```json
{"type": "flag", "row": 1, "col": 4}
```

## Dataset Columns (for reward function reconstruction)
| Column | Type | Purpose |
|--------|------|---------|
| `prompt` | list[dict] | Chat-formatted prompt |
| `seed` | int | RNG seed to reconstruct game |
| `move_history` | str (JSON) | Previous moves as JSON list |
| `board_rows` | int | Number of rows |
| `board_cols` | int | Number of columns |
| `board_mines` | int | Number of mines |

In [None]:
import json
import re

# ──────────────────────────────────────────────────────────────────────
# Minesweeper Logic Helpers — cached per game state
# ──────────────────────────────────────────────────────────────────────

def _compute_safe_and_mine_cells(game: MinesweeperGame):
    """Compute both safe and mine cells in a SINGLE pass (O(n²) not O(n⁴)).

    Returns (safe_set, mine_set) where each is a set of (row, col) tuples.

    A cell is logically SAFE if any adjacent revealed number has all its
    mines accounted for by flags (remaining_mines == 0).
    A cell is a logically CERTAIN mine if an adjacent number has
    remaining_mines == remaining unrevealed neighbors.
    """
    safe = set()
    mines = set()

    for r in range(game.rows):
        for c in range(game.cols):
            if (r, c) not in game._revealed:
                continue
            val = game._board[r][c]
            if val <= 0:
                continue

            flags = 0
            unrevealed = []
            for dr in [-1, 0, 1]:
                for dc in [-1, 0, 1]:
                    if dr == 0 and dc == 0:
                        continue
                    nr, nc = r + dr, c + dc
                    if 0 <= nr < game.rows and 0 <= nc < game.cols:
                        if (nr, nc) in game._flagged:
                            flags += 1
                        elif (nr, nc) not in game._revealed:
                            unrevealed.append((nr, nc))

            remaining = val - flags

            # All mines accounted for → remaining unrevealed are safe
            if remaining == 0 and unrevealed:
                for cell in unrevealed:
                    safe.add(cell)
            # Remaining mines == remaining unknowns → all are mines
            elif remaining > 0 and remaining == len(unrevealed):
                for cell in unrevealed:
                    mines.add(cell)

    return safe, mines


def _compute_safe_cells(game: MinesweeperGame) -> list:
    """Return list of [row, col] for logically safe cells."""
    safe, _ = _compute_safe_and_mine_cells(game)
    return [list(c) for c in safe]


def _compute_mine_cells(game: MinesweeperGame) -> list:
    """Return list of [row, col] for logically certain mine cells."""
    _, mines = _compute_safe_and_mine_cells(game)
    return [list(c) for c in mines]


def _is_logically_safe(game: MinesweeperGame, row: int, col: int) -> bool:
    """Check if (row, col) is logically safe."""
    safe, _ = _compute_safe_and_mine_cells(game)
    return (row, col) in safe


def _is_logically_mine(game: MinesweeperGame, row: int, col: int) -> bool:
    """Check if (row, col) is logically a mine."""
    _, mines = _compute_safe_and_mine_cells(game)
    return (row, col) in mines


# ──────────────────────────────────────────────────────────────────────
# Prompt formatting — concise, hint-enriched, variable board sizes
# ──────────────────────────────────────────────────────────────────────

SYSTEM_PROMPT = """You are an expert Minesweeper solver. Given a board state, output exactly ONE action as a JSON object. No explanation, no markdown, no extra text."""

def format_state_for_llm(game: MinesweeperGame) -> str:
    """Convert game state to an optimized prompt for LLM.

    Supports any board size. Prompt is kept compact to stay within
    max_prompt_length even for larger boards (up to ~10x10).
    """
    board = game.get_visible_board()

    # Compact board representation: one string per row instead of nested lists
    board_str = "\n".join(
        f"  {r:2d}: " + " ".join(row) for r, row in enumerate(board)
    )
    col_header = "      " + " ".join(f"{c}" for c in range(game.cols))

    safe_cells = _compute_safe_cells(game)
    mine_cells = _compute_mine_cells(game)

    hint_lines = []
    if safe_cells:
        hint_lines.append(f"SAFE cells (reveal one): {safe_cells[:5]}")
    if mine_cells:
        hint_lines.append(f"MINE cells (flag one): {mine_cells[:5]}")
    if not safe_cells and not mine_cells:
        hint_lines.append("No logical deductions — choose lowest-risk unrevealed cell.")

    hint_section = "\n".join(hint_lines)

    prompt = f"""Minesweeper {game.rows}x{game.cols}, {game.num_mines} mines, {game.num_mines - len(game._flagged)} remaining.
Revealed: {len(game._revealed)}/{game.rows * game.cols - game.num_mines} safe cells.

{col_header}
{board_str}

Legend: .=unrevealed F=flagged 0-8=adjacent mines

RULES:
- Number = count of adjacent mines (8 neighbors).
- remaining_mines_for_number = number - adjacent_flags.
- If remaining=0 → all unrevealed neighbors are SAFE → reveal one.
- If remaining = unrevealed_neighbor_count → all are MINES → flag one.
- Never reveal a flagged cell. Never flag a revealed cell.

{hint_section}

Output ONLY: {{"type":"reveal"|"flag","row":<int>,"col":<int>}}"""

    return prompt


def parse_llm_action(response: str) -> dict:
    """Extract JSON action from LLM response.

    Finds all JSON-like objects and returns the LAST one matching the
    expected schema (type + row + col). LLMs typically place their
    final answer at the end.
    """
    best = None
    for match in re.finditer(r'\{[^{}]*\}', response):
        try:
            action = json.loads(match.group())
            if ("type" in action and "row" in action and "col" in action
                    and action["type"] in ("reveal", "flag")):
                # Ensure row/col are ints
                action["row"] = int(action["row"])
                action["col"] = int(action["col"])
                best = action
        except (json.JSONDecodeError, ValueError, TypeError):
            continue
    return best


# ── Quick tests ──
for rows, cols, mines in [(5,5,3), (6,6,5), (8,8,10), (10,10,12)]:
    game = MinesweeperGame(rows=rows, cols=cols, num_mines=mines, seed=42)
    prompt = format_state_for_llm(game)
    assert f"{rows}x{cols}" in prompt
    assert f"{mines} mines" in prompt
    print(f"  {rows}x{cols} prompt: {len(prompt)} chars")

# Test parse_llm_action edge cases
assert parse_llm_action('{"type":"reveal","row":2,"col":3}') == {"type": "reveal", "row": 2, "col": 3}
assert parse_llm_action('blah blah {"type":"flag","row":"1","col":"2"} done') == {"type": "flag", "row": 1, "col": 2}
assert parse_llm_action('no json here') is None
assert parse_llm_action('{"type":"invalid","row":0,"col":0}') is None

# Show example prompt
game = MinesweeperGame(rows=6, cols=6, num_mines=5, seed=42)
game.do_action({"type": "reveal", "row": 0, "col": 0})
prompt = format_state_for_llm(game)
print(f"\n=== Example 6x6 prompt ({len(prompt)} chars) ===")
print(prompt)

You are an expert Minesweeper solver. Output ONE action as JSON only.

RULES:
- Numbers show how many of their 8 neighbors are mines.
- Subtract flagged neighbors from the number to find remaining mines.
- If remaining mines == remaining unrevealed neighbors → all are mines → FLAG.
- If remaining mines == 0 → all unrevealed neighbors are safe → REVEAL.
- Never reveal a flagged cell or flag a revealed cell.
- Prefer logically deducible moves over guessing.

{
  "board": [
    [
      ".",
      ".",
      ".",
      ".",
      ".",
      "."
    ],
    [
      ".",
      ".",
      ".",
      ".",
      ".",
      "."
    ],
    [
      ".",
      ".",
      ".",
      ".",
      ".",
      "."
    ],
    [
      ".",
      ".",
      ".",
      ".",
      ".",
      "."
    ],
    [
      ".",
      ".",
      ".",
      ".",
      ".",
      "."
    ],
    [
      ".",
      ".",
      ".",
      ".",
      ".",
      "."
    ]
  ],
  "rows": 6,
  "cols": 6,
  "mines": 5,
  "flags_pla

# Test Model Before Training

See how the base model performs without finetuning:

In [15]:
from transformers import TextStreamer

game = MinesweeperGame(rows=6, cols=6, num_mines=5, seed=42)
prompt = format_state_for_llm(game)

text = tokenizer.apply_chat_template(
    [{"role": "user", "content": prompt}],
    tokenize = False,
    add_generation_prompt = True,
)

print("=== Base Model Response ===")
output = model.generate(
    **tokenizer(text, return_tensors = "pt").to("cuda"),
    temperature = 0.7,
    top_p = 0.9,
    max_new_tokens = 128,
    do_sample = True,
    streamer = TextStreamer(tokenizer, skip_prompt = True),
)

=== Base Model Response ===
{"type":"reveal","row":0,"col":0}<|im_end|>


# GRPO Reward Functions

Define reward functions to guide the model's learning:

In [None]:
import numpy as np

# ──────────────────────────────────────────────────────────────────────
# Helper: Reconstruct game from dataset columns
# ──────────────────────────────────────────────────────────────────────

def _reconstruct_game(idx, kwargs):
    """Reconstruct a MinesweeperGame from dataset columns passed via GRPO kwargs.

    The dataset stores: seed, move_history, board_rows, board_cols, board_mines
    GRPOTrainer passes all non-'prompt' columns as lists in **kwargs.

    Returns (game, move_history_list) or (None, None) if data is missing.
    """
    seeds = kwargs.get("seed", [])
    move_histories = kwargs.get("move_history", [])
    rows_list = kwargs.get("board_rows", [])
    cols_list = kwargs.get("board_cols", [])
    mines_list = kwargs.get("board_mines", [])

    if idx >= len(seeds) or idx >= len(move_histories):
        return None, None

    seed = seeds[idx]
    rows = rows_list[idx] if idx < len(rows_list) else 6
    cols = cols_list[idx] if idx < len(cols_list) else 6
    num_mines = mines_list[idx] if idx < len(mines_list) else 5

    mh_raw = move_histories[idx]
    if isinstance(mh_raw, str):
        move_history = json.loads(mh_raw)
    else:
        move_history = list(mh_raw)

    game = MinesweeperGame(rows=int(rows), cols=int(cols),
                           num_mines=int(num_mines), seed=int(seed))
    for prev in move_history:
        result = game.do_action(prev)
        if result == "mine":
            return None, None  # History hit a mine — bad data

    return game, move_history


# ──────────────────────────────────────────────────────────────────────
# Reward 1: Valid JSON format + conciseness
# ──────────────────────────────────────────────────────────────────────

def valid_json_reward(prompts, completions, **kwargs):
    """Reward valid JSON action format. Also rewards conciseness.

    TRL GRPOTrainer calls: reward_func(prompts=..., completions=..., **kwargs)
    completions is a list of list-of-dicts (chat format).
    """
    scores = []
    for completion in completions:
        response = completion[0]["content"].strip() if completion else ""
        action = parse_llm_action(response)

        if action is None:
            scores.append(-3.0)
            continue

        # Bonus for pure JSON (no extra text)
        try:
            parsed = json.loads(response)
            if "type" in parsed and "row" in parsed and "col" in parsed:
                scores.append(3.0)  # Perfect — pure JSON only
                continue
        except json.JSONDecodeError:
            pass

        # Valid JSON but with extra surrounding text
        json_match = re.search(r'\{[^{}]*\}', response)
        extra_chars = len(response) - len(json_match.group()) if json_match else len(response)
        if extra_chars < 10:
            scores.append(2.5)
        elif extra_chars < 50:
            scores.append(1.0)
        elif extra_chars < 200:
            scores.append(0.0)
        else:
            scores.append(-1.0)  # Too verbose

    return scores


# ──────────────────────────────────────────────────────────────────────
# Reward 2: Gameplay — complete 12-criterion scoring
# ──────────────────────────────────────────────────────────────────────

def gameplay_scores(prompts, completions, **kwargs):
    """
    Complete gameplay reward implementing all 12 scoring criteria.
    Uses variable board sizes from dataset columns.

    1.  Flag cell that IS a mine        → +15 / +20 (logical)
    2.  Flag cell that is NOT a mine    → -10
    3.  Reveal cell that IS a mine      → -25
    4.  Reveal cell that is safe        → +10 (guess) / +15 (logical)
    5.  Flag already flagged cell       → -12
    6.  Reveal already revealed cell    → -12
    7.  Out of bounds                   → -15
    8.  Total flags > total mines       → -10 (additional)
    9.  Invalid JSON                    → -10
    10. Win the game                    → +100
    11. Reveal a flagged cell           → -8
    12. Flag a revealed cell            → -8
    """
    scores = []

    for idx, completion in enumerate(completions):
        response = completion[0]["content"] if completion else ""
        action = parse_llm_action(response)

        # ── Criterion 9: Invalid JSON ──
        if action is None:
            scores.append(-10.0)
            continue

        # ── Reconstruct game state from dataset columns ──
        game, move_history = _reconstruct_game(idx, kwargs)
        if game is None:
            scores.append(0.0)
            continue

        row, col = action["row"], action["col"]
        action_type = action["type"]

        # ── Criterion 7: Out of bounds ──
        if not (0 <= row < game.rows and 0 <= col < game.cols):
            scores.append(-15.0)
            continue

        # ── Compute logical deductions ONCE ──
        safe_set, mine_set = _compute_safe_and_mine_cells(game)

        score = 0.0

        if action_type == "reveal":
            # ── Criterion 6: Reveal already revealed cell ──
            if (row, col) in game._revealed:
                scores.append(-12.0)
                continue

            # ── Criterion 11: Reveal a flagged cell ──
            if (row, col) in game._flagged:
                scores.append(-8.0)
                continue

            # ── Criterion 3: Reveal a mine ──
            if game._board[row][col] == -1:
                scores.append(-25.0)
                continue

            # ── Criterion 4: Reveal safe cell ──
            if (row, col) in safe_set:
                score += 15.0   # Logically deduced safe cell
            else:
                score += 10.0   # Guessed safe cell

            # Small bonus for revealing near numbers (information-rich)
            board = game.get_visible_board()
            has_adjacent_number = False
            for dr in [-1, 0, 1]:
                for dc in [-1, 0, 1]:
                    nr, nc = row + dr, col + dc
                    if 0 <= nr < game.rows and 0 <= nc < game.cols:
                        if board[nr][nc] in ('1','2','3','4','5','6','7','8'):
                            has_adjacent_number = True
                            break
                if has_adjacent_number:
                    break
            if has_adjacent_number:
                score += 1.0

            # ── Criterion 10: Check for win ──
            game_copy = MinesweeperGame(
                rows=game.rows, cols=game.cols,
                num_mines=game.num_mines, seed=kwargs.get("seed", [0])[idx]
            )
            for prev in move_history:
                game_copy.do_action(prev)
            result = game_copy.do_action(action)
            if result == "win":
                score += 100.0

        elif action_type == "flag":
            # ── Criterion 5: Flag already flagged cell ──
            if (row, col) in game._flagged:
                scores.append(-12.0)
                continue

            # ── Criterion 12: Flag a revealed cell ──
            if (row, col) in game._revealed:
                scores.append(-8.0)
                continue

            # ── Criterion 1: Flag a mine (correct) ──
            if game._board[row][col] == -1:
                if (row, col) in mine_set:
                    score += 20.0   # Logically deduced mine
                else:
                    score += 15.0   # Correct but guessed

            # ── Criterion 2: Flag a non-mine (wrong) ──
            else:
                score -= 10.0

            # ── Criterion 8: Total flags > total mines ──
            new_flag_count = len(game._flagged) + 1
            if new_flag_count > game.num_mines:
                score -= 10.0

        scores.append(score)

    return scores


# ──────────────────────────────────────────────────────────────────────
# Reward 3: Strategic play — rewards logical deduction over guessing
# ──────────────────────────────────────────────────────────────────────

def strategic_reward(prompts, completions, **kwargs):
    """Reward strategic play patterns:
    - Choosing logically deducible moves when available
    - Opening in corners/edges (lower mine probability on fresh boards)
    - Penalize ignoring available deductions
    """
    scores = []

    for idx, completion in enumerate(completions):
        response = completion[0]["content"] if completion else ""
        action = parse_llm_action(response)

        if action is None:
            scores.append(0.0)
            continue

        game, move_history = _reconstruct_game(idx, kwargs)
        if game is None:
            scores.append(0.0)
            continue

        row, col = action["row"], action["col"]
        action_type = action["type"]
        score = 0.0

        if not (0 <= row < game.rows and 0 <= col < game.cols):
            scores.append(0.0)
            continue

        # ── Compute logical deductions ONCE ──
        safe_set, mine_set = _compute_safe_and_mine_cells(game)

        # ── Fresh game opening strategy ──
        if len(game._revealed) == 0 and action_type == "reveal":
            corners = [(0, 0), (0, game.cols - 1),
                       (game.rows - 1, 0), (game.rows - 1, game.cols - 1)]
            edges = ([(0, c) for c in range(game.cols)] +
                     [(game.rows-1, c) for c in range(game.cols)] +
                     [(r, 0) for r in range(game.rows)] +
                     [(r, game.cols-1) for r in range(game.rows)])
            if (row, col) in corners:
                score += 3.0   # Corners have only 3 neighbors → lowest risk
            elif (row, col) in edges:
                score += 1.0   # Edges have 5 neighbors → moderate risk

        # ── Reward choosing logically deducible moves ──
        if action_type == "reveal" and (row, col) in safe_set:
            score += 5.0   # Chose a provably safe cell
        elif action_type == "flag" and (row, col) in mine_set:
            score += 5.0   # Chose a provably mine cell
        elif safe_set or mine_set:
            # Deducible moves existed but agent didn't pick one
            score -= 3.0

        # ── Penalize flagging on a fresh board (no info to deduce) ──
        if len(game._revealed) == 0 and action_type == "flag":
            score -= 2.0

        scores.append(score)

    return scores


# ── Verify reward function signatures ──
print("✅ All reward functions defined with correct TRL signature:")
print("   1. valid_json_reward(prompts, completions, **kwargs)")
print("   2. gameplay_scores(prompts, completions, **kwargs)")
print("   3. strategic_reward(prompts, completions, **kwargs)")
print()

# ── Smoke test: simulate what GRPOTrainer passes ──
test_completions = [[{"role": "assistant", "content": '{"type":"reveal","row":0,"col":0}'}]]
test_prompts = ["test"]
test_kwargs = {
    "seed": [42],
    "move_history": ["[]"],
    "board_rows": [6],
    "board_cols": [6],
    "board_mines": [5],
}
r1 = valid_json_reward(test_prompts, test_completions, **test_kwargs)
r2 = gameplay_scores(test_prompts, test_completions, **test_kwargs)
r3 = strategic_reward(test_prompts, test_completions, **test_kwargs)
print(f"  Smoke test rewards: format={r1[0]:.1f}, gameplay={r2[0]:.1f}, strategic={r3[0]:.1f}")
print(f"  ✅ All reward functions work with kwargs (seed, move_history, board_rows/cols/mines)")

✅ All reward functions defined:
   1. valid_json_reward   — format + conciseness
   2. gameplay_scores     — all 12 criteria
   3. strategic_reward    — logical deduction bonuses


# Create Training Dataset

Generate diverse game states for training:

In [None]:
from datasets import Dataset

def _smart_reveal(game, rng):
    """Pick a smart cell to reveal during history generation.
    Prefers safe cells (if known) to avoid hitting mines and losing the game.
    Falls back to random unrevealed cell.
    """
    safe_set, _ = _compute_safe_and_mine_cells(game)
    if safe_set:
        return rng.choice(list(safe_set))

    # Fallback: random unrevealed, unflagged cell
    unrevealed = [(r, c) for r in range(game.rows) for c in range(game.cols)
                  if (r, c) not in game._revealed and (r, c) not in game._flagged]
    if not unrevealed:
        return None
    return rng.choice(unrevealed)


def _smart_flag(game, rng):
    """Pick a logically certain mine to flag, or None if none available."""
    _, mine_set = _compute_safe_and_mine_cells(game)
    mine_candidates = [c for c in mine_set if c not in game._flagged]
    if mine_candidates:
        return rng.choice(mine_candidates)
    return None  # Don't flag randomly — creates bad training data


def generate_game_states(num_samples=3000, rng_seed=42):
    """
    Generate diverse Minesweeper game states with:

    1. VARIABLE BOARD SIZES — sampled from BOARD_CONFIGS
    2. CURRICULUM LEARNING — fresh/early/mid/late distribution
    3. SMART move history — uses logical deduction to avoid mines
    4. FLAG states — only flags logically certain mines (no random flags)
    5. EDGE CASES — includes games with:
       - Boards where all neighbors of revealed cells are mines
       - Boards with zero-cascades (large reveals)
       - High mine density boards
       - Rectangular boards

    Stores: seed, move_history, board_rows, board_cols, board_mines
    so reward functions can reconstruct the EXACT game state.
    """
    np.random.seed(rng_seed)
    rng = random.Random(rng_seed)

    dataset_items = []
    attempts = 0
    max_attempts = num_samples * 10  # More attempts for harder configs

    # Move-count distribution for curriculum
    move_bins = [
        (0, 0, 0.12),    # Fresh — learn opening strategy
        (1, 2, 0.22),    # Early — learn basic deduction
        (3, 8, 0.40),    # Mid   — learn constraint satisfaction
        (9, 25, 0.18),   # Late  — learn endgame / flagging
        (1, 1, 0.08),    # Single-move — learn from minimal info
    ]

    # Counters for distribution tracking
    config_counts = {}
    phase_counts = {"fresh": 0, "early": 0, "mid": 0, "late": 0, "single": 0}

    while len(dataset_items) < num_samples and attempts < max_attempts:
        attempts += 1

        # ── Sample board configuration ──
        rows, cols, num_mines = sample_board_config(rng)
        config_key = f"{rows}x{cols}m{num_mines}"
        config_counts[config_key] = config_counts.get(config_key, 0) + 1

        # ── Sample game phase ──
        phase_rand = rng.random()
        cumulative = 0
        min_moves, max_moves_range = 0, 0
        phase_name = "fresh"
        for i, (mn, mx, prob) in enumerate(move_bins):
            cumulative += prob
            if phase_rand < cumulative:
                min_moves, max_moves_range = mn, mx
                phase_name = ["fresh", "early", "mid", "late", "single"][i]
                break

        # Scale move count by board size (larger boards need more moves for mid/late)
        total_safe = rows * cols - num_mines
        if phase_name in ("mid", "late"):
            max_moves_range = min(max_moves_range, total_safe - 1)

        if max_moves_range < min_moves:
            max_moves_range = min_moves
        num_moves = rng.randint(min_moves, max_moves_range)

        seed = rng.randint(0, 999999)
        game = MinesweeperGame(rows=rows, cols=cols, num_mines=num_mines, seed=seed)
        move_history = []

        for move_i in range(num_moves):
            if game.state() != "ongoing":
                break

            # ── Decide action: reveal vs flag ──
            # Flag only if we have logical certainty (15% chance when possible)
            action_dict = None
            if rng.random() < 0.15:
                flag_target = _smart_flag(game, rng)
                if flag_target:
                    action_dict = {"type": "flag", "row": flag_target[0], "col": flag_target[1]}

            if action_dict is None:
                reveal_target = _smart_reveal(game, rng)
                if reveal_target is None:
                    break
                action_dict = {"type": "reveal", "row": reveal_target[0], "col": reveal_target[1]}

            result = game.do_action(action_dict)
            if result == "mine":
                break  # Hit a mine — discard this game
            move_history.append(action_dict)

        # Only keep ongoing games with valid state
        if game.state() == "ongoing":
            prompt_text = format_state_for_llm(game)
            dataset_items.append({
                "prompt": [{"role": "user", "content": prompt_text}],
                "seed": seed,
                "move_history": json.dumps(move_history),
                "board_rows": rows,
                "board_cols": cols,
                "board_mines": num_mines,
            })
            phase_counts[phase_name] = phase_counts.get(phase_name, 0) + 1

    dataset_items = dataset_items[:num_samples]
    ds = Dataset.from_list(dataset_items)

    return ds, config_counts, phase_counts


# ── Generate training dataset ──
print("Generating training dataset with variable board sizes + curriculum...")
dataset, config_counts, phase_counts = generate_game_states(num_samples=3000, rng_seed=42)
print(f"Created {len(dataset)} training examples\n")

# Distribution analysis
print("Board size distribution:")
for config, count in sorted(config_counts.items(), key=lambda x: -x[1]):
    pct = count / len(dataset) * 100 if len(dataset) > 0 else 0
    print(f"  {config:12s}: {count:4d} ({pct:.1f}%)")

print(f"\nPhase distribution:")
for phase, count in phase_counts.items():
    pct = count / len(dataset) * 100 if len(dataset) > 0 else 0
    print(f"  {phase:8s}: {count:4d} ({pct:.1f}%)")

move_counts = [len(json.loads(item["move_history"])) for item in dataset]
print(f"\nMove statistics:")
print(f"  Min: {min(move_counts)}, Max: {max(move_counts)}, "
      f"Mean: {np.mean(move_counts):.1f}, Median: {np.median(move_counts):.1f}")

# Verify dataset columns
print(f"\nDataset columns: {dataset.column_names}")
print(f"  ✅ board_rows, board_cols, board_mines present for reward functions")

# Show sample
print(f"\nSample prompt ({dataset[0]['board_rows']}x{dataset[0]['board_cols']}, "
      f"{dataset[0]['board_mines']} mines):")
print(dataset[0]["prompt"][0]["content"][:200] + "...")

Generating training dataset with curriculum learning...
Created 2000 training examples (all ongoing games)

  Fresh games (0 moves): 637 (31.9%)
  Early game (1-2):      833 (41.6%)
  Mid game (3-8):        516 (25.8%)
  Late game (9+):        14 (0.7%)
  Avg moves per state:   1.9

Example training prompt (first 300 chars):
You are an expert Minesweeper solver. Output ONE action as JSON only.

RULES:
- Numbers show how many of their 8 neighbors are mines.
- Subtract flagged neighbors from the number to find remaining mines.
- If remaining mines == remaining unrevealed neighbors → all are mines → FLAG.
- If remaining mi...


# Configure GRPO Training

Set up GRPO trainer with all hyperparameters:

In [None]:
from trl import GRPOConfig, GRPOTrainer

# ── Lengths ──
# A 10x10 board with hints can be ~900 tokens. 1400 gives safe margin.
max_prompt_length = 1400
max_completion_length = 128  # Pure JSON output is ~40 tokens; 128 is generous

# ── GRPO Configuration ──
training_args = GRPOConfig(
    # === Generation ===
    temperature = 0.9,           # Exploration during training
    top_p = 0.95,

    # === Optimization ===
    learning_rate = 2e-5,
    weight_decay = 0.01,
    warmup_ratio = 0.05,
    lr_scheduler_type = "cosine",
    optim = "adamw_8bit",
    max_grad_norm = 0.5,

    # === Batch sizes ===
    logging_steps = 1,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 4,
    num_generations = 8,

    # === Lengths ===
    max_prompt_length = max_prompt_length,
    max_completion_length = max_completion_length,

    # === Training duration ===
    max_steps = 500,
    save_steps = 100,

    # === GRPO specific ===
    beta = 0.04,                 # Mild KL penalty to prevent reward hacking
    num_iterations = 1,

    # === Reward weighting (gameplay >> format >> strategy) ===
    reward_weights = [0.15, 0.70, 0.15],

    # === Ensure extra dataset columns are NOT removed ===
    remove_unused_columns = False,

    # === Output ===
    report_to = "none",
    output_dir = "minesweeper_grpo_v2",
    seed = 42,
    bf16 = True,
)

print("Training configuration:")
print(f"  Max steps:           {training_args.max_steps}")
print(f"  Generations/state:   {training_args.num_generations}")
print(f"  Learning rate:       {training_args.learning_rate}")
print(f"  LR scheduler:       {training_args.lr_scheduler_type}")
print(f"  Max grad norm:       {training_args.max_grad_norm}")
print(f"  Beta (KL penalty):   {training_args.beta}")
print(f"  Reward weights:      {training_args.reward_weights}")
print(f"  Prompt/Completion:   {max_prompt_length}/{max_completion_length}")
print(f"  Temperature:         {training_args.temperature}")
print(f"  remove_unused_cols:  {training_args.remove_unused_columns}")
print(f"  LoRA rank:           {lora_rank}")

Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.
We will change the batch size of 1 to the `num_generations` of 8
Training configuration:
  Model:               Qwen2.5-14B-Instruct
  Max steps:           500
  Generations/state:   8
  Learning rate:       2e-05
  LR scheduler:       SchedulerType.COSINE
  Max grad norm:       0.5
  Loss type:           bnpo
  Beta (KL penalty):   0.001
  Reward weights:      [0.15, 0.7, 0.15]
  Prompt/Completion:   700/200
  Temperature:         0.9
  Top-p:               0.95
  LoRA rank:           32


In [None]:
from transformers import TrainerCallback

# Board configs to evaluate on during training (mix of sizes)
EVAL_CONFIGS = [
    (6, 6, 5),   # Standard
    (6, 6, 5),   # Standard x2
    (5, 5, 3),   # Small
    (7, 7, 7),   # Medium
    (8, 8, 10),  # Large
    (6, 6, 5),   # Standard x3
    (6, 6, 5),   # Standard x4
    (9, 9, 10),  # Extra large
    (6, 8, 6),   # Rectangular
    (8, 8, 8),   # Large / moderate density
]


class MinesweeperEvalCallback(TrainerCallback):
    """Periodically play games during training with variable board sizes."""

    def __init__(self, eval_every_steps=50, num_games=10):
        self.eval_every_steps = eval_every_steps
        self.num_games = min(num_games, len(EVAL_CONFIGS))

    def on_step_end(self, args, state, control, model=None, processing_class=None, **kwargs):
        if state.global_step % self.eval_every_steps != 0:
            return

        tokenizer = processing_class
        if tokenizer is None or model is None:
            return

        was_training = model.training
        model.eval()

        wins = 0
        total_moves = 0
        invalid_count = 0

        for i in range(self.num_games):
            rows, cols, mines = EVAL_CONFIGS[i % len(EVAL_CONFIGS)]
            game = MinesweeperGame(rows=rows, cols=cols, num_mines=mines,
                                   seed=10000 + i)
            moves = 0
            invalids = 0
            max_moves = rows * cols  # Scale max moves with board size

            while game.state() == "ongoing" and moves < max_moves:
                prompt = format_state_for_llm(game)
                text = tokenizer.apply_chat_template(
                    [{"role": "user", "content": prompt}],
                    tokenize=False,
                    add_generation_prompt=True,
                )
                inputs = tokenizer(text, return_tensors="pt", truncation=True,
                                   max_length=max_prompt_length + 100)
                inputs = {k: v.to(model.device) for k, v in inputs.items()}

                with torch.no_grad():
                    output = model.generate(
                        **inputs,
                        temperature=0.3,
                        max_new_tokens=128,
                        do_sample=True,
                        top_p=0.8,
                    )

                # Decode ONLY the generated tokens (not the prompt)
                gen_tokens = output[0][inputs["input_ids"].shape[1]:]
                response = tokenizer.decode(gen_tokens, skip_special_tokens=True)
                action = parse_llm_action(response)

                if action is None:
                    invalids += 1
                    if invalids >= 3:
                        break
                    continue

                invalids = 0
                result = game.do_action(action)
                if result in ("mine", "win"):
                    moves += 1
                    break
                elif result == "ok":
                    moves += 1
                # For invalid moves (game stays ongoing), don't count as a move

            if game.state() == "success":
                wins += 1
            total_moves += moves
            invalid_count += invalids

        win_rate = wins / self.num_games
        avg_moves = total_moves / self.num_games
        print(f"\n[Eval @ step {state.global_step}] "
              f"Win: {wins}/{self.num_games} ({win_rate*100:.0f}%) | "
              f"Avg moves: {avg_moves:.1f} | "
              f"Invalid: {invalid_count}\n")

        if was_training:
            model.train()


eval_callback = MinesweeperEvalCallback(eval_every_steps=50, num_games=10)
print(f"Eval callback: {eval_callback.num_games} games every "
      f"{eval_callback.eval_every_steps} steps (variable board sizes)")

Eval callback: 10 games every 50 steps (temp=0.3 for deterministic eval)


# Train the Model

Start GRPO training with reward functions:

In [None]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        valid_json_reward,   # Reward valid JSON format + conciseness
        gameplay_scores,     # Core gameplay (all 12 criteria)
        strategic_reward,    # Logical deduction bonuses
    ],
    args = training_args,
    train_dataset = dataset,
    callbacks = [eval_callback],  # Periodic gameplay evaluation
)

print("Starting GRPO training with 3 reward functions...")
print("  [1] valid_json_reward  (weight: 0.15)")
print("  [2] gameplay_scores    (weight: 0.70)")
print("  [3] strategic_reward   (weight: 0.15)")
trainer.train()

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None}.


Starting GRPO training with 3 reward functions...
  [1] valid_json_reward  (weight: 0.15)
  [2] gameplay_scores    (weight: 0.70)
  [3] strategic_reward   (weight: 0.15)


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 2,000 | Num Epochs = 1 | Total steps = 500
O^O/ \_/ \    Batch size per device = 8 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (8 x 4 x 1) = 32
 "-____-"     Trainable parameters = 137,625,600 of 14,907,659,264 (0.92% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss,reward,reward_std,completions / mean_length,completions / min_length,completions / max_length,completions / clipped_ratio,completions / mean_terminated_length,completions / min_terminated_length,completions / max_terminated_length,sampling / sampling_logp_difference / mean,sampling / sampling_logp_difference / max,sampling / importance_sampling_ratio / min,sampling / importance_sampling_ratio / mean,sampling / importance_sampling_ratio / max,kl,rewards / valid_json_reward / mean,rewards / valid_json_reward / std,rewards / gameplay_scores / mean,rewards / gameplay_scores / std,rewards / strategic_reward / mean,rewards / strategic_reward / std
1,0.0,10.775,0.0,14.0,14.0,14.0,0.0,14.0,14.0,14.0,0,0,0,0,0,9e-06,3.0,0.0,14.0,4.310527,3.5,1.524001
2,0.0,-0.4375,0.0,14.0,14.0,14.0,0.0,14.0,14.0,14.0,No Log,No Log,No Log,No Log,No Log,5.2e-05,3.0,0.0,-1.75,18.008959,2.25,1.813925
3,0.0001,-3.4125,0.0,14.625,14.0,24.0,0.0,14.625,14.0,24.0,No Log,No Log,No Log,No Log,No Log,0.023024,3.0,0.0,-6.0,19.423962,2.25,1.813925
4,0.0,5.428125,1.177424,14.0,14.0,14.0,0.0,14.0,14.0,14.0,No Log,No Log,No Log,No Log,No Log,0.000511,3.0,0.0,6.46875,16.329241,3.0,2.155264
5,0.0001,0.621875,2.184623,14.625,14.0,24.0,0.0,14.625,14.0,24.0,No Log,No Log,No Log,No Log,No Log,0.028749,3.0,0.0,0.03125,11.830754,1.0,1.016001
6,0.0,4.565625,2.72226,14.0,14.0,14.0,0.0,14.0,14.0,14.0,No Log,No Log,No Log,No Log,No Log,0.000225,3.0,0.0,5.34375,16.736303,2.5,2.540002
7,0.0,12.6375,0.0,14.0,14.0,14.0,0.0,14.0,14.0,14.0,No Log,No Log,No Log,No Log,No Log,1e-05,3.0,0.0,16.5,4.158163,4.25,1.319824


# Test Trained Model

Evaluate the finetuned model:

In [None]:
# Test on multiple board sizes
FastLanguageModel.for_inference(model)

test_configs = [
    (6, 6, 5, 99, "Standard"),
    (5, 5, 3, 100, "Small/Easy"),
    (8, 8, 10, 101, "Large/Hard"),
    (7, 7, 7, 102, "Medium"),
]

for rows, cols, mines, seed, label in test_configs:
    print(f"\n{'='*50}")
    print(f"=== {label} ({rows}x{cols}, {mines} mines) ===")
    print(f"{'='*50}")

    test_game = MinesweeperGame(rows=rows, cols=cols, num_mines=mines, seed=seed)
    test_prompt = format_state_for_llm(test_game)

    test_text = tokenizer.apply_chat_template(
        [{"role": "user", "content": test_prompt}],
        tokenize=False,
        add_generation_prompt=True,
    )

    inputs = tokenizer(test_text, return_tensors="pt", truncation=True,
                       max_length=max_prompt_length + 100)
    inputs = {k: v.to("cuda") for k, v in inputs.items()}

    output = model.generate(
        **inputs,
        temperature=0.3,
        max_new_tokens=128,
        do_sample=True,
        top_p=0.8,
        repetition_penalty=1.2,
    )

    # Decode only generated tokens
    gen_tokens = output[0][inputs["input_ids"].shape[1]:]
    response_text = tokenizer.decode(gen_tokens, skip_special_tokens=True)
    print(f"Response: {response_text.strip()}")

    action = parse_llm_action(response_text)
    print(f"Parsed action: {action}")

    if action:
        result = test_game.do_action(action)
        print(f"Result: {result} | Game state: {test_game.state()}")
    else:
        print("⚠️ Failed to parse a valid action")

# Evaluation: Play Complete Games

Test the model on multiple complete games:

In [None]:
def play_full_game(model, tokenizer, rows=6, cols=6, num_mines=5, seed=None,
                   max_moves=None, verbose=False):
    """Play a complete Minesweeper game, tracking detailed metrics.

    Supports any board size. max_moves defaults to rows*cols if not specified.
    """
    if max_moves is None:
        max_moves = rows * cols  # Scale with board size

    game = MinesweeperGame(rows=rows, cols=cols, num_mines=num_mines, seed=seed)
    moves = 0
    invalid_streak = 0
    total_invalids = 0
    logical_moves = 0
    flags_correct = 0
    flags_wrong = 0

    while game.state() == "ongoing" and moves < max_moves:
        prompt = format_state_for_llm(game)
        text = tokenizer.apply_chat_template(
            [{"role": "user", "content": prompt}],
            tokenize=False,
            add_generation_prompt=True,
        )

        inputs = tokenizer(text, return_tensors="pt", truncation=True,
                           max_length=max_prompt_length + 100)
        inputs = {k: v.to("cuda") for k, v in inputs.items()}

        with torch.no_grad():
            output = model.generate(
                **inputs,
                temperature=0.3,
                max_new_tokens=128,
                do_sample=True,
                top_p=0.8,
                repetition_penalty=1.2,
            )

        # Decode ONLY generated tokens
        gen_tokens = output[0][inputs["input_ids"].shape[1]:]
        response = tokenizer.decode(gen_tokens, skip_special_tokens=True)
        action = parse_llm_action(response)

        if action is None:
            invalid_streak += 1
            total_invalids += 1
            if invalid_streak >= 3:
                break
            continue

        invalid_streak = 0

        # Track logical moves (compute BEFORE applying action)
        safe_set, mine_set = _compute_safe_and_mine_cells(game)
        r, c = action["row"], action["col"]
        if action["type"] == "reveal" and (r, c) in safe_set:
            logical_moves += 1
        elif action["type"] == "flag" and (r, c) in mine_set:
            logical_moves += 1

        # Track flag accuracy
        if action["type"] == "flag":
            if 0 <= r < game.rows and 0 <= c < game.cols:
                if game._board[r][c] == -1:
                    flags_correct += 1
                else:
                    flags_wrong += 1

        if verbose:
            print(f"  Move {moves}: {action}")

        result = game.do_action(action)
        if result in ("mine", "win", "ok"):
            moves += 1
        elif result in ("out_of_bounds", "already_revealed", "flagged_cell",
                         "invalid_flag", "invalid_format"):
            # Invalid moves don't count but game stays ongoing
            total_invalids += 1
            invalid_streak += 1
            if invalid_streak >= 3:
                break

        if result in ("mine", "win"):
            break

    return {
        "game": game,
        "moves": moves,
        "logical_moves": logical_moves,
        "flags_correct": flags_correct,
        "flags_wrong": flags_wrong,
        "total_invalids": total_invalids,
        "result": game.state(),
        "progress": game.progress(),
        "config": f"{rows}x{cols}m{num_mines}",
    }


# ──────────────────────────────────────────────────────────────────────
# Comprehensive Multi-Size Evaluation
# ──────────────────────────────────────────────────────────────────────

EVAL_SUITE = [
    # (rows, cols, mines, num_games, label)
    (5, 5, 3,  15, "Small/Easy"),
    (6, 6, 5,  30, "Standard"),
    (6, 6, 7,  10, "Standard/Hard"),
    (7, 7, 7,  15, "Medium"),
    (8, 8, 10, 15, "Large/Hard"),
    (9, 9, 10, 10, "XL"),
    (6, 8, 6,   5, "Rectangular"),
]

FastLanguageModel.for_inference(model)
print(f"{'='*70}")
print(f"  COMPREHENSIVE EVALUATION — {sum(n for _,_,_,n,_ in EVAL_SUITE)} games across {len(EVAL_SUITE)} configs")
print(f"{'='*70}\n")

all_results = []
per_config_stats = {}

for rows, cols, mines, num_games, label in EVAL_SUITE:
    config_key = f"{rows}x{cols}m{mines}"
    wins = 0
    config_results = []

    for i in range(num_games):
        info = play_full_game(model, tokenizer, rows=rows, cols=cols,
                              num_mines=mines, seed=5000 + i + hash(config_key) % 1000)
        config_results.append(info)
        all_results.append(info)

        if info["result"] == "success":
            wins += 1

    # Per-config summary
    avg_moves = np.mean([r["moves"] for r in config_results])
    avg_logical = np.mean([r["logical_moves"] for r in config_results])
    avg_progress = np.mean([r["progress"] for r in config_results])
    avg_invalids = np.mean([r["total_invalids"] for r in config_results])
    wr = wins / num_games * 100

    per_config_stats[config_key] = {
        "label": label, "wins": wins, "total": num_games,
        "win_rate": wr, "avg_moves": avg_moves, "avg_progress": avg_progress,
    }

    print(f"  {label:16s} ({config_key:8s}): "
          f"{wins:2d}/{num_games:2d} wins ({wr:5.1f}%) | "
          f"moves={avg_moves:4.1f} | logical={avg_logical:4.1f} | "
          f"progress={avg_progress:.0%} | invalids={avg_invalids:.1f}")

# ── Overall Summary ──
total_games = len(all_results)
total_wins = sum(1 for r in all_results if r["result"] == "success")
total_fails = sum(1 for r in all_results if r["result"] == "failed")
total_ongoing = sum(1 for r in all_results if r["result"] == "ongoing")
fc = sum(r["flags_correct"] for r in all_results)
fw = sum(r["flags_wrong"] for r in all_results)

print(f"\n{'='*70}")
print(f"  OVERALL: {total_wins}/{total_games} wins ({total_wins/total_games*100:.1f}%)")
print(f"{'='*70}")
print(f"  Wins:     {total_wins:4d} ({total_wins/total_games*100:.1f}%)")
print(f"  Losses:   {total_fails:4d} ({total_fails/total_games*100:.1f}%)")
print(f"  Timeout:  {total_ongoing:4d} ({total_ongoing/total_games*100:.1f}%)")
print(f"  Avg moves:    {np.mean([r['moves'] for r in all_results]):.1f}")
print(f"  Avg progress: {np.mean([r['progress'] for r in all_results]):.0%}")
print(f"  Avg logical:  {np.mean([r['logical_moves'] for r in all_results]):.1f}")
if fc + fw > 0:
    print(f"  Flag accuracy: {fc}/{fc+fw} ({fc/(fc+fw)*100:.1f}%)")
else:
    print(f"  Flags: none placed")
print(f"{'='*70}")

# Save the Model

Save your trained model for competition submission:

In [None]:
# Save LoRA adapters
model.save_pretrained("my_minesweeper_model")
tokenizer.save_pretrained("my_minesweeper_model")
print("✅ LoRA adapters saved to: my_minesweeper_model/")

# Save merged model in 16bit (local file name which will be used for eval)
if True:
    model.save_pretrained_merged(
        "my_minesweeper_model_merged",
        tokenizer,
        save_method = "merged_16bit"
    )
    print("✅ Merged 16-bit model saved to: my_minesweeper_model_merged/")

# Fixes & Improvements Applied

## Critical Bugs Fixed
| # | Bug | Fix |
|---|-----|-----|
| 1 | `do_action()` set `_state="failed"` for ALL invalid moves (out-of-bounds, already-revealed, etc.) — game ended on typos | Only `mine` sets state to "failed"; other invalid moves return error string but keep game "ongoing" |
| 2 | Reward functions used `**kwargs` to get `seed`/`move_history` — but GRPOTrainer only passes `completions` as positional; extra dataset columns need `prompts, completions, **kwargs` signature | Changed all reward function signatures to `(prompts, completions, **kwargs)` matching TRL's actual calling convention |
| 3 | `gameplay_scores` and `strategic_reward` hardcoded `MinesweeperGame(rows=6, cols=6, num_mines=5)` — no variable board support | Added `board_rows`, `board_cols`, `board_mines` to dataset; `_reconstruct_game()` reads them from kwargs |
| 4 | `max_prompt_length=700` truncated prompts with hints on 6x6 boards; larger boards would be destroyed | Increased to 1400 tokens; prompts made more compact |
| 5 | `_compute_safe_cells` and `_compute_mine_cells` were separate O(n²) passes, called multiple times per reward evaluation (O(n⁴) total) | Combined into `_compute_safe_and_mine_cells()` — single pass, returns both sets |
| 6 | Eval callback decoded FULL output including prompt before parsing | Now decodes only generated tokens: `output[0][input_ids.shape[1]:]` |
| 7 | Win-by-flagging was impossible: `_check_win()` only checked revealed cells | Documented: standard Minesweeper wins by revealing all safe cells (flags optional) — this is correct behavior |
| 8 | `remove_unused_columns` not explicitly set — could strip `seed`/`move_history` columns | Explicitly set `remove_unused_columns=False` in GRPOConfig |
| 9 | Dataset only flagged random cells (often non-mines) creating confusing training data | `_smart_flag()` only flags logically certain mines; `_smart_reveal()` prefers safe cells |
| 10 | Invalid JSON scored -50 in gameplay_scores + -5 in valid_json_reward = -35.75 weighted (worse than hitting a mine at -17.5) | Rebalanced: invalid JSON = -10 in gameplay, -3 in format. Mine = -25. Proper severity ordering |

## Design Improvements
| # | Change | Impact |
|---|--------|--------|
| 11 | Variable board sizes: 12 configs (5×5 to 10×10, rectangular) with weighted sampling | Model generalizes to any board size |
| 12 | 3000 training samples (up from 2000) with board diversity | More coverage |
| 13 | Smart move history generation using logical deduction | Cleaner training data; games don't end early from random mine hits |
| 14 | Multi-size evaluation suite: 100 games across 7 board configs | Robust performance measurement |
| 15 | `beta=0.04` mild KL penalty (was 0.0) | Prevents reward hacking / policy drift |
| 16 | `max_completion_length=128` (was 200) | JSON output is ~40 tokens; tighter limit |
| 17 | `max_seq_length=2048` (was 1024) | Supports larger board prompts |
| 18 | Compact board representation in prompt (row-per-line vs nested JSON) | Fewer tokens for same information |
| 19 | `parse_llm_action` now coerces row/col to int | Handles string numbers like "2" |
| 20 | Edge/corner opening bonus scaled for variable board sizes | Better opening strategy across sizes |