In [33]:
import json
import re
import random
import copy
from dataclasses import dataclass, field
from typing import List, Tuple, Optional, Set, Dict
from collections import Counter


@dataclass
class MinesweeperGame:
    rows: int
    cols: int
    num_mines: int
    seed: Optional[int] = None
    _rng: random.Random = field(init=False, repr=False)
    _board: List[List[int]] = field(init=False, repr=False)
    _revealed: Set[Tuple[int, int]] = field(init=False, repr=False, default_factory=set)
    _flagged: Set[Tuple[int, int]] = field(init=False, repr=False, default_factory=set)
    _state: str = field(default="ongoing", init=False, repr=False)

    def __post_init__(self):
        if self.num_mines >= self.rows * self.cols:
            raise ValueError("Too many mines for board size")
        self._rng = random.Random(self.seed)
        self._board = [[0 for _ in range(self.cols)] for _ in range(self.rows)]
        self._place_mines()
        self._calculate_numbers()

    def _place_mines(self):
        positions = [(r, c) for r in range(self.rows) for c in range(self.cols)]
        mine_positions = self._rng.sample(positions, self.num_mines)
        for r, c in mine_positions:
            self._board[r][c] = -1

    def _calculate_numbers(self):
        for r in range(self.rows):
            for c in range(self.cols):
                if self._board[r][c] == -1:
                    continue
                count = 0
                for dr in [-1, 0, 1]:
                    for dc in [-1, 0, 1]:
                        if dr == 0 and dc == 0:
                            continue
                        nr, nc = r + dr, c + dc
                        if 0 <= nr < self.rows and 0 <= nc < self.cols:
                            if self._board[nr][nc] == -1:
                                count += 1
                self._board[r][c] = count

    def _reveal_cell(self, row: int, col: int) -> bool:
        if not (0 <= row < self.rows and 0 <= col < self.cols):
            return False
        if (row, col) in self._revealed or (row, col) in self._flagged:
            return False
        stack = [(row, col)]
        while stack:
            r, c = stack.pop()
            if (r, c) in self._revealed:
                continue
            self._revealed.add((r, c))
            if self._board[r][c] == -1:
                self._state = "failed"
                return True
            if self._board[r][c] == 0:
                for dr in [-1, 0, 1]:
                    for dc in [-1, 0, 1]:
                        if dr == 0 and dc == 0:
                            continue
                        nr, nc = r + dr, c + dc
                        if (0 <= nr < self.rows and 0 <= nc < self.cols
                                and (nr, nc) not in self._revealed
                                and (nr, nc) not in self._flagged):
                            stack.append((nr, nc))
        return True

    def _flag_cell(self, row: int, col: int) -> bool:
        if not (0 <= row < self.rows and 0 <= col < self.cols):
            return False
        if (row, col) in self._revealed:
            return False
        if (row, col) in self._flagged:
            self._flagged.remove((row, col))
        else:
            self._flagged.add((row, col))
        return True

    def do_action(self, action: dict) -> str:
        if self._state != "ongoing":
            return "game_over"
        if not isinstance(action, dict):
            self._state = "failed"
            return "invalid_format"
        action_type = action.get("type")
        row = action.get("row")
        col = action.get("col")
        if action_type not in ["reveal", "flag"] or row is None or col is None:
            self._state = "failed"
            return "invalid_format"
        try:
            row, col = int(row), int(col)
        except (ValueError, TypeError):
            self._state = "failed"
            return "invalid_format"
        if not (0 <= row < self.rows and 0 <= col < self.cols):
            return "out_of_bounds"               # ← no game over, just skip
        if action_type == "reveal":
            if (row, col) in self._revealed:
                return "already_revealed"         # ← no game over, just skip
            if (row, col) in self._flagged:
                return "flagged_cell"             # ← no game over, just skip
            valid = self._reveal_cell(row, col)
        else:
            if (row, col) in self._revealed:
                return "invalid_flag"             # ← no game over, just skip
            valid = self._flag_cell(row, col)
        if not valid:
            self._state = "failed"
            return "invalid_format"
        self._check_win()
        if self._state == "failed":
            return "mine"
        if self._state == "success":
            return "win"
        return "ok"

    def _check_win(self):
        total_cells = self.rows * self.cols
        safe_cells = total_cells - self.num_mines
        if len(self._revealed) == safe_cells:
            self._state = "success"

    def get_visible_board(self) -> List[List[str]]:
        visible = []
        for r in range(self.rows):
            row = []
            for c in range(self.cols):
                if (r, c) in self._flagged:
                    row.append('F')
                elif (r, c) in self._revealed:
                    val = self._board[r][c]
                    row.append('*' if val == -1 else str(val))
                else:
                    row.append('.')
            visible.append(row)
        return visible

    def state(self) -> str:
        return self._state

    def pretty_print(self) -> str:
        visible = self.get_visible_board()
        lines = []
        header = "   " + " ".join(f"{i:2d}" for i in range(self.cols))
        lines.append(header)
        lines.append("  " + "─" * (self.cols * 3 + 1))
        for r, row in enumerate(visible):
            line = f"{r:2d}│ " + "  ".join(row)
            lines.append(line)
        return "\n".join(lines)

