In [11]:
import torch
from torch import nn
from utils import dotdict
import torch.nn.functional as F

def conv3x3(inchannels, channels, ks=3, s=1, p=1):
    return nn.Conv2d(inchannels, channels, kernel_size=ks, stride=s, padding=p)
def conv1x1(inchannels, channels, ks=1, s=1, p=0):
    return nn.Conv2d(inchannels, channels, kernel_size=ks, stride=s, padding=p)

class Flatten(nn.Module):
    def __init__(self): super().__init__()
    def forward(self, x): return x.view(x.size(0), -1)

class ResnetBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = conv3x3(channels, channels)
        self.b1 = nn.BatchNorm2d(channels)
        self.conv2 = conv3x3(channels, channels)
        self.b2 = nn.BatchNorm2d(channels)
        
    def forward(self, x):
        out = self.conv1(x)
        out = F.relu(self.b1(out))
        out = self.conv2(out)
        out = self.b2(out)
        out_x = F.relu(out + x)
        return out_x

class MainResnet(nn.Module):
    def __init__(self, create_block, blocks, inchannels, channels):
        super().__init__()        
        self.conv1 = conv3x3(inchannels, channels)
        self.b1 = nn.BatchNorm2d(channels)
        self.resnetblock = self._make_blocks(channels, create_block, blocks)
        
    def _make_blocks(self, channels, resnet_block, blocks):
        layers = []
        for i in range(0, blocks):
            layers.append(resnet_block(channels))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.resnetblock(x)
        return x

class ValueHead(nn.Module):
    def __init__(self, program_size, channels):
        super().__init__()
        self.conv1 = conv1x1(channels, 1)
        self.b = nn.BatchNorm2d(1)
        self.flatten = Flatten()
        self.linear1 = nn.Linear(program_size**2, 256)
        self.linear2 = nn.Linear(256, 1)
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(self.b(x))
        x = self.flatten(x)
        x = F.relu(self.linear1(x))
        x = F.tanh(self.linear2(x))
        return x

class PolicyHead(nn.Module):
    def __init__(self, program_size, vocab, channels):
        super().__init__()
        self.conv1 = conv1x1(channels, 1)
        self.b = nn.BatchNorm2d(1)
        self.flatten = Flatten()
        self.linear = nn.Linear(program_size**2, 512)
        self.logsoftmax = nn.LogSoftmax(dim=1)
    
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(self.b(x))
        x = self.flatten(x)
        x = self.logsoftmax(self.linear(x))
        return x

class GolaiZero(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.resnet = MainResnet(ResnetBlock, args.resnetBlocks, args.resnetInputDepth, args.resnetChannelDepth)
        self.policyhead = PolicyHead(args.programSize, args.vocabLen, args.resnetChannelDepth)
        self.valuehead = ValueHead(args.programSize, args.resnetChannelDepth)        
        
        
    def forward(self, x):
        features = self.resnet(x)
        policy_out = self.policyhead(features)
        value_out = self.valuehead(features)
        return policy_out, value_out



In [12]:
args = dotdict({
    'numIters': 1000,
    'numEps': 100,
    'vocabWidth': 2, 
    'vocabHeight': 2,
    'programSize': 6,
    'vocabLen': 9,
    'tempThreshold': 15,
    'updateThreshold': 0.6,
    'maxlenOfQueue': 200000,
    'numMCTSSims': 25,
    'arenaCompare': 40,
    'cpuct': 1,
    'resnetBlocks': 10,
    'resnetInputDepth': 1,
    'resnetChannelDepth': 64,
    'checkpoint': './temp/',
    'load_model': False,
    'load_folder_file': ('/dev/models/8x100x50', 'best.pth.tar'),
    'numItersForTrainExamplesHistory': 20,
})

In [13]:
#Print overview of model
from torchsummary import summary
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
golai_zero = GolaiZero(args).to(device)

In [14]:
summary(golai_zero, (1, 6, 6))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1             [-1, 64, 6, 6]             640
            Conv2d-2             [-1, 64, 6, 6]          36,928
       BatchNorm2d-3             [-1, 64, 6, 6]             128
            Conv2d-4             [-1, 64, 6, 6]          36,928
       BatchNorm2d-5             [-1, 64, 6, 6]             128
       ResnetBlock-6             [-1, 64, 6, 6]               0
            Conv2d-7             [-1, 64, 6, 6]          36,928
       BatchNorm2d-8             [-1, 64, 6, 6]             128
            Conv2d-9             [-1, 64, 6, 6]          36,928
      BatchNorm2d-10             [-1, 64, 6, 6]             128
      ResnetBlock-11             [-1, 64, 6, 6]               0
           Conv2d-12             [-1, 64, 6, 6]          36,928
      BatchNorm2d-13             [-1, 64, 6, 6]             128
           Conv2d-14             [-1, 6