In [45]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset

from tqdm.notebook import trange, tqdm

In [46]:
# Define the hyperparameters
learning_rate = 3e-4  # Learning rate for the optimizer
nepochs = 120  # Number of training epochs
batch_size = 64  # Batch size for training

output_size = 2

max_len = 64  # Maximum length of input sequences
chess_piece_type_count = 13 # six for each player and one for blank

In [47]:
board_data = torch.load("boards3-1.pt")
label_data = torch.load("labels3-1.pt")

print("board length: ", len(board_data))
print("label length: ", len(label_data))

board length:  658385
label length:  658385


In [48]:
class CustomChessDataset(Dataset):
    def __init__(self):
        pass

    def __len__(self):
        return len(board_data)

    def __getitem__(self, idx):
        label = label_data[idx]
        text = board_data[idx]
        return torch.tensor([label, 1.0 - label], dtype=torch.float), text

In [49]:
import torch.utils.data as data

full_dataset = CustomChessDataset()

train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size

print("Train size: ", train_size)
print("Test size: ", test_size)

train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])

# Data Loaders:
train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

Train size:  526708
Test size:  131677


## Create LSTM Model

In [50]:
from chess_model import LSTM

## Initialise Model and Optimizer

In [51]:
device = 'cuda'

# Define the size of the hidden layer and number of LSTM layers
hidden_size = 64
num_layers = 3

# Create the LSTM classifier model
lstm_classifier = LSTM(num_emb=chess_piece_type_count, output_size=2, num_layers=num_layers, hidden_size=hidden_size).to(device)

#lstm_classifier = ChessLSTM(chess_piece_type_count, hidden_size, num_layers, output_size, num_emb=chess_piece_type_count).to(device)

# Initialize the optimizer with Adam optimizer
optimizer = optim.Adam(lstm_classifier.parameters(), lr=learning_rate)

# Define the loss function as CrossEntropyLoss for classification
loss_fn = nn.BCEWithLogitsLoss()

# Initialize lists to store training and test loss, as well as accuracy
training_loss_logger = []
test_loss_logger = []
training_acc_logger = []
test_acc_logger = []

In [52]:
# Let's see how many Parameters our Model has!
num_model_params = 0
for param in lstm_classifier.parameters():
    num_model_params += param.flatten().shape[0]

