In [55]:
# Imports that are need 
import torch
import torch.nn as nn
import torch.nn.functional as F
import gym
import gym_chess
import random
import numpy as np
import chess
import chess.svg
from IPython.display import display, SVG
import os
from datetime import datetime

In [56]:
# Actor Network Definition
class Actor(nn.Module):
    def __init__(self, input_size, output_size):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, output_size)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return F.softmax(self.fc3(x), dim=-1)

In [57]:
# Critic Network Definition
class Critic(nn.Module):
    def __init__(self, input_size):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

In [58]:
#Move Handling Functions
def create_move_lookup():
    moves = []
    for from_square in range(64):
        for to_square in range(64):
            moves.append((from_square, to_square))
    return moves

def select_legal_action(action_probs, legal_moves):
    probs = action_probs.detach().numpy()[0]
    legal_moves_list = list(legal_moves)
    move_lookup = create_move_lookup()
    
    move_indices = []
    for move in legal_moves_list:
        from_square = move.from_square
        to_square = move.to_square
        try:
            idx = move_lookup.index((from_square, to_square))
            move_indices.append(idx)
        except ValueError:
            continue  # Skip if the move is not found in the lookup
    
    if not move_indices:
        # If no moves were matched, select a random legal move
        return random.choice(legal_moves_list)
    
    legal_probs = probs[move_indices]
    # Handle potential numerical issues
    legal_probs = np.clip(legal_probs, 1e-10, 1.0)
    if legal_probs.sum() == 0 or np.isnan(legal_probs.sum()):
        legal_probs = np.ones_like(legal_probs) / len(legal_probs)
    else:
        legal_probs = legal_probs / legal_probs.sum()
    
    selected_idx = np.random.choice(len(move_indices), p=legal_probs)
    return legal_moves_list[selected_idx]

def move_to_index(move, move_lookup):
    from_square = move.from_square
    to_square = move.to_square
    return move_lookup.index((from_square, to_square))


In [59]:
# Replay Buffer
class ReplayBuffer:
    def __init__(self, capacity=10000):
        self.capacity = capacity
        self.buffer = []
        self.position = 0
        
    def push(self, state, action, reward, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = (self.position + 1) % self.capacity
        
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*batch)
        return state, action, reward, next_state, done
    
    def __len__(self):
        return len(self.buffer)


In [60]:

def board_to_tensor(board):
    pieces = ['p', 'n', 'b', 'r', 'q', 'k', 'P', 'N', 'B', 'R', 'Q', 'K']
    state = np.zeros(768)
    
    for i in range(64):
        piece = board.piece_at(i)
        if piece:
            piece_idx = pieces.index(piece.symbol())
            state[i + piece_idx * 64] = 1
            
    return torch.FloatTensor(state)


In [61]:
# Reward function for chess
def calculate_reward(board, move, is_checkmate=False):
    base_reward = 0.0

    # Checkmate reward
    if is_checkmate:
        return 100.0  # Highest reward for winning
    
        # Penalize stalemates and draw conditions
    if board.is_stalemate():
        return -50.0  # Heavy penalty for stalemates
    if board.is_repetition(3):  # Threefold repetition
        return -50.0  # Heavy penalty for threefold repetition
    if board.halfmove_clock >= 50:  # Fifty-move rule
        return -50.0  # Heavy penalty for fifty-move rule

    # Piece values
    piece_values = {
        'p': 1.0,  # Pawn
        'n': 3.0,  # Knight
        'b': 3.0,  # Bishop
        'r': 5.0,  # Rook
        'q': 9.0,  # Queen
        'k': 0.0   # King (king captures are not applicable)
    }

    # Center control rewards
    center_squares = {chess.E4, chess.E5, chess.D4, chess.D5}
    if move.to_square in center_squares:
        base_reward += 0.5

    # Piece development rewards
    from_piece = board.piece_at(move.from_square)
    if from_piece and from_piece.piece_type in [chess.KNIGHT, chess.BISHOP]:
        if move.from_square in [chess.B1, chess.G1, chess.B8, chess.G8]:  # Starting squares
            base_reward += 0.3

    # King safety (castling)
    if from_piece and from_piece.piece_type == chess.KING:
        if chess.square_distance(move.from_square, move.to_square) > 1:
            base_reward += 1.0

    # Capture rewards
    captured_piece = board.piece_at(move.to_square)
    if captured_piece:
        base_reward += piece_values[captured_piece.symbol().lower()]

    return base_reward


