In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
from torch.utils.data import DataLoader, TensorDataset
from data_processing import generate_dataset_from_pgn

In [2]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {torch.cuda.get_device_name(torch.cuda.current_device())}")

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

dataset = generate_dataset_from_pgn("../pgn_files/LumbrasGigaBase_OTB_2025.pgn", max_games=50000)



Using device: NVIDIA GeForce RTX 4080
Total games: 1000
Total games: 2000
Total games: 3000
Total games: 4000
Total games: 5000
Total games: 6000
Total games: 7000
Total games: 8000
Total games: 9000
Total games: 10000
Total games: 11000
Total games: 12000
Total games: 13000
Total games: 14000
Total games: 15000
Total games: 16000
Total games: 17000
Total games: 18000
Total games: 19000
Total games: 20000
Total games: 21000
Total games: 22000
Total games: 23000
Total games: 24000
Total games: 25000
Total games: 26000
Total games: 27000
Total games: 28000
Total games: 29000
Total games: 30000
Total games: 31000
Total games: 32000
Total games: 33000
Total games: 34000
Total games: 35000
Total games: 36000
Total games: 37000
Total games: 38000
Total games: 39000
Total games: 40000
Total games: 41000
Total games: 42000
Total games: 43000
Total games: 44000
Total games: 45000
Total games: 46000
Total games: 47000
Total games: 48000
Total games: 49000
Total games: 50000
extracting fens with 

In [3]:
train_to_test_ratio = 0.8

train_size = int(len(dataset) * train_to_test_ratio)
test_size = len(dataset) - train_size

# split the dataset
train_data = dataset[:train_size]
test_data = dataset[train_size:]

X_train = torch.stack([board for board, move, winner in train_data])  # (N, 8, 8, 12)
t_train = torch.tensor([(move, winner) for board, move, winner in train_data])  # (N, 2)

X_test = torch.stack([board for board, move, winner in test_data])
t_test = torch.tensor([(move, winner) for board, move, winner in test_data])
# create DataLoaders
batch_size = 4096
train_dataset = TensorDataset(X_train, t_train)
test_dataset = TensorDataset(X_test, t_test)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [4]:
from SLPolicyValueGPU import SLPolicyValueNetwork

model = SLPolicyValueNetwork().to(device)
# model.load_state_dict(torch.load("model_files/V3SL_trained.pth", map_location=device))
policy_criterion = nn.CrossEntropyLoss() # softmax regression loss function
value_criterion = nn.MSELoss() # use to use logistic loss but expects labels to be 0 or 1, not a range betwen -1 and 1
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-3)

model.train()

SLPolicyValueNetwork(
  (conv1): Conv2d(13, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (blocks): ModuleList(
    (0-9): 10 x ResNetBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (fc_shared): Linear(in_features=16384, out_features=512, bias=True)
  (fc_policy): Linear(in_features=512, out_features=20480, bias=True)
  (value_conv): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1))
  (value_bn): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (value_fc1): Linear(in_features=64, out_features=256, bias=True)
  (value_fc2): Linear(in_f

In [5]:
checkpoint = torch.load("V3SL_trained.pth")
model.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
start_epoch = checkpoint["epoch"]
start_batch = checkpoint["batch"]
# start_epoch = 0
# start_batch = 0

print(f"Resuming from epoch {start_epoch}, batch {start_batch}")

epochs = 20

for epoch in range(epochs):

    epoch = start_epoch + epoch
    for batch_idx, (data, target) in enumerate(train_dataloader):
        data = data.to(device)
        batch_move_target = target[:, 0].to(device)
        batch_val_target = target[:, 1].float().unsqueeze(1).to(device)

        pred_policy, pred_val = model(data)  # calculate predictions for this batch

        # only calculate policy loss if it's winner's move
        mask = (batch_val_target.view(-1) == 1)
        if mask.sum() > 0:
            policy_loss = policy_criterion(pred_policy[mask], batch_move_target[mask])
        else:
            policy_loss = torch.tensor(0.0, device=device)
        
        value_loss = value_criterion(pred_val, batch_val_target) # calculate loss for value
        loss = policy_loss + value_loss
        optimizer.zero_grad()  # reset gradient
        loss.backward()  # calculate gradient
        optimizer.step()  # update parameters

        if batch_idx % 100 == 0:
            print(f"batch progress: epoch {epoch} {(100 *(batch_idx +1)/len(train_dataloader)):.2f}% loss: {loss.item():.6f}")
            # torch.save({
            #     "model": model.state_dict(),
            #     "optimizer": optimizer.state_dict(),
            #     "epoch": epoch+1, 
            #     "batch": batch_idx,
            # }, "checkpoint2.pth")

        # print(f"batch progress: epoch {epoch+1} {(100 *(batch_idx +1)/len(train_dataloader)):.2f}% loss: {loss.item():.6f}")
        # torch.save({
        #     "model": model.state_dict(),
        #     "optimizer": optimizer.state_dict(),
        #     "epoch": epoch+1,
        #     "batch": batch_idx,
        # }, "checkpoint2.pth")

    # check validation accuracy to see if general patterns are being learnt

    model.eval()
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_dataloader):
            data = data.to(device)
            batch_move_target = target[:, 0].to(device)
            batch_val_target = target[:, 1].float().unsqueeze(1).to(device)

            pred_policy, pred_val = model(data)
            policy_loss = policy_criterion(pred_policy, batch_move_target)  # calculate loss for policy
            value_loss = value_criterion(pred_val, batch_val_target) # calculate loss for value
            loss = policy_loss + value_loss
            test_loss += loss

    total_test_loss = test_loss / len(test_dataloader)
    print('epoch: {}, test loss: {:.6f}'.format(
        epoch + 1,
        total_test_loss,
        ))
    
    torch.save({
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "epoch": epoch+1, 
        "batch": batch_idx,
    }, "V3SL_trained.pth")
    print('model checkpoint saved')

    if total_test_loss < 1:
        break
    
    model.train()

Resuming from epoch 9, batch 210
batch progress: epoch 9 0.12% loss: 1.540049
batch progress: epoch 9 12.00% loss: 1.419505
batch progress: epoch 9 23.87% loss: 1.469947
batch progress: epoch 9 35.75% loss: 1.540466
batch progress: epoch 9 47.62% loss: 1.559277
batch progress: epoch 9 59.50% loss: 1.600436
batch progress: epoch 9 71.38% loss: 1.595518
batch progress: epoch 9 83.25% loss: 1.668903
batch progress: epoch 9 95.13% loss: 1.699417
epoch: 10, test loss: 3.928909
model checkpoint saved
batch progress: epoch 10 0.12% loss: 1.380563
batch progress: epoch 10 12.00% loss: 1.231666
batch progress: epoch 10 23.87% loss: 1.343590
batch progress: epoch 10 35.75% loss: 1.347462
batch progress: epoch 10 47.62% loss: 1.401613
batch progress: epoch 10 59.50% loss: 1.408080
batch progress: epoch 10 71.38% loss: 1.458766
batch progress: epoch 10 83.25% loss: 1.491179
batch progress: epoch 10 95.13% loss: 1.525310
epoch: 11, test loss: 4.156626
model checkpoint saved
batch progress: epoch 11

KeyboardInterrupt: 