In [5]:
# Imports
import chess
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
import os
from chess import Move, Board

In [6]:
from google.colab import drive
drive.mount('/content/drive')

KeyboardInterrupt: 

In [None]:
# Base Agent Class
class Agent:
    def __init__(self):
        pass

    def get_action(self, game_state: Board):
        pass

In [None]:
# Chess Environment
class ChessEnv:
    def __init__(self):
        self.board = chess.Board()
        self.action_space_size = 4096
        
        self.weights = {
            'pawn': 1.0,
            'knight': 3.0,
            'bishop': 3.2,
            'rook': 5.0,
            'queen': 9.0,
            'king_safety': 0.0,      # Disable for now (noise)
            'mobility': 0.0,         # Disable for now (noise)
            'center': 0.0,           # Disable for now (noise)
            'pst_scale': 0.0,        # Disable PST initially to focus on pure material capture
            'step_penalty': -0.05,   # Small pressure to finish games
            'check': 1.0,            # Helpful tactile feedback
            'castling': 1.0,         # Good safe habit
            'repetition_penalty': -2.0 
        }
        
        # Simple Piece-Square Tables (Pawn & Knight)
        # Scaled 0-100, centered on mid-game principles
        self.pst_pawn = [
             0,  0,  0,  0,  0,  0,  0,  0,
            50, 50, 50, 50, 50, 50, 50, 50,
            10, 10, 20, 30, 30, 20, 10, 10,
             5,  5, 10, 25, 25, 10,  5,  5,
             0,  0,  0, 20, 20,  0,  0,  0,
             5, -5,-10,  0,  0,-10, -5,  5,
             5, 10, 10,-20,-20, 10, 10,  5,
             0,  0,  0,  0,  0,  0,  0,  0
        ]
        self.pst_knight = [
            -50,-40,-30,-30,-30,-30,-40,-50,
            -40,-20,  0,  0,  0,  0,-20,-40,
            -30,  0, 10, 15, 15, 10,  0,-30,
            -30,  5, 15, 20, 20, 15,  5,-30,
            -30,  0, 15, 20, 20, 15,  0,-30,
            -30,  5, 10, 15, 15, 10,  5,-30,
            -40,-20,  0,  5,  5,  0,-20,-40,
            -50,-40,-30,-30,-30,-30,-40,-50
        ]

    def reset(self):
        self.board.reset()
        return self.get_state()

    def step(self, action_idx):
        move = self.decode_action(action_idx)
        if move not in self.board.legal_moves:
            return self.get_state(), -10.0, True, {"legal": False} # Penalty for illegal

        # 1. Calculate Potential BEFORE move
        prev_potential = self._get_potential(self.board)
        
        # Check for castling bonus (before move)
        castling_bonus = 0.0
        if self.board.is_castling(move):
            castling_bonus = self.weights['castling']

        # 2. Execute Move
        self.board.push(move)
        done = self.board.is_game_over()

        # 3. Calculate Potential AFTER move
        curr_potential = self._get_potential(self.board)
        
        # Check bonus (after move, is opponent in check?)
        check_bonus = 0.0
        if self.board.is_check():
            check_bonus = self.weights['check']
            
        # Repetition Penalty
        repetition_penalty = 0.0
        if self.board.is_repetition(2): # 2-fold repetition
            repetition_penalty = self.weights['repetition_penalty']

        # 4. Reward Shaping (Difference in Potential)
        # If I am white, I want potential to increase.
        # If I am black, I want potential to decrease (since eval is usually White-centric)
        # Note: self.board.turn is now the OPPONENT's turn after push.
        # So if we just moved White, board.turn is Black.
        
        # We need the color of the agent who JUST moved.
        agent_color = not self.board.turn 
        
        # Standard RL perspective: Reward is for the Agent.
        diff = curr_potential - prev_potential
        reward = diff if agent_color == chess.WHITE else -diff
        
        # Add Bonuses and Penalties
        reward += castling_bonus
        reward += check_bonus
        reward += repetition_penalty
        
        # Add Step Penalty (Time pressure)
        reward += self.weights['step_penalty']

        # Terminal Rewards (Override shaping for clear outcomes)
        if done:
            if self.board.is_checkmate():
                # Massive reward for winning.
                reward += 100.0 
                # reward += 20.0
            elif self.board.is_stalemate() or self.board.is_insufficient_material():
                # Draw is better than losing, but worse than winning.
                reward += 0.0

        return self.get_state(), reward, done, {"legal": True}

    def _get_potential(self, board):
        """
        Calculates a dense evaluation of the board from White's perspective.
        """
        score = 0
        
        # 1. Material
        pm = board.piece_map()
        for sq, piece in pm.items():
            val = 0
            if piece.piece_type == chess.PAWN: val = self.weights['pawn']
            elif piece.piece_type == chess.KNIGHT: val = self.weights['knight']
            elif piece.piece_type == chess.BISHOP: val = self.weights['bishop']
            elif piece.piece_type == chess.ROOK: val = self.weights['rook']
            elif piece.piece_type == chess.QUEEN: val = self.weights['queen']
            
            # PST
            pst_val = 0
            if self.weights['pst_scale'] > 0:
                # Calculate rank/file (0-7)
                rank = chess.square_rank(sq)
                file = chess.square_file(sq)
                idx = (7 - rank) * 8 + file # Map to 0=a8, 63=h1 table layout
                
                # Mirror for black
                if piece.color == chess.BLACK:
                    rank = 7 - rank # Mirror rank
                    idx = (7 - rank) * 8 + file

                if piece.piece_type == chess.PAWN: pst_val = self.pst_pawn[idx]
                elif piece.piece_type == chess.KNIGHT: pst_val = self.pst_knight[idx]
                
            total_piece_val = val + (pst_val * 0.01 * self.weights['pst_scale'])
            
            if piece.color == chess.WHITE:
                score += total_piece_val
            else:
                score -= total_piece_val

        # 2. Center Control (Bonus for occupying center)
        # Simple check for pieces on e4, d4, e5, d5
        for sq in [chess.E4, chess.D4, chess.E5, chess.D5]:
            p = board.piece_at(sq)
            if p:
                if p.color == chess.WHITE: score += self.weights['center']
                else: score -= self.weights['center']

        # 3. Mobility (Number of legal moves)
        # Note: python-chess legal_moves is for the current turn side only.
        # Approximating mobility is expensive if we flip turns. 
        # For efficiency, we might skip or just use current side.
        # Here is a simple implementation adding current side mobility:
        legal_count = board.legal_moves.count()
        mobility_score = legal_count * self.weights['mobility']
        
        if board.turn == chess.WHITE:
            score += mobility_score
        else:
            score -= mobility_score
            
        return score

    def get_state(self):
        """
        Converts board to 12x8x8 numpy array.
        Channels:
        0-5: White P, N, B, R, Q, K
        6-11: Black P, N, B, R, Q, K
        """
        state = np.zeros((12, 8, 8), dtype=np.float32)
        
        piece_map = {
            chess.PAWN: 0, chess.KNIGHT: 1, chess.BISHOP: 2,
            chess.ROOK: 3, chess.QUEEN: 4, chess.KING: 5
        }
        
        for square in chess.SQUARES:
            piece = self.board.piece_at(square)
            if piece:
                rank = chess.square_rank(square)
                file = chess.square_file(square)
                
                channel = piece_map[piece.piece_type]
                if piece.color == chess.BLACK:
                    channel += 6
                
                state[channel, rank, file] = 1
                
        return state

    def encode_action(self, move: chess.Move) -> int:
        """Encodes a chess move into an integer 0-4095."""
        from_square = move.from_square
        to_square = move.to_square
        return from_square * 64 + to_square

    def decode_action(self, action_idx: int) -> chess.Move:
        from_square = action_idx // 64
        to_square = action_idx % 64
        move = chess.Move(from_square, to_square)
        
        # --- FIX: Auto-promote to Queen ---
        # Check if this is a pawn move to the last rank
        piece = self.board.piece_at(from_square)
        if piece and piece.piece_type == chess.PAWN:
            rank = chess.square_rank(to_square)
            if (piece.color == chess.WHITE and rank == 7) or \
            (piece.color == chess.BLACK and rank == 0):
                move.promotion = chess.QUEEN # Auto-promote
                
        return move

    def get_legal_actions(self):
        """Returns list of legal action indices."""
        legal_moves = []
        for move in self.board.legal_moves:
            legal_moves.append(self.encode_action(move))
        return legal_moves

