## Training script for the CNN 

Loads in the converted plane representation of the pgn files, defines the network architecture and starts the training process. Checkpoints of the weights are saved if there's an improvement in the validation loss.
The training performance metrics (e.g. losses, accuracies...) are exported to tensorboard and can be checked during training.
* author: QueensGambit

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%reload_ext autoreload

In [None]:
from __future__ import print_function
import sys
sys.path.insert(0,'../../../')
import os
import re
import glob
import chess
import random
import logging
import datetime
import numpy as np
from time import time
from tqdm import tqdm_notebook
from mxnet import nd, autograd
from collections import deque
from copy import deepcopy
from multiprocessing import cpu_count
from time import time
import mxnet as mx
from mxnet import gluon
from mxnet import autograd as ag
from mxboard import SummaryWriter
import matplotlib.pyplot as plt
from DeepCrazyhouse.src.domain.variants.input_representation import board_to_planes, planes_to_board
from DeepCrazyhouse.src.domain.variants.output_representation import policy_to_moves, policy_to_best_move, policy_to_move
from DeepCrazyhouse.src.preprocessing.dataset_loader import load_pgn_dataset
from DeepCrazyhouse.src.runtime.color_logger import enable_color_logging
from DeepCrazyhouse.src.domain.neural_net.architectures.a0_resnet import AlphaZeroResnet
from DeepCrazyhouse.src.domain.neural_net.architectures.mxnet_alpha_zero import alpha_zero_symbol
from DeepCrazyhouse.src.domain.neural_net.architectures.rise_mobile_symbol import rise_mobile_symbol, preact_resnet_symbol
from DeepCrazyhouse.src.domain.neural_net.architectures.rise import Rise
from DeepCrazyhouse.src.domain.neural_net.architectures.densenet import DenseNet
from DeepCrazyhouse.src.domain.neural_net.architectures.wide_resnet_se import WideResnetSE
from DeepCrazyhouse.src.domain.neural_net.architectures.shuffle_rise import ShuffleRise
from DeepCrazyhouse.src.preprocessing.pgn_record_dataset import PGNRecordDataset
from DeepCrazyhouse.configs.main_config import main_config
from DeepCrazyhouse.configs.main_config import main_config
from DeepCrazyhouse.src.training.trainer_agent import TrainerAgent, evaluate_metrics, acc_sign, reset_metrics
from DeepCrazyhouse.src.training.trainer_agent_mxnet import TrainerAgentMXNET
from DeepCrazyhouse.src.training.lr_schedules.lr_schedules import *
from DeepCrazyhouse.src.domain.variants.plane_policy_representation import FLAT_PLANE_IDX
from DeepCrazyhouse.src.domain.variants.constants import NB_POLICY_MAP_CHANNELS, NB_LABELS

enable_color_logging()
%matplotlib inline

## Settings

In [None]:
# set the context on CPU, switch to GPU if there is one available (strongly recommended for training)
ctx = mx.gpu(0) #2
# set a specific seed value for reproducability
seed = 42

export_weights = True
log_metrics_to_tensorboard = True
export_grad_histograms = True
div_factor = 1

# batch_steps = 1000 means for example that every 1000 batches the validation set gets processed
batch_steps = 1000 * div_factor # this defines how often a new checkpoint will be saved and the metrics evaluated
# k_steps_initial defines how many steps have been trained before
# (k_steps_initial != 0 if you continue training from a checkpoint)
k_steps_initial = 0
cur_it = k_steps_initial * batch_steps # iteration counter used for the momentum and learning rate schedule
# these are the weights to continue training with
symbol_file = None
params_file = None

batch_size = 1024 // div_factor # 1024 # the batch_size needed to be reduced to 1024 in order to fit in the GPU 1080Ti
#4096 was originally used in the paper -> works slower for current GPU
# 2048 was used in the paper Mastering the game of Go without human knowledge and fits in GPU memory
#typically if you half the batch_size, you should double the lr

