In [None]:
import pandas as pd
import numpy as np
import os
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn as nn
import torch.onnx
import torchvision.models as models
from torchvision.models import resnet18, ResNet18_Weights

In [None]:
def Chess_Model(nn.Module):
    def __init__(self, bit_board_shape, num_float_inputs, channel_multiple, concatenated_size):
        super(Chess_Model, self).__init__()
        self.num_channels = bit_board_shape[0]
        self.multiple = channel_multiple
        self.num_float_inputs = num_float_inputs
        self.concat_size = concatenated_size
        
        #RESNET BLOCK 1
        self.conv1 = nn.Conv2d(in_channels=self.num_channels, out_channels=self.num_channels*self.multiple, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=self.num_channels*self.multiple, out_channels=self.num_channels*self.multiple, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=self.num_channels*self.multiple, out_channels=self.num_channels*self.multiple, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(in_channels=self.num_channels*self.multiple, out_channels=self.num_channels, kernel_size=3, stride=1, padding=1)

        #RESNET BLOCK 2
        self.conv5 = nn.Conv2d(in_channels=self.num_channels, out_channels=self.num_channels*self.multiple, kernel_size=3, stride=1, padding=1)
        self.conv6 = nn.Conv2d(in_channels=self.num_channels*self.multiple, out_channels=self.num_channels*self.multiple, kernel_size=3, stride=1, padding=1)
        self.conv7 = nn.Conv2d(in_channels=self.num_channels*self.multiple, out_channels=self.num_channels*self.multiple, kernel_size=3, stride=1, padding=1)
        self.conv8 = nn.Conv2d(in_channels=self.num_channels*self.multiple, out_channels=self.num_channels, kernel_size=3, stride=1, padding=1)

        #RESNET BLOCK 3
        self.conv9 = nn.Conv2d(in_channels=self.num_channels, out_channels=self.num_channels*self.multiple, kernel_size=3, stride=1, padding=1)
        self.conv10 = nn.Conv2d(in_channels=self.num_channels*self.multiple, out_channels=self.num_channels*self.multiple, kernel_size=3, stride=1, padding=1)
        self.conv11= nn.Conv2d(in_channels=self.num_channels*self.multiple, out_channels=self.num_channels*self.multiple, kernel_size=3, stride=1, padding=1)
        self.conv12 = nn.Conv2d(in_channels=self.num_channels*self.multiple, out_channels=self.num_channels, kernel_size=3, stride=1, padding=1)

        self.pool = nn.MaxPool2d(kernel_size=3, stride=2)
        self.float_inputs_fc = nn.Linear(self.num_float_inputs, 512)
        self.fc1 = nn.Linear(self.concat_size, 1024)
        self.fc2 = nn.Linear(1024, 64)
        self.output_layer = nn.Linear(64, 1)

    def forward(self, bit_board, hanging_inputs):
        conv_x = self.pool(self.ResNetBlock1(bit_board))
        conv_x = self.pool(self.ResNetBlock2(conv_x))
        conv_x = self.pool(self.ResNetBlock3(conv_x))

        conv_x = conv_x.view(conv_x.size(0), -1)
        float_x = nn.functional.relu(self.float_inputs_fc(hanging_inputs))
        x = torch.cat((float_x, conv_x), dim=1)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x =  torch.sigmoid(self.output_layer(x))
        return x
    
    def ResNetBlock1(self, x):
        conv_x1 = self.conv1(x)
        conv_x2 = self.conv2(conv_x1)
        added1 = conv_x1 + conv_x2
        conv_x3 =self.conv3(added1)
        added2 = conv_x1 + conv_x2 + conv_x3
        conv_x4 = self.conv4(added2)
        return conv_x4 + x
    
    def ResNetBlock2(self, x):
        conv_x1 = self.conv5(x)
        conv_x2 = self.conv6(conv_x1)
        added1 = conv_x1 + conv_x2
        conv_x3 =self.conv7(added1)
        added2 = conv_x1 + conv_x2 + conv_x3
        conv_x4 = self.conv8(added2)
        return conv_x4 + x
    
    def ResNetBlock3(self, x):
        conv_x1 = self.conv9(x)
        conv_x2 = self.conv10(conv_x1)
        added1 = conv_x1 + conv_x2
        conv_x3 =self.conv11(added1)
        added2 = conv_x1 + conv_x2 + conv_x3
        conv_x4 = self.conv12(added2)
        return conv_x4 + x

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
model = Chess_Model(bit_board_shape = (72, 8, 8), num_float_inputs = 8, channel_multiple = 4, concatenated_size = 10)
num_params = count_parameters(model)
print("Number of parameters in the model in millions:", round(num_params/(1e6), 4))

In [None]:
model.eval()
model = model.to("cpu")
opset_version = 20
board_shape = (1, 72, 8, 8)
floats_shape = (1, 8)
input_bitboard = torch.tensor(np.random.rand(*board_shape), dtype = torch.float32)
input_floats = torch.tensor(np.random.rand(*floats_shape), dtype = torch.float32)
torch.onnx.export(model, (input_bitboard, input_floats), "./Models\\PikeBot_Models\\CNN_model_basic.onnx", opset_version=opset_version)
print("Model exported successfully!")