In [None]:
# Imports

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

import kagglehub as kh
import wandb
from tqdm import tqdm

# Local imports

from src.dataset import PositionsDataset
from src.model import ChessResNetModel
from src.training import create_dataloaders, training

In [None]:
# Download the parquet
parquet_path = kh.dataset_download(
    handle="lichess/chess-evaluations", path="train-00000-of-00013.parquet"
)

In [None]:
# Dataset and Dataloaders

dataset = PositionsDataset(parquet_path)

training_dataloader, validation_dataloader = create_dataloaders(dataset, batch_size=2048)

In [None]:
# Start training

model = ChessResNetModel()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=10, verbose=True)
criterion = [
    nn.CrossEntropyLoss(),
    nn.MSELoss(),
    nn.CrossEntropyLoss()
]
save_dir = "."

training(
    model=model,
    epochs=10,
    train_loader=training_dataloader,
    optimizer=optimizer,
    scheduler=scheduler,
    criterion=criterion,
    save_dir=save_dir
)