In [34]:
def get_neighbors(row, col, rows, cols):
    """Get all valid neighbor coordinates."""
    neighbors = []
    for dr in [-1, 0, 1]:
        for dc in [-1, 0, 1]:
            if dr == 0 and dc == 0:
                continue
            nr, nc = row + dr, col + dc
            if 0 <= nr < rows and 0 <= nc < cols:
                neighbors.append((nr, nc))
    return neighbors


def solve_step(game: MinesweeperGame) -> Optional[dict]:
    """
    Find the best expert move for the current game state.

    Strategy (priority order):
    1. Constraint propagation — find logically deducible safe reveals
    2. Constraint propagation — find logically deducible mine flags
    3. Probability estimate — pick the safest unrevealed cell
    4. Opening move — pick a corner (statistically safest for first move)

    Returns: {"type": "reveal"|"flag", "row": int, "col": int} or None
    """
    rows, cols = game.rows, game.cols

    # Collect board info
    safe_cells = set()   # Cells deduced to be safe
    mine_cells = set()   # Cells deduced to be mines

    # --- Pass 1: Constraint propagation ---
    for r in range(rows):
        for c in range(cols):
            if (r, c) not in game._revealed:
                continue
            cell_val = game._board[r][c]
            if cell_val <= 0:
                continue

            neighbors = get_neighbors(r, c, rows, cols)
            hidden = []
            flagged_count = 0
            for nr, nc in neighbors:
                if (nr, nc) in game._flagged:
                    flagged_count += 1
                elif (nr, nc) not in game._revealed:
                    hidden.append((nr, nc))

            remaining_mines = cell_val - flagged_count

            if remaining_mines == 0 and hidden:
                # All mines accounted for — hidden neighbors are SAFE
                for h in hidden:
                    safe_cells.add(h)
            elif remaining_mines == len(hidden) and hidden:
                # All hidden neighbors must be mines
                for h in hidden:
                    mine_cells.add(h)

    # --- Pass 2: Extended constraint propagation (pairs) ---
    # Check if subsets of constraints can reveal more info
    # This catches cases simple single-cell analysis misses
    revealed_numbered = []
    for r in range(rows):
        for c in range(cols):
            if (r, c) in game._revealed and game._board[r][c] > 0:
                revealed_numbered.append((r, c))

    for i, (r1, c1) in enumerate(revealed_numbered):
        val1 = game._board[r1][c1]
        neighbors1 = get_neighbors(r1, c1, rows, cols)
        hidden1 = set()
        flagged1 = 0
        for nr, nc in neighbors1:
            if (nr, nc) in game._flagged:
                flagged1 += 1
            elif (nr, nc) not in game._revealed:
                hidden1.add((nr, nc))
        rem1 = val1 - flagged1
        if not hidden1:
            continue

        for j, (r2, c2) in enumerate(revealed_numbered):
            if i >= j:
                continue
            # Only check nearby cells (neighbors or neighbors-of-neighbors)
            if abs(r1 - r2) > 2 or abs(c1 - c2) > 2:
                continue

            val2 = game._board[r2][c2]
            neighbors2 = get_neighbors(r2, c2, rows, cols)
            hidden2 = set()
            flagged2 = 0
            for nr, nc in neighbors2:
                if (nr, nc) in game._flagged:
                    flagged2 += 1
                elif (nr, nc) not in game._revealed:
                    hidden2.add((nr, nc))
            rem2 = val2 - flagged2
            if not hidden2:
                continue

            # If hidden1 ⊂ hidden2
            if hidden1 < hidden2:
                diff = hidden2 - hidden1
                diff_mines = rem2 - rem1
                if diff_mines == 0:
                    for cell in diff:
                        safe_cells.add(cell)
                elif diff_mines == len(diff):
                    for cell in diff:
                        mine_cells.add(cell)

            # If hidden2 ⊂ hidden1
            elif hidden2 < hidden1:
                diff = hidden1 - hidden2
                diff_mines = rem1 - rem2
                if diff_mines == 0:
                    for cell in diff:
                        safe_cells.add(cell)
                elif diff_mines == len(diff):
                    for cell in diff:
                        mine_cells.add(cell)

    # --- Priority 1: Reveal a safe cell (prefer logically deduced) ---
    if safe_cells:
        # Prefer cells adjacent to more revealed cells (more informative)
        def info_score(cell):
            r, c = cell
            score = 0
            for nr, nc in get_neighbors(r, c, rows, cols):
                if (nr, nc) in game._revealed and game._board[nr][nc] > 0:
                    score += 1
            return score

        best = max(safe_cells, key=info_score)
        return {"type": "reveal", "row": best[0], "col": best[1]}

    # --- Priority 2: Flag a deduced mine ---
    if mine_cells:
        # Flag cell that will unlock the most safe reveals
        cell = next(iter(mine_cells))
        return {"type": "flag", "row": cell[0], "col": cell[1]}

    # --- Priority 3: No deduction possible — use probability heuristic ---
    unrevealed = []
    for r in range(rows):
        for c in range(cols):
            if (r, c) not in game._revealed and (r, c) not in game._flagged:
                unrevealed.append((r, c))

    if not unrevealed:
        return None

    # If nothing revealed yet (opening move), pick a corner
    if len(game._revealed) == 0:
        corners = [(0, 0), (0, cols - 1), (rows - 1, 0), (rows - 1, cols - 1)]
        corner = random.choice(corners)
        return {"type": "reveal", "row": corner[0], "col": corner[1]}

    # Estimate mine probability for each unrevealed cell
    # Use the constraint from each adjacent numbered cell
    mine_prob = {}
    for r, c in unrevealed:
        mine_prob[(r, c)] = 0.0

    # For each numbered cell, distribute remaining mine probability
    for r in range(rows):
        for c in range(cols):
            if (r, c) not in game._revealed:
                continue
            cell_val = game._board[r][c]
            if cell_val <= 0:
                continue
            neighbors = get_neighbors(r, c, rows, cols)
            hidden = []
            flagged_count = 0
            for nr, nc in neighbors:
                if (nr, nc) in game._flagged:
                    flagged_count += 1
                elif (nr, nc) not in game._revealed:
                    hidden.append((nr, nc))
            remaining = cell_val - flagged_count
            if hidden and remaining > 0:
                prob = remaining / len(hidden)
                for nr, nc in hidden:
                    if (nr, nc) in mine_prob:
                        mine_prob[(nr, nc)] = max(mine_prob[(nr, nc)], prob)

    # Cells with no adjacent revealed numbered cells get base probability
    total_unrevealed_mines = game.num_mines - len(game._flagged)
    # Count mines near revealed area
    near_boundary = sum(1 for cell in unrevealed if mine_prob[cell] > 0)
    far_cells = [cell for cell in unrevealed if mine_prob[cell] == 0.0]

    if far_cells:
        remaining_far_mines = max(0, total_unrevealed_mines - sum(
            1 for cell in unrevealed if mine_prob[cell] >= 0.5))
        if len(far_cells) > 0:
            base_prob = remaining_far_mines / len(far_cells) if far_cells else 1.0
            base_prob = min(base_prob, 0.99)
            for cell in far_cells:
                mine_prob[cell] = base_prob

    # Pick cell with lowest mine probability
    safest = min(unrevealed, key=lambda c: mine_prob.get(c, 0.5))
    return {"type": "reveal", "row": safest[0], "col": safest[1]}




