Tests for ChessNet

In [1]:
import numpy as np
import re
import pandas as pd
import gc
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch import nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from config import Config #TODO: update to device to use config object
from data import ChessDataset
from complex_cnn.complex_model import ChessNet
from preprocess import preprocess_chess_data
from test import *

In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model_name = 'cnn_final.pth'
saved_model = ChessNet(hidden_layers=4, hidden_size=200)
PATH_name = f"/home/tamiroffen/AI_Project/project/saved_models/{model_name}"
saved_model.load_state_dict(torch.load(PATH_name))
saved_model.to(device)
saved_model.eval()

cuda


ChessNet(
  (input_layer): Conv2d(6, 200, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (module_list): ModuleList(
    (0-3): 4 x module(
      (conv1): Conv2d(200, 200, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(200, 200, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation1): SELU()
      (activation2): SELU()
    )
  )
  (output_layer): Conv2d(200, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

In [13]:
config = Config()
train_chess_data = preprocess_chess_data(f'{config.chess_dataset_dir}/train_dataset.csv')
data_train = ChessDataset(train_chess_data['AN'])
valid_chess_data = preprocess_chess_data(f'{config.chess_dataset_dir}/valid_dataset.csv')
data_valid = ChessDataset(valid_chess_data['AN'])
test_chess_data = preprocess_chess_data(f'{config.chess_dataset_dir}/test_dataset.csv')
data_test = ChessDataset(test_chess_data['AN'])
print("Loaded datasets")

data_train_loader = DataLoader(data_train, batch_size=config.batch_size, shuffle=False, drop_last=True)
data_valid_loader = DataLoader(data_valid, batch_size=config.batch_size, shuffle=False, drop_last=True)
data_test_loader = DataLoader(data_test, batch_size=config.batch_size, shuffle=False, drop_last=True)
print("Loaders finished processing")



Loaded datasets
Loaders finished processing


In [14]:
print("Train Set:")
from_accuracy_train, piece_accuracies_train, overall_piece_accuracy_train = test_piece_accuracy(saved_model, data_train_loader)
print(f'from accuracy: {from_accuracy_train}')
print(f'piece accuracy: [p,r,n,b,q,k]: {piece_accuracies_train}')
print(f'overall piece accuracy: {overall_piece_accuracy_train}')

Train Set:
Testing piece accuracy of model, over 1 num of epochs
from accuracy: 0.39163910961946596
piece accuracy: [p,r,n,b,q,k]: [0.8822695035460993, 0.37555555555555553, 0.6209463051568315, 0.5227736233854521, 0.2888198757763975, 0.62203519510329]
overall piece accuracy: 0.5810964083175804


In [5]:
print("Valid Set:")
from_accuracy_valid, piece_accuracies_valid, overall_piece_accuracy_valid = test_piece_accuracy(saved_model, data_valid_loader)
print(f'from accuracy: {from_accuracy_valid}')
print(f'piece accuracy: [p,r,n,b,q,k]: {piece_accuracies_valid}')
print(f'overall piece accuracy: {overall_piece_accuracy_valid}')

Valid Set:
Testing piece accuracy of model, over 1 num of epochs
from accuracy: 0.39210487553395196
piece accuracy: [p,r,n,b,q,k]: [0.8622881355932204, 0.35508849557522126, 0.5996940336562978, 0.5239085239085239, 0.3013013013013013, 0.6484560570071259]
overall piece accuracy: 0.5752566992236414


In [6]:
print("Test Set:")
from_accuracy_test, piece_accuracies_test, overall_piece_accuracy_test = test_piece_accuracy(saved_model, data_test_loader)
print(f'from accuracy: {from_accuracy_test}')
print(f'piece accuracy: [p,r,n,b,q,k]: {piece_accuracies_test}')
print(f'overall piece accuracy: {overall_piece_accuracy_test}')

Test Set:
Testing piece accuracy of model, over 1 num of epochs
from accuracy: 0.38743196437407224
piece accuracy: [p,r,n,b,q,k]: [0.8544698544698545, 0.3425814234016888, 0.6037424325811778, 0.549792531120332, 0.3007518796992481, 0.6539398862713242]
overall piece accuracy: 0.5790549169859515
