**Imports**

In [13]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0)
from tqdm.notebook import trange
import random
import math
import chess
import chess.engine
import matplotlib as plt 
from tqdm import tqdm

In [14]:
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)} is available.")
else:
    print("No GPU available. Training will run on CPU.")

GPU: NVIDIA GeForce RTX 4080 Laptop GPU is available.


**Chess Game**

In [2]:
import chess
import numpy as np

class ChessGame:
    def __init__(self):
        self.board = chess.Board()
        self.action_size = 1000000  # Max possible moves incl. promotions

    def get_initial_state(self):
        return self.board.fen()  # FEN is a standard notation for representing the current state of the chess board
    
    def get_next_state(self, state, action, player):
        """Update the board with the chosen action.""" 
        board = chess.Board(state)
        move = chess.Move.from_uci(action)
        board.push(move)
        return board.fen()
    
    def get_valid_moves(self, state):
        """Get a binary mask of valid moves.""" 
        board = chess.Board(state)
        legal_moves = list(board.legal_moves)
        valid_moves = np.zeros(self.action_size, dtype=np.uint8)
        for move in legal_moves:
            move_idx = self.move_to_index(move)
            valid_moves[move_idx] = 1
        return valid_moves
    
    def get_opponent(self, player):
        """Toggle between players (1 for white, -1 for black)."""
        return -player
    
    def get_opponent_value(self, value):
        return -value

    def check_win(self, state, action):
        board = chess.Board(state)
        return board.is_game_over()
    
    def get_value_and_terminated(self, state, action):
        board = chess.Board(state)
        if board.is_checkmate():
            return 1, True
        elif board.is_stalemate() or board.is_insufficient_material():
            return 0, True
        return 0, False
    
    def get_encoded_state(self, state):
        """Encode the board state into a tensor format for input to the neural network."""
        board = chess.Board(state)
        board_tensor = np.zeros((13, 8, 8), dtype=np.float32)  # Shape [13, 8, 8]
        piece_map = board.piece_map()
        
        for square, piece in piece_map.items():
            piece_type = piece.piece_type - 1 if piece.color else piece.piece_type + 5 
            board_tensor[piece_type, square // 8, square % 8] = 1 
            
        # Set empty squares to the 12th channel
        for row in range(8):
            for col in range(8):
                if board_tensor[:, row, col].sum() == 0:
                    board_tensor[12, row, col] = 1  # Mark as empty
        
        # Convert to torch tensor and move to device
        board_tensor = torch.FloatTensor(board_tensor).to(device)
    
        return board_tensor

    
    def move_to_index(self, move):
        """Convert a move to a unique index for the action space.""" 
        uci_move = move.uci()
        from_square = chess.SQUARE_NAMES.index(uci_move[:2])
        to_square = chess.SQUARE_NAMES.index(uci_move[2:4])
        promotion = move.promotion or 0
        return from_square * 64 + to_square + promotion * 64 * 64

    def index_to_move(self, index):
        """Convert an index back to a move.""" 
        promotion = index // (64 * 64)
        index %= (64 * 64)
        from_square = index // 64
        to_square = index % 64
        move = chess.Move(from_square, to_square, promotion)
        return move.uci()
    
    def change_perspective(self, state, player):
        """Change the perspective of the board state (if necessary)."""
        # This function may need to be defined based on your architecture and how you intend to handle player perspectives.
        return state  # No need to modify state representation for this method


**ResNet**

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResNet(nn.Module):
    def __init__(self, game, num_resBlock, num_hidden):  # num_hidden is the hidden size of conv blocks
        super().__init__()

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Updated to accept 13 input channels
        self.startBlock = nn.Sequential(
            nn.Conv2d(13, num_hidden, kernel_size=3, padding=1),  # 6 white, 6 black, 1 empty = 13
            nn.BatchNorm2d(num_hidden),
            nn.ReLU()
        )
        
        self.backBone = nn.ModuleList(
            [ResBlock(num_hidden) for _ in range(num_resBlock)]  # Using num_hidden for ResBlock
        )

        self.policyHead = nn.Sequential(
            nn.Conv2d(num_hidden, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * 8 * 8, game.action_size),  # Assuming output size is compatible
        )

        self.valueHead = nn.Sequential(
            nn.Conv2d(num_hidden, 3, kernel_size=3, padding=1),
            nn.BatchNorm2d(3),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3 * 8 * 8, 1),
            nn.Tanh()
        )

    def forward(self, x):
        x = x.to(self.device)  # Move input to the appropriate device (GPU/CPU)
        x = self.startBlock(x)  # Pass through the initial block
        for res_block in self.backBone:  # Pass through each ResBlock
            x = res_block(x)
        policy = self.policyHead(x)  # Get policy predictions
        value = self.valueHead(x)  # Get value predictions
        return policy, value

