In [1]:
import tqdm
import torch

In [2]:
from src.dataset import load_pgn, load_multiple_pgns, create_value_csv_dataset
from src.dataclass import ChessDataset
from src.model import ResNet

In [None]:
# Step 1: Create my training dataset
games = load_multiple_pgns(num_pgns=38)
create_value_csv_dataset(games, name="all2016")

In [None]:
# The step above created a csv file(first26) with 2.302.559 chess state-action pairs
# Since I have many different datasets I can use some of them as validation sets
# This way I can skip performing a train - test split (2016-01 is apprx 700.000 moves)
# I also added value to my dataset
# I created a dataset with 25.496.768 positions to train on.

In [None]:
# Step 2: Create a validation dataset
games = load_pgn("data/pgn/lichess_elite_2017-01.pgn")
create_value_csv_dataset(games, name="le2017-01")

In [3]:
# Step 3: Load the data using the pytorch Dataset and Dataloader classes
dataset = ChessDataset("data/csv/all2016.csv")
loader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=True, num_workers=4, pin_memory=True)

In [4]:
# Step 4: Initialize the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNet(filters=128, res_blocks=6)
model = model.to(device)

In [5]:
model

ResNet(
  (start_block): Sequential(
    (0): Conv2d(18, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (res_tower): ModuleList(
    (0-5): 6 x ResBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (batch1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (batch2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (policy_head): Sequential(
    (0): Conv2d(128, 2, kernel_size=(1, 1), stride=(1, 1))
    (1): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Flatten(start_dim=1, end_dim=-1)
    (4): Linear(in_features=128, out_features=4544, bias=True)
  )
  (value_head): Se

In [6]:
# Step 5: Choose a loss function and the optimizer
policy_loss = torch.nn.CrossEntropyLoss()
value_loss = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [7]:
# If the model isn't loaded training starts at 0 epochs
epoch = 0

In [None]:
# This cell loads the model from a previous state
checkpoint = torch.load(f"models/le_first_26.pth")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch'] + 1

In [8]:
# Step 6: Train the model
EPOCHS = 50
train_losses = []
scaler = torch.GradScaler()

for epoch in tqdm.trange(epoch, EPOCHS+1):
    model.train()
    total_policy_loss = 0
    total_value_loss = 0

    for boards, labels, values in loader:
        boards = boards.to(device)
        labels = labels.to(device)
        values = values.to(device).float()

        optimizer.zero_grad()

        with torch.autocast(device_type="cuda", dtype=torch.float16):
            policy, value = model(boards)
            p_loss = policy_loss(policy, labels)
            v_loss = value_loss(value.squeeze(), values)

            loss = p_loss + v_loss

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_policy_loss += p_loss.item()
        total_value_loss += v_loss.item()

    avg_p_loss = total_policy_loss / len(loader)
    avg_v_loss = total_value_loss / len(loader)

    print(f"Epoch {epoch+1}/{EPOCHS+1} - Policy Loss: {avg_p_loss:.4f} - Value Loss: {avg_v_loss:.4f}")

    train_losses.append(avg_p_loss)
    torch.save({
    "train_losses": train_losses,
    }, "data/loss/loss_log_5.pt")

    
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, f"models/model5/model.{epoch}.pth")

    print(f"Checkpoint saved at epoch {epoch}")

  0%|          | 0/51 [00:00<?, ?it/s]

Epoch 1/51 - Policy Loss: 2.1874 - Value Loss: 0.8047
Checkpoint saved at epoch 0


  4%|▍         | 2/51 [1:13:55<30:10:20, 2216.74s/it]

Epoch 2/51 - Policy Loss: 1.8100 - Value Loss: 0.7865
Checkpoint saved at epoch 1


  6%|▌         | 3/51 [1:51:04<29:37:58, 2222.47s/it]

Epoch 3/51 - Policy Loss: 1.7206 - Value Loss: 0.7797
Checkpoint saved at epoch 2


  8%|▊         | 4/51 [2:28:09<29:01:43, 2223.48s/it]

Epoch 4/51 - Policy Loss: 1.6713 - Value Loss: 0.7752
Checkpoint saved at epoch 3


 10%|▉         | 5/51 [3:05:25<28:28:13, 2228.12s/it]

Epoch 5/51 - Policy Loss: 1.6388 - Value Loss: 0.7719
Checkpoint saved at epoch 4


 12%|█▏        | 6/51 [3:42:48<27:54:44, 2232.99s/it]

Epoch 6/51 - Policy Loss: 1.6151 - Value Loss: 0.7692
Checkpoint saved at epoch 5


 14%|█▎        | 7/51 [4:20:13<27:20:26, 2236.97s/it]

Epoch 7/51 - Policy Loss: 1.5967 - Value Loss: 0.7667
Checkpoint saved at epoch 6


 16%|█▌        | 8/51 [4:57:35<26:44:12, 2238.43s/it]

Epoch 8/51 - Policy Loss: 1.5818 - Value Loss: 0.7647
Checkpoint saved at epoch 7


 18%|█▊        | 9/51 [5:34:45<26:05:10, 2235.96s/it]

Epoch 9/51 - Policy Loss: 1.5696 - Value Loss: 0.7627
Checkpoint saved at epoch 8


 20%|█▉        | 10/51 [6:11:48<25:25:06, 2231.88s/it]

Epoch 10/51 - Policy Loss: 1.5592 - Value Loss: 0.7610
Checkpoint saved at epoch 9


 22%|██▏       | 11/51 [6:49:00<24:47:55, 2231.88s/it]

Epoch 11/51 - Policy Loss: 1.5502 - Value Loss: 0.7594
Checkpoint saved at epoch 10


 24%|██▎       | 12/51 [7:26:10<24:10:28, 2231.50s/it]

Epoch 12/51 - Policy Loss: 1.5423 - Value Loss: 0.7578
Checkpoint saved at epoch 11


 25%|██▌       | 13/51 [8:03:44<23:37:33, 2238.26s/it]

Epoch 13/51 - Policy Loss: 1.5355 - Value Loss: 0.7563
Checkpoint saved at epoch 12


 27%|██▋       | 14/51 [8:40:56<22:59:05, 2236.37s/it]

Epoch 14/51 - Policy Loss: 1.5293 - Value Loss: 0.7550
Checkpoint saved at epoch 13


 29%|██▉       | 15/51 [9:18:27<22:24:29, 2240.83s/it]

Epoch 15/51 - Policy Loss: 1.5236 - Value Loss: 0.7536
Checkpoint saved at epoch 14


 31%|███▏      | 16/51 [9:55:29<21:43:42, 2234.92s/it]

Epoch 16/51 - Policy Loss: 1.5186 - Value Loss: 0.7523
Checkpoint saved at epoch 15


 33%|███▎      | 17/51 [10:33:01<21:09:22, 2240.09s/it]

Epoch 17/51 - Policy Loss: 1.5142 - Value Loss: 0.7510
Checkpoint saved at epoch 16


 35%|███▌      | 18/51 [11:17:19<21:41:12, 2365.83s/it]

Epoch 18/51 - Policy Loss: 1.5100 - Value Loss: 0.7498
Checkpoint saved at epoch 17


 37%|███▋      | 19/51 [11:55:20<20:48:13, 2340.43s/it]

Epoch 19/51 - Policy Loss: 1.5061 - Value Loss: 0.7487
Checkpoint saved at epoch 18


 39%|███▉      | 20/51 [12:33:35<20:02:03, 2326.56s/it]

Epoch 20/51 - Policy Loss: 1.5025 - Value Loss: 0.7475
Checkpoint saved at epoch 19


 41%|████      | 21/51 [13:11:23<19:14:30, 2309.02s/it]

Epoch 21/51 - Policy Loss: 1.4992 - Value Loss: 0.7463
Checkpoint saved at epoch 20


 43%|████▎     | 22/51 [13:49:11<18:30:05, 2296.73s/it]

Epoch 22/51 - Policy Loss: 1.4963 - Value Loss: 0.7453
Checkpoint saved at epoch 21


 45%|████▌     | 23/51 [14:27:21<17:50:52, 2294.73s/it]

Epoch 23/51 - Policy Loss: 1.4935 - Value Loss: 0.7442
Checkpoint saved at epoch 22


 47%|████▋     | 24/51 [15:05:02<17:08:02, 2284.54s/it]

Epoch 24/51 - Policy Loss: 1.4908 - Value Loss: 0.7432
Checkpoint saved at epoch 23


 49%|████▉     | 25/51 [15:42:42<16:26:46, 2277.17s/it]

Epoch 25/51 - Policy Loss: 1.4884 - Value Loss: 0.7421
Checkpoint saved at epoch 24


 51%|█████     | 26/51 [16:20:19<15:46:21, 2271.26s/it]

Epoch 26/51 - Policy Loss: 1.4861 - Value Loss: 0.7411
Checkpoint saved at epoch 25


 53%|█████▎    | 27/51 [16:58:00<15:07:11, 2267.99s/it]

Epoch 27/51 - Policy Loss: 1.4839 - Value Loss: 0.7401
Checkpoint saved at epoch 26


 55%|█████▍    | 28/51 [17:35:50<14:29:40, 2268.73s/it]

Epoch 28/51 - Policy Loss: 1.4820 - Value Loss: 0.7391
Checkpoint saved at epoch 27


 57%|█████▋    | 29/51 [18:14:01<13:54:19, 2275.44s/it]

Epoch 29/51 - Policy Loss: 1.4800 - Value Loss: 0.7382
Checkpoint saved at epoch 28


 57%|█████▋    | 29/51 [18:30:03<14:02:06, 2296.66s/it]


KeyboardInterrupt: 