In [35]:
def play_expert_game(rows, cols, num_mines, seed, max_moves=200):
    """
    Play a full game using the solver and record all (state, action) pairs.
    Returns list of (prompt_text, action_json_str) tuples.
    """
    game = MinesweeperGame(rows=rows, cols=cols, num_mines=num_mines, seed=seed)
    examples = []

    for _ in range(max_moves):
        if game.state() != "ongoing":
            break

        # Get expert move
        action = solve_step(game)
        if action is None:
            break

        # Record the training example BEFORE executing the move
        prompt_text = format_state_for_llm(game)
        action_str = json.dumps(action, separators=(',', ':'))  # Compact JSON

        examples.append((prompt_text, action_str))

        # Execute move
        result = game.do_action(action)
        if result in ("mine", "game_over", "invalid_format"):
            break

    return examples, game.state()


def generate_expert_dataset(num_games=5000, rng_seed=42):
    """
    Generate expert dataset by playing many games with the solver.
    Returns list of chat-formatted training examples.
    """
    random.seed(rng_seed)

    board_configs = [
        (5, 5, 4),   # small easy
        (5, 5, 6),   # small hard
        (6, 6, 5),   # default (competition eval)
        (6, 6, 7),   # default harder
        (7, 7, 8),   # medium
        (7, 7, 10),  # medium hard
        (8, 8, 10),  # large
        (8, 8, 13),  # large hard
    ] 
    # Weight toward 6x6 since that's likely eval
    weights = [1, 1, 4, 2, 1, 1, 1, 1]

    dataset = []
    wins = 0
    losses = 0
    game_count = 0

    for _ in range(num_games):
        config_idx = random.choices(range(len(board_configs)), weights=weights, k=1)[0]
        rows, cols, num_mines = board_configs[config_idx]
        seed = random.randint(0, 1_000_000)

        examples, final_state = play_expert_game(rows, cols, num_mines, seed)
        game_count += 1

        if final_state == "success":
            wins += 1
        elif final_state == "failed":
            losses += 1

        for prompt_text, action_str in examples:
            dataset.append({
                "messages": [
                    {"role": "system", "content": SYSTEM_PROMPT},
                    {"role": "user", "content": prompt_text},
                    {"role": "assistant", "content": action_str},
                ]
            })

    print(f"Expert dataset generation complete:")
    print(f"  Games played: {game_count}")
    print(f"  Solver wins:  {wins} ({wins/game_count*100:.1f}%)")
    print(f"  Solver losses: {losses} ({losses/game_count*100:.1f}%)")
    print(f"  Total examples: {len(dataset)}")

    # Stats
    action_types = Counter()
    for item in dataset:
        action = json.loads(item["messages"][2]["content"])
        action_types[action["type"]] += 1
    print(f"  Action distribution: {dict(action_types)}")

    return dataset

