In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.notebook import trange
import random
import os

torch.manual_seed(0)

EMPTY = 0
PLAYER1 = 1
PLAYER2 = -1

In [None]:
def in_colab():
    try:
        import google.colab
        return True
    except ImportError:
        return False

if in_colab():
    from google.colab import drive
    drive.mount('/content/drive')
    base_path = '/content/drive/My Drive/Colab Notebooks'
else:
    base_path = '.'

model_path = os.path.join(base_path, 'alpha_gomoku/model')
optimizer_path = os.path.join(base_path, 'alpha_gomoku/optimizer')

os.makedirs(model_path, exist_ok=True)
os.makedirs(optimizer_path, exist_ok=True)

Mounted at /content/drive


In [None]:
class Gomoku:
    def __init__(self, size=15):
        self.size = size

    def get_initial_state(self):
        return np.zeros((self.size, self.size)).astype(np.int8)

    def get_next_state(self, state, action, player):
        row = action // self.size
        col = action % self.size
        if state[row, col] != EMPTY:
            raise ValueError("Invalid action")
        state[row, col] = player
        return state

    def get_moves(self, state):
        return (state.reshape(-1) == EMPTY).astype(np.uint8)

    def check_win(self, state, action):
        if action is None:
            return False

        row = action // self.size
        col = action % self.size
        player = state[row, col]
        if player == EMPTY:
            return False

        directions = [(1, 0), (0, 1), (1, 1), (1, -1)]

        for dr, dc in directions:
            count = 1

            # Create an array of indices for positive direction
            indices = np.array([(row + i * dr, col + i * dc) for i in range(1, 5)])
            valid_indices = (indices[:, 0] >= 0) & (indices[:, 0] < self.size) & (indices[:, 1] >= 0) & (indices[:, 1] < self.size)
            valid_indices = indices[valid_indices]

            count += np.sum(state[valid_indices[:, 0], valid_indices[:, 1]] == player)

            # Create an array of indices for negative direction
            indices = np.array([(row - i * dr, col - i * dc) for i in range(1, 5)])
            valid_indices = (indices[:, 0] >= 0) & (indices[:, 0] < self.size) & (indices[:, 1] >= 0) & (indices[:, 1] < self.size)
            valid_indices = indices[valid_indices]

            count += np.sum(state[valid_indices[:, 0], valid_indices[:, 1]] == player)

            if count >= 5:
                return True

        return False

    def get_value_and_terminated(self, state, action):
        if self.check_win(state, action):
            return 1, True
        if np.sum(state == EMPTY) == 0:
            return 0, True
        return 0, False

    def get_opponent(self, player):
        return PLAYER1 if player == PLAYER2 else PLAYER2

    def get_opponent_value(self, player):
        return -player

    def change_perspective(self, state, player):
        return state * player

    def get_encoded_state(self, state):
        return np.stack(
            (state == PLAYER1, state == PLAYER2, state == EMPTY)
        ).astype(np.float32)


    def print(self, state):
        board_str = ''
        for row in range(self.size):
            row_str = ' '.join(str(state[row, col]) if state[row, col] != EMPTY else '.' for col in range(self.size))
            board_str += row_str + '\n'
        board_str = board_str.replace('-1', 'O').replace('1', 'X')
        print(board_str)

In [None]:
class ResBlock(nn.Module):
    def __init__(self, num_hidden):
        super().__init__()
        self.conv1 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(num_hidden)
        self.conv2 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(num_hidden)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = out + x
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, size, num_blocks, num_hidden, device):
        super().__init__()

        self.device = device
        self.startBlock = nn.Sequential(
            nn.Conv2d(3, num_hidden, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_hidden),
            nn.ReLU()
        )

        self.backBone = nn.ModuleList(
            [ResBlock(num_hidden) for _ in range(num_blocks)]
        )

        self.policyHead = nn.Sequential(
            nn.Conv2d(num_hidden, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(size * size * 32, size * size)
        )

        self.valueHead = nn.Sequential(
            nn.Conv2d(num_hidden, 3, kernel_size=3, padding=1),
            nn.BatchNorm2d(3),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3 * size * size, 1),
            nn.Tanh()
        )

        self.to(device)

    def forward(self, x):
        x = self.startBlock(x)
        for block in self.backBone:
            x = block(x)
        policy = self.policyHead(x)
        value = self.valueHead(x)

        return policy, value

