## Training script for the CNN 

Loads in the converted plane representation of the pgn files, defines the network architecture and starts the training process. Checkpoints of the weights are saved if there's an improvement in the validation loss.
The training performance metrics (e.g. losses, accuracies...) are exported to tensorboard and can be checked during training.
* author: QueensGambit

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%reload_ext autoreload

In [None]:
from __future__ import print_function
import os
import sys
sys.path.insert(0,'../../../')
import glob
import chess
import shutil
import logging
import numpy as np
from mxnet import nd
from copy import deepcopy
import mxnet as mx
from mxnet import gluon
try:
    import mxnet.metric as metric
except ModuleNotFoundError:
    import mxnet.gluon.metric as metrics

from DeepCrazyhouse.src.domain.variants.input_representation import board_to_planes, planes_to_board
from DeepCrazyhouse.src.domain.variants.output_representation import policy_to_moves, policy_to_best_move, policy_to_move
from DeepCrazyhouse.src.preprocessing.dataset_loader import load_pgn_dataset
from DeepCrazyhouse.src.runtime.color_logger import enable_color_logging
from DeepCrazyhouse.src.domain.neural_net.architectures.a0_resnet import AlphaZeroResnet
from DeepCrazyhouse.src.domain.neural_net.architectures.mxnet_alpha_zero import alpha_zero_symbol
from DeepCrazyhouse.src.domain.neural_net.architectures.rise_mobile_v2 import rise_mobile_v2_symbol
from DeepCrazyhouse.src.domain.neural_net.architectures.rise_mobile_v3 import rise_mobile_v3_symbol
from DeepCrazyhouse.src.domain.neural_net.architectures.preact_resnet_se import preact_resnet_se
from DeepCrazyhouse.configs.main_config import main_config
from DeepCrazyhouse.configs.train_config import TrainConfig, TrainObjects
from DeepCrazyhouse.src.training.trainer_agent import TrainerAgent, evaluate_metrics, acc_sign
from DeepCrazyhouse.src.training.trainer_agent_mxnet import TrainerAgentMXNET, get_context, prepare_policy
from DeepCrazyhouse.src.training.lr_schedules.lr_schedules import *
from DeepCrazyhouse.src.domain.variants.plane_policy_representation import FLAT_PLANE_IDX
from DeepCrazyhouse.src.domain.variants.constants import NB_POLICY_MAP_CHANNELS, NB_LABELS
from DeepCrazyhouse.src.domain.neural_net.onnx.convert_to_onnx import convert_mxnet_model_to_onnx

enable_color_logging()
%matplotlib inline

## Settings

In [None]:
tc = TrainConfig()
to = TrainObjects()

In [None]:
# set the context on CPU, switch to GPU if there is one available (strongly recommended for training)
tc.context = "gpu"
tc.device_id = 0

# set a specific seed value for reproducibility
tc.seed = 7 # 42

tc.export_weights = True
tc.log_metrics_to_tensorboard = True
tc.export_grad_histograms = True

# directory to write and read weights, logs, onnx and other export files
tc.export_dir = "./"

tc.div_factor = 1  # div factor is a constant which can be used to reduce the batch size and learning rate respectively
# use a value greater 1 if you encounter memory allocation errors

# batch_steps = 1000 means for example that every 1000 batches the validation set gets processed
tc.batch_steps = 1000 * tc.div_factor # this defines how often a new checkpoint will be saved and the metrics evaluated
# k_steps_initial defines how many steps have been trained before
# (k_steps_initial != 0 if you continue training from a checkpoint)
tc.k_steps_initial = 0
# these are the weights to continue training with
tc.symbol_file = None # 'model-0.81901-0.713-symbol.json'
tc.params_file = None #'model-0.81901-0.713-0498.params'

tc.batch_size = int(1024 / tc.div_factor) # 1024 # the batch_size needed to be reduced to 1024 in order to fit in the GPU 1080Ti
#4096 was originally used in the paper -> works slower for current GPU
# 2048 was used in the paper Mastering the game of Go without human knowledge and fits in GPU memory
#typically if you half the batch_size, you should double the lr