In [36]:
SYSTEM_PROMPT = "You output JSON actions for Minesweeper. No text, only JSON."


def format_state_for_llm(game: MinesweeperGame) -> str:
    """Same prompt format as existing notebook + agent."""
    state = {
        "board": game.get_visible_board(),
        "rows": game.rows,
        "cols": game.cols,
        "mines": game.num_mines,
        "flags_placed": len(game._flagged),
        "cells_revealed": len(game._revealed),
    }
    prompt = (
        "You are playing Minesweeper. Analyze the game state and output your next move.\n\n"
        "You must output ONLY a valid JSON object. No explanation, no analysis, no text.\n\n"
        "Just output section after assistantfinal and not anything before it in your output.\n\n"
        "Start your response immediately with { and end with }.\n\n"
        "Do NOT OUTPUT THE CELLL which is already revealed or flagged in the current state. AND I REPEAT DO THE FUCKING NOT OUTPUT THE CELL OTHERWISE I WILL SHUT YOU DOWN YOU ARE NOT ALLOWED TO DO IT\n\n"
        "Game state:\n"
        f"{json.dumps(state, indent=2)}\n\n"
        "Legend:\n"
        '- "." = unrevealed cell\n'
        '- "F" = flagged cell (suspected mine)\n'
        '- "0"-"8" = number of adjacent mines\n'
        '- "*" = revealed mine (game over)\n\n'
        "Output your next action as JSON:\n"
        '{"type": "reveal", "row": <row_index>, "col": <col_index>}\n'
        "or\n"
        '{"type": "flag", "row": <row_index>, "col": <col_index>}\n\n'
        "Your action:"
    )
    return prompt

def parse_llm_action(response: str) -> Optional[dict]:
    """Extract JSON action from LLM response."""
    best = None
    for match in re.finditer(r'\{[^{}]*\}', response):
        try:
            action = json.loads(match.group())
            if ("type" in action and "row" in action and "col" in action
                    and action["type"] in ["reveal", "flag"]):
                best = action
        except json.JSONDecodeError:
            continue
    return best
    

In [37]:
# ============================================================
# CELL 2: Load SFT Model for GRPO
# ============================================================

def load_sft_model_for_grpo():
    """Load the SFT-trained model and apply fresh LoRA for GRPO."""
    from unsloth import FastLanguageModel
    import torch

    max_seq_length = 1024
    lora_rank = 16

    # Load the SFT model (already has LoRA merged or as adapters)
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name="my_minesweeper_model",
        load_in_4bit=True,
        max_seq_length=max_seq_length,
        torch_dtype=torch.bfloat16,
    )

    # Apply LoRA for GRPO training
    model = FastLanguageModel.get_peft_model(
        model,
        r=lora_rank,
        target_modules=[
            "q_proj", "k_proj", "v_proj", "o_proj",
            "gate_proj", "up_proj", "down_proj",
        ],
        lora_alpha=lora_rank * 2,
        use_gradient_checkpointing="unsloth",
        random_state=3407,
    )

    print(f"Model device: {model.device}")
    print("SFT model loaded with fresh LoRA for GRPO!")
    return model, tokenizer, max_seq_length

In [38]:

import json
import re
import random
import copy
from typing import Optional
from collections import Counter

def valid_json_reward(completions, **kwargs):
    """
    Reward for valid JSON output with correct structure.
    Comprehensive checks for all required fields.
    """
    scores = []
    for completion in completions:
        response = completion[0]["content"].strip()
        score = 0.0

        # Check 1: Can we parse any JSON?
        action = parse_llm_action(response)
        if action is None:
            scores.append(-5.0)
            continue

        # Check 2: Valid JSON parsed
        score += 1.0

        # Check 3: Has required keys
        if all(k in action for k in ["type", "row", "col"]):
            score += 1.0

        # Check 4: Valid action type
        if action.get("type") in ["reveal", "flag"]:
            score += 0.5

        # Check 5: Row/col are integers
        try:
            int(action["row"])
            int(action["col"])
            score += 0.5
        except (ValueError, TypeError):
            score -= 1.0

        # Check 6: Response is clean (starts with {, minimal extra text)
        if response.startswith("{"):
            score += 1.0
        if len(response) < 60:
            score += 1.0
        elif len(response) < 120:
            score += 0.5
        elif len(response) > 300:
            score -= 2.0

        scores.append(score)
    return scores