# optimization parameters
optimizer_name = "nag"
max_lr = 0.35 / div_factor #0.01 # default lr for adam
min_lr = 0.00001
max_momentum = 0.95
min_momentum = 0.8
# loads a previous checkpoint if the loss increased significanly
use_spike_recovery = True
# stop training as soon as max_spikes has been reached
max_spikes = 20
# define spike threshold when the detection will be triggered
spike_thresh = 1.5
# weight decay
wd = 1e-4
dropout_rate = 0.2
# weight the value loss a lot lower than the policy loss in order to prevent overfitting
val_loss_factor = 0.01
policy_loss_factor = 0.99
discount = 1.0 # 0.995

normalize = True # define whether to normalize input data to [0,1]
nb_epochs = 7 # define how many epoches the network will be trained

select_policy_from_plane = True # Boolean if potential legal moves will be selected from final policy output
use_mxnet_style = True # Decide between mxnet and gluon style for training

In [None]:
# Fixing the random seed
mx.random.seed(seed)

In [None]:
mx.__version__

### Create a ./logs and ./weights directory

In [None]:
!mkdir ./logs && mkdir ./weights

### load the config file

In [None]:
print(main_config)

In [None]:
CPU_COUNT = cpu_count()//2
#if use_mxnet_style:
#    os.environ["MXNET_CPU_WORKER_NTHREADS"] = str(CPU_COUNT) 

### load the dataset-files

### Validation Dataset (which is used during training)

In [None]:
s_idcs_val, x_val, yv_val, yp_val, plys_to_end, pgn_datasets_val = load_pgn_dataset(dataset_type='val', part_id=0,
                                                                           print_statistics=True, print_parameters=True, normalize=normalize)
if discount != 1:
    yv_val *= discount**plys_to_end
if use_mxnet_style:
    if select_policy_from_plane:
        val_iter = mx.io.NDArrayIter({'data': x_val}, {'value_label': yv_val, 'policy_label': np.array(FLAT_PLANE_IDX)[yp_val.argmax(axis=1)]}, batch_size)
    else:
        val_iter = mx.io.NDArrayIter({'data': x_val}, {'value_label': yv_val, 'policy_label': yp_val.argmax(axis=1)}, batch_size)
else:
    val_dataset = gluon.data.ArrayDataset(nd.array(x_val), nd.array(yv_val), nd.array(yp_val.argmax(axis=1)))
    val_data = gluon.data.DataLoader(val_dataset, batch_size, shuffle=False, num_workers=CPU_COUNT)

In [None]:
x_val.dtype

In [None]:
nb_parts = len(glob.glob(main_config['planes_train_dir'] + '**/*'))
nb_parts

In [None]:
nb_it_per_epoch = (len(x_val) * nb_parts) // batch_size # calculate how many iterations per epoch exist
# one iteration is defined by passing 1 batch and doing backprop
total_it = int(nb_it_per_epoch * nb_epochs)
total_it

### Define a Learning Rate schedule

In [None]:
lr_schedule = OneCycleSchedule(start_lr=max_lr/8, max_lr=max_lr, cycle_length=total_it*.3, cooldown_length=total_it*.6, finish_lr=min_lr)
lr_schedule = LinearWarmUp(lr_schedule, start_lr=min_lr, length=total_it/30)
plot_schedule(lr_schedule, iterations=total_it)
#lr_schedule = ConstantSchedule(min_lr)
#plot_schedule(lr_schedule, iterations=total_it, ylim=[-min_lr, max_lr*1.1])

### Momentum schedule

In [None]:
momentum_schedule = MomentumSchedule(lr_schedule, min_lr, max_lr, min_momentum, max_momentum)
#momentum_schedule = ConstantSchedule(min_momentum)
plot_schedule(momentum_schedule, iterations=total_it, ylabel='Momentum')

### Create the model

In [None]:
input_shape = x_val[0].shape

In [None]:
try:
    del net
except:
    pass

### Load the pretrained model

In [None]:
# net = gluon.nn.SymbolBlock.imports(symbol_file='weights/%s'%symbol_file, input_names='data', param_file='weights/%s'%params_file, ctx=ctx)

