In [1]:
import os
import numpy as np # type: ignore
import torch
import time
import torch.nn as nn # type: ignore
import torch.optim as optim # type: ignore
from torch.utils.data import DataLoader # type: ignore
from chess import pgn # type: ignore
from tqdm import tqdm # type: ignore

# Data preprocessing

## Load data

In [4]:
def load_pgn(file_path):
    games = []
    with open(file_path, 'r') as pgn_file:
        while True:
            game = pgn.read_game(pgn_file)
            if game is None:
                break
            games.append(game)
    return games

files = [file for file in os.listdir("../../data/pgn") if file.endswith(".pgn")]
LIMIT_OF_FILES = min(len(files), 28)
games = []
i = 1
for file in tqdm(files):
    games.extend(load_pgn(f"../../data/pgn/{file}"))
    if i >= LIMIT_OF_FILES:
        break
    i += 1

 34%|███▍      | 27/79 [01:11<02:18,  2.66s/it]


In [5]:
print(f"GAMES PARSED: {len(games)}")

GAMES PARSED: 41570


## Convert data into tensors

In [1]:
from auxiliary_func import create_input_for_nn, encode_moves

In [7]:
X, y = create_input_for_nn(games)

print(f"NUMBER OF SAMPLES: {len(y)}")

NUMBER OF SAMPLES: 3332761


In [8]:
X = X[0:2500000]
y = y[0:2500000]

In [9]:
y, move_to_int = encode_moves(y)
num_classes = len(move_to_int)

In [10]:
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.long)

The convolutional layers in PyTorch expects the input to be in the shape ```[batch_size, channels, height, width]```, 

but the current input is in the shape ```[batch_size, height, width, channels]```

In [11]:
print(f"before: {X.shape}")
X = np.transpose(X, (0, 3, 1, 2))  # Change shape from [num_samples, 8, 8, 13] to [num_samples, 8, 13, 8]
print(f"after: {X.shape}")

before: torch.Size([2500000, 8, 8, 13])
after: torch.Size([2500000, 13, 8, 8])


# Preliminary actions

In [4]:
from dataset import ChessDataset
from model import ChessModel

In [12]:
# Create Dataset and DataLoader
dataset = ChessDataset(X, y)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

# Model Initialization
model = ChessModel(num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

Using device: cuda


# Training

In [13]:
# Training Loop with Verbose Output
num_epochs = 50
for epoch in range(num_epochs):
    start_time = time.time()
    model.train()
    running_loss = 0.0
    for inputs, labels in tqdm(dataloader):
        inputs, labels = inputs.to(device), labels.to(device)  # Move data to GPU
        optimizer.zero_grad()

        outputs = model(inputs)  # Raw logits

        # Compute loss
        loss = criterion(outputs, labels)
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        running_loss += loss.item()
    end_time = time.time()
    epoch_time = end_time - start_time
    minutes: int = int(epoch_time // 60)
    seconds: int = int(epoch_time) - minutes * 60
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / len(dataloader):.4f}, Time: {minutes}m{seconds}s')

100%|██████████| 39063/39063 [01:40<00:00, 388.23it/s]


Epoch 1/50, Loss: 3.5800, Time: 1m40s


100%|██████████| 39063/39063 [01:33<00:00, 419.75it/s]


Epoch 2/50, Loss: 2.6938, Time: 1m33s


100%|██████████| 39063/39063 [01:32<00:00, 420.76it/s]


Epoch 3/50, Loss: 2.4705, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 420.70it/s]


Epoch 4/50, Loss: 2.3449, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 420.21it/s]


Epoch 5/50, Loss: 2.2568, Time: 1m32s


100%|██████████| 39063/39063 [01:33<00:00, 419.57it/s]


Epoch 6/50, Loss: 2.1892, Time: 1m33s


100%|██████████| 39063/39063 [01:32<00:00, 420.17it/s]


Epoch 7/50, Loss: 2.1347, Time: 1m32s


