In [95]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import random
from einops.layers.torch import Rearrange
from einops import rearrange

from typing import Any, Dict, Tuple, Optional
from game_mechanics import (
    State,
    all_legal_moves,
    choose_move_randomly,
    human_player,
    is_terminal,
    load_pkl,
    play_go,
    reward_function,
    save_pkl,
    transition_function,
)
from tqdm.notebook import tqdm

from functools import partial
import pandas as pd
from datetime import datetime
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

from network import *
from configs.v0_PPO_MCTS import *
from MCTS import *

# from utils import *
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
env = GoEnv(choose_move_randomly)

In [97]:
net = AlphaGoZeroBatch(n_residual_blocks=architecture_settings['n_residual_blocks'], 
                       block_width=architecture_settings['block_width'])

In [66]:
def tensorize(state):
    return torch.as_tensor(state.board, dtype=torch.float32)

boards = []
legal_moves = []
state, reward, done, info = env.reset() 
for move in range(5):
    state, reward, done, info = env.step(move)
    
    boards.append(tensorize(state))
    legal_moves.append(all_legal_moves(state.board, state.ko))

In [69]:
boards = torch.stack(boards)

In [71]:
boards.shape

torch.Size([5, 9, 9])

In [85]:
net(boards, legal_moves)

(tensor([[0.0000e+00, 1.7069e-02, 4.2378e-03, 2.8627e-03, 1.2309e-03, 1.1055e-03,
          4.2781e-04, 5.9385e-02, 1.3311e-03, 4.3965e-03, 9.1464e-03, 1.4114e-02,
          1.3187e-02, 1.2248e-03, 1.1151e-02, 2.4886e-02, 3.0666e-03, 1.2912e-03,
          4.5420e-03, 2.1956e-03, 4.1933e-03, 5.6979e-03, 1.9598e-03, 3.8316e-03,
          2.1236e-02, 5.1215e-03, 8.1335e-03, 3.6050e-03, 2.0207e-02, 1.3014e-03,
          6.0377e-03, 8.7932e-04, 4.5097e-02, 4.6495e-02, 3.6479e-03, 4.9087e-04,
          8.2803e-03, 3.1976e-03, 1.6818e-02, 3.6947e-03, 9.2572e-03, 9.1271e-04,
          1.5196e-02, 1.7517e-03, 1.3186e-02, 4.6502e-03, 1.0345e-03, 3.2948e-03,
          1.7118e-02, 6.5746e-03, 3.0423e-02, 7.7311e-02, 5.7950e-02, 2.5900e-03,
          2.1337e-02, 0.0000e+00, 2.9044e-03, 1.3965e-01, 8.0718e-03, 8.9049e-03,
          4.4472e-04, 4.5735e-03, 4.4404e-02, 2.9204e-03, 4.3449e-03, 4.3044e-03,
          6.9560e-03, 1.5270e-02, 1.9981e-03, 3.5350e-03, 2.7919e-02, 4.3354e-03,
          1.1827