In [None]:
#net = AlphaZeroResnet(n_labels=2272, channels=256, channels_value_head=8, channels_policy_head=81, num_res_blocks=19, value_fc_size=256, bn_mom=0.9, act_type='relu', select_policy_from_plane=select_policy_from_plane)

In [None]:
#et = alpha_zero_resnet(n_labels=2272, channels=256, channels_value_head=1, channels_policy_head=81, num_res_blocks=19, value_fc_size=256, bn_mom=0.9, act_type='relu')

In [None]:
#expand_res_blocks=[3,3,3,5,5,5,5,5,7,7,7,7,7]
#expand_res_blocks=[3,3,7,7,7]
#net = Rise(n_labels=yp_val.shape[1], channels=256, channels_value_head=8, channels_policy_head=81, nb_res_blocks_x=0, nb_res_blocks_x_neck=0, expand_res_blocks=expand_res_blocks, value_fc_size=256, bn_mom=0.9, act_type='relu', squeeze_excitation_type="cSE", select_policy_from_plane=select_policy_from_plane, use_rise_stem=True)

In [None]:
#net = DenseNet(channels_init=64, growth_rate=24, n_layers=7, bottleneck_factor=4, n_labels=yp_val.shape[1], channels_value_head=4, channels_policy_head=8, value_fc_size=256)

In [None]:
# net = WideResnetSE(n_labels=yp_val.shape[1], channels=512, channels_value_head=4, channels_policy_head=8, nb_res_blocks=6, value_fc_size=512, bn_mom=0.9, act_type='relu', use_se=True, use_rise_stem=True)

In [None]:
#net = ShuffleRise(n_labels=yp_val.shape[1], channels=64, channels_value_head=4, channels_policy_head=81, nb_res_blocks_x=0, nb_shuffle_blocks=19, nb_shuffle_blocks_neck=0, value_fc_size=256, bn_mom=0.9, act_type='lrelu', squeeze_excitation_type=None, select_policy_from_plane=select_policy_from_plane, use_rise_stem=True)

In [None]:
#net = PyramidResnetSE(n_labels=2272, channels=256,  channels_value_head=1, channels_policy_head=81, num_res_blocks=19, value_fc_size=256,  bn_mom=0.9, act_type='relu')

In [None]:
#symbol = alpha_zero_symbol(num_filter=256, channels_value_head=4, channels_policy_head=8, workspace=1024, value_fc_size=256, num_res_blocks=7, bn_mom=0.9, act_type='relu',
#                            n_labels=2272, grad_scale_value=0.01, grad_scale_policy=0.99, select_policy_from_plane=select_policy_from_plane)

In [None]:
bc_res_blocks = [3]+[5]+[7]+[5]+[3]*9

In [None]:
symbol = rise_mobile_symbol(channels=256, channels_operating_init=128, channel_expansion=64, channels_value_head=4,
                   channels_policy_head=NB_POLICY_MAP_CHANNELS, value_fc_size=256, bc_res_blocks=bc_res_blocks, res_blocks=[], act_type='relu',
                   n_labels=NB_LABELS, grad_scale_value=0.01, grad_scale_policy=0.99, select_policy_from_plane=select_policy_from_plane,
                   use_se=True, dropout_rate=dropout_rate)
#    symbol = mx.sym.load("weights/" + symbol_file)

In [None]:
#symbol =preact_resnet_symbol(channels=256, channels_value_head=4,
#                   channels_policy_head=81, value_fc_size=256, value_kernelsize=7,res_blocks=19,
#                   act_type='relu', n_labels=2272, grad_scale_value=0.01, grad_scale_policy=0.99,
#                   select_policy_from_plane=select_policy_from_plane)

symbol = preact_resnet_symbol(channels=256, channels_value_head=8,
                   channels_policy_head=81, value_fc_size=256, res_blocks=19, act_type='relu',
                   n_labels=4992, grad_scale_value=0.01, grad_scale_policy=0.99, select_policy_from_plane=True)
