In [3]:
import numpy as np
import torch

import time
from Utilities import Utilities as Utils
from IPython.display import clear_output

In [5]:
import torch.nn as nn

# The model architecture, change with caution due to possible state loading issues

class ResidualLayer(nn.Module):
    def __init__(self, filters, kernal_size=3):
        super().__init__()

        self.conv2d_sequential = nn.Sequential(
            nn.Conv2d(filters, filters, kernal_size, padding=(kernal_size - 1) // 2),
            nn.BatchNorm2d(filters),
            nn.ReLU(),
            nn.Conv2d(filters, filters, kernal_size, padding=(kernal_size - 1) // 2),
            nn.BatchNorm2d(filters),
        )

        self.relu = nn.ReLU()

    def forward(self, x):
        residual = x
        x = self.conv2d_sequential(x)
        x += residual
        x = self.relu(x)

        return x
    
class ConvolutionLayer(nn.Module):
    def __init__(self, infilters, outfilters, kernal_size=3):
        super().__init__()
        
        self.conv2d_sequential = nn.Sequential(                
            nn.Conv2d(infilters, outfilters, kernal_size, padding=(kernal_size - 1) // 2),
            nn.BatchNorm2d(outfilters),
            nn.ReLU(),
        )

    def forward(self, x):
        x = self.conv2d_sequential(x)
        return x
    
class PolicyHead(nn.Module):
    def __init__(self, filters):
        super().__init__()
        self.filters = filters

        self.head = nn.Sequential(
            nn.Conv2d(self.filters, 2, 1),
            nn.Flatten(),
            nn.BatchNorm1d(450),
            nn.ReLU(),
            nn.Linear(450, 225)
        )

    def forward(self, x):
        x = self.head(x)
        return x

    def forward(self, x):
        x = self.head(x)
        return x
    
class ValueHead(nn.Module):
    def __init__(self, filters):
        super().__init__()
        self.filters = filters

        self.value = nn.Sequential(
            nn.Conv2d(self.filters, 1, 1, padding=0),
            nn.BatchNorm2d(1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(225, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.value(x)
        return x

class NeuralNetwork(nn.Module):
    def __init__(self, filters, feature_dimensions, residual_layers=5, kernal_size=3):
        super().__init__()

        self.conv_layer = ConvolutionLayer(feature_dimensions, filters, kernal_size=kernal_size)
        self.residual_layers = nn.ModuleList([ResidualLayer(filters, kernal_size=kernal_size) for _ in range(residual_layers)])
        self.policy_head = PolicyHead(filters)
        self.value_head = ValueHead(filters)

    def forward(self, x):      
        x = self.conv_layer(x)
        for layer in self.residual_layers:
            x = layer(x)
        policyResult = self.policy_head(x)
        valueResult = self.value_head(x)

        return policyResult, valueResult

In [7]:
Filters = 128
Layers = 15
HistoryDepth = 8
KernalSize = 3

model = NeuralNetwork(Filters, HistoryDepth + 1, Layers, kernal_size=KernalSize)

model.load_state_dict(torch.load('../../128f_15l_8hd.pt', map_location=torch.device('cpu')))
model.eval()

TypeError: Expected state_dict to be dict-like, got <class 'torch.jit._script.RecursiveScriptModule'>.

In [3]:
board = np.zeros((HistoryDepth + 1, 15, 15), dtype=bool)

In [4]:
player = 1
while (True):
    if player % 2:
        modelInput = torch.from_numpy(board.astype(np.float32)).unsqueeze(0)
        with torch.no_grad():
            output = model(modelInput)
        policyOut, valueOut = output
        policyOut = torch.nn.functional.softmax(policyOut, dim=1)
        policyOut = np.array(policyOut)
        policyOut = policyOut.reshape(15, 15)
        rawPolicyOut = policyOut.copy()
        policyOut *= ~board[HistoryDepth]
        policyOut *= ~board[HistoryDepth // 2]
        policyOut = policyOut.flatten()

        valueOut = valueOut[0].item()

        index = policyOut.argmax()
        x, y = int(index // 15), int(index % 15)
        board = Utils.makeMove(board, x, y)
        clear_output(wait=True)
        print(x,y)
        print(Utils.sliceGamestate(board, 0))
        #^print(rawPolicyOut)
        time.sleep(1)
    else:
        x = int(input("X:"))
        y = int(input("Y:"))

        board = Utils.makeMove(board, x, y)
        clear_output(wait=True)
        print(Utils.sliceGamestate(board, 0))
        time.sleep(1)
    player += 1

8 7
   --------------------------------------------------------------
14 |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |
   --------------------------------------------------------------
13 |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |
   --------------------------------------------------------------
12 |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |
   --------------------------------------------------------------
11 |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |
   --------------------------------------------------------------
10 |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |
   --------------------------------------------------------------
 9 |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |
   --------------------------------------------------------------
 8 |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |
   --------------------------------------------------------------
 7 |   |   | 

ValueError: invalid literal for int() with base 10: ''