In [None]:
import sys
sys.path.append('../../../')

import glob
import logging
import numpy as np
import mxnet as mx
from mxnet import gluon

from DeepCrazyhouse.configs.main_config import main_config
from DeepCrazyhouse.configs.train_config import TrainConfig, TrainObjects
from DeepCrazyhouse.src.runtime.color_logger import enable_color_logging
from DeepCrazyhouse.src.domain.variants.constants import NB_POLICY_MAP_CHANNELS, NB_LABELS
from DeepCrazyhouse.src.domain.variants.plane_policy_representation import FLAT_PLANE_IDX

from DeepCrazyhouse.src.preprocessing.dataset_loader import load_xiangqi_dataset

from DeepCrazyhouse.src.training.lr_schedules.lr_schedules import *

from DeepCrazyhouse.src.domain.neural_net.architectures.rise_mobile_v2 import rise_mobile_v2_symbol
from DeepCrazyhouse.src.domain.neural_net.architectures.rise_mobile_v3 import rise_mobile_v3_symbol

from DeepCrazyhouse.src.training.trainer_agent import TrainerAgent, evaluate_metrics, acc_sign, reset_metrics
from DeepCrazyhouse.src.training.trainer_agent_mxnet import TrainerAgentMXNET, get_context

enable_color_logging()

print("mxnet version: ", mx.__version__)

# Main Configuration

In [None]:
for key in main_config.keys():
    print(key, "= ", main_config[key])

# Settings for training

In [None]:
tc = TrainConfig()
to = TrainObjects()

In [None]:
# Setting the context to GPU is strongly recommended
tc.context = "gpu" # Be sure to check the used devices!!!
tc.device_id = 0

# Used for reproducibility
tc.seed = 7

tc.export_weights = True
tc.log_metrics_to_tensorboard = True
tc.export_grad_histograms = True

# div factor is a constant which can be used to reduce the batch size and learning rate respectively
# use a value larger 1 if you enconter memory allocation errors
tc.div_factor = 2

# defines how often a new checkpoint will be saved and the metrics evaluated
# (batch_steps = 1000 means that every 1000 batches the validation set gets processed)
tc.batch_steps = 100 * tc.div_factor
# k_steps_initial defines how many steps have been trained before
# (k_steps_initial != 0 if you continue training from a checkpoint)
tc.k_steps_initial = 0

# these are the weights to continue training with
tc.symbol_file = None # 'model-0.81901-0.713-symbol.json'
tc.params_file = None #'model-0.81901-0.713-0498.params'

#typically if you half the batch_size, you should double the lr
tc.batch_size = int(1024 / tc.div_factor)

# optimization parameters
tc.optimizer_name = "nag"
tc.max_lr = 0.35 / tc.div_factor
tc.min_lr = 0.00001
tc.max_momentum = 0.95
tc.min_momentum = 0.8
# loads a previous checkpoint if the loss increased significanly
tc.use_spike_recovery = True
# stop training as soon as max_spikes has been reached
tc.max_spikes = 20
# define spike threshold when the detection will be triggered
tc.spike_thresh = 1.5
# weight decay
tc.wd = 1e-4
tc.dropout_rate = 0 #0.15
# weight the value loss a lot lower than the policy loss in order to prevent overfitting
tc.val_loss_factor = 0.01
tc.policy_loss_factor = 0.99
tc.discount = 1.0

tc.normalize = True
tc.nb_epochs = 7
# Boolean if potential legal moves will be selected from final policy output
tc.select_policy_from_plane = True 

# define if policy training target is one-hot encoded a distribution (e.g. mcts samples, knowledge distillation)
tc.sparse_policy_label = True
# define if the policy data is also defined in "select_policy_from_plane" representation
tc.is_policy_from_plane_data = False

# Decide between mxnet and gluon style for training
tc.use_mxnet_style = True 

# additional custom validation set files which will be logged to tensorboard
tc.variant_metrics = None #["chess960", "koth", "three_check"]

tc.name_initials = "Your Initials"

# enable data set augmentation
augment = True

In [None]:
mode = main_config["mode"]
ctx = get_context(tc.context, tc.device_id)

# if use_extra_variant_input is true the current active variant is passed two each residual block and
# concatenated at the end of the final feature representation
use_extra_variant_input = False

# iteration counter used for the momentum and learning rate schedule
cur_it = tc.k_steps_initial * tc.batch_steps 

