In [1]:
import tqdm
import torch

In [2]:
from src.dataset import load_pgn, load_multiple_pgns, create_value_csv_dataset
from src.dataclass import ChessDataset
from src.model import ResNet

In [None]:
# Step 1: Create my training dataset
games = load_multiple_pgns(num_pgns=38)
create_value_csv_dataset(games, name="all2016")

In [None]:
# The step above created a csv file(first26) with 2.302.559 chess state-action pairs
# Since I have many different datasets I can use some of them as validation sets
# This way I can skip performing a train - test split (2016-01 is apprx 700.000 moves)
# I also added value to my dataset
# I created a dataset with 25.496.768 positions to train on.

In [11]:
# Step 2: Create a validation dataset
games = load_pgn("data/pgn/lichess_elite_2017-01.pgn")
create_value_csv_dataset(games, name="le2017-01")

In [3]:
# Step 3: Load the data using the pytorch Dataset and Dataloader classes
dataset = ChessDataset("data/csv/all2016.csv")
loader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=True, num_workers=4, pin_memory=True)

In [4]:
# Step 4: Initialize the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNet(filters=128, res_blocks=6)
model = model.to(device)

In [5]:
model

ResNet(
  (start_block): Sequential(
    (0): Conv2d(18, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (res_tower): ModuleList(
    (0-5): 6 x ResBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (batch1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (batch2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (policy_head): Sequential(
    (0): Conv2d(128, 2, kernel_size=(1, 1), stride=(1, 1))
    (1): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Flatten(start_dim=1, end_dim=-1)
    (4): Linear(in_features=128, out_features=4544, bias=True)
  )
  (value_head): Se

In [6]:
# Step 5: Choose a loss function and the optimizer
policy_loss = torch.nn.CrossEntropyLoss()
value_loss = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [7]:
# If the model isn't loaded training starts at 0 epochs
epoch = 0

In [8]:
# This cell loads the model from a previous state
checkpoint = torch.load(f"models/model5/model.28.pth")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch'] + 1

In [9]:
model

ResNet(
  (start_block): Sequential(
    (0): Conv2d(18, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (res_tower): ModuleList(
    (0-5): 6 x ResBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (batch1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (batch2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (policy_head): Sequential(
    (0): Conv2d(128, 2, kernel_size=(1, 1), stride=(1, 1))
    (1): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Flatten(start_dim=1, end_dim=-1)
    (4): Linear(in_features=128, out_features=4544, bias=True)
  )
  (value_head): Se

In [10]:
# Step 6: Train the model
EPOCHS = 50
train_losses = []
scaler = torch.GradScaler()

for epoch in tqdm.trange(epoch, EPOCHS+1):
    model.train()
    total_policy_loss = 0
    total_value_loss = 0

    for boards, labels, values in loader:
        boards = boards.to(device)
        labels = labels.to(device)
        values = values.to(device).float()

        optimizer.zero_grad()

        with torch.autocast(device_type="cuda", dtype=torch.float16):
            policy, value = model(boards)
            p_loss = policy_loss(policy, labels)
            v_loss = value_loss(value.squeeze(), values)

            loss = p_loss + v_loss

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_policy_loss += p_loss.item()
        total_value_loss += v_loss.item()

    avg_p_loss = total_policy_loss / len(loader)
    avg_v_loss = total_value_loss / len(loader)

    print(f"Epoch {epoch+1}/{EPOCHS+1} - Policy Loss: {avg_p_loss:.4f} - Value Loss: {avg_v_loss:.4f}")

    train_losses.append(avg_p_loss)
    torch.save({
    "train_losses": train_losses,
    }, "data/loss/loss_log_5.pt")

    
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, f"models/model5/model.{epoch}.pth")

    print(f"Checkpoint saved at epoch {epoch}")

  5%|▍         | 1/22 [37:40<13:11:07, 2260.36s/it]

Epoch 30/51 - Policy Loss: 1.4783 - Value Loss: 0.7372
Checkpoint saved at epoch 29


  9%|▉         | 2/22 [1:15:17<12:32:52, 2258.64s/it]

Epoch 31/51 - Policy Loss: 1.4765 - Value Loss: 0.7362
Checkpoint saved at epoch 30


 14%|█▎        | 3/22 [1:53:03<11:56:15, 2261.88s/it]

Epoch 32/51 - Policy Loss: 1.4749 - Value Loss: 0.7354
Checkpoint saved at epoch 31


 18%|█▊        | 4/22 [2:31:03<11:20:40, 2268.89s/it]

Epoch 33/51 - Policy Loss: 1.4733 - Value Loss: 0.7345
Checkpoint saved at epoch 32


 23%|██▎       | 5/22 [3:09:13<10:45:03, 2276.67s/it]

Epoch 34/51 - Policy Loss: 1.4719 - Value Loss: 0.7336
Checkpoint saved at epoch 33


 27%|██▋       | 6/22 [3:46:51<10:05:22, 2270.14s/it]

Epoch 35/51 - Policy Loss: 1.4706 - Value Loss: 0.7327
Checkpoint saved at epoch 34


 32%|███▏      | 7/22 [4:24:35<9:27:05, 2268.39s/it] 

Epoch 36/51 - Policy Loss: 1.4693 - Value Loss: 0.7318
Checkpoint saved at epoch 35


 36%|███▋      | 8/22 [5:03:21<8:53:30, 2286.49s/it]

Epoch 37/51 - Policy Loss: 1.4681 - Value Loss: 0.7310
Checkpoint saved at epoch 36


 41%|████      | 9/22 [5:42:40<8:20:21, 2309.31s/it]

Epoch 38/51 - Policy Loss: 1.4669 - Value Loss: 0.7301
Checkpoint saved at epoch 37


 45%|████▌     | 10/22 [6:20:57<7:41:04, 2305.38s/it]

Epoch 39/51 - Policy Loss: 1.4658 - Value Loss: 0.7293
Checkpoint saved at epoch 38


 50%|█████     | 11/22 [6:58:38<7:00:10, 2291.83s/it]

Epoch 40/51 - Policy Loss: 1.4646 - Value Loss: 0.7285
Checkpoint saved at epoch 39


 55%|█████▍    | 12/22 [7:35:44<6:18:38, 2271.86s/it]

Epoch 41/51 - Policy Loss: 1.4636 - Value Loss: 0.7277
Checkpoint saved at epoch 40


 59%|█████▉    | 13/22 [8:12:56<5:38:56, 2259.65s/it]

Epoch 42/51 - Policy Loss: 1.4627 - Value Loss: 0.7269
Checkpoint saved at epoch 41


 64%|██████▎   | 14/22 [8:50:18<5:00:35, 2254.38s/it]

Epoch 43/51 - Policy Loss: 1.4617 - Value Loss: 0.7261
Checkpoint saved at epoch 42


 68%|██████▊   | 15/22 [9:27:46<4:22:47, 2252.44s/it]

Epoch 44/51 - Policy Loss: 1.4609 - Value Loss: 0.7253
Checkpoint saved at epoch 43


 73%|███████▎  | 16/22 [10:05:19<3:45:16, 2252.81s/it]

Epoch 45/51 - Policy Loss: 1.4600 - Value Loss: 0.7246
Checkpoint saved at epoch 44


 77%|███████▋  | 17/22 [10:42:38<3:07:23, 2248.63s/it]

Epoch 46/51 - Policy Loss: 1.4591 - Value Loss: 0.7238
Checkpoint saved at epoch 45


 82%|████████▏ | 18/22 [11:20:19<2:30:08, 2252.17s/it]

Epoch 47/51 - Policy Loss: 1.4582 - Value Loss: 0.7231
Checkpoint saved at epoch 46


 86%|████████▋ | 19/22 [11:57:47<1:52:33, 2251.07s/it]

Epoch 48/51 - Policy Loss: 1.4576 - Value Loss: 0.7224
Checkpoint saved at epoch 47


 91%|█████████ | 20/22 [12:35:43<1:15:17, 2258.51s/it]

Epoch 49/51 - Policy Loss: 1.4568 - Value Loss: 0.7216
Checkpoint saved at epoch 48


 95%|█████████▌| 21/22 [13:13:39<37:43, 2263.74s/it]  

Epoch 50/51 - Policy Loss: 1.4561 - Value Loss: 0.7208
Checkpoint saved at epoch 49


100%|██████████| 22/22 [13:52:46<00:00, 2271.21s/it]

Epoch 51/51 - Policy Loss: 1.4555 - Value Loss: 0.7202
Checkpoint saved at epoch 50