# optimization parameters
tc.optimizer_name = "nag"
tc.max_lr = 0.35 / tc.div_factor #0.01 # default lr for adam
tc.min_lr = 0.00001
tc.max_momentum = 0.95
tc.min_momentum = 0.8
# loads a previous checkpoint if the loss increased significanly
tc.use_spike_recovery = True
# stop training as soon as max_spikes has been reached
tc.max_spikes = 20
# define spike threshold when the detection will be triggered
tc.spike_thresh = 1.5
# weight decay
tc.wd = 1e-4
tc.dropout_rate = 0 #0.15
# weight the value loss a lot lower than the policy loss in order to prevent overfitting
tc.val_loss_factor = 0.01
tc.policy_loss_factor = 0.99
tc.discount = 1.0

tc.normalize = True # define whether to normalize input data to [0,1]
tc.nb_training_epochs = 7 # define how many epochs the network will be trained
tc.select_policy_from_plane = True # Boolean if potential legal moves will be selected from final policy output
tc.use_mxnet_style = True  # Decide between mxnet and gluon style for training

# additional custom validation set files which will be logged to tensorboard
to.variant_metrics = None # ["chess960", "koth", "three_check"]
# if use_extra_variant_input is true the current active variant is passed two each residual block and

# ratio for mixing the value return with the corresponding q-value
# for a ratio of 0 no q-value information will be used
tc.q_value_ratio = 0

# define if policy training target is one-hot encoded a distribution (e.g. mcts samples, knowledge distillation)
tc.sparse_policy_label = True
# define if the policy data is also defined in "select_policy_from_plane" representation
tc.is_policy_from_plane_data = False
tc.name_initials = "JC"

In [None]:
mode = main_config["mode"]
ctx = get_context(tc.context, tc.device_id)
# concatenated at the end of the final feature representation
use_extra_variant_input = False
cur_it = tc.k_steps_initial * tc.batch_steps # iteration counter used for the momentum and learning rate schedule
# Fixing the random seed
mx.random.seed(tc.seed)

In [None]:
mx.__version__

### Create logs and weights directory

In [None]:
if not os.path.exists(tc.export_dir + "logs"):
    os.mkdir(tc.export_dir + "logs")
if not os.path.exists(tc.export_dir + "weights"):
    os.mkdir(tc.export_dir + "weights")

### Show the config files

In [None]:
print(main_config)

In [None]:
print(tc)

In [None]:
print(to)

### Load the dataset-files

### Validation Dataset (which is used during training)

In [None]:
s_idcs_val, x_val, yv_val, yp_val, plys_to_end, pgn_datasets_val = load_pgn_dataset(dataset_type='val', part_id=0,
                                                                           verbose=True, normalize=tc.normalize)
if tc.discount != 1:
    yv_val *= tc.discount**plys_to_end

if tc.use_mxnet_style:
    if tc.select_policy_from_plane:
        val_iter = mx.io.NDArrayIter({'data': x_val}, {'value_label': yv_val, 'policy_label': np.array(FLAT_PLANE_IDX)[yp_val.argmax(axis=1)]}, tc.batch_size)
    else:
        val_iter = mx.io.NDArrayIter({'data': x_val}, {'value_label': yv_val, 'policy_label': yp_val.argmax(axis=1)}, tc.batch_size)
else:
    val_dataset = gluon.data.ArrayDataset(nd.array(x_val), nd.array(yv_val), nd.array(prepare_policy(yp_val, tc.select_policy_from_plane, tc.sparse_policy_label, tc.is_policy_from_plane_data)))
    val_data = gluon.data.DataLoader(val_dataset, tc.batch_size, shuffle=False, num_workers=tc.cpu_count)

In [None]:
tc.nb_parts = len(glob.glob(main_config['planes_train_dir'] + '**/*'))
tc.nb_parts

In [None]:
nb_it_per_epoch = (len(x_val) * tc.nb_parts) // tc.batch_size # calculate how many iterations per epoch exist
# one iteration is defined by passing 1 batch and doing backprop
tc.total_it = int(nb_it_per_epoch * tc.nb_training_epochs)
tc.total_it

### Define a Learning Rate schedule

