In [1]:
import tqdm
import torch

In [None]:
from src.dataset import load_pgn, load_multiple_pgns, create_csv_dataset
from src.dataclass import ChessDataset
from src.model import ConvModel, ResNet

In [None]:
# Step 1: Create my training dataset
games = load_multiple_pgns(num_pgns=26)
create_csv_dataset(games, name="le_first_26")

In [None]:
# The step above created a csv file with 2.302.559 chess state-action pairs
# Since I have many different datasets I can use some of them as validation sets
# This way I can skip performing a train - test split

In [None]:
# Step 2: Create a validation dataset
games = load_pgn("data/pgn/lichess_elite_2016-01.pgn")
create_csv_dataset(games, name="le_2016-01")

In [3]:
# Step 3: Load the data using the pytorch Dataset and Dataloader classes
dataset = ChessDataset("data/csv/le_first_26.csv")
loader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=True)

In [None]:
# Step 4: Initialize the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNet(filters=64, res_blocks=4)
model = model.to(device)

In [5]:
model

ResNet(
  (start_conv): Conv2d(18, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (start_batch): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (res_tower): ModuleList(
    (0-3): 4 x ResBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (batch1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (batch2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (global_avg): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc1): Linear(in_features=64, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=4544, bias=True)
)

In [None]:
# Step 5: Choose a loss function and the optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [7]:
# If the model isn't loaded training starts at 0 epochs
epoch = 0

In [None]:
# This cell loads the model from a previous state
checkpoint = torch.load(f"models/le_first_26.pth")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch'] + 1

In [None]:
# Step 6: Train the model
EPOCHS = 50
train_losses = []
for epoch in tqdm.trange(epoch, EPOCHS+1):
    model.train()
    total_loss = 0

    for boards, labels in loader:
        boards = boards.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(boards)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(loader)

    train_losses.append(avg_loss)
    torch.save({
    "train_losses": train_losses,
    }, "data/loss/loss_log_3.pt")
    
    print(f"Epoch {epoch + 1}/{EPOCHS}, Loss: {avg_loss:.4f}")

    if epoch % 10 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, f"models/model.3.{epoch}.pth")

        print(f"Checkpoint saved at epoch {epoch}")

  2%|▏         | 1/51 [07:54<6:35:31, 474.62s/it]

Epoch 1/50, Loss: 4.6280
Checkpoint saved at epoch 0


  4%|▍         | 2/51 [15:49<6:27:32, 474.54s/it]

Epoch 2/50, Loss: 3.6277


  6%|▌         | 3/51 [23:42<6:19:07, 473.91s/it]

Epoch 3/50, Loss: 3.3165


  8%|▊         | 4/51 [31:35<6:10:59, 473.60s/it]

Epoch 4/50, Loss: 3.1323


 10%|▉         | 5/51 [39:29<6:03:10, 473.72s/it]

Epoch 5/50, Loss: 3.0044


 12%|█▏        | 6/51 [47:19<5:54:20, 472.45s/it]

Epoch 6/50, Loss: 2.9106


 14%|█▎        | 7/51 [55:08<5:45:37, 471.30s/it]

Epoch 7/50, Loss: 2.8378


 16%|█▌        | 8/51 [1:02:54<5:36:37, 469.71s/it]

Epoch 8/50, Loss: 2.7799


 18%|█▊        | 9/51 [1:10:42<5:28:21, 469.09s/it]

Epoch 9/50, Loss: 2.7316


 20%|█▉        | 10/51 [1:18:29<5:20:14, 468.66s/it]

Epoch 10/50, Loss: 2.6913


 22%|██▏       | 11/51 [1:26:17<5:12:16, 468.42s/it]

Epoch 11/50, Loss: 2.6554
Checkpoint saved at epoch 10


 24%|██▎       | 12/51 [1:34:05<5:04:21, 468.23s/it]

Epoch 12/50, Loss: 2.6251


 25%|██▌       | 13/51 [1:41:52<4:56:20, 467.90s/it]

Epoch 13/50, Loss: 2.5986


 27%|██▋       | 14/51 [1:49:40<4:48:33, 467.93s/it]

Epoch 14/50, Loss: 2.5743


 29%|██▉       | 15/51 [1:57:28<4:40:38, 467.74s/it]

Epoch 15/50, Loss: 2.5527


 31%|███▏      | 16/51 [2:05:16<4:32:55, 467.87s/it]

Epoch 16/50, Loss: 2.5327


 33%|███▎      | 17/51 [2:13:03<4:25:03, 467.76s/it]

Epoch 17/50, Loss: 2.5148


 35%|███▌      | 18/51 [2:20:51<4:17:11, 467.64s/it]

Epoch 18/50, Loss: 2.4990


 37%|███▋      | 19/51 [2:28:38<4:09:20, 467.52s/it]

Epoch 19/50, Loss: 2.4831


 39%|███▉      | 20/51 [2:36:25<4:01:28, 467.38s/it]

Epoch 20/50, Loss: 2.4691


 41%|████      | 21/51 [2:44:13<3:53:44, 467.49s/it]

Epoch 21/50, Loss: 2.4556
Checkpoint saved at epoch 20


 43%|████▎     | 22/51 [2:52:00<3:45:53, 467.37s/it]

Epoch 22/50, Loss: 2.4435


 45%|████▌     | 23/51 [2:59:46<3:38:00, 467.17s/it]

Epoch 23/50, Loss: 2.4324


 47%|████▋     | 24/51 [3:07:33<3:30:12, 467.12s/it]

Epoch 24/50, Loss: 2.4218


 49%|████▉     | 25/51 [3:15:20<3:22:19, 466.91s/it]

Epoch 25/50, Loss: 2.4115


 51%|█████     | 26/51 [3:23:06<3:14:28, 466.76s/it]

Epoch 26/50, Loss: 2.4015


 53%|█████▎    | 27/51 [3:30:53<3:06:39, 466.66s/it]

Epoch 27/50, Loss: 2.3934


 55%|█████▍    | 28/51 [3:38:39<2:58:52, 466.63s/it]

Epoch 28/50, Loss: 2.3847


 57%|█████▋    | 29/51 [3:46:28<2:51:19, 467.26s/it]

Epoch 29/50, Loss: 2.3768


 59%|█████▉    | 30/51 [3:54:16<2:43:36, 467.43s/it]

Epoch 30/50, Loss: 2.3688


 61%|██████    | 31/51 [4:02:03<2:35:46, 467.34s/it]

Epoch 31/50, Loss: 2.3614
Checkpoint saved at epoch 30


 63%|██████▎   | 32/51 [4:09:50<2:27:59, 467.37s/it]

Epoch 32/50, Loss: 2.3550


 65%|██████▍   | 33/51 [4:17:38<2:20:12, 467.34s/it]

Epoch 33/50, Loss: 2.3479


 67%|██████▋   | 34/51 [4:25:25<2:12:26, 467.43s/it]

Epoch 34/50, Loss: 2.3415


 69%|██████▊   | 35/51 [4:33:14<2:04:42, 467.67s/it]

Epoch 35/50, Loss: 2.3355


 71%|███████   | 36/51 [4:41:02<1:56:59, 467.95s/it]

Epoch 36/50, Loss: 2.3296


 73%|███████▎  | 37/51 [4:48:49<1:49:08, 467.73s/it]

Epoch 37/50, Loss: 2.3241


 75%|███████▍  | 38/51 [4:56:36<1:41:16, 467.41s/it]

Epoch 38/50, Loss: 2.3188


 76%|███████▋  | 39/51 [5:04:24<1:33:31, 467.64s/it]

Epoch 39/50, Loss: 2.3132


 78%|███████▊  | 40/51 [5:12:12<1:25:45, 467.81s/it]

Epoch 40/50, Loss: 2.3079


 80%|████████  | 41/51 [5:20:01<1:18:00, 468.06s/it]

Epoch 41/50, Loss: 2.3044
Checkpoint saved at epoch 40


 82%|████████▏ | 42/51 [5:27:49<1:10:12, 468.06s/it]

Epoch 42/50, Loss: 2.2994


 84%|████████▍ | 43/51 [5:35:38<1:02:25, 468.20s/it]

Epoch 43/50, Loss: 2.2945


 86%|████████▋ | 44/51 [5:43:25<54:35, 467.94s/it]  

Epoch 44/50, Loss: 2.2906


 88%|████████▊ | 45/51 [5:51:12<46:46, 467.76s/it]

Epoch 45/50, Loss: 2.2862


 90%|█████████ | 46/51 [5:58:59<38:56, 467.38s/it]

Epoch 46/50, Loss: 2.2827


 92%|█████████▏| 47/51 [6:06:47<31:10, 467.64s/it]

Epoch 47/50, Loss: 2.2787


 94%|█████████▍| 48/51 [6:14:33<23:21, 467.28s/it]

Epoch 48/50, Loss: 2.2742


 96%|█████████▌| 49/51 [6:22:21<15:34, 467.24s/it]

Epoch 49/50, Loss: 2.2710


 98%|█████████▊| 50/51 [6:30:07<07:47, 467.01s/it]

Epoch 50/50, Loss: 2.2671


100%|██████████| 51/51 [6:37:53<00:00, 468.11s/it]

Epoch 51/50, Loss: 2.2639
Checkpoint saved at epoch 50



