**Imports**

In [21]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0)
from tqdm.notebook import trange
import random
import math
import chess
import chess.engine
import matplotlib as plt 
from tqdm import tqdm
import torch.multiprocessing as mp

In [22]:
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)} is available.")
else:
    print("No GPU available. Training will run on CPU.")

GPU: NVIDIA GeForce RTX 4080 Laptop GPU is available.


**Chess Game**

In [23]:
class ChessGame:
    def __init__(self, device):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.board = chess.Board()
        self.action_size = 24710  # Max possible moves incl. promotions

    def get_initial_state(self):
        return self.board.fen()  # FEN is a standard notation for representing the current state of the chess board
    
    def get_next_state(self, state, action, player):
        """Update the board with the chosen action.""" 
        board = chess.Board(state)
        move = chess.Move.from_uci(action)
        if move in board.legal_moves:
            board.push(move)
        else:
            raise ValueError(f"Illegal move: {action}")
        return board.fen()
    
    def get_valid_moves(self, state):
        """Get a binary mask of valid moves.""" 
        board = chess.Board(state)
        legal_moves = list(board.legal_moves)
        valid_moves = np.zeros(self.action_size, dtype=np.uint8)
        for move in legal_moves:
            move_idx = self.move_to_index(move)
            valid_moves[move_idx] = 1
        return valid_moves
    
    def get_opponent(self, player):
        """Toggle between players (1 for white, -1 for black)."""
        return -player
    
    def get_opponent_value(self, value):
        return -value

    def check_win(self, state, action):
        board = chess.Board(state)
        return board.is_game_over()
    
    def get_value_and_terminated(self, state, action):
        board = chess.Board(state)
        if board.is_checkmate():
            return 1, True
        elif board.is_stalemate() or board.is_insufficient_material() or board.is_fifty_moves() or board.is_fivefold_repetition():
            return 0, True
        return 0, False
    
    def get_encoded_state(self, state):
        """Encode the board state into a tensor format for input to the neural network."""
        board = chess.Board(state)
        board_tensor = np.zeros((13, 8, 8), dtype=np.float32)  # Shape [13, 8, 8]
        piece_map = board.piece_map()
        
        for square, piece in piece_map.items():
            piece_type = piece.piece_type - 1 if piece.color else piece.piece_type + 5 
            board_tensor[piece_type, square // 8, square % 8] = 1 
            
        # Set empty squares to the 12th channel
        for row in range(8):
            for col in range(8):
                if board_tensor[:, row, col].sum() == 0:
                    board_tensor[12, row, col] = 1  # Mark as empty
        
        # Convert to torch tensor and move to device
        board_tensor = torch.FloatTensor(board_tensor).to(device)
    
        return board_tensor

    
    def move_to_index(self, move):
        """Convert a move to a unique index for the action space.""" 
        uci_move = move.uci()
        from_square = chess.SQUARE_NAMES.index(uci_move[:2])
        to_square = chess.SQUARE_NAMES.index(uci_move[2:4])
        promotion = move.promotion or 0
        return from_square * 64 + to_square + promotion * 64 * 64

    def index_to_move(self, index):
        """Convert an index back to a move.""" 
        promotion = index // (64 * 64)
        index %= (64 * 64)
        from_square = index // 64
        to_square = index % 64
        move = chess.Move(from_square, to_square, promotion)
        return move.uci()
    
    def change_perspective(self, state, player):
        """Change the perspective of the board state (if necessary)."""
        # This function may need to be defined based on your architecture and how you intend to handle player perspectives.
        return state  # No need to modify state representation for this method


**ResNet**

In [24]:
class ResNet(nn.Module):
    def __init__(self, game, num_resBlock, num_hidden):  # num_hidden is the hidden size of conv blocks
        super().__init__()

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Updated to accept 13 input channels
        self.startBlock = nn.Sequential(
            nn.Conv2d(13, num_hidden, kernel_size=3, padding=1),  # 6 white, 6 black, 1 empty = 13
            nn.BatchNorm2d(num_hidden),
            nn.ReLU()
        )
        
        self.backBone = nn.ModuleList(
            [ResBlock(num_hidden) for _ in range(num_resBlock)]  # Using num_hidden for ResBlock
        )

        self.policyHead = nn.Sequential(
            nn.Conv2d(num_hidden, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * 8 * 8, game.action_size),  # Assuming output size is compatible
        )

        self.valueHead = nn.Sequential(
            nn.Conv2d(num_hidden, 3, kernel_size=3, padding=1),
            nn.BatchNorm2d(3),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3 * 8 * 8, 1),
            nn.Tanh()
        )

    def forward(self, x):
        x = x.to(self.device)  # Move input to the appropriate device (GPU/CPU)
        x = self.startBlock(x)  # Pass through the initial block
        for res_block in self.backBone:  # Pass through each ResBlock
            x = res_block(x)
        policy = self.policyHead(x)  # Get policy predictions
        value = self.valueHead(x)  # Get value predictions
        return policy, value

