In [None]:
# Imports

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

import kagglehub as kh
import wandb
from tqdm import tqdm

# Local imports

from src.dataset import PositionsDataset
from src.model import ChessResNetModel
from src.training import create_dataloaders, training
from src.preprocess import preprocess_data_in_chunks

In [None]:
# Device configuration
device = torch.device(
    "cuda" if torch.cuda.is_available() 
    else "mps" if torch.backends.mps.is_available() 
    else "cpu"
)
print(f"Using device: {device}")

In [None]:
# Download the parquet
parquet_path = kh.dataset_download(
    handle="lichess/chess-evaluations", path="train-00000-of-00013.parquet"
)

In [None]:
# Preprocess Data

destination_path = "processed_data"

preprocess_data_in_chunks(parquet_path, destination_path)

In [None]:
# Dataset and Dataloaders

dataset = PositionsDataset(data_path=destination_path)
training_dataloader, validation_dataloader = create_dataloaders(dataset)


In [None]:
# Test device configuration
print(f"Device being used: {device}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"MPS available: {torch.backends.mps.is_available()}")

# Test tensor creation on device
test_tensor = torch.randn(3, 3).to(device)
print(f"Test tensor device: {test_tensor.device}")

# Test model device
test_model = ChessResNetModel().to(device)
print(f"Model device: {next(test_model.parameters()).device}")

del test_tensor, test_model  # Clean up

In [None]:
# Start training

model = ChessResNetModel().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=10)
criterion = [
    nn.CrossEntropyLoss(),
    nn.MSELoss(),
    nn.CrossEntropyLoss()
]
save_dir = "."

training(
    model=model,
    epochs=10,
    train_loader=training_dataloader,
    val_loader=validation_dataloader,
    optimizer=optimizer,
    scheduler=scheduler,
    criterion=criterion,
    save_dir=save_dir,
    device=device
)