In [1]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import chess
import pygame
import torch
from ChessPredictorClass import ChessMovePredictor
from env_graphical import draw_board  # Import your existing draw_board function


class ChessPygameEnv(gym.Env):
    """
    Custom Gym Environment integrating Pygame for rendering and handling chess gameplay.
    """
    metadata = {"render_modes": ["human"], "render_fps": 10}

    def __init__(self, model_path="chess_move_predictor.pth", render_mode="human"):
        super().__init__()
        
        # Validate render mode
        assert render_mode in self.metadata["render_modes"], f"Invalid render_mode: {render_mode}"
        self.render_mode = render_mode

        # Initialize Pygame
        pygame.init()
        self.screen = pygame.display.set_mode((640, 640)) if render_mode == "human" else None
        pygame.display.set_caption("Chess Gym Environment")

        # Initialize Chess Board
        self.board = chess.Board()

        # Load your pre-trained model
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = ChessMovePredictor().to(self.device)
        self.model.load_state_dict(torch.load(model_path, map_location=self.device))

        # Define observation and action spaces
        self.observation_space = spaces.Box(
            low=0, high=1, shape=(12, 8, 8), dtype=np.float32
        )
        self.action_space = spaces.Discrete(64 * 64)  # Actions are from-square and to-square indices

    def fen_to_tensor(self, fen):
        """Convert board state (FEN) to a tensor."""
        board = chess.Board(fen)
        tensor = np.zeros((12, 8, 8), dtype=np.float32)
        piece_map = {
            "P": 0, "N": 1, "B": 2, "R": 3, "Q": 4, "K": 5,
            "p": 6, "n": 7, "b": 8, "r": 9, "q": 10, "k": 11
        }
        for square in chess.SQUARES:
            piece = board.piece_at(square)
            if piece:
                row = 7 - (square // 8)
                col = square % 8
                tensor[piece_map[piece.symbol()], row, col] = 1.0
        return tensor

    def reset(self, seed=None, options=None):
        """Reset the environment to the initial chess position."""
        super().reset(seed=seed)  # Seed the environment
        self.board.reset()
        if self.render_mode == "human":
            self._render()
        return self.fen_to_tensor(self.board.fen()), {}

    def step(self, action):
        """Apply an action to the board."""
        from_square = action // 64
        to_square = action % 64
        move = chess.Move(from_square, to_square)

        reward = 0
        if move in self.board.legal_moves:
            self.board.push(move)

            # Reward logic based on game state
            if self.board.is_checkmate():
                reward = 1  # Winning is rewarded
                done = True
            elif self.board.is_stalemate() or self.board.is_insufficient_material():
                reward = 0.5  # Draw has a smaller reward
                done = True
            else:
                reward = 0  # Neutral for valid moves
                done = False
        else:
            reward = -1  # Penalize invalid moves
            done = True  # End the game on invalid move

        truncated = False  # Define truncation logic if necessary

        if self.render_mode == "human":
            self._render()

        return self.fen_to_tensor(self.board.fen()), reward, done, truncated, {}


    def render(self):
        """Render the environment."""
        if self.render_mode == "human":
            self._render()

    def close(self):
        """Close the environment."""
        if self.render_mode == "human":
            pygame.quit()

    def _render(self):
        """Draw the board using your existing Pygame logic."""
        if self.screen is not None:
            self.screen.fill((0, 0, 0))
            draw_board(self.screen, self.board)
            pygame.display.flip()

In [2]:
from stable_baselines3.common.callbacks import BaseCallback
import time

class RenderCallback(BaseCallback):
    """
    Custom callback to render the environment during training.
    """
    def __init__(self, env, verbose=0):
        super(RenderCallback, self).__init__(verbose)
        self.env = env

    def _on_step(self) -> bool:
        self.env.envs[0].render()  # Render the first environment in the vectorized wrapper
        time.sleep(0.1)  # Delay for better visualization
        return True


In [None]:
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env

# Create a vectorized version of the ChessPygameEnv
env = make_vec_env(lambda: ChessPygameEnv(model_path="chess_move_predictor.pth", render_mode="human"), n_envs=1)

# Initialize PPO with your policy model
ppo_agent = PPO("MlpPolicy", env, verbose=1)

# Create the render callback
render_callback = RenderCallback(env)

# Train the agent using self-play
ppo_agent.learn(total_timesteps=10000, callback=render_callback)

# Save the fine-tuned agent
ppo_agent.save("ppo_chess_agent_finetuned")


  self.model.load_state_dict(torch.load(model_path, map_location=self.device))


Using cuda device




In [None]:
env = ChessPygameEnv(model_path="chess_move_predictor.pth", render_mode="human")

# Reset the environment
obs, info = env.reset()
done = False

print("Testing environment with random actions...")

while not done:
    # Sample a random action
    action = env.action_space.sample()
    
    # Step the environment
    obs, reward, done, truncated, _ = env.step(action)
    print(f"Action: {action}, Reward: {reward}, Done: {done}")

# env.close()


Testing environment with random actions...
Action: 1769, Reward: -1, Done: True


  self.model.load_state_dict(torch.load(model_path, map_location=self.device))


: 