In [1]:
import tqdm
import torch

In [2]:
from src.dataset import load_pgn, load_multiple_pgns, create_csv_dataset
from src.dataclass import ChessDataset
from src.model import ChessNet

In [None]:
# Step 1: Create my training dataset
games = load_multiple_pgns(num_pgns=26)
create_csv_dataset(games, name="le_first_26")

27it [01:11,  2.64s/it]                        


In [4]:
# The step above created a csv file with 2.302.559 chess state-action pairs
# Since I have many different datasets I can use some of them as validation sets
# This way I can skip performing a train - test split

In [None]:
# Step 2: Create a validation dataset
games = load_pgn("data/pgn/lichess_elite_2016-01.pgn")
create_csv_dataset(games, name="le_2016-01")

In [3]:
# Step 3: Load the data using the pytorch Dataset and Dataloader classes
dataset = ChessDataset("data/csv/le_first_26.csv")
loader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=True)

In [4]:
# Step 4: Initialize the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ChessNet()
model = model.to(device)

In [5]:
# Step 5: Choose a loss function and the optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [6]:
# If the model isn't loaded training starts at 0 epochs
epoch = 0

In [7]:
# This cell loads the model from a previous state
checkpoint = torch.load(f"models/le_first_26.pth")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch'] + 1

In [None]:
# Step 6: Train the model
EPOCHS = 50
train_losses = []
for epoch in tqdm.trange(epoch, EPOCHS):
    model.train()
    total_loss = 0

    for boards, labels in loader:
        boards = boards.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(boards)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(loader)

    train_losses.append(avg_loss)
    torch.save({
    "train_losses": train_losses,
    }, "data/loss/loss_log_1.pt")
    
    print(f"Epoch {epoch + 1}/{EPOCHS}, Loss: {avg_loss:.4f}")

    if epoch % 10 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, f"models/le_first_26_{epoch}.pth")

        print(f"Checkpoint saved at epoch {epoch}")

  2%|▏         | 1/44 [06:03<4:20:11, 363.07s/it]

Epoch 7/50, Loss: 2.9509


  5%|▍         | 2/44 [12:10<4:16:02, 365.77s/it]

Epoch 8/50, Loss: 2.9173


  7%|▋         | 3/44 [18:18<4:10:35, 366.72s/it]

Epoch 9/50, Loss: 2.8877


  9%|▉         | 4/44 [24:26<4:04:48, 367.21s/it]

Epoch 10/50, Loss: 2.8623


 11%|█▏        | 5/44 [30:33<3:58:40, 367.19s/it]

Epoch 11/50, Loss: 2.8393
Checkpoint saved at epoch 10


 14%|█▎        | 6/44 [36:35<3:51:28, 365.48s/it]

Epoch 12/50, Loss: 2.8188


 16%|█▌        | 7/44 [42:41<3:45:25, 365.54s/it]

Epoch 13/50, Loss: 2.8007


 18%|█▊        | 8/44 [48:46<3:39:18, 365.50s/it]

Epoch 14/50, Loss: 2.7834


 20%|██        | 9/44 [54:52<3:33:16, 365.62s/it]

Epoch 15/50, Loss: 2.7675


 23%|██▎       | 10/44 [1:00:58<3:27:12, 365.66s/it]

Epoch 16/50, Loss: 2.7513


 25%|██▌       | 11/44 [1:07:04<3:21:09, 365.73s/it]

Epoch 17/50, Loss: 2.7370


 27%|██▋       | 12/44 [1:13:09<3:14:58, 365.57s/it]

Epoch 18/50, Loss: 2.7235


 30%|██▉       | 13/44 [1:19:15<3:08:52, 365.55s/it]

Epoch 19/50, Loss: 2.7116


 32%|███▏      | 14/44 [1:25:21<3:02:50, 365.68s/it]

Epoch 20/50, Loss: 2.7002


 34%|███▍      | 15/44 [1:31:27<2:56:51, 365.91s/it]

Epoch 21/50, Loss: 2.6892
Checkpoint saved at epoch 20


 36%|███▋      | 16/44 [1:37:34<2:50:50, 366.10s/it]