class ResBlock(nn.Module):
    def __init__(self, num_hidden):
        super().__init__()
        self.conv1 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(num_hidden)
        self.conv2 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(num_hidden)

    def forward(self, x):  # x is input
        residual = x  # Save the input for the skip connection
        x = F.relu(self.bn1(self.conv1(x)))  # First convolution + batch norm + ReLU
        x = self.bn2(self.conv2(x))  # Second convolution + batch norm
        x += residual  # Add the input back for the skip connection
        x = F.relu(x)  # ReLU activation after adding residual
        return x


**MCTS**

In [4]:
class Node:
    def __init__(self, game, args, state, parent=None, action_taken=None, prior=0, visit_count=0):
        self.game = game  # ChessGame object
        self.args = args
        self.state = state  # FEN string or any other chess state representation
        self.parent = parent
        self.action_taken = action_taken
        self.prior = prior
        
        self.children = []
        
        self.visit_count = visit_count
        self.value_sum = 0
        
    def is_fully_expanded(self):
        return len(self.children) > 0
    
    def select(self):
        """Select the best child node based on UCB score."""
        best_child = None
        best_ucb = -np.inf
        
        for child in self.children:
            ucb = self.get_ucb(child)
            if ucb > best_ucb:
                best_child = child
                best_ucb = ucb
                
        return best_child
    
    def get_ucb(self, child):
        """Calculate UCB for a child node."""
        if child.visit_count == 0:
            q_value = 0
        else:
            # Scale value to [0, 1] from [-1, 1]
            q_value = 1 - ((child.value_sum / child.visit_count) + 1) / 2
        return q_value + self.args['C'] * (math.sqrt(self.visit_count) / (child.visit_count + 1)) * child.prior
    
    def expand(self, policy):
        """Expand the current node by adding children for each valid move."""
        for action, prob in enumerate(policy):
            if prob > 0:
                # Get the next state by applying the action (move)
                child_state = self.state  # Using FEN or other format
                child_state = self.game.get_next_state(child_state, self.game.index_to_move(action), 1)  # Apply action (chess move)
                child_state = self.game.change_perspective(child_state, player=-1)

                # Create a new child node and add it to the children list
                child = Node(self.game, self.args, child_state, self, action, prob)
                self.children.append(child)
                
        return child  # Return last expanded child node (optional)
            
    def backpropagate(self, value):
        """Update the current node and propagate the result back up to the root."""
        self.value_sum += value
        self.visit_count += 1
        
        # Propagate the value to the parent, changing perspective (opponent's value)
        value = self.game.get_opponent_value(value)
        if self.parent is not None:
            self.parent.backpropagate(value)


In [5]:
class MCTS:
    def __init__(self, game, args, model):
        self.game = game
        self.args = args
        self.model = model
        
    @torch.no_grad()
    def search(self, state):
        # Initialize the root node of the MCTS tree
        root = Node(self.game, self.args, state, visit_count=1)
        
        # Get the policy and value from the model
        policy, _ = self.model(
        torch.tensor(self.game.get_encoded_state(state), device=self.model.device).unsqueeze(0).to(self.model.device)
        )
        
        # Apply softmax to policy
        policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy()
        
        # Add Dirichlet noise for exploration
        policy = (1 - self.args['dirichlet_epsilon']) * policy + self.args['dirichlet_epsilon'] \
            * np.random.dirichlet([self.args['dirichlet_alpha']] * self.game.action_size)
        
        # Get valid moves and update policy
        valid_moves = self.game.get_valid_moves(state)
        policy *= valid_moves
        if np.sum(policy)>0:
            policy /= np.sum(policy)  # Normalize the policy
        else:
            policy = valid_moves / np.sum(valid_moves)
        # Expand the root node with the computed policy
        root.expand(policy)
        
        # Perform the search for a number of iterations
        for search in range(self.args['num_searches']):
            node = root
            
            # Traverse down the tree until an unexpanded node is found
            while node.is_fully_expanded():
                node = node.select()
                
            # Get the value and check if the node is terminal
            value, is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken)
            value = self.game.get_opponent_value(value)  # Convert value for the opponent
            
            # If the node is not terminal, expand further
            if not is_terminal:
                policy, value = self.model(
                    torch.tensor(self.game.get_encoded_state(node.state), device=self.model.device).unsqueeze(0)
                )
                
                policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy()
                valid_moves = self.game.get_valid_moves(node.state)
                
                # Update policy based on valid moves
                policy *= valid_moves
                if np.sum(policy) > 0:
                    policy /= np.sum(policy)  # Normalize the policy
                else:
                    policy = valid_moves / np.sum(valid_moves)
                
                value = value.item()  # Get the value as a scalar
                
                # Expand the node with the new policy
                node.expand(policy)
                
            # Backpropagate the value up to the root node
            node.backpropagate(value)
            
        # Collect visit counts from the root's children for action probabilities
        action_probs = np.zeros(self.game.action_size)
        for child in root.children:
            action_probs[child.action_taken] = child.visit_count
            
        action_probs /= np.sum(action_probs)  # Normalize the action probabilities
        return action_probs