class ResBlock(nn.Module):
    def __init__(self, num_hidden):
        super().__init__()
        self.conv1 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(num_hidden)
        self.conv2 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(num_hidden)

    def forward(self, x):  # x is input
        residual = x  # Save the input for the skip connection
        x = F.relu(self.bn1(self.conv1(x)))  # First convolution + batch norm + ReLU
        x = self.bn2(self.conv2(x))  # Second convolution + batch norm
        x += residual  # Add the input back for the skip connection
        x = F.relu(x)  # ReLU activation after adding residual
        return x


**MCTS**

In [25]:
class Node:
    def __init__(self, game, args, state, parent=None, action_taken=None, prior=0, visit_count=0):
        self.game = game  # ChessGame object
        self.args = args
        self.state = state  # FEN string or any other chess state representation
        self.parent = parent
        self.action_taken = action_taken
        self.prior = prior
        
        self.children = []
        
        self.visit_count = visit_count
        self.value_sum = 0
        
    def is_fully_expanded(self):
        return len(self.children) > 0
    
    def select(self):
        """Select the best child node based on UCB score."""
        best_child = None
        best_ucb = -np.inf
        
        for child in self.children:
            ucb = self.get_ucb(child)
            if ucb > best_ucb:
                best_child = child
                best_ucb = ucb
                
        return best_child
    
    def get_ucb(self, child):
        """Calculate UCB for a child node."""
        if child.visit_count == 0:
            q_value = 0
        else:
            # Scale value to [0, 1] from [-1, 1]
            q_value = 1 - ((child.value_sum / child.visit_count) + 1) / 2
        return q_value + self.args['C'] * (math.sqrt(self.visit_count) / (child.visit_count + 1)) * child.prior
    
    def expand(self, policy):
        """Expand the current node by adding children for each valid move."""
        last_child = None  # Initialize to None to handle no valid move case
        for action, prob in enumerate(policy):
            if prob > 0:
                # Get the next state by applying the action (move)
                child_state = self.state  # Using FEN or other format
                child_state = self.game.get_next_state(child_state, self.game.index_to_move(action), 1)  # Apply action (chess move)
                child_state = self.game.change_perspective(child_state, player=-1)

                # Create a new child node and add it to the children list
                child = Node(self.game, self.args, child_state, self, action, prob)
                self.children.append(child)
                last_child = child  # Track the last valid child created

        if last_child is None:
            raise ValueError("No valid moves were found in the policy. Check if the policy is correctly generated.")

        return last_child  # Return last valid expanded child node
            
    def backpropagate(self, value):
        """Update the current node and propagate the result back up to the root."""
        self.value_sum += value
        self.visit_count += 1
        
        # Propagate the value to the parent, changing perspective (opponent's value)
        value = self.game.get_opponent_value(value)
        if self.parent is not None:
            self.parent.backpropagate(value)


In [26]:
#lass MCTS:
#   def __init__(self, game, args, model):
#       self.game = game
#       self.args = args
#       self.model = model
#       
#   @torch.no_grad()
#   def search(self, state):
#       # Initialize the root node of the MCTS tree
#       root = Node(self.game, self.args, state, visit_count=1)
#       
#       # Get the policy and value from the model
#       policy, _ = self.model(
#       torch.tensor(self.game.get_encoded_state(state), device=self.model.device).unsqueeze(0).to(self.model.device)
#       )
#       
#       # Apply softmax to policy
#       policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy()
#       
#       # Add Dirichlet noise for exploration
#       policy = (1 - self.args['dirichlet_epsilon']) * policy + self.args['dirichlet_epsilon'] \
#           * np.random.dirichlet([self.args['dirichlet_alpha']] * self.game.action_size)
#       
#       # Get valid moves and update policy
#       valid_moves = self.game.get_valid_moves(state)
#       policy *= valid_moves
#       if np.sum(policy)>0:
#           policy /= np.sum(policy)  # Normalize the policy
#       else:
#           policy = valid_moves / np.sum(valid_moves)
#       # Expand the root node with the computed policy
#       root.expand(policy)
#       
#       # Perform the search for a number of iterations
#       for search in range(self.args['num_searches']):
#           node = root
#           
#           # Traverse down the tree until an unexpanded node is found
#           while node.is_fully_expanded():
#               node = node.select()
#               
#           # Get the value and check if the node is terminal
#           value, is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken)
#           value = self.game.get_opponent_value(value)  # Convert value for the opponent
#           
#           # If the node is not terminal, expand further
#           if not is_terminal:
#               policy, value = self.model(
#                   torch.tensor(self.game.get_encoded_state(node.state), device=self.model.device).unsqueeze(0)
#               )
#               
#               policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy()
#               valid_moves = self.game.get_valid_moves(node.state)
#               
#               # Update policy based on valid moves
#               policy *= valid_moves
#               if np.sum(policy) > 0:
#                   policy /= np.sum(policy)  # Normalize the policy
#               else:
#                   policy = valid_moves / np.sum(valid_moves)
#               
#               value = value.item()  # Get the value as a scalar
#               
#               # Expand the node with the new policy
#               node.expand(policy)
#               
#           # Backpropagate the value up to the root node
#           node.backpropagate(value)
#           
#       # Collect visit counts from the root's children for action probabilities
#       action_probs = np.zeros(self.game.action_size)
#       for child in root.children:
#           action_probs[child.action_taken] = child.visit_count
#           
#       action_probs /= np.sum(action_probs)  # Normalize the action probabilities
#       return action_probs
#