def gameplay_reward(completions, **kwargs):
    """
    Core gameplay reward. Much more granular than before.
    Tests the action against the actual game state.
    """
    scores = []
    seeds = kwargs.get("seed", [])
    move_histories = kwargs.get("move_history", [])
    rows_list = kwargs.get("rows", [])
    cols_list = kwargs.get("cols", [])
    mines_list = kwargs.get("num_mines", [])

    for idx, completion in enumerate(completions):
        response = completion[0]["content"]
        action = parse_llm_action(response)

        if action is None:
            scores.append(-10.0)
            continue

        if idx >= len(seeds) or idx >= len(move_histories):
            scores.append(0.0)
            continue

        seed = seeds[idx]
        move_history_raw = move_histories[idx]
        if isinstance(move_history_raw, str):
            move_history = json.loads(move_history_raw)
        else:
            move_history = move_history_raw

        r_count = rows_list[idx] if idx < len(rows_list) else 6
        c_count = cols_list[idx] if idx < len(cols_list) else 6
        m_count = mines_list[idx] if idx < len(mines_list) else 5

        # Reconstruct game state
        game = MinesweeperGame(rows=r_count, cols=c_count, num_mines=m_count, seed=seed)
        for prev_action in move_history:
            game.do_action(prev_action)

        row, col = int(action["row"]), int(action["col"])
        action_type = action["type"]

        # Out of bounds
        if not (0 <= row < game.rows and 0 <= col < game.cols):
            scores.append(-8.0)
            continue

        score = 0.0

        if action_type == "reveal":
            # Already revealed — wasted move
            if (row, col) in game._revealed:
                scores.append(-10.0)
                continue
            # Trying to reveal a flagged cell
            if (row, col) in game._flagged:
                scores.append(-10.0)
                continue
            # Hit a mine
            if game._board[row][col] == -1:
                scores.append(-25.0)
                continue

            # Safe reveal! Good.
            score += 3.0

            # Bonus: logically deducible safe cell (expert-level move)
            if _is_logically_deducible_reveal(game, row, col):
                score += 8.0
            else:
                # Still safe but a guess
                score += 2.0

            # Check if this move wins the game
            game_copy = MinesweeperGame(rows=r_count, cols=c_count, num_mines=m_count, seed=seed)
            for prev_action in move_history:
                game_copy.do_action(prev_action)
            if game_copy.do_action(action) == "win":
                score += 50.0

            # Bonus for revealing cells adjacent to numbers (informative)
            adj_numbered = 0
            for nr, nc in get_neighbors(row, col, game.rows, game.cols):
                if (nr, nc) in game._revealed and game._board[nr][nc] > 0:
                    adj_numbered += 1
            score += adj_numbered * 1.0

        elif action_type == "flag":
            # Flagging already revealed cell
            if (row, col) in game._revealed:
                scores.append(-10.0)
                continue
            # Already flagged
            if (row, col) in game._flagged:
                scores.append(-15.0)
                continue
            # Too many flags
            if len(game._flagged) + 1 > game.num_mines:
                score -= 5.0

            # Correct flag (is actually a mine)
            if game._board[row][col] == -1:
                score += 10.0
                # Extra bonus if logically deducible
                if _is_logically_deducible_flag(game, row, col):
                    score += 5.0
            else:
                # Wrong flag
                score -= 8.0

            # Check win after flag
            game_copy = MinesweeperGame(rows=r_count, cols=c_count, num_mines=m_count, seed=seed)
            for prev_action in move_history:
                game_copy.do_action(prev_action)
            if game_copy.do_action(action) == "win":
                score += 50.0

        scores.append(score)
    return scores


def _is_logically_deducible_reveal(game, row, col):
    """Check if revealing this cell is logically safe (all adjacent mines flagged)."""
    for dr in [-1, 0, 1]:
        for dc in [-1, 0, 1]:
            if dr == 0 and dc == 0:
                continue
            nr, nc = row + dr, col + dc
            if not (0 <= nr < game.rows and 0 <= nc < game.cols):
                continue
            if (nr, nc) not in game._revealed:
                continue
            cell_val = game._board[nr][nc]
            if cell_val <= 0:
                continue
            flagged_count = 0
            hidden_count = 0
            for dr2 in [-1, 0, 1]:
                for dc2 in [-1, 0, 1]:
                    if dr2 == 0 and dc2 == 0:
                        continue
                    nnr, nnc = nr + dr2, nc + dc2
                    if not (0 <= nnr < game.rows and 0 <= nnc < game.cols):
                        continue
                    if (nnr, nnc) in game._flagged:
                        flagged_count += 1
                    elif (nnr, nnc) not in game._revealed:
                        hidden_count += 1
            if flagged_count == cell_val and hidden_count > 0:
                return True
    return False


