# Шахматный движок на PyTorch

## Импорты

In [1]:
import os
import numpy as np
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from chess import pgn
from tqdm import tqdm

# Подготовка данных

## Загрузка данных

In [2]:
def load_pgn(file_path):
    games = []
    with open(file_path, 'r') as pgn_file:
        while True:
            game = pgn.read_game(pgn_file)
            if game is None:
                break
            games.append(game)
    return games

files = [file for file in os.listdir("../../data/pgn") if file.endswith(".pgn")]
LIMIT_OF_FILES = min(len(files), 28)
games = []
i = 1
for file in tqdm(files):
    games.extend(load_pgn(f"../../data/pgn/{file}"))
    if i >= LIMIT_OF_FILES:
        break
    i += 1

 34%|███▍      | 27/79 [03:08<06:03,  6.99s/it]


In [3]:
print(f"GAMES PARSED: {len(games)}")

GAMES PARSED: 41570


## Преобразование данных в тензоры

In [4]:
from auxiliary_func import create_input_for_nn, encode_moves

In [5]:
X, y = create_input_for_nn(games)

print(f"NUMBER OF SAMPLES: {len(y)}")

NUMBER OF SAMPLES: 3332761


In [6]:
X = X[0:2500000]
y = y[0:2500000]

In [7]:
y, move_to_int = encode_moves(y)
num_classes = len(move_to_int)

In [8]:
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.long)

# Создание набора данных

In [9]:
from dataset import ChessDataset
from model import ChessModel

In [10]:
dataset = ChessDataset(X, y)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# Проверка доступности CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

