In [None]:
import torch
import wandb
import numpy as np
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from Utilities import Utilities as Utils
from NeuralNet import NeuralNetwork as Network

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f'Torch: {torch.__version__} using {device} device')

In [None]:
# Basic Function definitions

class CustomDataset(Dataset):
    def __init__(self, X, Y1, Y2):
        self.X = X
        self.Y1 = Y1
        self.Y2 = Y2

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x_sample = self.X[idx]
        y1_sample = self.Y1[idx]
        y2_sample = self.Y2[idx]
        return x_sample, y1_sample, y2_sample

def train(dataloader, model, pol_loss, val_loss, optimizer, epoch, logcount=5, wandb_log=False):
    size = len(dataloader.dataset)
    loginterval = len(dataloader) // logcount
    averagePolLoss = 0.0
    averageValLoss = 0.0

    model.train()
    for batch, (X, yPol, yVal) in enumerate(dataloader):
        X, yPol, yVal = X.to(device), yPol.to(device), yVal.to(device)

        # Compute prediction error
        polPred, valPred = model(X)
        polLoss = pol_loss(polPred, yPol)
        valLoss = val_loss(valPred, yVal)

        averagePolLoss += polLoss.detach().item()
        averageValLoss += valLoss.detach().item()
        
        combinedLoss = polLoss + valLoss
        # Backpropagation
        combinedLoss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if (batch % loginterval == 0) and (batch > 0):
            logPolLoss = averagePolLoss / loginterval
            logValLoss = averageValLoss / loginterval
            averagePolLoss = 0
            averageValLoss = 0
            current = batch * len(X)
            print(f"Pol Loss: {logPolLoss:>8f}, Val Loss: {logValLoss:>8f} [{current:>5d}/{size:>5d}]")
            if wandb_log:
                wandb.log({"epoch": epoch, "trainPolLoss": logPolLoss, "trainValLoss": logValLoss})

def test(dataloader, model, pol_loss, val_loss, epoch, wandb_log=False):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    testPolLoss, testValLoss, polCorrect, valCorrect = 0, 0, 0, 0
    with torch.no_grad():
        for X, yPol, yVal in dataloader:
            X, yPol, yVal = X.to(device), yPol.to(device), yVal.to(device)
            polPred, valPred = model(X)
            testPolLoss += pol_loss(polPred, yPol).detach().item()
            testValLoss += val_loss(valPred, yVal).detach().item()
            polCorrect += (polPred.argmax(1) == yPol.argmax(1)).type(torch.float).sum().item()
    testPolLoss /= num_batches
    testValLoss /= num_batches
    polCorrect /= size
    if wandb_log:
        wandb.log({"epoch" : epoch, "testPolAcc": polCorrect, "testPolLoss": testPolLoss, "testValLoss": testValLoss})
    print(f"Test Error: \n Pol Acc: {(100*polCorrect):>0.3f}%, Pol Loss: {testPolLoss:>8f}, Val Loss: {testValLoss:>8f} \n")

def parameterCount(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

In [None]:
# Model Hyperparameters / Config

Filters = 128
Layers = 20
HistoryDepth = 8
BatchSize = 256
LogCount = 5
KernalSize = 3
datasetPath = "../../Datasets/HumanExamples/GeneratedDatasets/HD8,TS0.8,RULESETS(1-6, 8-29)"

model = Network(Filters, HistoryDepth + 1, Layers, kernal_size=KernalSize).to(device)
#model = torch.compile(model)

wandb_logging = False
datasetName = datasetPath.split("/")[-1]
polLoss = nn.CrossEntropyLoss()
valLoss = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters())

paramCount = sum(p.numel() for p in model.parameters())
print(paramCount)

In [None]:
# Load dataset into memory

historyDimSize = HistoryDepth + 1
Xtrain = torch.from_numpy(np.fromfile(f'{datasetPath}/XTrain.bin', dtype=bool).astype(np.float32).reshape(-1, historyDimSize, 15, 15))
YtrainPol = torch.from_numpy(np.fromfile(f'{datasetPath}/YTrainPol.bin', dtype=bool).astype(np.float32).reshape(-1, 225))
YtrainVal = torch.from_numpy(np.fromfile(f'{datasetPath}/YTrainVal.bin', dtype=np.int8).astype(np.float32).reshape(-1, 1))
Xtest = torch.from_numpy(np.fromfile(f'{datasetPath}/XTest.bin', dtype=bool).astype(np.float32).reshape(-1, historyDimSize, 15, 15))
YtestPol = torch.from_numpy(np.fromfile(f'{datasetPath}/YTestPol.bin', dtype=bool).astype(np.float32).reshape(-1, 225))
YtestVal = torch.from_numpy(np.fromfile(f'{datasetPath}/YTestVal.bin', dtype=np.int8).astype(np.float32).reshape(-1, 1))

train_dataset = CustomDataset(Xtrain, YtrainPol, YtrainVal)
train_loader = DataLoader(train_dataset, batch_size=BatchSize, shuffle=True)
test_dataset = CustomDataset(Xtest, YtestPol, YtestVal)
test_loader = DataLoader(test_dataset, batch_size=BatchSize, shuffle=False)
print(Xtrain.shape)
print(YtrainPol.shape)
print(YtrainVal.shape)

In [None]:
# Validate dataset (recommended before training)

randomSample = np.random.randint(0, len(Xtrain))
gamestate = Xtrain[randomSample]
print(Utils.sliceGamestate(gamestate))
index = YtrainPol[randomSample].argmax(0)
winning = YtrainVal[randomSample]
print(f"Turn: {gamestate[0][0][0]}; Move: {index // 15}, {index % 15}; Winning: {winning[0].item()}")

In [None]:
# Init wandb tracking
wandb.init(project='TorchGomoku', config={"DatasetName": datasetName, "BatchSize": BatchSize, "LogCount": LogCount, "HistoryDepth" : HistoryDepth, "Filters": Filters, "Layers" : Layers, "KernalSize" : KernalSize, "ParameterCount" : paramCount}, tags=["Multihead"])
wandb_logging = True

In [None]:
wandb.run.notes = ""

In [None]:
# Start training for specified epochs

epochs = 10
checkpoints = []
for epoch in range(epochs):
    print(f"Epoch {epoch+1}\n-------------------------------")
    train(train_loader, model, polLoss, valLoss, optimizer, epoch, logcount=LogCount, wandb_log=wandb_logging)
    test(test_loader, model, polLoss, valLoss, epoch, wandb_log=wandb_logging)
    checkpoints.append(model.state_dict())
print("Done!")

if wandb_logging:
    wandb.finish()

In [None]:
modelNamePath = "../../Models/HumanModels/128f20l"
for i, checkpoint in enumerate(checkpoints):
    torch.save(checkpoint, f'{modelNamePath}/{i}.pt')