**AlphaZero Training loop**

In [6]:
class AlphaZero:
    def __init__(self, model, optimizer, game, args):
        self.model = model
        self.optimizer = optimizer
        self.game = game
        self.args = args
        self.mcts = MCTS(game, args, model)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Ensure GPU is set

    def selfPlay(self):
        memory = []
        player = 1
        state = self.game.get_initial_state()  # Ensure this returns a 13-channel state

        while True:
            neutral_state = self.game.change_perspective(state, player)  # This should also return a 13-channel state
            action_probs = self.mcts.search(neutral_state)  # Ensure output is compatible

            memory.append((self.game.get_encoded_state(neutral_state).to(self.device), action_probs, player))  # Moved to GPU

            # Use temperature for exploration
            temperature_action_probs = action_probs ** (1 / self.args['temperature'])
            if temperature_action_probs.sum() > 0:
                temperature_action_probs /= temperature_action_probs.sum()
            else:
                temperature_action_probs = np.ones_like(temperature_action_probs) / len(temperature_action_probs)  # Handle edge case

            action = np.random.choice(self.game.action_size, p=temperature_action_probs)  # Use temperature-adjusted probabilities

            uci_move = self.game.index_to_move(action)  # Convert action index to UCI move
            state = self.game.get_next_state(state, uci_move, player)  # Ensure this state has 13 channels

            value, is_terminal = self.game.get_value_and_terminated(state, action)  # Check termination

            if is_terminal:
                return_memory = []
                for hist_neutral_state, hist_action_probs, hist_player in memory:
                    hist_outcome = value if hist_player == player else self.game.get_opponent_value(value)
                    return_memory.append((
                        self.game.get_encoded_state(hist_neutral_state).to(self.device),  # Moved to GPU
                        hist_action_probs,
                        hist_outcome
                    ))
                return return_memory

            # Switch player
            player = self.game.get_opponent(player)

    def train(self, memory):
        random.shuffle(memory)
        for batchIdx in tqdm(range(0, len(memory), self.args['batch_size']), desc="Training Batches"):
            sample = memory[batchIdx:min(len(memory), batchIdx + self.args['batch_size'])]
            
            state, policy_targets, value_targets = zip(*sample)
            state = torch.stack(state).to(self.device)
            policy_targets = torch.tensor(policy_targets, dtype=torch.float32).to(self.device)
            value_targets = torch.tensor(value_targets, dtype=torch.float32).view(-1, 1).to(self.device)

            out_policy, out_value = self.model(state)
            policy_loss = F.cross_entropy(out_policy, policy_targets)
            value_loss = F.mse_loss(out_value, value_targets)
            loss = policy_loss + value_loss
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # Show loss for each batch in tqdm
            tqdm.write(f"Batch Loss = {loss.item()}")

            # Monitor GPU memory
            print(f"Batch Loss = {loss.item()}")
            print(f"GPU Memory Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
            print(f"GPU Memory Cached: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")

    def learn(self):
        for iteration in range(self.args['num_iterations']):
            memory = []
            
            # Self-play phase
            self.model.eval()
            with torch.no_grad():  # No need to compute gradients during self-play
                for _ in trange(self.args['num_selfPlay_iterations']):
                    memory += self.selfPlay()
                
            # Training phase
            self.model.train()
            for epoch in trange(self.args['num_epochs']):
                self.train(memory)
            
            # Save model and optimizer states after each iteration
            torch.save(self.model.state_dict(), f"model_{iteration}.pt")
            torch.save(self.optimizer.state_dict(), f"optimizer_{iteration}.pt")


In [15]:
# Initialize the ChessGame instance
chess_game = ChessGame()

# Set the device for model training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Initialize the ResNet model for chess with specified parameters
model = ResNet(chess_game, num_resBlock=4, num_hidden=64).to(device)  # Ensure model is moved to GPU


# Initialize the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

# Set the arguments for AlphaZero
args = {
    'C': 2,
    'num_searches': 60,
    'num_iterations': 3,
    'num_selfPlay_iterations': 500,
    'num_parallel_games': 100,
    'num_epochs': 4,
    'batch_size': 64,  # This will manage GPU memory by processing in batches
    'temperature': 1.25,
    'dirichlet_epsilon': 0.25,
    'dirichlet_alpha': 0.3
}

# Create an instance of AlphaZero for chess
alphaZero = AlphaZero(model, optimizer, chess_game, args)

# Start the learning process
alphaZero.learn()  # This now runs on the GPU with batching


cuda


  0%|          | 0/500 [00:00<?, ?it/s]

  torch.tensor(self.game.get_encoded_state(state), device=self.model.device).unsqueeze(0).to(self.model.device)
  torch.tensor(self.game.get_encoded_state(node.state), device=self.model.device).unsqueeze(0)
