# Chess Engine with Pytorch

### Imports

In [1]:
import os 
import numpy
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from chess import pgn
from tqdm import tqdm

### Data Preprocessing


#### Load Data

In [2]:
def load_pgn(file_path):
    games = []
    with open(file_path, "r") as pgn_file:
        while True:
            game = pgn.read_game(pgn_file)
            if game is None:
                break
            games.append(game)	
    return games

files = [file for file in os.listdir("../Database") if file.endswith(".pgn")]

LIMIT_OF_FILES = min(len(files), 20)

games = []
i = 1
for file in tqdm(files):
    games.extend(load_pgn(f"../Database/{file}"))
    if i == LIMIT_OF_FILES:
        break
    i += 1


 17%|█▋        | 19/110 [00:21<01:41,  1.11s/it]


In [3]:
print(f"Games Parsed: {len(games)}")

Games Parsed: 5667


# Convert data into tensors

In [4]:
from auxiliary_functs import create_input_for_nn, encode_moves

In [5]:
X, y = create_input_for_nn(games)
print(f"Number of samples: {len(y)}")

Number of samples: 479333


In [6]:
# X = X[0:2500000]
# y = y[0:2500000]

In [7]:
y, move_to_int = encode_moves(y)
num_classes = len(move_to_int)

In [8]:
X = torch.tensor(X,dtype=torch.float32)
y = torch.tensor(y,dtype=torch.long)


# Preliminary actions

In [9]:
from dataset import ChessDataset
from model import ChessModel

In [10]:
# Create Dataset and DataLoader
dataset = ChessDataset(X, y)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

