In [68]:
import torch
import numpy as np
import gym
from collections import deque, namedtuple
import random
from matplotlib import pyplot as plt
import numpy as np
import copy

In [69]:
class ResidualBlock(torch.nn.Module):
    def __init__(self, input_dims, n_filters):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Conv2d(input_dims, n_filters, kernel_size=3, stride=1, padding=1),
            torch.nn.BatchNorm2d(n_filters),
            torch.nn.ReLU(),
            torch.nn.Conv2d(n_filters, n_filters, kernel_size=3, stride=1, padding=1),
            torch.nn.BatchNorm2d(n_filters),
            torch.nn.ReLU(),
        )
        self.shortcut = torch.nn.Sequential(
            torch.nn.Conv2d(input_dims, n_filters, kernel_size=1, stride=1),
            torch.nn.BatchNorm2d(n_filters),
        )
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        return self.relu(self.layers(x) + self.shortcut(x))


class DNN(torch.nn.Module):
    def __init__(self, board_size, input_dims):
        super().__init__()
        n_filters = 32
        
        self.main_path = torch.nn.Sequential(
            torch.nn.Conv2d(input_dims[0], n_filters, kernel_size=3, stride=1),
            torch.nn.BatchNorm2d(n_filters),
            torch.nn.ReLU(),
            ResidualBlock(n_filters, n_filters),
            ResidualBlock(n_filters, n_filters),
        )
        
        
        self.policy = torch.nn.Sequential(
            torch.nn.Conv2d(n_filters, 2, kernel_size=1, stride=1),
            torch.nn.ReLU(),
            torch.nn.Flatten(),
            torch.nn.Linear((board_size-2)**2 * 2, board_size ** 2 + 1),
            torch.nn.Softmax(dim=1),
        )
        
        self.value = torch.nn.Sequential(
            torch.nn.Conv2d(n_filters, 1, kernel_size=1, stride=1),
            torch.nn.BatchNorm2d(1),
            torch.nn.ReLU(),
            torch.nn.Flatten(),
            torch.nn.Linear((board_size-2)**2, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 1),
            torch.nn.Tanh(),
        )
        
    def forward(self, state):
        x = self.main_path(state)
        policy = self.policy(x)
        value = self.value(x)
        return policy, value
        

In [70]:
# [black, white, turn, invalid, pass, game_over]
class MCTSNode:
    def __init__(self, state, parent=None, prior_action=-1):
        size = state.shape[1] * state.shape[2] + 1 # board size + pass
        self.valid_actions = (1 - state[3]).flatten() # invalid actions
        self.valid_actions = torch.cat((self.valid_actions, torch.tensor([1.0]))) # with pass
        self.valid_actions = self.valid_actions.bool()

        self.visit_counts = torch.zeros(size, dtype=torch.int32)
        self.total_action_values = torch.zeros(size, dtype=torch.float32)
        self.mean_action_values = torch.zeros(size, dtype=torch.float32)
        self.prior_probabilities = torch.zeros(size, dtype=torch.float32)
        
        self.selected_count = 0
        self.children = [None] * size
        
        self.parent = parent
        self.prior_action = prior_action
        
    def select(self, puct_multiplier=1.0):        
        const_part = puct_multiplier * np.sqrt(self.selected_count)
        us = const_part * self.prior_probabilities / (1 + self.visit_counts)
        
        sums = self.mean_action_values + us
        invalid_value = torch.min(sums) - 1
        masked_sums = torch.where(self.valid_actions, sums, invalid_value)
        max_a = torch.argmax(masked_sums)
        
        return max_a, self.children[max_a]
    

