### Imports

In [1]:
import os
import numpy as np
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from chess import pgn
from tqdm import tqdm


### Loading Data

In [2]:

def load_pgn(file_path):
    games = []
    with open(file_path, 'r', encoding='utf-8', errors='ignore') as pgn_file:
        while True:
            game = pgn.read_game(pgn_file)
            if game is None:
                break
            games.append(game)
    return games

files = [file for file in os.listdir("/home/rusted/Projects/chessengine/data/pgn") if file.endswith(".pgn")]
#FILE_LIMIT = min(len(files),28)
games = []
i = 1
for file in tqdm(files):
    games.extend(load_pgn(f"/home/rusted/Projects/chessengine/data/pgn/{file}"))
    #if i >= FILE_LIMIT:
    break
    i += 1 

  0%|          | 0/1 [00:17<?, ?it/s]


In [3]:
len(games)

8439

### Convert data to tensors

In [4]:
from auxiliary_functions import create_input_for_nn, encode_moves

In [5]:
X, y = create_input_for_nn(games)

print(f"SAMPLES : {len(y)}")

SAMPLES : 723923


In [6]:
X = X[0:1000000]
y = y[0:1000000]

In [7]:
y, move_to_int = encode_moves(y)
num_classes = len(move_to_int)

In [8]:
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.long)

### Training

In [9]:
from dataset import ChessDataset
from model import ChessModel

In [10]:
# Make dataset and DataLoader
dataset = ChessDataset(X, y)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

# Initializing the model
model = ChessModel(num_classes=num_classes).to(device)
loss_criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)


Using device: cuda


In [11]:
num_epochs = 50
for epoch in range(num_epochs):
    start_time = time.time()
    model.train()
    total_loss = 0.0
    for inputs, labels in tqdm(dataloader):
        inputs, labels = inputs.to(device), labels.to(device) # Move data to GPU
        optimizer.zero_grad()
        
        outputs = model(inputs) # Logits

        # Compute loss
        loss = loss_criterion(outputs, labels)
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()
        total_loss += loss.item()
    end_time = time.time()
    epoch_time = end_time - start_time
    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(dataloader):.4f}, Time: {epoch_time}')

100%|██████████| 11312/11312 [00:37<00:00, 301.58it/s]


Epoch 1/50, Loss: 4.5913, Time: 37.51087141036987


100%|██████████| 11312/11312 [00:37<00:00, 304.89it/s]


Epoch 2/50, Loss: 3.4590, Time: 37.10342717170715


100%|██████████| 11312/11312 [00:39<00:00, 284.44it/s]


Epoch 3/50, Loss: 3.0693, Time: 39.77075242996216


100%|██████████| 11312/11312 [00:37<00:00, 305.08it/s]


Epoch 4/50, Loss: 2.8321, Time: 37.081443309783936


100%|██████████| 11312/11312 [00:37<00:00, 301.67it/s]


Epoch 5/50, Loss: 2.6655, Time: 37.4993155002594


100%|██████████| 11312/11312 [00:37<00:00, 298.70it/s]


Epoch 6/50, Loss: 2.5342, Time: 37.872594118118286


100%|██████████| 11312/11312 [00:37<00:00, 304.37it/s]


Epoch 7/50, Loss: 2.4234, Time: 37.16701650619507


100%|██████████| 11312/11312 [00:35<00:00, 314.82it/s]


Epoch 8/50, Loss: 2.3264, Time: 35.93359875679016


100%|██████████| 11312/11312 [00:36<00:00, 306.73it/s]


Epoch 9/50, Loss: 2.2372, Time: 36.88014101982117


100%|██████████| 11312/11312 [00:37<00:00, 298.87it/s]


Epoch 10/50, Loss: 2.1569, Time: 37.851282835006714


100%|██████████| 11312/11312 [00:36<00:00, 313.82it/s]


Epoch 11/50, Loss: 2.0819, Time: 36.047269344329834


100%|██████████| 11312/11312 [00:37<00:00, 301.25it/s]


Epoch 12/50, Loss: 2.0121, Time: 37.55250430107117


100%|██████████| 11312/11312 [00:36<00:00, 311.44it/s]


Epoch 13/50, Loss: 1.9466, Time: 36.32283616065979


100%|██████████| 11312/11312 [00:36<00:00, 311.90it/s]


Epoch 14/50, Loss: 1.8843, Time: 36.27039194107056


100%|██████████| 11312/11312 [00:36<00:00, 311.71it/s]


Epoch 15/50, Loss: 1.8247, Time: 36.29286050796509


100%|██████████| 11312/11312 [00:36<00:00, 309.80it/s]


Epoch 16/50, Loss: 1.7693, Time: 36.515533447265625


100%|██████████| 11312/11312 [00:36<00:00, 311.76it/s]


Epoch 17/50, Loss: 1.7166, Time: 36.286365032196045


100%|██████████| 11312/11312 [00:36<00:00, 311.85it/s]


Epoch 18/50, Loss: 1.6651, Time: 36.27550029754639


100%|██████████| 11312/11312 [00:36<00:00, 312.57it/s]


Epoch 19/50, Loss: 1.6171, Time: 36.19163656234741


