
 # Chess Engine with PyTorch



 ## Imports


In [13]:
import os
import numpy as np # type: ignore
import time
import torch
import torch.nn as nn # type: ignore
import torch.optim as optim # type: ignore
from torch.utils.data import DataLoader # type: ignore
from chess import pgn # type: ignore
from tqdm import tqdm # type: ignore


# Data preprocessing


 ## Load data


In [14]:
def load_pgn(file_path):
    games = []
    with open(file_path, 'r') as pgn_file:
        t=2000
        while t!=0:
            game = pgn.read_game(pgn_file)
            if game is None:
                break
            if game.headers.get("Variant", "Standard") != "Standard":
                continue
            games.append(game)
            t-=1
    return games

# files = [file for file in os.listdir("../../data/pgn") if file.endswith(".pgn")]
# LIMIT_OF_FILES = min(len(files), 28)
files = ["lichess_sai_TejA.pgn"]
games = []
# i = 1
for file in tqdm(files):
    games.extend(load_pgn(f"../../data/pgn/{file}"))
    # if i >= LIMIT_OF_FILES:
    # break
    # i += 1


100%|██████████| 1/1 [00:01<00:00,  1.51s/it]


In [15]:

print(f"GAMES PARSED: {len(games)}")

GAMES PARSED: 2000



## Convert data into tensors

In [16]:
from utils import create_input_for_nn, encode_moves

X, y = create_input_for_nn(games)

print(f"NUMBER OF SAMPLES: {len(y)}")


NUMBER OF SAMPLES: 135225


In [17]:
X = X[0:2500000]
y = y[0:2500000]

In [18]:
y, move_to_int = encode_moves(y)
num_classes = len(move_to_int)

In [19]:
print(num_classes)

1808


In [20]:
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.long)

# Preliminary actions


In [26]:
from dataset import ChessDataset
from chess_model import ChessModel

In [27]:
# Create Dataset and DataLoader
dataset = ChessDataset(X, y)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# Check for GPU
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f'Using device: {device}')

# Model Initialization
model = ChessModel(num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)


Using device: mps



 # Training


