In [None]:
import torch
import wandb
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torch.nn.functional as F
from torchvision import transforms
from Utilities import Utilities as Util

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

In [None]:
# Basic Function definitions

def train(dataloader, model, loss_fn, optimizer, epoch, logcount=5, wandb_log=False):
    size = len(dataloader.dataset)
    loginterval = len(dataloader) // logcount
    average_loss = 0.0

    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)
        average_loss += loss.detach().item()
        
        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if (batch % loginterval == 0) and (batch > 0):
            log_loss = average_loss / loginterval
            average_loss = 0
            current = batch * len(X)
            print(f"loss: {log_loss:>8f}  [{current:>5d}/{size:>5d}]")
            if wandb_log:
                wandb.log({"epoch": epoch, "train_loss": log_loss})

def test(dataloader, model, loss_fn, epoch, wandb_log=False):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).detach().item()
    test_loss /= num_batches
    if wandb_log:
        wandb.log({"epoch" : epoch, "test_loss": test_loss})
    print(f"Test Error: \n Avg loss: {test_loss:>8f} \n")

def parameterCount(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

In [None]:
class ConvolutionLayer(nn.Module):
    def __init__(self, infilters, outfilters, kernal_size=3):
        super().__init__()
        
        self.conv2d_sequential = nn.Sequential(                
            nn.Conv2d(infilters, outfilters, kernal_size, padding=(kernal_size - 1) // 2),
            nn.BatchNorm2d(outfilters),
            nn.ReLU(),
        )

    def forward(self, x):
        x = self.conv2d_sequential(x)
        return x
    
class NeuralNetwork(nn.Module):
    def __init__(self, inFilters, outFilters):
        super().__init__()

        self.conv_layer = ConvolutionLayer(inFilters, outFilters)

    def forward(self, x):      
        x = self.conv_layer(x)
        
        return x

In [None]:
# Model Hyperparameters / Config

BatchSize = 128
LogCount = 5
datasetPathV1 = "../../Datasets/HumanExamples/GeneratedDatasets/HD5,TS0.8,RULESETS(1-6, 8-29),Legacy"
datasetPathV2 = "../../Datasets/HumanExamples/GeneratedDatasets/HD6,TS0.8,RULESETS(1-6, 8-29)"

In [None]:
# Load datasets into memory

XtrainV1 = torch.from_numpy(np.fromfile(f'{datasetPathV1}/XTrain.bin', dtype=bool).astype(np.float32).reshape(-1, 11, 15, 15))
XtestV1 = torch.from_numpy(np.fromfile(f'{datasetPathV1}/XTest.bin', dtype=bool).astype(np.float32).reshape(-1, 11, 15, 15))


XtrainV2 = torch.from_numpy(np.fromfile(f'{datasetPathV2}/XTrain.bin', dtype=bool).astype(np.float32).reshape(-1, 7, 15, 15))
XtestV2 = torch.from_numpy(np.fromfile(f'{datasetPathV2}/XTest.bin', dtype=bool).astype(np.float32).reshape(-1, 7, 15, 15))

print(XtrainV1.shape)
print(XtrainV2.shape)

In [None]:
def renderGamestateSliceV1(gamestate, HD, depth):
    if HD  == 1:
        blackStones = gamestate[2][:][:]
        whiteStones = gamestate[1][:][:]
    else:
        blackStones = gamestate[HD * 2 - depth][:][:]
        whiteStones = gamestate[HD - depth][:][:]
    print("     0   1   2   3   4   5   6   7   8   9  10  11  12  13  14")
    print("   --------------------------------------------------------------")
    for y in range(14, -1, -1):
        print(f'{y:2} |', end="")
        for x in range(15):
            if blackStones[x][y] == 0 and whiteStones[x][y] == 0:
                print("   ", end="")
            elif blackStones[x][y] == 1:
                print(" B ", end="")
            elif whiteStones[x][y] == 1:
                print(" W ", end="")
            print("|", end="")
        print("\n   --------------------------------------------------------------")

In [None]:
i = 10
print("Old:")
renderGamestateSliceV1(XtrainV1[i], 5, 3)
print("New:")
print(Util.sliceGamestate(XtrainV2[i], 6, 1))

In [None]:
train_dataset = TensorDataset(Xtrain, Ytrain)
train_loader = DataLoader(train_dataset, batch_size=BatchSize, shuffle=True)
test_dataset = TensorDataset(Xtest, Ytest)
test_loader = DataLoader(test_dataset, batch_size=BatchSize, shuffle=False)

In [None]:
model = NeuralNetwork().to(device)
#model = torch.compile(model)

wandb_logging = False
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters())

paramCount = sum(p.numel() for p in model.parameters())
print(paramCount)

In [None]:
# Validate dataset (recommended before training)

randomSample = np.random.randint(0, len(Xtrain))
randomSample = 1
renderGamestateSlice(Xtrain[randomSample], HistoryDepth)
index = Ytrain[randomSample].argmax(0)
print(f"Move: {index // 15}, {index % 15}")

In [None]:
# Init wandb tracking
wandb.init(project='TorchGomoku', config={"DatasetName": datasetName, "BatchSize": BatchSize, "LogCount": LogCount, "HistoryDepth" : HistoryDepth, "Filters": Filters, "Layers" : Layers, "KernalSize" : KernalSize}, tags=["AlphaGo"])
wandb_logging = True

In [None]:
wandb.run.notes = "Trying no history"

In [None]:
# Start training for specified epochs

epochs = 15
checkpoints = []
for epoch in range(epochs):
    print(f"Epoch {epoch+1}\n-------------------------------")
    train(train_loader, model, loss_fn, optimizer, epoch, logcount=LogCount, wandb_log=wandb_logging)
    test(test_loader, model, loss_fn, epoch, wandb_log=wandb_logging)
    checkpoints.append(model.state_dict())
print("Done!")

In [None]:
wandb.finish()

In [None]:
torch.save(model.state_dict(), "Model/Models/test1noHistory.pt")

In [None]:
modelNamePath = "../Models/HumanModels/128f20l3kAllBut7H1"
for i, checkpoint in enumerate(checkpoints):
    torch.save(checkpoint, f'{modelNamePath}/{i}.pt')