In [None]:
import torch
import wandb
import numpy as np
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from Utilities import Utilities as Utils
from NeuralNet import ResidualNetwork as ResNet
from NeuralNet import PolicyNetwork as PolicyHead
from NeuralNet import ValueNetwork as ValueHead

import os

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f'Torch: {torch.__version__} using {device} device')

In [None]:
def loadModel(ConstructedModel, Path):
    ConstructedModel.load_state_dict(torch.load(Path, map_location=torch.device('cpu')))
    ConstructedModel.eval()
    ConstructedModel.to(device)

def trainValue(dataloader, resNet, valNet, val_loss, optimizer, epoch, logcount=5, wandb_log=False):
    size = len(dataloader.dataset)
    loginterval = len(dataloader) // logcount
    averageValLoss = 0.0

    resNet.train()
    valNet.train()
    for batch, (X, yPol) in enumerate(dataloader):
        X, yPol = X.to(device), yPol.to(device)

        # Compute prediction error
        resNetOut = resNet(X)
        valPred = valNet(resNetOut)
        valLoss = val_loss(valPred, yPol)

        averageValLoss += valLoss.detach().item()
        # Backpropagation
        valLoss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if (batch % loginterval == 0) and (batch > 0):
            logValLoss = averageValLoss / loginterval
            averageValLoss = 0
            current = batch * len(X)
            print(f"Val Loss: {logValLoss:>8f} [{current:>5d}/{size:>5d}]")
            if wandb_log:
                wandb.log({"epoch": epoch, "trainValLoss": logValLoss})

def testValue(dataloader, resNet, valNet, val_loss, epoch, wandb_log=False):
    num_batches = len(dataloader)
    resNet.eval()
    valNet.eval()
    testValLoss = 0
    with torch.no_grad():
        for X, yPol in dataloader:
            X, yPol = X.to(device), yPol.to(device)
            resNetOut = resNet(X)
            valPred = valNet(resNetOut)
            testValLoss += val_loss(valPred, yPol).detach().item()
    testValLoss /= num_batches
    if wandb_log:
        wandb.log({"epoch" : epoch, "testValLoss": testValLoss})
    print(f"Test Error: \n Val Loss: {testValLoss:>8f}\n")

In [None]:
Filters = 128
Layers = 13
HistoryDepth = 8
BatchSize = 128
LogCount = 5

# Head hyperparameters
ConvFilters = 256
LinearFilters = 512

datasetPath = "../../Datasets/HumanExamples/GeneratedDatasets/HD8,TS0.8,RULESETS(1-6, 8-29)"

wandb_logging = False
datasetName = datasetPath.split("/")[-1]

XTrain = Utils.loadDataset(f'{datasetPath}/XTrain.bin', (-1, HistoryDepth + 1, 15, 15), bool)
YTrain = Utils.loadDataset(f'{datasetPath}/YTrainVal.bin', (-1, 1), np.int8)

XTest = Utils.loadDataset(f'{datasetPath}/XTest.bin', (-1, HistoryDepth + 1, 15, 15), bool)
YTest = Utils.loadDataset(f'{datasetPath}/YTestVal.bin', (-1, 1), np.int8)

TrainDataloader = Utils.toDataloader(XTrain, YTrain, BatchSize=BatchSize, Shuffle=True)
TestDataloader = Utils.toDataloader(XTest, YTest, BatchSize=BatchSize, Shuffle=True)

In [None]:
resNet = ResNet(Filters, Layers, HistoryDepth + 1)

loadModel(resNet, "../../Models/Human/ResNet/test1.pt")

for param in resNet.parameters():
    param.requires_grad = False

valHead = ValueHead(Filters, ConvFilters, LinearFilters).to(device)

loss = nn.MSELoss()
optimizer = torch.optim.AdamW(valHead.parameters())

In [None]:
# Init wandb tracking
wandb.init(project='big-skull', config={"DatasetName": datasetName, "BatchSize": BatchSize, "LogCount": LogCount, "HistoryDepth" : HistoryDepth, "InFilters": Filters, "ConvFilters": ConvFilters, "LinearFilters": LinearFilters}, tags=["Multihead", "Value"])
wandb.run.notes = "More linear layers"
wandb_logging = True

In [None]:
epochs = 3
valHeadCheckpoints = []
for epoch in range(epochs):
    print(f"Epoch {epoch+1}\n-------------------------------")
    trainValue(TrainDataloader, resNet, valHead, loss, optimizer, epoch, logcount=LogCount, wandb_log=wandb_logging)
    testValue(TestDataloader, resNet, valHead, loss, epoch, wandb_log=wandb_logging)
    valHeadCheckpoints.append(valHead.state_dict())

In [None]:
if wandb_logging:
    checkpoints_dir = os.path.join(wandb.run.dir, 'checkpoints')
    os.makedirs(checkpoints_dir, exist_ok=True)

    for i, checkpoint in enumerate(valHeadCheckpoints):
        checkpoint_path = os.path.join(checkpoints_dir, f"ValHead_{i}.pt")
        torch.save(checkpoint, checkpoint_path)
        wandb.save(checkpoint_path, base_path=wandb.run.dir)
    
    wandb.finish()