In [None]:
import torch
import wandb
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torch.nn.functional as F
from torchvision import transforms
from Utilities import Utilities as Utils

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

In [None]:
# Basic Function definitions

def train(dataloader, model, loss_fn, optimizer, epoch, logcount=5, wandb_log=False):
    size = len(dataloader.dataset)
    loginterval = len(dataloader) // logcount
    average_loss = 0.0

    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)
        average_loss += loss.detach().item()
        
        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if (batch % loginterval == 0) and (batch > 0):
            log_loss = average_loss / loginterval
            average_loss = 0
            current = batch * len(X)
            print(f"loss: {log_loss:>8f}  [{current:>5d}/{size:>5d}]")
            if wandb_log:
                wandb.log({"epoch": epoch, "train_loss": log_loss})

def test(dataloader, model, loss_fn, epoch, wandb_log=False):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).detach().item()
            correct += (pred.argmax(1) == y.argmax(1)).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    if wandb_log:
        wandb.log({"epoch" : epoch, "test_acc": correct, "test_loss": test_loss})
    print(f"Test Error: \n Accuracy: {(100*correct):>0.3f}%, Avg loss: {test_loss:>8f} \n")

def parameterCount(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

In [None]:
# The model architecture, change with caution due to possible state loading issues

class ResidualLayer(nn.Module):
    def __init__(self, filters, kernal_size=3):
        super().__init__()

        self.conv2d_sequential = nn.Sequential(                
            nn.Conv2d(filters, filters, kernal_size, padding=(kernal_size - 1) // 2),
            nn.BatchNorm2d(filters),
            nn.ReLU(),
            nn.Conv2d(filters, filters, kernal_size, padding=(kernal_size - 1) // 2),
            nn.BatchNorm2d(filters),
        )

        self.relu = nn.ReLU()

    def forward(self, x):
        residual = x
        x = self.conv2d_sequential(x)
        x += residual
        x = self.relu(x)

        return x
    
class ConvolutionLayer(nn.Module):
    def __init__(self, infilters, outfilters, kernal_size=3):
        super().__init__()
        
        self.conv2d_sequential = nn.Sequential(                
            nn.Conv2d(infilters, outfilters, kernal_size, padding=(kernal_size - 1) // 2),
            nn.BatchNorm2d(outfilters),
            nn.ReLU(),
        )

    def forward(self, x):
        x = self.conv2d_sequential(x)
        return x
    
class PolicyHead(nn.Module):
    def __init__(self, filters):
        super().__init__()
        self.filters = filters

        self.head = nn.Sequential(
            nn.Conv2d(self.filters, 1, 1),
            nn.Flatten(),
            nn.BatchNorm1d(225),
            nn.ReLU(),
            nn.Linear(225, 225)
        )

    def forward(self, x):
        x = self.head(x)
        return x

class NeuralNetwork(nn.Module):
    def __init__(self, filters, feature_dimensions, residual_layers=5, kernal_size=3):
        super().__init__()

        self.conv_layer = ConvolutionLayer(feature_dimensions, filters, kernal_size=kernal_size)
        self.residual_layers = nn.ModuleList([ResidualLayer(filters, kernal_size=kernal_size) for _ in range(residual_layers)])
        self.policy_head = PolicyHead(filters)

    def forward(self, x):      
        x = self.conv_layer(x)
        for layer in self.residual_layers:
            x = layer(x)
        x = self.policy_head(x)
        
        return x

In [None]:
# Model Hyperparameters / Config

Filters = 128
Layers = 20
HistoryDepth = 8
BatchSize = 128
LogCount = 5
KernalSize = 3
datasetPath = "../../Datasets/HumanExamples/GeneratedDatasets/HD8,TS0.8,RULESETS(1-6, 8-29)"

model = NeuralNetwork(Filters, HistoryDepth + 1, Layers, kernal_size=KernalSize).to(device)
#model = torch.compile(model)

wandb_logging = False
datasetName = datasetPath.split("/")[-1]
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters())

paramCount = sum(p.numel() for p in model.parameters())
print(paramCount)

In [None]:
# Load dataset into memory

historyDimSize = HistoryDepth + 1
Xtrain = torch.from_numpy(np.fromfile(f'{datasetPath}/XTrain.bin', dtype=bool).astype(np.float32).reshape(-1, historyDimSize, 15, 15))
Ytrain = torch.from_numpy(np.fromfile(f'{datasetPath}/YTrain.bin', dtype=bool).astype(np.float32).reshape(-1, 225))
Xtest = torch.from_numpy(np.fromfile(f'{datasetPath}/XTest.bin', dtype=bool).astype(np.float32).reshape(-1, historyDimSize, 15, 15))
Ytest = torch.from_numpy(np.fromfile(f'{datasetPath}/YTest.bin', dtype=bool).astype(np.float32).reshape(-1, 225))

train_dataset = TensorDataset(Xtrain, Ytrain)
train_loader = DataLoader(train_dataset, batch_size=BatchSize, shuffle=True)
test_dataset = TensorDataset(Xtest, Ytest)
test_loader = DataLoader(test_dataset, batch_size=BatchSize, shuffle=False)
Xtrain.shape

In [None]:
# Validate dataset (recommended before training)

randomSample = np.random.randint(0, len(Xtrain))
print(Utils.sliceGamestate(Xtrain[randomSample], 0))
index = Ytrain[randomSample].argmax(0)
print(f"Move: {index // 15}, {index % 15}")

In [None]:
# Init wandb tracking
wandb.init(project='TorchGomoku', config={"DatasetName": datasetName, "BatchSize": BatchSize, "LogCount": LogCount, "HistoryDepth" : HistoryDepth, "Filters": Filters, "Layers" : Layers, "KernalSize" : KernalSize}, tags=["AlphaGo"])
wandb_logging = True

In [None]:
wandb.run.notes = ""

In [None]:
# Start training for specified epochs

epochs = 15
checkpoints = []
for epoch in range(epochs):
    print(f"Epoch {epoch+1}\n-------------------------------")
    train(train_loader, model, loss_fn, optimizer, epoch, logcount=LogCount, wandb_log=wandb_logging)
    test(test_loader, model, loss_fn, epoch, wandb_log=wandb_logging)
    checkpoints.append(model.state_dict())
print("Done!")

wandb.finish()

In [None]:
modelNamePath = "../Models/HumanModels/128f20l3kAllBut7H1"
for i, checkpoint in enumerate(checkpoints):
    torch.save(checkpoint, f'{modelNamePath}/{i}.pt')