In [1]:
import os
import io
import math
import chess
import chess.pgn
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical
from tqdm import tqdm
from torch.utils.data import DataLoader, TensorDataset
from torch.amp import autocast, GradScaler
from sqlalchemy import create_engine, Column, Integer, String, Text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
torch.cuda.is_available()

Using device: cuda


True

In [2]:
Base = declarative_base()

class ChessGame(Base):
    __tablename__ = 'games'
    id = Column(Integer, primary_key=True)
    pgn = Column(Text)

engine = create_engine('sqlite:///../chess_games.db')
Session = sessionmaker(bind=engine)
session = Session()

  Base = declarative_base()


In [3]:
def load_openings_from_pgn(pgn_file):
    openings = []
    with open(pgn_file) as f:
        while True:
            game = chess.pgn.read_game(f)
            if game is None:
                break

            board = game.board()
            moves = []
            for move in game.mainline_moves():
                moves.append(board.san(move))
                board.push(move)

            openings.append(moves)
    return openings

#Function to find matching openings based on played moves
def find_matching_openings(played_moves, openings):
    """
    Finding opening works by checking if the played moves match the start of any opening in the opening book.
    This is because a chosen opening can be diverged from at any point by the opponent.
    This makes the bot more dynamic in the opening phase.
    """
    matching_openings = []
    for opening in openings:
        if played_moves == opening[:len(played_moves)]:
            matching_openings.append(opening)
    return matching_openings

#Function to choose the next move from matching openings
def select_next_move(played_moves, matching_openings):
    if not matching_openings:
        return None  # No matching opening found, time for engine

    # Check if there is a next move available in the matching opening
    for opening in matching_openings:
        if len(opening) > len(played_moves):
            next_move = opening[len(played_moves)]
            return next_move
    
    return None  # No more moves in the opening book, time for engine

#Load the openings
openings = load_openings_from_pgn("eco.pgn")

#Example usage
played_moves = ['e4']  
matching_openings = find_matching_openings(played_moves, openings)
next_move = select_next_move(played_moves, matching_openings)

if next_move:
    print(f"Bot's next move: {next_move}")
else:
    print("No matching opening found, calculate the move using engine logic.")

Bot's next move: g6


In [4]:
class ChessNet(nn.Module):
    def __init__(self):
        super(ChessNet, self).__init__()

        self.conv1 = nn.Conv2d(12, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm2d(128)
        
        #Positional encoding for the transformer
        self.positional_encoding = PositionalEncoding(d_model=128)

        #Transformer Encoder
        encoder_layers = nn.TransformerEncoderLayer(d_model=128, nhead=8)
        self.transformer = nn.TransformerEncoder(encoder_layers, num_layers=2)

        #Fully connected layer
        self.fc1 = nn.Linear(8*8*128, 4096)  #4096 possible moves

    def forward(self, x):

        x = self.conv1(x)
        x = self.bn1(x)  
        x = F.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)  
        x = F.relu(x)

        x = self.conv3(x)
        x = self.bn3(x)  
        x = F.relu(x)

        x = x.view(-1, 128, 8*8)  #[batch_size, d_model, sequence_length]
        x = x.permute(2, 0, 1)  #[sequence_length, batch_size, d_model]

        # Positional encoding
        x = self.positional_encoding(x)

        # Transformer encoder
        x = self.transformer(x)

        x = x.permute(1, 0, 2).contiguous()  #[batch_size, sequence_length, d_model]
        x = x.view(-1, 8*8*128)
        x = self.fc1(x)
        return x

#Positional encoding for the transformer in order to give the model information about the position of the pieces
#Uses the sine and cosine functions to encode the position of the board in a unique way
#Experimental, might be overkill. Saw somewhere it could be useful for the transformer, but not sure if it is properly implemented here
class PositionalEncoding(nn.Module): 
    def __init__(self, d_model, max_len=64):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)
        self.encoding.requires_grad = False

        pos = torch.arange(0, max_len).float().unsqueeze(1)
        _2i = torch.arange(0, d_model, 2).float()

        self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
        self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))
        self.encoding = self.encoding.unsqueeze(1)

    def forward(self, x):
        return x + self.encoding[:x.size(0), :].to(x.device)