# Инициализация модели
model = ChessModel(num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

Using device: cuda


# Обучение

In [None]:
num_epochs = 100
for epoch in range(num_epochs):
    start_time = time.time()
    model.train()
    running_loss = 0.0
    for inputs, labels in tqdm(dataloader):
        inputs, labels = inputs.to(device), labels.to(device)
        # В моём случае вычисления будут производиться на видеокарте
        # Если CUDA не установлена, вычисления будут производиться на центральном процессоре
        optimizer.zero_grad()

        outputs = model(inputs)  # Необработанные логиты

        # Вычисление издержек
        loss = criterion(outputs, labels)
        loss.backward()
        
        # Обрезка градиента
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        running_loss += loss.item()
    end_time = time.time()
    epoch_time = end_time - start_time
    minutes: int = int(epoch_time // 60)
    seconds: int = int(epoch_time) - minutes * 60
    print(f'Epoch {epoch}/{num_epochs}, Loss: {running_loss / len(dataloader):.4f}, Time: {minutes}m{seconds}s')

100%|██████████| 39063/39063 [04:46<00:00, 136.48it/s]


Epoch 0/100, Loss: 3.5637, Time: 4m46s


100%|██████████| 39063/39063 [05:37<00:00, 115.82it/s]


Epoch 1/100, Loss: 2.6847, Time: 5m37s


100%|██████████| 39063/39063 [04:36<00:00, 141.39it/s]


Epoch 2/100, Loss: 2.4634, Time: 4m36s


100%|██████████| 39063/39063 [04:33<00:00, 143.01it/s]


Epoch 3/100, Loss: 2.3363, Time: 4m33s


100%|██████████| 39063/39063 [04:42<00:00, 138.17it/s]


Epoch 4/100, Loss: 2.2479, Time: 4m42s


100%|██████████| 39063/39063 [04:43<00:00, 137.92it/s]


Epoch 5/100, Loss: 2.1799, Time: 4m43s


100%|██████████| 39063/39063 [04:30<00:00, 144.40it/s]


Epoch 6/100, Loss: 2.1244, Time: 4m30s


100%|██████████| 39063/39063 [04:34<00:00, 142.44it/s]


Epoch 7/100, Loss: 2.0779, Time: 4m34s


100%|██████████| 39063/39063 [04:33<00:00, 142.96it/s]


Epoch 8/100, Loss: 2.0378, Time: 4m33s


100%|██████████| 39063/39063 [04:24<00:00, 147.71it/s]


Epoch 9/100, Loss: 2.0027, Time: 4m24s


100%|██████████| 39063/39063 [04:27<00:00, 146.27it/s]


Epoch 10/100, Loss: 1.9709, Time: 4m27s


100%|██████████| 39063/39063 [04:27<00:00, 146.10it/s]


Epoch 11/100, Loss: 1.9426, Time: 4m27s


100%|██████████| 39063/39063 [04:33<00:00, 142.61it/s]


Epoch 12/100, Loss: 1.9168, Time: 4m33s


100%|██████████| 39063/39063 [04:30<00:00, 144.28it/s]


Epoch 13/100, Loss: 1.8935, Time: 4m30s


100%|██████████| 39063/39063 [04:28<00:00, 145.73it/s]


Epoch 14/100, Loss: 1.8719, Time: 4m28s


100%|██████████| 39063/39063 [04:27<00:00, 145.83it/s]


Epoch 15/100, Loss: 1.8525, Time: 4m27s


100%|██████████| 39063/39063 [04:36<00:00, 141.52it/s]


Epoch 16/100, Loss: 1.8339, Time: 4m36s


100%|██████████| 39063/39063 [04:33<00:00, 143.02it/s]


Epoch 17/100, Loss: 1.8169, Time: 4m33s


100%|██████████| 39063/39063 [04:41<00:00, 138.64it/s]


Epoch 18/100, Loss: 1.8012, Time: 4m41s


100%|██████████| 39063/39063 [04:36<00:00, 141.33it/s]


Epoch 19/100, Loss: 1.7869, Time: 4m36s


100%|██████████| 39063/39063 [04:32<00:00, 143.54it/s]


Epoch 20/100, Loss: 1.7733, Time: 4m32s


100%|██████████| 39063/39063 [04:39<00:00, 139.88it/s]


Epoch 21/100, Loss: 1.7606, Time: 4m39s


100%|██████████| 39063/39063 [04:40<00:00, 139.18it/s]


Epoch 22/100, Loss: 1.7489, Time: 4m40s


100%|██████████| 39063/39063 [04:46<00:00, 136.33it/s]


Epoch 23/100, Loss: 1.7373, Time: 4m46s


100%|██████████| 39063/39063 [04:28<00:00, 145.56it/s]


Epoch 24/100, Loss: 1.7265, Time: 4m28s


100%|██████████| 39063/39063 [04:37<00:00, 140.90it/s]


Epoch 25/100, Loss: 1.7168, Time: 4m37s


100%|██████████| 39063/39063 [04:32<00:00, 143.39it/s]


Epoch 26/100, Loss: 1.7070, Time: 4m32s


100%|██████████| 39063/39063 [04:25<00:00, 146.85it/s]


Epoch 27/100, Loss: 1.6976, Time: 4m26s


100%|██████████| 39063/39063 [04:30<00:00, 144.56it/s]


Epoch 28/100, Loss: 1.6892, Time: 4m30s


100%|██████████| 39063/39063 [04:26<00:00, 146.39it/s]


Epoch 29/100, Loss: 1.6805, Time: 4m26s


100%|██████████| 39063/39063 [04:26<00:00, 146.58it/s]


Epoch 30/100, Loss: 1.6721, Time: 4m26s


100%|██████████| 39063/39063 [04:31<00:00, 144.13it/s]


Epoch 31/100, Loss: 1.6641, Time: 4m31s


100%|██████████| 39063/39063 [04:31<00:00, 144.10it/s]


Epoch 32/100, Loss: 1.6568, Time: 4m31s


100%|██████████| 39063/39063 [04:26<00:00, 146.43it/s]


Epoch 33/100, Loss: 1.6494, Time: 4m26s


100%|██████████| 39063/39063 [04:33<00:00, 142.57it/s]


Epoch 34/100, Loss: 1.6418, Time: 4m33s


100%|██████████| 39063/39063 [04:38<00:00, 140.51it/s]


Epoch 35/100, Loss: 1.6354, Time: 4m38s


100%|██████████| 39063/39063 [04:23<00:00, 148.45it/s]


Epoch 36/100, Loss: 1.6292, Time: 4m23s


100%|██████████| 39063/39063 [04:36<00:00, 141.17it/s]


Epoch 37/100, Loss: 1.6230, Time: 4m36s


100%|██████████| 39063/39063 [04:26<00:00, 146.52it/s]


Epoch 38/100, Loss: 1.6171, Time: 4m26s


100%|██████████| 39063/39063 [04:27<00:00, 145.96it/s]


Epoch 39/100, Loss: 1.6112, Time: 4m27s


100%|██████████| 39063/39063 [04:28<00:00, 145.53it/s]


Epoch 40/100, Loss: 1.6055, Time: 4m28s


100%|██████████| 39063/39063 [04:28<00:00, 145.27it/s]


Epoch 41/100, Loss: 1.6006, Time: 4m28s


100%|██████████| 39063/39063 [04:26<00:00, 146.37it/s]


Epoch 42/100, Loss: 1.5953, Time: 4m26s


100%|██████████| 39063/39063 [04:33<00:00, 142.69it/s]


Epoch 43/100, Loss: 1.5902, Time: 4m33s


100%|██████████| 39063/39063 [04:32<00:00, 143.44it/s]


Epoch 44/100, Loss: 1.5852, Time: 4m32s


100%|██████████| 39063/39063 [04:30<00:00, 144.48it/s]


Epoch 45/100, Loss: 1.5808, Time: 4m30s


100%|██████████| 39063/39063 [04:33<00:00, 142.83it/s]


Epoch 46/100, Loss: 1.5764, Time: 4m33s


100%|██████████| 39063/39063 [04:32<00:00, 143.11it/s]


Epoch 47/100, Loss: 1.5722, Time: 4m32s


100%|██████████| 39063/39063 [04:27<00:00, 146.07it/s]


Epoch 48/100, Loss: 1.5677, Time: 4m27s


100%|██████████| 39063/39063 [04:35<00:00, 141.96it/s]


Epoch 49/100, Loss: 1.5641, Time: 4m35s


100%|██████████| 39063/39063 [04:37<00:00, 140.68it/s]


Epoch 50/100, Loss: 1.5599, Time: 4m37s


100%|██████████| 39063/39063 [04:29<00:00, 145.01it/s]


Epoch 51/100, Loss: 1.5563, Time: 4m29s


100%|██████████| 39063/39063 [04:31<00:00, 144.08it/s]


Epoch 52/100, Loss: 1.5525, Time: 4m31s


100%|██████████| 39063/39063 [04:32<00:00, 143.21it/s]


Epoch 53/100, Loss: 1.5488, Time: 4m32s


100%|██████████| 39063/39063 [04:28<00:00, 145.26it/s]


Epoch 54/100, Loss: 1.5453, Time: 4m28s


100%|██████████| 39063/39063 [04:33<00:00, 142.89it/s]


Epoch 55/100, Loss: 1.5421, Time: 4m33s


100%|██████████| 39063/39063 [04:33<00:00, 142.95it/s]


Epoch 56/100, Loss: 1.5387, Time: 4m33s


100%|██████████| 39063/39063 [04:31<00:00, 144.12it/s]


Epoch 57/100, Loss: 1.5352, Time: 4m31s


100%|██████████| 39063/39063 [04:31<00:00, 143.94it/s]


Epoch 58/100, Loss: 1.5326, Time: 4m31s


100%|██████████| 39063/39063 [04:31<00:00, 143.73it/s]


Epoch 59/100, Loss: 1.5291, Time: 4m31s


100%|██████████| 39063/39063 [04:33<00:00, 142.66it/s]


Epoch 60/100, Loss: 1.5263, Time: 4m33s


100%|██████████| 39063/39063 [04:32<00:00, 143.36it/s]


Epoch 61/100, Loss: 1.5230, Time: 4m32s


100%|██████████| 39063/39063 [04:26<00:00, 146.66it/s]


Epoch 62/100, Loss: 1.5195, Time: 4m26s


100%|██████████| 39063/39063 [04:25<00:00, 147.30it/s]


Epoch 63/100, Loss: 1.5168, Time: 4m25s


100%|██████████| 39063/39063 [04:23<00:00, 148.26it/s]


Epoch 64/100, Loss: 1.5138, Time: 4m23s


100%|██████████| 39063/39063 [04:31<00:00, 143.75it/s]


Epoch 65/100, Loss: 1.5108, Time: 4m31s


100%|██████████| 39063/39063 [04:30<00:00, 144.17it/s]


Epoch 66/100, Loss: 1.5081, Time: 4m30s


100%|██████████| 39063/39063 [04:27<00:00, 145.96it/s]


Epoch 67/100, Loss: 1.5053, Time: 4m27s


100%|██████████| 39063/39063 [04:44<00:00, 137.07it/s]


Epoch 68/100, Loss: 1.5024, Time: 4m44s


100%|██████████| 39063/39063 [04:54<00:00, 132.71it/s]


Epoch 69/100, Loss: 1.4999, Time: 4m54s


100%|██████████| 39063/39063 [04:54<00:00, 132.74it/s]


Epoch 70/100, Loss: 1.4971, Time: 4m54s


100%|██████████| 39063/39063 [04:43<00:00, 137.75it/s]


Epoch 71/100, Loss: 1.4949, Time: 4m43s


100%|██████████| 39063/39063 [04:26<00:00, 146.58it/s]


Epoch 72/100, Loss: 1.4920, Time: 4m26s


 27%|██▋       | 10501/39063 [01:11<03:16, 145.54it/s]

# Сохранение модели и маппинга

In [None]:
torch.save(model.state_dict(), "../../models/TORCH_100EPOCHS.pth")

In [None]:
import pickle

with open("../../models/heavy_move_to_int", "wb") as file:
    pickle.dump(move_to_int, file)