if symbol_file:
    symbol = mx.sym.load("weights/" + symbol_file)
    # probably change 'policy_loss_factor' and 'val_loss_factor'
    #value_out = symbol.get_internals()['value_out_output']
    #policy_out = symbol.get_internals()['policy_out_output']
    #sym = mx.symbol.Group([value_out, policy_out])
    #policy_out = mx.sym.SoftmaxOutput(data=policy_out, name='policy', grad_scale=policy_loss_factor)
    #value_out = mx.sym.LinearRegressionOutput(data=value_out, name='value', grad_scale=val_loss_factor)

    # group value_out and policy_out together
    #symbol = mx.symbol.Group([value_out, policy_out])

## Network summary

In [None]:
if not use_mxnet_style:
    print(net)

In [None]:
if use_mxnet_style:
    display(mx.viz.plot_network(
        symbol,
        shape={'data':(1, input_shape[0], input_shape[1], input_shape[2])},
        node_attrs={"shape":"oval","fixedsize":"false"}
    ))
else:
    display(mx.viz.plot_network(
        net(mx.sym.var('data'))[1],
        shape={'data':(1, input_shape[0], input_shape[1], input_shape[2])},
        node_attrs={"shape":"oval","fixedsize":"false"}
    ))

In [None]:
if use_mxnet_style:
    mx.viz.print_summary(
        symbol,
        shape={'data':(1, input_shape[0], input_shape[1], input_shape[2])},
    )
else:
    mx.viz.print_summary(
    net(mx.sym.var('data'))[1], 
    shape={'data':(1, input_shape[0], input_shape[1], input_shape[2])},
    ) 

## Initialize the weights 
(only needed if no pretrained weights are used)

In [None]:
# create a trainable module on compute context
if use_mxnet_style:
    model = mx.mod.Module(symbol=symbol, context=ctx, label_names=['value_label', 'policy_label'])
    model.bind(for_training=True, data_shapes=[('data', (batch_size, input_shape[0], input_shape[1], input_shape[2]))],
             label_shapes=val_iter.provide_label)
    model.init_params(mx.initializer.Xavier(rnd_type='uniform', factor_type='avg', magnitude=2.24))
    if params_file:
        model.load_params("weights/" + params_file)    
else:
    net.collect_params().initialize(mx.init.Xavier(rnd_type='uniform', factor_type='avg', magnitude=2.24), ctx=ctx)
    net.hybridize()

## Define the metrics to use

In [None]:
metrics_mxnet = [
mx.metric.MSE(name='value_loss', output_names=['value_output'], label_names=['value_label']),
mx.metric.CrossEntropy(name='policy_loss', output_names=['policy_output'],
                                            label_names=['policy_label']),
mx.metric.create(acc_sign, name='value_acc_sign', output_names=['value_output'],
                                         label_names=['value_label']),
mx.metric.Accuracy(axis=1, name='policy_acc', output_names=['policy_output'],
                                       label_names=['policy_label'])
]
metrics_gluon = {
'value_loss': mx.metric.MSE(name='value_loss', output_names=['value_output']),
'policy_loss': mx.metric.CrossEntropy(name='policy_loss', output_names=['policy_output'],
                                            label_names=['policy_label']),
'value_acc_sign': mx.metric.create(acc_sign, name='value_acc_sign', output_names=['value_output'],
                                         label_names=['value_label']),
'policy_acc': mx.metric.Accuracy(axis=1, name='policy_acc', output_names=['policy_output'],
                                       label_names=['policy_label'])
}
if use_mxnet_style:
    metrics = metrics_mxnet
else:
    metrics = metrics_gluon

## Define a training agent

In [None]:
if use_mxnet_style:
    train_agent = TrainerAgentMXNET(model, symbol, val_iter, nb_parts, lr_schedule, momentum_schedule, total_it, optimizer_name, wd=wd, batch_steps=batch_steps,
                 k_steps_initial=k_steps_initial, cpu_count=CPU_COUNT-3, batch_size=batch_size, normalize=normalize, export_weights=export_weights,
                 export_grad_histograms=export_grad_histograms, log_metrics_to_tensorboard=log_metrics_to_tensorboard, ctx=ctx, metrics=metrics,
                use_spike_recovery=use_spike_recovery, max_spikes=max_spikes, spike_thresh=spike_thresh, seed=seed,
                           val_loss_factor=val_loss_factor, policy_loss_factor=policy_loss_factor, select_policy_from_plane=select_policy_from_plane, discount=discount)
