In [1]:
# Imports that are needed 
import torch
import torch.nn as nn
import torch.nn.functional as F
import gym
import gym_chess
import random
import numpy as np
import chess
import chess.svg
from IPython.display import display, SVG
import os
from datetime import datetime
import math
import chess.pgn

# Reset PyTorch's CUDA state
torch.cuda.empty_cache()

class ActorCNN(nn.Module):
    def __init__(self, output_size=4672):
        super(ActorCNN, self).__init__()
        # Input channels = 12 (6 piece types * 2 colors)
        self.conv1 = nn.Conv2d(in_channels=12, out_channels=32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        
        # After convolution, the board is still 8x8 in spatial dimension
        # Flatten for the fully connected layers:
        self.fc1 = nn.Linear(in_features=64 * 8 * 8, out_features=256)
        self.fc2 = nn.Linear(256, output_size)

    def forward(self, x):
        """
        x shape: (batch_size, 12, 8, 8)
        """
        # Convolution layers
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))

        # Flatten
        x = x.view(x.size(0), -1)  # shape becomes (batch_size, 64*8*8)
        
        # Fully connected layers
        x = F.relu(self.fc1(x))
        
        # Output layer -> softmax for a probability distribution
        x = self.fc2(x)
        x = F.softmax(x, dim=-1)
        return x

