In [159]:
from __future__ import print_function

import glob
from itertools import chain
import os
import random
import zipfile

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.utils.data as data
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from linformer import Linformer
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

from vit_pytorch.efficient import ViT
from vit_pytorch import SimpleViT

In [160]:
print(f"Torch: {torch.__version__}")

Torch: 2.2.1


In [161]:
print("loading image dataset")
images = torch.load("boards2.pt")
print("loading label dataset")
labels = torch.load("labels2.pt")
print(len(images), len(labels))

loading image dataset
loading label dataset
100014 100014


In [162]:
# Hyperparameters:
batch_size = 64
epochs = 30
lr = 5e-4
gamma = 0.7
seed = 142
IMG_SIZE = 8
patch_size = 2
num_classes = 10
dim = 128

In [163]:
device = 'cuda'

In [164]:
class CustomChessDataset(Dataset):
    def __init__(self):
        pass

    def __len__(self):
        return len(labels)

    def __getitem__(self, idx):
        label = labels[idx]
        img = images[idx]
        return img, label

In [165]:
import torch.utils.data as data

full_dataset = CustomChessDataset()

train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size

print("Train size: ", train_size)
print("Test size: ", test_size)

train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])

# Data Loaders:
train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

Train size:  80011
Test size:  20003


In [166]:
# Linear Transformer:
efficient_transformer = Linformer(
    dim=dim,
    seq_len= 65,
    depth=12,
    heads=8,
    k=64
)

# Vision Transformer Model:
model = ViT(dim=dim, image_size=IMG_SIZE, patch_size=patch_size, num_classes=num_classes, transformer=efficient_transformer, channels=1).to(device)

# loss function
criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)

# Learning Rate Scheduler for Optimizer:
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

In [167]:
for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )

  0%|          | 0/1251 [00:00<?, ?it/s]

Epoch : 1 - loss : 1.6715 - acc: 0.3140 - val_loss : 1.6015 - val_acc: 0.3391


  0%|          | 0/1251 [00:00<?, ?it/s]

Epoch : 2 - loss : 1.5430 - acc: 0.3665 - val_loss : 1.5435 - val_acc: 0.3671


  0%|          | 0/1251 [00:00<?, ?it/s]

Epoch : 3 - loss : 1.4446 - acc: 0.4119 - val_loss : 1.5041 - val_acc: 0.3893


  0%|          | 0/1251 [00:00<?, ?it/s]

Epoch : 4 - loss : 1.3484 - acc: 0.4532 - val_loss : 1.4661 - val_acc: 0.4050


  0%|          | 0/1251 [00:00<?, ?it/s]

Epoch : 5 - loss : 1.2591 - acc: 0.4887 - val_loss : 1.4288 - val_acc: 0.4290


  0%|          | 0/1251 [00:00<?, ?it/s]

Epoch : 6 - loss : 1.1747 - acc: 0.5270 - val_loss : 1.4305 - val_acc: 0.4374


  0%|          | 0/1251 [00:00<?, ?it/s]

Epoch : 7 - loss : 1.0966 - acc: 0.5583 - val_loss : 1.3991 - val_acc: 0.4510


  0%|          | 0/1251 [00:00<?, ?it/s]

Epoch : 8 - loss : 1.0101 - acc: 0.5942 - val_loss : 1.4067 - val_acc: 0.4640


  0%|          | 0/1251 [00:00<?, ?it/s]

Epoch : 9 - loss : 0.9259 - acc: 0.6252 - val_loss : 1.4134 - val_acc: 0.4676


  0%|          | 0/1251 [00:00<?, ?it/s]

Epoch : 10 - loss : 0.8369 - acc: 0.6607 - val_loss : 1.4384 - val_acc: 0.4814


  0%|          | 0/1251 [00:00<?, ?it/s]

Epoch : 11 - loss : 0.7492 - acc: 0.6959 - val_loss : 1.4615 - val_acc: 0.4866


  0%|          | 0/1251 [00:00<?, ?it/s]

Epoch : 12 - loss : 0.6665 - acc: 0.7280 - val_loss : 1.4874 - val_acc: 0.5016


  0%|          | 0/1251 [00:00<?, ?it/s]

Epoch : 13 - loss : 0.5922 - acc: 0.7587 - val_loss : 1.5318 - val_acc: 0.5054


  0%|          | 0/1251 [00:00<?, ?it/s]

