# Implementation of Alpha Zero for our game
This follows the ideas of idea.md

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.notebook import trange

import random
import math

### OUR GAME

In [13]:
# compile with  g++ -shared -o engine.dll *.cpp -static -static-libgcc -static-libstdc++ -I. -O2 -Wall -DBUILDING_DLL
# for linux compile with g++ -fPIC -shared -o engine.dll *.cpp

import ctypes
import os

# This represents the return types of the functions
class ReturnTypes:
    OK = 0
    ERROR = 1
    GAME_OVER_DRAW = 2
    GAME_OVER_WHITE_WINS = 3
    GAME_OVER_BLACK_WINS = 4
    INVALID_ARGUMENT = 5
    INVALID_GAME_NOT_STARTED = 6

class EngineDLL():
    def __init__(self,gamestring: str = ""):
        #TODO: place engine.dll in content/ directory
        path = os.path.join(os.getcwd(), "engine.dll")
        self.dll = ctypes.CDLL(path)
        self._setup_functions()
        self.start_game(gamestring)

    def _setup_functions(self):
        # Set up game functions
        self.dll.startGame.argtypes = [ctypes.c_char_p]
        self.dll.startGame.restype = ctypes.c_int

        self.dll.playMove.argtypes = [ctypes.c_char_p]
        self.dll.playMove.restype = ctypes.c_int

        self.dll.validMoves.restype = ctypes.c_char_p

        self.dll.getBoard.restype = ctypes.c_char_p

        self.dll.undo.argtypes = [ctypes.c_int]

        self.dll.getTurn.restype = ctypes.c_int

        self.dll.oracleEval.restype=ctypes.c_float

    def current_player_turn(self)->int:
        return self.dll.getTurn()

    def valid_moves(self)->str:
        return self.get_valid_moves()

    def play(self,move_string: str)-> None:
        self.play_move(move_string)

    def start_game(self, game_string):
        encoded_string = game_string.encode("utf-8")
        return self.dll.startGame(encoded_string) == ReturnTypes.OK

    def play_move(self, move_string):
        encoded_string = move_string.encode("utf-8")
        return self.dll.playMove(encoded_string)

    def current_player_turn(self)->int:
        return 1

    def get_valid_moves(self):
        try:
            result = self.dll.validMoves()
            if result:
                raw_bytes = ctypes.string_at(result)
                try:
                    decoded = raw_bytes.decode("utf-8")
                    return decoded
                except UnicodeDecodeError as e:
                    return raw_bytes.decode('ascii', errors='replace')
        except Exception:
            return ""

    def get_board(self):
        result = self.dll.getBoard()
        if result:
            try:
                return ctypes.string_at(result).decode("utf-8")
            except UnicodeDecodeError:
                return ctypes.string_at(result).decode('ascii', errors='replace')
        return ""


In [3]:
## ENGINE

