# todo
* increase chunk size - 5k+
* increase games count
* adaptive learning rate
* batch normalization
* focus on mid-end game
* more comment tips


In [None]:
import os
import numpy as np # type: ignore
import time
import torch
import torch.nn as nn # type: ignore
import torch.optim as optim # type: ignore
from torch.optim.lr_scheduler import OneCycleLR # type: ignore
import math
from torch.utils.data import DataLoader # type: ignore
from chess import pgn # type: ignore
import tqdm # type: ignore
from dataset import ChessDataset
from model import ChessModel
from helper_funcs import collect_unique_moves, create_input_for_nn
from helper_funcs import process_data_and_save_chunks
import pickle
import gc

# Data Processing

## load data - into chunks so that memory is not overwhelmed, store them in sepearte folder

In [None]:
files = [os.path.join("../data/pgn", file) for file in os.listdir("../data/pgn") if file.endswith(".pgn")]
files.sort()  # Ensure consistent order
LIMIT_OF_FILES = min(len(files), 28)
files = files[:LIMIT_OF_FILES]

max_games = 150000
chunk_size = 15000

# Collect unique moves
move_to_int, num_classes = collect_unique_moves(files, max_games=max_games)

with open("../../models/mark3_move_to_int", "wb") as file:
    pickle.dump(move_to_int, file)

In [None]:
data_chunk_files = process_data_and_save_chunks(files, move_to_int, chunk_size=chunk_size, max_games=max_games)

## setup

In [None]:
# Check for GPU
if torch.backends.mps.is_available():
    device = torch.device('mps')
    print("Using MPS backend on Apple Silicon (M2).")
else:
    device = torch.device('cpu')
    print("MPS backend not available. Using CPU.")
    
# Model Initialization
model = ChessModel(num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()

# Train

In [None]:
num_epochs = 20

batch_size = 64  # Ensure this matches your DataLoader batch_size
total_batches_per_epoch = 0
for data_chunk_file in data_chunk_files:
    data = np.load(data_chunk_file)
    num_samples = data['X'].shape[0]
    num_batches = math.ceil(num_samples / batch_size)
    total_batches_per_epoch += num_batches

total_steps = num_epochs * total_batches_per_epoch

optimizer = optim.Adam(model.parameters(), lr=0.0001)
scheduler = OneCycleLR(optimizer, max_lr=0.001, total_steps=total_steps)

# Training loop
for epoch in range(num_epochs):
    start_time = time.time()
    model.train()
    running_loss = 0.0
    total_batches = 0
    for data_chunk_file in tqdm.tqdm(data_chunk_files, desc=f'Epoch {epoch+1}/{num_epochs}'):
        # Load data chunk
        data = np.load(data_chunk_file)
        X = torch.tensor(data['X'], dtype=torch.float32)
        y = torch.tensor(data['y'], dtype=torch.long)
        # Create Dataset and DataLoader
        dataset = ChessDataset(X, y)
        dataloader = DataLoader(dataset, batch_size=64, num_workers=4, shuffle=True)
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)  # Move data to device
            optimizer.zero_grad()
            outputs = model(inputs)  # Raw logits
            # Compute loss
            loss = criterion(outputs, labels)
            loss.backward()
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()  # Update learning rate
            running_loss += loss.item()
            total_batches += 1
        # Free up memory
        del X, y, dataset, dataloader
        gc.collect()

    end_time = time.time()
    epoch_time = end_time - start_time
    minutes = int(epoch_time // 60)
    seconds = int(epoch_time % 60)
    avg_loss = running_loss / total_batches
    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, Time: {minutes}m{seconds}s')

In [None]:
# Save the model
torch.save(model.state_dict(), "../models/mark3-50e-150k.pth")

# Mark 1 - 10e - 50k
* base; 5k chunks, 14 boards reprezentation
* 10 minutes per epoch - 50k games

# Mark 2 - 20e - 100k
* 10k chunks
* adaptive learning rate
* probabilistic favourizm of mid-late game states
* batch normalization after convolutional layer

# Mark 3
* 150k - 15k chunks - 50e
* its useable :D - forgot to load the mapping and data is sampled randomly