In [None]:
import torch
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from tqdm.notebook import trange
torch.manual_seed(0)

In [None]:
def in_colab():
    try:
        import google.colab
        return True
    except ImportError:
        return False

if in_colab():
    from google.colab import drive
    drive.mount('/content/drive')
    base_path = '/content/drive/My Drive/Colab Notebooks'
else:
    base_path = '.'

model_path = os.path.join(base_path, 'alpha_gomoku/model')
optimizer_path = os.path.join(base_path, 'alpha_gomoku/optimizer')

os.makedirs(model_path, exist_ok=True)
os.makedirs(optimizer_path, exist_ok=True)

In [None]:
EMPTY = 0
PLAYER1 = 1
PLAYER2 = -1

In [None]:
class Gomoku:
    """
    Static class that contains methods related to Gomoku.
    """
    
    def __init__(self, size=15):
        self.size = size

    def get_initial_state(self):
        """
        Create an empty board.

        Returns:
            np.array: A 2D array representing the board.
        """
        return np.zeros((self.size, self.size)).astype(np.int8)

    def get_next_state(self, state, action, player):
        """
        Creates a new state by applying the action to the current state.

        Args:
            state (np.array): The current state of the game.
            action (int): The action to apply to the state. Represented as an integer from 0 to size^2 - 1.
            player (int): The player to apply the action. Should be either PLAYER1 or PLAYER2.

        Raises:
            ValueError: The action is invalid. The cell is already occupied.

        Returns:
            np.array: A 2D array representing the new state of the game.
        """
        row = action // self.size
        col = action % self.size
        if state[row, col] != EMPTY:
            raise ValueError("Invalid action")
        state[row, col] = player
        return state

    def get_moves(self, state):
        """
        Get all the legal actions for the current state.

        Args:
            state (np.array): The current state of the game.

        Returns:
            np.array: A 1D array of size size^2 representing the legal actions.
        """
        return (state.reshape(-1) == EMPTY).astype(np.uint8)

    def check_win(self, state, action):
        """
        Checks if the action won the game.

        Args:
            state (np.array): The current state of the game.
            action (int): Last action made.

        Returns:
            boolean: True if the action won the game, False otherwise.
        """
        if action is None:
            return False

        row = action // self.size
        col = action % self.size
        player = state[row, col]
        if player == EMPTY:
            return False

        directions = [(1, 0), (0, 1), (1, 1), (1, -1)]

        for dr, dc in directions:
            count = 1

            # Create an array of indices for positive direction
            indices = np.array([(row + i * dr, col + i * dc) for i in range(1, 5)])
            valid_indices = (
                (indices[:, 0] >= 0)
                & (indices[:, 0] < self.size)
                & (indices[:, 1] >= 0)
                & (indices[:, 1] < self.size)
            )
            valid_indices = indices[valid_indices]

            count += np.sum(state[valid_indices[:, 0], valid_indices[:, 1]] == player)

            # Create an array of indices for negative direction
            indices = np.array([(row - i * dr, col - i * dc) for i in range(1, 5)])
            valid_indices = (
                (indices[:, 0] >= 0)
                & (indices[:, 0] < self.size)
                & (indices[:, 1] >= 0)
                & (indices[:, 1] < self.size)
            )
            valid_indices = indices[valid_indices]

            count += np.sum(state[valid_indices[:, 0], valid_indices[:, 1]] == player)

            if count >= 5:
                return True

        return False

    def get_value_and_terminated(self, state, action):
        """Returns the value of the state and whether the game is terminated.

        Args:
            state (np.array): The current state of the game.
            action (int): The last action made.

        Returns:
            (int, boolean): The value of the state and whether the game is terminated.
        """
        if self.check_win(state, action):
            return 1, True
        if np.sum(state == EMPTY) == 0:
            return 0, True
        return 0, False

    def get_opponent(self, player):
        """
        Returns the opponent of the given player.

        Args:
            player (int): The current player.

        Returns:
            int: The opponent.
        """
        return PLAYER1 if player == PLAYER2 else PLAYER2

    def get_opponent_value(self, player):
        """
        _summary_

        Args:
            player (_type_): _description_

        Returns:
            _type_: _description_
        """
        return -player

    def change_perspective(self, state, player):
        """
        Returns the state from the perspective of the given player. Also works for batched states.

        Args:
            state (np.array): The current state of the game.
            player (int): The player to change the perspective to.

        Returns:
            np.array: The new state of the game.
        """
        return state * player

    def get_encoded_state(self, state):
        """
        Encodes the game state into a multi-channel array.

        Args:
            state (np.array): The game state represented as a 2D or 3D numpy array.

        Returns:
            np.array: The encoded game state as a multi-channel array.
        """
        encoded_state = np.stack(
            (state == PLAYER1, state == PLAYER2, state == EMPTY)
        ).astype(np.float32)

        if len(state.shape) == 3:
            encoded_state = np.swapaxes(encoded_state, 0, 1)

        return encoded_state

    def print(self, state):
        """
        Prints the current state of the Gomoku board.

        Args:
            state (np.array): The current state of the Gomoku board.

        Returns:
            None
        """
        board_str = ""
        for row in range(self.size):
            row_str = " ".join(
                str(state[row, col]) if state[row, col] != EMPTY else "."
                for col in range(self.size)
            )
            board_str += row_str + "\n"
        board_str = board_str.replace("-1", "O").replace("1", "X")
        print(board_str)


