In [1]:
import torch
import wandb
import numpy as np
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from Utilities import Utilities as Utils
from NeuralNet import ResidualNetwork as ResNet
from NeuralNet import PolicyNetwork as PolicyHead
from NeuralNet import ValueNetwork as ValueHead

import os

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f'Torch: {torch.__version__} using {device} device')

Torch: 2.1.0 using mps device


In [2]:
# Basic Function definitions

def trainPolicy(dataloader, resNet, polNet, pol_loss, optimizer, epoch, logcount=5, wandb_log=False):
    size = len(dataloader.dataset)
    loginterval = len(dataloader) // logcount
    averagePolLoss = 0.0

    resNet.train()
    polNet.train()
    for batch, (X, yPol) in enumerate(dataloader):
        X, yPol = X.to(device), yPol.to(device)

        # Compute prediction error
        resNetOut = resNet(X)
        polPred = polNet(resNetOut)
        polLoss = pol_loss(polPred, yPol)

        averagePolLoss += polLoss.detach().item()
        # Backpropagation
        polLoss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if (batch % loginterval == 0) and (batch > 0):
            logPolLoss = averagePolLoss / loginterval
            averagePolLoss = 0
            current = batch * len(X)
            print(f"Pol Loss: {logPolLoss:>8f} [{current:>5d}/{size:>5d}]")
            if wandb_log:
                wandb.log({"epoch": epoch, "trainPolLoss": logPolLoss})

def testPolicy(dataloader, resNet, polNet, pol_loss, epoch, wandb_log=False):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    resNet.eval()
    polNet.eval()
    testPolLoss, polCorrect = 0, 0
    with torch.no_grad():
        for X, yPol in dataloader:
            X, yPol = X.to(device), yPol.to(device)
            resNetOut = resNet(X)
            polPred = polNet(resNetOut)
            testPolLoss += pol_loss(polPred, yPol).detach().item()
            polCorrect += (polPred.argmax(1) == yPol.argmax(1)).type(torch.float).sum().item()
    testPolLoss /= num_batches
    polCorrect /= size
    if wandb_log:
        wandb.log({"epoch" : epoch, "testPolAcc": polCorrect, "testPolLoss": testPolLoss})
    print(f"Test Error: \n Pol Acc: {(100*polCorrect):>0.3f}%, Pol Loss: {testPolLoss:>8f}\n")

def parameterCount(model):
    pp = 0
    for p in model.parameters():
        if p.requires_grad:
            nn = 1
            for s in p.size():
                nn *= s
            pp += nn
    return pp

In [3]:
# Model Hyperparameters / Config

Filters = 128
Layers = 10
HistoryDepth = 8
BatchSize = 256
LogCount = 5
datasetPath = "../../Datasets/HumanExamples/GeneratedDatasets/HD8,AUG,TS0.8,RULESETS(1)"

resModel = ResNet(Filters, Layers, HistoryDepth + 1).to(device)
polModel = PolicyHead(Filters).to(device)
valModel = ValueHead(Filters).to(device)

wandb_logging = False
datasetName = datasetPath.split("/")[-1]

resParameters = parameterCount(resModel)
polParameters = parameterCount(polModel)
valParameters = parameterCount(valModel)
totalParameters = resParameters + polParameters + valParameters
print(f'Res:{resParameters}; Pol:{polParameters}; Val:{valParameters}')
print(f'Total:{totalParameters}')

# Load dataset into memory
historyDimSize = HistoryDepth + 1
Xtrain = torch.from_numpy(np.fromfile(f'{datasetPath}/XTrain.bin', dtype=bool).astype(np.float32).reshape(-1, historyDimSize, 15, 15))
YtrainPol = torch.from_numpy(np.fromfile(f'{datasetPath}/YTrainPol.bin', dtype=bool).astype(np.float32).reshape(-1, 225))
YtrainVal = torch.from_numpy(np.fromfile(f'{datasetPath}/YTrainVal.bin', dtype=np.int8).astype(np.float32).reshape(-1, 1))
Xtest = torch.from_numpy(np.fromfile(f'{datasetPath}/XTest.bin', dtype=bool).astype(np.float32).reshape(-1, historyDimSize, 15, 15))
YtestPol = torch.from_numpy(np.fromfile(f'{datasetPath}/YTestPol.bin', dtype=bool).astype(np.float32).reshape(-1, 225))
YtestVal = torch.from_numpy(np.fromfile(f'{datasetPath}/YTestVal.bin', dtype=np.int8).astype(np.float32).reshape(-1, 1))