In [5]:
def board_to_input(board):
    board_planes = np.zeros((8, 8, 12), dtype=np.float32)
    for square in chess.SQUARES:
        piece = board.piece_at(square)
        if piece:
            plane = piece.piece_type - 1
            if piece.color == chess.BLACK:
                plane += 6
            row, col = divmod(square, 8)
            board_planes[row, col, plane] = 1
    return board_planes

#Encode the move
def move_to_output(move):
    from_square = move.from_square
    to_square = move.to_square
    return from_square * 64 + to_square

def calculate_accuracy(output, target):
    _, predicted = torch.max(output, 1)
    correct = (predicted == target).sum().item()
    return correct / target.size(0)

#Training step with move skipping and batching
def train_on_batch(games, model, optimizer, criterion, device, skip_moves=10):
    all_board_inputs = []
    all_targets = []
    total_moves = 0

    for game_str in games:
        pgn_io = io.StringIO(game_str)
        game = chess.pgn.read_game(pgn_io)
        board = game.board()
        move_count = 0

        for move in game.mainline_moves():
            #A static number of moves are skipped to avoid overfitting to the opening book
            #More sophisticated methods can be used to skip exact amount of book moves, but it is too inefficient for my machine
            if move_count < skip_moves:
                board.push(move)
                move_count += 1
                continue

            #Prepare the input and output
            board_input = board_to_input(board)
            board_input = torch.tensor(board_input, dtype=torch.float32).unsqueeze(0).permute(0, 3, 1, 2).to(device)
            actual_output = move_to_output(move)
            actual_output = torch.tensor([actual_output], dtype=torch.long).to(device)

            all_board_inputs.append(board_input)
            all_targets.append(actual_output)
            total_moves += 1

            #Update the board with the actual move
            board.push(move)

    if all_board_inputs:
        #Stack all inputs and targets
        batch_inputs = torch.cat(all_board_inputs, dim=0)
        batch_targets = torch.cat(all_targets, dim=0)

        optimizer.zero_grad()
        output = model(batch_inputs)

        loss = criterion(output, batch_targets)
        accuracy = calculate_accuracy(output, batch_targets)

        loss.backward()
        optimizer.step()

        return loss.item(), accuracy, total_moves
    else:
        return 0, 0, 0  #If no valid moves in batch

#Training loop
batch_size = 1000
game_batch_size = 16 
offset = 0
step = 0
j = 1

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ChessNet().to(device)
#model.load_state_dict(torch.load('savedModels/cnn_transformer_model_epoch_1.pth')) #Load the model from the previous training session
model.train()
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