In [62]:
#Training the Chess AI
def train_chess_ai(num_episodes=100, batch_size=64, gamma=0.99, actor_net=None, critic_net=None, save_path='./models/chess_ai'):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)

    env = gym.make('Chess-v0')
    env.reset()
    chess_board = chess.Board() # Access the board from the environment

    if actor_net is None:
        actor_net = Actor(input_size=768, output_size=4672)
    if critic_net is None:
        critic_net = Critic(input_size=768)

    actor_optimizer = torch.optim.Adam(actor_net.parameters(), lr=1e-4)
    critic_optimizer = torch.optim.Adam(critic_net.parameters(), lr=1e-3)

    memory = ReplayBuffer(capacity=10000)
    training_history = []
    move_lookup = create_move_lookup()
    
    # Add at the start of the function
    wins = 0
    games_played = 0

    def update_networks(states, actions, rewards, next_states, dones):
        with torch.no_grad():
            next_values = critic_net(next_states).squeeze(1)
            target_values = rewards + gamma * next_values * (1 - dones)

        current_values = critic_net(states).squeeze(1)
        critic_loss = F.mse_loss(current_values, target_values)

        critic_optimizer.zero_grad()
        critic_loss.backward()
        critic_optimizer.step()

        action_probs = actor_net(states)
        advantages = (target_values - current_values).detach()

        action_log_probs = torch.log(action_probs + 1e-10)
        selected_action_log_probs = action_log_probs.gather(1, actions.unsqueeze(1)).squeeze(1)
        actor_loss = -(selected_action_log_probs * advantages).mean()

        actor_optimizer.zero_grad()
        actor_loss.backward()
        actor_optimizer.step()

        return critic_loss.item(), actor_loss.item()

    for episode in range(num_episodes):
        env.reset()
        chess_board = chess.Board()  # Reset board for new episode
        done = False
        total_reward = 0
        episode_data = []

        while not done:
            # Display current board state
            print(env.render())
            state_tensor = board_to_tensor(chess_board).unsqueeze(0)
            action_probs = actor_net(state_tensor)
            legal_moves = list(chess_board.legal_moves)

            if not legal_moves:
                break  # No legal moves, end the game

            action = select_legal_action(action_probs, legal_moves)
            action_idx = move_to_index(action, move_lookup)

            # Perform the move
            observation, reward_env, done, info = env.step(action)
            # Update the chess board
            chess_board = observation


            # Check for checkmate
            is_checkmate = chess_board.is_checkmate()
            
            if is_checkmate:
                wins += 1
            games_played += 1 # need to fix this. One epsisode is one game. but it showw alot of games played.
            # Need to create check statement for : 
            # Checkmate is achieved
            #No legal moves are available (stalemate)
            #The game reaches a terminal state            

            # Calculate enhanced reward
            reward = calculate_reward(
                board=chess_board,
                move=action,
                is_checkmate=is_checkmate
            )

            next_state_tensor = board_to_tensor(chess_board).unsqueeze(0)
            memory.push(state_tensor, action_idx, reward, next_state_tensor, done)

            if len(memory) > batch_size:
                states, actions, rewards, next_states, dones = memory.sample(batch_size)
                states = torch.cat(states)
                next_states = torch.cat(next_states)
                actions = torch.tensor(actions, dtype=torch.long)
                rewards = torch.tensor(rewards, dtype=torch.float32)
                dones = torch.tensor(dones, dtype=torch.float32)

                critic_loss, actor_loss = update_networks(
                    states, actions, rewards, next_states, dones
                )

                episode_data.append({
                    'critic_loss': critic_loss,
                    'actor_loss': actor_loss
                })

            total_reward += reward

        training_history.append({
            'episode': episode + 1,
            'total_reward': total_reward,
            'moves': episode_data
        })

        print(f"Episode {episode + 1}, Total Reward: {total_reward}")
        
        # Print statistics every N episodes (e.g., every 10 games)
        if episode % 10 == 0:
            win_rate = (wins / games_played) * 100
            print(f"Games played: {games_played}")
            print(f"Wins: {wins}")
            print(f"Win rate: {win_rate:.2f}%")

        if (episode + 1) % 100 == 0:
            torch.save(actor_net.state_dict(), f'{save_path}_actor.pth')
            torch.save(critic_net.state_dict(), f'{save_path}_critic.pth')
            np.save(f'{save_path}_history.npy', training_history)

    return actor_net, critic_net


