In [None]:
# !pip install -r requirements.txt

In [None]:
# Imports

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

import kagglehub as kh
import wandb
from tqdm import tqdm

import matplotlib.pyplot as plt

# Local imports

from src.dataset import PositionsDataset
from src.model import ChessResNetModel
from src.training import create_dataloaders, training
from src.early_stopping import EarlyStopping

In [None]:
# Device configuration
device = torch.device(
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using device: {device}")

In [None]:
# Download the parquet
parquet_path = kh.dataset_download(
    handle="lichess/chess-evaluations", path="train-00000-of-00013.parquet"
)

In [None]:
# Dataset and Dataloaders

dataset = PositionsDataset(parquet_path)

In [None]:
training_dataloader, validation_dataloader = create_dataloaders(dataset, 2048)

In [None]:
# Test device configuration
print(f"Device being used: {device}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"MPS available: {torch.backends.mps.is_available()}")

# Test tensor creation on device
test_tensor = torch.randn(3, 3).to(device)
print(f"Test tensor device: {test_tensor.device}")

# Test model device
test_model = ChessResNetModel().to(device)
print(f"Model device: {next(test_model.parameters()).device}")

del test_tensor, test_model  # Clean up

In [None]:
# Start training

model = ChessResNetModel().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=5)
criterion = [
    nn.CrossEntropyLoss(),
    nn.MSELoss(),
    nn.CrossEntropyLoss()
]
save_dir = "."
early_stopping = EarlyStopping(patience=10)

train_losses, validation_losses = training(
    model=model,
    epochs=100,
    train_loader=training_dataloader,
    val_loader=validation_dataloader,
    optimizer=optimizer,
    scheduler=scheduler,
    early_stopping=early_stopping,
    criterion=criterion,
    save_dir=save_dir,
    device=device
)

In [None]:
# Plot training and validation loss
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label="Training Loss")
plt.plot(validation_losses, label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training and Validation Loss")
plt.legend()
plt.show()