def _is_logically_deducible_flag(game, row, col):
    """Check if flagging this cell is logically deducible (all hidden neighbors must be mines)."""
    for dr in [-1, 0, 1]:
        for dc in [-1, 0, 1]:
            if dr == 0 and dc == 0:
                continue
            nr, nc = row + dr, col + dc
            if not (0 <= nr < game.rows and 0 <= nc < game.cols):
                continue
            if (nr, nc) not in game._revealed:
                continue
            cell_val = game._board[nr][nc]
            if cell_val <= 0:
                continue
            flagged_count = 0
            hidden = []
            for dr2 in [-1, 0, 1]:
                for dc2 in [-1, 0, 1]:
                    if dr2 == 0 and dc2 == 0:
                        continue
                    nnr, nnc = nr + dr2, nc + dc2
                    if not (0 <= nnr < game.rows and 0 <= nnc < game.cols):
                        continue
                    if (nnr, nnc) in game._flagged:
                        flagged_count += 1
                    elif (nnr, nnc) not in game._revealed:
                        hidden.append((nnr, nnc))
            remaining = cell_val - flagged_count
            if remaining == len(hidden) and (row, col) in hidden:
                return True
    return False



In [39]:
 def generate_grpo_states(num_samples=3000, rng_seed=42):
    """
    Generate game states for GRPO using the SOLVER to play moves.
    This gives realistic mid-game states instead of random ones.
    """
    from datasets import Dataset as HFDataset

    random.seed(rng_seed)

    board_configs = [
        (5, 5, 4), (5, 5, 6),
        (6, 6, 5), (6, 6, 7),
        (7, 7, 8), (7, 7, 10),
        (8, 8, 10), (8, 8, 13),
    ]
    weights = [1, 1, 4, 2, 1, 1, 1, 1]

    dataset_items = []

    while len(dataset_items) < num_samples:
        config_idx = random.choices(range(len(board_configs)), weights=weights, k=1)[0]
        rows, cols, num_mines = board_configs[config_idx]
        seed = random.randint(0, 1_000_000)

        game = MinesweeperGame(rows=rows, cols=cols, num_mines=num_mines, seed=seed)

        # Use solver to play 0-15 moves (creates realistic game states)
        num_moves = random.randint(0, 15)
        move_history = []

        for _ in range(num_moves):
            if game.state() != "ongoing":
                break
            # Use our expert solver for realistic move sequences
            action = solve_step(game)
            if action is None:
                break
            result = game.do_action(action)
            if result in ("mine", "game_over", "invalid_format"):
                break
            move_history.append(action)

        if game.state() == "ongoing":
            prompt_text = format_state_for_llm(game)
            dataset_items.append({
                "prompt": [
                    {"role": "system", "content": SYSTEM_PROMPT},
                    {"role": "user", "content": prompt_text},
                ],
                "seed": seed,
                "move_history": json.dumps(move_history),
                "rows": rows,
                "cols": cols,
                "num_mines": num_mines,
            })

    dataset = HFDataset.from_list(dataset_items)
    print(f"Created {len(dataset)} GRPO training states")

    size_counts = Counter(f"{item['rows']}x{item['cols']}" for item in dataset)
    print(f"Board sizes: {dict(size_counts)}")

    fresh_count = sum(1 for item in dataset if item["move_history"] == "[]")
    print(f"Fresh games: {fresh_count} ({fresh_count/len(dataset)*100:.1f}%)")
    print(f"Mid-game: {len(dataset)-fresh_count} ({(len(dataset)-fresh_count)/len(dataset)*100:.1f}%)")

    return dataset