class MCTS:
    def __init__(self, state, evaluator, board_size=19):
        self.root = MCTSNode(state)
        self.PUCT_CONST = 1.0
        self.tau = 1.0
        
        # dihedral transforms on the board
        self.dihedrals_transforms = [
            lambda x: x,
            lambda x: torch.rot90(x, k=1, dims=(1, 2)),
            lambda x: torch.rot90(x, k=2, dims=(1, 2)),
            lambda x: torch.rot90(x, k=3, dims=(1, 2)),
            lambda x: torch.flip(x, dims=(1,)),
            lambda x: torch.flip(x, dims=(2,)),
            lambda x: torch.flip(torch.rot90(x, k=1, dims=(1, 2)), dims=(1,)),
            lambda x: torch.flip(torch.rot90(x, k=1, dims=(1, 2)), dims=(2,)),
        ]
        
        self.evaluator = evaluator
        self.board_size = board_size
        
    
    def train(self, env, state):        
        state, parent, child, action = self.search(env, self.root)
        child, value = self.expand(state, parent, action)
        self.backup(child, value)
                
    def search(self, env, root):
        node = root
        cnode = root
        state = None
        done = False
        while cnode and not done:
            node = cnode
            action, cnode = node.select(self.PUCT_CONST)
            state, _, done, _ = env.step(action.item())
            state = torch.tensor(state, dtype=torch.float32)
        
        return state, node, cnode, action
    
    def expand(self, state, parent, action):
        dihedrals = [t(state) for t in self.dihedrals_transforms]
        dihedrals = torch.stack(dihedrals)
        policy, value = self.evaluator(dihedrals)
        policy = policy.mean(dim=0)
        value = value.mean()
        
        child = MCTSNode(state, parent=parent, prior_action=action)
        parent.children[action] = child
        child.prior_probabilities = policy
        
        return child, value
        
    def backup(self, child, value, child_was_none=False):
        parent = child
        if not child_was_none:
            parent = child.parent
        while parent:
            parent.selected_count += 1
            parent.visit_counts[child.prior_action] += 1
            parent.total_action_values[child.prior_action] += value
            parent.mean_action_values[child.prior_action] = parent.total_action_values[child.prior_action] / parent.visit_counts[child.prior_action]
            
            child, parent = parent, parent.parent
            
            
    def play(self, state, deterministic=False):
        node = self.root
        pis = (node.visit_counts ** (1 / self.tau)) / (node.selected_count ** (1 / self.tau))
        if deterministic:
            selected_action = torch.argmax(pis)
        else:
            selected_action = torch.multinomial(pis, 1).item()
        
          
        child = node.children[selected_action]
        child.parent = None
        self.root = child
        return selected_action
        
        
        
        

In [71]:
Experience = namedtuple('Experience', 'state policy outcome')

In [72]:
class AlphaGoZero:
    def __init__(self, env, board_size=19):
        self.env = env
        self.board_size = board_size
        
        self.max_moves = board_size ** 2 + 1
        
        self.sub_iterations = 25
        self.batch_size = 16
        self.min_buffer_size = 16
                
        input_dims = env.observation_space.shape

        base_state = env.reset()
        base_state = torch.tensor(base_state, dtype=torch.float32)
        self.dnn = DNN(board_size, input_dims)
        self.mcts = MCTS(base_state, self.dnn, board_size)
        self.trojectory_buffer = deque(maxlen=10000)
        
        self.optimizer = torch.optim.Adam(self.dnn.parameters(), lr=0.0001)
        
    def train(self, iterations, render=False, eval=False):
        for i in range(iterations):

            state = self.env.reset()
            state = torch.tensor(state, dtype=torch.float32)
            done = False
        
            side = -1
            step_n = 0
            print("Iteration: ", i)
            
            action_so_far = []
            
            while not done and step_n <= self.max_moves:
                if render:
                    print("Iteration: ", i, "Step: ", step_n, " Side: ", "Black" if side == -1 else "White", "action: ")
                env2 = gym.make('gym_go:go-v0', size=5, komi=0, reward_method='heuristic')
                for _ in range(self.sub_iterations):
                    # env2 = copy.deepcopy(self.env)
                    env2.reset()
                    
                    for a in action_so_far:
                        state, _, _, _ = env2.step(a)
                        state = torch.tensor(state, dtype=torch.float32)
                    try:
                        self.mcts.train(env2, state)
                    except Exception as e:
                        env2.render("terminal")
                        raise e
                env2.close()
                
                action = self.mcts.play(state, deterministic=eval)
                action_so_far.append(action)
                nstate, reward, done, _ = self.env.step(action)
                if render:
                    self.env.render("terminal")
                
                self.trojectory_buffer.append(Experience(state, action, reward * side))
                state = torch.tensor(nstate, dtype=torch.float32)
                
                side *= -1
                step_n += 1
            
            if not eval:
                self.train_dnn()
            
    def train_dnn(self):
        if len(self.trojectory_buffer) < self.min_buffer_size:
            return
        
        batches = random.sample(self.trojectory_buffer, self.batch_size)
        states, policies, outcomes = zip(*batches)
        states = torch.stack(states).detach()
        policies = torch.tensor(policies, dtype=torch.long).detach()
        outcomes = torch.tensor(outcomes, dtype=torch.float32).detach()
        
        # Forward pass
        predicted_policies, predicted_values = self.dnn(states)
        
        # Define loss function (e.g., MSE for value and cross-entropy for policy)
        value_loss = torch.nn.functional.mse_loss(predicted_values.squeeze(-1), outcomes)
        policy_loss = torch.nn.functional.cross_entropy(predicted_policies, policies)
        total_loss = value_loss + policy_loss

        # Optimize the model
        self.optimizer.zero_grad()
        total_loss.backward()
        self.optimizer.step()

        

