# Chess Engine with PyTorch

## Imports

In [1]:
import os
import numpy as np # type: ignore
import time
import torch
import torch.nn as nn # type: ignore
import torch.optim as optim # type: ignore
from torch.utils.data import DataLoader # type: ignore
from chess import pgn # type: ignore
from tqdm import tqdm # type: ignore

# Data preprocessing

## Load data

In [2]:
def load_pgn(file_path):
    games = []
    with open(file_path, 'r') as pgn_file:
        while True:
            game = pgn.read_game(pgn_file)
            if game is None:
                break
            games.append(game)
    return games

files = [file for file in os.listdir("../../data/pgn") if file.endswith(".pgn")]
print(len(files))
LIMIT_OF_FILES = min(len(files), 10)
games = []
i = 1
for file in tqdm(files):
    games.extend(load_pgn(f"../../data/pgn/{file}"))
    if i >= LIMIT_OF_FILES:
        break
    i += 1

79


 11%|█▏        | 9/79 [00:01<00:14,  4.81it/s]


In [3]:
print(f"GAMES PARSED: {len(games)}")

GAMES PARSED: 686


## Convert data into tensors

In [4]:
from auxiliary_func import create_input_for_nn, encode_moves, create_input_for_nn_2

In [5]:
X, y = create_input_for_nn_2(games)

print(f"NUMBER OF SAMPLES: {len(y)}")

NUMBER OF SAMPLES: 61619


In [6]:
X = X[0:2_500_000]
y = y[0:2_500_000]

In [7]:
y, move_to_int = encode_moves(y)
num_classes = len(move_to_int)

In [8]:
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.long)

# Preliminary actions

In [29]:
from dataset import ChessDataset
import model as custom_model
import importlib
importlib.reload(custom_model)

<module 'model' from 'c:\\Users\\salla\\OneDrive\\Documents\\Cours\\A3\\Advanced_Machine_Learning\\Kaggle_Competitions\\chess_engine_main\\engines\\torch\\model.py'>

In [32]:
# Create Dataset and DataLoader
dataset = ChessDataset(X, y)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