class Engine():

  def __init__(self):
    # Create a singleton instance
    self.CPPInterface = EngineDLL()

  # Init the game
  def newgame(self, arguments: list[str]) -> None:
    self.CPPInterface.start_game(" ".join(arguments))

  ##TODO: define a state and a board in a convenient way

  def get_initial_state(self):
    self.newgame("Base+MLP")
    return self.CPPInterface.get_board()

  #TODO: how are these passed?
  def get_valid_moves(self, state):
    return self.CPPInterface.get_valid_moves(state) #TODO: does it make sense to pass the state here?

  def get_next_state(self, state, action, player):
    self.CPPInterface.play_move(state, action) #Player is useless
    return self.CPPInterface.get_board()

  def check_win(self, state): #TODO: check if this action gets to a winning state
    return self.CPPInterface.check_win(state)

  def get_value_and_terminated(self, state): #TODO: fix, maybe it's finished and we just lost
        if self.check_win(state):
            return 1, True
        if np.sum(self.get_valid_moves(state)) == 0:
            return 0, True
        return 0, False

  def get_turn(self):
    return self.CPPInterface.current_player_turn()

  def get_opponent(self, player):
        return -player

  def get_opponent_value(self, value):
      return -value

  def change_perspective(self, state, player):
      return state * player

  def get_encoded_state(self, state):
      encoded_state = np.stack(
          (state == -1, state == 0, state == 1)
      ).astype(np.float32)

      if len(state.shape) == 3:
          encoded_state = np.swapaxes(encoded_state, 0, 1)

      return encoded_state

  def get_graph_representation(self, state):
    """
    Convert the adjacency matrix state representation to graph format.

    Args:
        state: The adjacency matrix representation

    Returns:
        node_features: Features for each node/piece [num_nodes, feature_dim]
        edge_index: Graph connectivity in COO format [2, num_edges]
    """
    num_pieces = state.shape[0]  # Number of rows = number of pieces

    # Initialize node features
    # Features could include: piece type, player ownership, etc.
    node_features = []
    for i in range(num_pieces):
        # Get piece type and owner from your game state
        piece_type = self.get_piece_type(state, i)
        piece_owner = self.get_piece_owner(state, i)

        # Create one-hot encodings or other suitable features. TODO
        features = [...]  # Construct appropriate features
        node_features.append(features)

    # Construct edge list from adjacency information
    edge_src = []
    edge_dst = []

    # For each piece (row in the adjacency matrix)
    for i in range(num_pieces):
        # For each connection direction (column)
        for j in range(state.shape[1]):
            connected_piece = state[i, j]
            if connected_piece >= 0:  # If there's a connection (not -1 or empty)
                edge_src.append(i)
                edge_dst.append(connected_piece)

    edge_index = [edge_src, edge_dst]

    return np.array(node_features), np.array(edge_index)

### Neural Network Model

In [5]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m48.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [6]:
# Neural network used to output two things: a policy (= probability that from here, doing a certain move gives the winning) and a value (= evaluation of current state)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool

class HiveGNN(nn.Module):
    def __init__(self, node_features, hidden_dim, device):
        super().__init__()
        self.device = device

        # Initial graph convolution
        self.conv1 = GCNConv(node_features, hidden_dim)

        # Several graph convolution layers (like your ResBlocks)
        self.conv_layers = nn.ModuleList([
            GCNConv(hidden_dim, hidden_dim) for _ in range(9)  # Similar to your 9 ResBlocks
        ])

        # Policy head - now operates on node embeddings
        self.policy_mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU()
        )

        # Final policy output will be generated dynamically based on valid moves

        # Value head
        self.value_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),
            nn.Tanh()
        )

        self.to(device)

    def forward(self, x, edge_index, batch=None):
        # Initial convolution
        h = F.relu(self.conv1(x, edge_index))

        # Graph convolution layers
        for conv in self.conv_layers:
            # Optional: add residual connections
            residual = h
            h = F.relu(conv(h, edge_index))
            h = h + residual  # Residual connection

        # Node embeddings are now in h

        # For value: global pooling then MLP
        if batch is not None:
            pooled = global_mean_pool(h, batch)  # [batch_size, hidden_dim]
        else:
            pooled = h.mean(dim=0, keepdim=True)  # [1, hidden_dim]

        value = self.value_head(pooled)

        # For policy: get embeddings for each node
        node_policy_features = self.policy_mlp(h)  # [num_nodes, hidden_dim//2]

        # node_policy_features can be used to dynamically compute action probabilities
        # for the valid moves in the current state

        return node_policy_features, value


### Basic MCTS and Alpha Zero

In [7]:
import torch
import numpy as np
import math