In [None]:
class Node:
    def __init__(self, game, args, state, parent=None, action=None, prior=0):
        self.game = game
        self.args = args
        self.state = state
        self.parent = parent
        self.action = action
        self.prior = prior

        self.children = []
        self.visit_count = 0
        self.total_value = 0

    def is_fully_expanded(self):
        """
        Checks if the node is fully expanded.

        Returns:
            bool: True if the node is fully expanded, False otherwise.
        """
        return len(self.children) > 0

    def select(self):
        """
        Selects the best child node based on the Upper Confidence Bound (UCB) value.

        Returns:
            The best child node.
        """
        best_child = None
        best_ucb = -np.inf
        for child in self.children:
            ucb = self.get_ucb(child)
            if ucb > best_ucb:
                best_ucb = ucb
                best_child = child

        return best_child

    def get_ucb(self, child):
        """
        Calculates the Upper Confidence Bound (UCB) value for a given child node.

        Parameters:
            child (Node): The child node for which to calculate the UCB value.

        Returns:
            float: The UCB value for the child node.
        """
        if child.visit_count == 0:
            q_value = 0
        else:
            q_value = 1 - (child.total_value / child.visit_count + 1) / 2
        return (
            q_value
            + self.args["C"]
            * np.sqrt(self.visit_count)
            / (child.visit_count + 1)
            * child.prior
        )

    def expand(self, policy):
        """
        Expands the current node by creating child nodes for each possible action.

        Args:
            policy (list): A list of probabilities for each possible action.

        Returns:
            child (Node): The child node that was created.
        """
        for action, prob in enumerate(policy):
            if prob > 0:
                child_state = self.state.copy()
                child_state = self.game.get_next_state(child_state, action, PLAYER1)
                child_state = self.game.change_perspective(child_state, player=PLAYER2)

                child = Node(self.game, self.args, child_state, self, action, prob)
                self.children.append(child)

        return child

    def backpropagate(self, value):
        """
        Backpropagate the value obtained from the simulation to update the statistics of the current node and its ancestors.

        Args:
            value: The value obtained from the simulation.

        Returns:
            None
        """
        self.visit_count += 1
        self.total_value += value
        if self.parent is not None:
            value = self.game.get_opponent_value(value)
            self.parent.backpropagate(value)


