## Training a ChessNet42069 model

In [None]:
import time, os
import wandb

import chess
import chess.pgn

import torch
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler
from torch.cuda.amp import autocast
from torch.utils.data import random_split

from OBM_ChessNetwork import Chess42069Network

import sys
sys.path.append('../../chess_utils')
from chess_dataset import ChessDataset
from utils import RunningAverage
from adversarial_gym.chess_env import ChessEnv


In [None]:
# PGN_FILE = "/home/kage/chess_workspace/PGN-data/pgncombined/COMBINED.pgn"
PGN_FILE = "/home/kage/chess_workspace/PGN-data/alphazero_stockfish_all/alphazero_vs_stockfish_all.pgn"
NUM_EPOCH = 1

t1 = time.perf_counter()

# Load the datasets
chess_dataset = ChessDataset(PGN_FILE)
print(len(chess_dataset))

# Define a ratio for your train/validation split
train_ratio = 0.9
train_size = int(train_ratio * len(chess_dataset))
val_size = len(chess_dataset) - train_size

# Split the dataset
train_dataset, val_dataset = random_split(chess_dataset, [train_size, val_size])

# Create data loaders for the training and validation sets
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, 
                          pin_memory=False, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, 
                        pin_memory=False, num_workers=2)

print(f"Loaded dataset in {time.perf_counter() - t1} seconds")


Training a ChessNetwork (values and actions)

In [None]:
# Initialize model
MODEL_PATH = 'ChessNet42069.pt'

model = Chess42069Network(hidden_dim=256)
if os.path.exists(MODEL_PATH):
    print("Loading model at: {MODEL_PATH}")
    model.load_state_dict(torch.load(MODEL_PATH))

model = model.to('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def run_validation(model, val_loader, stats):
    model.eval()
    stats.reset("val_loss")
    t1 = time.perf_counter()
    with torch.no_grad():
        for i, (state, action, result) in enumerate(val_loader):
            state = state.float().to('cuda' if torch.cuda.is_available() else 'cpu')
            action = action.to('cuda' if torch.cuda.is_available() else 'cpu')
            result = result.float().to('cuda' if torch.cuda.is_available() else 'cpu')
            
            policy_output, value_output = model(state.unsqueeze(1))
            policy_loss = model.policy_loss(policy_output.squeeze(), action)
            value_loss = model.val_loss(value_output.squeeze(), result)
            
            loss = policy_loss + value_loss
            stats.update("val_loss", loss.item())
    
    print(f"Mean Validation Loss: {stats.get_average('val_loss')}, time elapsed: {time.perf_counter()-t1} seconds")
    return stats.get_average('val_loss')

In [None]:
NUM_EPOCH = 10

stats = RunningAverage()
stats.add(["train_loss", "val_loss", "train_p_loss", "train_v_loss"])
grad_scaler = GradScaler()

wandb.init(project='Chess')
for epoch in range(NUM_EPOCH): 
    model.train()
    t1 = time.perf_counter()
    for i, (state, action, result) in enumerate(train_loader):
        state = state.float().to('cuda' if torch.cuda.is_available() else 'cpu')
        action = action.to('cuda' if torch.cuda.is_available() else 'cpu')
        result = result.float().to('cuda' if torch.cuda.is_available() else 'cpu')

        with autocast():
            policy_output, value_output = model(state.unsqueeze(1))
            policy_loss = model.policy_loss(policy_output.squeeze(), action)
            value_loss = model.val_loss(value_output.squeeze(), result)
            loss = policy_loss + value_loss
        
        # AMP with gradient clipping
        model.optimizer.zero_grad()
        grad_scaler.scale(loss).backward()
        grad_scaler.unscale_(model.optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        grad_scaler.step(model.optimizer)
        grad_scaler.update()

        stats.update({
            "train_loss": loss.item(),
            "train_p_loss": policy_loss.item(),
            "train_v_loss": value_loss.item()
            })
        
        if i % 1000 == 0:
            print(f"Epoch: {epoch}, Iter: {i}, Mean Loss: {stats.get_average('train_loss')}")
            wandb.log({"train_loss": stats.get_average('train_loss'), "iter": i})
        if i % 20_000 == 0 and i > 0 :
            t2 = time.perf_counter()
            valid_loss = run_validation(model, val_loader)
            print(f"Mean Validation Loss: {valid_loss}, time elapsed: {time.perf_counter()-t2} seconds")
            wandb.log({"val_loss": valid_loss, "iter": i})

            
    print(f"Epoch took {time.perf_counter()-t1} seconds ")
    wandb.log({"val_loss": valid_loss, "iter": i})
    torch.save(model.state_dict(), MODEL_PATH)