In [73]:
SIZE = 5

env = gym.make('gym_go:go-v0', size=SIZE, komi=0, reward_method='heuristic')
alphago_zero = AlphaGoZero(env, board_size=SIZE)
alphago_zero.train(10)

Iteration:  0
Iteration:  1
Iteration:  2
Iteration:  3
Iteration:  4
Iteration:  5
Iteration:  6
Iteration:  7
Iteration:  8
Iteration:  9


In [76]:
alphago_zero.max_moves = 100

alphago_zero.train(1, render=True, eval=True)

Iteration:  0
Iteration:  0 Step:  0  Side:  Black action: 


  logger.warn(
  logger.warn(
  logger.warn(
  logger.warn(
  logger.warn(
  logger.deprecation(
  if not isinstance(done, (bool, np.bool8)):
  logger.warn(
  logger.warn(
  logger.warn(f"{pre} is not within the observation space.")


	0 1 2 3 4 
0	╔═╤═╤═╤═╗
1	╟─┼─┼─┼─╢
2	╟─┼─┼─┼─╢
3	╟─┼─┼─┼─╢
4	○═╧═╧═╧═╝
	Turn: WHITE, Game State (ONGOING|PASSED|END): ONGOING
	Black Area: 25, White Area: 0

Iteration:  0 Step:  1  Side:  White action: 
	0 1 2 3 4 
0	╔═╤═╤═╤═╗
1	╟─┼─┼─┼─╢
2	╟─┼─┼─┼─╢
3	╟─┼─┼─┼─╢
4	○═╧═╧═●═╝
	Turn: BLACK, Game State (ONGOING|PASSED|END): ONGOING
	Black Area: 1, White Area: 1

Iteration:  0 Step:  2  Side:  Black action: 
	0 1 2 3 4 
0	○═╤═╤═╤═╗
1	╟─┼─┼─┼─╢
2	╟─┼─┼─┼─╢
3	╟─┼─┼─┼─╢
4	○═╧═╧═●═╝
	Turn: WHITE, Game State (ONGOING|PASSED|END): ONGOING
	Black Area: 2, White Area: 1

Iteration:  0 Step:  3  Side:  White action: 
	0 1 2 3 4 
0	○═╤═╤═╤═●
1	╟─┼─┼─┼─╢
2	╟─┼─┼─┼─╢
3	╟─┼─┼─┼─╢
4	○═╧═╧═●═╝
	Turn: BLACK, Game State (ONGOING|PASSED|END): ONGOING
	Black Area: 2, White Area: 2

Iteration:  0 Step:  4  Side:  Black action: 
	0 1 2 3 4 
0	○═╤═╤═╤═●
1	○─┼─┼─┼─╢
2	╟─┼─┼─┼─╢
3	╟─┼─┼─┼─╢
4	○═╧═╧═●═╝
	Turn: WHITE, Game State (ONGOING|PASSED|END): ONGOING
	Black Area: 3, White Area: 2

Iteration:  0 Step:  5  S