else:
    train_agent = TrainerAgent(net, val_data, nb_parts, lr_schedule, momentum_schedule, total_it, optimizer_name, wd=wd, batch_steps=batch_steps,
                     k_steps_initial=k_steps_initial, cpu_count=CPU_COUNT-3, batch_size=batch_size, normalize=normalize, export_weights=export_weights,
                     export_grad_histograms=export_grad_histograms, log_metrics_to_tensorboard=log_metrics_to_tensorboard, ctx=ctx, metrics=metrics,
                    use_spike_recovery=use_spike_recovery, max_spikes=max_spikes, spike_thresh=spike_thresh, seed=seed,
                               val_loss_factor=val_loss_factor, policy_loss_factor=policy_loss_factor, select_policy_from_plane=select_policy_from_plane)

## Performance Pre-Training

In [None]:
model.score(val_iter, metrics)

# adapted from: https://cwiki.apache.org/confluence/display/MXNET/How+to+use+MXNet-TensorRT+integration
# Execute with MXNet
x_batch = nd.array(x_val[0:128], ctx=ctx)
batch_shape = (128, input_shape[0], input_shape[1], input_shape[2])
if use_mxnet_style:
    executor = symbol.simple_bind(ctx=mx.gpu(0), data=batch_shape, grad_req='null', force_rebind=True)
    (arg_params, aux_params) = model.get_params()
    executor.copy_params_from(arg_params, aux_params)

    # Warmup
    print('Warming up MXNet')
    for i in range(0, 10):
        y_gen = executor.forward(is_train=False, data=x_batch)
        y_gen[0].wait_to_read()

    # Timing
    print('Starting MXNet timed run')
    start = time()
    for i in range(0, 500):
        y_gen = executor.forward(is_train=False, data=x_batch)
        y_gen[0].wait_to_read()
    end = time()
else:
    # Warmup
    print('Warming up MXNet')
    for i in range(0, 10):
        y_gen = net(x_batch)
        y_gen[0][0].wait_to_read()

    # Timing
    print('Starting MXNet timed run')
    start = time()
    for i in range(0, 500):
        y_gen = net(x_batch)
        y_gen[0][0].wait_to_read()
    end = time()
print("Elapsed time: %.4fs" % (time() - start))

## Start the training process

In [None]:
(k_steps_final, val_loss_final, val_p_acc_final), (k_steps_best, val_loss_best, val_p_acc_best) = train_agent.train(cur_it)

## Export the last model state

In [None]:
prefix = "./weights/model-%.5f-%.3f" % (val_loss_final, val_p_acc_final)

if use_mxnet_style:
    # the export function saves both the architecture and the weights
    model.save_checkpoint(prefix, epoch=k_steps_final)
else:
    # the export function saves both the architecture and the weights
    net.export(prefix, epoch=k_steps_final)
    logging.info("Saved checkpoint to %s-%04d.params", prefix, k_steps_final)

## Load the best model once again

In [None]:
# delete the current net object form memory
if not use_mxnet_style:
    del net

In [None]:
model_prefix = "./weights/model-%.5f-%.3f" % (val_loss_best, val_p_acc_best)
model_arch_path = '%s-symbol.json' % model_prefix
model_params_path = '%s-%04d.params' % (model_prefix, k_steps_best)
print('load current best model:', model_params_path)
symbol = mx.sym.load(model_arch_path)
inputs = mx.sym.var('data', dtype='float32')
value_out = symbol.get_internals()['value_out_output'] #value_out_output']
policy_out = symbol.get_internals()['policy_out_output']
sym = mx.symbol.Group([value_out, policy_out])
net = mx.gluon.SymbolBlock(sym, inputs)
net.collect_params().load(model_params_path, ctx) #, allow_missing=True)

In [None]:
print('best val_loss: %.5f with v_policy_acc: %.5f at k_steps_best %d' % (val_loss_best, val_p_acc_best, k_steps_best))

In [None]:
idx = 0