In [None]:
to.lr_schedule = OneCycleSchedule(start_lr=tc.max_lr/8, max_lr=tc.max_lr, cycle_length=tc.total_it*.3, cooldown_length=tc.total_it*.6, finish_lr=tc.min_lr)
to.lr_schedule = LinearWarmUp(to.lr_schedule, start_lr=tc.min_lr, length=tc.total_it/30)

logging.getLogger().setLevel(logging.WARNING)
plot_schedule(to.lr_schedule, iterations=tc.total_it)
logging.getLogger().setLevel(logging.DEBUG)

### Momentum schedule

In [None]:
to.momentum_schedule = MomentumSchedule(to.lr_schedule, tc.min_lr, tc.max_lr, tc.min_momentum, tc.max_momentum)
plot_schedule(to.momentum_schedule, iterations=tc.total_it, ylabel='Momentum')

### Create the model

In [None]:
input_shape = x_val[0].shape
input_shape

In [None]:
try:
    del net
except:
    pass

### Define the NN model / Load the pretrained model

In [None]:
symbol = None

In [None]:
#net = AlphaZeroResnet(n_labels=2272, channels=256, channels_value_head=8, channels_policy_head=81, num_res_blocks=19, value_fc_size=256, bn_mom=0.9, act_type='relu', select_policy_from_plane=select_policy_from_plane)

In [None]:
#net = alpha_zero_resnet(n_labels=2272, channels=256, channels_value_head=1, channels_policy_head=81, num_res_blocks=19, value_fc_size=256, bn_mom=0.9, act_type='relu')

In [None]:
#symbol = alpha_zero_symbol(num_filter=256, channels_value_head=4, channels_policy_head=81, workspace=1024, value_fc_size=256, num_res_blocks=19, bn_mom=0.9, act_type='relu',
#                            n_labels=2272, grad_scale_value=0.01, grad_scale_policy=0.99, select_policy_from_plane=select_policy_from_plane)

bc_res_blocks = [3] * 13
if tc.symbol_file is None:
    symbol = rise_mobile_v2_symbol(channels=256, channels_operating_init=128, channel_expansion=64, channels_value_head=8,
                      channels_policy_head=NB_POLICY_MAP_CHANNELS, value_fc_size=256, bc_res_blocks=bc_res_blocks, res_blocks=[], act_type='relu',
                      n_labels=NB_LABELS, grad_scale_value=tc.val_loss_factor, grad_scale_policy=tc.policy_loss_factor, select_policy_from_plane=tc.select_policy_from_plane,
                      use_se=True, dropout_rate=tc.dropout_rate, use_extra_variant_input=use_extra_variant_input)
else:
    symbol = mx.sym.load(tc.export_dir + "weights/" + tc.symbol_file)

kernels = [3] * 15
kernels[7] = 5
kernels[11] = 5
kernels[12] = 5
kernels[13] = 5

se_types = [None] * len(kernels)
se_types[5] = "eca_se"
se_types[8] = "eca_se"
se_types[12] = "eca_se"
se_types[13] = "eca_se"
se_types[14] = "eca_se"

symbol = rise_mobile_v3_symbol(channels=256, channels_operating_init=224, channel_expansion=32, act_type='relu',
                               channels_value_head=8, value_fc_size=256,
                               channels_policy_head=NB_POLICY_MAP_CHANNELS,
                               grad_scale_value=tc.val_loss_factor, grad_scale_policy=tc.policy_loss_factor, 
                               dropout_rate=tc.dropout_rate, select_policy_from_plane=True,
                               kernels=kernels, se_types=se_types, use_avg_features=False)

In [None]:
kernels = [3,3,3,3,3,3,5,5]

se_types = [
    None, # 1
    None, # 2
    None,  # 3
    "eca_se",  # 4
    None, # 5
    None,  # 6
    None, # 7
    "eca_se", # 8
] 

symbol = preact_resnet_se(channels=288, act_type='relu',
                          channels_value_head=8, value_fc_size=256,
                          channels_policy_head=NB_POLICY_MAP_CHANNELS,
                          grad_scale_value=tc.val_loss_factor, grad_scale_policy=tc.policy_loss_factor, 
                          dropout_rate=tc.dropout_rate, select_policy_from_plane=True,
                          kernels=kernels, se_types=se_types, use_avg_features=True, use_raw_features=True)