In [63]:
class ChessEnv:
    def __init__(self):
        self.board = chess.Board()
        
    def get_state(self):
        # Convert board state to tensor format (768 input features)
        state = torch.zeros(768)
        for i, square in enumerate(chess.SQUARES):
            piece = self.board.piece_at(square)
            if piece:
                # 12 piece types (6 pieces * 2 colors) * 64 squares = 768 features
                piece_idx = (piece.piece_type - 1 + (6 if piece.color else 0)) * 64 + i
                state[piece_idx] = 1
        return state
        
    def step(self, action):
        # Convert model output to chess move
        move = self._action_to_move(action)
        
        # Make move and get reward
        if move in self.board.legal_moves:
            self.board.push(move)
            if self.board.is_game_over():
                if self.board.is_checkmate():
                    return 1.0, True  # Win
                return 0.0, True  # Draw
            return 0.0, False  # Game continues
        return -1.0, True  # Illegal move
        
    def _action_to_move(self, action):
        # Convert network output (4672) to chess move
        action_idx = torch.argmax(action).item()
        from_square = action_idx // 73
        to_square = action_idx % 73
        return chess.Move(from_square, to_square)


In [None]:
def play_game(model1, model2):
    chess_board = chess.Board()
    while not chess_board.is_game_over():
        # Model 1's move (White)
        state = board_to_tensor(chess_board).unsqueeze(0)
        action_probs1 = model1(state)
        move1 = select_legal_action(action_probs1, chess_board.legal_moves)
        chess_board.push(move1)

        if chess_board.is_game_over():
            break

        # Model 2's move (Black)
        state = board_to_tensor(chess_board).unsqueeze(0)
        action_probs2 = model2(state)
        move2 = select_legal_action(action_probs2, chess_board.legal_moves)
        chess_board.push(move2)

    result = chess_board.result()  # '1-0', '0-1', or '1/2-1/2'
    if result == '1-0':
        return 1  # Model 1 wins
    elif result == '0-1':
        return 2  # Model 2 wins
    else:
        return 0  # Draw


In [64]:
def tournament(models, num_episodes):
    num_models = len(models)
    scores = [0] * num_models
    print("Starting tournament with", num_models, "models")
    
    # Round-robin tournament
    for i in range(num_models):
        for j in range(i + 1, num_models):
            print(f"\nMatch between Model {i} and Model {j}")
            
            # Train each model before the match
            print(f"Training Model {i}...")
            models[i] = train_chess_ai(
                num_episodes=num_episodes,
                actor_net=models[i],
                critic_net=Critic(input_size=768),
                save_path=f'models/model_{i}'
            )[0]  # Get only the actor net
            
            print(f"Training Model {j}...")
            models[j] = train_chess_ai(
                num_episodes=num_episodes,
                actor_net=models[j],
                critic_net=Critic(input_size=768),
                save_path=f'models/model_{j}'
            )[0]  # Get only the actor net
            
            # Play multiple games between the trained models
            print(f"Playing matches between Model {i} and Model {j}")
            score_i = 0
            score_j = 0
            for game in range(10):  # Play 10 games per match
                result = play_game(models[i], models[j])
                if result == 1:
                    score_i += 1
                elif result == 2:
                    score_j += 1
                print(f"Game {game + 1} result: {'Model '+str(i) if result == 1 else 'Model '+str(j) if result == 2 else 'Draw'}")
            
            scores[i] += score_i
            scores[j] += score_j
            print(f"Match results - Model {i}: {score_i}, Model {j}: {score_j}")
    
    champion_idx = scores.index(max(scores))
    print(f"\nTournament complete!")
    print(f"Final scores: {scores}")
    print(f"Champion is Model {champion_idx}")
    return models[champion_idx]


In [65]:
# Model Initialization and Testing
input_size = 768  # Example input size (board state as a flat vector)
output_size = 4672  # Example output size (number of possible moves in chess)
actor_net = Actor(input_size=input_size, output_size=output_size)
critic_net = Critic(input_size=input_size)

In [66]:
# Test with Dummy Data
dummy_state = torch.rand(1, input_size)
action_probs = actor_net(dummy_state)
print(f"Action probabilities: {action_probs}")
state_value = critic_net(dummy_state)
print(f"State value: {state_value}")

Action probabilities: tensor([[0.0002, 0.0002, 0.0002,  ..., 0.0002, 0.0002, 0.0002]],
       grad_fn=<SoftmaxBackward0>)
State value: tensor([[0.1688]], grad_fn=<AddmmBackward0>)


In [69]:
# Run Training
# train_chess_ai(save_path='models/chess_ai')

# actor_net, critic_net, history = train_chess_ai(
#     num_episodes=1,
#     save_path='models/chess_ai',
#     batch_size=64,
#     gamma=0.99
# )

# Run tournament with more episodes for better training
num_models = 4  # Reduced number for testing
num_episodes = 10  # Increased episodes for better training
save_path = 'champion_model.pth'