In [27]:
class MCTS:
    def __init__(self, game, args, model):
        self.game = game
        self.args = args
        self.model = model
        
    @torch.no_grad()
    def search(self, state_batch):
        # We assume state_batch is a batch of game states
        batch_size = len(state_batch)
        action_probs_batch = np.zeros((batch_size, self.game.action_size))

        for i in range(batch_size):
            state = state_batch[i]
            root_node = Node(self.game, self.args, state, visit_count=1)

            # Get the policy and value from the model for the current state
            encoded_state = self.game.get_encoded_state(state).unsqueeze(0).to(self.model.device)
            policy, _ = self.model(encoded_state)

            # Apply softmax to the policy for the current state
            policy = torch.softmax(policy, axis=1).cpu().numpy().flatten()

            # Apply Dirichlet noise for exploration
            policy = (1 - self.args['dirichlet_epsilon']) * policy + self.args['dirichlet_epsilon'] \
                * np.random.dirichlet([self.args['dirichlet_alpha']] * self.game.action_size)

            # Get valid moves and update policy for the current game state
            valid_moves = self.game.get_valid_moves(state)
            policy *= valid_moves

            valid_sum = np.sum(policy)  # Check sum of the policy
            if valid_sum > 0:
                policy /= valid_sum  # Normalize the policy
            else:
                valid_moves_sum = np.sum(valid_moves)
                if valid_moves_sum > 0:
                    policy = valid_moves / valid_moves_sum
                else:
                    print(f"No valid moves found for game state {i}. Skipping expansion for this node.")
                    continue  # Skip this node if no valid moves are available

            root_node.expand(policy)

            # Perform MCTS search for the current game
            for search in range(self.args['num_searches']):
                node = root_node

                # Traverse down the tree until an unexpanded node is found
                while node.is_fully_expanded():
                    node = node.select()

                value, is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken)
                value = self.game.get_opponent_value(value)  # Convert value for the opponent

                if not is_terminal:
                    encoded_state = self.game.get_encoded_state(node.state).unsqueeze(0).to(self.model.device)
                    policy, value = self.model(encoded_state)
                    policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy()
                    valid_moves = self.game.get_valid_moves(node.state)

                    # Update policy based on valid moves
                    policy *= valid_moves
                    valid_sum = np.sum(policy)  # Check sum of the policy
                    if valid_sum > 0:
                        policy /= valid_sum  # Normalize the policy
                    else:
                        valid_moves_sum = np.sum(valid_moves)
                        if valid_moves_sum > 0:
                            policy = valid_moves / valid_moves_sum
                        else:
                            print(f"No valid moves found for the current node. Skipping expansion for this node.")
                            continue  # Skip this node if no valid moves are available

                    value = value.item()  # Get the value as a scalar
                    node.expand(policy)

                node.backpropagate(value)

            # Collect action probabilities for the current game
            for child in root_node.children:
                action_probs_batch[i][child.action_taken] = child.visit_count
        
            # Normalize action probabilities for the current game
            action_probs_batch[i] /= np.sum(action_probs_batch[i])

        return action_probs_batch


**AlphaZero Training loop**