100%|██████████| 39063/39063 [01:33<00:00, 419.96it/s]


Epoch 8/50, Loss: 2.0892, Time: 1m33s


100%|██████████| 39063/39063 [01:33<00:00, 419.41it/s]


Epoch 9/50, Loss: 2.0497, Time: 1m33s


100%|██████████| 39063/39063 [01:33<00:00, 419.89it/s]


Epoch 10/50, Loss: 2.0150, Time: 1m33s


100%|██████████| 39063/39063 [01:33<00:00, 418.96it/s]


Epoch 11/50, Loss: 1.9838, Time: 1m33s


100%|██████████| 39063/39063 [01:33<00:00, 419.38it/s]


Epoch 12/50, Loss: 1.9557, Time: 1m33s


100%|██████████| 39063/39063 [01:32<00:00, 420.04it/s]


Epoch 13/50, Loss: 1.9303, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 420.88it/s]


Epoch 14/50, Loss: 1.9068, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 420.91it/s]


Epoch 15/50, Loss: 1.8852, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 421.05it/s]


Epoch 16/50, Loss: 1.8658, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 422.01it/s]


Epoch 17/50, Loss: 1.8472, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 421.28it/s]


Epoch 18/50, Loss: 1.8297, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 421.53it/s]


Epoch 19/50, Loss: 1.8132, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 420.96it/s]


Epoch 20/50, Loss: 1.7976, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 421.46it/s]


Epoch 21/50, Loss: 1.7828, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 421.12it/s]


Epoch 22/50, Loss: 1.7686, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 421.89it/s]


Epoch 23/50, Loss: 1.7557, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 420.88it/s]


Epoch 24/50, Loss: 1.7432, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 420.62it/s]


Epoch 25/50, Loss: 1.7314, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 422.86it/s]


Epoch 26/50, Loss: 1.7199, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 422.24it/s]


Epoch 27/50, Loss: 1.7096, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 422.60it/s]


Epoch 28/50, Loss: 1.6993, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 421.46it/s]


Epoch 29/50, Loss: 1.6894, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 422.45it/s]


Epoch 30/50, Loss: 1.6802, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 422.33it/s]


Epoch 31/50, Loss: 1.6714, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 421.60it/s]


Epoch 32/50, Loss: 1.6629, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 422.44it/s]


Epoch 33/50, Loss: 1.6547, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 421.41it/s]


Epoch 34/50, Loss: 1.6470, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 421.73it/s]


Epoch 35/50, Loss: 1.6397, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 421.59it/s]


Epoch 36/50, Loss: 1.6326, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 421.04it/s]


Epoch 37/50, Loss: 1.6252, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 421.69it/s]


Epoch 38/50, Loss: 1.6189, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 422.47it/s]


Epoch 39/50, Loss: 1.6123, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 422.88it/s]


Epoch 40/50, Loss: 1.6061, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 420.52it/s]


Epoch 41/50, Loss: 1.6001, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 422.33it/s]


Epoch 42/50, Loss: 1.5937, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 422.42it/s]


Epoch 43/50, Loss: 1.5881, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 422.54it/s]


Epoch 44/50, Loss: 1.5827, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 420.82it/s]


Epoch 45/50, Loss: 1.5770, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 421.12it/s]


Epoch 46/50, Loss: 1.5715, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 421.18it/s]


Epoch 47/50, Loss: 1.5663, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 421.23it/s]


Epoch 48/50, Loss: 1.5610, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 421.66it/s]


Epoch 49/50, Loss: 1.5563, Time: 1m32s


100%|██████████| 39063/39063 [01:32<00:00, 420.52it/s]

Epoch 50/50, Loss: 1.5515, Time: 1m32s





# Save the model and mappings

In [17]:
# Save the model
torch.save(model.state_dict(), "../../model/heavy_chess_model.pth")

In [22]:
import pickle

with open("../../model/heavy_move_to_int", "wb") as file:
    pickle.dump(move_to_int, file)