# Model Initialization
model = custom_model.ChessModel4(num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

Using device: cpu


# Training

In [33]:
num_epochs = 100
# resume training
# model.load_state_dict(torch.load("../../models/TORCH_50EPOCHS_3.pth", map_location=device))
for epoch in range(num_epochs):
    start_time = time.time()
    model.train()
    running_loss = 0.0
    for inputs, labels in tqdm(dataloader):
        inputs, labels = inputs.to(device), labels.to(device)  # Move data to GPU
        optimizer.zero_grad()

        outputs = model(inputs)  # Raw logits

        # Compute loss
        loss = criterion(outputs, labels)
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        running_loss += loss.item()
    end_time = time.time()
    epoch_time = end_time - start_time
    minutes: int = int(epoch_time // 60)
    seconds: int = int(epoch_time) - minutes * 60
    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {running_loss / len(dataloader):.4f}, Time: {minutes}m{seconds}s')

100%|██████████| 963/963 [01:33<00:00, 10.25it/s]


Epoch 1/100, Loss: 6.2387, Time: 1m33s


100%|██████████| 963/963 [01:16<00:00, 12.59it/s]


Epoch 2/100, Loss: 5.3948, Time: 1m16s


100%|██████████| 963/963 [01:26<00:00, 11.17it/s]


Epoch 3/100, Loss: 4.8497, Time: 1m26s


100%|██████████| 963/963 [01:07<00:00, 14.18it/s]


Epoch 4/100, Loss: 4.3900, Time: 1m7s


100%|██████████| 963/963 [01:18<00:00, 12.27it/s]


Epoch 5/100, Loss: 4.0005, Time: 1m18s


100%|██████████| 963/963 [01:19<00:00, 12.12it/s]


Epoch 6/100, Loss: 3.6734, Time: 1m19s


100%|██████████| 963/963 [01:46<00:00,  9.02it/s]


Epoch 7/100, Loss: 3.3721, Time: 1m46s


100%|██████████| 963/963 [01:17<00:00, 12.47it/s]


Epoch 8/100, Loss: 3.1157, Time: 1m17s


100%|██████████| 963/963 [01:06<00:00, 14.48it/s]


Epoch 9/100, Loss: 2.8855, Time: 1m6s


100%|██████████| 963/963 [00:45<00:00, 21.24it/s]


Epoch 10/100, Loss: 2.6846, Time: 0m45s


100%|██████████| 963/963 [01:42<00:00,  9.38it/s]


Epoch 11/100, Loss: 2.4883, Time: 1m42s


100%|██████████| 963/963 [01:48<00:00,  8.86it/s]


Epoch 12/100, Loss: 2.3146, Time: 1m48s


100%|██████████| 963/963 [01:44<00:00,  9.25it/s]


Epoch 13/100, Loss: 2.1560, Time: 1m44s


100%|██████████| 963/963 [01:31<00:00, 10.53it/s]


Epoch 14/100, Loss: 2.0146, Time: 1m31s


100%|██████████| 963/963 [01:43<00:00,  9.29it/s]


Epoch 15/100, Loss: 1.8710, Time: 1m43s


100%|██████████| 963/963 [01:48<00:00,  8.90it/s]


Epoch 16/100, Loss: 1.7394, Time: 1m48s


100%|██████████| 963/963 [01:40<00:00,  9.61it/s]


Epoch 17/100, Loss: 1.6267, Time: 1m40s


100%|██████████| 963/963 [01:36<00:00, 10.02it/s]


Epoch 18/100, Loss: 1.5284, Time: 1m36s


100%|██████████| 963/963 [01:40<00:00,  9.59it/s]


Epoch 19/100, Loss: 1.4313, Time: 1m40s


100%|██████████| 963/963 [02:02<00:00,  7.89it/s]


Epoch 20/100, Loss: 1.3337, Time: 2m2s


100%|██████████| 963/963 [01:24<00:00, 11.38it/s]


Epoch 21/100, Loss: 1.2489, Time: 1m24s


100%|██████████| 963/963 [01:20<00:00, 11.92it/s]


Epoch 22/100, Loss: 1.1762, Time: 1m20s


100%|██████████| 963/963 [01:38<00:00,  9.80it/s]


Epoch 23/100, Loss: 1.0961, Time: 1m38s


100%|██████████| 963/963 [01:22<00:00, 11.64it/s]


Epoch 24/100, Loss: 1.0339, Time: 1m22s


100%|██████████| 963/963 [01:19<00:00, 12.09it/s]


Epoch 25/100, Loss: 0.9734, Time: 1m19s


100%|██████████| 963/963 [01:19<00:00, 12.17it/s]


Epoch 26/100, Loss: 0.9179, Time: 1m19s


100%|██████████| 963/963 [01:20<00:00, 12.03it/s]


Epoch 27/100, Loss: 0.8620, Time: 1m20s


100%|██████████| 963/963 [01:17<00:00, 12.35it/s]


Epoch 28/100, Loss: 0.8255, Time: 1m17s


100%|██████████| 963/963 [00:42<00:00, 22.84it/s]


Epoch 29/100, Loss: 0.7734, Time: 0m42s


100%|██████████| 963/963 [00:36<00:00, 26.49it/s]


Epoch 30/100, Loss: 0.7368, Time: 0m36s


100%|██████████| 963/963 [00:34<00:00, 27.86it/s]


Epoch 31/100, Loss: 0.7034, Time: 0m34s


100%|██████████| 963/963 [00:34<00:00, 28.19it/s]


Epoch 32/100, Loss: 0.6681, Time: 0m34s


100%|██████████| 963/963 [00:34<00:00, 28.13it/s]


Epoch 33/100, Loss: 0.6381, Time: 0m34s


100%|██████████| 963/963 [00:34<00:00, 28.25it/s]


Epoch 34/100, Loss: 0.6096, Time: 0m34s


100%|██████████| 963/963 [00:33<00:00, 28.35it/s]


Epoch 35/100, Loss: 0.5857, Time: 0m33s


100%|██████████| 963/963 [00:35<00:00, 26.87it/s]


Epoch 36/100, Loss: 0.5579, Time: 0m35s


100%|██████████| 963/963 [00:35<00:00, 27.39it/s]


Epoch 37/100, Loss: 0.5400, Time: 0m35s


100%|██████████| 963/963 [00:34<00:00, 28.25it/s]


Epoch 38/100, Loss: 0.5162, Time: 0m34s


100%|██████████| 963/963 [00:49<00:00, 19.33it/s]


Epoch 39/100, Loss: 0.5015, Time: 0m49s


100%|██████████| 963/963 [00:51<00:00, 18.74it/s]


Epoch 40/100, Loss: 0.4863, Time: 0m51s


100%|██████████| 963/963 [00:49<00:00, 19.34it/s]


Epoch 41/100, Loss: 0.4694, Time: 0m49s


100%|██████████| 963/963 [00:51<00:00, 18.56it/s]


Epoch 42/100, Loss: 0.4556, Time: 0m51s


100%|██████████| 963/963 [01:00<00:00, 15.90it/s]


Epoch 43/100, Loss: 0.4373, Time: 1m0s


100%|██████████| 963/963 [00:43<00:00, 21.94it/s]


Epoch 44/100, Loss: 0.4229, Time: 0m43s


100%|██████████| 963/963 [00:52<00:00, 18.41it/s]


Epoch 45/100, Loss: 0.4202, Time: 0m52s


100%|██████████| 963/963 [00:47<00:00, 20.36it/s]


Epoch 46/100, Loss: 0.4055, Time: 0m47s


100%|██████████| 963/963 [00:47<00:00, 20.23it/s]


Epoch 47/100, Loss: 0.3996, Time: 0m47s


100%|██████████| 963/963 [00:52<00:00, 18.46it/s]


Epoch 48/100, Loss: 0.3866, Time: 0m52s


100%|██████████| 963/963 [00:46<00:00, 20.55it/s]


Epoch 49/100, Loss: 0.3738, Time: 0m46s


100%|██████████| 963/963 [00:46<00:00, 20.55it/s]


Epoch 50/100, Loss: 0.3710, Time: 0m46s


100%|██████████| 963/963 [00:40<00:00, 23.53it/s]


Epoch 51/100, Loss: 0.3619, Time: 0m40s


100%|██████████| 963/963 [00:38<00:00, 25.12it/s]


Epoch 52/100, Loss: 0.3534, Time: 0m38s


100%|██████████| 963/963 [00:38<00:00, 25.19it/s]


Epoch 53/100, Loss: 0.3474, Time: 0m38s


100%|██████████| 963/963 [00:38<00:00, 25.21it/s]


Epoch 54/100, Loss: 0.3425, Time: 0m38s


100%|██████████| 963/963 [00:38<00:00, 25.00it/s]


Epoch 55/100, Loss: 0.3359, Time: 0m38s


100%|██████████| 963/963 [00:38<00:00, 25.18it/s]


Epoch 56/100, Loss: 0.3301, Time: 0m38s


100%|██████████| 963/963 [00:41<00:00, 23.18it/s]


Epoch 57/100, Loss: 0.3239, Time: 0m41s


100%|██████████| 963/963 [00:59<00:00, 16.29it/s]


Epoch 58/100, Loss: 0.3162, Time: 0m59s


100%|██████████| 963/963 [00:55<00:00, 17.37it/s]


Epoch 59/100, Loss: 0.3168, Time: 0m55s


100%|██████████| 963/963 [01:04<00:00, 14.92it/s]


Epoch 60/100, Loss: 0.3075, Time: 1m4s


100%|██████████| 963/963 [00:54<00:00, 17.80it/s]


Epoch 61/100, Loss: 0.3022, Time: 0m54s


100%|██████████| 963/963 [00:37<00:00, 25.48it/s]


Epoch 62/100, Loss: 0.3002, Time: 0m37s


100%|██████████| 963/963 [00:36<00:00, 26.27it/s]


Epoch 63/100, Loss: 0.2945, Time: 0m36s


100%|██████████| 963/963 [00:37<00:00, 25.39it/s]


Epoch 64/100, Loss: 0.2918, Time: 0m37s


100%|██████████| 963/963 [00:36<00:00, 26.07it/s]


Epoch 65/100, Loss: 0.2884, Time: 0m36s


100%|██████████| 963/963 [00:37<00:00, 25.47it/s]


Epoch 66/100, Loss: 0.2820, Time: 0m37s


100%|██████████| 963/963 [00:37<00:00, 25.84it/s]


Epoch 67/100, Loss: 0.2831, Time: 0m37s


100%|██████████| 963/963 [00:37<00:00, 25.37it/s]


Epoch 68/100, Loss: 0.2787, Time: 0m37s


100%|██████████| 963/963 [00:39<00:00, 24.27it/s]


Epoch 69/100, Loss: 0.2746, Time: 0m39s


100%|██████████| 963/963 [00:38<00:00, 25.20it/s]


Epoch 70/100, Loss: 0.2713, Time: 0m38s


100%|██████████| 963/963 [00:37<00:00, 25.46it/s]


Epoch 71/100, Loss: 0.2704, Time: 0m37s


100%|██████████| 963/963 [00:37<00:00, 25.46it/s]


Epoch 72/100, Loss: 0.2665, Time: 0m37s


100%|██████████| 963/963 [00:37<00:00, 25.41it/s]


Epoch 73/100, Loss: 0.2620, Time: 0m37s


100%|██████████| 963/963 [00:38<00:00, 25.21it/s]


Epoch 74/100, Loss: 0.2587, Time: 0m38s


100%|██████████| 963/963 [00:38<00:00, 25.26it/s]


Epoch 75/100, Loss: 0.2622, Time: 0m38s


100%|██████████| 963/963 [00:38<00:00, 25.02it/s]


Epoch 76/100, Loss: 0.2565, Time: 0m38s


100%|██████████| 963/963 [00:37<00:00, 25.38it/s]


Epoch 77/100, Loss: 0.2553, Time: 0m37s


100%|██████████| 963/963 [00:37<00:00, 25.42it/s]


Epoch 78/100, Loss: 0.2527, Time: 0m37s


100%|██████████| 963/963 [00:38<00:00, 25.19it/s]


Epoch 79/100, Loss: 0.2504, Time: 0m38s


100%|██████████| 963/963 [00:37<00:00, 25.47it/s]


Epoch 80/100, Loss: 0.2466, Time: 0m37s


100%|██████████| 963/963 [00:37<00:00, 25.41it/s]


Epoch 81/100, Loss: 0.2455, Time: 0m37s


100%|██████████| 963/963 [00:38<00:00, 25.30it/s]


Epoch 82/100, Loss: 0.2475, Time: 0m38s


100%|██████████| 963/963 [00:37<00:00, 25.39it/s]


Epoch 83/100, Loss: 0.2423, Time: 0m37s


100%|██████████| 963/963 [00:38<00:00, 25.26it/s]


Epoch 84/100, Loss: 0.2458, Time: 0m38s


100%|██████████| 963/963 [00:38<00:00, 25.24it/s]


Epoch 85/100, Loss: 0.2353, Time: 0m38s


100%|██████████| 963/963 [00:38<00:00, 25.32it/s]


Epoch 86/100, Loss: 0.2363, Time: 0m38s


100%|██████████| 963/963 [00:37<00:00, 25.38it/s]


Epoch 87/100, Loss: 0.2375, Time: 0m37s


100%|██████████| 963/963 [00:38<00:00, 24.78it/s]


Epoch 88/100, Loss: 0.2352, Time: 0m38s


100%|██████████| 963/963 [00:38<00:00, 25.18it/s]


Epoch 89/100, Loss: 0.2286, Time: 0m38s


100%|██████████| 963/963 [00:38<00:00, 25.20it/s]


Epoch 90/100, Loss: 0.2304, Time: 0m38s


100%|██████████| 963/963 [00:38<00:00, 25.25it/s]


Epoch 91/100, Loss: 0.2290, Time: 0m38s


100%|██████████| 963/963 [00:37<00:00, 25.46it/s]


Epoch 92/100, Loss: 0.2288, Time: 0m37s


100%|██████████| 963/963 [00:37<00:00, 25.39it/s]


Epoch 93/100, Loss: 0.2283, Time: 0m37s


100%|██████████| 963/963 [00:37<00:00, 25.41it/s]


Epoch 94/100, Loss: 0.2231, Time: 0m37s


100%|██████████| 963/963 [00:37<00:00, 25.43it/s]


Epoch 95/100, Loss: 0.2241, Time: 0m37s


100%|██████████| 963/963 [00:37<00:00, 25.41it/s]


Epoch 96/100, Loss: 0.2217, Time: 0m37s


100%|██████████| 963/963 [00:37<00:00, 25.37it/s]


Epoch 97/100, Loss: 0.2237, Time: 0m37s


100%|██████████| 963/963 [00:37<00:00, 25.39it/s]


Epoch 98/100, Loss: 0.2211, Time: 0m37s


100%|██████████| 963/963 [00:37<00:00, 25.42it/s]


Epoch 99/100, Loss: 0.2204, Time: 0m37s


100%|██████████| 963/963 [00:38<00:00, 25.30it/s]

Epoch 100/100, Loss: 0.2194, Time: 0m38s





# Save the model and mapping

In [34]:
# Save the model
torch.save(model.state_dict(), "../../models/TORCH_50EPOCHS_3.pth")

In [35]:
import pickle

with open("../../models/heavy_move_to_int3", "wb") as file:
    pickle.dump(move_to_int, file)