# Create initial models
models = [Actor(input_size=768, output_size=4672) for _ in range(num_models)]

# Run tournament
champion = tournament(models, num_episodes=num_episodes)

# Save the champion model
torch.save(champion.state_dict(), save_path)


Starting tournament with 4 models

Match between Model 0 and Model 1
Training Model 0...
♜ ♞ ♝ ♛ ♚ ♝ ♞ ♜
♟ ♟ ♟ ♟ ♟ ♟ ♟ ♟
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
♙ ♙ ♙ ♙ ♙ ♙ ♙ ♙
♖ ♘ ♗ ♕ ♔ ♗ ♘ ♖
♜ ♞ ♝ ♛ ♚ ♝ ♞ ♜
♟ ♟ ♟ ♟ ♟ ♟ ♟ ♟
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ♘
♙ ♙ ♙ ♙ ♙ ♙ ♙ ♙
♖ ♘ ♗ ♕ ♔ ♗ ⭘ ♖
♜ ♞ ♝ ♛ ♚ ♝ ♞ ♜
♟ ♟ ♟ ⭘ ♟ ♟ ♟ ♟
⭘ ⭘ ⭘ ♟ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ♘
♙ ♙ ♙ ♙ ♙ ♙ ♙ ♙
♖ ♘ ♗ ♕ ♔ ♗ ⭘ ♖
♜ ♞ ♝ ♛ ♚ ♝ ♞ ♜
♟ ♟ ♟ ⭘ ♟ ♟ ♟ ♟
⭘ ⭘ ⭘ ♟ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
♙ ♙ ♙ ♙ ♙ ♙ ♙ ♙
♖ ♘ ♗ ♕ ♔ ♗ ♘ ♖
♜ ♞ ⭘ ♛ ♚ ♝ ♞ ♜
♟ ♟ ♟ ♝ ♟ ♟ ♟ ♟
⭘ ⭘ ⭘ ♟ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
♙ ♙ ♙ ♙ ♙ ♙ ♙ ♙
♖ ♘ ♗ ♕ ♔ ♗ ♘ ♖
♜ ♞ ⭘ ♛ ♚ ♝ ♞ ♜
♟ ♟ ♟ ♝ ♟ ♟ ♟ ♟
⭘ ⭘ ⭘ ♟ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ♙ ⭘ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
♙ ♙ ⭘ ♙ ♙ ♙ ♙ ♙
♖ ♘ ♗ ♕ ♔ ♗ ♘ ♖
♜ ♞ ⭘ ♛ ♚ ♝ ♞ ♜
♟ ♟ ♟ ⭘ ♟ ♟ ♟ ♟
⭘ ⭘ ⭘ ♟ ♝ ⭘ ⭘ ⭘
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ♙ ⭘ ⭘ ⭘ ⭘ ⭘
⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘ ⭘
♙ ♙ ⭘ ♙ ♙ ♙ ♙ ♙
♖ ♘ ♗ ♕ ♔ ♗ ♘ ♖
♜ ♞ ⭘ ♛ ♚ ♝ ♞ ♜

In [73]:
from datetime import datetime
def record_champion_game(champion_model, num_games=1, output_file="champion_games.txt"):
    with open(output_file, "w") as f:
        for game_num in range(num_games):
            chess_board = chess.Board()
            move_list = []
            
            f.write(f"Game {game_num + 1}\n")
            f.write("[Event \"Champion Model Analysis Game\"]\n")
            f.write(f"[Date \"{datetime.now().strftime('%Y.%m.%d')}\"]\n")
            f.write("[White \"ChampionModel\"]\n")
            f.write("[Black \"ChampionModel\"]\n\n")
            
            move_number = 1
            while not chess_board.is_game_over():
                # White's move
                state = board_to_tensor(chess_board).unsqueeze(0)
                action_probs = champion_model(state)
                move = select_legal_action(action_probs, chess_board.legal_moves)
                
                # Record move in standard algebraic notation
                move_san = chess_board.san(move)
                if chess_board.turn:  # White's move
                    move_list.append(f"{move_number}. {move_san}")
                else:  # Black's move
                    move_list.append(f"{move_san}")
                    move_number += 1
                
                chess_board.push(move)
                
            # Write the moves in PGN format
            moves_text = " ".join(move_list)
            f.write(f"{moves_text} {chess_board.result()}\n\n")
            
            print(f"Game {game_num + 1} recorded")

# Usage example:
record_champion_game(champion, num_games=5, output_file="champion_analysis_games.txt")


Game 1 recorded
Game 2 recorded
Game 3 recorded
Game 4 recorded
Game 5 recorded
