In [1]:
# %pip install pygame
# %pip install numpy
# %pip install tensorflow

In [2]:
import random
import numpy as np
import tensorflow as tf

# Set a specific integer for reproducibility
SEED_VALUE = 20250506

# 1. Set seed for Python's built-in random module
random.seed(SEED_VALUE)

# 2. Set seed for NumPy operations
np.random.seed(SEED_VALUE)

# 3. Set seed for TensorFlow (affects weight initialization and some GPU operations)
tf.random.set_seed(SEED_VALUE)
# Note: You might also need to set some environment variables for complete determinism

In [3]:
import pygame
from datetime import datetime

pygame.init()

# ----------------
# Constants
# ----------------
DISPLAY_INFO = pygame.display.Info()
MONITOR_WIDTH = DISPLAY_INFO.current_w
MONITOR_HEIGHT = DISPLAY_INFO.current_h

SCREEN_WIDTH = int(MONITOR_WIDTH * 0.6)
SCREEN_HEIGHT = int(MONITOR_HEIGHT * 0.8)

GRID_SIZE = 8
AVAILABLE_HEIGHT = SCREEN_HEIGHT - 230
AVAILABLE_WIDTH = SCREEN_WIDTH - 40

TILE_SIZE = min(AVAILABLE_HEIGHT // GRID_SIZE, AVAILABLE_WIDTH // GRID_SIZE)
MARGIN = 150

MINE_COUNT = 5
FLAG_LIMIT = 5
MAX_LIVES = 3

FILE_NAME = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}"
LOG_NAME = f"a2c_logs_{FILE_NAME}.txt"
MODEL_NAME = f"a2c_model_{FILE_NAME}"

AUTO_QUIT = False

  from pkg_resources import resource_stream, resource_exists


pygame 2.6.1 (SDL 2.28.4, Python 3.12.5)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [4]:
import os
import json
from datetime import datetime
from A2C import A2C

def saveModel(modelName: str, agent: A2C):
    """
    Saves the agent's model weights and training metrics to the './models' directory.

    :param modelName: The name for the save file (e.g., "run1").
    :type modelName: str
    
    :param agent: The A2C agent instance to save.
    :type agent: A2C
    """
    # Ensure the save directory exists
    modelsDir = "./models"
    if not os.path.exists(modelsDir):
        os.makedirs(modelsDir)
    
    # Define file paths
    filePath = os.path.join(modelsDir, modelName)
    weightsPath = f'{filePath}_model.weights.h5'
    metricsPath = f'{filePath}_metrics.json'
    
    print(f"\nSaving progress to {filePath}...")
    
    # Save the model's weights
    agent.model.save_weights(weightsPath)
    
    # Prepare metrics data for saving
    metricsData = {
        "EPISODES": agent.EPISODES,
        "ELAPSED_TIME": agent.ELAPSED_TIME,
        "GAME_CONCLUSIONS": agent.GAME_CONCLUSIONS,
        "WIN_RATE": agent.WIN_RATE,
        "HIGHEST_SCORE": agent.HIGHEST_SCORE,
        "SCORES": agent.SCORES,
        "TIME_EFFICIENCY": agent.TIME_EFFICIENCY,
        "CLICKS": agent.CLICKS,
    }
    
    # Save metrics to a JSON file
    with open(metricsPath, 'w') as f:
        json.dump(metricsData, f, indent = 4)
        
    print("Progress saved successfully.")