# Fix the random seed
mx.random.seed(tc.seed)

###  Crete a ./logs and ./weights directory

In [None]:
!mkdir ./logs && mkdir ./weights

# Load Datasets

### Validation set

In [None]:
combined = True

if combined:
    start_indices_val_0, x_val_0, y_value_val_0, y_policy_val_0, dataset_0 = load_xiangqi_dataset(dataset_type="val",
                                                                                        part_id=0,
                                                                                        verbose=True,
                                                                                        normalize=tc.normalize)
    start_indices_val_1, x_val_1, y_value_val_1, y_policy_val_1, dataset_1 = load_xiangqi_dataset(dataset_type="val",
                                                                                        part_id=1,
                                                                                        verbose=True,
                                                                                        normalize=tc.normalize)
    # X
    nb_inputs = x_val_0.shape[0] + x_val_1.shape[0]
    nb_planes = x_val_0.shape[1]
    nb_rows = x_val_0.shape[2]
    nb_cols = x_val_0.shape[3]
    x_val = np.zeros((nb_inputs, nb_planes, nb_rows, nb_cols))
    x_val[:x_val_0.shape[0]] = x_val_0
    x_val[x_val_0.shape[0]:] = x_val_1

    # value targets
    nb_targets_value = y_value_val_0.shape[0] + y_value_val_1.shape[0]
    y_value_val = np.zeros((nb_targets_value,))
    y_value_val[:y_value_val_0.shape[0]] = y_value_val_0
    y_value_val[y_value_val_0.shape[0]:] = y_value_val_1

    # policy targets
    nb_targets_policy = y_policy_val_0.shape[0] + y_policy_val_1.shape[0]
    y_policy_val = np.zeros((nb_targets_policy,y_policy_val_0.shape[1]))
    y_policy_val[:y_policy_val_0.shape[0]] = y_policy_val_0
    y_policy_val[y_policy_val_0.shape[0]:] = y_policy_val_1
else:
    start_indices_val, x_val, y_value_val, y_policy_val, dataset = load_xiangqi_dataset(dataset_type="val",
                                                                                        part_id=0,
                                                                                        verbose=True,
                                                                                        normalize=tc.normalize)
if tc.normalize:
    assert x_val.max() <= 1.0, "Error: Normalization not working."

if tc.select_policy_from_plane:
    val_iter = mx.io.NDArrayIter({'data': x_val}, 
                                     {'value_label': y_value_val, 
                                      'policy_label': np.array(FLAT_PLANE_IDX)[y_policy_val.argmax(axis=1)]},
                                     tc.batch_size)
else:
    val_iter = mx.io.NDArrayIter({'data': x_val}, 
                                     {'value_label': y_value_val, 
                                      'policy_label': y_policy_val.argmax(axis=1)}, 
                                     tc.batch_size)

In [None]:
print("x_val.shape: ", x_val.shape)
print("y_value_val.shape: ", y_value_val.shape)
print("y_policy_val.shape: ", y_policy_val.shape)

### Training properties

In [None]:
len(x_val)

In [None]:
tc.nb_parts = len(glob.glob(main_config["planes_train_dir"] + "**/*"))
print("Parts training dataset: ", tc.nb_parts)

In [None]:
# one iteration is defined by passing 1 batch and doing backpropagation
if augment:
    nb_it_per_epoch = (len(x_val) * tc.nb_parts * 2) // tc.batch_size
else:
    nb_it_per_epoch = (len(x_val) * tc.nb_parts) // tc.batch_size
tc.total_it = int(nb_it_per_epoch * tc.nb_epochs)
print("Total iterations: ", tc.total_it)

# Learning Rate Schedule

In [None]:
to.lr_schedule = OneCycleSchedule(start_lr=tc.max_lr/8, 
                               max_lr=tc.max_lr, 
                               cycle_length=tc.total_it*.3, 
                               cooldown_length=tc.total_it*.6, 
                               finish_lr=tc.min_lr)
to.lr_schedule = LinearWarmUp(to.lr_schedule, start_lr=tc.min_lr, length=tc.total_it/30)
plot_schedule(to.lr_schedule, iterations=tc.total_it)

# Momentum Schedule

In [None]:
to.momentum_schedule = MomentumSchedule(to.lr_schedule, tc.min_lr, tc.max_lr, tc.min_momentum, tc.max_momentum)
plot_schedule(to.momentum_schedule, iterations=tc.total_it, ylabel='Momentum')