In [28]:
num_epochs = 50
for epoch in range(num_epochs):
    start_time = time.time()
    model.train()
    running_loss = 0.0
    for inputs, labels in tqdm(dataloader):
        inputs, labels = inputs.to(device), labels.to(device)  # Move data to GPU
        optimizer.zero_grad()

        outputs = model(inputs)  # Raw logits

        # Compute loss
        loss = criterion(outputs, labels)
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        running_loss += loss.item()
    end_time = time.time()
    epoch_time = end_time - start_time
    minutes: int = int(epoch_time // 60)
    seconds: int = int(epoch_time) - minutes * 60
    print(f'Epoch {epoch + 1 + 50}/{num_epochs + 1 + 50}, Loss: {running_loss / len(dataloader):.4f}, Time: {minutes}m{seconds}s')


100%|██████████| 2113/2113 [00:21<00:00, 98.53it/s] 


Epoch 51/101, Loss: 5.5230, Time: 0m21s


100%|██████████| 2113/2113 [00:17<00:00, 118.68it/s]


Epoch 52/101, Loss: 4.5840, Time: 0m17s


100%|██████████| 2113/2113 [00:18<00:00, 117.05it/s]


Epoch 53/101, Loss: 3.9953, Time: 0m18s


100%|██████████| 2113/2113 [00:18<00:00, 116.54it/s]


Epoch 54/101, Loss: 3.6133, Time: 0m18s


100%|██████████| 2113/2113 [00:17<00:00, 121.34it/s]


Epoch 55/101, Loss: 3.3398, Time: 0m17s


100%|██████████| 2113/2113 [00:17<00:00, 117.53it/s]


Epoch 56/101, Loss: 3.1259, Time: 0m17s


100%|██████████| 2113/2113 [00:17<00:00, 118.60it/s]


Epoch 57/101, Loss: 2.9450, Time: 0m17s


100%|██████████| 2113/2113 [00:18<00:00, 117.29it/s]


Epoch 58/101, Loss: 2.7870, Time: 0m18s


100%|██████████| 2113/2113 [00:18<00:00, 117.07it/s]


Epoch 59/101, Loss: 2.6423, Time: 0m18s


100%|██████████| 2113/2113 [00:18<00:00, 117.21it/s]


Epoch 60/101, Loss: 2.5092, Time: 0m18s


100%|██████████| 2113/2113 [00:17<00:00, 117.88it/s]


Epoch 61/101, Loss: 2.3836, Time: 0m17s


100%|██████████| 2113/2113 [00:17<00:00, 123.38it/s]


Epoch 62/101, Loss: 2.2676, Time: 0m17s


100%|██████████| 2113/2113 [00:17<00:00, 117.74it/s]


Epoch 63/101, Loss: 2.1535, Time: 0m17s


100%|██████████| 2113/2113 [00:17<00:00, 117.83it/s]


Epoch 64/101, Loss: 2.0471, Time: 0m17s


100%|██████████| 2113/2113 [00:17<00:00, 117.90it/s]


Epoch 65/101, Loss: 1.9440, Time: 0m17s


100%|██████████| 2113/2113 [00:18<00:00, 117.26it/s]


Epoch 66/101, Loss: 1.8436, Time: 0m18s


100%|██████████| 2113/2113 [00:18<00:00, 116.91it/s]


Epoch 67/101, Loss: 1.7491, Time: 0m18s


100%|██████████| 2113/2113 [00:18<00:00, 116.98it/s]


Epoch 68/101, Loss: 1.6577, Time: 0m18s


100%|██████████| 2113/2113 [00:18<00:00, 116.93it/s]


Epoch 69/101, Loss: 1.5666, Time: 0m18s


100%|██████████| 2113/2113 [00:18<00:00, 117.17it/s]


Epoch 70/101, Loss: 1.4798, Time: 0m18s


100%|██████████| 2113/2113 [00:18<00:00, 116.60it/s]


Epoch 71/101, Loss: 1.3976, Time: 0m18s


100%|██████████| 2113/2113 [00:18<00:00, 116.74it/s]


Epoch 72/101, Loss: 1.3162, Time: 0m18s


100%|██████████| 2113/2113 [00:17<00:00, 123.07it/s]


Epoch 73/101, Loss: 1.2402, Time: 0m17s


100%|██████████| 2113/2113 [00:17<00:00, 117.72it/s]


Epoch 74/101, Loss: 1.1652, Time: 0m17s


100%|██████████| 2113/2113 [00:17<00:00, 118.82it/s]


Epoch 75/101, Loss: 1.0949, Time: 0m17s


100%|██████████| 2113/2113 [00:17<00:00, 118.15it/s]


Epoch 76/101, Loss: 1.0294, Time: 0m17s


100%|██████████| 2113/2113 [00:17<00:00, 122.55it/s]


Epoch 77/101, Loss: 0.9658, Time: 0m17s


100%|██████████| 2113/2113 [00:17<00:00, 119.83it/s]


Epoch 78/101, Loss: 0.9046, Time: 0m17s


100%|██████████| 2113/2113 [00:18<00:00, 117.05it/s]


Epoch 79/101, Loss: 0.8488, Time: 0m18s


100%|██████████| 2113/2113 [00:18<00:00, 116.49it/s]


Epoch 80/101, Loss: 0.7973, Time: 0m18s


100%|██████████| 2113/2113 [00:18<00:00, 116.38it/s]


Epoch 81/101, Loss: 0.7467, Time: 0m18s


100%|██████████| 2113/2113 [00:17<00:00, 118.61it/s]


Epoch 82/101, Loss: 0.7017, Time: 0m17s


100%|██████████| 2113/2113 [00:18<00:00, 115.50it/s]


Epoch 83/101, Loss: 0.6604, Time: 0m18s


100%|██████████| 2113/2113 [00:17<00:00, 117.78it/s]


Epoch 84/101, Loss: 0.6221, Time: 0m17s


100%|██████████| 2113/2113 [00:17<00:00, 117.58it/s]


Epoch 85/101, Loss: 0.5857, Time: 0m17s


100%|██████████| 2113/2113 [00:18<00:00, 117.29it/s]


Epoch 86/101, Loss: 0.5549, Time: 0m18s


100%|██████████| 2113/2113 [00:17<00:00, 119.45it/s]


Epoch 87/101, Loss: 0.5250, Time: 0m17s


100%|██████████| 2113/2113 [00:17<00:00, 118.50it/s]


Epoch 88/101, Loss: 0.4988, Time: 0m17s


100%|██████████| 2113/2113 [00:18<00:00, 116.99it/s]


Epoch 89/101, Loss: 0.4754, Time: 0m18s


100%|██████████| 2113/2113 [00:18<00:00, 117.33it/s]


Epoch 90/101, Loss: 0.4533, Time: 0m18s


100%|██████████| 2113/2113 [00:17<00:00, 117.42it/s]


Epoch 91/101, Loss: 0.4348, Time: 0m17s


100%|██████████| 2113/2113 [00:17<00:00, 117.94it/s]


Epoch 92/101, Loss: 0.4154, Time: 0m17s


100%|██████████| 2113/2113 [00:17<00:00, 117.52it/s]


Epoch 93/101, Loss: 0.3991, Time: 0m17s


100%|██████████| 2113/2113 [00:17<00:00, 117.91it/s]


Epoch 94/101, Loss: 0.3859, Time: 0m17s


100%|██████████| 2113/2113 [00:17<00:00, 118.87it/s]


Epoch 95/101, Loss: 0.3724, Time: 0m17s


100%|██████████| 2113/2113 [00:18<00:00, 116.80it/s]


Epoch 96/101, Loss: 0.3590, Time: 0m18s


100%|██████████| 2113/2113 [00:18<00:00, 117.23it/s]


Epoch 97/101, Loss: 0.3501, Time: 0m18s


100%|██████████| 2113/2113 [00:17<00:00, 118.78it/s]


Epoch 98/101, Loss: 0.3393, Time: 0m17s


100%|██████████| 2113/2113 [00:18<00:00, 117.03it/s]


Epoch 99/101, Loss: 0.3314, Time: 0m18s


100%|██████████| 2113/2113 [00:17<00:00, 123.06it/s]

Epoch 100/101, Loss: 0.3216, Time: 0m17s






# Save the model and mapping

In [29]:
# Save the model
torch.save(model.state_dict(), "../../models/TORCH_100EPOCHS.pth")

In [30]:
import pickle

with open("../../models/move_to_int", "wb") as file:
    pickle.dump(move_to_int, file)