In [None]:
import tqdm
import torch

In [None]:
from src.dataset import load_pgn, load_multiple_pgns, create_csv_dataset
from src.dataclass import ChessDataset
from src.model import ConvModel, ResNet

In [None]:
# Step 1: Create my training dataset
games = load_multiple_pgns(num_pgns=26)
create_csv_dataset(games, name="le_first_26")

In [None]:
# The step above created a csv file with 2.302.559 chess state-action pairs
# Since I have many different datasets I can use some of them as validation sets
# This way I can skip performing a train - test split

In [None]:
# Step 2: Create a validation dataset
games = load_pgn("data/pgn/lichess_elite_2016-01.pgn")
create_csv_dataset(games, name="le_2016-01")

In [None]:
# Step 3: Load the data using the pytorch Dataset and Dataloader classes
dataset = ChessDataset("data/csv/le_first_26.csv")
loader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=True)

In [None]:
# Step 4: Initialize the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNet(filters=64, res_blocks=4)
model = model.to(device)

In [None]:
model

In [None]:
# Step 5: Choose a loss function and the optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
# If the model isn't loaded training starts at 0 epochs
epoch = 0

In [None]:
# This cell loads the model from a previous state
checkpoint = torch.load(f"models/le_first_26.pth")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch'] + 1

In [None]:
# Step 6: Train the model
EPOCHS = 50
train_losses = []
for epoch in tqdm.trange(epoch, EPOCHS+1):
    model.train()
    total_loss = 0

    for boards, labels in loader:
        boards = boards.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(boards)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(loader)

    train_losses.append(avg_loss)
    torch.save({
    "train_losses": train_losses,
    }, "data/loss/loss_log_3.pt")
    
    print(f"Epoch {epoch + 1}/{EPOCHS}, Loss: {avg_loss:.4f}")

    if epoch % 10 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, f"models/model.3.{epoch}.pth")

        print(f"Checkpoint saved at epoch {epoch}")