In [1]:
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from datetime import datetime

from src.dataset import WhiteStandardChessDataset
from src.model import ChessModel
from src.training import train_model
from src.training import EarlyStop

## Prepare data and model

In [2]:
# Create Dataset and DataLoader
dataset = WhiteStandardChessDataset(Path("./data/last/"))
train_dataset, val_dataset = random_split(dataset, [0.8, 0.2])

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=True)

# Check for GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Model Initialization
model = ChessModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

early_stopper = EarlyStop(patience=3)

Files Processed:   0%|          | 0/10 [00:00<?, ?it/s]

data/last/lichess_elite_2024-04.pgn:   0%|          | 0/268460 [00:00<?, ?it/s]

data/last/lichess_elite_2024-05.pgn:   0%|          | 0/276494 [00:00<?, ?it/s]

data/last/lichess_elite_2024-06.pgn:   0%|          | 0/253321 [00:00<?, ?it/s]

data/last/lichess_elite_2024-02.pgn:   0%|          | 0/263816 [00:00<?, ?it/s]

data/last/lichess_elite_2024-03.pgn:   0%|          | 0/278293 [00:00<?, ?it/s]

data/last/lichess_elite_2024-01.pgn:   0%|          | 0/298095 [00:00<?, ?it/s]

data/last/lichess_elite_2023-09.pgn:   0%|          | 0/291787 [00:00<?, ?it/s]

data/last/lichess_elite_2023-10.pgn:   0%|          | 0/285920 [00:00<?, ?it/s]

data/last/lichess_elite_2023-11.pgn:   0%|          | 0/284773 [00:00<?, ?it/s]

data/last/lichess_elite_2023-12.pgn:   0%|          | 0/315135 [00:00<?, ?it/s]

  X = torch.frombuffer(buffer[:self.X_nbytes], dtype=torch.int8).view(self.X_shape)


Using device: cpu


## Train model

In [3]:
timestamp = datetime.now().strftime("%Y%m%d_%H%M")

EXPERIMENT_NAME = f"chess_white_14dim_run_{timestamp}"
EPOCHS = 50

train_model(
    training_loader=train_dataloader,
    validation_loader=val_dataloader,
    model=model,
    optimizer=optimizer,
    loss_fn=criterion,
    epochs=EPOCHS,
    experiment_name=EXPERIMENT_NAME,
    device=device,
    early_stop=early_stopper,
)

Training Epoch 1:   0%|          | 0/12978 [00:00<?, ?it/s]

Calculating vloss:   0%|          | 0/3245 [00:00<?, ?it/s]

LOSS train 5.379398063182831 valid 5.379934787750244


Training Epoch 2:   0%|          | 0/12978 [00:00<?, ?it/s]

Calculating vloss:   0%|          | 0/3245 [00:00<?, ?it/s]

LOSS train 5.156546977519989 valid 5.175371170043945


Training Epoch 3:   0%|          | 0/12978 [00:00<?, ?it/s]

Calculating vloss:   0%|          | 0/3245 [00:00<?, ?it/s]

LOSS train 5.018446099281311 valid 5.094732761383057


Training Epoch 4:   0%|          | 0/12978 [00:00<?, ?it/s]

Calculating vloss:   0%|          | 0/3245 [00:00<?, ?it/s]

LOSS train 4.9104393205642705 valid 5.03961706161499


Training Epoch 5:   0%|          | 0/12978 [00:00<?, ?it/s]

Calculating vloss:   0%|          | 0/3245 [00:00<?, ?it/s]

LOSS train 4.845424025058747 valid 5.010734558105469


Training Epoch 6:   0%|          | 0/12978 [00:00<?, ?it/s]

Calculating vloss:   0%|          | 0/3245 [00:00<?, ?it/s]

LOSS train 4.774623353004456 valid 5.003129005432129


Training Epoch 7:   0%|          | 0/12978 [00:00<?, ?it/s]

Calculating vloss:   0%|          | 0/3245 [00:00<?, ?it/s]

LOSS train 4.706614449501037 valid 4.996741771697998


Training Epoch 8:   0%|          | 0/12978 [00:00<?, ?it/s]

Calculating vloss:   0%|          | 0/3245 [00:00<?, ?it/s]

LOSS train 4.670058840751648 valid 4.998325347900391


Training Epoch 9:   0%|          | 0/12978 [00:00<?, ?it/s]

Calculating vloss:   0%|          | 0/3245 [00:00<?, ?it/s]

LOSS train 4.621377175569534 valid 5.004258632659912


Training Epoch 10:   0%|          | 0/12978 [00:00<?, ?it/s]

Calculating vloss:   0%|          | 0/3245 [00:00<?, ?it/s]

LOSS train 4.5796228530406955 valid 5.011059761047363
Early stopping! vloss not decreasing for 3 epochs.
