In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import random
from einops.layers.torch import Rearrange
from einops import rearrange

from typing import Any, Dict, Tuple, Optional
from game_mechanics import (
    State,
    all_legal_moves,
    choose_move_randomly,
    human_player,
    is_terminal,
    load_pkl,
    play_go,
    reward_function,
    save_pkl,
    transition_function,
)
from tqdm.notebook import tqdm

from functools import partial
import pandas as pd
from datetime import datetime
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

from architecture import *
from configs.v0_PPO_MCTS import *

# from utils import *
%load_ext autoreload
%autoreload 2

In [10]:
class UCTNode():
    def __init__(self, game_state,
                 move, parent=None):
        self.game_state = game_state
        self.move = move
        self.is_expanded = False
        self.parent = parent  # Optional[UCTNode]
        self.children = {}  # Dict[move, UCTNode]
        self.child_priors = np.zeros(
            [81], dtype=np.float32)
        self.child_total_value = np.zeros(
            [81], dtype=np.float32)
        self.child_number_visits = np.zeros(
            [81], dtype=np.float32)
        
    @property
    def number_visits(self):
        return self.parent.child_number_visits[self.move]
    
    @number_visits.setter
    def number_visits(self, value):
        self.parent.child_number_visits = value
        
    @property
    def total_value(self):
        return self.parent.child_total_value[self.move]
    
    @total_value.setter
    def total_value(self, value):
        self.parent.child_total_value = value
        
    def child_Q(self):
        return self.child_total_value / (1 + self.child_number_visit)
    
    def child_U(self):
        return math.sqrt(self.number_visits) * (
            self.child_priors / (1 + self.child_number_visits))
    
    def best_child(self):
        return np.argmax(self.child_U() + self.child_Q())

In [12]:
env = GoEnv(choose_move_randomly)

In [13]:
state, reward, done, info = env.reset() 

In [15]:
state.board

array([[0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int8)

In [17]:
state.lib_tracker.__dict__

{'group_index': array([[-1, -1, -1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1, -1, -1]]),
 'groups': {},
 'liberty_cache': array([[0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=uint8),
 'max_group_id': 0}

In [21]:
state

State(board=array([[0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int8), recent_moves=(), to_play=1, ko=None, board_deltas=[], lib_tracker=<game_mechanics.liberty_tracker.LibertyTracker object at 0x00000149AEDDDCA0>)

In [43]:
legal_moves = all_legal_moves(state.board, state.ko)
legal_moves = legal_moves[legal_moves != 81]

In [29]:
state, reward, done, info = env.step(4)

In [30]:
state.board

array([[ 0,  1,  0,  1,  1,  1,  0,  0,  0],
       [-1,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  0, -1, -1,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  0, -1,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0,  0,  0]], dtype=int8)

In [31]:
state.lib_tracker.__dict__

{'group_index': array([[-1,  1, -1,  7,  7,  7, -1, -1, -1],
        [ 2, -1, -1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1,  8,  8, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1,  6, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1, -1, -1]]),
 'groups': {1: Group(id=1, stones=frozenset({(0, 1)}), liberties=frozenset({(0, 0), (1, 1), (0, 2)}), color=1),
  2: Group(id=2, stones=frozenset({(1, 0)}), liberties=frozenset({(1, 1), (2, 0), (0, 0)}), color=-1),
  6: Group(id=6, stones=frozenset({(7, 3)}), liberties=frozenset({(7, 4), (8, 3), (6, 3), (7, 2)}), color=-1),
  7: Group(id=7, stones=frozenset({(0, 3), (0, 4), (0, 5)}), liberties=frozenset({(0, 2), (1, 5), (1, 3), (1, 4), (0, 6)}), color=1),
  8: Group(id=8, stones=frozenset({(5, 3), (5, 4)}), liberties=frozenset({(4, 4), (5, 5), (4, 3), (6, 3), (6,

In [None]:
n_states_sample = 800 # MCTS evaluations per move
gamma = 1.0
n_residual_blocks = 4
block_width = 1000
update_opponent_wr = 0.55
num_steps = 1_000_000 # steps to update the network
batch_size = 2000


In [37]:
nn.BatchNorm2d?

In [84]:
net = AlphaGoZeroBatch(n_residual_blocks=4, block_width=1000)

In [62]:
all_legal_moves(state.board, state.ko)

array([ 0,  2,  6,  7,  8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
       22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
       39, 40, 41, 42, 43, 44, 45, 46, 47, 50, 51, 52, 53, 54, 55, 56, 57,
       58, 59, 60, 61, 62, 63, 64, 65, 67, 68, 69, 70, 71, 72, 73, 74, 75,
       76, 77, 78, 79, 80, 81])

In [86]:
def UCT_search(game_state, num_reads):
    root = UCTNode(game_state)
    for _ in range(num_reads):
        leaf = root.select_leaf()
        child_priors, value_estimate = NeuralNet.evaluate(leaf.game_state)
        leaf.expand(child_priors)
        leaf.backup(value_estimate)
    return max(root.children.items(),
               key=lambda item: item[1].number_visits)

In [66]:
def tensorize(state):
    return torch.as_tensor(state.board, dtype=torch.float32)

boards = []
legal_moves = []
state, reward, done, info = env.reset() 
for move in range(5):
    state, reward, done, info = env.step(move)
    
    boards.append(tensorize(state))
    legal_moves.append(all_legal_moves(state.board, state.ko))

In [69]:
boards = torch.stack(boards)

In [71]:
boards.shape

torch.Size([5, 9, 9])

In [85]:
net(boards, legal_moves)

(tensor([[0.0000e+00, 1.7069e-02, 4.2378e-03, 2.8627e-03, 1.2309e-03, 1.1055e-03,
          4.2781e-04, 5.9385e-02, 1.3311e-03, 4.3965e-03, 9.1464e-03, 1.4114e-02,
          1.3187e-02, 1.2248e-03, 1.1151e-02, 2.4886e-02, 3.0666e-03, 1.2912e-03,
          4.5420e-03, 2.1956e-03, 4.1933e-03, 5.6979e-03, 1.9598e-03, 3.8316e-03,
          2.1236e-02, 5.1215e-03, 8.1335e-03, 3.6050e-03, 2.0207e-02, 1.3014e-03,
          6.0377e-03, 8.7932e-04, 4.5097e-02, 4.6495e-02, 3.6479e-03, 4.9087e-04,
          8.2803e-03, 3.1976e-03, 1.6818e-02, 3.6947e-03, 9.2572e-03, 9.1271e-04,
          1.5196e-02, 1.7517e-03, 1.3186e-02, 4.6502e-03, 1.0345e-03, 3.2948e-03,
          1.7118e-02, 6.5746e-03, 3.0423e-02, 7.7311e-02, 5.7950e-02, 2.5900e-03,
          2.1337e-02, 0.0000e+00, 2.9044e-03, 1.3965e-01, 8.0718e-03, 8.9049e-03,
          4.4472e-04, 4.5735e-03, 4.4404e-02, 2.9204e-03, 4.3449e-03, 4.3044e-03,
          6.9560e-03, 1.5270e-02, 1.9981e-03, 3.5350e-03, 2.7919e-02, 4.3354e-03,
          1.1827