100%|██████████| 11312/11312 [00:37<00:00, 303.05it/s]


Epoch 20/50, Loss: 1.5711, Time: 37.32876753807068


100%|██████████| 11312/11312 [00:36<00:00, 305.90it/s]


Epoch 21/50, Loss: 1.5276, Time: 36.98045039176941


100%|██████████| 11312/11312 [00:37<00:00, 301.59it/s]


Epoch 22/50, Loss: 1.4850, Time: 37.509441614151


100%|██████████| 11312/11312 [00:37<00:00, 305.38it/s]


Epoch 23/50, Loss: 1.4447, Time: 37.044172525405884


100%|██████████| 11312/11312 [00:36<00:00, 310.74it/s]


Epoch 24/50, Loss: 1.4058, Time: 36.40485715866089


100%|██████████| 11312/11312 [00:37<00:00, 305.01it/s]


Epoch 25/50, Loss: 1.3688, Time: 37.08911895751953


100%|██████████| 11312/11312 [00:37<00:00, 301.39it/s]


Epoch 26/50, Loss: 1.3330, Time: 37.53433847427368


100%|██████████| 11312/11312 [00:38<00:00, 295.98it/s]


Epoch 27/50, Loss: 1.2991, Time: 38.220566272735596


100%|██████████| 11312/11312 [00:36<00:00, 308.15it/s]


Epoch 28/50, Loss: 1.2665, Time: 36.7112603187561


100%|██████████| 11312/11312 [00:35<00:00, 316.46it/s]


Epoch 29/50, Loss: 1.2340, Time: 35.74732995033264


100%|██████████| 11312/11312 [00:32<00:00, 347.31it/s]


Epoch 30/50, Loss: 1.2033, Time: 32.572322845458984


100%|██████████| 11312/11312 [00:32<00:00, 344.05it/s]


Epoch 31/50, Loss: 1.1740, Time: 32.88073253631592


100%|██████████| 11312/11312 [00:31<00:00, 353.86it/s]


Epoch 32/50, Loss: 1.1451, Time: 31.969374179840088


100%|██████████| 11312/11312 [00:32<00:00, 345.19it/s]


Epoch 33/50, Loss: 1.1175, Time: 32.77223515510559


100%|██████████| 11312/11312 [00:38<00:00, 293.22it/s]


Epoch 34/50, Loss: 1.0908, Time: 38.579925298690796


100%|██████████| 11312/11312 [00:36<00:00, 308.54it/s]


Epoch 35/50, Loss: 1.0653, Time: 36.66445517539978


100%|██████████| 11312/11312 [00:37<00:00, 300.91it/s]


Epoch 36/50, Loss: 1.0400, Time: 37.595386028289795


100%|██████████| 11312/11312 [00:36<00:00, 313.12it/s]


Epoch 37/50, Loss: 1.0152, Time: 36.12775254249573


100%|██████████| 11312/11312 [00:36<00:00, 312.83it/s]


Epoch 38/50, Loss: 0.9924, Time: 36.16213274002075


100%|██████████| 11312/11312 [00:37<00:00, 303.98it/s]


Epoch 39/50, Loss: 0.9696, Time: 37.21456742286682


100%|██████████| 11312/11312 [00:36<00:00, 311.68it/s]


Epoch 40/50, Loss: 0.9479, Time: 36.295090198516846


100%|██████████| 11312/11312 [00:36<00:00, 313.78it/s]


Epoch 41/50, Loss: 0.9267, Time: 36.05256938934326


100%|██████████| 11312/11312 [00:35<00:00, 317.14it/s]


Epoch 42/50, Loss: 0.9060, Time: 35.67002725601196


100%|██████████| 11312/11312 [00:38<00:00, 293.00it/s]


Epoch 43/50, Loss: 0.8862, Time: 38.60930895805359


100%|██████████| 11312/11312 [00:37<00:00, 304.86it/s]


Epoch 44/50, Loss: 0.8672, Time: 37.107171297073364


100%|██████████| 11312/11312 [00:37<00:00, 302.38it/s]


Epoch 45/50, Loss: 0.8484, Time: 37.41244578361511


100%|██████████| 11312/11312 [00:35<00:00, 320.34it/s]


Epoch 46/50, Loss: 0.8315, Time: 35.31370449066162


100%|██████████| 11312/11312 [00:31<00:00, 358.70it/s]


Epoch 47/50, Loss: 0.8139, Time: 31.53785729408264


100%|██████████| 11312/11312 [00:32<00:00, 349.85it/s]


Epoch 48/50, Loss: 0.7969, Time: 32.33605241775513


100%|██████████| 11312/11312 [00:37<00:00, 300.86it/s]


Epoch 49/50, Loss: 0.7795, Time: 37.600913524627686


100%|██████████| 11312/11312 [00:35<00:00, 314.32it/s]

Epoch 50/50, Loss: 0.7647, Time: 35.99007201194763





### Save model and mappings

In [12]:
torch.save(model.state_dict(),"/home/rusted/Projects/chessengine/models/CHESSDATA2.pth")

In [13]:
import pickle
with open("/home/rusted/Projects/chessengine/models/move_to_int2", "wb") as file:
    pickle.dump(move_to_int, file)