Epoch : 14 - loss : 0.5223 - acc: 0.7863 - val_loss : 1.5831 - val_acc: 0.5068


  0%|          | 0/1251 [00:00<?, ?it/s]

Epoch : 15 - loss : 0.4665 - acc: 0.8087 - val_loss : 1.6092 - val_acc: 0.5208


  0%|          | 0/1251 [00:00<?, ?it/s]

Epoch : 16 - loss : 0.4199 - acc: 0.8253 - val_loss : 1.6650 - val_acc: 0.5235


  0%|          | 0/1251 [00:00<?, ?it/s]

Epoch : 17 - loss : 0.3852 - acc: 0.8389 - val_loss : 1.7381 - val_acc: 0.5249


  0%|          | 0/1251 [00:00<?, ?it/s]

Epoch : 18 - loss : 0.3555 - acc: 0.8524 - val_loss : 1.7594 - val_acc: 0.5261


  0%|          | 0/1251 [00:00<?, ?it/s]

Epoch : 19 - loss : 0.3289 - acc: 0.8626 - val_loss : 1.7613 - val_acc: 0.5295


  0%|          | 0/1251 [00:00<?, ?it/s]

Epoch : 20 - loss : 0.3100 - acc: 0.8701 - val_loss : 1.8162 - val_acc: 0.5345


  0%|          | 0/1251 [00:00<?, ?it/s]

Epoch : 21 - loss : 0.2982 - acc: 0.8747 - val_loss : 1.8420 - val_acc: 0.5299


  0%|          | 0/1251 [00:00<?, ?it/s]

Epoch : 22 - loss : 0.2782 - acc: 0.8827 - val_loss : 1.8769 - val_acc: 0.5352


  0%|          | 0/1251 [00:00<?, ?it/s]

Epoch : 23 - loss : 0.2712 - acc: 0.8835 - val_loss : 1.9108 - val_acc: 0.5349


  0%|          | 0/1251 [00:00<?, ?it/s]

Epoch : 24 - loss : 0.2562 - acc: 0.8898 - val_loss : 1.9126 - val_acc: 0.5418


  0%|          | 0/1251 [00:00<?, ?it/s]

Epoch : 25 - loss : 0.2545 - acc: 0.8893 - val_loss : 1.9289 - val_acc: 0.5398


  0%|          | 0/1251 [00:00<?, ?it/s]

Epoch : 26 - loss : 0.2416 - acc: 0.8951 - val_loss : 1.8728 - val_acc: 0.5420


  0%|          | 0/1251 [00:00<?, ?it/s]

Epoch : 27 - loss : 0.2367 - acc: 0.8962 - val_loss : 1.9094 - val_acc: 0.5361


  0%|          | 0/1251 [00:00<?, ?it/s]

Epoch : 28 - loss : 0.2300 - acc: 0.8985 - val_loss : 2.0217 - val_acc: 0.5329


  0%|          | 0/1251 [00:00<?, ?it/s]

Epoch : 29 - loss : 0.2296 - acc: 0.8990 - val_loss : 1.9382 - val_acc: 0.5384


  0%|          | 0/1251 [00:00<?, ?it/s]

Epoch : 30 - loss : 0.2194 - acc: 0.9024 - val_loss : 1.9730 - val_acc: 0.5393


In [168]:
torch.save(model, "model2.pt")
print("model saved")

model saved


In [169]:
import chess.pgn
from create_dataset import chess_board_to_image
import io



model = torch.load("model2.pt")

pgn = open("Nakamura.pgn")

game = chess.pgn.read_game(pgn)
header = chess.pgn.read_headers(pgn)
print(header)
#while not game.is_end():
game = game.end()
#game = game.next()
tensor = chess_board_to_image(str(game.board()))
tensor.unsqueeze_(0)
with torch.no_grad():
    predictions = model(tensor.to(device))
    print(predictions)
    probabilities = F.softmax(predictions, dim=1) 
    np_arr = predictions.detach().cpu().numpy()
    
        

Headers(Event='Wch U10', Site='Cannes', Date='1997.??.??', Round='2', White='Nakamura, Hikaru', Black='El Mikati, Mohamad', Result='1-0', WhiteElo='', BlackElo='', ECO='C11')
tensor([[-5.5200, -2.2672, -5.5602, -2.9908,  6.6583, -4.0547, -0.7927,  0.6678,
          6.3818,  5.6270]], device='cuda:0')
