In [35]:
import torch
import numpy as np
import gym
from collections import deque, namedtuple
import random
from matplotlib import pyplot as plt
import numpy as np

In [36]:
class ResidualBlock(torch.nn.Module):
    def __init__(self, input_dims, n_filters):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Conv2d(input_dims, n_filters, kernel_size=3, stride=1, padding=1),
            torch.nn.BatchNorm2d(n_filters),
            torch.nn.ReLU(),
            torch.nn.Conv2d(n_filters, n_filters, kernel_size=3, stride=1, padding=1),
            torch.nn.BatchNorm2d(n_filters),
            torch.nn.ReLU(),
        )
        self.shortcut = torch.nn.Sequential(
            torch.nn.Conv2d(input_dims, n_filters, kernel_size=1, stride=1),
            torch.nn.BatchNorm2d(n_filters),
        )
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        return self.relu(self.layers(x) + self.shortcut(x))


class DNN(torch.nn.Module):
    def __init__(self, board_size, input_dims):
        super().__init__()
        n_filters = 256
        
        self.main_path = torch.nn.Sequential(
            torch.nn.Conv2d(input_dims[0], 256, kernel_size=3, stride=1),
            torch.nn.BatchNorm2d(n_filters),
            torch.nn.ReLU(),
            ResidualBlock(n_filters, n_filters),
            ResidualBlock(n_filters, n_filters),
            ResidualBlock(n_filters, n_filters),
            ResidualBlock(n_filters, n_filters),
            ResidualBlock(n_filters, n_filters),
        )
        
        
        self.policy = torch.nn.Sequential(
            torch.nn.Conv2d(n_filters, 2, kernel_size=1, stride=1),
            torch.nn.ReLU(),
            torch.nn.Flatten(),
            torch.nn.Linear((board_size-2)**2 * 2, board_size ** 2 + 1),
            torch.nn.Softmax(dim=1),
        )
        
        self.value = torch.nn.Sequential(
            torch.nn.Conv2d(n_filters, 1, kernel_size=1, stride=1),
            torch.nn.BatchNorm2d(1),
            torch.nn.ReLU(),
            torch.nn.Flatten(),
            torch.nn.Linear((board_size-2)**2, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 1),
            torch.nn.Tanh(),
        )
        
    def forward(self, state):
        x = self.main_path(state)
        policy = self.policy(x)
        value = self.value(x)
        return policy, value
        

In [37]:
class AlphaGoZero:
    def __init__(self, env, board_size=19, device='cpu'):
        self.env = env
        self.device = device
        
        input_dims = env.observation_space.shape
        
        
        
        self.dnn = DNN(board_size, input_dims).to(device)
        
        
        pass
    
    def train(self, iterations):
        state = self.env.reset()
        
        state = torch.from_numpy(state).float().to(self.device)
        state = torch.stack([state] * 3, dim=0)
        print(state.shape)
        policy, value = self.dnn(state)
        print(policy.shape, value.shape)
        
        pass

In [38]:
SIZE = 7

device = 'cuda' if torch.cuda.is_available() else 'cpu'

env = gym.make('gym_go:go-v0', size=7, komi=0, reward_method='real')
alphago_zero = AlphaGoZero(env, board_size=7, device=device)
alphago_zero.train(100)

torch.Size([3, 6, 19, 19])
torch.Size([3, 362]) torch.Size([3, 1])