print("-This Model Has %d (Approximately %d Million) Parameters!" % (num_model_params, num_model_params//1e6))

-This Model Has 100802 (Approximately 0 Million) Parameters!


## Training

In [53]:
for epoch in range(nepochs):
    epoch_loss = 0
    epoch_accuracy = 0
    # Set model to training mode
    lstm_classifier.train()
    steps = 0
    
    # Iterate through training data loader
    for label, text in tqdm(train_loader):
        label = label.to(device)
        text = text.to(device)
        
        output = lstm_classifier(text)
        loss = loss_fn(output, label)
            
        # Backpropagation and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Accuracy Calculation
        probs = F.sigmoid(output)  # Convert to probabilities
        predicted_class = (probs > 0.5).float()  # Apply threshold (adjust as needed)
        target_class = (label > 0.5).float()  # Extract the "stronger" class from soft label
        acc = (predicted_class == target_class).float().mean()
        
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)
    
    # Set model to evaluation mode
    lstm_classifier.eval()
    
    # Iterate through test data loader
    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for label, text in tqdm(valid_loader):
            label = label.to(device)
            text = text.to(device)

            val_output = lstm_classifier(text)
            val_loss = loss_fn(val_output, label)
            
            # Accuracy Calculation
            probs = F.sigmoid(val_output)  # Convert to probabilities
            predicted_class = (probs > 0.5).float()  # Apply threshold (adjust as needed)
            target_class = (label > 0.5).float()  # Extract the "stronger" class from soft label
            acc = (predicted_class == target_class).float().mean()
            
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)
            
    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )
            

  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 1 - loss : 0.6553 - acc: 0.6120 - val_loss : 0.6462 - val_acc: 0.6315


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 2 - loss : 0.6457 - acc: 0.6338 - val_loss : 0.6463 - val_acc: 0.6188


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 3 - loss : 0.6443 - acc: 0.6367 - val_loss : 0.6436 - val_acc: 0.6395


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 4 - loss : 0.6433 - acc: 0.6390 - val_loss : 0.6426 - val_acc: 0.6424


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 5 - loss : 0.6426 - acc: 0.6410 - val_loss : 0.6423 - val_acc: 0.6424


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 6 - loss : 0.6419 - acc: 0.6430 - val_loss : 0.6414 - val_acc: 0.6467


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 7 - loss : 0.6413 - acc: 0.6456 - val_loss : 0.6412 - val_acc: 0.6443


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 8 - loss : 0.6406 - acc: 0.6459 - val_loss : 0.6410 - val_acc: 0.6466


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 9 - loss : 0.6401 - acc: 0.6472 - val_loss : 0.6405 - val_acc: 0.6499


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 10 - loss : 0.6394 - acc: 0.6496 - val_loss : 0.6405 - val_acc: 0.6495


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 11 - loss : 0.6388 - acc: 0.6504 - val_loss : 0.6406 - val_acc: 0.6450


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 12 - loss : 0.6382 - acc: 0.6511 - val_loss : 0.6397 - val_acc: 0.6516


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 13 - loss : 0.6377 - acc: 0.6530 - val_loss : 0.6399 - val_acc: 0.6499


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 14 - loss : 0.6369 - acc: 0.6544 - val_loss : 0.6385 - val_acc: 0.6542


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 15 - loss : 0.6364 - acc: 0.6554 - val_loss : 0.6380 - val_acc: 0.6547


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 16 - loss : 0.6357 - acc: 0.6566 - val_loss : 0.6390 - val_acc: 0.6539


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 17 - loss : 0.6352 - acc: 0.6577 - val_loss : 0.6380 - val_acc: 0.6512


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 18 - loss : 0.6346 - acc: 0.6589 - val_loss : 0.6377 - val_acc: 0.6488


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 19 - loss : 0.6340 - acc: 0.6593 - val_loss : 0.6371 - val_acc: 0.6567


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 20 - loss : 0.6334 - acc: 0.6612 - val_loss : 0.6374 - val_acc: 0.6555


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 21 - loss : 0.6328 - acc: 0.6621 - val_loss : 0.6370 - val_acc: 0.6564


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 22 - loss : 0.6324 - acc: 0.6629 - val_loss : 0.6363 - val_acc: 0.6587


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 23 - loss : 0.6318 - acc: 0.6643 - val_loss : 0.6370 - val_acc: 0.6562


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 24 - loss : 0.6314 - acc: 0.6653 - val_loss : 0.6369 - val_acc: 0.6579


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 25 - loss : 0.6308 - acc: 0.6672 - val_loss : 0.6370 - val_acc: 0.6577


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 26 - loss : 0.6304 - acc: 0.6676 - val_loss : 0.6357 - val_acc: 0.6595


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 27 - loss : 0.6297 - acc: 0.6690 - val_loss : 0.6355 - val_acc: 0.6612


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 28 - loss : 0.6294 - acc: 0.6687 - val_loss : 0.6359 - val_acc: 0.6611


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 29 - loss : 0.6289 - acc: 0.6700 - val_loss : 0.6367 - val_acc: 0.6606


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 30 - loss : 0.6285 - acc: 0.6710 - val_loss : 0.6366 - val_acc: 0.6592


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 31 - loss : 0.6280 - acc: 0.6712 - val_loss : 0.6359 - val_acc: 0.6613


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 32 - loss : 0.6274 - acc: 0.6726 - val_loss : 0.6358 - val_acc: 0.6599


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 33 - loss : 0.6271 - acc: 0.6734 - val_loss : 0.6358 - val_acc: 0.6626


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 34 - loss : 0.6267 - acc: 0.6736 - val_loss : 0.6355 - val_acc: 0.6617


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 35 - loss : 0.6262 - acc: 0.6747 - val_loss : 0.6359 - val_acc: 0.6623


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 36 - loss : 0.6260 - acc: 0.6756 - val_loss : 0.6351 - val_acc: 0.6591


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 37 - loss : 0.6255 - acc: 0.6755 - val_loss : 0.6352 - val_acc: 0.6610


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 38 - loss : 0.6251 - acc: 0.6763 - val_loss : 0.6346 - val_acc: 0.6624


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 39 - loss : 0.6249 - acc: 0.6761 - val_loss : 0.6356 - val_acc: 0.6622


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 40 - loss : 0.6245 - acc: 0.6773 - val_loss : 0.6349 - val_acc: 0.6630


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 41 - loss : 0.6241 - acc: 0.6779 - val_loss : 0.6359 - val_acc: 0.6626


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 42 - loss : 0.6238 - acc: 0.6789 - val_loss : 0.6349 - val_acc: 0.6629


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 43 - loss : 0.6236 - acc: 0.6793 - val_loss : 0.6356 - val_acc: 0.6604


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 44 - loss : 0.6232 - acc: 0.6794 - val_loss : 0.6358 - val_acc: 0.6610


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 45 - loss : 0.6229 - acc: 0.6802 - val_loss : 0.6353 - val_acc: 0.6629


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 46 - loss : 0.6226 - acc: 0.6807 - val_loss : 0.6351 - val_acc: 0.6641


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 47 - loss : 0.6225 - acc: 0.6805 - val_loss : 0.6357 - val_acc: 0.6604


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 48 - loss : 0.6221 - acc: 0.6812 - val_loss : 0.6351 - val_acc: 0.6635


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 49 - loss : 0.6217 - acc: 0.6821 - val_loss : 0.6348 - val_acc: 0.6633


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 50 - loss : 0.6214 - acc: 0.6818 - val_loss : 0.6352 - val_acc: 0.6617


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 51 - loss : 0.6213 - acc: 0.6819 - val_loss : 0.6349 - val_acc: 0.6635


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 52 - loss : 0.6211 - acc: 0.6824 - val_loss : 0.6349 - val_acc: 0.6642


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 53 - loss : 0.6207 - acc: 0.6828 - val_loss : 0.6354 - val_acc: 0.6633


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 54 - loss : 0.6204 - acc: 0.6833 - val_loss : 0.6345 - val_acc: 0.6643


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 55 - loss : 0.6204 - acc: 0.6840 - val_loss : 0.6346 - val_acc: 0.6644


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 56 - loss : 0.6201 - acc: 0.6839 - val_loss : 0.6347 - val_acc: 0.6654


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 57 - loss : 0.6199 - acc: 0.6853 - val_loss : 0.6348 - val_acc: 0.6648


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 58 - loss : 0.6197 - acc: 0.6848 - val_loss : 0.6353 - val_acc: 0.6641


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 59 - loss : 0.6196 - acc: 0.6856 - val_loss : 0.6354 - val_acc: 0.6639


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 60 - loss : 0.6191 - acc: 0.6853 - val_loss : 0.6348 - val_acc: 0.6641


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 61 - loss : 0.6192 - acc: 0.6855 - val_loss : 0.6346 - val_acc: 0.6622


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 62 - loss : 0.6189 - acc: 0.6861 - val_loss : 0.6349 - val_acc: 0.6648


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 63 - loss : 0.6187 - acc: 0.6864 - val_loss : 0.6348 - val_acc: 0.6662


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 64 - loss : 0.6182 - acc: 0.6867 - val_loss : 0.6354 - val_acc: 0.6623


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 65 - loss : 0.6182 - acc: 0.6868 - val_loss : 0.6356 - val_acc: 0.6641


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 66 - loss : 0.6182 - acc: 0.6875 - val_loss : 0.6342 - val_acc: 0.6642


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 67 - loss : 0.6179 - acc: 0.6877 - val_loss : 0.6352 - val_acc: 0.6650


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 68 - loss : 0.6178 - acc: 0.6875 - val_loss : 0.6361 - val_acc: 0.6655


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 69 - loss : 0.6175 - acc: 0.6877 - val_loss : 0.6350 - val_acc: 0.6643


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 70 - loss : 0.6174 - acc: 0.6881 - val_loss : 0.6346 - val_acc: 0.6652


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 71 - loss : 0.6173 - acc: 0.6874 - val_loss : 0.6350 - val_acc: 0.6658


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 72 - loss : 0.6170 - acc: 0.6892 - val_loss : 0.6349 - val_acc: 0.6653


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 73 - loss : 0.6168 - acc: 0.6887 - val_loss : 0.6348 - val_acc: 0.6652


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 74 - loss : 0.6168 - acc: 0.6892 - val_loss : 0.6344 - val_acc: 0.6656


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 75 - loss : 0.6165 - acc: 0.6895 - val_loss : 0.6356 - val_acc: 0.6649


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 76 - loss : 0.6164 - acc: 0.6901 - val_loss : 0.6350 - val_acc: 0.6591


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 77 - loss : 0.6163 - acc: 0.6899 - val_loss : 0.6361 - val_acc: 0.6657


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 78 - loss : 0.6161 - acc: 0.6901 - val_loss : 0.6349 - val_acc: 0.6669


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 79 - loss : 0.6160 - acc: 0.6901 - val_loss : 0.6343 - val_acc: 0.6670


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 80 - loss : 0.6158 - acc: 0.6903 - val_loss : 0.6348 - val_acc: 0.6622


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 81 - loss : 0.6157 - acc: 0.6907 - val_loss : 0.6347 - val_acc: 0.6624


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 82 - loss : 0.6155 - acc: 0.6903 - val_loss : 0.6347 - val_acc: 0.6647


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 83 - loss : 0.6156 - acc: 0.6901 - val_loss : 0.6358 - val_acc: 0.6651


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 84 - loss : 0.6153 - acc: 0.6916 - val_loss : 0.6363 - val_acc: 0.6650


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 85 - loss : 0.6151 - acc: 0.6913 - val_loss : 0.6351 - val_acc: 0.6648


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 86 - loss : 0.6148 - acc: 0.6916 - val_loss : 0.6344 - val_acc: 0.6642


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 87 - loss : 0.6147 - acc: 0.6921 - val_loss : 0.6355 - val_acc: 0.6657


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 88 - loss : 0.6147 - acc: 0.6919 - val_loss : 0.6363 - val_acc: 0.6662


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 89 - loss : 0.6146 - acc: 0.6921 - val_loss : 0.6355 - val_acc: 0.6649


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 90 - loss : 0.6146 - acc: 0.6920 - val_loss : 0.6353 - val_acc: 0.6662


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 91 - loss : 0.6143 - acc: 0.6928 - val_loss : 0.6351 - val_acc: 0.6660


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 92 - loss : 0.6141 - acc: 0.6923 - val_loss : 0.6349 - val_acc: 0.6652


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 93 - loss : 0.6140 - acc: 0.6927 - val_loss : 0.6357 - val_acc: 0.6646


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 94 - loss : 0.6141 - acc: 0.6926 - val_loss : 0.6354 - val_acc: 0.6647


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 95 - loss : 0.6136 - acc: 0.6933 - val_loss : 0.6359 - val_acc: 0.6637


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 96 - loss : 0.6138 - acc: 0.6933 - val_loss : 0.6351 - val_acc: 0.6651


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 97 - loss : 0.6136 - acc: 0.6939 - val_loss : 0.6356 - val_acc: 0.6634


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 98 - loss : 0.6132 - acc: 0.6940 - val_loss : 0.6346 - val_acc: 0.6643


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 99 - loss : 0.6133 - acc: 0.6937 - val_loss : 0.6353 - val_acc: 0.6649


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 100 - loss : 0.6130 - acc: 0.6941 - val_loss : 0.6357 - val_acc: 0.6657


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 101 - loss : 0.6130 - acc: 0.6944 - val_loss : 0.6356 - val_acc: 0.6591


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 102 - loss : 0.6131 - acc: 0.6946 - val_loss : 0.6352 - val_acc: 0.6649


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 103 - loss : 0.6129 - acc: 0.6941 - val_loss : 0.6363 - val_acc: 0.6651


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 104 - loss : 0.6126 - acc: 0.6940 - val_loss : 0.6367 - val_acc: 0.6645


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 105 - loss : 0.6127 - acc: 0.6945 - val_loss : 0.6358 - val_acc: 0.6649


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 106 - loss : 0.6126 - acc: 0.6940 - val_loss : 0.6352 - val_acc: 0.6603


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 107 - loss : 0.6124 - acc: 0.6955 - val_loss : 0.6359 - val_acc: 0.6630


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 108 - loss : 0.6123 - acc: 0.6956 - val_loss : 0.6360 - val_acc: 0.6655


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 109 - loss : 0.6122 - acc: 0.6951 - val_loss : 0.6355 - val_acc: 0.6639


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 110 - loss : 0.6120 - acc: 0.6959 - val_loss : 0.6350 - val_acc: 0.6656


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 111 - loss : 0.6120 - acc: 0.6957 - val_loss : 0.6348 - val_acc: 0.6660


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 112 - loss : 0.6119 - acc: 0.6957 - val_loss : 0.6359 - val_acc: 0.6662


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 113 - loss : 0.6119 - acc: 0.6956 - val_loss : 0.6362 - val_acc: 0.6642


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 114 - loss : 0.6118 - acc: 0.6965 - val_loss : 0.6353 - val_acc: 0.6602


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 115 - loss : 0.6117 - acc: 0.6969 - val_loss : 0.6363 - val_acc: 0.6646


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 116 - loss : 0.6116 - acc: 0.6973 - val_loss : 0.6362 - val_acc: 0.6660


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 117 - loss : 0.6116 - acc: 0.6969 - val_loss : 0.6358 - val_acc: 0.6645


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 118 - loss : 0.6112 - acc: 0.6965 - val_loss : 0.6355 - val_acc: 0.6626


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 119 - loss : 0.6112 - acc: 0.6972 - val_loss : 0.6351 - val_acc: 0.6659


  0%|          | 0/8230 [00:00<?, ?it/s]

  0%|          | 0/2058 [00:00<?, ?it/s]

