In [44]:
import torch
import numpy as np
import gym
from collections import deque, namedtuple
import random
from matplotlib import pyplot as plt
import numpy as np
import copy

In [43]:
# [black, white, turn, invalid, pass, game_over]
class MCTSNode:
    def __init__(self, state, parent=None, prior_action=-1, device='cpu'):
        size = state.shape[1] * state.shape[2] + 1 # board size + pass
        self.valid_actions = ~state[3]
        self.visit_counts = torch.zeros(size).to(device)
        self.total_action_values = torch.zeros(size).to(device)
        self.mean_action_values = torch.zeros(size).to(device)
        self.prior_probabilities = torch.zeros(size).to(device)
        
        self.selected_count = 0
        self.children = torch.empty(size, dtype=torch.object)
        
        self.parent = parent
        self.prior_action = prior_action
        
    def select(self, puct_multiplier=1.0):
        const_part = puct_multiplier * torch.sqrt(self.selected_count)
        us = const_part * self.prior_probabilities / (1 + self.visit_counts)
        
        sums = self.mean_action_values + us
        masked_sums = sums * self.valid_actions
        max_a = torch.argmax(masked_sums)
        
        self.selected_count += 1
        self.visit_counts[max_a] += 1
        
        return max_a, self.children[max_a]
    

class MCTS:
    def __init__(self, evaluator, board_size=19, device='cpu'):
        self.root = MCTSNode(None)
        self.PUCT_CONST = 1.0
        self.tau = 1.0
        
        # dihedral transforms on the board
        self.dihedrals_transforms = [
            lambda x: x,
            lambda x: torch.rot90(x, k=1, axes=(1, 2)),
            lambda x: torch.rot90(x, k=2, axes=(1, 2)),
            lambda x: torch.rot90(x, k=3, axes=(1, 2)),
            lambda x: torch.flip(x, axis=1),
            lambda x: torch.flip(x, axis=2),
            lambda x: torch.flip(torch.rot90(x, k=1, axes=(1, 2)), axis=1),
            lambda x: torch.flip(torch.rot90(x, k=1, axes=(1, 2)), axis=2),
        ]
        
        self.evaluator = evaluator
        self.board_size = board_size
        self.device = device
    
    def train(self, env, state, deterministic=False):
        turn = state[2][0][0]        
        
        state, parent, child, action, done, reward = self.search(env, state, self.root)
        if done:
            mul = 1 if state[2][0][0] != turn else -1
            value = reward * mul
        else:
            child, value = self.expand(state, parent, child, action)
        self.backup(child, value)
                
    def search(self, env, root):
        node = root
        cnode = root
        state = None
        while cnode and not done:
            node = cnode
            action, cnode = node.select(self.PUCT_CONST, device=self.device)
            state, reward, done, _ = env.step(action)            
        
        return state, node, cnode, action, done, reward

    def expand(self, state, parent, child, action):
        dihedrals = [t(state) for t in self.dihedrals_transforms]
        dihedrals = torch.tensor(dihedrals).to(self.device)
        policy, value = self.evaluator(dihedrals)
        # average policy and value over dihedral transforms
        policy = policy.mean(dim=0)
        value = value.mean()
        
        child = MCTSNode(state, parent=parent, action=action, device=self.device)
        parent.children[action] = child
        child.prior_probabilities = policy
        
        return child, value
        
    def backup(self, child, value):
        child.visit_counts += 1
        parent = child.parent
        while parent:
            parent.selected_count += 1
            parent.visited_counts[child.prior_action] += 1
            
            parent.total_action_values[child.prior_action] += value
            parent.mean_action_values[child.prior_action] = parent.total_action_values[child.prior_action] / parent.visit_counts[child.prior_action]
            
            
    def play(self, state, node, deterministic=False):
        pis = (node.visit_counts ** (1 / self.tau)) / (node.selected_count ** (1 / self.tau))
        if deterministic:
            selected_action = torch.argmax(pis)
        else:
            selected_action = torch.random.choice(pis)
        child = node.children[selected_action]
        child.parent = None
        return selected_action
        
        
        
        

In [40]:
class ResidualBlock(torch.nn.Module):
    def __init__(self, input_dims, n_filters):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Conv2d(input_dims, n_filters, kernel_size=3, stride=1, padding=1),
            torch.nn.BatchNorm2d(n_filters),
            torch.nn.ReLU(),
            torch.nn.Conv2d(n_filters, n_filters, kernel_size=3, stride=1, padding=1),
            torch.nn.BatchNorm2d(n_filters),
            torch.nn.ReLU(),
        )
        self.shortcut = torch.nn.Sequential(
            torch.nn.Conv2d(input_dims, n_filters, kernel_size=1, stride=1),
            torch.nn.BatchNorm2d(n_filters),
        )
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        return self.relu(self.layers(x) + self.shortcut(x))


