In [13]:
import tqdm
import torch

In [10]:
from src.dataset import load_pgn, load_multiple_pgns, create_csv_dataset
from src.dataclass import ChessDataset
from src.model import ChessNet

In [None]:
# Step 1: Create my training dataset
games = load_multiple_pgns(num_pgns=26)
create_csv_dataset(games, name="le_first_26")

27it [01:11,  2.64s/it]                        


In [4]:
# The step above created a csv file with 2.302.559 chess state-action pairs
# Since I have many different datasets I can use some of them as validation sets
# This way I can skip performing a train - test split

In [None]:
# Step 2: Create a validation dataset
games = load_pgn("data/pgn/lichess_elite_2016-01.pgn")
create_csv_dataset(games, name="le_2016-01")

In [9]:
# Step 3: Load the data using the pytorch Dataset and Dataloader classes
dataset = ChessDataset("data/csv/dataset_first_26.csv")
loader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=True)

In [11]:
# Step 4: Initialize the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ChessNet()
model = model.to(device)

In [12]:
# Step 5: Choose a loss function and the optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [14]:
# Step 6: Train the model
EPOCHS = 5

for epoch in tqdm.trange(EPOCHS):
    model.train()
    total_loss = 0

    for boards, labels in loader:
        boards = boards.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(boards)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(loader)
    print(f"Epoch {epoch + 1}/{EPOCHS}, Loss: {avg_loss:.4f}")

    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': avg_loss,
    }, "models/le_first_26.pth")

 20%|██        | 1/5 [06:20<25:21, 380.47s/it]

Epoch 1/5, Loss: 4.1496


 40%|████      | 2/5 [12:35<18:51, 377.33s/it]

Epoch 2/5, Loss: 3.3603


 60%|██████    | 3/5 [18:47<12:30, 375.07s/it]

Epoch 3/5, Loss: 3.1971


 80%|████████  | 4/5 [25:02<06:14, 374.89s/it]

Epoch 4/5, Loss: 3.1045


100%|██████████| 5/5 [31:16<00:00, 375.32s/it]

Epoch 5/5, Loss: 3.0404