# Model Initialization
model = ChessModel(num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Load the model state
if os.path.exists("../model/chess_model_pytorch.pth"):
    model.load_state_dict(torch.load("../model/chess_model_pytorch.pth"))
    print("Loaded model from disk.")

# Load the optimizer state
if os.path.exists("../model/chess_optimizer_pytorch.pth"):
    optimizer.load_state_dict(torch.load("../model/chess_optimizer_pytorch.pth"))
    print("Loaded optimizer state from disk.")

Using device: cuda


# Training with games

In [11]:

num_epochs = 50

for epoch in range(num_epochs):
    start_time = time.time()
    model.train()
    running_loss = 0.0
    for inputs, labels in tqdm(dataloader):
        inputs, labels = inputs.to(device), labels.to(device) # move data to GPU
        optimizer.zero_grad()
        outputs = model(inputs) # Raw logits

        # Compute loss
        loss = criterion(outputs, labels)
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()
        running_loss += loss.item()
    end_time = time.time()
    epoch_time = end_time - start_time
    minutes: int = int(epoch_time // 60)
    seconds: int = int(epoch_time % 60)

    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {running_loss / len(dataloader):.4f} - Time: {minutes}m {seconds}s")

100%|██████████| 7490/7490 [01:24<00:00, 88.28it/s] 


Epoch 1/50 - Loss: 4.8225 - Time: 1m 24s


100%|██████████| 7490/7490 [01:26<00:00, 86.39it/s] 


Epoch 2/50 - Loss: 3.6173 - Time: 1m 26s


100%|██████████| 7490/7490 [01:25<00:00, 87.10it/s] 


Epoch 3/50 - Loss: 3.1958 - Time: 1m 26s


100%|██████████| 7490/7490 [01:35<00:00, 78.20it/s] 


Epoch 4/50 - Loss: 2.9227 - Time: 1m 35s


100%|██████████| 7490/7490 [01:24<00:00, 88.42it/s] 


Epoch 5/50 - Loss: 2.7186 - Time: 1m 24s


100%|██████████| 7490/7490 [01:24<00:00, 88.53it/s] 


Epoch 6/50 - Loss: 2.5578 - Time: 1m 24s


100%|██████████| 7490/7490 [01:22<00:00, 90.27it/s] 


Epoch 7/50 - Loss: 2.4218 - Time: 1m 22s


100%|██████████| 7490/7490 [01:16<00:00, 97.97it/s] 


Epoch 8/50 - Loss: 2.3018 - Time: 1m 16s


100%|██████████| 7490/7490 [01:15<00:00, 99.40it/s] 


Epoch 9/50 - Loss: 2.1935 - Time: 1m 15s


100%|██████████| 7490/7490 [01:24<00:00, 88.23it/s] 


Epoch 10/50 - Loss: 2.0947 - Time: 1m 24s


100%|██████████| 7490/7490 [01:15<00:00, 99.19it/s] 


Epoch 11/50 - Loss: 2.0008 - Time: 1m 15s


100%|██████████| 7490/7490 [01:23<00:00, 89.89it/s] 


Epoch 12/50 - Loss: 1.9127 - Time: 1m 23s


100%|██████████| 7490/7490 [01:20<00:00, 92.99it/s] 


Epoch 13/50 - Loss: 1.8285 - Time: 1m 20s


100%|██████████| 7490/7490 [01:16<00:00, 97.86it/s] 


Epoch 14/50 - Loss: 1.7492 - Time: 1m 16s


100%|██████████| 7490/7490 [01:22<00:00, 91.09it/s] 


Epoch 15/50 - Loss: 1.6728 - Time: 1m 22s


100%|██████████| 7490/7490 [01:20<00:00, 92.63it/s] 


Epoch 16/50 - Loss: 1.5994 - Time: 1m 20s


100%|██████████| 7490/7490 [01:10<00:00, 106.55it/s]


Epoch 17/50 - Loss: 1.5298 - Time: 1m 10s


100%|██████████| 7490/7490 [01:09<00:00, 108.09it/s]


Epoch 18/50 - Loss: 1.4631 - Time: 1m 9s


100%|██████████| 7490/7490 [01:04<00:00, 116.05it/s]


Epoch 19/50 - Loss: 1.3989 - Time: 1m 4s


100%|██████████| 7490/7490 [01:05<00:00, 113.97it/s]


Epoch 20/50 - Loss: 1.3375 - Time: 1m 5s


100%|██████████| 7490/7490 [01:05<00:00, 114.90it/s]


Epoch 21/50 - Loss: 1.2786 - Time: 1m 5s


100%|██████████| 7490/7490 [01:06<00:00, 112.93it/s]


Epoch 22/50 - Loss: 1.2226 - Time: 1m 6s


100%|██████████| 7490/7490 [01:04<00:00, 115.94it/s]


Epoch 23/50 - Loss: 1.1692 - Time: 1m 4s


100%|██████████| 7490/7490 [01:03<00:00, 118.26it/s]


Epoch 24/50 - Loss: 1.1177 - Time: 1m 3s


100%|██████████| 7490/7490 [01:03<00:00, 117.07it/s]


Epoch 25/50 - Loss: 1.0680 - Time: 1m 3s


100%|██████████| 7490/7490 [01:03<00:00, 117.53it/s]


Epoch 26/50 - Loss: 1.0203 - Time: 1m 3s


100%|██████████| 7490/7490 [01:03<00:00, 117.08it/s]


Epoch 27/50 - Loss: 0.9761 - Time: 1m 3s


100%|██████████| 7490/7490 [01:03<00:00, 117.32it/s]


Epoch 28/50 - Loss: 0.9326 - Time: 1m 3s


100%|██████████| 7490/7490 [01:03<00:00, 117.50it/s]


Epoch 29/50 - Loss: 0.8914 - Time: 1m 3s


100%|██████████| 7490/7490 [01:03<00:00, 117.37it/s]


Epoch 30/50 - Loss: 0.8529 - Time: 1m 3s


100%|██████████| 7490/7490 [01:03<00:00, 117.13it/s]


Epoch 31/50 - Loss: 0.8150 - Time: 1m 3s


100%|██████████| 7490/7490 [01:03<00:00, 117.49it/s]


Epoch 32/50 - Loss: 0.7795 - Time: 1m 3s


100%|██████████| 7490/7490 [01:03<00:00, 117.17it/s]


Epoch 33/50 - Loss: 0.7459 - Time: 1m 3s


100%|██████████| 7490/7490 [01:03<00:00, 117.70it/s]


Epoch 34/50 - Loss: 0.7141 - Time: 1m 3s


100%|██████████| 7490/7490 [01:03<00:00, 117.58it/s]


Epoch 35/50 - Loss: 0.6848 - Time: 1m 3s


100%|██████████| 7490/7490 [01:03<00:00, 117.30it/s]


Epoch 36/50 - Loss: 0.6565 - Time: 1m 3s


100%|██████████| 7490/7490 [01:03<00:00, 117.73it/s]


Epoch 37/50 - Loss: 0.6303 - Time: 1m 3s


100%|██████████| 7490/7490 [01:03<00:00, 117.29it/s]


Epoch 38/50 - Loss: 0.6053 - Time: 1m 3s


100%|██████████| 7490/7490 [01:03<00:00, 117.57it/s]


Epoch 39/50 - Loss: 0.5826 - Time: 1m 3s


100%|██████████| 7490/7490 [01:03<00:00, 117.66it/s]


Epoch 40/50 - Loss: 0.5602 - Time: 1m 3s


100%|██████████| 7490/7490 [01:03<00:00, 117.06it/s]


Epoch 41/50 - Loss: 0.5397 - Time: 1m 3s


100%|██████████| 7490/7490 [01:03<00:00, 117.52it/s]


Epoch 42/50 - Loss: 0.5207 - Time: 1m 3s


100%|██████████| 7490/7490 [01:03<00:00, 117.21it/s]


Epoch 43/50 - Loss: 0.5040 - Time: 1m 3s


100%|██████████| 7490/7490 [01:03<00:00, 117.61it/s]


Epoch 44/50 - Loss: 0.4873 - Time: 1m 3s


100%|██████████| 7490/7490 [01:03<00:00, 117.32it/s]


Epoch 45/50 - Loss: 0.4725 - Time: 1m 3s


100%|██████████| 7490/7490 [01:03<00:00, 117.27it/s]


Epoch 46/50 - Loss: 0.4590 - Time: 1m 3s


100%|██████████| 7490/7490 [01:03<00:00, 117.56it/s]


Epoch 47/50 - Loss: 0.4458 - Time: 1m 3s


100%|██████████| 7490/7490 [01:03<00:00, 117.25it/s]


Epoch 48/50 - Loss: 0.4341 - Time: 1m 3s


100%|██████████| 7490/7490 [01:03<00:00, 117.59it/s]


Epoch 49/50 - Loss: 0.4229 - Time: 1m 3s


100%|██████████| 7490/7490 [01:03<00:00, 117.47it/s]


Epoch 50/50 - Loss: 0.4115 - Time: 1m 3s


In [12]:
torch.save(model.state_dict(), "../model/chess_model_pytorch.pth")
torch.save(optimizer.state_dict(), "../model/chess_optimizer_pytorch.pth")

In [13]:
import pickle

with open("../model/move_to_int", "wb") as file:
    pickle.dump(move_to_int, file)