### Convert MXNet Symbol to Gluon Network

In [None]:
if not tc.use_mxnet_style and symbol is not None:
    inputs = mx.sym.var('data', dtype='float32')
    value_out = symbol.get_internals()[main_config['value_output']+'_output']
    policy_out = symbol.get_internals()[main_config['policy_output']+'_output']
    sym = mx.symbol.Group([value_out, policy_out])
    net = mx.gluon.SymbolBlock(sym, inputs)

## Network summary

In [None]:
if not tc.use_mxnet_style:
    print(net)

In [None]:
if symbol is not None:
    display(mx.viz.plot_network(
        symbol,
        shape={'data':(1, input_shape[0], input_shape[1], input_shape[2])},
        node_attrs={"shape":"oval","fixedsize":"false"}
    ))
else:
    display(mx.viz.plot_network(
        net(mx.sym.var('data'))[1],
        shape={'data':(1, input_shape[0], input_shape[1], input_shape[2])},
        node_attrs={"shape":"oval","fixedsize":"false"}
    ))

In [None]:
if tc.use_mxnet_style:
    mx.viz.print_summary(
        symbol,
        shape={'data':(1, input_shape[0], input_shape[1], input_shape[2])},
    )
else:
    mx.viz.print_summary(
    net(mx.sym.var('data'))[1], 
    shape={'data':(1, input_shape[0], input_shape[1], input_shape[2])},
    ) 

## Initialize the weights 
(only needed if no pretrained weights are used)

In [None]:
# create a trainable module on compute context
if tc.use_mxnet_style:
    model = mx.mod.Module(symbol=symbol, context=ctx, label_names=['value_label', 'policy_label'])
    model.bind(for_training=True, data_shapes=[('data', (tc.batch_size, input_shape[0], input_shape[1], input_shape[2]))],
             label_shapes=val_iter.provide_label)
    model.init_params(mx.initializer.Xavier(rnd_type='uniform', factor_type='avg', magnitude=2.24))
    if tc.params_file:
        model.load_params(tc.export_dir + "weights/" + tc.params_file)
else:    
    # Initializing the parameters
    for param in net.collect_params('.*gamma|.*moving_mean|.*moving_var'):
        net.params[param].initialize(mx.initializer.Constant(1), ctx=ctx)
    for param in net.collect_params('.*beta|.*bias'):
        net.params[param].initialize(mx.initializer.Constant(0), ctx=ctx)
    for param in net.collect_params('.*weight'):
        net.params[param].initialize(mx.init.Xavier(rnd_type='uniform', factor_type='avg', magnitude=2.24), ctx=ctx)

    if tc.params_file:
        net.collect_params().load(tc.export_dir + "weights/" + tc.params_file, ctx)
    net.hybridize()

## Define the metrics to use

In [None]:
metrics_mxnet = [
metric.MSE(name='value_loss', output_names=['value_output'], label_names=['value_label']),
metric.CrossEntropy(name='policy_loss', output_names=['policy_output'],
                                            label_names=['policy_label']),
metric.create(acc_sign, name='value_acc_sign', output_names=['value_output'],
                                         label_names=['value_label']),
metric.Accuracy(axis=1, name='policy_acc', output_names=['policy_output'],
                                       label_names=['policy_label'])
]
metrics_gluon = {
'value_loss': metric.MSE(name='value_loss', output_names=['value_output']),
'policy_loss': metric.CrossEntropy(name='policy_loss', output_names=['policy_output'],
                                            label_names=['policy_label']),
'value_acc_sign': metric.create(acc_sign, name='value_acc_sign', output_names=['value_output'],
                                         label_names=['value_label']),
'policy_acc': metric.Accuracy(axis=1, name='policy_acc', output_names=['policy_output'],
                                       label_names=['policy_label'])
}
if tc.use_mxnet_style:
    to.metrics = metrics_mxnet
else:
    to.metrics = metrics_gluon

