In [8]:
# imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [101]:
# hyperparameters
EPOCHS = 100
BATCH_SIZE = 64
N = 8 # board size

num_input_planes =  (2*6 + 1) + (1 + 4 + 1)
INPUT_DIM = (BATCH_SIZE, num_input_planes, N, N) # (64, 19, 8, 8)
queen_planes = 56
knight_planes = 8
underpromotion_planes = 9
num_output_planes = queen_planes + knight_planes + underpromotion_planes
OUTPUT_DIM = (BATCH_SIZE, N*N*num_output_planes, 1) # (64, 73x8x8, 1)

LEARNING_RATE = 0.2
POLICY_WEIGHT = 0.5 # weight of policy loss
VALUE_WEIGHT = 0.5 # weight of value loss
CONVOLUTION_FILTERS = 256
NUM_HIDDEN_LAYERS = 19 # number of residual blocks

In [106]:
class ConvNet(nn.Module):
    def __init__(self, input_dim, output_dim, num_hidden_layers, convolution_filters):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_hidden_layers = num_hidden_layers
        self.convolution_filters = convolution_filters

        # convolutional layer
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=self.input_dim[1], out_channels=self.convolution_filters, kernel_size=3, stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(num_features=self.convolution_filters),
            nn.ReLU()
        )

        # residual layer
        self.residual = nn.Sequential(
            nn.Conv2d(in_channels=self.convolution_filters, out_channels=convolution_filters, kernel_size=3, stride=(1, 1), padding=1),
            nn.BatchNorm2d(num_features=self.convolution_filters),
            nn.ReLU(),
            nn.Conv2d(in_channels=self.convolution_filters, out_channels=convolution_filters, kernel_size=3, stride=(1, 1), padding=1),
            nn.BatchNorm2d(self.convolution_filters)
        )
        self.residual_layers = nn.ModuleList()
        for _ in range(num_hidden_layers):
            self.residual_layers.append(self.residual)

        # policy head
        self.policy_head = nn.Sequential(
            nn.Conv2d(in_channels=self.convolution_filters, out_channels=2, kernel_size=(1, 1), padding=0, stride=(1, 1)),
            nn.BatchNorm2d(num_features=2),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(in_features=2*self.input_dim[2]*self.input_dim[3], out_features=self.output_dim[1]),
            nn.Sigmoid() # probability = (0, 1)
        )

        # value head
        self.value_head = nn.Sequential(
            nn.Conv2d(in_channels=self.convolution_filters, out_channels=1, kernel_size=(1, 1), padding=0, stride=(1, 1)),
            nn.BatchNorm2d(num_features=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(in_features=1*self.input_dim[2]*self.input_dim[3], out_features=self.convolution_filters),
            nn.ReLU(),
            nn.Linear(in_features=self.convolution_filters, out_features=1),
            nn.Tanh() # value = (-1, 1)
        )

        
    def forward(self, x):
        # convolutional layer
        x = self.conv(x)

        # residual layers
        for residual_layer in self.residual_layers:
            x_res = residual_layer(x)
            x += x_res
            x = torch.relu(x)

        # policy output
        policy_output = self.policy_head(x)

        # value output
        value_output = self.value_head(x)

        return policy_output, value_output

In [107]:
from torchsummary import summary
model = ConvNet(INPUT_DIM, OUTPUT_DIM, NUM_HIDDEN_LAYERS, CONVOLUTION_FILTERS)
input_size = INPUT_DIM[1:]
summary(model, input_size=tuple(input_size), device='cpu')


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 256, 8, 8]          44,032
       BatchNorm2d-2            [-1, 256, 8, 8]             512
              ReLU-3            [-1, 256, 8, 8]               0
            Conv2d-4            [-1, 256, 8, 8]         590,080
            Conv2d-5            [-1, 256, 8, 8]         590,080
       BatchNorm2d-6            [-1, 256, 8, 8]             512
       BatchNorm2d-7            [-1, 256, 8, 8]             512
              ReLU-8            [-1, 256, 8, 8]               0
              ReLU-9            [-1, 256, 8, 8]               0
           Conv2d-10            [-1, 256, 8, 8]         590,080
           Conv2d-11            [-1, 256, 8, 8]         590,080
      BatchNorm2d-12            [-1, 256, 8, 8]             512
      BatchNorm2d-13            [-1, 256, 8, 8]             512
           Conv2d-14            [-1, 25

In [6]:
input_dim = INPUT_DIM
output_dim = OUTPUT_DIM
num_hidden_layers = NUM_HIDDEN_LAYERS
convolution_filters = CONVOLUTION_FILTERS

model = ConvNet(input_dim, output_dim, num_hidden_layers, convolution_filters)

policy_criterion = nn.CrossEntropyLoss()
value_criterion = nn.MSELoss()

learning_rate = LEARNING_RATE
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

def train_model():
    epochs = EPOCHS
    policy_weight = POLICY_WEIGHT
    value_weight = VALUE_WEIGHT

    for epoch in range(epochs):
        # to be written on the basis of input data
        pass