# Chess Engine with PyTorch

## Imports

In [3]:
import os
import numpy as np # type: ignore
import time
import torch
import torch.nn as nn # type: ignore
import torch.optim as optim # type: ignore
from torch.utils.data import DataLoader # type: ignore
from chess import pgn # type: ignore
from tqdm import tqdm # type: ignore

# Data preprocessing

## Load data

In [4]:
def load_pgn(file_path):
    games = []
    with open(file_path, 'r') as pgn_file:
        while True:
            game = pgn.read_game(pgn_file)
            if game is None:
                break
            games.append(game)
    return games

files = [file for file in os.listdir("../../data/pgn") if file.endswith(".pgn")]
LIMIT_OF_FILES = min(len(files), 20)
games = []
i = 1
for file in tqdm(files):
    games.extend(load_pgn(f"../../data/pgn/{file}"))
    if i >= LIMIT_OF_FILES:
        break
    i += 1

 24%|██▍       | 19/78 [00:32<01:39,  1.69s/it]


In [6]:
print(f"GAMES PARSED: {len(games)}")

GAMES PARSED: 2572


## Convert data into tensors

In [1]:
from auxiliary_func import create_input_for_nn, encode_moves

In [5]:
X, y = create_input_for_nn(games)


In [6]:
X = X[0:2500000]
y = y[0:2500000]

In [7]:
y, move_to_int = encode_moves(y)
num_classes = len(move_to_int)

In [8]:
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.long)

# Preliminary actions

In [10]:
from dataset import ChessDataset
from model import ChessModel

In [11]:
# Create Dataset and DataLoader
dataset = ChessDataset(X, y)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