Epoch : 120 - loss : 0.6113 - acc: 0.6971 - val_loss : 0.6352 - val_acc: 0.6672


In [54]:
torch.save(lstm_classifier, 'model3-1.pt')

In [55]:
import chess.pgn
from create_dataset import chess_board_to_text
import io


model = torch.load("model3-1.pt")

pgn = open("Nakamura.pgn")

game = chess.pgn.read_game(pgn)
header = chess.pgn.read_headers(pgn)
print(header)
while not game.is_end():
#game = game.end()
    game = game.next()
    tensor = chess_board_to_text(str(game.board()))
    tensor.unsqueeze_(0)
    with torch.no_grad():
        predictions = model(tensor.to(device))
        probabilities = F.sigmoid(predictions) 
        np_arr = probabilities.detach().cpu().numpy()
        print(predictions, np_arr)

Headers(Event='Wch U10', Site='Cannes', Date='1997.??.??', Round='2', White='Nakamura, Hikaru', Black='El Mikati, Mohamad', Result='1-0', WhiteElo='', BlackElo='', ECO='C11')
tensor([[-0.0146,  0.0146]], device='cuda:0') [[0.4963506 0.5036495]]
tensor([[-0.0114,  0.0114]], device='cuda:0') [[0.49714217 0.5028578 ]]
tensor([[-0.0166,  0.0166]], device='cuda:0') [[0.49585295 0.50414705]]
tensor([[-0.0271,  0.0271]], device='cuda:0') [[0.49323434 0.5067656 ]]
tensor([[-0.0763,  0.0763]], device='cuda:0') [[0.48092866 0.51907134]]
tensor([[-0.0484,  0.0484]], device='cuda:0') [[0.48790276 0.5120972 ]]
tensor([[-0.0428,  0.0428]], device='cuda:0') [[0.4892895 0.5107105]]
tensor([[-0.0374,  0.0374]], device='cuda:0') [[0.49065402 0.509346  ]]
tensor([[-0.0267,  0.0267]], device='cuda:0') [[0.49333423 0.5066657 ]]
tensor([[-0.0072,  0.0072]], device='cuda:0') [[0.4982078  0.50179213]]
tensor([[-0.0052,  0.0052]], device='cuda:0') [[0.4986938  0.50130624]]
tensor([[-0.0014,  0.0014]], device='