def loadModel(modelName: str, agent: A2C) -> A2C:
    """
    Loads the agent's model weights and training metrics from the './models' directory.

    :param modelName: The name of the save file to load (e.g., "run1").
    :type modelName: str
    
    :param agent: The A2C agent instance to load progress into.
    :type agent: A2C
    
    :returns: The agent instance with loaded progress.
    :rtype: A2C
    """
    # Define file paths
    filePath = os.path.join("./models", modelName)
    weightsPath = f'{filePath}_model.weights.h5'
    metricsPath = f'{filePath}_metrics.json'
    
    # Check if both weights and metrics files exist
    if os.path.exists(weightsPath) and os.path.exists(metricsPath):
        print(f"\nLoading progress from {filePath}...")
        
        # Load the model's weights
        agent.model.load_weights(weightsPath)
        
        # Load metrics from the JSON file
        with open(metricsPath, 'r') as f:
            metricsData = json.load(f)
        
        # Populate the agent's attributes with the loaded data
        agent.EPISODES = metricsData.get("EPISODES", 1)
        agent.ELAPSED_TIME = metricsData.get("ELAPSED_TIME", [])
        agent.GAME_CONCLUSIONS = metricsData.get("GAME_CONCLUSIONS", [])
        agent.WIN_RATE = metricsData.get("WIN_RATE", 0.0)
        agent.HIGHEST_SCORE = metricsData.get("HIGHEST_SCORE", 0)
        agent.SCORES = metricsData.get("SCORES", [])
        agent.TIME_EFFICIENCY = metricsData.get("TIME_EFFICIENCY", 0.0)
        agent.CLICKS = metricsData.get("CLICKS", {"left": [], "right": [], "total": []})
        
        # Ensure the learning rate is correct for the loaded episode
        if isinstance(agent.LR, dict):
            currentLR = agent.optimizer.learning_rate.numpy()
            newLR = currentLR
            
            # Find the latest episode milestone that has been passed
            for milestone in sorted(agent.LR.keys()):
                if agent.EPISODES >= milestone:
                    newLR = agent.LR[milestone]
            
            # If the learning rate needs to be changed, update the optimizer
            if newLR != currentLR:
                print(f"[LR Scheduler] Loaded at Episode {agent.EPISODES}: Setting learning rate to {newLR:.6f}")
                agent.optimizer.learning_rate.assign(newLR)
        
        print("Progress loaded successfully.")
    else:
        print("\nNo saved progress found. Starting a new training run.")
        
    return agent

def log(string: str) -> str:
    if not os.path.exists("logs"):
        os.makedirs("logs")

    logName = os.path.join("logs", LOG_NAME)

    if not os.path.exists(logName):
        with open(logName, "w") as f:
            f.write(f"Log created on {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
            f.write("="*50 + "\n")

    with open(logName, "a") as f:
        f.write(string + "\n")

    return string


In [5]:
import pygame
import random

# Initialize Pygame
pygame.init()

# Get display info and set screen size based on monitor
display_info = pygame.display.Info()
MONITOR_WIDTH = display_info.current_w
MONITOR_HEIGHT = display_info.current_h

# Set screen size to a percentage of monitor size
SCREEN_WIDTH = int(MONITOR_WIDTH * 0.6)  # 60% of monitor width
SCREEN_HEIGHT = int(MONITOR_HEIGHT * 0.8)  # 80% of monitor height

# Adjust tile size based on screen size
GRID_SIZE = 8
AVAILABLE_HEIGHT = SCREEN_HEIGHT - 230  # Subtract margins (150 top + 80 bottom)
AVAILABLE_WIDTH = SCREEN_WIDTH - 40  # Some padding

# Calculate tile size to fit the grid
TILE_SIZE = min(AVAILABLE_HEIGHT // GRID_SIZE, AVAILABLE_WIDTH // GRID_SIZE)
MARGIN = 150

# Game constants
MINE_COUNT = 5
FLAG_LIMIT = 5
LIVES = 3

# Colors
BG_COLOR = (40, 44, 52)
PANEL_COLOR = (30, 34, 42)
UNREVEALED_COLOR = (58, 95, 135)
UNREVEALED_HOVER_COLOR = (75, 115, 160)
REVEALED_COLOR = (220, 220, 225)
MINE_COLOR = (215, 95, 95)
FLAG_COLOR = (255, 50, 50)
BORDER_COLOR = (20, 20, 25)
LIFE_COLOR = (255, 80, 80)

# Number colors (classic minesweeper colors)
NUMBER_COLORS = {
    1: (45, 85, 165),
    2: (60, 145, 70),
    3: (185, 60, 60),
    4: (35, 60, 135),
    5: (150, 50, 50),
    6: (50, 140, 140),
    7: (40, 40, 40),
    8: (100, 100, 100)
}

# Create resizable window
screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT), pygame.RESIZABLE)
pygame.display.set_caption("Minesweeper RL")

class Tile:
    def __init__(self, x, y):
        self.x = x
        self.y = y
        self.is_mine = False
        self.is_revealed = False
        self.is_flagged = False
        self.adjacent_mines = 0
        self.hover = False

