In [None]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1" 
import torchvision
import torchvision.transforms as transforms
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data import Sampler
import chess
import chess.pgn
import math
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter
import random
from sklearn.metrics import confusion_matrix
import pandas as pd
import seaborn as sn
import numpy as np
import io
import matplotlib.pyplot as plt
from torch import optim
from torch.utils.data import WeightedRandomSampler

classes=['AI GAME', 'AI WHITE', 'AI BLACK', 'HUMAN GAME']

In [None]:
def fen_sequence_to_tensor(fen_list):
    '''
        Input: String list. A list of fen codes that are in ordered sequence
        Output: 4D tensor to input into Neural Net. Dimensions of tensor are (6,20,8,8)
    '''
    t = torch.zeros(6,20,8,8)
    for i in range(6):
        for j in range(20):
            board = chess.Board(fen=fen_list[j])
            for color_bool in [True,False]:
                piece_loc_indexes = list(board.pieces(piece_type=i+1,color=color_bool))
                for inx in piece_loc_indexes:
                    x = math.floor(inx / 8)
                    y = inx % 8
                    #set the tensor, we flip x axis to make it easier to look at
                    t[i][j][7-x][y] = 1
    return t


In [None]:
def randomRows2Cols(t):
    r = random.randint(1,100)
    if r <= 95:
        return t
    flipped = torch.flip(t,[2,3])
    return flipped

def augmentation_6x6(t):
    for i in t:
        for j in i:
            indices = (j == 1).nonzero(as_tuple=False)
            for inx in indices:
                if 0 in inx or 7 in inx:
                    j[inx[0]][inx[1]] = 0
    return t

def augmentation_6x8(t):
    for i in t:
        for j in i:
            indices = (j == 1).nonzero(as_tuple=False)
            for inx in indices:
                if 0 in inx[0] or 7 in inx[0]:
                    j[inx[0]][inx[1]] = 0
    return t

def randomCrop(t):
    #10% chance to crop the data
    r = random.randint(1,100)
    if r <= 5:
        return augmentation_6x6(t)
    elif r <= 10:
        return augmentation_6x8(t)
    else:
        return t 

In [None]:
transform = transforms.Compose([transforms.Lambda(randomCrop), transforms.Lambda(randomRows2Cols)])

In [None]:
class ChessSeq(Dataset):

    

    def __init__(self,path):
        # Iterate through training directories and create a list of tuples (filename, label)
        divide_num_files_to_read = 1
        print("0")
        files = Path(path / "AvA").glob('*.txt')
        items = [(str(f),"0") for f in files][::divide_num_files_to_read]
        print("1")
        files = Path(path / "A_W").glob('*.txt')
        items = items + [(str(f),"1") for f in files][::divide_num_files_to_read]
        print("2")
        files = Path(path / "A_B").glob('*.txt')
        items = items + [(str(f),"2") for f in files][::divide_num_files_to_read]
        print("3")
        files = Path(path / "HvH").glob('*.txt')
        items = items + [(str(f),"3") for f in files][::divide_num_files_to_read]
        print("4")
        self.items = items
        self.length = len(self.items)
        self.transform = transform
        

    def __getitem__(self, index):
        filename, label = self.items[index]
        f = open(filename)
        fen_array = f.readlines()
        seqTensor = fen_sequence_to_tensor(fen_array)
        f.close()
        seqTensor = self.transform(seqTensor)
        return (seqTensor, int(label))    

    def __len__(self):
        return self.length

In [None]:
bs = 128
path = Path.cwd()
writer = SummaryWriter("runs/final_2")

sample_weights = [1/x for x in [303168,33952,36853,132609]]

sampler_train = WeightedRandomSampler(weights=sample_weights, num_samples = 50000,replacement=True)
sampler_valid_test = WeightedRandomSampler(weights=sample_weights, num_samples = 50000,replacement=True)

train_chessseq = ChessSeq(path / "train")
valid_chessseq = ChessSeq(path / "valid")
test_chessseq = ChessSeq(path / "test")

'''
train_loader = torch.utils.data.DataLoader(train_chessseq, batch_size = bs, sampler= sampler_train)
valid_loader = torch.utils.data.DataLoader(valid_chessseq, batch_size = bs, sampler=sampler_valid_test)
test_loader  = torch.utils.data.DataLoader(test_chessseq, batch_size = bs, sampler=sampler_valid_test)
'''


train_loader = torch.utils.data.DataLoader(train_chessseq, batch_size = bs, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_chessseq, batch_size = bs, shuffle=True)
test_loader  = torch.utils.data.DataLoader(test_chessseq, batch_size = bs, shuffle=True)

In [None]:
data_batch, labels_batch = next(iter(train_loader))
print(labels_batch)
print(data_batch.size())

In [None]:
#TESTING