## Define a training agent

In [None]:
if tc.use_mxnet_style:
    train_agent = TrainerAgentMXNET(model, symbol, val_iter, tc, to, use_rtpt=True)
else:
    train_agent = TrainerAgent(net, val_data, tc, to, use_rtpt=True)

## Performance Pre-Training

In [None]:
if tc.use_mxnet_style:
    print(model.score(val_iter, to.metrics))

## Start the training process

In [None]:
(k_steps_final, value_loss_final, policy_loss_final, value_acc_sign_final, val_p_acc_final), \
    (k_steps_best, val_metric_values_best) = train_agent.train(cur_it)

## Export the last model state

In [None]:
prefix = tc.export_dir + "weights/model-%.5f-%.3f" % (policy_loss_final, val_p_acc_final)

if tc.use_mxnet_style:
    # the export function saves both the architecture and the weights
    model.save_checkpoint(prefix, epoch=k_steps_final)
else:
    # the export function saves both the architecture and the weights
    net.export(prefix, epoch=k_steps_final)
    logging.info("Saved checkpoint to %s-%04d.params", prefix, k_steps_final)

## Print validation metrics for best model

In [None]:
print(val_metric_values_best)

## Load the best model once again

In [None]:
# delete the current net object form memory
if not tc.use_mxnet_style:
    del net
    del model

In [None]:
val_loss_best = val_metric_values_best["loss"]
val_p_acc_best = val_metric_values_best["policy_acc"]

model_name = "model-%.5f-%.3f" % (val_loss_best, val_p_acc_best)
model_prefix = tc.export_dir + "weights/" + model_name
model_arch_path = '%s-symbol.json' % model_prefix
model_params_path = '%s-%04d.params' % (model_prefix, k_steps_best)
print('load current best model:', model_params_path)

symbol = mx.sym.load(model_arch_path)
inputs = mx.sym.var('data', dtype='float32')
value_out = symbol.get_internals()[main_config['value_output']+'_output']
policy_out = symbol.get_internals()[main_config['policy_output']+'_output']
sym = mx.symbol.Group([value_out, policy_out])
net = mx.gluon.SymbolBlock(sym, inputs)
net.collect_params().load(model_params_path, ctx)

In [None]:
print('best val_loss: %.5f with v_policy_acc: %.5f at k_steps_best %d' % (val_loss_best, val_p_acc_best, k_steps_best))

## Copy best model & convert to onnx

In [None]:
if not os.path.exists(tc.export_dir + "best-model"):
    os.mkdir(tc.export_dir + "best-model")
    
best_model_prefix = tc.export_dir + "best-model/" + model_name
best_model_arch_path = '%s-symbol.json' % best_model_prefix
best_model_params_path = '%s-%04d.params' % (best_model_prefix, k_steps_best)

shutil.copy(model_arch_path, best_model_arch_path)
shutil.copy(model_params_path, best_model_params_path)

convert_mxnet_model_to_onnx(best_model_arch_path, best_model_params_path, 
                            ["value_out_output", "policy_out_output"], 
                            tuple(input_shape), tuple([1, 8, 16]), True)

In [None]:
print("Saved json, weight & onnx files of the best model to %s" % (tc.export_dir + "best-model"))

## Show move predictions

In [None]:
idx = 0

In [None]:
board = planes_to_board(x_val[idx], normalized_input=tc.normalize, mode=mode)

print(chess.COLOR_NAMES[board.turn])
if board.uci_variant == "crazyhouse":
    print(board.pockets)
board

In [None]:
def predict_single(net, x, select_policy_from_plane=False):
    
    out = [None, None]
    pred = net(mx.nd.array(np.expand_dims(x, axis=0), ctx=ctx))
    if select_policy_from_plane:
        pred[1] = pred[1][:, FLAT_PLANE_IDX]
    pred[1] = pred[1].softmax()
    out[0] = pred[0].asnumpy()
    out[1] = pred[1].asnumpy()
    
    return out

In [None]:
pred = predict_single(net, x_val[0], tc.select_policy_from_plane)
pred

In [None]:
pred = predict_single(net, x_val[0], tc.select_policy_from_plane)