In [28]:
#class AlphaZero:
#    def __init__(self, model, optimizer, game, args):
#        self.model = model
#        self.optimizer = optimizer
#        self.game = game
#        self.args = args
#        self.mcts = MCTS(game, args, model)
#        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Ensure GPU is set
#
#    def selfPlay(self):
#        memory = []
#        player = 1
#        state = self.game.get_initial_state()  # Ensure this returns a 13-channel state
#
#        while True:
#            neutral_state = self.game.change_perspective(state, player)  # This should also return a 13-channel state
#            action_probs = self.mcts.search(neutral_state)  # Ensure output is compatible
#
#            memory.append((self.game.get_encoded_state(neutral_state).to(self.device), action_probs, player))  # Move to GPU
#
#
#            # Use temperature for exploration
#            temperature_action_probs = action_probs ** (1 / self.args['temperature'])
#            if temperature_action_probs.sum() > 0:
#                temperature_action_probs /= temperature_action_probs.sum()
#            else:
#                temperature_action_probs = np.ones_like(temperature_action_probs) / len(temperature_action_probs)  # Handle edge case
#
#            action = np.random.choice(self.game.action_size, p=temperature_action_probs)  # Use temperature-adjusted probabilities
#
#            uci_move = self.game.index_to_move(action)  # Convert action index to UCI move
#            state = self.game.get_next_state(state, uci_move, player)  # Ensure this state has 13 channels
#
#            value, is_terminal = self.game.get_value_and_terminated(state, action)  # Check termination
#
#            if is_terminal:
#                return_memory = []
#                for hist_neutral_state, hist_action_probs, hist_player in memory:
#                    hist_outcome = value if hist_player == player else self.game.get_opponent_value(value)
#                    return_memory.append((
#                        self.game.get_encoded_state(hist_neutral_state).to(self.device),  # Moved to GPU
#                        hist_action_probs,
#                        hist_outcome
#                    ))
#                return return_memory
#
#            # Switch player
#            player = self.game.get_opponent(player)
#
#    def train(self, memory):
#        random.shuffle(memory)
#        for batchIdx in tqdm(range(0, len(memory), self.args['batch_size']), desc="Training Batches"):
#            sample = memory[batchIdx:min(len(memory), batchIdx + self.args['batch_size'])]
#            
#            state, policy_targets, value_targets = zip(*sample)
#            state = torch.stack(state).to(self.device)
#            policy_targets = torch.tensor(policy_targets, dtype=torch.float32).to(self.device)
#            value_targets = torch.tensor(value_targets, dtype=torch.float32).view(-1, 1).to(self.device)
#
#            out_policy, out_value = self.model(state)
#            policy_loss = F.cross_entropy(out_policy, policy_targets)
#            value_loss = F.mse_loss(out_value, value_targets)
#            loss = policy_loss + value_loss
#            
#            self.optimizer.zero_grad()
#            loss.backward()
#            self.optimizer.step()
#
#            # Show loss for each batch in tqdm
#            tqdm.write(f"Batch Loss = {loss.item()}")
#
#            # Monitor GPU memory
#            print(f"Batch Loss = {loss.item()}")
#            print(f"GPU Memory Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
#            print(f"GPU Memory Cached: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
#
#    def learn(self):
#        for iteration in range(self.args['num_iterations']):
#            memory = []
#            
#            # Self-play phase
#            self.model.eval()
#            with torch.no_grad():  # No need to compute gradients during self-play
#                for _ in trange(self.args['num_selfPlay_iterations']):
#                    memory += self.selfPlay()
#                
#            # Training phase
#            self.model.train()
#            for epoch in trange(self.args['num_epochs']):
#                self.train(memory)
#            
#            # Save model and optimizer states after each iteration
#            torch.save(self.model.state_dict(), f"model_{iteration}.pt")
#            torch.save(self.optimizer.state_dict(), f"optimizer_{iteration}.pt")
#