fen_sequence_to_tensor(['r2q1rk1/pb1nbppp/2p2n2/1p2p1B1/4P3/1BN2N2/PPP1QPPP/R4RK1 w - - 2 11',
'r2q1rk1/pb1nbppp/2p2n2/1p2p1B1/4P3/1BN2N2/PPP1QPPP/3R1RK1 b - - 3 11',
'r4rk1/pbqnbppp/2p2n2/1p2p1B1/4P3/1BN2N2/PPP1QPPP/3R1RK1 w - - 4 12',
'r4rk1/pbqnbppp/2p2n2/1p2p1B1/4P3/1BN2N2/PPPRQPPP/5RK1 b - - 5 12',
'r4rk1/1bqnbppp/2p2n2/pp2p1B1/4P3/1BN2N2/PPPRQPPP/5RK1 w - - 0 13',
'r4rk1/1bqnbppp/2p2n2/pp2p1B1/4P3/PBN2N2/1PPRQPPP/5RK1 b - - 0 13',
'r4rk1/1bqnbppp/2p2n2/p3p1B1/1p2P3/PBN2N2/1PPRQPPP/5RK1 w - - 0 14',
'r4rk1/1bqnbppp/2p2n2/p3p1B1/1P2P3/1BN2N2/1PPRQPPP/5RK1 b - - 0 14',
'r4rk1/1bqnbppp/2p2n2/4p1B1/1p2P3/1BN2N2/1PPRQPPP/5RK1 w - - 0 15',
'r4rk1/1bqnbppp/2p2n2/4p1B1/1p2P3/1B3N2/NPPRQPPP/5RK1 b - - 1 15',
'r4rk1/1bqnbppp/5n2/2p1p1B1/1p2P3/1B3N2/NPPRQPPP/5RK1 w - - 0 16',
'r4rk1/1bqnbppp/5n2/2p1p1B1/1p2P3/1B3N2/1PPRQPPP/2N2RK1 b - - 1 16',
'r4rk1/2qnbppp/b4n2/2p1p1B1/1p2P3/1B3N2/1PPRQPPP/2N2RK1 w - - 2 17',
'r4rk1/2qnbppp/b4n2/2p1p1B1/1pB1P3/5N2/1PPRQPPP/2N2RK1 b - - 3 17',
'r4rk1/2qnbppp/5n2/2p1p1B1/1pb1P3/5N2/1PPRQPPP/2N2RK1 w - - 0 18',
'r4rk1/2qnbppp/5n2/2p1p1B1/1pQ1P3/5N2/1PPR1PPP/2N2RK1 b - - 0 18',
'r4rk1/2q1bppp/1n3n2/2p1p1B1/1pQ1P3/5N2/1PPR1PPP/2N2RK1 w - - 1 19',
'r4rk1/2q1bppp/1n3n2/2p1p1B1/1p2P3/5N2/1PPRQPPP/2N2RK1 b - - 2 19',
'r4rk1/2q1bppp/5n2/2p1p1B1/np2P3/5N2/1PPRQPPP/2N2RK1 w - - 3 20',
'r4rk1/2q1bppp/5n2/2p1p1B1/np2P3/1P3N2/2PRQPPP/2N2RK1 b - - 0 20',
])[0][0] #kings, first move of sequence



In [None]:
def add_pr_curve(class_inx, probs, label, global_step=0):
    tb_truth = label == class_inx
    tb_probs = probs[:, class_inx]
    writer.add_pr_curve(classes[class_inx],
                        tb_truth,
                        tb_probs,
                        global_step=global_step)

def calculate_precision_recall_f1(probs, label):
    
    metrics_dict = {}
    label = label.cpu().detach().numpy()
    probs = probs.cpu().detach().numpy().round()

    choices = np.array([np.argmax(item) for item in probs])

    tp_count = int((np.array(choices == 0) & np.array(label == 0)).sum())
    tn_count = int((np.array(choices == 1) & np.array(label == 1)).sum())
    
    fp_count = int((np.array(choices == 1) & np.array(label == 0)).sum())
    fn_count = int((np.array(choices == 0) & np.array(label == 1)).sum())

    print(tp_count, fp_count, fn_count, tn_count)
    precision = metrics_dict['pr/precision'] = tp_count / np.float32(tp_count + fp_count)
    recall = metrics_dict['pr/recall'] = tp_count / np.float32(tp_count + fn_count)
    metrics_dict['pr/f1'] = 2 * (precision * recall) / (precision + recall)
    
    return metrics_dict

def make_confusion_matrix(probs, label):
    label = label.cpu().detach().numpy()
    probs = probs.cpu().detach().numpy().round()

    choices = np.array([np.argmax(item) for item in probs])

    cf_matrix = confusion_matrix(label, choices)
    df_cm = pd.DataFrame(cf_matrix, index=[i for i in classes],
                         columns=[i for i in classes])
    plt.figure(figsize=(6, 6))   
    hm = sn.heatmap(df_cm, annot=True).get_figure()

    hm.savefig('conf_matr.png')

In [None]:
from torch import nn
import torch.nn.functional as F