In [None]:
policy_to_best_move(board, yp_val[idx])

In [None]:
opts = 5
selected_moves, probs = policy_to_moves(board, pred[1][0])
selected_moves[:opts]

In [None]:
plt.barh(range(opts)[::-1], probs[:opts])
ax = plt.gca()
ax.set_yticks(range(opts)[::-1])
ax.set_yticklabels(selected_moves[:opts])

In [None]:
board = planes_to_board(x_val[0], normalized_input=True, mode=mode)
board.push_uci('e2e4')
board.push_uci('e7e5')
board.push_uci('f1c4')
board.push_uci('b8c6')
board.push_uci('d1h5')
x_scholar_atck = board_to_planes(board, normalize=tc.normalize, mode=mode)
board

In [None]:
pred = predict_single(net, x_scholar_atck, tc.select_policy_from_plane)

selected_moves, probs = policy_to_moves(board, pred[1][0])
plt.barh(range(opts)[::-1], probs[:opts])
ax = plt.gca()
ax.set_yticks(range(opts)[::-1])
ax.set_yticklabels(selected_moves[:opts])

In [None]:
board.push(selected_moves[0])
board

### Performance on test dataset


In [None]:
s_idcs_test, x_test, yv_test, yp_test, _, pgn_datasets_test = load_pgn_dataset(dataset_type='test', part_id=0,
                                                                               verbose=True, normalize=True)
yp_test = prepare_policy(y_policy=yp_test, select_policy_from_plane=tc.select_policy_from_plane,
                          sparse_policy_label=False,
                          is_policy_from_plane_data=tc.is_policy_from_plane_data)
test_dataset = gluon.data.ArrayDataset(nd.array(x_test), nd.array(yv_test), nd.array(yp_test.argmax(axis=1)))
test_data = gluon.data.DataLoader(test_dataset, batch_size=tc.batch_size, shuffle=True, num_workers=tc.cpu_count)

In [None]:
metrics = metrics_gluon
evaluate_metrics(metrics, test_data, net, nb_batches=None, sparse_policy_label=True, ctx=ctx,
                 apply_select_policy_from_plane=tc.select_policy_from_plane)

### Show result on mate-in-one problems

In [None]:
s_idcs_mate, x_mate, yv_mate, yp_mate, _, pgn_dataset_mate = load_pgn_dataset(dataset_type='mate_in_one', part_id=0,
                                                                              verbose=True, normalize=tc.normalize)
yp_mate_new = prepare_policy(y_policy=yp_mate, select_policy_from_plane=tc.select_policy_from_plane,
                          sparse_policy_label=False,
                          is_policy_from_plane_data=tc.is_policy_from_plane_data)

In [None]:
mate_dataset = mx.gluon.data.dataset.ArrayDataset(nd.array(x_mate), nd.array(yv_mate), nd.array(yp_mate_new.argmax(axis=1)))
mate_data = mx.gluon.data.DataLoader(mate_dataset, batch_size=tc.batch_size, num_workers=tc.cpu_count)

### Mate In One Performance

In [None]:
metrics = metrics_gluon
evaluate_metrics(metrics, mate_data, net, sparse_policy_label=True, ctx=ctx,
                 apply_select_policy_from_plane=tc.select_policy_from_plane)

### Show some example mate problems

In [None]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

### Evaluate Performance