In [None]:
board = planes_to_board(x_val[idx], normalized_input=normalize)

print(chess.COLOR_NAMES[board.turn])
if board.uci_variant == "crazyhouse":
    print(board.pockets)
board

In [None]:
def predict_single(net, x, select_policy_from_plane=False):
    
    out = [None, None]
    pred = net(mx.nd.array(np.expand_dims(x, axis=0), ctx=ctx))
    if select_policy_from_plane:
        pred[1] = pred[1][:, FLAT_PLANE_IDX]
    pred[1] = pred[1].softmax()
    out[0] = pred[0].asnumpy()
    out[1] = pred[1].asnumpy()
    
    return out

In [None]:
pred = predict_single(net, x_val[0], select_policy_from_plane)
pred

In [None]:
pred = predict_single(net, x_val[0], select_policy_from_plane)

In [None]:
policy_to_best_move(board, yp_val[idx])

In [None]:
opts = 5
selected_moves, probs = policy_to_moves(board, pred[1][0])
selected_moves[:opts]

In [None]:
plt.barh(range(opts)[::-1], probs[:opts])
ax = plt.gca()
ax.set_yticks(range(opts)[::-1])
ax.set_yticklabels(selected_moves)

In [None]:
board = planes_to_board(x_val[0], normalized_input=True)
board.push_uci('e2e4')
board.push_uci('e7e5')
board.push_uci('f1c4')
board.push_uci('b8c6')
board.push_uci('d1h5')
x_scholar_atck = board_to_planes(board, normalize=normalize)
board

In [None]:
pred = predict_single(net, x_scholar_atck, select_policy_from_plane)

selected_moves, probs = policy_to_moves(board, pred[1][0])
plt.barh(range(opts)[::-1], probs[:opts])
ax = plt.gca()
ax.set_yticks(range(opts)[::-1])
ax.set_yticklabels(selected_moves)

In [None]:
board.push(selected_moves[0])
board

### Performance on test dataset


In [None]:
s_idcs_test, x_test, yv_test, yp_test, _, pgn_datasets_test = load_pgn_dataset(dataset_type='test', part_id=0,
                                                                           print_statistics=True, print_parameters=True, normalize=True)
test_dataset = gluon.data.ArrayDataset(nd.array(x_test), nd.array(yv_test), nd.array(yp_test.argmax(axis=1)))
test_data = gluon.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=CPU_COUNT)

In [None]:
metrics = metrics_gluon
evaluate_metrics(metrics, test_data, net, nb_batches=None, select_policy_from_plane=select_policy_from_plane, ctx=ctx)

### Show result on mate-in-one problems

In [None]:
s_idcs_mate, x_mate, yv_mate, yp_mate, _, pgn_dataset_mate = load_pgn_dataset(dataset_type='mate_in_one', part_id=1,
                                                         print_parameters=True, print_statistics=True, normalize=normalize)

In [None]:
mate_dataset = mx.gluon.data.dataset.ArrayDataset(nd.array(x_mate), nd.array(yv_mate), nd.array(yp_mate.argmax(axis=1)))
mate_data = mx.gluon.data.DataLoader(mate_dataset, batch_size=batch_size, num_workers=CPU_COUNT)

### Mate In One Performance

In [None]:
evaluate_metrics(metrics, mate_data, net, select_policy_from_plane=select_policy_from_plane, ctx=ctx)

### Show some example mate problems

In [None]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

### Evaluate Performance