In [29]:
class AlphaZero:
    def __init__(self, model, optimizer, game, args):
        self.model = model
        self.optimizer = optimizer
        self.game = game
        self.args = args
        self.mcts = MCTS(game, args, model)  # Use the non-parallel MCTS class
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Ensure GPU is set

    def selfPlay(self):
        """Run a single game."""
        memory_batch = []
        player = 1  # Start as white
        state = self.game.get_initial_state()  # Initial state for the game
        move_no = 1

        while True:
            # Change perspective for the player and get action probabilities
            neutral_state = self.game.change_perspective(state, player)
            action_probs = self.mcts.search([neutral_state])[0]  # Run MCTS search for the current state

            memory_batch.append((
                self.game.get_encoded_state(neutral_state).to(self.device),
                action_probs,
                player
            ))

            # Use temperature for exploration
            temperature_action_probs = action_probs ** (1 / self.args['temperature'])
            
            # Normalize action probabilities
            temp_sum = np.sum(temperature_action_probs)
            if temp_sum > 0:
                temperature_action_probs /= temp_sum
            else:
                temperature_action_probs = np.ones_like(temperature_action_probs) / len(temperature_action_probs)

            # Choose action based on temperature-adjusted probabilities
            action = np.random.choice(self.game.action_size, p=temperature_action_probs)
            uci_move = self.game.index_to_move(action)

            value, terminated = self.game.get_value_and_terminated(state, uci_move)

            if terminated:
                print(f"Game over! Result: {'Draw' if value == 0 else 'Win'}")
                break

            # Convert the UCI move string into a chess.Move object
            move_object = chess.Move.from_uci(uci_move)

            # Print the move and the current player
            player_str = 'White' if player == 1 else 'Black'
            print(f"Player: {player_str}, Move {move_no}: {uci_move}")
            if player_str == "Black":
                move_no += 1

            # Ensure the board is in the correct state for the current player
            current_turn = self.game.board.turn  # True for White, False for Black
            if current_turn != (player == 1):
                print(f"Board turn mismatch detected. Correcting turn.")
                self.game.board.turn = player == 1  # Set board to the correct turn

            # Check if the move is valid
            if move_object in self.game.board.legal_moves:
                # Apply the move to the board
                self.game.board.push(move_object)  # Update the actual board with the move

                # Now update the state with the new board position
                state = self.game.board.fen()  # Update state to the FEN notation of the new board

                # Print the updated board position
                print(f"Board after move {uci_move}:\n{self.game.board}")
            else:
                print(f"Illegal move encountered: {uci_move}. Skipping this game.")
                
                # Fallback: Get a valid random move, ensuring it's from the current legal moves
                valid_moves = list(self.game.board.legal_moves)
                if valid_moves:
                    fallback_move = np.random.choice(valid_moves)  # Get a valid move directly from legal moves
                    print(f"Using fallback move: {fallback_move.uci()}")
                    
                    # Update the board with the fallback move
                    try:
                        self.game.board.push(fallback_move)  # Push the fallback move onto the board

                        # Update the state with the new board position
                        state = self.game.board.fen()  # Store the new FEN notation after fallback move

                        # Print fallback move and updated board
                        print(f"Board after fallback move {fallback_move.uci()}:\n{self.game.board}")
                    except ValueError as e:
                        print(f"Error applying fallback move: {fallback_move.uci()}. Error: {e}")
                        continue
                else:
                    print(f"No valid moves available. Game over.")
                    break  # No moves available, end the game

            # Check for terminal state
            is_terminal = self.game.get_value_and_terminated(state, action)[1]
            if is_terminal:
                break  # Exit the loop if the game has ended

            # Switch players
            player = self.game.get_opponent(player)

        return memory_batch

    def train(self, memory):
        """Training process remains the same as in the original class."""
        random.shuffle(memory)
        for batchIdx in tqdm(range(0, len(memory), self.args['batch_size']), desc="Training Batches"):
            sample = memory[batchIdx:min(len(memory), batchIdx + self.args['batch_size'])]
            
            state, policy_targets, value_targets = zip(*sample)
            state = torch.stack(state).to(self.device)
            policy_targets = torch.tensor(policy_targets, dtype=torch.float32).to(self.device)
            value_targets = torch.tensor(value_targets, dtype=torch.float32).view(-1, 1).to(self.device)

            out_policy, out_value = self.model(state)
            policy_loss = F.cross_entropy(out_policy, policy_targets)
            value_loss = F.mse_loss(out_value, value_targets)
            loss = policy_loss + value_loss
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # Show loss for each batch in tqdm
            tqdm.write(f"Batch Loss = {loss.item()}")

    def learn(self):
        """Learn using non-parallel self-play."""
        for iteration in range(self.args['num_iterations']):
            memory = []
            
            # Self-play phase
            self.model.eval()
            with torch.no_grad():  # No need to compute gradients during self-play
                for _ in trange(self.args['num_selfPlay_iterations']):
                    memory += self.selfPlay()
                
            # Training phase
            self.model.train()
            for epoch in trange(self.args['num_epochs']):
                self.train(memory)
            
            # Save model and optimizer states after each iteration
            torch.save(self.model.state_dict(), f"model.pt")
            torch.save(self.optimizer.state_dict(), f"optimizer.pt")


**Base Training Here**

In [30]:
# Set the device for model training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Initialize the ChessGame instance
chess_game = ChessGame(device=device)

# Initialize the ResNet model for chess with specified parameters
model = ResNet(chess_game, num_resBlock=4, num_hidden=64).to(device)  # Ensure model is moved to GPU


# Initialize the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

# Set the arguments for AlphaZero
args = {
    'C': 2,
    'num_searches': 20, #60
    'num_iterations': 5,
    'num_selfPlay_iterations': 10, #500
    'num_epochs': 15,
    'batch_size': 64,  # This will manage GPU memory by processing in batches
    'temperature': 1.25,
    'dirichlet_epsilon': 0.25,
    'dirichlet_alpha': 0.3
}

# Create an instance of AlphaZero for chess
alphaZero = AlphaZero(model, optimizer, chess_game, args)

# Start the learning process
alphaZero.learn()  # This now runs on the GPU with batching


cuda


  0%|          | 0/10 [00:00<?, ?it/s]