Epoch 22/50, Loss: 2.6787


 39%|███▊      | 17/44 [1:43:39<2:44:40, 365.93s/it]

Epoch 23/50, Loss: 2.6692


 41%|████      | 18/44 [1:49:45<2:38:33, 365.89s/it]

Epoch 24/50, Loss: 2.6604


 43%|████▎     | 19/44 [1:55:51<2:32:30, 366.00s/it]

Epoch 25/50, Loss: 2.6514


 45%|████▌     | 20/44 [2:01:53<2:25:53, 364.73s/it]

Epoch 26/50, Loss: 2.6438


 48%|████▊     | 21/44 [2:07:55<2:19:29, 363.89s/it]

Epoch 27/50, Loss: 2.6361


 50%|█████     | 22/44 [2:13:56<2:13:08, 363.11s/it]

Epoch 28/50, Loss: 2.6288


 52%|█████▏    | 23/44 [2:19:57<2:06:49, 362.35s/it]

Epoch 29/50, Loss: 2.6222


 55%|█████▍    | 24/44 [2:25:58<2:00:40, 362.02s/it]

Epoch 30/50, Loss: 2.6162


 57%|█████▋    | 25/44 [2:31:59<1:54:35, 361.84s/it]

Epoch 31/50, Loss: 2.6105
Checkpoint saved at epoch 30


 59%|█████▉    | 26/44 [2:37:59<1:48:19, 361.09s/it]

Epoch 32/50, Loss: 2.6050


 61%|██████▏   | 27/44 [2:43:59<1:42:16, 360.94s/it]

Epoch 33/50, Loss: 2.5989


 64%|██████▎   | 28/44 [2:49:58<1:36:04, 360.30s/it]

Epoch 34/50, Loss: 2.5938


 66%|██████▌   | 29/44 [2:55:59<1:30:08, 360.59s/it]

Epoch 35/50, Loss: 2.5884


 68%|██████▊   | 30/44 [3:01:58<1:24:00, 360.05s/it]

Epoch 36/50, Loss: 2.5830


 70%|███████   | 31/44 [3:07:59<1:18:04, 360.34s/it]

Epoch 37/50, Loss: 2.5787


 73%|███████▎  | 32/44 [3:13:57<1:11:55, 359.61s/it]

Epoch 38/50, Loss: 2.5755


 75%|███████▌  | 33/44 [3:19:54<1:05:45, 358.72s/it]

Epoch 39/50, Loss: 2.5708


 77%|███████▋  | 34/44 [3:25:52<59:46, 358.66s/it]  

Epoch 40/50, Loss: 2.5667


 80%|███████▉  | 35/44 [3:31:52<53:49, 358.82s/it]

Epoch 41/50, Loss: 2.5626
Checkpoint saved at epoch 40


 82%|████████▏ | 36/44 [3:37:53<47:57, 359.73s/it]

Epoch 42/50, Loss: 2.5596


 84%|████████▍ | 37/44 [3:43:53<41:57, 359.71s/it]

Epoch 43/50, Loss: 2.5563


 86%|████████▋ | 38/44 [3:49:53<35:59, 359.85s/it]

Epoch 44/50, Loss: 2.5530


 89%|████████▊ | 39/44 [3:55:54<30:00, 360.08s/it]

Epoch 45/50, Loss: 2.5493


 91%|█████████ | 40/44 [4:01:55<24:01, 360.27s/it]

Epoch 46/50, Loss: 2.5467


 93%|█████████▎| 41/44 [4:07:56<18:01, 360.59s/it]

Epoch 47/50, Loss: 2.5441


 95%|█████████▌| 42/44 [4:13:57<12:01, 360.84s/it]

Epoch 48/50, Loss: 2.5413


 98%|█████████▊| 43/44 [4:20:00<06:01, 361.29s/it]

Epoch 49/50, Loss: 2.5383


100%|██████████| 44/44 [4:25:59<00:00, 362.73s/it]

Epoch 50/50, Loss: 2.5364



