In [None]:
import torch
import wandb
import numpy as np
import torch.nn as nn
from Config import Config as Conf

from Utilities import Utilities as Utils
from NeuralNet import ResidualNetwork as ResNet
from NeuralNet import PolicyNetwork as PolicyHead
from NeuralNet import ValueNetwork as ValueHead

import os

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f'Torch: {torch.__version__} using {device} device')

In [None]:
# Basic Function definitions

def trainPolicy(dataloader, resNet, polNet, pol_loss, optimizer, epoch, logcount=5, wandb_log=False):
    size = len(dataloader.dataset)
    loginterval = len(dataloader) // logcount
    averagePolLoss = 0.0

    resNet.train()
    polNet.train()
    for batch, (X, yPol) in enumerate(dataloader):
        X, yPol = X.to(device), yPol.to(device)

        # Compute prediction error
        resNetOut = resNet(X)
        polPred = polNet(resNetOut)
        polLoss = pol_loss(polPred, yPol)

        averagePolLoss += polLoss.detach().item()
        # Backpropagation
        polLoss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if (batch % loginterval == 0) and (batch > 0):
            logPolLoss = averagePolLoss / loginterval
            averagePolLoss = 0
            current = batch * len(X)
            print(f"Pol Loss: {logPolLoss:>8f} [{current:>5d}/{size:>5d}]")
            if wandb_log:
                wandb.log({"epoch": epoch, "trainPolLoss": logPolLoss})


def testPolicy(dataloader, resNet, polNet, pol_loss, epoch, wandb_log=False):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    resNet.eval()
    polNet.eval()
    testPolLoss, polCorrect = 0, 0
    with torch.no_grad():
        for X, yPol in dataloader:
            X, yPol = X.to(device), yPol.to(device)
            resNetOut = resNet(X)
            polPred = polNet(resNetOut)
            testPolLoss += pol_loss(polPred, yPol).detach().item()
            polCorrect += (polPred.argmax(1) == yPol.argmax(1)).type(torch.float).sum().item()
    testPolLoss /= num_batches
    polCorrect /= size
    if wandb_log:
        wandb.log({"epoch" : epoch, "testPolAcc": polCorrect, "testPolLoss": testPolLoss})
    print(f"Test Error: \n Pol Acc: {(100*polCorrect):>0.3f}%, Pol Loss: {testPolLoss:>8f}\n")

In [None]:
# Model Hyperparameters / Config

Filters = Conf.NN_FILTERS
Layers = Conf.NN_RESNETLAYERS
HistoryDepth = Conf.HISTORYDEPTH
BatchSize = 256
LogCount = 5
datasetPath = "../../Datasets/HumanExamples/GeneratedDatasets/HD8,TS0.8,RULESETS(1-6, 8-29)"

resModel = ResNet(Filters, Layers, HistoryDepth + 1).to(device)
polModel = PolicyHead(Filters).to(device)

wandb_logging = False
datasetName = datasetPath.split("/")[-1]

resParameters = Utils.trainableParameterCount(resModel)
polParameters = Utils.trainableParameterCount(polModel)

totalParameters = resParameters + polParameters
print(f'Res:{resParameters}; Pol:{polParameters};')
print(f'Total:{totalParameters}')

# Load dataset into memory
historyDimSize = HistoryDepth + 1
Xtrain = Utils.loadDataset(f'{datasetPath}/XTrain.bin', (-1, historyDimSize, 15, 15), bool)
Ytrain = Utils.loadDataset(f'{datasetPath}/YTrainPol.bin', (-1, 225), bool)

Xtest = Utils.loadDataset(f'{datasetPath}/XTest.bin', (-1, historyDimSize, 15, 15), bool)
Ytest = Utils.loadDataset(f'{datasetPath}/YTestPol.bin', (-1, 225), bool)

train_dataloader = Utils.toDataloader(Xtrain, Ytrain, BatchSize=BatchSize, Shuffle=True)
test_dataloader = Utils.toDataloader(Xtest, Ytest, BatchSize=BatchSize, Shuffle=False)

In [None]:
# Validate dataset (recommended before training)
randomSample = np.random.randint(0, len(Xtrain))
gamestate = Xtrain[randomSample]
print(Utils.sliceGamestate(gamestate))
index = Ytrain[randomSample].argmax(0)
x, y = Utils.indexToCords(index)
print(f"Turn: {gamestate[0][0][0]}; Move: {x}, {y};")

In [None]:
polLoss = nn.CrossEntropyLoss()
valLoss = nn.MSELoss()

polParameters = list(resModel.parameters()) + list(polModel.parameters())
polOptimizer = torch.optim.AdamW(polParameters)

In [None]:
# Init wandb tracking
wandb.init(project='big-skull', config={"DatasetName": datasetName, "BatchSize": BatchSize, "LogCount": LogCount, "HistoryDepth" : HistoryDepth, "Filters": Filters, "Layers" : Layers,  "ParameterCount" : totalParameters}, tags=["ResNet", "Policy"])
wandb.run.notes = "Resnet and policy"
wandb_logging = True

In [None]:
# Start training for specified epochs

epochs = 5
resNetCheckpoints = []
polNetCheckpoints = []
for epoch in range(epochs):
    print(f"Epoch {epoch+1}\n-------------------------------")
    trainPolicy(train_dataloader, resModel, polModel, polLoss, polOptimizer, epoch, logcount=LogCount, wandb_log=wandb_logging)
    testPolicy(test_dataloader, resModel, polModel, polLoss, epoch, wandb_log=wandb_logging)
    resNetCheckpoints.append(resModel.state_dict())
    polNetCheckpoints.append(polModel.state_dict())

print("Done!")

In [None]:
if wandb_logging:
    checkpoints_dir = os.path.join(wandb.run.dir, 'checkpoints')
    os.makedirs(checkpoints_dir, exist_ok=True)

    for i, checkpoint in enumerate(resNetCheckpoints):
        checkpoint_path = os.path.join(checkpoints_dir, f"ResNet_{i}.pt")
        torch.save(checkpoint, checkpoint_path)
        wandb.save(checkpoint_path, base_path=wandb.run.dir)

    for i, checkpoint in enumerate(polNetCheckpoints):
        checkpoint_path = os.path.join(checkpoints_dir, f"PolNet_{i}.pt")
        torch.save(checkpoint, checkpoint_path)
        wandb.save(checkpoint_path, base_path=wandb.run.dir)


    wandb.finish()