In [None]:
def eval_pos(net, x_mate, yp_mate, verbose=False, select_policy_from_plane=False):
    
    board = planes_to_board(x_mate, normalized_input=tc.normalize, mode=mode)
    if verbose is True:
        print("{0}'s turn".format(chess.COLOR_NAMES[board.turn]))
        if board.uci_variant == "crazyhouse":
            print("black/white {0}".format(board.pockets))
    pred = predict_single(net, x_mate, select_policy_from_plane=select_policy_from_plane)
    
    true_move = policy_to_move(yp_mate, is_white_to_move=board.turn)
    
    opts = 5
    pred_moves, probs = policy_to_moves(board, pred[1][0])
    pred_moves = pred_moves[:opts]
    
    legal_move_cnt = board.legal_moves.count()
    mate_move_cnt = str(board.legal_moves).count('#')
    
    is_mate_5_top = False
    
    for pred_move in pred_moves:
        board_5_top = deepcopy(board)
        board_5_top.push(pred_move)
        if board_5_top.is_checkmate() is True:
            is_mate_5_top = True
            break
    
    board.push(pred_moves[0])
    
    is_checkmate = False
    if board.is_checkmate() is True:
        is_checkmate = True
        
    filtered_pred = sorted(pred[1][0], reverse=True)
    
    if verbose is True:
        plt.barh(range(opts)[::-1], filtered_pred[:opts])
        ax = plt.gca()
        ax.set_yticks(range(opts)[::-1])
        ax.set_yticklabels(pred_moves)
        plt.title('True Move:' + str(true_move) +
                 '\nEval:' + str(pred[0][0]))
        plt.show()
    
    return pred, pred_moves, true_move, board, is_checkmate, is_mate_5_top, legal_move_cnt, mate_move_cnt

In [None]:
nb_pos = len(x_mate)
mates_found = []
mates_5_top_found = []
legal_mv_cnts = []
mate_mv_cnts = []

for i in range(nb_pos):
    pred, pred_moves, true_move, board, is_mate, is_mate_5_top, legal_mv_cnt, mate_mv_cnt= eval_pos(net, x_mate[i], yp_mate[i], select_policy_from_plane=tc.select_policy_from_plane)
    mates_found.append(is_mate)
    legal_mv_cnts.append(legal_mv_cnt)
    mate_mv_cnts.append(mate_mv_cnt)
    mates_5_top_found.append(is_mate_5_top)

In [None]:
np.array(mate_mv_cnts).mean()

In [None]:
np.array(legal_mv_cnts).mean()

### Random Guessing Baseline

In [None]:
np.array(mate_mv_cnts).mean() / np.array(legal_mv_cnts).mean()

### Prediciton Performance

In [None]:
print('mate_in_one_acc:', sum(mates_found) / nb_pos)

In [None]:
sum(mates_5_top_found) / nb_pos

In [None]:
pgn_dataset_mate.tree()

In [None]:
metadata = np.array(pgn_dataset_mate['metadata'])
metadata[0, :]
metadata[1, :]

In [None]:
site_mate = metadata[1:, 1]

In [None]:
def clean_string(np_string):
    string = str(site_mate[i]).replace("b'", "")
    string = string.replace("'", "")
    string = string.replace('"', '')
    
    return string

In [None]:
import chess.svg
from IPython.display import SVG, HTML

## Show the result of the first 17 examples

In [None]:
for i in range(17):
    print(clean_string(site_mate[i]))
    pred, pred_moves, true_move, board, is_checkmate, is_mate_5_top, legal_move_cnt, mate_move_cnt = eval_pos(net, x_mate[i], yp_mate[i], verbose=True, select_policy_from_plane=tc.select_policy_from_plane)
    pred_move = pred_moves[0]
    pred_arrow = chess.svg.Arrow(pred_move.from_square, pred_move.to_square)
    SVG(data=chess.svg.board(board=board, arrows=[pred_arrow], size=400))

## Show examples where it failed

In [None]:
mate_missed = 0
for i in range(1000):
    pred, pred_moves, true_move, board, is_checkmate, is_mate_5_top, legal_move_cnt, mate_move_cnt = eval_pos(net, x_mate[i], yp_mate[i], verbose=False, select_policy_from_plane=tc.select_policy_from_plane)
    if is_mate_5_top is False:
        mate_missed += 1
        print(clean_string(site_mate[i]))
        pred, pred_moves, true_move, board, is_checkmate, is_mate_5_top, legal_move_cnt, mate_move_cnt = eval_pos(net, x_mate[i], yp_mate[i], verbose=True, select_policy_from_plane=tc.select_policy_from_plane)
        pred_move = pred_moves[0]
        pred_arrow = chess.svg.Arrow(pred_move.from_square, pred_move.to_square)
        SVG(data=chess.svg.board(board=board, arrows=[pred_arrow], size=400))
    if mate_missed == 15:
        break