# Model Initialization
model = ChessModel(num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

Using device: cpu


# Training

In [31]:
num_epochs = 50
for epoch in range(num_epochs):
    start_time = time.time()
    model.train()
    running_loss = 0.0
    for inputs, labels in tqdm(dataloader):
        inputs, labels = inputs.to(device), labels.to(device)  # Move data to GPU
        optimizer.zero_grad()

        outputs = model(inputs)  # Raw logits

        # Compute loss
        loss = criterion(outputs, labels)
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        running_loss += loss.item()
    end_time = time.time()
    epoch_time = end_time - start_time
    minutes: int = int(epoch_time // 60)
    seconds: int = int(epoch_time) - minutes * 60
    print(f'Epoch {epoch + 1 + 50}/{num_epochs + 1 + 50}, Loss: {running_loss / len(dataloader):.4f}, Time: {minutes}m{seconds}s')

100%|██████████| 3369/3369 [03:08<00:00, 17.87it/s]


Epoch 51/101, Loss: 0.2015, Time: 3m8s


100%|██████████| 3369/3369 [03:13<00:00, 17.44it/s]


Epoch 52/101, Loss: 0.2010, Time: 3m13s


100%|██████████| 3369/3369 [03:15<00:00, 17.20it/s]


Epoch 53/101, Loss: 0.2007, Time: 3m15s


100%|██████████| 3369/3369 [03:15<00:00, 17.21it/s]


Epoch 54/101, Loss: 0.1995, Time: 3m15s


100%|██████████| 3369/3369 [03:31<00:00, 15.92it/s]


Epoch 55/101, Loss: 0.1983, Time: 3m31s


100%|██████████| 3369/3369 [03:04<00:00, 18.27it/s]


Epoch 56/101, Loss: 0.1983, Time: 3m4s


100%|██████████| 3369/3369 [03:00<00:00, 18.67it/s]


Epoch 57/101, Loss: 0.1978, Time: 3m0s


100%|██████████| 3369/3369 [02:47<00:00, 20.08it/s]


Epoch 58/101, Loss: 0.1981, Time: 2m47s


100%|██████████| 3369/3369 [02:49<00:00, 19.87it/s]


Epoch 59/101, Loss: 0.1968, Time: 2m49s


100%|██████████| 3369/3369 [02:48<00:00, 20.04it/s]


Epoch 60/101, Loss: 0.1964, Time: 2m48s


100%|██████████| 3369/3369 [02:48<00:00, 19.94it/s]


Epoch 61/101, Loss: 0.1960, Time: 2m48s


100%|██████████| 3369/3369 [02:48<00:00, 20.01it/s]


Epoch 62/101, Loss: 0.1962, Time: 2m48s


100%|██████████| 3369/3369 [02:45<00:00, 20.38it/s]


Epoch 63/101, Loss: 0.1953, Time: 2m45s


100%|██████████| 3369/3369 [02:45<00:00, 20.32it/s]


Epoch 64/101, Loss: 0.1939, Time: 2m45s


100%|██████████| 3369/3369 [02:45<00:00, 20.32it/s]


Epoch 65/101, Loss: 0.1943, Time: 2m45s


100%|██████████| 3369/3369 [02:46<00:00, 20.22it/s]


Epoch 66/101, Loss: 0.1938, Time: 2m46s


100%|██████████| 3369/3369 [02:46<00:00, 20.28it/s]


Epoch 67/101, Loss: 0.1926, Time: 2m46s


100%|██████████| 3369/3369 [02:47<00:00, 20.09it/s]


Epoch 68/101, Loss: 0.1951, Time: 2m47s


100%|██████████| 3369/3369 [02:48<00:00, 20.04it/s]


Epoch 69/101, Loss: 0.1939, Time: 2m48s


100%|██████████| 3369/3369 [02:47<00:00, 20.06it/s]


Epoch 70/101, Loss: 0.1928, Time: 2m47s


100%|██████████| 3369/3369 [02:47<00:00, 20.13it/s]


Epoch 71/101, Loss: 0.1925, Time: 2m47s


100%|██████████| 3369/3369 [02:46<00:00, 20.21it/s]


Epoch 72/101, Loss: 0.1924, Time: 2m46s


100%|██████████| 3369/3369 [02:47<00:00, 20.15it/s]


Epoch 73/101, Loss: 0.1914, Time: 2m47s


100%|██████████| 3369/3369 [02:47<00:00, 20.15it/s]


Epoch 74/101, Loss: 0.1916, Time: 2m47s


100%|██████████| 3369/3369 [02:48<00:00, 20.05it/s]


Epoch 75/101, Loss: 0.1901, Time: 2m48s


100%|██████████| 3369/3369 [03:14<00:00, 17.31it/s]


Epoch 76/101, Loss: 0.1903, Time: 3m14s


100%|██████████| 3369/3369 [02:51<00:00, 19.63it/s]


Epoch 77/101, Loss: 0.1897, Time: 2m51s


100%|██████████| 3369/3369 [02:51<00:00, 19.68it/s]


Epoch 78/101, Loss: 0.1897, Time: 2m51s


100%|██████████| 3369/3369 [02:53<00:00, 19.47it/s]


Epoch 79/101, Loss: 0.1892, Time: 2m53s


100%|██████████| 3369/3369 [02:52<00:00, 19.52it/s]


Epoch 80/101, Loss: 0.1900, Time: 2m52s


100%|██████████| 3369/3369 [02:53<00:00, 19.39it/s]


Epoch 81/101, Loss: 0.1892, Time: 2m53s


100%|██████████| 3369/3369 [03:02<00:00, 18.43it/s]


Epoch 82/101, Loss: 0.1879, Time: 3m2s


100%|██████████| 3369/3369 [03:13<00:00, 17.44it/s]


Epoch 83/101, Loss: 0.1876, Time: 3m13s


100%|██████████| 3369/3369 [02:57<00:00, 19.01it/s]


Epoch 84/101, Loss: 0.1878, Time: 2m57s


100%|██████████| 3369/3369 [03:02<00:00, 18.44it/s]


Epoch 85/101, Loss: 0.1867, Time: 3m2s


100%|██████████| 3369/3369 [02:59<00:00, 18.81it/s]


Epoch 86/101, Loss: 0.1869, Time: 2m59s


100%|██████████| 3369/3369 [03:10<00:00, 17.72it/s]


Epoch 87/101, Loss: 0.1868, Time: 3m10s


100%|██████████| 3369/3369 [02:57<00:00, 18.98it/s]


Epoch 88/101, Loss: 0.1846, Time: 2m57s


100%|██████████| 3369/3369 [02:54<00:00, 19.28it/s]


Epoch 89/101, Loss: 0.1856, Time: 2m54s


100%|██████████| 3369/3369 [02:53<00:00, 19.46it/s]


Epoch 90/101, Loss: 0.1850, Time: 2m53s


100%|██████████| 3369/3369 [02:53<00:00, 19.39it/s]


Epoch 91/101, Loss: 0.1845, Time: 2m53s


100%|██████████| 3369/3369 [02:52<00:00, 19.52it/s]


Epoch 92/101, Loss: 0.1850, Time: 2m52s


100%|██████████| 3369/3369 [02:54<00:00, 19.35it/s]


Epoch 93/101, Loss: 0.1841, Time: 2m54s


100%|██████████| 3369/3369 [02:54<00:00, 19.32it/s]


Epoch 94/101, Loss: 0.1839, Time: 2m54s


100%|██████████| 3369/3369 [02:54<00:00, 19.31it/s]


Epoch 95/101, Loss: 0.1833, Time: 2m54s


100%|██████████| 3369/3369 [02:53<00:00, 19.41it/s]


Epoch 96/101, Loss: 0.1842, Time: 2m53s


100%|██████████| 3369/3369 [02:53<00:00, 19.37it/s]


Epoch 97/101, Loss: 0.1822, Time: 2m53s


100%|██████████| 3369/3369 [02:53<00:00, 19.38it/s]


Epoch 98/101, Loss: 0.1848, Time: 2m53s


100%|██████████| 3369/3369 [02:53<00:00, 19.39it/s]


Epoch 99/101, Loss: 0.1821, Time: 2m53s


100%|██████████| 3369/3369 [02:55<00:00, 19.22it/s]

Epoch 100/101, Loss: 0.1837, Time: 2m55s





# Save the model and mapping

In [12]:
# Save the model
torch.save(model.state_dict(), "../models/TORCH_100EPOCHS.pth")

In [13]:
import pickle

with open("../models/heavy_move_to_int", "wb") as file:
    pickle.dump(move_to_int, file)