In [None]:
class Node:
    def __init__(self, game, args, state, parent=None, action=None, prior=0):
        self.game = game
        self.args = args
        self.state = state
        self.parent = parent
        self.action = action
        self.prior = prior

        self.children = []
        self.visit_count = 0
        self.total_value = 0

    def is_fully_expanded(self):
        return len(self.children) > 0

    def select(self):
        best_child = None
        best_ucb = -np.inf
        for child in self.children:
            ucb = self.get_ucb(child)
            if ucb > best_ucb:
                best_ucb = ucb
                best_child = child

        return best_child

    def get_ucb(self, child):
        if child.visit_count == 0:
            q_value = 0
        else:
            q_value = 1 - (child.total_value / child.visit_count + 1) / 2
        return q_value + self.args['C'] * np.sqrt(self.visit_count) / (child.visit_count + 1) * child.prior

    def expand(self, policy):
        for action, prob in enumerate(policy):
            if prob > 0:
                child_state = self.state.copy()
                child_state = self.game.get_next_state(child_state, action, PLAYER1)
                child_state = self.game.change_perspective(child_state, player=PLAYER2)

                child = Node(self.game, self.args, child_state, self, action, prob)
                self.children.append(child)

        return child

    def backpropagate(self, value):
        self.visit_count += 1
        self.total_value += value
        if self.parent is not None:
            value = self.game.get_opponent_value(value)
            self.parent.backpropagate(value)


class AlphaMCTS:
    def __init__(self, game, args, model):
        self.game = game
        self.args = args
        self.model = model

    @torch.no_grad()
    def search(self, state):
        root = Node(self.game, self.args, state)
        for search in range(self.args['num_searches']):
            node = root

            # selection
            while node.is_fully_expanded():
                node = node.select()

            value, terminated = self.game.get_value_and_terminated(node.state, node.action)
            value = self.game.get_opponent_value(value)

            if not terminated:
                policy, value = self.model(
                    torch.tensor(self.game.get_encoded_state(node.state), device=self.model.device).unsqueeze(0)
                )
                policy = torch.softmax(policy, axis=1).squeeze(0).detach().cpu().numpy()
                valid_moves = self.game.get_moves(node.state)
                policy = policy * valid_moves
                policy /= np.sum(policy)

                value = value.item()
                node.expand(policy)

            node.backpropagate(value)

        action_probs = np.zeros(self.game.size * self.game.size)
        for child in root.children:
            action_probs[child.action] = child.visit_count
        action_probs = action_probs / np.sum(action_probs)

        return action_probs


