In [1]:
import tqdm
import torch

In [2]:
from src.dataset import load_pgn, load_multiple_pgns, create_csv_dataset
from src.dataclass import ChessDataset
from src.model import ChessNet

In [None]:
# Step 1: Create my training dataset
games = load_multiple_pgns(num_pgns=26)
create_csv_dataset(games, name="le_first_26")

In [None]:
# The step above created a csv file with 2.302.559 chess state-action pairs
# Since I have many different datasets I can use some of them as validation sets
# This way I can skip performing a train - test split

In [None]:
# Step 2: Create a validation dataset
games = load_pgn("data/pgn/lichess_elite_2016-01.pgn")
create_csv_dataset(games, name="le_2016-01")

In [3]:
# Step 3: Load the data using the pytorch Dataset and Dataloader classes
dataset = ChessDataset("data/csv/le_first_26.csv")
loader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=True)

In [4]:
# Step 4: Initialize the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ChessNet()
model = model.to(device)

In [5]:
model

ChessNet(
  (conv1): Conv2d(18, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=4096, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=4544, bias=True)
)

In [6]:
# Step 5: Choose a loss function and the optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [8]:
# If the model isn't loaded training starts at 0 epochs
epoch = 0

In [None]:
# This cell loads the model from a previous state
checkpoint = torch.load(f"models/le_first_26.pth")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch'] + 1

In [None]:
# Step 6: Train the model
EPOCHS = 50
train_losses = []
for epoch in tqdm.trange(epoch, EPOCHS+1):
    model.train()
    total_loss = 0

    for boards, labels in loader:
        boards = boards.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(boards)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(loader)

    train_losses.append(avg_loss)
    torch.save({
    "train_losses": train_losses,
    }, "data/loss/loss_log_2.pt")
    
    print(f"Epoch {epoch + 1}/{EPOCHS}, Loss: {avg_loss:.4f}")

    if epoch % 10 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, f"models/model.2.{epoch}.pth")

        print(f"Checkpoint saved at epoch {epoch}")

  2%|▏         | 1/50 [07:26<6:04:20, 446.12s/it]

Epoch 1/50, Loss: 3.5038
Checkpoint saved at epoch 0


  4%|▍         | 2/50 [14:55<5:58:27, 448.07s/it]

Epoch 2/50, Loss: 2.6309


  6%|▌         | 3/50 [22:24<5:51:17, 448.46s/it]

Epoch 3/50, Loss: 2.4614


  8%|▊         | 4/50 [29:54<5:44:10, 448.93s/it]

Epoch 4/50, Loss: 2.3682


 10%|█         | 5/50 [37:21<5:36:12, 448.27s/it]

Epoch 5/50, Loss: 2.3041


 12%|█▏        | 6/50 [44:46<5:28:05, 447.40s/it]

Epoch 6/50, Loss: 2.2555


 14%|█▍        | 7/50 [52:11<5:20:04, 446.62s/it]

Epoch 7/50, Loss: 2.2147


 16%|█▌        | 8/50 [59:36<5:12:09, 445.93s/it]

Epoch 8/50, Loss: 2.1810


 18%|█▊        | 9/50 [1:07:00<5:04:14, 445.22s/it]

Epoch 9/50, Loss: 2.1503


 20%|██        | 10/50 [1:14:25<4:56:49, 445.24s/it]

Epoch 10/50, Loss: 2.1235


 22%|██▏       | 11/50 [1:21:50<4:49:21, 445.16s/it]

Epoch 11/50, Loss: 2.0987
Checkpoint saved at epoch 10


 24%|██▍       | 12/50 [1:29:15<4:41:51, 445.04s/it]

Epoch 12/50, Loss: 2.0766


 26%|██▌       | 13/50 [1:36:36<4:33:51, 444.09s/it]

Epoch 13/50, Loss: 2.0567


 28%|██▊       | 14/50 [1:43:59<4:26:14, 443.75s/it]

Epoch 14/50, Loss: 2.0383


 30%|███       | 15/50 [1:51:22<4:18:38, 443.40s/it]

Epoch 15/50, Loss: 2.0220


 32%|███▏      | 16/50 [1:58:45<4:11:10, 443.24s/it]

Epoch 16/50, Loss: 2.0061


 34%|███▍      | 17/50 [2:06:08<4:03:41, 443.07s/it]

Epoch 17/50, Loss: 1.9919


 36%|███▌      | 18/50 [2:13:31<3:56:17, 443.04s/it]