In [None]:
def eval_pos(net, x_mate, yp_mate, verbose=False, select_policy_from_plane=False):
    
    board = planes_to_board(x_mate, normalized_input=normalize)
    if verbose is True:
        print("{0}'s turn".format(chess.COLOR_NAMES[board.turn]))
        if board.uci_variant == "crazyhouse":
            print("black/white {0}".format(board.pockets))
    pred = predict_single(net, x_mate, select_policy_from_plane=select_policy_from_plane)
    
    true_move = policy_to_move(yp_mate, is_white_to_move=board.turn)
    
    opts = 5
    pred_moves, probs = policy_to_moves(board, pred[1][0])
    pred_moves = pred_moves[:opts]
    
    legal_move_cnt = board.legal_moves.count()
    mate_move_cnt = str(board.legal_moves).count('#')
    
    is_mate_5_top = False
    
    for pred_move in pred_moves:
        board_5_top = deepcopy(board)
        board_5_top.push(pred_move)
        if board_5_top.is_checkmate() is True:
            is_mate_5_top = True
            break
    
    board.push(pred_moves[0])
    
    is_checkmate = False
    if board.is_checkmate() is True:
        is_checkmate = True
        
    filtered_pred = sorted(pred[1][0], reverse=True)
    
    if verbose is True:
        plt.barh(range(opts)[::-1], filtered_pred[:opts])
        ax = plt.gca()
        ax.set_yticks(range(opts)[::-1])
        ax.set_yticklabels(pred_moves)
        plt.title('True Move:' + str(true_move) +
                 '\nEval:' + str(pred[0][0]))
        plt.show()
    
    return pred, pred_moves, true_move, board, is_checkmate, is_mate_5_top, legal_move_cnt, mate_move_cnt

In [None]:
nb_pos = len(x_mate)
mates_found = []
mates_5_top_found = []
legal_mv_cnts = []
mate_mv_cnts = []

for i in range(nb_pos):
    pred, pred_moves, true_move, board, is_mate, is_mate_5_top, legal_mv_cnt, mate_mv_cnt= eval_pos(net, x_mate[i], yp_mate[i], select_policy_from_plane=select_policy_from_plane)
    mates_found.append(is_mate)
    legal_mv_cnts.append(legal_mv_cnt)
    mate_mv_cnts.append(mate_mv_cnt)
    mates_5_top_found.append(is_mate_5_top)

In [None]:
np.array(mate_mv_cnts).mean()

In [None]:
np.array(legal_mv_cnts).mean()

### Random Guessing Baseline

In [None]:
np.array(mate_mv_cnts).mean() / np.array(legal_mv_cnts).mean()

### Prediciton Performance

In [None]:
print('mate_in_one_acc:', sum(mates_found) / nb_pos)

In [None]:
sum(mates_5_top_found) / nb_pos

In [None]:
pgn_dataset_mate.tree()

In [None]:
metadata = np.array(pgn_dataset_mate['metadata'])
metadata[0, :]
metadata[1, :]

In [None]:
site_mate = metadata[1:, 1]

In [None]:
def clean_string(np_string):
    string = str(site_mate[i]).replace("b'", "")
    string = string.replace("'", "")
    string = string.replace('"', '')
    
    return string

In [None]:
import chess.svg
from IPython.display import SVG, HTML

## Show the result of the first 17 examples

In [None]:
for i in range(17):
    print(clean_string(site_mate[i]))
    pred, pred_moves, true_move, board, is_checkmate, is_mate_5_top, legal_move_cnt, mate_move_cnt = eval_pos(net, x_mate[i], yp_mate[i], verbose=True, select_policy_from_plane=select_policy_from_plane)
    pred_move = pred_moves[0]
    pred_arrow = chess.svg.Arrow(pred_move.from_square, pred_move.to_square)
    SVG(data=chess.svg.board(board=board, arrows=[pred_arrow], size=400))

## Show examples where it failed

In [None]:
for i in range(1000):
    pred, pred_moves, true_move, board, is_checkmate, is_mate_5_top, legal_move_cnt, mate_move_cnt = eval_pos(net, x_mate[i], yp_mate[i], verbose=False, select_policy_from_plane=select_policy_from_plane)
    if is_mate_5_top is False:
        print(clean_string(site_mate[i]))
        pred, pred_moves, true_move, board, is_checkmate, is_mate_5_top, legal_move_cnt, mate_move_cnt = eval_pos(net, x_mate[i], yp_mate[i], verbose=True, select_policy_from_plane=select_policy_from_plane)
        pred_move = pred_moves[0]
        pred_arrow = chess.svg.Arrow(pred_move.from_square, pred_move.to_square)
        SVG(data=chess.svg.board(board=board, arrows=[pred_arrow], size=400))