class AlphaMCTS:
    def __init__(self, game, args, model):
        self.game = game
        self.args = args
        self.model = model

    @torch.no_grad()
    def search(self, state):
        """
        Performs Monte Carlo Tree Search (MCTS) to find the best action to take.

        Args:
            state (np.array): The current state of the game.

        Returns:
            np.array: An array of action probabilities, indicating the likelihood of selecting each action.
        """
        root = Node(self.game, self.args, state)
        for _ in range(self.args["num_searches"]):
            node = root

            # selection
            while node.is_fully_expanded():
                node = node.select()

            value, terminated = self.game.get_value_and_terminated(
                node.state, node.action
            )
            value = self.game.get_opponent_value(value)

            if not terminated:
                policy, value = self.model(
                    torch.tensor(
                        self.game.get_encoded_state(node.state),
                        device=self.model.device,
                    ).unsqueeze(0)
                )
                policy = torch.softmax(policy, axis=1).squeeze(0).detach().cpu().numpy()
                valid_moves = self.game.get_moves(node.state)
                policy = policy * valid_moves
                policy /= np.sum(policy)

                value = value.item()
                node.expand(policy)

            node.backpropagate(value)

        action_probs = np.zeros(self.game.size * self.game.size)
        for child in root.children:
            action_probs[child.action] = child.visit_count
        action_probs = action_probs / np.sum(action_probs)

        return action_probs


In [None]:
class ResBlock(nn.Module):
    def __init__(self, num_hidden):
        super().__init__()
        self.conv1 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(num_hidden)
        self.conv2 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(num_hidden)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = out + x
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, size, num_blocks, num_hidden, device):
        super().__init__()

        self.device = device
        self.startBlock = nn.Sequential(
            nn.Conv2d(3, num_hidden, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_hidden),
            nn.ReLU()
        )

        self.backBone = nn.ModuleList(
            [ResBlock(num_hidden) for _ in range(num_blocks)]
        )

        self.policyHead = nn.Sequential(
            nn.Conv2d(num_hidden, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(size * size * 32, size * size)
        )

        self.valueHead = nn.Sequential(
            nn.Conv2d(num_hidden, 3, kernel_size=3, padding=1),
            nn.BatchNorm2d(3),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3 * size * size, 1),
            nn.Tanh()
        )

        self.to(device)

    def forward(self, x):
        x = self.startBlock(x)
        for block in self.backBone:
            x = block(x)
        policy = self.policyHead(x)
        value = self.valueHead(x)

        return policy, value

In [None]:
class TrainingMCTS:
    def __init__(self, game, args, model):
        self.game = game
        self.args = args
        self.model = model

    @torch.no_grad()
    def search(self, states, self_play_games):
        policy, _ = self.model(
            torch.tensor(self.game.get_encoded_state(states), device=self.model.device)
        )
        policy = torch.softmax(policy, axis=1).cpu().numpy()

        for i, game in enumerate(self_play_games):
            game_policy = policy[i]
            valid_moves = self.game.get_moves(states[i])
            game_policy *= valid_moves
            game_policy /= np.sum(game_policy)

            game.root = Node(self.game, self.args, states[i])
            game.root.expand(game_policy)

        for _ in range(self.args["num_searches"]):
            for game in self_play_games:
                game.node = None
                node = game.root

                while node.is_fully_expanded():
                    node = node.select()

                value, terminated = self.game.get_value_and_terminated(
                    node.state, node.action
                )
                value = self.game.get_opponent_value(value)

                if terminated:
                    node.backpropagate(value)
                else:
                    game.node = node

            expandable_games = [
                idx
                for idx in range(len(self_play_games))
                if self_play_games[idx].node is not None
            ]
            if len(expandable_games) > 0:
                states = np.stack(
                    [self_play_games[idx].node.state for idx in expandable_games]
                )
                policy, value = self.model(
                    torch.tensor(
                        self.game.get_encoded_state(states), device=self.model.device
                    )
                )
                policy = torch.softmax(policy, axis=1).detach().cpu().numpy()
                value = value.cpu().numpy()

            for i, idx in enumerate(expandable_games):
                node = self_play_games[idx].node
                self_play_policy, self_play_value = policy[i], value[i]

                valid_moves = self.game.get_moves(node.state)
                self_play_policy = self_play_policy * valid_moves
                self_play_policy /= np.sum(self_play_policy)
                node.expand(self_play_policy)
                node.backpropagate(self_play_value)


class SelfPlayGame:
    def __init__(self, game):
        self.state = game.get_initial_state()
        self.memory = []
        self.root = None
        self.node = None