class ChessNet(nn.Module):

    def __init__(self, num_classes=4, conv_channels=8):
        super(ChessNet, self).__init__()
        
        self.b1 = ChessBlock(6,conv_channels)
        self.b2 = ChessBlock(conv_channels,conv_channels * 2)
        self.b3 = ChessBlock(conv_channels * 2,conv_channels * 4)
        #self.b4 = ChessBlock(conv_channels * 4,conv_channels * 8)
        self.batchnorm = nn.BatchNorm3d(6)
        self.linear = nn.Sequential( 
            nn.Linear(2048, 128), 
            nn.ReLU(),
            nn.Linear(128, num_classes))
    
    def forward(self, x):
        batch_out = self.batchnorm(x)
        b = self.b1(batch_out)
        b = self.b2(b)
        b = self.b3(b)
        #b = self.b4(b)
        
        b = torch.flatten(b, 1)
        b = self.linear(b)
        return b

class ChessBlock(nn.Module):
    def __init__(self, in_channels, conv_channels):
        super().__init__()

        self.conv1 = nn.Conv3d(in_channels, conv_channels, kernel_size=(5,3,3), stride=(2,1,1), padding=(2,1,1), bias=True)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv3d(conv_channels, conv_channels , kernel_size=(5,3,3), stride=(2,1,1), padding=(2,1,1), bias=True)
        self.relu2 = nn.ReLU(inplace=True)
        
        self.pool = nn.MaxPool3d(kernel_size=(5,3,3), stride=(2,1,1), padding=(2,1,1))
        

    def forward(self, input_batch):
        b = self.conv1(input_batch)
        b = self.relu1(b)
        b = self.conv2(b)
        b = self.relu2(b)

        return self.pool(b)

device=None
if torch.cuda.is_available():
    device = torch.device("cuda") 
else:
    device = torch.device("cpu")
cnn = ChessNet()
cnn.to(device)

In [None]:
def train(model, optimizer, loss_fn, train_loader, val_loader, epochs=15, device="cpu"):
    
    for epoch in range(1, epochs+1):
        training_loss = 0.0
        valid_loss = 0.0
        class_probs = [] #for uploading PR metrics
        class_label = [] # ^
        num_correct=0
        num_examples=0

        model.train()
        for batch in train_loader:
            
            inputs, targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)
            writer.add_graph(model,inputs)

            optimizer.zero_grad()
            output = model(inputs)

            class_probs_batch = [F.softmax(el, dim=0) for el in output]
            class_probs.append(class_probs_batch)
            class_label.append(targets)

            loss = loss_fn(output, targets)
            loss.backward()
            optimizer.step()
            training_loss += loss.data.item() * inputs.size(0)

            correct = torch.eq(torch.max(F.softmax(output, dim=1), dim=1)[1], targets)
            num_correct += torch.sum(correct).item()
            num_examples += correct.shape[0]

        training_loss /= len(train_loader.dataset)
        writer.add_scalar("Loss/train", training_loss, epoch)
        writer.add_scalar("Correct/train", num_correct/num_examples, epoch)
        
        model.eval()

        num_correct = 0 
        num_examples = 0
        for batch in val_loader:
            inputs, targets = batch
            inputs = inputs.to(device)
            output = model(inputs)
            targets = targets.to(device)
            loss = loss_fn(output,targets) 
            valid_loss += loss.data.item() * inputs.size(0)
            correct = torch.eq(torch.max(F.softmax(output, dim=1), dim=1)[1], targets)
            num_correct += torch.sum(correct).item()
            num_examples += correct.shape[0]
        valid_loss /= len(val_loader.dataset)
        writer.add_scalar("Loss/valid", valid_loss, epoch)
        writer.add_scalar("Correct/valid", num_correct, epoch)

        print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}, accuracy = {:.2f}'.format(epoch, training_loss,
        valid_loss, num_correct / num_examples))
        writer.flush()
        writer.close()

In [None]:
train(cnn,optim.Adam(cnn.parameters(),lr=0.0003, weight_decay=1e-5),nn.CrossEntropyLoss(),train_loader,valid_loader,epochs=1,device=device)

In [None]:
def test(model, test_loader, device="cpu"):
    num_correct=0
    total=0
    for batch in test_loader:
        inputs, targets = batch
        inputs = inputs.to(device)
        output = model(inputs)
        targets = targets.to(device)
        correct = torch.eq(torch.max(F.softmax(output, dim=1), dim=1)[1], targets)
        num_correct += torch.sum(correct).item()
        total+=len(correct)
    print("Prediction accuracy {:.2f}% ({} correct out of {})".format((num_correct/total)*100,num_correct,total))



In [None]:
class_probs = []
class_label = []
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        images = images.to(device)
        output = cnn(images)
        class_probs_batch = [F.softmax(el, dim=0) for el in output]
        class_probs.append(class_probs_batch)
        class_label.append(labels)

test_probs = torch.cat([torch.stack(batch) for batch in class_probs])
test_label = torch.cat(class_label)

add_pr_curve(0, test_probs, test_label)
add_pr_curve(1, test_probs, test_label)
make_confusion_matrix(test_probs,test_label)
metrics = calculate_precision_recall_f1(test_probs,test_label)
print(metrics)
writer.close()

In [None]:
test(cnn,test_loader,device)

In [None]:
experiment_seq = ChessSeq(path / "experiment_sequences")
exp_loader  = torch.utils.data.DataLoader(experiment_seq, batch_size = 32, shuffle=True)
test(cnn, exp_loader,device)