# A node of the Tree: modified to handle variable action spaces
class Node:
    def __init__(self, game, args, state, parent=None, action_taken=None, prior=0, visit_count=0):
        self.game = game
        self.args = args
        self.state = state
        self.parent = parent
        self.action_taken = action_taken  # Now this will be a tuple or object representing the move
        self.prior = prior

        self.children = []
        self.valid_actions = []  # Store valid actions for this node

        self.visit_count = visit_count
        self.value_sum = 0

    # -------------------------------------- SELECT PHASE -------------------------------------------------------

    def is_fully_expanded(self):
        return len(self.children) > 0

    def select(self):
        best_child = None
        best_ucb = -np.inf

        for child in self.children:
            ucb = self.get_ucb(child)
            if ucb > best_ucb:
                best_child = child
                best_ucb = ucb

        return best_child

    def get_ucb(self, child):
        if child.visit_count == 0:
            q_value = 0
        else:
            q_value = 1 - ((child.value_sum / child.visit_count) + 1) / 2
        return q_value + self.args['C'] * (math.sqrt(self.visit_count) / (child.visit_count + 1)) * child.prior

    # -------------------------------------- EXPANSION PHASE -------------------------------------------------------

    # Modified to handle action probabilities from graph neural network
    def expand(self, action_probs):
        # action_probs should be a dictionary mapping actions to probabilities
        self.valid_actions = list(action_probs.keys())

        for action, prob in action_probs.items():
            if prob > 0:
                child_state = self.state.copy()
                child_state = self.game.get_next_state(child_state, action, 1)
                child_state = self.game.change_perspective(child_state, player=-1)

                child = Node(self.game, self.args, child_state, self, action, prob)
                self.children.append(child)

        return self.children[0] if self.children else None

    # -------------------------------------- BACKPROPAGATION PHASE -------------------------------------------------------

    def backpropagate(self, value):
        self.value_sum += value
        self.visit_count += 1

        value = self.game.get_opponent_value(value)
        if self.parent is not None:
            self.parent.backpropagate(value)