class DNN(torch.nn.Module):
    def __init__(self, board_size, input_dims):
        super().__init__()
        n_filters = 64
        
        self.main_path = torch.nn.Sequential(
            torch.nn.Conv2d(input_dims[0], 256, kernel_size=3, stride=1),
            torch.nn.BatchNorm2d(n_filters),
            torch.nn.ReLU(),
            ResidualBlock(n_filters, n_filters),
            ResidualBlock(n_filters, n_filters),
            ResidualBlock(n_filters, n_filters),
            ResidualBlock(n_filters, n_filters),
            ResidualBlock(n_filters, n_filters),
        )
        
        
        self.policy = torch.nn.Sequential(
            torch.nn.Conv2d(n_filters, 2, kernel_size=1, stride=1),
            torch.nn.ReLU(),
            torch.nn.Flatten(),
            torch.nn.Linear((board_size-2)**2 * 2, board_size ** 2 + 1),
            torch.nn.Softmax(dim=1),
        )
        
        self.value = torch.nn.Sequential(
            torch.nn.Conv2d(n_filters, 1, kernel_size=1, stride=1),
            torch.nn.BatchNorm2d(1),
            torch.nn.ReLU(),
            torch.nn.Flatten(),
            torch.nn.Linear((board_size-2)**2, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 1),
            torch.nn.Tanh(),
        )
        
    def forward(self, state):
        x = self.main_path(state)
        policy = self.policy(x)
        value = self.value(x)
        return policy, value
        

In [49]:
Experience = namedtuple('Experience', 'state policy outcome')

In [41]:
class AlphaGoZero:
    def __init__(self, env, board_size=19, device='cpu'):
        self.env = env
        self.device = device
        
        self.sub_iterations = 100
        self.batch_size = 16
        self.min_buffer_size = 16
                
        input_dims = env.observation_space.shape
        
        self.dnn = DNN(board_size, input_dims).to(device)
        self.mcts = MCTS(self.dnn, board_size, device)
        self.trojectory_buffer = deque(maxlen=10000)
        
        self.optimizer = torch.optim.Adam(self.dnn.parameters(), lr=0.0001)
        
    def train(self, iterations):
        for i in range(iterations):

            state = self.env.reset()
            state = torch.tensor(state).to(self.device)
            done = False
        
            trajectory_white = []
            trajectory_black = []
            side = 1
            while not done:
                for _ in range(self.sub_iterations):
                    env2 = copy.deepcopy(self.env)
                    self.mcts.train(env2, state)
                    env2.close()
                
                action = self.mtc.play(state, deterministic=False)  
                nstate, reward, done, _ = self.env.step(action)

                if side == 1:
                    trajectory_white.append(Experience(state, action, reward))
                else:
                    trajectory_black.append(Experience(state, action, reward))
                
                state = torch.tensor(nstate).to(self.device)
                
                side *= -1
            
            outcome = reward
            
            for exp in trajectory_white:
                self.trojectory_buffer.append(Experience(exp.state, exp.action, -outcome))
            for exp in trajectory_black:
                self.trojectory_buffer.append(Experience(exp.state, exp.action, outcome))

            self.train_dnn()
            
    def train_dnn(self):
        if len(self.trojectory_buffer) < self.min_buffer_size:
            return
        
        batches = random.sample(self.trojectory_buffer, self.batch_size)
        states, policies, outcomes = zip(*batches)
        states = torch.stack(states).to(self.device).detach()
        policies = torch.tensor(policies).to(self.device).detach()
        outcomes = torch.tensor(outcomes).to(self.device).detach()
        
        winners = np.array(outcomes)[:, -1]

        # Forward pass
        predicted_policies, predicted_values = self.dnn(states)
        
        # Define loss function (e.g., MSE for value and cross-entropy for policy)
        value_loss = torch.nn.functional.mse_loss(predicted_values.squeeze(-1), outcomes)
        policy_loss = torch.nn.functional.cross_entropy(predicted_policies, policies)
        total_loss = value_loss + policy_loss

        # Optimize the model
        self.optimizer.zero_grad()
        total_loss.backward()
        self.optimizer.step()

        

In [42]:
SIZE = 7

device = 'cuda' if torch.cuda.is_available() else 'cpu'

env = gym.make('gym_go:go-v0', size=7, komi=0, reward_method='real')

alphago_zero = AlphaGoZero(env, board_size=7, device=device)
alphago_zero.train(100)

torch.Size([3, 6, 7, 7])
torch.Size([3, 50]) torch.Size([3, 1])