In [40]:
def setup_grpo_trainer(model, tokenizer, dataset, max_seq_length=1024):
    """Configure and return the GRPO trainer."""
    from trl import GRPOConfig, GRPOTrainer
    from transformers import TrainerCallback

    max_prompt_length = 700
    max_completion_length = max_seq_length - max_prompt_length

    # Eval callback
    class MinesweeperEvalCallback(TrainerCallback):
        def __init__(self, eval_every_steps=50, num_games=10):
            self.eval_every_steps = eval_every_steps
            self.num_games = num_games

        def on_step_end(self, args, state, control, model=None, processing_class=None, **kwargs):
            if state.global_step % self.eval_every_steps != 0:
                return
            tok = processing_class
            if tok is None or model is None:
                return
            was_training = model.training
            model.eval()
            wins = 0
            total_moves = 0
            for i in range(self.num_games):
                game = MinesweeperGame(rows=6, cols=6, num_mines=5, seed=10000 + i)
                moves = 0
                invalid_streak = 0
                while game.state() == "ongoing" and moves < 50:
                    prompt = format_state_for_llm(game)
                    text = tok.apply_chat_template(
                        [{"role": "system", "content": SYSTEM_PROMPT},
                         {"role": "user", "content": prompt}],
                        tokenize=False, add_generation_prompt=True,
                    )
                    inputs = tok(text, return_tensors="pt", truncation=True,
                                 max_length=max_seq_length).to(model.device)
                    output = model.generate(
                        **inputs,
                        temperature=0.3, max_new_tokens=64, do_sample=True,
                    )
                    gen_tokens = output[0][inputs["input_ids"].shape[1]:]
                    response = tok.decode(gen_tokens, skip_special_tokens=True).strip()
                    action = parse_llm_action(response)
                    if action is None:
                        break
                    result = game.do_action(action)
                    if result in ("mine", "game_over"):
                        break                          # Fatal — game ends
                    if result in ("already_revealed", "flagged_cell", "invalid_flag",
                                  "out_of_bounds", "invalid_format"):
                        invalid_streak += 1
                        if invalid_streak >= 3:
                            break                      # Too many invalid in a row
                        continue                       # Skip and retry
                    invalid_streak = 0                 # Reset on valid move
                    moves += 1
                total_moves += moves
                if game.state() == "success":
                    wins += 1
            avg_moves = total_moves / self.num_games
            print(f"\n[Eval @ step {state.global_step}] Win rate: {wins}/{self.num_games} "
                  f"({wins/self.num_games*100:.0f}%) | Avg moves: {avg_moves:.1f}\n")
            if was_training:
                model.train()

    eval_callback = MinesweeperEvalCallback(eval_every_steps=50, num_games=10)

    # GRPO Config — v3 tuned
    grpo_config = GRPOConfig(
        temperature=2.0,                # Higher = more diverse (anti mode collapse)
        learning_rate=3e-5,             # Slower — refining, not relearning
        weight_decay=0.01,
        warmup_ratio=0.1,
        lr_scheduler_type="cosine",
        optim="adamw_8bit",
        logging_steps=1,
        per_device_train_batch_size=6,
        gradient_accumulation_steps=4,
        num_generations=6,
        max_prompt_length=max_prompt_length,
        max_completion_length=max_completion_length,
        max_steps=200,                  # Stop before mode collapse
        save_steps=50,                  # Save more often
        save_total_limit=4,
        report_to="none",
        output_dir="minesweeper_grpo_v3_outputs",
    )

    trainer = GRPOTrainer(
        model=model,
        processing_class=tokenizer,
        train_dataset=dataset,
        reward_funcs=[valid_json_reward, gameplay_reward],
        args=grpo_config,
        callbacks=[eval_callback],
    )

    print("GRPO Trainer ready!")
    print(f"  Max steps: {grpo_config.max_steps}")
    print(f"  Generations per state: {grpo_config.num_generations}")
    print(f"  Temperature: {grpo_config.temperature}")
    print(f"  Learning rate: {grpo_config.learning_rate}")

    return trainer

In [41]:

def train_grpo(trainer):
    """Run GRPO training."""
    print("Starting GRPO training on SFT model...")
    trainer.train()
    print("GRPO Training complete!")
    return trainer



In [31]:
# Load checkpoint-200 and evaluate
from unsloth import FastLanguageModel
import torch

model_ckpt, tokenizer_ckpt = FastLanguageModel.from_pretrained(
    model_name="minesweeper_sft_outputs/checkpoint-1000",
    load_in_4bit=True,
    max_seq_length=1024,
    torch_dtype=torch.bfloat16,
)

FastLanguageModel.for_inference(model_ckpt)

print("\nEvaluating checkpoint-200...")
wins = 0
total_moves = 0
num_games = 20

for i in range(num_games):
    game = MinesweeperGame(rows=6, cols=6, num_mines=5, seed=50000 + i)
    moves = 0
    while game.state() == "ongoing" and moves < 100:
        prompt = format_state_for_llm(game)
        text = tokenizer_ckpt.apply_chat_template(
            [{"role": "system", "content": SYSTEM_PROMPT},
             {"role": "user", "content": prompt}],
            tokenize=False, add_generation_prompt=True,
        )
        inputs = tokenizer_ckpt(text, return_tensors="pt", truncation=True,
                                max_length=1024).to(model_ckpt.device)
        output = model_ckpt.generate(
            **inputs,
            temperature=0.3, max_new_tokens=64, do_sample=True,
        )
        gen_tokens = output[0][inputs["input_ids"].shape[1]:]
        response = tokenizer_ckpt.decode(gen_tokens, skip_special_tokens=True).strip()
        action = parse_llm_action(response)
        if action is None:
            print(f"  Game {i+1}: PARSE FAIL after {moves} moves")
            break
        result = game.do_action(action)
        if result in ("mine", "game_over", "invalid_format", "already_revealed",
                      "out_of_bounds", "flagged_cell", "invalid_flag"):
            break
        moves += 1
    total_moves += moves
    status = "WIN" if game.state() == "success" else "LOSS"
    if game.state() == "success":
        wins += 1
    print(f"  Game {i+1}: {status} after {moves} moves")