# MCTS class, modified to work with graph neural networks
class MCTS:
    def __init__(self, game, args, model):
        self.game = game
        self.args = args
        self.model = model

    @torch.no_grad()
    def search(self, state, turn=1):
        # Initialize the root
        root = Node(self.game, self.args, state, visit_count=1)

        # Get the graph representation of the state
        node_features, edge_index = self.game.get_graph_representation(state)

        # Convert to PyTorch tensors
        node_features = torch.tensor(node_features, device=self.model.device).float()
        edge_index = torch.tensor(edge_index, device=self.model.device).long()

        # Get node embeddings and value from the model
        node_embeddings, value = self.model(node_features, edge_index)

        # Get valid moves for this state
        valid_moves = self.game.get_valid_moves(state)

        # Compute action probabilities from node embeddings
        action_probs = self.compute_action_probabilities(node_embeddings, valid_moves, state)

        # Apply Dirichlet noise for exploration (only for root)
        if turn <= self.args['dirichlet_turn']:
            # Adapt Dirichlet noise for variable action space
            actions = list(action_probs.keys())
            probs = np.array(list(action_probs.values()))
            noise = np.random.dirichlet([self.args['dirichlet_alpha']] * len(probs))

            for i, action in enumerate(actions):
                action_probs[action] = (1 - self.args['dirichlet_epsilon']) * action_probs[action] + \
                                      self.args['dirichlet_epsilon'] * noise[i]

            # Normalize
            total = sum(action_probs.values())
            action_probs = {a: p/total for a, p in action_probs.items()}

        # Expand the root
        root.expand(action_probs)

        # Do a determined number of searches
        for search in range(self.args['num_searches']):
            node = root

            # Search a leaf
            while node.is_fully_expanded():
                node = node.select()

            # See if this is terminal
            value, is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken)
            value = self.game.get_opponent_value(value)

            # If it is not terminal
            if not is_terminal:
                # Get graph representation
                node_features, edge_index = self.game.get_graph_representation(node.state)

                # Convert to PyTorch tensors
                node_features = torch.tensor(node_features, device=self.model.device).float()
                edge_index = torch.tensor(edge_index, device=self.model.device).long()

                # Get node embeddings and value from the model
                node_embeddings, value = self.model(node_features, edge_index)

                # Get valid moves
                valid_moves = self.game.get_valid_moves(node.state)

                # Compute action probabilities
                action_probs = self.compute_action_probabilities(node_embeddings, valid_moves, node.state)

                value = value.item()

                # Expand
                node.expand(action_probs)

            # Backpropagate
            node.backpropagate(value)

        # Create action probabilities for return
        visit_counts = {}
        for child in root.children:
            visit_counts[child.action_taken] = child.visit_count

        total_visits = sum(visit_counts.values())
        action_probs = {action: count/total_visits for action, count in visit_counts.items()}

        return action_probs


    #todo: this is how it is suggested by Claude. Understand how to train it

    def compute_action_probabilities(self, node_embeddings, valid_moves, state):
    """
    Compute action probabilities from node embeddings and valid moves.

    Args:
        node_embeddings: Tensor of shape [num_nodes, embedding_dim]
        valid_moves: List of valid moves
        state: Current game state

    Returns:
        Dictionary mapping each valid move to its probability
    """
    action_probs = {}

    # Convert node_embeddings to numpy for easier handling if needed
    embeddings = node_embeddings.cpu().numpy() if isinstance(node_embeddings, torch.Tensor) else node_embeddings

    # We'll use a small MLP to score each move
    move_scorer = nn.Sequential(
        nn.Linear(embeddings.shape[1] * 2, 64),  # Concatenated source and target embeddings
        nn.ReLU(),
        nn.Linear(64, 1)
    ).to(node_embeddings.device)

    # For each valid move
    move_logits = []
    move_list = []

    for move in valid_moves:
        move_list.append(move)

        if isinstance(move, tuple) and len(move) == 2:
            # This is a movement move: (source_piece_idx, destination_position)
            source_idx = move[0]

            # For the destination, we need to get its current embedding
            # This could be an existing piece or a position on the board
            if isinstance(move[1], int):  # If destination is another piece index
                target_idx = move[1]
                target_embedding = embeddings[target_idx]
            else:  # If destination is a position description
                # We need to synthesize an embedding for this position
                # You might use adjacent pieces to create this embedding
                adjacent_pieces = self.game.get_adjacent_pieces(move[1], state)
                if adjacent_pieces:
                    # Average the embeddings of adjacent pieces
                    adjacent_embeddings = [embeddings[p] for p in adjacent_pieces]
                    target_embedding = np.mean(adjacent_embeddings, axis=0)
                else:
                    # Fallback if no adjacent pieces
                    target_embedding = np.zeros_like(embeddings[0])

            # Get source piece embedding
            source_embedding = embeddings[source_idx]

            # Concatenate source and target embeddings
            combined = torch.tensor(
                np.concatenate([source_embedding, target_embedding]),
                device=node_embeddings.device
            ).float()

            # Get score for this move
            logit = move_scorer(combined).item()

        else:
            # This is a placement move: typically a tuple with piece type and position
            # (or however you represent placement moves)

            # For placement, we might use the embedding of the piece type
            # or the average embedding of pieces already on the board
            piece_type = move[0] if isinstance(move, tuple) else move

            # Get pieces of same type if any exist
            same_type_pieces = [i for i in range(len(embeddings))
                               if self.game.get_piece_type(state, i) == piece_type]

            if same_type_pieces:
                # Average embeddings of same type pieces
                piece_embedding = np.mean([embeddings[i] for i in same_type_pieces], axis=0)
            else:
                # If no pieces of this type yet, use default embedding
                piece_embedding = np.zeros_like(embeddings[0])

            # Get position embedding similarly to movement case
            pos = move[1] if isinstance(move, tuple) else None
            if pos:
                adjacent_pieces = self.game.get_adjacent_pieces(pos, state)
                if adjacent_pieces:
                    adjacent_embeddings = [embeddings[p] for p in adjacent_pieces]
                    pos_embedding = np.mean(adjacent_embeddings, axis=0)
                else:
                    pos_embedding = np.zeros_like(embeddings[0])
            else:
                pos_embedding = np.zeros_like(embeddings[0])

            # Concatenate piece and position embeddings
            combined = torch.tensor(
                np.concatenate([piece_embedding, pos_embedding]),
                device=node_embeddings.device
            ).float()

            # Get score for this move
            logit = move_scorer(combined).item()

        move_logits.append(logit)

    # Apply softmax to get probabilities
    if move_logits:
        logits = np.array(move_logits)
        probs = np.exp(logits - np.max(logits))  # Subtract max for numerical stability
        probs = probs / np.sum(probs)

        action_probs = {move: float(prob) for move, prob in zip(move_list, probs)}

    return action_probs