class MinesweeperGame:
    def __init__(self):
        self.grid = [[Tile(x, y) for y in range(GRID_SIZE)] for x in range(GRID_SIZE)]
        self.game_over = False
        self.game_won = False
        self.flags_placed = 0
        self.flag_limit = FLAG_LIMIT
        self.revealed_count = 0
        self.score = 0
        self.start_time = pygame.time.get_ticks()
        self.end_time = None
        self.combo = 0
        self.max_combo = 0
        self.lives = LIVES
        self.total_mines = MINE_COUNT
        self.mines_hit = []  # Track which mines were hit
        self.setup_mines()
        self.calculate_adjacent_mines()
    
    def setup_mines(self):
        mines_placed = 0
        while mines_placed < self.total_mines:
            x = random.randint(0, GRID_SIZE - 1)
            y = random.randint(0, GRID_SIZE - 1)
            if not self.grid[x][y].is_mine:
                self.grid[x][y].is_mine = True
                mines_placed += 1
    
    def calculate_adjacent_mines(self):
        for x in range(GRID_SIZE):
            for y in range(GRID_SIZE):
                if not self.grid[x][y].is_mine:
                    count = 0
                    for dx in [-1, 0, 1]:
                        for dy in [-1, 0, 1]:
                            if dx == 0 and dy == 0:
                                continue
                            nx, ny = x + dx, y + dy
                            if 0 <= nx < GRID_SIZE and 0 <= ny < GRID_SIZE:
                                if self.grid[nx][ny].is_mine:
                                    count += 1
                    self.grid[x][y].adjacent_mines = count
    
    def reveal_tile(self, x, y):
        if self.game_over or self.game_won:
            return
        
        tile = self.grid[x][y]
        if tile.is_revealed or tile.is_flagged:
            return
        
        tile.is_revealed = True
        self.revealed_count += 1
        
        if tile.is_mine:
            self.lives -= 1
            self.combo = 0  # Reset combo on mine hit
            self.mines_hit.append((x, y))
            
            if self.lives <= 0:
                self.game_over = True
                self.end_time = pygame.time.get_ticks()
                self.reveal_all_mines()
            return
        
        # Add score based on tile value
        if tile.adjacent_mines == 0:
            self.score += 10
            self.combo += 1
        else:
            self.score += tile.adjacent_mines * 5
            self.combo += 1
        
        # Combo bonus
        if self.combo > 5:
            self.score += self.combo * 2
        
        self.max_combo = max(self.max_combo, self.combo)
        
        if tile.adjacent_mines == 0:
            for dx in [-1, 0, 1]:
                for dy in [-1, 0, 1]:
                    if dx == 0 and dy == 0:
                        continue
                    nx, ny = x + dx, y + dy
                    if 0 <= nx < GRID_SIZE and 0 <= ny < GRID_SIZE:
                        self.reveal_tile(nx, ny)
        
        # Check win condition: all non-mine tiles revealed
        non_mine_tiles = GRID_SIZE * GRID_SIZE - self.total_mines
        safe_tiles_revealed = self.revealed_count - len(self.mines_hit)
        
        if safe_tiles_revealed == non_mine_tiles:
            self.game_won = True
            self.end_time = pygame.time.get_ticks()
            # Win bonus
            time_bonus = max(0, 1000 - (self.end_time - self.start_time) // 100)
            self.score += time_bonus
            # Perfect flag bonus
            if self.flags_placed == MINE_COUNT:
                self.score += 500
            # Lives bonus
            self.score += self.lives * 200
    
    def toggle_flag(self, x, y):
        if self.game_over or self.game_won:
            return
        tile = self.grid[x][y]
        if not tile.is_revealed:
            if tile.is_flagged:
                tile.is_flagged = False
                self.flags_placed -= 1
                self.combo = 0  # Reset combo on flag remove
                # Cancel bonus for removing a correct flag
                if tile.is_mine:
                    self.score -= 50
            elif self.flags_placed < self.flag_limit:
                tile.is_flagged = True
                self.flags_placed += 1
                # Bonus for correct flag
                if tile.is_mine:
                    self.score += 50
                self.combo = 0  # Reset combo on flag
    
    def reveal_all_mines(self):
        for x in range(GRID_SIZE):
            for y in range(GRID_SIZE):
                if self.grid[x][y].is_mine:
                    self.grid[x][y].is_revealed = True


In [6]:
from typing import Optional
from A2C import A2C  # Assuming A2C is defined in A2C.py

def get_tile_from_mouse(mx, my):
    if my < MARGIN or my >= MARGIN + GRID_SIZE * TILE_SIZE:
        return None, None
    if mx < (SCREEN_WIDTH - GRID_SIZE * TILE_SIZE) // 2:
        return None, None
    if mx >= (SCREEN_WIDTH - GRID_SIZE * TILE_SIZE) // 2 + GRID_SIZE * TILE_SIZE:
        return None, None
    
    offset_x = (SCREEN_WIDTH - GRID_SIZE * TILE_SIZE) // 2
    grid_x = (mx - offset_x) // TILE_SIZE
    grid_y = (my - MARGIN) // TILE_SIZE
    
    if 0 <= grid_x < GRID_SIZE and 0 <= grid_y < GRID_SIZE:
        return int(grid_x), int(grid_y)
    
    return None, None

def draw_game(game, agent: Optional[A2C] = None):
    screen.fill(BG_COLOR)
    
    # Draw top panel
    pygame.draw.rect(screen, PANEL_COLOR, (0, 0, SCREEN_WIDTH, MARGIN))
    
    # Draw bottom panel
    pygame.draw.rect(screen, PANEL_COLOR, (0, SCREEN_HEIGHT - 80, SCREEN_WIDTH, 80))
    
    # Calculate grid offset for centering
    offset_x = (SCREEN_WIDTH - GRID_SIZE * TILE_SIZE) // 2
    offset_y = MARGIN
    
    # Draw grid
    for x in range(GRID_SIZE):
        for y in range(GRID_SIZE):
            tile = game.grid[x][y]
            rect = pygame.Rect(offset_x + x * TILE_SIZE, offset_y + y * TILE_SIZE, TILE_SIZE, TILE_SIZE)
            
            # Draw border
            pygame.draw.rect(screen, BORDER_COLOR, rect, 3)
            
            # Shrink rect for inner tile
            inner_rect = rect.inflate(-6, -6)
            
            if tile.is_revealed:
                if tile.is_mine:
                    pygame.draw.rect(screen, MINE_COLOR, inner_rect)
                    # Draw mine circle
                    center = inner_rect.center
                    pygame.draw.circle(screen, (40, 40, 40), center, 20)
                    # Draw X if this mine was hit
                    if (x, y) in game.mines_hit:
                        pygame.draw.line(screen, (255, 255, 255), 
                                       (center[0] - 12, center[1] - 12), 
                                       (center[0] + 12, center[1] + 12), 4)
                        pygame.draw.line(screen, (255, 255, 255), 
                                       (center[0] + 12, center[1] - 12), 
                                       (center[0] - 12, center[1] + 12), 4)
                else:
                    pygame.draw.rect(screen, REVEALED_COLOR, inner_rect)
                    if tile.adjacent_mines > 0:
                        # Draw number
                        font = pygame.font.Font(None, 60)
                        text = font.render(str(tile.adjacent_mines), True, NUMBER_COLORS[tile.adjacent_mines])
                        text_rect = text.get_rect(center=inner_rect.center)
                        screen.blit(text, text_rect)
            else:
                if tile.hover:
                    pygame.draw.rect(screen, UNREVEALED_HOVER_COLOR, inner_rect)
                else:
                    pygame.draw.rect(screen, UNREVEALED_COLOR, inner_rect)
                
                if tile.is_flagged:
                    # Draw flag triangle
                    center_x, center_y = inner_rect.center
                    # Flag pole
                    pygame.draw.rect(screen, (60, 60, 60), (center_x - 2, center_y - 20, 4, 35))
                    # Flag triangle
                    flag_points = [
                        (center_x + 2, center_y - 18),
                        (center_x + 22, center_y - 8),
                        (center_x + 2, center_y + 2)
                    ]
                    pygame.draw.polygon(screen, FLAG_COLOR, flag_points)
                    pygame.draw.polygon(screen, (180, 30, 30), flag_points, 2)

    # Visualize Last Agent Action
    if agent is not None and agent.LAST_ACTION is not None:
        x, y, action_type = agent.LAST_ACTION
        
        offset_x = (SCREEN_WIDTH - GRID_SIZE * TILE_SIZE) // 2
        offset_y = MARGIN
        
        action_x = offset_x + x * TILE_SIZE + TILE_SIZE // 2
        action_y = offset_y + y * TILE_SIZE + TILE_SIZE // 2
        
        # Blue circle for reveal, red for flag
        color = (100, 200, 255) if action_type == 0 else (255, 100, 100)
        pygame.draw.circle(screen, color, (action_x, action_y), TILE_SIZE // 2, 5)
    
    # Draw UI text
    font_large = pygame.font.Font(None, 48)
    font_medium = pygame.font.Font(None, 36)
    font_small = pygame.font.Font(None, 28)
    font_xsmall = pygame.font.Font(None, 20)
    
    # Calculate time
    if game.end_time:
        elapsed = (game.end_time - game.start_time) // 1000
    else:
        elapsed = (pygame.time.get_ticks() - game.start_time) // 1000
    
    if game.game_over:
        text = font_large.render("GAME OVER!", True, (235, 100, 95))
        screen.blit(text, (SCREEN_WIDTH // 2 - text.get_width() // 2, 30))
        score_text = font_medium.render(f"Final Score: {game.score}", True, (255, 200, 100))
        screen.blit(score_text, (SCREEN_WIDTH // 2 - score_text.get_width() // 2, 80))
    elif game.game_won:
        text = font_large.render("YOU WIN!", True, (90, 200, 110))
        screen.blit(text, (SCREEN_WIDTH // 2 - text.get_width() // 2, 30))
        score_text = font_medium.render(f"Final Score: {game.score}", True, (100, 255, 150))
        screen.blit(score_text, (SCREEN_WIDTH // 2 - score_text.get_width() // 2, 80))
        
        # Show achievements
        achievements = []
        if elapsed < 60:
            achievements.append("⚡ Speed Demon!")
        if game.max_combo >= 10:
            achievements.append("🔥 Combo Master!")
        if game.flags_placed == MINE_COUNT:
            achievements.append("🎯 Perfect Flags!")
        if game.lives == LIVES:
            achievements.append("💎 Flawless Victory!")
        
        y_pos = 120
        for achievement in achievements:
            ach_text = font_small.render(achievement, True, (255, 215, 0))
            screen.blit(ach_text, (SCREEN_WIDTH // 2 - ach_text.get_width() // 2, y_pos))
            y_pos += 30
    else:
        # Top left - Stats
        text1 = font_medium.render(f"Score: {game.score}", True, (255, 215, 0))
        text2 = font_medium.render(f"Time: {elapsed}s", True, (100, 200, 255))
        screen.blit(text1, (30, 30))
        screen.blit(text2, (30, 65))
        
        # Draw lives as hearts
        heart_y = 100
        for i in range(LIVES):
            heart_x = 40 + i * 35
            if i < game.lives:
                # Filled heart
                color = LIFE_COLOR
            else:
                # Empty heart
                color = (80, 80, 80)
            
            # Draw simple heart shape
            pygame.draw.circle(screen, color, (heart_x, heart_y), 10)
            pygame.draw.circle(screen, color, (heart_x + 14, heart_y), 10)
            heart_points = [
                (heart_x - 10, heart_y + 3),
                (heart_x + 7, heart_y + 20),
                (heart_x + 24, heart_y + 3)
            ]
            pygame.draw.polygon(screen, color, heart_points)
        
        # Top right - Game info
        mines_left = MINE_COUNT - game.flags_placed
        text3 = font_medium.render(f"Mines: {mines_left}", True, (240, 200, 120))
        text4 = font_medium.render(f"Flags: {game.flags_placed}/{FLAG_LIMIT}", True, (120, 180, 230))
        screen.blit(text3, (SCREEN_WIDTH - 200, 30))
        screen.blit(text4, (SCREEN_WIDTH - 200, 65))
        
        # Combo indicator
        if game.combo > 3:
            combo_text = font_medium.render(f"Combo x{game.combo}!", True, (255, 100, 255))
            yPos = 30 if agent is None else 5
            screen.blit(combo_text, (SCREEN_WIDTH // 2 - combo_text.get_width() // 2, yPos))

    # Display metrics for RL agent
    if agent is not None:
        metrics_y = 36
        metrics = [
            agent.printMetrics(["Win Rate"], True),
            agent.printMetrics(["Episodes", "Current Episode"], True),
            agent.printMetrics(["Highest Score"], True),
            agent.printMetrics(["Time Efficiency"], True),
            agent.printMetrics(["Left Click Avg", "Right Click Avg", ], True),
            agent.printMetrics(["Total Click Avg"], True),
            agent.printMetrics(["Game Conclusion History"], True)
        ]

        for metric in metrics:
            metric_text = font_xsmall.render(metric, True, (200, 200, 200))
            screen.blit(metric_text, (SCREEN_WIDTH // 2 - metric_text.get_width() // 2, metrics_y))
            metrics_y += metric_text.get_height()
    
    # Bottom instructions
    text1 = font_small.render("Left Click: Reveal", True, (200, 200, 200))
    text2 = font_small.render("Right Click: Flag", True, (200, 200, 200))
    text3 = font_small.render("R: Restart", True, (200, 200, 200))

    # Mouse control hints
    if agent is not None:
        text4 = font_small.render("Blue: Reveal", True, (100, 200, 255))
        text5 = font_small.render("Red: Flag", True, (255, 100, 100))
    
    screen.blit(text1, (30, SCREEN_HEIGHT - 50))
    screen.blit(text2, (250, SCREEN_HEIGHT - 50))
    screen.blit(text3, (450, SCREEN_HEIGHT - 50))
    screen.blit(text4, (SCREEN_WIDTH - 150, SCREEN_HEIGHT - 75))
    screen.blit(text5, (SCREEN_WIDTH - 150, SCREEN_HEIGHT - 25))


# Key Implementation Changes

The initial proposal provided a strong foundation. The following changes were implemented during development to improve the training effectiveness and final performance of the Actor-Critic agent.

## Enhanced Agent State ($S_t$)

The agent's state representation was enhanced to provide a richer understanding of the environment. This allows the agent to make more intelligent decisions and prevents it from getting stuck in unproductive loops.

The original proposed state included:

- Visible Grid
- Score
- Lives
- Flags
- Combo
- Game Status (active/won/lost)

The final implementation was updated to these:

- **Visible Grid:** The board state.
- **Lives:** Tracks how many lives are left in the episode.
- **Combo:** Tracks the real-time combo in the run.
- **Flags:** Tracks how many flags are left allowed to be place on the board.
- **Progress:** Tracks the ratio of revealed safe tiles to the total number of safe tiles.
- **Interaction History:** Encodes how recently the agent has interacted with each tiles. Basically a history of actions.

## Simplified Reward Function ($R_t$)

The initial reward function, tied to the complex in-game score, proved to be a noisy signal that limited the agent's performance (resulting in a 5-13% win rate). Therefore, it was replaced with a simplified, sparse reward function to provide a clearer learning objective.

The new function is based on four simple rules:

- **Win Bonus:** A large positive reward of **+100** is given when the agent successfully clears all non-mine tiles. This defines the ultimate goal.
- **Loss Penalty:** A large negative penalty of **-100** is given when the agent loses its last life or runs out of time. This clearly punishes failure.
- **Progress Reward:** A small positive reward of **+1 is given for each new safe tile** the agent successfully reveals in a single turn. This encourages the agent to explore and clear the board.
- **Efficiency Penalty:** A tiny negative penalty of **-0.1 is applied for every action taken**. This incentivizes the agent to win in as few steps as possible.

In [7]:
from A2C import A2C
import numpy as np

game = None
agent = None

def main():
    global SCREEN_WIDTH, SCREEN_HEIGHT, TILE_SIZE, screen, game, agent

    game = MinesweeperGame()
    agent = A2C(
        gridSize = GRID_SIZE,
        maxLives = LIVES,
        gameInstance = game,
        discountFactor = 0.999,
        betaEntropy = 0.015,
        lr = {
            0: 0.001,
            2000: 0.0005,
            5000: 0.0001,
		}
    )
    
    clock = pygame.time.Clock()
    running = True

    def clickHandler():
        mx, my = pygame.mouse.get_pos()
        gridX, gridY = get_tile_from_mouse(mx, my)
            
        if gridX is not None:
            if event.button == 1:  # Left click
                game.reveal_tile(gridX, gridY)
            elif event.button == 3:  # Right click
                game.toggle_flag(gridX, gridY)
        return

    def agentTurnLogic():
        global game, agent

        done = False
        elapsedTime = ((pygame.time.get_ticks() - game.start_time) / 1000)
        
        if not (game.game_over or game.game_won):
            
            # Encode Current State
            stateCurrent = agent._encodeState()
            
            # Agent Selects Action
            x, y, actionType, actionIdx, _ = agent.chooseAction(stateCurrent)
            
            # Update action history and interaction tracking
            agent.ACTION_HISTORY.append((x, y, actionType))
            if len(agent.ACTION_HISTORY) > agent.ACTION_HISTORY_SIZE:
                agent.ACTION_HISTORY.pop(0)
            
            agent.TILE_INTERACTION_HISTORY[(x, y)] = len(agent.ACTION_HISTORY)
            agent.LAST_ACTION = (x, y, actionType)
            
            # Store pre-action state to measure outcome
            livesOld = game.lives
            revealedOld = game.revealed_count
            reward = 0

            # Execute Action
            if actionType == 0:
                game.reveal_tile(x, y)
            else:
                game.toggle_flag(x, y)

            # -------------------------------
            # Simplified Reward Calculation
            # -------------------------------
            # Apply small "cost of living" penalty to encourage efficiency
            reward -= 0.1
            
            # Reward progress for revealing new safe tiles
            mine_hit = livesOld > game.lives
            if not mine_hit:
                newly_revealed = game.revealed_count - revealedOld
                if newly_revealed > 0:
                    reward += newly_revealed * 1.0

            # Check for game termination
            done = game.game_over or game.game_won
            if done:
                if game.game_won:
                    reward += 100  # Large bonus for winning
                else:
                    reward += -100 # Large penalty for losing
            
            # ------------------
            # Click Metric Update
            # ------------------
            if len(agent.CLICKS["left"]) < agent.EPISODES:
                agent.CLICKS["left"].append(0)
                agent.CLICKS["right"].append(0)
                agent.CLICKS["total"].append(0)

            agent.CLICKS["total"][agent.EPISODES - 1] += 1
            if actionType == 0:
                agent.CLICKS["left"][agent.EPISODES - 1] += 1
            else:
                agent.CLICKS["right"][agent.EPISODES - 1] += 1

            # Terminate if game is taking too long
            if elapsedTime > 45:
                game.game_over = True
                done = True
                if not game.game_won:
                    # Ensure loss penalty is applied on timeout
                    reward = -100
                print(log("Game terminated: Time limit exceeded"))

            # Reward clipping to prevent extreme values
            reward = np.clip(reward, -100, 100)
            
            # Get Next State and Train Agent
            stateNext = agent._encodeState()
            loss, tdError = agent.train(stateCurrent, actionIdx, reward, stateNext, done)
            
            # Print progress for debugging
            if agent.CLICKS["total"][agent.EPISODES - 1] % 10 == 0 or done:
                print(f"[Ep {agent.EPISODES}, Step {agent.CLICKS['total'][agent.EPISODES - 1]}] "
                    f"Score: {game.score} | Reward: {reward:.2f} | TD Error: {tdError:.2f} | "
                    f"Action: ({x}, {y}) {'Reveal' if actionType == 0 else 'Flag'} | "
                    f"Revealed: {game.revealed_count}/{GRID_SIZE * GRID_SIZE - game.total_mines}")

        # Reset game if episode is finished
        if game.game_over or game.game_won:
            draw_game(game, agent)
            pygame.display.flip()
            pygame.time.delay(1000)

            # Episode Checkpoint
            if agent.EPISODES % 100 == 0:
                if len(agent.GAME_CONCLUSIONS) >= 300:
                    last100WR = sum(agent.GAME_CONCLUSIONS[-100:]) / 100
                    last200WR = sum(agent.GAME_CONCLUSIONS[-200:-100]) / 100
                    last300WR = sum(agent.GAME_CONCLUSIONS[-300:-200]) / 100

                    print(log(f"\n[Win Rate Trend]:"))
                    print(log(f"  Episodes {agent.EPISODES-300}-{agent.EPISODES-200}: {last300WR:.1%}"))
                    print(log(f"  Episodes {agent.EPISODES-200}-{agent.EPISODES-100}: {last200WR:.1%}"))
                    print(log(f"  Episodes {agent.EPISODES-100}-{agent.EPISODES}: {last100WR:.1%}"))

                    if last100WR > last200WR:
                        print(log("  🟢 Improving trend"))
                    elif last100WR < last200WR:
                        print(log("  🟡 Declining trend - consider hyperparameter adjustment"))
                    else:
                        print(log("  🟡 Flat trend - agent may be stuck"))

            # Re-initialize the game and agent state for next episode
            game = MinesweeperGame()
            agent.endEpisode(game)

            # ----------------------------
            # Print Metrics
            # ----------------------------
            print(log(f"\n{'='*60}"))
            print(log(f"Episode {agent.EPISODES - 1} finished: {'WIN' if agent.GAME_CONCLUSIONS[-1] else 'LOSS'} | Score: {agent.SCORES[-1]}"))
            print(log(agent.printMetrics(asString = True)))
            print(log(f"{'='*60}\n"))

        return done

    while running:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                running = False
            elif event.type == pygame.VIDEORESIZE:
                # Handle window resize
                SCREEN_WIDTH, SCREEN_HEIGHT = event.w, event.h
                screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT), pygame.RESIZABLE)
                
                # Recalculate tile size
                AVAILABLE_HEIGHT = SCREEN_HEIGHT - 230
                AVAILABLE_WIDTH = SCREEN_WIDTH - 40
                TILE_SIZE = min(AVAILABLE_HEIGHT // GRID_SIZE, AVAILABLE_WIDTH // GRID_SIZE)
                
            elif event.type == pygame.MOUSEBUTTONDOWN:
                clickHandler()
            elif event.type == pygame.KEYDOWN:
                if event.key == pygame.K_r:
                    game = MinesweeperGame()

        # Let the agent take its turn
        if agent is not None:
            gameDone = agentTurnLogic()

            # Stop after 300 episodes
            if agent.EPISODES > 300 and gameDone:
                saveModel(MODEL_NAME, agent)
                running = False

        # Update hover state
        mx, my = pygame.mouse.get_pos()
        hoverX, hoverY = get_tile_from_mouse(mx, my)
        for x in range(GRID_SIZE):
            for y in range(GRID_SIZE):
                game.grid[x][y].hover = (x == hoverX and y == hoverY)
        
        draw_game(game, agent)
        pygame.display.flip()
        clock.tick(60)

    if AUTO_QUIT:
    	pygame.quit()

if __name__ == "__main__":
    main()

[Ep 1, Step 10] Score: 2440 | Reward: -100.00 | TD Error: -10.00 | Action: (6, 2) Reveal | Revealed: 49/59

Episode 1 finished: LOSS | Score: 2440
Episodes: 1 | Win Rate: 0.00% | Highest Score: 2440.00 | Time Efficiency: 3.06 steps/sec | Left Click Avg: 5.00 c/g (1.53 c/s) | Right Click Avg: 5.00 c/g (1.53 c/s) | Total Click Avg: 10.00 c/g (3.06 c/s) | Game Conclusion History: 0 Wins, 1 Losses

[Ep 2, Step 10] Score: 1706 | Reward: -0.10 | TD Error: 0.33 | Action: (1, 1) Flag | Revealed: 50/59
[Ep 2, Step 15] Score: 1711 | Reward: -100.00 | TD Error: -10.00 | Action: (0, 3) Reveal | Revealed: 53/59

Episode 2 finished: LOSS | Score: 1711
Episodes: 2 | Win Rate: 0.00% | Highest Score: 2440.00 | Time Efficiency: 3.64 steps/sec | Left Click Avg: 7.50 c/g (2.18 c/s) | Right Click Avg: 5.00 c/g (1.45 c/s) | Total Click Avg: 12.50 c/g (3.64 c/s) | Game Conclusion History: 0 Wins, 2 Losses

[Ep 3, Step 10] Score: 3017 | Reward: -0.10 | TD Error: -0.09 | Action: (7, 7) Flag | Revealed: 57/59
[

In [None]:
# pygame.quit()

In [9]:
# saveModel(MODEL_NAME, agent)