In [None]:
class AlphaZero:
    def __init__(self, model, optimizer, game, args):
        self.model = model
        self.optimizer = optimizer
        self.game = game
        self.args = args
        self.mcts = AlphaMCTS(game, args, model)

    def self_play(self):
        memory = []
        player = PLAYER1
        state = self.game.get_initial_state()

        while True:
            neutral_state = self.game.change_perspective(state, player)
            action_probs = self.mcts.search(neutral_state)
            memory.append((neutral_state, action_probs, player))
            action = np.random.choice(self.game.size * self.game.size, p=action_probs)
            state = self.game.get_next_state(state, action, player)
            value, terminated = self.game.get_value_and_terminated(state, action)
            if terminated:
                return_memory = []
                for state, action_probs, player in memory:
                    outcome = value if player == PLAYER1 else self.game.get_opponent_value(value)
                    return_memory.append((self.game.get_encoded_state(state), action_probs, outcome))

                return return_memory

            player = self.game.get_opponent(player)



    def train(self, memory):
        random.shuffle(memory)
        for batch_index in range(0, len(memory), self.args['batch_size']):
            batch = memory[batch_index:min(len(memory) - 1, batch_index) + self.args['batch_size']]
            states, policy_targets, value_targets = zip(*batch)

            states, policy_targets, value_targets = np.array(states), np.array(policy_targets), np.array(value_targets).reshape(-1, 1)
            states = torch.tensor(states, dtype=torch.float32, device=self.model.device)
            policy_targets = torch.tensor(policy_targets, dtype=torch.float32, device=self.model.device)
            value_targets = torch.tensor(value_targets, dtype=torch.float32, device=self.model.device)

            out_policy, out_value = self.model(states)

            policy_loss = F.cross_entropy(out_policy, policy_targets)
            value_loss = F.mse_loss(out_value, value_targets)
            loss = policy_loss + value_loss

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

    def learn(self):
        for iteration in trange(self.args['num_iterations']):
            memory = []

            self.model.eval()
            for play_iteration in trange(self.args['num_self_play_iterations']):
                memory += self.self_play()

            self.model.train()
            for epoch in range(self.args['num_epochs']):
                self.train(memory)

            torch.save(self.model.state_dict(), f'{model_path}/model_{iteration}.pt')
            torch.save(self.optimizer.state_dict(), f'{optimizer_path}/optimizer_{iteration}.pt')

In [9]:
game = Gomoku(10)
num_blocks = 4
num_hidden = 64
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ResNet(game.size, num_blocks, num_hidden, device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
args = {
    'C': 2,
    'num_searches': 100,
    'num_iterations': 3,
    'num_self_play_iterations': 50,
    'num_epochs': 4,
    'batch_size': 64,
}

alphaZero = AlphaZero(model, optimizer, game, args)
alphaZero.learn()

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

In [None]:
game = Gomoku(10)

state = game.get_initial_state()
actions = [0, 7, 1, 8, 2, 9, 3, 10]
player = PLAYER1
for action in actions:
    state = game.get_next_state(state, action, player)
    player = game.get_opponent(player)

game.print(state)

tensor_state = torch.tensor(game.get_encoded_state(state)).unsqueeze(0)

model = ResNet(game.size, num_blocks, num_hidden, device)
model.load_state_dict(torch.load(f'{model_path}/model_2.pt'))
model.eval()

policy, value = model(tensor_state)
value = value.item()
policy = torch.softmax(policy, axis=1).squeeze(0).detach().cpu().numpy()

print(value, policy)

In [None]:
game = Gomoku(10)
player = PLAYER1

args = {
    'C': 2,
    'num_searches': 100,
    'num_iterations': 3,
    'num_self_play_iterations': 50,
    'num_epochs': 4,
    'batch_size': 64,
}
model = ResNet(game.size, num_blocks, num_hidden, device)
model.load_state_dict(torch.load(f'{model_path}/model_2.pt', map_location=device))
model.eval()
mcts = AlphaMCTS(game, args, model)

states = game.get_initial_state()
tiles = game.size * game.size

while True:
    game.print(states)
    if player == PLAYER1:
        break
        valid_moves = game.get_moves(states)
        print("Valid moves:", [(i // tiles, i % tiles) for i in range(tiles) if valid_moves[i] == 1])
        user_input = input("Enter action: ")
        action = int(user_input.split(',')[0]) * game.size + int(user_input.split(',')[1])

        if valid_moves[action] == 0:
            print("Invalid move")
            continue
    else:
        neutral_state = game.change_perspective(states, player)
        mcts_action_probs = mcts.search(states)
        action = np.argmax(mcts_action_probs)


    states = game.get_next_state(states, action, player)
    value, terminated = game.get_value_and_terminated(states, action)
    if terminated:
        if value == 1:
            print("Player", player, "wins")
        else:
            print("Draw")
        break

    player = game.get_opponent(player)