class CriticCNN(nn.Module):
    def __init__(self):
        super(CriticCNN, self).__init__()
        self.conv1 = nn.Conv2d(12, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        
        self.fc1 = nn.Linear(64 * 8 * 8, 256)
        self.fc2 = nn.Linear(256, 1)

    def forward(self, x):
        """
        x shape: (batch_size, 12, 8, 8)
        """
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x  # No activation; can be negative or positive


In [2]:
def mutate_network(network, mutation_rate):
    if isinstance(network, ActorCNN):
        new_network = ActorCNN(output_size=4672)  # Using fixed sizes for chess
    else:
        new_network = CriticCNN()  # Using fixed size for chess
        
    new_network.load_state_dict(network.state_dict())
    
    for param in new_network.parameters():
        if torch.rand(1) < mutation_rate:
            param.data += torch.randn_like(param) * 0.1
            
    return new_network
# Initialize model and test
output_size = 4672
actor_net = ActorCNN(output_size=output_size)


class MCTSNode:
    def __init__(self, board, parent=None, prior=0, device=None):
        self.board = board
        self.parent = parent
        self.prior = prior
        self.children = {}
        self.visit_count = 0
        self.value_sum = 0
        self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.state = board_to_3d_tensor(board).to(self.device)
        self.move_number = board.fullmove_number  # Track move number for game phase

    def expand(self, actor_net):
        action_probs = actor_net(self.state.unsqueeze(0).to(self.device))
        move_lookup = create_move_lookup()
        for move in self.board.legal_moves:
            next_board = self.board.copy()
            next_board.push(move)
            self.children[move] = MCTSNode(
                next_board, 
                parent=self, 
                prior=action_probs[0][move_to_index(move, move_lookup)],
                device=self.device
            )

    def select_child(self):
        # Dynamic exploration constant based on game phase
        if self.move_number < 10:
            c_puct = 2.0  # Higher exploration in opening
        elif self.move_number < 20:
            c_puct = 1.5  # Moderate exploration in middlegame
        else:
            c_puct = 1.0  # Lower exploration in endgame

        best_score = float('-inf')
        best_child = None
        
        for move, child in self.children.items():
            # UCB1 formula with prior probability
            ucb_score = child.get_value() + c_puct * child.prior * \
                       (math.sqrt(self.visit_count) / (1 + child.visit_count))
            
            # Check if move is legal before evaluating captures and checks
            if move in self.board.legal_moves:
                # Create a copy of the board to safely check move properties
                temp_board = self.board.copy()
                temp_board.push(move)
                
                # Check for captures
                if temp_board.is_capture(move):
                    ucb_score += 0.5
                
                # Check for checks
                if temp_board.is_check():
                    ucb_score += 0.3
                
                if ucb_score > best_score:
                    best_score = ucb_score
                    best_child = (move, child)
                    
        return best_child

    def get_value(self):
        if self.visit_count == 0:
            return 0
        return self.value_sum / self.visit_count

def mcts_search(board, actor_net, critic_net, num_simulations=150):
    root = MCTSNode(board)
    
    for _ in range(num_simulations):
        node = root
        search_path = [node]
        
        while node.children and not node.board.is_game_over():
            move, node = node.select_child()
            search_path.append(node)
            
        if not node.board.is_game_over():
            node.expand(actor_net)
            
        value = critic_net(node.state.unsqueeze(0))
        
        for node in search_path:
            node.value_sum += value
            node.visit_count += 1
            
    return max(root.children.items(), key=lambda x: x[1].visit_count)[0]


In [3]:
class PrioritizedReplayBuffer:
    def __init__(self, capacity, alpha=0.6, device=None):
        self.capacity = capacity
        self.alpha = alpha
        self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Initialize empty buffers with numpy arrays
        self.buffer = []
        self.priorities = np.ones(capacity, dtype=np.float32)  # Initialize all priorities to 1
        self.position = 0
        
    def push(self, state, action, reward, next_state, done):
        # Convert inputs to tensors and move to device
        state = state.to(self.device)
        if isinstance(action, torch.Tensor):
            action = action.to(self.device)
        else:
            action = torch.tensor([action], device=self.device)
        reward = torch.tensor([reward], device=self.device, dtype=torch.float32)
        next_state = next_state.to(self.device)
        done = torch.tensor([done], device=self.device, dtype=torch.float32)
        
        # Create experience tuple
        experience = (state, action, reward, next_state, done)
        
        # Add experience to buffer
        if len(self.buffer) < self.capacity:
            self.buffer.append(experience)
        else:
            self.buffer[self.position] = experience
        
        # Update priority
        self.priorities[self.position] = self.priorities[:len(self.buffer)].max()
        self.position = (self.position + 1) % self.capacity


# Add this method to your PrioritizedReplayBuffer class
def __len__(self):
    return len(self.buffer)

def sample(self, batch_size):
    # If buffer is not populated enough, return empty samples
    if len(self.buffer) < batch_size:
        return [], [], []
    
    # Calculate sampling probabilities based on priorities
    priorities = self.priorities[:len(self.buffer)]
    probabilities = priorities ** self.alpha
    probabilities = probabilities / np.sum(probabilities)
    
    # Sample indices based on priorities
    indices = np.random.choice(len(self.buffer), batch_size, p=probabilities, replace=False)
    
    # Calculate importance sampling weights
    weights = (len(self.buffer) * probabilities[indices]) ** (-0.4)  # Beta=0.4
    weights = weights / np.max(weights)  # Normalize weights
    weights = torch.tensor(weights, dtype=torch.float32, device=self.device)
    
    # Get experiences
    batch = [self.buffer[idx] for idx in indices]
    states, actions, rewards, next_states, dones = zip(*batch)
    
    # Stack tensors
    states = torch.cat([s.unsqueeze(0) for s in states])
    actions = torch.cat([a for a in actions])
    rewards = torch.cat([r for r in rewards])
    next_states = torch.cat([ns.unsqueeze(0) for ns in next_states])
    dones = torch.cat([d for d in dones])
    
    experiences = (states, actions, rewards, next_states, dones)
    
    return experiences, indices, weights

def update_priorities(self, indices, td_errors):
    for idx, error in zip(indices, td_errors):
        self.priorities[idx] = np.abs(error) + 1e-6  # Add small constant to avoid zero priority


In [4]:
# Add this in a new cell
def adjust_training_parameters(episode, num_episodes):
    # Gradually increase MCTS simulations
    simulations = min(150 + (episode // 100) * 50, 500)
    
    # Adjust learning rate
    lr = max(0.001 * (0.95 ** (episode // 100)), 0.0001)
    
    # Adjust exploration rate
    exploration = max(0.1 * (0.95 ** (episode // 100)), 0.01)
    
    return simulations, lr, exploration

In [5]:
def compute_returns(rewards, gamma):
    returns = []
    R = 0
    for r in reversed(rewards):
        R = r + gamma * R
        returns.insert(0, R)
    return torch.tensor(returns)

class ModelPopulation:
    def __init__(self, population_size=2):
        self.population = []
        for i in range(population_size):
            actor = ActorCNN(output_size=4672)
            critic = CriticCNN()
            self.population.append((actor, critic))



def run_tournament(population, games_per_match=10):
    """
    population[i] = (actor_net_i, critic_net_i).
    Scores[i] is the total round-robin score for model i.
    """
    scores = {i: 0.0 for i in range(len(population))}
    
    # Round-robin tournament
    for i in range(len(population)):
        for j in range(i + 1, len(population)):
            model1, model2 = population[i], population[j]
            score_for_model1 = play_match(model1, model2, games_per_match)
            # model1 is White, model2 is Black
            scores[i] += score_for_model1
            scores[j] += (games_per_match - score_for_model1)
    
    # Sort models by tournament performance
    ranked_models = sorted(scores.items(), key=lambda x: x[1], reverse=True)
    return ranked_models


def play_match(model1, model2, num_games):
    """
    Returns the total score earned by model1 (the White side) over `num_games`.
    Each game:
      - 1.0 for White win,
      - 0.0 for Black win,
      - 0.5 for draw.
    """
    model1_score = 0.0
    print(f"Playing match between Model 1 (ActorCNN-CriticCNN) and Model 2 (ActorCNN-CriticCNN)")

    for game in range(num_games):
        chess_board = chess.Board()
        while not chess_board.is_game_over():
            current_model = model1 if chess_board.turn else model2
            move = mcts_search(chess_board, current_model[0], current_model[1])
            chess_board.push(move)

        # Final result string: "1-0", "0-1", or "1/2-1/2"
        result = chess_board.result()
        if result == "1-0":
            # White (model1) wins
            model1_score += 1.0
        elif result == "0-1":
            # Black (model2) wins
            model1_score += 0.0
        else:
            # Draw
            model1_score += 0.5

    return model1_score



def evolve_population(population, ranked_models, mutation_rate=0.01):
    # Keep top 50% performers
    survivors = [population[idx] for idx, _ in ranked_models[:len(population)//2]]
    
    # Create offspring with mutations
    offspring = []
    for model in survivors:
        new_actor = mutate_network(model[0], mutation_rate)
        new_critic = mutate_network(model[1], mutation_rate)
        offspring.append((new_actor, new_critic))
    
    return survivors + offspring


#Move Handling Functions
def create_move_lookup():
    moves = []
    for from_square in range(64):
        for to_square in range(64):
            moves.append((from_square, to_square))
    return moves

def get_piece_value(piece):
    piece_values = {
        chess.PAWN: 10,
        chess.KNIGHT: 30,
        chess.BISHOP: 30,
        chess.ROOK: 50,
        chess.QUEEN: 90,
        chess.KING: 0
    }
    return piece_values.get(piece.piece_type, 0)

def select_legal_action(action_probs, legal_moves, board):
    probs = action_probs.detach().cpu().numpy()[0]
    legal_moves_list = list(legal_moves)
    move_lookup = create_move_lookup()
    
    move_indices = []
    move_weights = []
    
    for move in legal_moves_list:
        from_square = move.from_square
        to_square = move.to_square
        
        # Calculate move value based on captured piece
        captured_piece = board.piece_at(to_square)
        move_value = get_piece_value(captured_piece) if captured_piece else 0
        
        try:
            idx = move_lookup.index((from_square, to_square))
            move_indices.append(idx)
            move_weights.append(move_value + 1)  # Add 1 to ensure non-zero probability
        except ValueError:
            continue
    
    if not move_indices:
        return random.choice(legal_moves_list)
    
    legal_probs = probs[move_indices]
    # Multiply probabilities by piece values
    weighted_probs = legal_probs * np.array(move_weights)
    weighted_probs = np.clip(weighted_probs, 1e-10, 1.0)
    
    if weighted_probs.sum() == 0 or np.isnan(weighted_probs.sum()):
        weighted_probs = np.ones_like(weighted_probs) / len(weighted_probs)
    else:
        weighted_probs = weighted_probs / weighted_probs.sum()
    
    selected_idx = np.random.choice(len(move_indices), p=weighted_probs)
    return legal_moves_list[selected_idx]

def move_to_index(move, move_lookup):
    from_square = move.from_square
    to_square = move.to_square
    return move_lookup.index((from_square, to_square))

In [6]:
# Replay Buffer
class ReplayBuffer:
    def __init__(self, capacity, device):
        self.capacity = capacity
        self.buffer = []
        self.position = 0
        self.device = device

    def __len__(self):
        return len(self.buffer)

    def push(self, state, action, reward, next_state, done):
        state = state.to(self.device)
        action = torch.tensor([action], device=self.device)
        reward = torch.tensor([reward], device=self.device)
        next_state = next_state.to(self.device)
        done = torch.tensor([done], device=self.device)
        
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        
        return (
            torch.cat(states),
            torch.cat(actions),
            torch.cat(rewards),
            torch.cat(next_states),
            torch.cat(dones)
        )

def board_to_3d_tensor(board):
    """
    Converts a chess.Board() into a shape (12, 8, 8) tensor.
    12 channels: [p, n, b, r, q, k, P, N, B, R, Q, K].
    Each channel is an 8x8 plane with 1.0 where that piece is present.
    """
    # Channels for black pieces: p, n, b, r, q, k
    # Channels for white pieces: P, N, B, R, Q, K
    piece_symbols = ['p','n','b','r','q','k','P','N','B','R','Q','K']
    state = np.zeros((12, 8, 8), dtype=np.float32)

    for square in range(64):
        piece = board.piece_at(square)
        if piece is not None:
            # Find which channel this piece corresponds to
            symbol = piece.symbol()  # e.g. 'P', 'n', etc.
            channel_idx = piece_symbols.index(symbol)
            row = square // 8
            col = square % 8
            state[channel_idx, row, col] = 1.0

    return torch.tensor(state, dtype=torch.float32)

In [7]:
def calculate_enhanced_reward(board, move, is_checkmate=False):
    base_reward = 0.0
    
    # High reward for checkmate
    if is_checkmate:
        return 100.0
    
    # Penalize draw conditions
    if board.is_stalemate() or board.is_repetition(3) or board.halfmove_clock >= 50:
        return -100.0
    
    # Material balance
    material_diff = evaluate_material_difference(board)
    base_reward += material_diff * 0.1
    
    # Piece activity
    piece_activity = len(list(board.attacks(move.to_square)))
    base_reward += piece_activity * 0.2
    
    # Center control
    center_squares = {chess.E4, chess.E5, chess.D4, chess.D5}
    extended_center = {chess.C3, chess.C4, chess.C5, chess.C6, chess.F3, chess.F4, chess.F5, chess.F6}
    if move.to_square in center_squares:
        base_reward += 1.0
    elif move.to_square in extended_center:
        base_reward += 0.5
        
            
    # Capture rewards
    captured_piece = board.piece_at(move.to_square)
    if captured_piece:
        capture_value = get_piece_value(captured_piece)
        base_reward += capture_value
        
        # Bonus for capturing with less valuable pieces
        from_piece = board.piece_at(move.from_square)
        if from_piece and get_piece_value(from_piece) < capture_value:
            base_reward += capture_value * 0.5
            
    # King safety (castling)
    from_piece = board.piece_at(move.from_square)
    if from_piece and from_piece.piece_type == chess.KING:
        if chess.square_distance(move.from_square, move.to_square) > 1:
            base_reward += 2.0
            
    return base_reward

In [8]:
def load_champion_model(save_path):
    actor = ActorCNN(output_size=4672).to(device)
    critic = CriticCNN().to(device)
    
    if os.path.exists(f"{save_path}_actor.pth"):
        actor.load_state_dict(torch.load(f"{save_path}_actor.pth"))
    if os.path.exists(f"{save_path}_critic.pth"):
        critic.load_state_dict(torch.load(f"{save_path}_critic.pth"))
        
    return actor, critic


def initialize_from_champion(champion_model, new_model):
    new_model.load_state_dict(champion_model.state_dict())
    return new_model

def train_chess_ai(
    num_episodes=100,
    batch_size=64,
    gamma=0.99,
    actor_net=None,
    critic_net=None,
    save_path='./models/chess_ai'
):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    champion_path = os.path.join(save_path, 'champion')
    os.makedirs(champion_path, exist_ok=True)

    memory = PrioritizedReplayBuffer(capacity=50000, device=device)
        
    # Load champion model if exists
    if os.path.exists(f"{save_path}_actor.pth") and actor_net and critic_net:
        champion_actor, champion_critic = load_champion_model(save_path)
        actor_net = initialize_from_champion(champion_actor, actor_net)
        critic_net = initialize_from_champion(champion_critic, critic_net)
    
    actor_net = actor_net or ActorCNN(output_size=4672)
    critic_net = critic_net or CriticCNN()
    actor_net = actor_net.to(device)
    critic_net = critic_net.to(device)
    
    actor_optimizer = torch.optim.Adam(actor_net.parameters(), lr=0.001)
    critic_optimizer = torch.optim.Adam(critic_net.parameters(), lr=0.001)
    
    memory = PrioritizedReplayBuffer(capacity=100000)
    move_lookup = create_move_lookup()
    
    # Optional: opening/endgame knowledge
    opening_book = {}
    endgame_database = {}
    
    total_moves = 0
    checkmates = 0
    white_wins = 0
    black_wins = 0
    
    for episode in range(num_episodes):
        print(f"\nEpisode {episode + 1}/{num_episodes}")
        print("=" * 50)
        
        board = chess.Board()
        total_reward = 0
        
        # Dynamic training parameter adjustment
        simulations, lr, exploration = adjust_training_parameters(episode, num_episodes)
        for param_group in actor_optimizer.param_groups:
            param_group['lr'] = lr
        for param_group in critic_optimizer.param_groups:
            param_group['lr'] = lr

        while not board.is_game_over():
            total_moves += 1
            if total_moves % 10 == 0:
                print(f"Total moves: {total_moves}")
                print(board)
            
            state = board_to_3d_tensor(board).unsqueeze(0)
            move = mcts_search(board, actor_net, critic_net, num_simulations=simulations)
            board.push(move)
            next_state = board_to_3d_tensor(board).unsqueeze(0)
            
            reward = calculate_enhanced_reward(board, move, board.is_checkmate())
            total_reward += reward
            
            is_done = board.is_game_over()
            memory.push(state, move_to_index(move, move_lookup), reward, next_state, is_done)
            
            # Checkmate handling
            if board.is_checkmate():
                checkmates += 1
                if board.turn == chess.BLACK:
                    white_wins += 1
                else:
                    black_wins += 1
                print("\nCheckmate!")
                print(f"Winner: {'White' if board.turn == chess.BLACK else 'Black'}")
        
        # Train networks
        if hasattr(memory, 'buffer') and len(memory.buffer) > batch_size:
            batch, indices, weights = memory.sample(batch_size)
            # Only pass the batch to update_networks
            update_networks(batch, actor_net, critic_net, actor_optimizer, critic_optimizer)

        # Log progress
        print(f"\nEpisode Summary:")
        print(f"Total Reward: {total_reward:.2f}")
        print(f"Checkmates: {checkmates}, White Wins: {white_wins}, Black Wins: {black_wins}")
        print(f"Average Moves/Game: {total_moves / (episode + 1):.1f}")
        
        # Save checkpoint
        if (episode + 1) % 100 == 0:
            torch.save(actor_net.state_dict(), f'{save_path}_actor.pth')
            torch.save(critic_net.state_dict(), f'{save_path}_critic.pth')
    
    return actor_net, critic_net, opening_book, endgame_database


def evaluate_material_difference(board):
    white_total = 0
    black_total = 0
    for square in chess.SQUARES:
        piece = board.piece_at(square)
        if piece:
            # Use the unified evaluation function
            value = get_piece_value(piece)
            if piece.color == chess.WHITE:
                white_total += value
            else:
                black_total += value
    return white_total - black_total

# Reward function for chess
def calculate_reward(board, move, is_checkmate=False):
    base_reward = 0.0

    # High reward for winning by checkmate
    if is_checkmate:
        return 100000.0  
    
    # Reward for moves that lead toward checkmate
    if board.is_check():
        reward += 500.0

    # Penalize draw conditions
    if board.is_stalemate() or board.is_repetition(3) or board.halfmove_clock >= 50:
        return -100000.0  

    # Reward for threatening opponent's king
    if board.is_attacked_by(not board.turn, board.king(board.turn)):
        reward += 2.0

    # # Center control reward
    # center_squares = {chess.E4, chess.E5, chess.D4, chess.D5}
    # if move.to_square in center_squares:
    #     base_reward += 0.5

    # # Reward development of knights and bishops from their starting positions
    # from_piece = board.piece_at(move.from_square)
    # if from_piece and from_piece.piece_type in [chess.KNIGHT, chess.BISHOP]:
    #     if move.from_square in [chess.B1, chess.G1, chess.B8, chess.G8]:
    #         base_reward += 0.3

    # # Reward for castling or moving the king significantly (for king safety)
    # if from_piece and from_piece.piece_type == chess.KING:
    #     if chess.square_distance(move.from_square, move.to_square) > 1:
    #         base_reward += 1.0

    # Reward for capturing an opponent piece
    captured_piece = board.piece_at(move.to_square)
    if captured_piece:
        base_reward += get_piece_value(captured_piece)

    # Use the material difference as a proxy for hanging pieces or overall advantage.
    base_reward += evaluate_material_difference(board)

    return base_reward

# Set global device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def tournament(models, num_episodes=100, save_path='./tournament_results'):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    
    # Move all models to GPU
    models = [model.to(device) for model in models]
    print(f"Running tournament on: {device}")
    
    # Tournament statistics
    scores = [0] * len(models)
    wins = [0] * len(models)
    losses = [0] * len(models)
    draws = [0] * len(models)
    games_played = 0
    tournament_history = []
    
    env = ChessEnv()  # Make sure your env handles draw logic

    # Round robin tournament
    for round_num in range(2):  # Two rounds so each model plays both colors
        for i in range(len(models)):
            for j in range(len(models)):
                if i != j:
                    # print(f"\nRound {round_num + 1}: Model {i+1} vs Model {j+1}")
                    
                    # Track match statistics for i vs j
                    match_data = []
                    total_score = 0
                    
                    for episode in range(num_episodes):
                        # print(f"Episode {episode + 1}/{num_episodes}")
                        
                        # Reset environment for each episode
                        env.reset()
                        state = env.get_state().to(device)
                        done = False
                        moves = 0
                        # Positive => advantage model i, Negative => advantage model j
                        episode_score = 0

                        while not done:
                            moves += 1
                            # Model i's turn
                            action1 = models[i](state)
                            reward_i, done = env.step(action1)
                            episode_score += reward_i

                            if done:
                                # If the game ended after i's turn, break out
                                break

                            # Model j's turn
                            state = env.get_state().to(device)
                            action2 = models[j](state)
                            reward_j, done = env.step(action2)
                            # Subtract j's reward (score for j reduces i's net score)
                            episode_score -= reward_j

                            if not done:
                                state = env.get_state().to(device)

                        # One game is completed
                        games_played += 1

                        # Debug: Check final episode score
                        # print(f"DEBUG: Final episode_score = {episode_score}")

                        # Decide the winner based on final episode_score
                        if episode_score > 0:
                            wins[i] += 1
                            losses[j] += 1
                        elif episode_score < 0:
                            wins[j] += 1
                            losses[i] += 1
                        else:
                            draws[i] += 1
                            draws[j] += 1

                        total_score += episode_score
                        
                        match_data.append({
                            'episode': episode + 1,
                            'moves': moves,
                            'score': episode_score
                        })
                        
                        # Print statistics every 2 episodes (adjust as you like)
                        if episode % 2 == 0:
                            for k in range(len(models)):
                                win_rate = (wins[k] / games_played) * 100 if games_played > 0 else 0
                                loss_rate = (losses[k] / games_played) * 100 if games_played > 0 else 0
                                draw_rate = (draws[k] / games_played) * 100 if games_played > 0 else 0
                                # print(f"""Model {k+1}:
                                #     Wins: {wins[k]} ({win_rate:.2f}%)
                                #     Losses: {losses[k]} ({loss_rate:.2f}%)
                                #     Draws: {draws[k]} ({draw_rate:.2f}%)
                                #     Total Games: {games_played}""")

                    # Tally the net score for i and j
                    scores[i] += total_score
                    scores[j] -= total_score

                    tournament_history.append({
                        'round': round_num + 1,
                        'model1': i,
                        'model2': j,
                        'matches': match_data,
                        'total_score': total_score
                    })
                    
                    # Save results periodically
                    np.save(f'{save_path}_history.npy', tournament_history)
    
    # Determine champion
    champion_idx = scores.index(max(scores))
    print(f"\nTournament Winner: Model {champion_idx + 1}")
    print(f"Final Scores: {scores}")
    print(f"Total Wins: {wins}")
    print(f"Total Draws: {draws}")
    
    return models[champion_idx], tournament_history

def update_networks(batch, actor_net, critic_net, actor_optimizer, critic_optimizer):
    # Unpack the batch tuple
    experiences, indices, weights = batch
    states, actions, rewards, next_states, dones = experiences
    
    # Update critic
    with torch.no_grad():
        next_values = critic_net(next_states)
        target_values = rewards.unsqueeze(1) + (1 - dones.unsqueeze(1)) * 0.99 * next_values
    
    current_values = critic_net(states)
    critic_loss = (weights.unsqueeze(1) * F.mse_loss(current_values, target_values, reduction='none')).mean()
    
    critic_optimizer.zero_grad()
    critic_loss.backward()
    critic_optimizer.step()
    
    # Update actor
    action_probs = actor_net(states)
    action_log_probs = torch.log(action_probs + 1e-10)
    selected_action_log_probs = action_log_probs.gather(1, actions.unsqueeze(1))
    
    advantages = (target_values - current_values).detach()
    actor_loss = -(weights.unsqueeze(1) * selected_action_log_probs * advantages).mean()
    
    actor_optimizer.zero_grad()
    actor_loss.backward()
    actor_optimizer.step()
    
    # Update priorities
    td_errors = torch.abs(target_values - current_values).detach().cpu().numpy().flatten()
    batch[1].update_priorities(indices, td_errors)
    
    return critic_loss.item(), actor_loss.item()


  



def play_game(model1, model2, env, num_episodes):
    total_score = 0
    for episode in range(num_episodes):
        # print(f"\nEpisode {episode + 1}/{num_episodes}")
        state = env.get_state().to(device)
        done = False
        moves = 0
        
        while not done:
            moves += 1
            # Model 1's turn
            action1 = model1(state)
            reward, done = env.step(action1)
            total_score += reward
            # print(f"Move {moves}: Model 1 reward: {reward}")
            
            if done:
                print(f"Game ended after {moves} moves")
                break
                
            # Model 2's turn    
            state = env.get_state().to(device)
            action2 = model2(state)
            reward, done = env.step(action2)
            total_score -= reward
            # print(f"Move {moves + 1}: Model 2 reward: {reward}")
            
            if not done:
                state = env.get_state().to(device)
                
        # print(f"Episode {episode + 1} complete - Total score: {total_score}")
    
    return total_score

def save_champion(actor, critic, folder_path):
    # Save directly without the 'model_state_dict' wrapper
    torch.save(actor.state_dict(), f"{folder_path}/champion_actor.pth")
    torch.save(critic.state_dict(), f"{folder_path}/champion_critic.pth")

def test_champion(folder_path):
    # Load the champion model
    actor = ActorCNN(output_size=4672).to(device)
    actor.load_state_dict(torch.load(f'{folder_path}/champion_actor.pth'))
    actor.eval()
    
    print(f"\nTesting champion from {folder_path}")
    game = play_chess_game(actor, device)  # Pass device as well
    
    pgn_path = f'{folder_path}/champion_game.pgn'
    with open(pgn_path, 'w') as f:
        f.write(str(game))

def get_model_move(model, board):
    with torch.no_grad():
        state_3d = board_to_3d_tensor(board).unsqueeze(0).to(device)
        action_probs = model(state_3d)  # shape (1, 4672)
        move_idx = torch.argmax(action_probs).item()
        legal_moves = list(board.legal_moves)
        return legal_moves[move_idx % len(legal_moves)]


def get_model_move(model, state, board):
    with torch.no_grad():
        action_probs = model(state)
        move_idx = torch.argmax(action_probs).item()
        legal_moves = list(board.legal_moves)
        return legal_moves[move_idx % len(legal_moves)]

def play_chess_game(model, device):
    board = chess.Board()
    game = chess.pgn.Game()
    node = game
    
    while not board.is_game_over():
        # Convert board to state tensor
        state = board_to_3d_tensor(board).unsqueeze(0).to(device)
        # Pass the state tensor and the board to get_model_move
        move = get_model_move(model, state, board)
        board.push(move)
        node = node.add_variation(move)
    
    return game




In [9]:
# Main generational loop
major_gen = 1
minor_gen = 3
max_minor_gen = 1
max_major_gen = 4

def create_generation_folder(gen_major, gen_minor):
    folder_name = f"chessCNN_ai_generation_{gen_major}_{gen_minor}"
    os.makedirs(folder_name, exist_ok=True)
    return folder_name

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def load_previous_champion(prev_major_gen, prev_minor_gen):
    prev_folder = f"chessCNN_ai_generation_{prev_major_gen}_{prev_minor_gen}"
    actor_path = os.path.join(prev_folder, 'champion_actor.pth')
    critic_path = os.path.join(prev_folder, 'champion_critic.pth')
    
    actor = ActorCNN(output_size=4672).to(device)
    critic = CriticCNN().to(device)
    
    actor.load_state_dict(torch.load(actor_path))
    critic.load_state_dict(torch.load(critic_path))
    
    return actor, critic

while major_gen <= max_major_gen:
    current_folder = create_generation_folder(major_gen, minor_gen)
    print(f"\nStarting Generation {major_gen}_{minor_gen}")
    
    # Load previous champion or create fresh models
    if major_gen == 2 and minor_gen == 1:
        initial_actor, initial_critic = load_previous_champion(1, 1)
    else:
        initial_actor = ActorCNN(output_size=4672).to(device)
        initial_critic = CriticCNN().to(device)
    
    # Create population with mutations
    population = ModelPopulation(population_size=8)
    population.population[0] = (initial_actor, initial_critic)
    
    for i in range(1, len(population.population)):
        population.population[i] = (
            mutate_network(initial_actor, mutation_rate=0.01).to(device),
            mutate_network(initial_critic, mutation_rate=0.01).to(device)
        )
    
    # Train each model in population
    for i, (actor, critic) in enumerate(population.population):
        train_chess_ai(
            actor_net=actor,
            critic_net=critic,
            save_path=os.path.join(current_folder, f'model_{i}')
        )
    
    # Run tournament to find best model
    rankings = run_tournament(population.population)
    best_model_idx = rankings[0][0]
    winner_actor, winner_critic = population.population[best_model_idx]
    
    # Save champion
    save_champion(winner_actor, winner_critic, current_folder)
    
    # Test the champion
    test_champion(current_folder)
    
    # Update generation counters
    minor_gen += 1
    if minor_gen > max_minor_gen:
        minor_gen = 1
        major_gen += 1
    
    print(f"Completed Generation {major_gen}_{minor_gen-1}")

print("Training completed!")



Starting Generation 1_3
Using device: cpu

Episode 1/100


KeyboardInterrupt: 