class AlphaZero:
    def __init__(self, model, optimizer, game, args):
        self.model = model
        self.optimizer = optimizer
        self.game = game
        self.args = args
        self.mcts = TrainingMCTS(game, args, model)

    def self_play(self):
        return_memory = []
        player = PLAYER1
        self_paly_games = [
            SelfPlayGame(self.game) for _ in range(self.args["num_parallel_games"])
        ]

        while len(self_paly_games) > 0:
            states = np.stack([game.state for game in self_paly_games])
            neutral_states = self.game.change_perspective(states, player)
            self.mcts.search(neutral_states, self_paly_games)

            for i in range(len(self_paly_games))[::-1]:
                game = self_paly_games[i]

                action_probs = np.zeros(self.game.size * self.game.size)
                for child in game.root.children:
                    action_probs[child.action] = child.visit_count
                action_probs = action_probs / np.sum(action_probs)

                game.memory.append((game.root.state, action_probs, player))
                action = np.random.choice(
                    self.game.size * self.game.size, p=action_probs
                )
                game.state = self.game.get_next_state(game.state, action, player)
                value, terminated = self.game.get_value_and_terminated(
                    game.state, action
                )
                if terminated:
                    for state, action_probs, player in game.memory:
                        outcome = (
                            value
                            if player == PLAYER1
                            else self.game.get_opponent_value(value)
                        )
                        return_memory.append(
                            (self.game.get_encoded_state(state), action_probs, outcome)
                        )
                    del self_paly_games[i]

            player = self.game.get_opponent(player)

        return return_memory

    def train(self, memory):
        random.shuffle(memory)
        for batch_index in range(0, len(memory), self.args["batch_size"]):
            batch = memory[
                batch_index : min(len(memory) - 1, batch_index)
                + self.args["batch_size"]
            ]
            states, policy_targets, value_targets = zip(*batch)

            states, policy_targets, value_targets = (
                np.array(states),
                np.array(policy_targets),
                np.array(value_targets).reshape(-1, 1),
            )
            states = torch.tensor(states, dtype=torch.float32, device=self.model.device)
            policy_targets = torch.tensor(
                policy_targets, dtype=torch.float32, device=self.model.device
            )
            value_targets = torch.tensor(
                value_targets, dtype=torch.float32, device=self.model.device
            )

            out_policy, out_value = self.model(states)

            policy_loss = F.cross_entropy(out_policy, policy_targets)
            value_loss = F.mse_loss(out_value, value_targets)
            loss = policy_loss + value_loss

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

    def learn(self):
        for iteration in trange(self.args["num_iterations"]):
            memory = []

            self.model.eval()
            for play_iteration in trange(
                self.args["num_self_play_iterations"] // self.args["num_parallel_games"]
            ):
                memory += self.self_play()

            self.model.train()
            for epoch in range(self.args["num_epochs"]):
                self.train(memory)

            torch.save(
                self.model.state_dict(),
                f'{self.args["model_path"]}/model_{iteration}.pt',
            )
            torch.save(
                self.optimizer.state_dict(),
                f'{self.args["optimizer_path"]}/optimizer_{iteration}.pt',
            )


In [None]:
BOARD_SIZE = 10

game = Gomoku(BOARD_SIZE)
num_blocks = 16
num_hidden = 256
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ResNet(game.size, num_blocks, num_hidden, device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
args = {
    'C': 2,
    'num_searches': 500,
    'num_iterations': 3,
    'num_self_play_iterations': 200,
    'num_epochs': 4,
    'num_parallel_games': 200,
    'batch_size': 128,

    'model_path': model_path,
    'optimizer_path': optimizer_path
}

alphaZero = AlphaZero(model, optimizer, game, args)
alphaZero.learn()

In [None]:
game = Gomoku(BOARD_SIZE)

state = game.get_initial_state()
actions = [0, 7, 1, 8, 2, 9, 3, 10]
player = PLAYER1
for action in actions:
    state = game.get_next_state(state, action, player)
    player = game.get_opponent(player)

game.print(state)

tensor_state = torch.tensor(game.get_encoded_state(state)).unsqueeze(0)

model = ResNet(game.size, num_blocks, num_hidden, device)
model.load_state_dict(torch.load(f'{model_path}/model_{args["num_iterations"] - 1}.pt'))
model.eval()

policy, value = model(tensor_state)
value = value.item()
policy = torch.softmax(policy, axis=1).squeeze(0).detach().cpu().numpy()

print(value, policy)