In [None]:
# Random Agent
class RandomAgent(Agent):
    def __init__(self):
        super().__init__()
        
    def get_action(self, game_state: Board) -> Move:
        valid_moves = game_state.legal_moves
        return random.choice(list(valid_moves))

In [None]:
# RL Agent (Dueling DQN)
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = x
        x = torch.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        x += residual
        return torch.relu(x)

class DuelingDQN(nn.Module):
    def __init__(self, input_shape, num_actions):
        super(DuelingDQN, self).__init__()
        self.input_shape = input_shape
        self.num_actions = num_actions
        
        # Initial Conv Layer
        self.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        
        # Residual Towers (AlphaZero style, but smaller)
        self.res1 = ResidualBlock(64)
        self.res2 = ResidualBlock(64)
        self.res3 = ResidualBlock(64)
        self.res4 = ResidualBlock(64)
        
        flat_size = 64 * 8 * 8
        
        # Value Stream (Evaluates the board state)
        self.value_fc = nn.Sequential(
            nn.Linear(flat_size, 512),
            nn.ReLU(),
            nn.Linear(512, 1)
        )
        
        # Advantage Stream (Evaluates each specific action)
        self.advantage_fc = nn.Sequential(
            nn.Linear(flat_size, 1024),
            nn.ReLU(),
            nn.Linear(1024, num_actions)
        )
        
    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = self.res1(x)
        x = self.res2(x)
        x = self.res3(x)
        x = self.res4(x)
        
        x = x.view(x.size(0), -1)
        
        value = self.value_fc(x)
        advantage = self.advantage_fc(x)
        
        # Dueling Network Aggregation
        q_values = value + (advantage - advantage.mean(dim=1, keepdim=True))
        return q_values