Player: White, Move 1: b2b4
Board after move b2b4:
r n b q k b n r
p p p p p p p p
. . . . . . . .
. . . . . . . .
. P . . . . . .
. . . . . . . .
P . P P P P P P
R N B Q K B N R
Player: Black, Move 1: d7d5
Board after move d7d5:
r n b q k b n r
p p p . p p p p
. . . . . . . .
. . . p . . . .
. P . . . . . .
. . . . . . . .
P . P P P P P P
R N B Q K B N R
Player: White, Move 2: c1a3
Board after move c1a3:
r n b q k b n r
p p p . p p p p
. . . . . . . .
. . . p . . . .
. P . . . . . .
B . . . . . . .
P . P P P P P P
R N . Q K B N R
Player: Black, Move 2: d8d7
Board after move d8d7:
r n b . k b n r
p p p q p p p p
. . . . . . . .
. . . p . . . .
. P . . . . . .
B . . . . . . .
P . P P P P P P
R N . Q K B N R
Player: White, Move 3: f2f3
Board after move f2f3:
r n b . k b n r
p p p q p p p p
. . . . . . . .
. . . p . . . .
. P . . . . . .
B . . . . P . .
P . P P P . P P
R N . Q K B N R
Player: Black, Move 3: d7h3
Board after move d7h3:
r n b . k b n r
p p p . p p p p
. . . . . . . .
. . . 

  0%|          | 0/15 [00:00<?, ?it/s]

  policy_targets = torch.tensor(policy_targets, dtype=torch.float32).to(self.device)