while True:
    games = session.query(ChessGame).offset(offset).limit(batch_size).all()
    if not games:
        break

    total_loss = 0.0
    total_accuracy = 0.0
    total_moves = 0
    

    with tqdm(total=len(games) // game_batch_size, desc=f"Processing Batch {offset // batch_size + 1}") as pbar:
        for i in range(0, len(games), game_batch_size):
            game_batch = [game.pgn for game in games[i:i + game_batch_size]]
            loss, accuracy, moves = train_on_batch(game_batch, model, optimizer, criterion, device, skip_moves=10)
            total_loss += loss * moves
            total_accuracy += accuracy * moves
            total_moves += moves
            pbar.update(1)
            if total_moves > 0:
                pbar.set_postfix({'Loss': total_loss / total_moves, 'Accuracy': total_accuracy / total_moves})
    
    j += 1
    if j % 10 == 0:
        model_save_path = os.path.join('savedModels', f'cnn_transformer_model_epoch_{j}.pth')
        torch.save(model.state_dict(), model_save_path)
    offset += batch_size
    
#Close the session and TensorBoard writer
#Still have not tried TensorBoard, might not work
session.close()
model_save_path = os.path.join('savedModels', f'cnn_transformer_model_final.pth')
torch.save(model.state_dict(), model_save_path)


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
Processing Batch 1: 63it [00:45,  1.38it/s, Loss=7.35, Accuracy=0.0152]                         
Processing Batch 2: 63it [00:31,  2.02it/s, Loss=5.42, Accuracy=0.0669]                        
Processing Batch 3: 63it [00:26,  2.39it/s, Loss=4.67, Accuracy=0.0981]                        
Processing Batch 4: 63it [00:29,  2.12it/s, Loss=4.38, Accuracy=0.113]                        
Processing Batch 5: 63it [00:28,  2.22it/s, Loss=4.19, Accuracy=0.125]                        
Processing Batch 6: 63it [02:52,  2.74s/it, Loss=4.07, Accuracy=0.133]                        
Processing Batch 7: 63it [02:14,  2.13s/it, Loss=3.95, Accuracy=0.142]                        
Processing Batch 8: 63it [00:28,  2.23it/s, Loss=3.89, Accuracy=0.146]                        
Processing Batch 9: 63it [00:29,  2.16it/s, Loss=3.82, Accuracy=0.151]                        
Processing Batch 10: 63it [00:28,  2.21it/s, Loss=3.73

KeyboardInterrupt: 

In [8]:
###### RL PART ######

model = ChessNet().to(device)
model.load_state_dict(torch.load('savedModels/cnn_transformer_model_epoch_60.pth'))


def board_to_input(board):
    board_planes = torch.zeros((8, 8, 12), dtype=torch.float32)
    for square in chess.SQUARES:
        piece = board.piece_at(square)
        if piece:
            plane = piece.piece_type - 1
            if piece.color == chess.BLACK:
                plane += 6
            row, col = divmod(square, 8)
            board_planes[row, col, plane] = 1
    board_input = board_planes.unsqueeze(0).permute(0, 3, 1, 2)
    return board_input

def move_to_index(move):
    from_square = move.from_square
    to_square = move.to_square
    return from_square * 64 + to_square

# Function to select a move given the current board position. Not currently used in the training loop whilst trying to fix some stuff
def select_move(model, board):
    state = board_to_input(board)
    state = state.unsqueeze(0) 

    
    logits = model(state) 
    probabilities = torch.softmax(logits, dim=1)

    legal_moves = list(board.legal_moves)
    legal_indices = [move_to_index(move) for move in legal_moves]

    mask = torch.zeros_like(probabilities)
    mask[:, legal_indices] = 1  # Mask out illegal moves

    legal_probabilities = probabilities * mask  # Apply the mask
    legal_probabilities = legal_probabilities / legal_probabilities.sum(dim=1, keepdim=True)  # Re-normalize

    m = Categorical(legal_probabilities)
    # Sample from moves based on the distribution of probabilities.
    # In the early stages of training, the bot will explore many moves that it does not particularly prefer, 
    # but as the training progresses, the bot will start to get better, the distribution of probabilities of moves will more greatly favor the better moves, 
    # leading to less exploration.
    move_idx = m.sample() 

    selected_move = legal_moves[move_idx.item()]
    return selected_move, m.log_prob(move_idx)

# Function to simulate a batch of games of self-play
def play_batch_games(model, batch_size):
    boards = [chess.Board() for _ in range(batch_size)]
    log_probs = [[] for _ in range(batch_size)]
    rewards = [[] for _ in range(batch_size)]
    previous_material_balances = [material_balance(board) for board in boards]
    checkmate_count = 0
    while any(not board.is_game_over() for board in boards):
        active_indices = [i for i, board in enumerate(boards) if not board.is_game_over()]
        states = [board_to_input(boards[i]) for i in active_indices]
        states = torch.cat(states).to(device)

        
        logits = model(states) 
        probabilities = torch.softmax(logits, dim=1)

        for idx, i in enumerate(active_indices):
            legal_moves = list(boards[i].legal_moves)
            if len(legal_moves) == 0:
                continue  # Skip if no legal moves are available

            legal_indices = [move_to_index(move) for move in legal_moves]
            mask = torch.zeros_like(probabilities[idx])
            mask[legal_indices] = 1
            legal_probabilities = probabilities[idx] * mask

            if legal_probabilities.sum() == 0:
                print("All legal probabilities are zero")
                legal_probabilities = mask  # fallback to uniform distribution over legal moves

            legal_probabilities = legal_probabilities / legal_probabilities.sum(dim=0, keepdim=True)

            m = Categorical(legal_probabilities)
            move_idx = m.sample()

            log_prob = m.log_prob(move_idx) 
            log_probs[i].append(log_prob)

            move_idx_in_legal_moves = legal_indices.index(move_idx.item()) if move_idx.item() in legal_indices else None

            if move_idx_in_legal_moves is None:
                continue  # Safeguard against out-of-bound indices

            selected_move = legal_moves[move_idx_in_legal_moves]
            boards[i].push(selected_move)

            if boards[i].is_checkmate():
                checkmate_count += 1  # Increment checkmate counter
                #Greatly reward the bot for checkmating the opponent
                rewards[i].append(5)
                break

            # If too many moves, break
            if len(rewards[i]) > 200:
                boards[i].push(chess.Move.null())
                # Penalize the bot for taking too many moves
                rewards[i][-1] -= 0.5
                break

            current_material_balance = material_balance(boards[i])
            reward = current_material_balance - previous_material_balances[i]

            if boards[i].turn == chess.WHITE:  # Bot just played as black
                rewards[i].append(-reward)  # Negative reward if bot is black (after white's move)
            else:  # Bot just played as white
                rewards[i].append(reward)  # Positive reward if bot is white (after black's move)

            previous_material_balances[i] = current_material_balance

    for i in range(batch_size):
        result = boards[i].result()
        
        for j in range(len(rewards[i])):
            rewards[i][j] -= 0.01 # Penalize each move to try make the bot not do many unnecessary moves

        if result == '1-0':  # White wins
            if len(rewards[i]) % 2 == 1:  
                rewards[i][-1] += 1  
                rewards[i][-2] -= 1 

        elif result == '0-1':  # Black wins
            if len(rewards[i]) % 2 == 1:  
                rewards[i][-1] += 1  
                rewards[i][-2] -= 1  

        elif result == '1/2-1/2':  # Draw
            rewards[i][-1] += 0.5  
            if len(rewards[i]) > 1:
                rewards[i][-2] += 0.5  

    del states, logits, probabilities
    torch.cuda.empty_cache()

    return log_probs, rewards, checkmate_count



def material_balance(board):
    white_material = 0
    black_material = 0
    
    for square in chess.SQUARES:
        piece = board.piece_at(square)
        if piece is not None:
            value = piece_value(piece)
            if piece.color == chess.WHITE:
                white_material += value
            else:
                black_material += value

    return white_material - black_material  # Positive if white has more material

# Function to assign value to pieces TODO: fix right values since these are maybe not good
def piece_value(piece):
    if piece is None:
        return 0
    elif piece.piece_type == chess.PAWN:
        return 0.1
    elif piece.piece_type in [chess.KNIGHT, chess.BISHOP]:
        return 0.3
    elif piece.piece_type == chess.ROOK:
        return 0.5
    elif piece.piece_type == chess.QUEEN:
        return 0.9
    return 0

def update_policy_batch(log_probs_batch, rewards_batch, optimizer, gamma=0.99):
    policy_loss = 0

    for log_probs, rewards in zip(log_probs_batch, rewards_batch):
        R = 0
        returns = []
        for r in rewards[::-1]:
            R = r + gamma * R
            returns.insert(0, R)

        returns = torch.tensor(returns, dtype=torch.float32, requires_grad=False).to(device) # requires_grad=False since having problems with torch.no_grad during self-play. Trying to fix it
        returns = (returns - returns.mean()) / (returns.std() + 1e-5)  # Normalize returns

        for log_prob, R in zip(log_probs, returns):
            policy_loss += -log_prob * R

    return policy_loss

def train_self_play_batch(model, optimizer, num_episodes=1000, batch_size=4, accumulation_steps=4):
    model.train()
    scaler = GradScaler()  # Initialize the gradient scaler for mixed precision
    total_checkmates = 0
    optimizer.zero_grad() 

    for episode in tqdm(range(num_episodes), desc="Training Progress"):
        episode_loss = 0

        for _ in range(accumulation_steps):
            log_probs_batch, rewards_batch, checkmate_count = play_batch_games(model, batch_size)
            
            # Ensure autocast is used only during forward pass
            with autocast(device_type="cuda"):  # Enable mixed precision during policy update
                policy_loss = update_policy_batch(log_probs_batch, rewards_batch, optimizer)

                # Accumulate the loss
                episode_loss += policy_loss / accumulation_steps

        # Check that episode_loss requires gradients
       # assert episode_loss.requires_grad, "episode_loss does not require gradients."

        # Backward pass with scaled gradients after accumulation
        scaler.scale(episode_loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        total_checkmates += checkmate_count  #Count checkmates as metric for progress in early stages

        if episode % 10 == 0:
            avg_reward = sum(sum(rewards) for rewards in rewards_batch) / batch_size
            print(f"Episode {episode} complete - Average Reward: {avg_reward:.2f}, Checkmates: {total_checkmates / (episode + 1) * batch_size:.2f}%")
            total_checkmates = 0

        if episode % 100 == 0:
            model_save_path = os.path.join('savedModels', f'cnn_transformer_model_rl_{episode}.pth')
            torch.save(model.state_dict(), model_save_path)

optimizer = optim.Adam(model.parameters(), lr=0.001)

train_self_play_batch(model, optimizer, num_episodes=1000, batch_size=16, accumulation_steps=16)

  model.load_state_dict(torch.load('savedModels/cnn_transformer_model_epoch_60.pth'))
Training Progress:   0%|          | 0/1000 [00:17<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 8.00 GiB of which 0 bytes is free. Of the allocated memory 14.38 GiB is allocated by PyTorch, and 10.21 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Assuming your ChessNet model is defined elsewhere
model = ChessNet().to(device)
model.load_state_dict(torch.load('savedModels/cnn_transformer_model_epoch_10.pth'))

scaler = GradScaler()

# Functions to convert board states and moves
def board_to_input(board):
    board_planes = torch.zeros((8, 8, 12), dtype=torch.float32)
    for square in chess.SQUARES:
        piece = board.piece_at(square)
        if piece:
            plane = piece.piece_type - 1
            if piece.color == chess.BLACK:
                plane += 6
            row, col = divmod(square, 8)
            board_planes[row, col, plane] = 1
    board_input = board_planes.unsqueeze(0).permute(0, 3, 1, 2)
    return board_input.to(device)

def move_to_index(move):
    from_square = move.from_square
    to_square = move.to_square
    return from_square * 64 + to_square

def piece_value(piece):
    if piece is None:
        return 0
    elif piece.piece_type == chess.PAWN:
        return 1
    elif piece.piece_type in [chess.KNIGHT, chess.BISHOP]:
        return 3
    elif piece.piece_type == chess.ROOK:
        return 5
    elif piece.piece_type == chess.QUEEN:
        return 9
    return 0

def material_balance(board):
    white_material = 0
    black_material = 0
    
    for square in chess.SQUARES:
        piece = board.piece_at(square)
        if piece is not None:
            value = piece_value(piece)
            if piece.color == chess.WHITE:
                white_material += value
            else:
                black_material += value

    return white_material - black_material  # Positive if white has more material

def center_control(board):
    central_squares = [chess.D4, chess.E4, chess.D5, chess.E5]
    control_score = 0

    for square in central_squares:
        attackers = board.attackers(chess.WHITE, square)
        defenders = board.attackers(chess.BLACK, square)

        if len(attackers) > len(defenders):
            control_score += 0.1
        elif len(attackers) < len(defenders):
            control_score -= 0.1

    return control_score

def select_move(model, board):
    state = board_to_input(board)
    state = state.unsqueeze(0)  # Add batch dimension

    with torch.no_grad():
        logits = model(state)
    probabilities = torch.softmax(logits, dim=1)

    legal_moves = list(board.legal_moves)
    legal_indices = [move_to_index(move) for move in legal_moves]

    mask = torch.zeros_like(probabilities)
    mask[:, legal_indices] = 1  # Mask out illegal moves

    legal_probabilities = probabilities * mask  # Apply the mask
    legal_probabilities = legal_probabilities / legal_probabilities.sum(dim=1, keepdim=True)  # Re-normalize

    m = Categorical(legal_probabilities)
    move_idx = m.sample()

    selected_move = legal_moves[move_idx.item()]
    return selected_move, m.log_prob(move_idx)

def dynamic_gamma(move_count):
    if move_count < 20:
        return 0.95  # Early game
    elif move_count < 40:
        return 0.98  # Mid game
    else:
        return 0.99  # End game

def play_batch_games(model, batch_size):
    boards = [chess.Board() for _ in range(batch_size)]
    log_probs = [[] for _ in range(batch_size)]
    rewards = [[] for _ in range(batch_size)]
    previous_material_balances = [material_balance(board) for board in boards]
    checkmate_count = 0
    move_count = 0

    while any(not board.is_game_over() for board in boards):
        active_indices = [i for i, board in enumerate(boards) if not board.is_game_over()]
        states = torch.cat([board_to_input(boards[i]) for i in active_indices])

        logits = model(states)
        probabilities = torch.softmax(logits, dim=1)

        for idx, i in enumerate(active_indices):
            legal_moves = list(boards[i].legal_moves)
            if len(legal_moves) == 0:
                continue  # Skip if no legal moves are available

            legal_indices = [move_to_index(move) for move in legal_moves]
            mask = torch.zeros_like(probabilities[idx])
            mask[legal_indices] = 1
            legal_probabilities = probabilities[idx] * mask

            if legal_probabilities.sum() == 0:
                legal_probabilities = mask  # fallback to uniform distribution over legal moves

            legal_probabilities = legal_probabilities / legal_probabilities.sum(dim=0, keepdim=True)

            m = Categorical(legal_probabilities)
            move_idx = m.sample()

            log_prob = m.log_prob(move_idx)
            log_probs[i].append(log_prob)

            move_idx_in_legal_moves = legal_indices.index(move_idx.item()) if move_idx.item() in legal_indices else None

            if move_idx_in_legal_moves is None:
                continue  # Safeguard against out-of-bound indices

            selected_move = legal_moves[move_idx_in_legal_moves]
            boards[i].push(selected_move)

            if boards[i].is_checkmate():
                checkmate_count += 1  # Increment checkmate counter
                break

            current_material_balance = material_balance(boards[i])
            current_control_score = center_control(boards[i])
            reward = current_material_balance - previous_material_balances[i] + current_control_score

            if boards[i].turn == chess.WHITE:  # Bot just played as black
                rewards[i].append(-reward)  # Negative reward if bot is black (after white's move)
            else:  # Bot just played as white
                rewards[i].append(reward)  # Positive reward if bot is white (after black's move)

            previous_material_balances[i] = current_material_balance
            move_count += 1

    for i in range(batch_size):
        result = boards[i].result()

        for j in range(len(rewards[i])):
            rewards[i][j] -= 0.01  # Penalize each move to try make the bot not do many unnecessary moves

        if result == '1-0':  # White wins
            if len(rewards[i]) % 2 == 1:
                rewards[i][-1] += 1
                rewards[i][-2] -= 1

        elif result == '0-1':  # Black wins
            if len(rewards[i]) % 2 == 1:
                rewards[i][-1] += 1
                rewards[i][-2] -= 1

        elif result == '1/2-1/2':  # Draw
            rewards[i][-1] += 0.5
            if len(rewards[i]) > 1:
                rewards[i][-2] += 0.5

    return log_probs, rewards, checkmate_count, move_count

def update_policy_batch(log_probs_batch, rewards_batch, optimizer, move_count):
    optimizer.zero_grad()
    gamma = dynamic_gamma(move_count)

    for log_probs, rewards in zip(log_probs_batch, rewards_batch):
        R = 0
        returns = []
        for r in rewards[::-1]:
            R = r + gamma * R
            returns.insert(0, R)

        returns = torch.tensor(returns).to(device)
        returns = (returns - returns.mean()) / (returns.std() + 1e-5)  # Normalize returns

        for log_prob, R in zip(log_probs, returns):
            loss = -log_prob * R
            scaler.scale(loss).backward(retain_graph=True)

    scaler.step(optimizer)
    scaler.update()

def train_self_play_batch(model, optimizer, num_episodes=1000, batch_size=8, accumulate_grad_steps=4):
    model.train()
    total_checkmates = 0

    for episode in tqdm(range(num_episodes), desc="Training Progress"):
        log_probs_batch, rewards_batch, checkmate_count, move_count = play_batch_games(model, batch_size)
        update_policy_batch(log_probs_batch, rewards_batch, optimizer, move_count)

        total_checkmates += checkmate_count  # Count checkmates as metric for progress in early stages

        if episode % 10 == 0:
            avg_reward = sum(sum(rewards) for rewards in rewards_batch) / batch_size
            print(f"Episode {episode} complete - Average Reward: {avg_reward:.2f}, Checkmates: {total_checkmates/(episode+1)*batch_size:.2f}%")
            total_checkmates = 0

        if episode % 100 == 0:
            model_save_path = os.path.join('savedModels', f'cnn_transformer_model_rl_{episode}.pth')
            torch.save(model.state_dict(), model_save_path)
            

# Initialize optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Start training
train_self_play_batch(model, optimizer, num_episodes=1000, batch_size=8, accumulate_grad_steps=4)

  model.load_state_dict(torch.load('savedModels/cnn_transformer_model_epoch_10.pth'))
Training Progress:   0%|          | 0/1000 [04:32<?, ?it/s]


KeyboardInterrupt: 

In [None]:
model = ChessNet()
model.load_state_dict(torch.load('savedModels/cnn_transformer_model_epoch_1.pth'))
model.eval()

board = chess.Board()

def board_to_input(board):
    board_planes = torch.zeros((8, 8, 12), dtype=torch.float32)
    for square in chess.SQUARES:
        piece = board.piece_at(square)
        if piece:
            plane = piece.piece_type - 1
            if piece.color == chess.BLACK:
                plane += 6
            row, col = divmod(square, 8)
            board_planes[row, col, plane] = 1
    board_input = board_planes.unsqueeze(0).permute(0, 3, 1, 2) 
    return board_input

def predict_move(model, board):
    board_input = board_to_input(board)
    with torch.no_grad():
        output = model(board_input)
    
    move_scores = output.squeeze().sort(descending=True)
    move_indices = move_scores.indices.tolist()
    
    for move_index in move_indices:
        from_square = move_index // 64
        to_square = move_index % 64
        move = chess.Move(from_square, to_square)
        
        if move in board.legal_moves:
            return move
    
    # If no valid moves are found (which shouldn't happen), return None
    return None


while not board.is_game_over():
    display(board)  
    user_move = input("Your move: ")

    
    try:
        move = chess.Move.from_uci(user_move)
        if move in board.legal_moves:
            board.push(move)
        else:
            print("Invalid move. Try again.")
            continue
    except ValueError:
        print("Invalid format. Use UCI format (e.g., e2e4).")
        continue

    if board.is_game_over():
        break

    #Get the bots move
    bot_move = predict_move(model, board)
    if bot_move:
        board.push(bot_move)
        print(f"Bot's move: {bot_move}")
    else:
        print("Bot could not find a valid move.")
        break

print("Game over!")
print(f"Result: {board.result()}")
