In [1]:
import torch
import wandb
import numpy as np
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from Utilities import Utilities as Utils
from NeuralNet import ResidualNetwork as ResNet
from NeuralNet import PolicyNetwork as PolicyHead
from NeuralNet import ValueNetwork as ValueHead

import os

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f'Torch: {torch.__version__} using {device} device')

Torch: 2.1.0 using mps device


In [2]:
def loadModel(ConstructedModel, Path):
    ConstructedModel.load_state_dict(torch.load(Path, map_location=torch.device('cpu')))
    ConstructedModel.eval()
    ConstructedModel.to(device)

def loadDataset(Path, Shape):
    return torch.from_numpy(np.fromfile(Path, dtype=bool).astype(np.float32).reshape((Shape)))

def toDataloader(X, Y, BatchSize=128, Shuffle=False):
    dataset = TensorDataset(X, Y)
    return DataLoader(dataset, batch_size=BatchSize, shuffle=Shuffle)

def trainValue(dataloader, resNet, valNet, val_loss, optimizer, epoch, logcount=5, wandb_log=False):
    size = len(dataloader.dataset)
    loginterval = len(dataloader) // logcount
    averageValLoss = 0.0

    resNet.train()
    valNet.train()
    for batch, (X, yPol) in enumerate(dataloader):
        X, yPol = X.to(device), yPol.to(device)

        # Compute prediction error
        resNetOut = resNet(X)
        valPred = valNet(resNetOut)
        valLoss = val_loss(valPred, yPol)

        averageValLoss += valLoss.detach().item()
        # Backpropagation
        valLoss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if (batch % loginterval == 0) and (batch > 0):
            logValLoss = averageValLoss / loginterval
            averageValLoss = 0
            current = batch * len(X)
            print(f"Val Loss: {logValLoss:>8f} [{current:>5d}/{size:>5d}]")
            if wandb_log:
                wandb.log({"epoch": epoch, "trainValLoss": logValLoss})

In [3]:
Filters = 128
Layers = 13
HistoryDepth = 8
BatchSize = 128
LogCount = 5
datasetPath = "../../Datasets/HumanExamples/GeneratedDatasets/HD8,AUG,TS0.8,RULESETS(1)"

X = loadDataset(f'{datasetPath}/XTrain.bin', (-1, HistoryDepth + 1, 15, 15))
Y = loadDataset(f'{datasetPath}/YTrainVal.bin', (-1, 1))
dataloader = toDataloader(X, Y, BatchSize=BatchSize, Shuffle=True)

In [4]:
resNet = ResNet(Filters, Layers, HistoryDepth + 1)
loadModel(resNet, "../../Models/Human/ResNet/ResNet.pt")

for param in resNet.parameters():
    param.requires_grad = False

valHead = ValueHead(Filters).to(device)

loss = nn.MSELoss()
optimizer = torch.optim.AdamW(valHead.parameters())

In [5]:
epochs = 15
valHeadCheckpoints = []
for epoch in range(epochs):
    print(f"Epoch {epoch+1}\n-------------------------------")
    trainValue(dataloader, resNet, valHead, loss, optimizer, epoch, logcount=50, wandb_log=False)

Epoch 1
-------------------------------
Val Loss: 0.132071 [76928/3847720]
Val Loss: 0.119759 [153856/3847720]
Val Loss: 0.120535 [230784/3847720]
Val Loss: 0.117745 [307712/3847720]
Val Loss: 0.118940 [384640/3847720]
Val Loss: 0.118785 [461568/3847720]