[A                                                    
[A                                                            
Training Batches:  25%|██▌       | 1/4 [00:00<00:01,  1.99it/s]

Batch Loss = 11.201610565185547
Batch Loss = 11.57590389251709


[A
[A                                                            
Training Batches: 100%|██████████| 4/4 [00:01<00:00,  3.80it/s]


Batch Loss = 11.207752227783203
Batch Loss = 11.004985809326172



[A                                                    
[A                                                            

Batch Loss = 7.660636901855469
Batch Loss = 6.833587169647217



[A                                                            
Training Batches: 100%|██████████| 4/4 [00:00<00:00,  6.08it/s]


Batch Loss = 7.007371425628662
Batch Loss = 6.488181114196777



[A                                                    
[A                                                            

Batch Loss = 5.705410003662109
Batch Loss = 5.060520172119141



[A                                                            
Training Batches: 100%|██████████| 4/4 [00:00<00:00,  6.98it/s]


Batch Loss = 5.019509792327881
Batch Loss = 4.875611305236816



[A                                                    
[A                                                            

Batch Loss = 4.012418746948242
Batch Loss = 3.8280389308929443



[A                                                            
Training Batches: 100%|██████████| 4/4 [00:00<00:00,  6.20it/s]


Batch Loss = 3.956040382385254
Batch Loss = 3.78560733795166



[A                                                    
[A                                                            

Batch Loss = 3.1195924282073975
Batch Loss = 3.1602845191955566



[A                                                            
Training Batches: 100%|██████████| 4/4 [00:00<00:00,  6.75it/s]


Batch Loss = 2.9687066078186035
Batch Loss = 3.0294132232666016



[A                                                    
[A                                                            

Batch Loss = 2.623812198638916
Batch Loss = 2.356936454772949



[A                                                            
Training Batches: 100%|██████████| 4/4 [00:00<00:00,  6.05it/s]


Batch Loss = 2.4677648544311523
Batch Loss = 2.6329898834228516



[A                                                    
[A                                                            

Batch Loss = 2.1220955848693848
Batch Loss = 2.0018701553344727



[A                                                            
Training Batches: 100%|██████████| 4/4 [00:00<00:00,  6.09it/s]


Batch Loss = 2.102227210998535
Batch Loss = 2.2966532707214355



[A                                                    
[A                                                            

Batch Loss = 1.8198251724243164
Batch Loss = 1.7714412212371826



[A                                                            
Training Batches: 100%|██████████| 4/4 [00:00<00:00,  6.65it/s]


Batch Loss = 1.954291820526123
Batch Loss = 1.9623687267303467



[A                                                    
[A                                                            

Batch Loss = 1.74827241897583
Batch Loss = 1.572890281677246



[A                                                            
Training Batches: 100%|██████████| 4/4 [00:00<00:00,  6.36it/s]


Batch Loss = 1.6144468784332275
Batch Loss = 1.6488746404647827



[A                                                    
[A                                                            

Batch Loss = 1.5188157558441162
Batch Loss = 1.3365223407745361



[A                                                            
Training Batches: 100%|██████████| 4/4 [00:00<00:00,  6.09it/s]


Batch Loss = 1.585639476776123
Batch Loss = 1.4311294555664062



[A                                                    
[A                                                            

Batch Loss = 1.378335952758789
Batch Loss = 1.239251971244812



[A                                                            
Training Batches: 100%|██████████| 4/4 [00:00<00:00,  6.10it/s]


Batch Loss = 1.3580260276794434
Batch Loss = 1.4995083808898926



[A                                                    
[A                                                            

Batch Loss = 1.2194361686706543
Batch Loss = 1.1617616415023804



[A                                                            
Training Batches: 100%|██████████| 4/4 [00:00<00:00,  6.32it/s]


Batch Loss = 1.1695168018341064
Batch Loss = 1.2605892419815063



[A                                                    


Batch Loss = 1.092496395111084


[A                                                            
[A                                                            

Batch Loss = 1.0057857036590576
Batch Loss = 1.231766939163208



Training Batches: 100%|██████████| 4/4 [00:00<00:00,  6.04it/s]


Batch Loss = 1.1534160375595093



[A                                                    

Batch Loss = 0.9278549551963806



[A                                                            

Batch Loss = 1.0341001749038696



[A                                                            

Batch Loss = 0.9985952973365784



Training Batches: 100%|██████████| 4/4 [00:00<00:00,  6.67it/s]


Batch Loss = 1.177155613899231



[A                                                    

Batch Loss = 0.8361006379127502



[A                                                            

Batch Loss = 1.0015325546264648



[A                                                            
Training Batches: 100%|██████████| 4/4 [00:00<00:00,  6.11it/s]


Batch Loss = 1.0602346658706665
Batch Loss = 0.9773311614990234


  0%|          | 0/10 [00:00<?, ?it/s]

Game over! Result: Draw
Game over! Result: Draw
Game over! Result: Draw
Game over! Result: Draw
Game over! Result: Draw
Game over! Result: Draw
Game over! Result: Draw
Game over! Result: Draw
Game over! Result: Draw
Game over! Result: Draw


  0%|          | 0/15 [00:00<?, ?it/s]


Training Batches: 100%|██████████| 1/1 [00:00<00:00, 10.97it/s]


Batch Loss = 2.117870330810547



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 14.83it/s]


Batch Loss = 2.2127621173858643



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 17.55it/s]


Batch Loss = 1.7784230709075928



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 19.73it/s]


Batch Loss = 1.8226325511932373



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 20.91it/s]


Batch Loss = 1.7345813512802124



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 19.19it/s]


Batch Loss = 1.722790241241455



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 18.23it/s]


Batch Loss = 1.6636521816253662



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 18.79it/s]

Batch Loss = 1.670383334159851




Training Batches: 100%|██████████| 1/1 [00:00<00:00, 19.84it/s]


Batch Loss = 1.6663542985916138



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 20.90it/s]


Batch Loss = 1.6235471963882446



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 22.07it/s]


Batch Loss = 1.6545372009277344



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 21.34it/s]


Batch Loss = 1.6458345651626587



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 21.85it/s]


Batch Loss = 1.6172279119491577



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 20.92it/s]


Batch Loss = 1.6256287097930908



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 19.27it/s]


Batch Loss = 1.6332358121871948


  0%|          | 0/10 [00:00<?, ?it/s]

Game over! Result: Draw
Game over! Result: Draw
Game over! Result: Draw
Game over! Result: Draw
Game over! Result: Draw
Game over! Result: Draw
Game over! Result: Draw
Game over! Result: Draw
Game over! Result: Draw
Game over! Result: Draw


  0%|          | 0/15 [00:00<?, ?it/s]


[A                                                    

Batch Loss = 1.5220309495925903


Training Batches: 100%|██████████| 1/1 [00:00<00:00, 16.85it/s]

Training Batches: 100%|██████████| 1/1 [00:00<00:00, 21.26it/s]


Batch Loss = 1.384861707687378



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 20.83it/s]


Batch Loss = 1.3843724727630615



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 19.51it/s]


Batch Loss = 1.3613569736480713



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 18.89it/s]


Batch Loss = 1.334825873374939



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 19.67it/s]


Batch Loss = 1.3448957204818726



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 21.43it/s]


Batch Loss = 1.3464617729187012



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 21.97it/s]


Batch Loss = 1.330277442932129



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 20.94it/s]


Batch Loss = 1.3346933126449585



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 15.53it/s]


Batch Loss = 1.338127613067627



[A                                                    
Training Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batch Loss = 1.3328136205673218


Training Batches: 100%|██████████| 1/1 [00:00<00:00, 20.35it/s]

Training Batches: 100%|██████████| 1/1 [00:00<00:00, 21.79it/s]


Batch Loss = 1.3258297443389893



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 21.58it/s]


Batch Loss = 1.3333193063735962



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 21.39it/s]


Batch Loss = 1.3320488929748535



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 22.10it/s]


Batch Loss = 1.3242775201797485


  0%|          | 0/10 [00:00<?, ?it/s]

Game over! Result: Draw
Game over! Result: Draw
Game over! Result: Draw
Game over! Result: Draw
Game over! Result: Draw
Game over! Result: Draw
Game over! Result: Draw
Game over! Result: Draw
Game over! Result: Draw
Game over! Result: Draw


  0%|          | 0/15 [00:00<?, ?it/s]


Training Batches: 100%|██████████| 1/1 [00:00<00:00, 14.13it/s]


Batch Loss = 1.150601863861084



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 17.13it/s]


Batch Loss = 1.0609508752822876



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 16.18it/s]


Batch Loss = 1.0316537618637085



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 16.38it/s]


Batch Loss = 1.0240120887756348



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 16.45it/s]


Batch Loss = 1.0042712688446045



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 16.45it/s]


Batch Loss = 1.0059291124343872



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 15.82it/s]


Batch Loss = 1.0101697444915771



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 20.01it/s]


Batch Loss = 1.0041712522506714



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 19.56it/s]


Batch Loss = 1.0023339986801147



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 19.81it/s]


Batch Loss = 1.0040432214736938



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 19.67it/s]


Batch Loss = 1.006851077079773



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 20.12it/s]


Batch Loss = 0.9997209310531616



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 19.47it/s]


Batch Loss = 1.00164794921875



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 19.66it/s]


Batch Loss = 1.0049883127212524



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 21.35it/s]


Batch Loss = 1.0003920793533325


  0%|          | 0/10 [00:00<?, ?it/s]

Game over! Result: Draw
Game over! Result: Draw
Game over! Result: Draw
Game over! Result: Draw
Game over! Result: Draw
Game over! Result: Draw
Game over! Result: Draw
Game over! Result: Draw
Game over! Result: Draw
Game over! Result: Draw


  0%|          | 0/15 [00:00<?, ?it/s]


Training Batches: 100%|██████████| 1/1 [00:00<00:00, 18.08it/s]


Batch Loss = 0.949451208114624



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 17.20it/s]


Batch Loss = 0.939667820930481



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 17.60it/s]


Batch Loss = 0.9508942365646362



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 18.24it/s]


Batch Loss = 0.9503247737884521



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 17.41it/s]


Batch Loss = 0.9448715448379517



[A                                                    
Training Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batch Loss = 0.9414928555488586


Training Batches: 100%|██████████| 1/1 [00:00<00:00, 16.08it/s]

Training Batches: 100%|██████████| 1/1 [00:00<00:00, 15.19it/s]


Batch Loss = 0.94450843334198



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 16.79it/s]


Batch Loss = 0.9482269883155823



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 19.30it/s]


Batch Loss = 0.9413946866989136



[A                                                    
Training Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batch Loss = 0.940932035446167


Training Batches: 100%|██████████| 1/1 [00:00<00:00, 20.07it/s]

Training Batches: 100%|██████████| 1/1 [00:00<00:00, 20.55it/s]


Batch Loss = 0.9445201754570007



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 19.67it/s]


Batch Loss = 0.9437317252159119



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 20.30it/s]


Batch Loss = 0.9420797228813171



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 18.23it/s]


Batch Loss = 0.9398247599601746



Training Batches: 100%|██████████| 1/1 [00:00<00:00, 17.45it/s]


Batch Loss = 0.9431111812591553


**Restart Training Here** (not working)

In [31]:
# Check if a GPU is available and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Assuming the AlphaZero class, model, optimizer, and args are defined elsewhere
# model, optimizer, game, and args should be defined based on your code setup

# Create an instance of the AlphaZero class
alpha_zero = AlphaZero(model, optimizer, game, args)  # Replace with actual objects for model, optimizer, game, and args

# Load the model and optimizer states from the saved files
alpha_zero.model.load_state_dict(torch.load("model.pt", map_location=device))  # Load model to the correct device
alpha_zero.model.to(device)  # Move the model to GPU if available
alpha_zero.optimizer.load_state_dict(torch.load("optimizer.pt", map_location=device))  # Load optimizer to the correct device

# You can modify these parameters for the restart
num_iterations = 500  # Total number of iterations for training
num_selfPlay_iterations = 1000  # Number of self-play iterations per training cycle

# Resume training
for iteration in range(num_iterations):
    memory = []  # Memory to store game data for training

    # Self-play phase
    alpha_zero.model.eval()  # Set model to evaluation mode (no gradients needed during self-play)
    with torch.no_grad():
        for _ in trange(num_selfPlay_iterations, desc="Self-Play"):
            memory += alpha_zero.selfPlay()  # Collect data from self-play games

    # Training phase
    alpha_zero.model.train()  # Set model to training mode
    for epoch in trange(alpha_zero.args['num_epochs'], desc="Training"):
        alpha_zero.train(memory)  # Train the model using self-play data

    # Save the updated model and optimizer states after each iteration
    torch.save(alpha_zero.model.state_dict(), "model.pt")
    torch.save(alpha_zero.optimizer.state_dict(), "optimizer.pt")

    # If using a scheduler, you can also save its state
    # torch.save(alpha_zero.scheduler.state_dict(), "scheduler.pt")

print("Training resumed and completed.")


Using device: cuda


NameError: name 'game' is not defined