Epoch 18/50, Loss: 1.9787


 38%|███▊      | 19/50 [2:20:57<3:49:24, 444.01s/it]

Epoch 19/50, Loss: 1.9659


 40%|████      | 20/50 [2:28:26<3:42:47, 445.58s/it]

Epoch 20/50, Loss: 1.9541


 42%|████▏     | 21/50 [2:35:54<3:35:41, 446.26s/it]

Epoch 21/50, Loss: 1.9440
Checkpoint saved at epoch 20


 44%|████▍     | 22/50 [2:43:23<3:28:42, 447.23s/it]

Epoch 22/50, Loss: 1.9336


 46%|████▌     | 23/50 [2:50:51<3:21:17, 447.33s/it]

Epoch 23/50, Loss: 1.9247


 48%|████▊     | 24/50 [2:58:20<3:14:02, 447.77s/it]

Epoch 24/50, Loss: 1.9152


 50%|█████     | 25/50 [3:05:49<3:06:47, 448.28s/it]

Epoch 25/50, Loss: 1.9068


 52%|█████▏    | 26/50 [3:13:18<2:59:21, 448.38s/it]

Epoch 26/50, Loss: 1.8992


 54%|█████▍    | 27/50 [3:20:47<2:51:59, 448.70s/it]

Epoch 27/50, Loss: 1.8922


 56%|█████▌    | 28/50 [3:28:16<2:44:31, 448.73s/it]

Epoch 28/50, Loss: 1.8853


 58%|█████▊    | 29/50 [3:35:45<2:37:03, 448.75s/it]

Epoch 29/50, Loss: 1.8782


 60%|██████    | 30/50 [3:43:14<2:29:38, 448.90s/it]

Epoch 30/50, Loss: 1.8717


 62%|██████▏   | 31/50 [3:50:43<2:22:08, 448.86s/it]

Epoch 31/50, Loss: 1.8656
Checkpoint saved at epoch 30


 64%|██████▍   | 32/50 [3:58:13<2:14:43, 449.09s/it]

Epoch 32/50, Loss: 1.8601


 66%|██████▌   | 33/50 [4:05:42<2:07:16, 449.21s/it]

Epoch 33/50, Loss: 1.8548


 68%|██████▊   | 34/50 [4:13:11<1:59:47, 449.21s/it]

Epoch 34/50, Loss: 1.8496


 70%|███████   | 35/50 [4:20:40<1:52:18, 449.21s/it]

Epoch 35/50, Loss: 1.8446


 72%|███████▏  | 36/50 [4:28:10<1:44:49, 449.25s/it]

Epoch 36/50, Loss: 1.8402


 74%|███████▍  | 37/50 [4:35:38<1:37:17, 449.06s/it]

Epoch 37/50, Loss: 1.8357


 76%|███████▌  | 38/50 [4:43:08<1:29:50, 449.20s/it]

Epoch 38/50, Loss: 1.8305


 78%|███████▊  | 39/50 [4:50:37<1:22:20, 449.17s/it]

Epoch 39/50, Loss: 1.8262


 80%|████████  | 40/50 [4:58:07<1:14:52, 449.29s/it]

Epoch 40/50, Loss: 1.8220


 82%|████████▏ | 41/50 [5:05:37<1:07:26, 449.64s/it]

Epoch 41/50, Loss: 1.8185
Checkpoint saved at epoch 40


 84%|████████▍ | 42/50 [5:13:05<59:52, 449.11s/it]  

Epoch 42/50, Loss: 1.8138


 86%|████████▌ | 43/50 [5:20:35<52:25, 449.42s/it]

Epoch 43/50, Loss: 1.8115


 88%|████████▊ | 44/50 [5:28:04<44:54, 449.13s/it]

Epoch 44/50, Loss: 1.8082


 90%|█████████ | 45/50 [5:35:33<37:25, 449.09s/it]

Epoch 45/50, Loss: 1.8039


 92%|█████████▏| 46/50 [5:43:03<29:57, 449.47s/it]

Epoch 46/50, Loss: 1.8018


 94%|█████████▍| 47/50 [5:50:34<22:29, 449.86s/it]

Epoch 47/50, Loss: 1.7986


 96%|█████████▌| 48/50 [5:58:04<14:59, 449.88s/it]

Epoch 48/50, Loss: 1.7951


 98%|█████████▊| 49/50 [6:05:34<07:30, 450.09s/it]

Epoch 49/50, Loss: 1.7929


100%|██████████| 50/50 [6:13:04<00:00, 447.69s/it]

Epoch 50/50, Loss: 1.7902