In [8]:
class AlphaZero:
    def __init__(self, model, optimizer, game, args):
        self.model = model
        self.optimizer = optimizer
        self.game = game
        self.args = args
        self.mcts = MCTS(game, args, model)

    def selfPlay(self):
        memory = []
        player = 1
        state = self.game.get_initial_state()

        while True:
            neutral_state = self.game.change_perspective(state, player)
            action_probs = self.mcts.search(neutral_state)

            memory.append((neutral_state, action_probs, player))

            temperature_action_probs = action_probs ** (1 / self.args['temperature'])
            action = np.random.choice(self.game.action_size, p=temperature_action_probs/np.sum(temperature_action_probs)) # Divide temperature_action_probs with its sum in case of an error

            state = self.game.get_next_state(state, action, player)

            value, is_terminal = self.game.get_value_and_terminated(state, action)

            if is_terminal:
                returnMemory = []
                for hist_neutral_state, hist_action_probs, hist_player in memory:
                    hist_outcome = value if hist_player == player else self.game.get_opponent_value(value)
                    returnMemory.append((
                        self.game.get_encoded_state(hist_neutral_state),
                        hist_action_probs,
                        hist_outcome
                    ))
                return returnMemory

            player = self.game.get_opponent(player)

    def train(self, memory):
        random.shuffle(memory)
        for batchIdx in range(0, len(memory), self.args['batch_size']):
            sample = memory[batchIdx:min(len(memory) - 1, batchIdx + self.args['batch_size'])] # Change to memory[batchIdx:batchIdx+self.args['batch_size']] in case of an error
            state, policy_targets, value_targets = zip(*sample)

            state, policy_targets, value_targets = np.array(state), np.array(policy_targets), np.array(value_targets).reshape(-1, 1)

            state = torch.tensor(state, dtype=torch.float32, device=self.model.device)
            policy_targets = torch.tensor(policy_targets, dtype=torch.float32, device=self.model.device)
            value_targets = torch.tensor(value_targets, dtype=torch.float32, device=self.model.device)

            out_policy, out_value = self.model(state)

            policy_loss = F.cross_entropy(out_policy, policy_targets)
            value_loss = F.mse_loss(out_value, value_targets)
            loss = policy_loss + value_loss

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

    def learn(self):
        for iteration in range(self.args['num_iterations']):
            memory = []

            self.model.eval()
            for selfPlay_iteration in trange(self.args['num_selfPlay_iterations']):
                memory += self.selfPlay()

            self.model.train()
            for epoch in trange(self.args['num_epochs']):
                self.train(memory)

            torch.save(self.model.state_dict(), f"model_{iteration}_{self.game}.pt")
            torch.save(self.optimizer.state_dict(), f"optimizer_{iteration}_{self.game}.pt")

### Optimized and parallel (even if without threads) version of MCTS and Alpha Zero

In [None]:
#TODO

### Training

In [17]:
game = Engine()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = HiveGNN(game, 9, 128, device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

args = {
    'C': 2,
    'num_searches': 600,
    'num_iterations': 8,
    'num_selfPlay_iterations': 500,
    'num_parallel_games': 100,
    'num_epochs': 4,
    'batch_size': 128,
    'temperature': 1.25,
    'dirichlet_epsilon': 0.25,
    'dirichlet_alpha': 0.3
}

alphaZero = AlphaZero(model, optimizer, game, args)
alphaZero.learn()

OSError: /content/engine.dll: invalid ELF header

### Test to see how it plays

To make it play we only need the model (ResNet), the file produced by the training (.pt file) and the MCTS algorithm (standard version).

In [None]:
game = ConnectFour()
player = 1

args = {
    'C': 2,
    'num_searches': 600,
    'dirichlet_epsilon': 0.,
    'dirichlet_alpha': 0.3
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ResNet(game, 9, 128, device)
model.load_state_dict(torch.load("model_7_ConnectFour.pt", map_location=device))
model.eval()

mcts = MCTS(game, args, model)

state = game.get_initial_state()


while True:
    print(state)

    if player == 1:
        valid_moves = game.get_valid_moves(state)
        print("valid_moves", [i for i in range(game.action_size) if valid_moves[i] == 1])
        action = int(input(f"{player}:"))

        if valid_moves[action] == 0:
            print("action not valid")
            continue

    else:
        neutral_state = game.change_perspective(state, player)
        mcts_probs = mcts.search(neutral_state)
        action = np.argmax(mcts_probs)

    state = game.get_next_state(state, action, player)

    value, is_terminal = game.get_value_and_terminated(state, action)

    if is_terminal:
        print(state)
        if value == 1:
            print(player, "won")
        else:
            print("draw")
        break

    player = game.get_opponent(player)