# Policy Training
pol_train_dataset = TensorDataset(Xtrain, YtrainPol)
pol_train_dataloader = DataLoader(pol_train_dataset, batch_size=BatchSize, shuffle=True)
pol_test_dataset = TensorDataset(Xtest, YtestPol)
pol_test_dataloader = DataLoader(pol_test_dataset, batch_size=BatchSize, shuffle=True)

# Value Training
val_train_dataset = TensorDataset(Xtrain, YtrainVal)
val_train_dataloader = DataLoader(val_train_dataset, batch_size=BatchSize, shuffle=False)
val_test_dataset = TensorDataset(Xtest, YtestVal)
val_test_dataloader = DataLoader(val_test_dataset, batch_size=BatchSize, shuffle=False)

Res:2967552; Pol:102633; Val:58244
Total:3128429


In [4]:
# Validate dataset (recommended before training)
randomSample = np.random.randint(0, len(Xtrain))
gamestate = Xtrain[randomSample]
print(Utils.sliceGamestate(gamestate))
index = YtrainPol[randomSample].argmax(0)
winning = YtrainVal[randomSample]
x, y = Utils.indexToCords(index)
print(f"Turn: {gamestate[0][0][0]}; Move: {x}, {y}; Winning: {winning[0].item()}")

   --------------------------------------------------------------
14 |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |
   --------------------------------------------------------------
13 |   |   |   |   |   |   |   |   |   |   |   |   |   |   |   |
   --------------------------------------------------------------
12 |   |   |   | B |   |   |   |   |   |   |   |   |   |   |   |
   --------------------------------------------------------------
11 |   |   |   |   | W |   |   |   |   |   |   |   |   |   |   |
   --------------------------------------------------------------
10 |   |   |   | W | W | W | B |   |   |   |   |   |   |   |   |
   --------------------------------------------------------------
 9 |   |   |   | B | B | B | W |   |   |   |   |   |   |   |   |
   --------------------------------------------------------------
 8 |   |   |   |   | W | B | W | W | W |   |   |   |   |   |   |
   --------------------------------------------------------------
 7 |   |   |   | 

In [4]:
polLoss = nn.CrossEntropyLoss()
valLoss = nn.MSELoss()

polParameters = list(resModel.parameters()) + list(polModel.parameters())
polOptimizer = torch.optim.AdamW(polParameters)

In [5]:
# Init wandb tracking
wandb.init(project='big-skull', config={"DatasetName": datasetName, "BatchSize": BatchSize, "LogCount": LogCount, "HistoryDepth" : HistoryDepth, "Filters": Filters, "Layers" : Layers,  "ParameterCount" : totalParameters}, tags=["Multihead"])
wandb.notes = "Resnet and policy"
wandb_logging = True

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33malexkurtz[0m ([33mbig-skull[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [6]:
# Start training for specified epochs

epochs = 10
resNetCheckpoints = []
polNetCheckpoints = []
for epoch in range(epochs):
    print(f"Epoch {epoch+1}\n-------------------------------")
    trainPolicy(pol_train_dataloader, resModel, polModel, polLoss, polOptimizer, epoch, logcount=LogCount, wandb_log=wandb_logging)
    testPolicy(pol_test_dataloader, resModel, polModel, polLoss, epoch, wandb_log=wandb_logging)
    resNetCheckpoints.append(resModel.state_dict())
    polNetCheckpoints.append(polModel.state_dict())

print("Done!")

Epoch 1
-------------------------------


In [None]:
if wandb_logging:
    checkpoints_dir = os.path.join(wandb.run.dir, 'checkpoints')
    os.makedirs(checkpoints_dir, exist_ok=True)

    for i in range(len(resNetCheckpoints)):
        resNetPath = os.path.join(checkpoints_dir, f'{Filters}f_{Layers}l_cp:{i}_ResNet')
        polNetPath = os.path.join(checkpoints_dir, f'{Filters}f_cp:{i}_PolNet')

        torch.save(resNetCheckpoints[i], resNetPath)
        torch.save(polNetCheckpoints[i], polNetPath)

        wandb.save(resNetPath, base_path=wandb.run.dir)
        wandb.save(polNetPath, base_path=wandb.run.dir)
    wandb.finish()