# Define NN model / Load pretrained model

In [None]:
input_shape = x_val[0].shape
print("Input shape: ", input_shape)

In [None]:
bc_res_blocks = [3] * 5 # 13
if tc.symbol_file is None:
    # channels_operating_init, channel_expansion
    symbol = rise_mobile_v2_symbol(channels=256, channels_operating_init=512, channel_expansion=0, channels_value_head=8,
                      channels_policy_head=NB_POLICY_MAP_CHANNELS, value_fc_size=256, bc_res_blocks=bc_res_blocks, res_blocks=[], act_type='relu',
                      n_labels=NB_LABELS, grad_scale_value=tc.val_loss_factor, grad_scale_policy=tc.policy_loss_factor, select_policy_from_plane=tc.select_policy_from_plane,
                      use_se=True, dropout_rate=tc.dropout_rate, use_extra_variant_input=use_extra_variant_input)
else:
    symbol = mx.sym.load("weights/" + symbol_file)

In [None]:
"""
bc_res_blocks = [3] * 13 
if tc.symbol_file is None:
    # channels_operating_init, channel_expansion
    symbol = rise_mobile_v2_symbol(channels=256, channels_operating_init=128, channel_expansion=64, channels_value_head=8,
                      channels_policy_head=NB_POLICY_MAP_CHANNELS, value_fc_size=256, bc_res_blocks=bc_res_blocks, res_blocks=[], act_type='relu',
                      n_labels=NB_LABELS, grad_scale_value=tc.val_loss_factor, grad_scale_policy=tc.policy_loss_factor, select_policy_from_plane=tc.select_policy_from_plane,
                      use_se=True, dropout_rate=tc.dropout_rate, use_extra_variant_input=use_extra_variant_input)
else:
    symbol = mx.sym.load("weights/" + symbol_file)
"""

# Network summary

In [None]:
display(mx.viz.plot_network(
        symbol,
        shape={'data':(1, input_shape[0], input_shape[1], input_shape[2])},
        node_attrs={"shape":"oval","fixedsize":"false"}
    ))

In [None]:
mx.viz.print_summary(
        symbol,
        shape={'data':(1, input_shape[0], input_shape[1], input_shape[2])},
    )

# Initialize weights if no pretrained weights are used

In [None]:
# create a trainable module on compute context
model = mx.mod.Module(symbol=symbol, context=ctx, label_names=['value_label', 'policy_label'])
model.bind(for_training=True, data_shapes=[('data', (tc.batch_size, input_shape[0], input_shape[1], input_shape[2]))],
          label_shapes=val_iter.provide_label)
model.init_params(mx.initializer.Xavier(rnd_type='uniform', factor_type='avg', magnitude=2.24))
if tc.params_file:
    model.load_params("weights/" + tc.params_file)    

# Metrics

In [None]:
metrics_mxnet = [
mx.metric.MSE(name='value_loss', output_names=['value_output'], label_names=['value_label']),
mx.metric.CrossEntropy(name='policy_loss', output_names=['policy_output'],
                                            label_names=['policy_label']),
mx.metric.create(acc_sign, name='value_acc_sign', output_names=['value_output'],
                                         label_names=['value_label']),
mx.metric.Accuracy(axis=1, name='policy_acc', output_names=['policy_output'],
                                       label_names=['policy_label'])
]
metrics_gluon = {
'value_loss': mx.metric.MSE(name='value_loss', output_names=['value_output']),
'policy_loss': mx.metric.CrossEntropy(name='policy_loss', output_names=['policy_output'],
                                            label_names=['policy_label']),
'value_acc_sign': mx.metric.create(acc_sign, name='value_acc_sign', output_names=['value_output'],
                                         label_names=['value_label']),
'policy_acc': mx.metric.Accuracy(axis=1, name='policy_acc', output_names=['policy_output'],
                                       label_names=['policy_label'])
}
to.metrics = metrics_mxnet

# Training Agent

In [None]:
train_agent = TrainerAgentMXNET(model, symbol, val_iter, tc, to, use_rtpt=False, augment=augment)

# Performance Pre-Training

In [None]:
print(model.score(val_iter, to.metrics))

# Training

In [None]:
(k_steps_final, value_loss_final, policy_loss_final, value_acc_sign_final, val_p_acc_final), (k_steps_best, val_loss_best, val_p_acc_best) = train_agent.train(cur_it)