# Chess Engine with PyTorch

## Imports

In [12]:
import os
import numpy as np 
import time
import torch
import torch.nn as nn 
import torch.optim as optim 
from torch.utils.data import DataLoader     
from chess import pgn 
from tqdm import tqdm 
import chess
from auxiliary_func import create_input_for_nn, encode_moves
from dataset import ChessDataset
from model import ChessModel
from multiprocessing import Pool, cpu_count
import matplotlib.pyplot as plt

# Data preprocessing

## Load data

In [2]:
def load_pgn(file_path, limit_per_file=None):
    games = []
    with open(file_path, 'r') as pgn_file:
        count = 0
        while True:
            game = pgn.read_game(pgn_file)
            if game is None:
                break
            games.append(game)
            count += 1
            if limit_per_file is not None and count >= limit_per_file:
                break
    return games

# Lister tous les fichiers PGN
files = [file for file in os.listdir("../data/pgn") if file.endswith(".pgn")]
LIMIT_OF_FILES = min(len(files), 28)
LIMIT_PER_FILE = 500  # Limite de parties par fichier

# Charger les fichiers avec une barre de progression
games = []
for i, file in enumerate(tqdm(files[:LIMIT_OF_FILES], desc="Loading PGN files")):
    games.extend(load_pgn(f"../data/pgn/{file}", limit_per_file=LIMIT_PER_FILE))

print(f"GAMES PARSED: {len(games)}")

Loading PGN files: 100%|██████████| 1/1 [00:01<00:00,  1.02s/it]

GAMES PARSED: 500





In [3]:
print(f"GAMES PARSED: {len(games)}")

GAMES PARSED: 500


## Convert data into tensors

In [4]:
X, y = create_input_for_nn(games)

print(f"NUMBER OF SAMPLES: {len(y)}")

X = X[0:2500000]
y = y[0:2500000]

y, move_to_int = encode_moves(y)
num_classes = len(move_to_int)

X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.long)

NUMBER OF SAMPLES: 43532


# Preliminary actions

In [None]:
# Create Dataset and DataLoader
dataset = ChessDataset(X, y)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, pin_memory=True, num_workers=cpu_count())

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

# Model Initialization
model = ChessModel(num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

Using device: cuda


# Training

In [None]:
def train(model, dataloader, optimizer, criterion, device, num_epochs=50):
    model.to(device)
    epoch_losses = []

    for epoch in range(1, num_epochs + 1):
        model.train()
        running_loss = 0.0

        # tqdm avec affichage de l'epoch et progression fine
        pbar = tqdm(dataloader, desc=f"Epoch {epoch}/{num_epochs}", leave=False)
        for inputs, labels in pbar:
            # transfert non‐bloquant si possible
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()

            # gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            running_loss += loss.item()
            pbar.set_postfix(loss=running_loss / (pbar.n + 1), refresh=False)

        avg_loss = running_loss / len(dataloader)
        epoch_losses.append(avg_loss)
        print(f"Epoch {epoch}/{num_epochs} — avg loss: {avg_loss:.4f}")

    # Tracé de la courbe de perte
    plt.figure(figsize=(6,4))
    plt.plot(range(1, num_epochs+1), epoch_losses, marker='o')
    plt.title("Training Loss over Epochs")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    return epoch_losses


100%|██████████| 681/681 [00:02<00:00, 279.58it/s]


Epoch 1/50, Loss: 6.5023, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 312.99it/s]


Epoch 2/50, Loss: 6.0527, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 313.88it/s]


Epoch 3/50, Loss: 5.9686, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 300.16it/s]


Epoch 4/50, Loss: 5.9179, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 283.97it/s]


Epoch 5/50, Loss: 5.8749, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 261.36it/s]


Epoch 6/50, Loss: 5.8369, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 246.69it/s]


Epoch 7/50, Loss: 5.7988, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 278.56it/s]


Epoch 8/50, Loss: 5.7657, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 266.63it/s]


Epoch 9/50, Loss: 5.7342, Time: 0m2s


100%|██████████| 681/681 [00:03<00:00, 221.29it/s]


Epoch 10/50, Loss: 5.6967, Time: 0m3s


100%|██████████| 681/681 [00:02<00:00, 284.40it/s]


Epoch 11/50, Loss: 5.6664, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 314.53it/s]


Epoch 12/50, Loss: 5.6333, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 267.96it/s]


Epoch 13/50, Loss: 5.5989, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 274.33it/s]


Epoch 14/50, Loss: 5.5665, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 310.49it/s]


Epoch 15/50, Loss: 5.5324, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 245.46it/s]


Epoch 16/50, Loss: 5.4963, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 283.25it/s]


Epoch 17/50, Loss: 5.4626, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 256.34it/s]


Epoch 18/50, Loss: 5.4277, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 290.53it/s]


Epoch 19/50, Loss: 5.3863, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 302.14it/s]


Epoch 20/50, Loss: 5.3497, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 305.63it/s]


Epoch 21/50, Loss: 5.3141, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 307.14it/s]


Epoch 22/50, Loss: 5.2759, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 313.57it/s]


Epoch 23/50, Loss: 5.2390, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 301.20it/s]


Epoch 24/50, Loss: 5.2046, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 293.79it/s]


Epoch 25/50, Loss: 5.1672, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 315.59it/s]


Epoch 26/50, Loss: 5.1309, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 296.70it/s]


Epoch 27/50, Loss: 5.0899, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 317.32it/s]


Epoch 28/50, Loss: 5.0584, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 313.50it/s]


Epoch 29/50, Loss: 5.0149, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 302.25it/s]


Epoch 30/50, Loss: 4.9868, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 310.24it/s]


Epoch 31/50, Loss: 4.9549, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 312.21it/s]


Epoch 32/50, Loss: 4.9194, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 294.71it/s]


Epoch 33/50, Loss: 4.8897, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 298.67it/s]


Epoch 34/50, Loss: 4.8486, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 311.37it/s]


Epoch 35/50, Loss: 4.8146, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 312.51it/s]


Epoch 36/50, Loss: 4.7826, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 313.89it/s]


Epoch 37/50, Loss: 4.7537, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 309.07it/s]


Epoch 38/50, Loss: 4.7139, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 310.28it/s]


Epoch 39/50, Loss: 4.6845, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 302.80it/s]


Epoch 40/50, Loss: 4.6550, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 312.32it/s]


Epoch 41/50, Loss: 4.6201, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 303.09it/s]


Epoch 42/50, Loss: 4.5908, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 313.68it/s]


Epoch 43/50, Loss: 4.5577, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 307.20it/s]


Epoch 44/50, Loss: 4.5297, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 308.08it/s]


Epoch 45/50, Loss: 4.4994, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 304.50it/s]


Epoch 46/50, Loss: 4.4742, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 249.41it/s]


Epoch 47/50, Loss: 4.4415, Time: 0m2s


100%|██████████| 681/681 [00:03<00:00, 225.49it/s]


Epoch 48/50, Loss: 4.4173, Time: 0m3s


100%|██████████| 681/681 [00:02<00:00, 258.15it/s]


Epoch 49/50, Loss: 4.3816, Time: 0m2s


100%|██████████| 681/681 [00:02<00:00, 251.19it/s]

Epoch 50/50, Loss: 4.3527, Time: 0m2s





# Save the model and mapping

In [7]:
# Save the model
torch.save(model.state_dict(), "../models/noob.pth")

In [9]:
import pickle

with open("../models/heavy_move_to_int", "wb") as file:
    pickle.dump(move_to_int, file)