# training overview
	•	Data Chunking: By preprocessing the data and saving it in chunks, we prevent memory overflow and ensure that only necessary data is loaded during training.
	•	Custom Dataset: The ChunkedChessDataset class efficiently accesses individual samples without loading entire datasets into memory.
	•	Batch Normalization: Including batch normalization layers helps stabilize and accelerate training, even with deeper networks.
	
	•	Efficient Data Loading: Loading data in batches and on-the-fly reduces memory usage.
	•	Minimizing Disk I/O: Preprocessing data reduces the need for complex computations during training.
	•	Balanced Workload: Adjusting num_workers and batch_size ensures optimal CPU and GPU utilization.
	•	Avoiding Repetition: Data is only processed once during preprocessing, avoiding redundant computations in the training loop.


In [1]:
import os
import numpy as np # type: ignore
import time
import torch
import torch.nn as nn # type: ignore
import torch.optim as optim # type: ignore
from torch.utils.data import DataLoader # type: ignore
from chess import pgn # type: ignore
from tqdm import tqdm # type: ignore
from helper_funcs import encode_moves, preprocess_and_save_data
from dataset import ChunkedChessDataset
from model import ChessModel
import glob

# Data Processing

## load data - into chunks so that memory is not overwhelmed, store them in sepearte folder

In [2]:
# takes ~ 

# List all PGN files in the data directory limit to 1 for testing
pgn_files = glob.glob('../data/pgn/*.pgn')[:1]
# pgn_files = glob.glob('../data/pgn/*.pgn')

# Preprocess data and save in chunks using probabilistic sampling
move_to_int, num_classes = preprocess_and_save_data(pgn_files, chunk_size=10000)

Collecting unique moves...


Processing PGN Files for Move Collection:   0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
print(f"GAMES PARSED: {len(move_to_int)}")

## setup

In [None]:
# Device configuration
if torch.backends.mps.is_available():
    device = torch.device('mps')
    print("Using MPS backend on Apple Silicon (M2).")
else:
    device = torch.device('cpu')
    print("MPS backend not available. Using CPU.")

# Create dataset and split into training and validation
dataset = ChunkedChessDataset(data_dir='data_chunks')
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create dataloaders
train_dataloader = DataLoader(train_dataset,
                              batch_size=64,
                              shuffle=True,
                              num_workers=2,
                              pin_memory=False)  # Disabled for MPS backend

val_dataloader = DataLoader(val_dataset,
                            batch_size=64,
                            shuffle=False,
                            num_workers=2,
                            pin_memory=False)

# Initialize the model, loss function, and optimizer
model = ChessModel(num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Initialize the Learning Rate Scheduler
# Option 1: ReduceLROnPlateau
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                 factor=0.1, patience=5,
                                                 verbose=True, threshold=1e-4)

# Option 2: StepLR (Alternative)
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

# Train

In [None]:
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    start_time = time.time()

    # Training phase
    for inputs, labels in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs} - Training", unit="batch"):
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        running_loss += loss.item() * inputs.size(0)

    # Calculate average training loss
    train_loss = running_loss / train_size

    # Validation phase
    model.eval()
    val_running_loss = 0.0
    with torch.no_grad():
        for val_inputs, val_labels in tqdm(val_dataloader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation", unit="batch"):
            val_inputs = val_inputs.to(device)
            val_labels = val_labels.to(device)

            val_outputs = model(val_inputs)
            val_loss = criterion(val_outputs, val_labels)

            val_running_loss += val_loss.item() * val_inputs.size(0)

    # Calculate average validation loss
    val_loss = val_running_loss / val_size

    # Step the scheduler based on validation loss
    scheduler.step(val_loss)

    # Calculate elapsed time
    elapsed_time = time.time() - start_time
    minutes = int(elapsed_time // 60)
    seconds = int(elapsed_time % 60)

    print(f"Epoch [{epoch+1}/{num_epochs}], "
          f"Train Loss: {train_loss:.4f}, "
          f"Val Loss: {val_loss:.4f}, "
          f"Time: {minutes}m {seconds}s")
    
    # Optionally, save model checkpoints
    if (epoch + 1) % 5 == 0:
        checkpoint_path = f'chess_model_epoch_{epoch+1}.pth'
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Checkpoint saved at '{checkpoint_path}'.")

# Save the final trained model
torch.save(model.state_dict(), 'chess_model_final.pth')
print("Training complete. Final model saved as 'chess_model_final.pth'.")