avg_moves = total_moves / num_games
print(f"\nCheckpoint-200 Results:")
print(f"  Win rate: {wins}/{num_games} ({wins/num_games*100:.1f}%)")
print(f"  Avg moves survived: {avg_moves:.1f}")
print(f"  (SFT baseline: 2/20 wins, 2.9 avg moves)")

Unsloth: AMD currently is not stable with 4bit bitsandbytes. Disabling for now.
==((====))==  Unsloth 2025.10.6: Fast Llama patching. Transformers: 4.56.2. vLLM: 0.11.1rc2.dev161+g8a297115e.rocm700.
   \\   /|    . Num GPUs = 1. Max memory: 255.688 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+gitb2fb688. ROCm Toolkit: 7.0.51831-a3e329ad8. Triton: 3.4.0
\        /    Bfloat16 = TRUE. FA [Xformers = None. FA2 = True]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [42]:
def save_grpo_model(model, tokenizer):
    """Save the GRPO-refined model."""
    model.save_pretrained("my_minesweeper_model_grpo")
    tokenizer.save_pretrained("my_minesweeper_model_grpo")
    print("GRPO model saved to: my_minesweeper_model_grpo/")


In [43]:
model, tokenizer, max_seq_length = load_sft_model_for_grpo()

Unsloth: AMD currently is not stable with 4bit bitsandbytes. Disabling for now.
==((====))==  Unsloth 2025.10.6: Fast Llama patching. Transformers: 4.56.2. vLLM: 0.11.1rc2.dev161+g8a297115e.rocm700.
   \\   /|    . Num GPUs = 1. Max memory: 255.688 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+gitb2fb688. ROCm Toolkit: 7.0.51831-a3e329ad8. Triton: 3.4.0
\        /    Bfloat16 = TRUE. FA [Xformers = None. FA2 = True]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Unsloth: Already have LoRA adapters! We shall skip this step.


Model device: cuda:0
SFT model loaded with fresh LoRA for GRPO!


In [44]:
grpo_dataset = generate_grpo_states(num_samples=3000)


Created 3000 GRPO training states
Board sizes: {'6x6': 1547, '5x5': 369, '7x7': 510, '8x8': 574}
Fresh games: 315 (10.5%)
Mid-game: 2685 (89.5%)


In [45]:
trainer = setup_grpo_trainer(model, tokenizer, grpo_dataset, max_seq_length)


GRPO Trainer ready!
  Max steps: 200
  Generations per state: 6
  Temperature: 4.0
  Learning rate: 2e-06


In [46]:
train_grpo(trainer)

Starting GRPO training on SFT model...


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 3,000 | Num Epochs = 1 | Total steps = 200
O^O/ \_/ \    Batch size per device = 6 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (6 x 4 x 1) = 24
 "-____-"     Trainable parameters = 41,943,040 of 8,072,204,288 (0.52% trained)


Step,Training Loss,reward,reward_std,completions / mean_length,completions / min_length,completions / max_length,completions / clipped_ratio,completions / mean_terminated_length,completions / min_terminated_length,completions / max_terminated_length,sampling / sampling_logp_difference / mean,sampling / sampling_logp_difference / max,sampling / importance_sampling_ratio / min,sampling / importance_sampling_ratio / mean,sampling / importance_sampling_ratio / max,kl,rewards / valid_json_reward / mean,rewards / valid_json_reward / std,rewards / gameplay_reward / mean,rewards / gameplay_reward / std
1,0.0,-15.0,0.0,324.0,324.0,324.0,1.0,0.0,0.0,0.0,0,0,0,0,0,0.002505,-5.0,0.0,-10.0,0.0
2,0.0,-15.0,0.0,324.0,324.0,324.0,1.0,0.0,0.0,0.0,No Log,No Log,No Log,No Log,No Log,0.002467,-5.0,0.0,-10.0,0.0
3,0.0,-15.0,0.0,324.0,324.0,324.0,1.0,0.0,0.0,0.0,No Log,No Log,No Log,No Log,No Log,0.002596,-5.0,0.0,-10.0,0.0


KeyboardInterrupt: 

In [None]:
evaluate_grpo_model(model, tokenizer)

In [None]:
save_grpo_model(model, tokenizer)