class RLAgent(Agent):
    def __init__(self, state_shape=(12, 8, 8), action_size=4096):
        super().__init__()
        self.state_shape = state_shape
        self.action_size = action_size
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.policy_net = DuelingDQN(state_shape, action_size).to(self.device)
        self.target_net = DuelingDQN(state_shape, action_size).to(self.device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
        
        # Lower LR for stability with ResNet
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=1e-4)
        self.memory = deque(maxlen=20000) # Increased memory
        
        self.batch_size = 64 # Increased batch size
        self.gamma = 0.99
        self.epsilon = 1.0
        self.epsilon_min = 0.05
        # Faster decay initially to exploit learned behavior sooner
        self.epsilon_decay = 0.9995 
        
    def get_action(self, game_state, legal_moves_indices=None):
        is_inference = False
        board_for_inference = None

        if isinstance(game_state, chess.Board):
            is_inference = True
            board_for_inference = game_state.copy()
            state_tensor = self._board_to_tensor(game_state)
            legal_moves_indices = self._get_legal_actions(game_state)
        else:
            state_tensor = game_state
        
        if not is_inference and random.random() < self.epsilon:
            if legal_moves_indices:
                return random.choice(legal_moves_indices)
            return random.randint(0, self.action_size - 1)

        with torch.no_grad():
            state_tensor = torch.FloatTensor(state_tensor).unsqueeze(0).to(self.device)
            q_values = self.policy_net(state_tensor)
            
            if legal_moves_indices:
                # Mask illegal moves with negative infinity
                mask = torch.full((1, self.action_size), -float('inf')).to(self.device)
                mask[0, legal_moves_indices] = 0
                q_values += mask
                
            action_idx = q_values.argmax().item()
        
        if is_inference:
            return self._decode_action(action_idx, board_for_inference)
            
        return action_idx

    def update(self):
        if len(self.memory) < self.batch_size:
            return
        
        batch = random.sample(self.memory, self.batch_size)
        state, action, reward, next_state, done = zip(*batch)
        
        state = torch.FloatTensor(np.array(state)).to(self.device)
        action = torch.LongTensor(action).unsqueeze(1).to(self.device)
        reward = torch.FloatTensor(reward).unsqueeze(1).to(self.device)
        next_state = torch.FloatTensor(np.array(next_state)).to(self.device)
        done = torch.FloatTensor(done).unsqueeze(1).to(self.device)
        
        # Double DQN Logic
        # 1. Select best action using Policy Net
        next_actions = self.policy_net(next_state).argmax(1, keepdim=True)
        # 2. Evaluate that action using Target Net
        next_q_values = self.target_net(next_state).gather(1, next_actions)
        
        expected_q_values = reward + (1 - done) * self.gamma * next_q_values
        q_values = self.policy_net(state).gather(1, action)
        
        loss = nn.MSELoss()(q_values, expected_q_values)
        
        self.optimizer.zero_grad()
        loss.backward()
        # Gradient clipping to prevent explosion
        torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
        self.optimizer.step()
        

    def decay_epsilon(self):
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

    def update_target_network(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def save(self, path):
        torch.save(self.policy_net.state_dict(), path)

    def load(self, path, training=False):
        self.policy_net.load_state_dict(torch.load(path, map_location=self.device))
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.epsilon = 0.5 if training else 0.0
        if training: self.policy_net.train()
        else: self.policy_net.eval()

    def _board_to_tensor(self, board):
        state = np.zeros((12, 8, 8), dtype=np.float32)
        piece_map = {
            chess.PAWN: 0, chess.KNIGHT: 1, chess.BISHOP: 2,
            chess.ROOK: 3, chess.QUEEN: 4, chess.KING: 5
        }
        for square in chess.SQUARES:
            piece = board.piece_at(square)
            if piece:
                rank = chess.square_rank(square)
                file = chess.square_file(square)
                channel = piece_map[piece.piece_type]
                if piece.color == chess.BLACK:
                    channel += 6
                state[channel, rank, file] = 1
        return state

    def _get_legal_actions(self, board):
        legal_moves = []
        for move in board.legal_moves:
            legal_moves.append(move.from_square * 64 + move.to_square)
        return legal_moves

    def _decode_action(self, action_idx: int, board: chess.Board = None) -> chess.Move:
        from_square = action_idx // 64
        to_square = action_idx % 64
        move = chess.Move(from_square, to_square)
        
        # Logic to auto-promote during inference/testing
        if board:
            piece = board.piece_at(from_square)
            if piece and piece.piece_type == chess.PAWN:
                rank = chess.square_rank(to_square)
                if (piece.color == chess.WHITE and rank == 7) or \
                   (piece.color == chess.BLACK and rank == 0):
                    move.promotion = chess.QUEEN
        
        return move

In [None]:
# Evaluation Function
def evaluate_agent(agent, opponent, num_games=20):
    wins = 0
    draws = 0
    losses = 0
    
    # Force exploitation during evaluation
    original_epsilon = agent.epsilon
    agent.epsilon = 0.0
    
    for _ in range(num_games):
        env = ChessEnv()
        state = env.reset()
        done = False
        board = env.board
        
        while not done:
            if board.turn == chess.WHITE:
                legal_moves = env.get_legal_actions()
                action_idx = agent.get_action(state, legal_moves)
                move = env.decode_action(action_idx)
                
                if move not in board.legal_moves:
                    losses += 1; done = True; break
                board.push(move)
            else:
                if board.is_game_over(): break
                move = opponent.get_action(board)
                board.push(move)
            
            state = env.get_state() # Update state for next step
            
            if board.is_game_over():
                outcome = board.outcome()
                if outcome.winner == chess.WHITE: wins += 1
                elif outcome.winner == chess.BLACK: losses += 1
                else: draws += 1
                done = True

    agent.epsilon = original_epsilon
    return wins / num_games

In [None]:
# Configuration & Initialization
load_model_path = None # Example: "models/rl_model_1000.pth"
start_episode = 0

if load_model_path:
    start_episode = int(load_model_path.split("_")[-1].split(".")[0])

episodes = 5000
target_update_freq = 20

# Initialize Environment and Agents
env = ChessEnv()
agent = RLAgent()

if load_model_path:
    agent.load(load_model_path, True)
    print(f"Loaded model from {load_model_path}, starting at episode {start_episode}")

# CRITICAL CHANGE: Train against RandomAgent first!
train_opponent = RandomAgent()
eval_opponent = RandomAgent()

if not os.path.exists("models"): os.makedirs("models")

In [None]:
# Training Loop
for episode in range(start_episode, episodes):
    try:
        pass # Placeholder for keyboard interrupt check structure if needed, but in notebook we can just stop cell
    except KeyboardInterrupt:
        print(f"\nTraining interrupted. Saving model at episode {episode}...")
        agent.save(f"models/rl_model_{episode}.pth")
        break

    state = env.reset()
    done = False
    max_steps = 200 # Games vs Random shouldn't take forever
    step_count = 0
    total_reward = 0
    
    while not done and step_count < max_steps:
        step_count += 1
        
        # --- Agent Turn (White) ---
        legal_moves = env.get_legal_actions()
        action_idx = agent.get_action(state, legal_moves)
        
        next_state, reward, done, info = env.step(action_idx)
        # total_reward += reward
        
        # --- Opponent Turn (Black) ---
        if not done:
            opp_move = train_opponent.get_action(env.board)
            opp_action_idx = env.encode_action(opp_move)
            
            # We care about the state AFTER opponent moves
            next_state_final, opp_reward, done, info = env.step(opp_action_idx)
            
            # Reward Logic:
            # My Reward - Opponent Gain. 
            # If Opponent blunders (negative opp_reward), I get a bonus.
            opp_pure_reward = opp_reward - env.weights['step_penalty']
            if opp_pure_reward > 0:
                combined_reward = reward - opp_pure_reward
            else:
                combined_reward = reward
            
            agent.remember(state, action_idx, combined_reward, next_state_final, done)
            state = next_state_final
            total_reward += combined_reward
        else:
            agent.remember(state, action_idx, reward, next_state, done)
            state = next_state # Technically terminal
            total_reward += reward
        
        agent.update()

        
    if episode % target_update_freq == 0:
        agent.update_target_network()
    agent.decay_epsilon()
        
    print(f"Episode: {episode}, Reward: {total_reward:.2f}, Epsilon: {agent.epsilon:.3f}")
    
    # Evaluation every 100 episodes
    if episode > 0 and episode % 100 == 0:
        win_rate = evaluate_agent(agent, eval_opponent, num_games=50)
        print(f"--- Eval Episode {episode}: Win Rate {win_rate*100:.1f}% ---")
        
        if win_rate >= 0.90:
            print("GOAL REACHED! Saving model.")
            agent.save(f"models/chess_90_percent.pth")
            # Optional: break or switch to harder opponent here

    if episode % 250 == 0:
        agent.save(f"models/chess_{episode}.pth")