In [None]:
import sys
import mxnet as mx
sys.path.append('../../../')
from DeepCrazyhouse.src.preprocessing.dataset_loader import load_pgn_dataset
from DeepCrazyhouse.src.runtime.color_logger import enable_color_logging
from DeepCrazyhouse.src.training.crossentropy import *

import DeepCrazyhouse.src.domain.variants.constants as constants
import DeepCrazyhouse.src.domain.variants.plane_policy_representation
import zarr
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
enable_color_logging()

In [None]:
start_indices, x, y_value, y_policy, _, _ = load_pgn_dataset(dataset_type="train",
                                                         part_id=0,
                                                         normalize=True,
                                                         verbose=True)

### This is how you would load the data from uncompressed form

data = zarr.load('/media/queensgambit/Volume/Deep_Learning/projects/CrazyAra/engine/src/rl/data_gpu_0.zarr')
start_indices = data['start_indices']
x = data['x']
y_policy = data['y_policy']
y_value = data['y_value']
y_best_move_q = data['y_best_move_q']

In [None]:
start_indices

### Check for possible NaN values

In [None]:
nan_idx = np.isnan(y_policy).argmax()
nan_idx //= y_policy.shape[1]
nan_idx

In [None]:
path = "/media/queensgambit/Volume/Deep_Learning/data/RL/weights/"

## load gluon model from MXNet checkpoint

In [None]:
def load_gluon_net_from_mxnet_checkpoint(symbol_file: str, param_file: str, ctx, input_name='data'):
    """
    Loads a gluon style net based on a checkpoint which was initially created in MXNet symbol format.
    The alternative standard gluon.nn.SymbolBlock.imports() returns an error that the parameter for the labels are missing.
    :param symbol_file: Filename of the symbol architecture to load
    :param param_file: Filename of the parameter weights
    :param ctx: Context to use
    :param input_name: Input name for the data of the model
    :return:
    """
    symbol = mx.sym.load(symbol_file)
    inputs = mx.sym.var(input_name, dtype='float32')
    value_out = symbol.get_internals()['value_tanh0_output']#['value_output']
    policy_out = symbol.get_internals()['flatten0_output']#['policy_output']
    policy_out = mx.sym.SoftmaxActivation(data=policy_out, name='policy')
    sym = mx.symbol.Group([value_out, policy_out])
    net = mx.gluon.SymbolBlock(sym, inputs)
    # we need to init all params first, otherwise we get an error that "*_label" has not been initialized
    #net.initialize(ctx=ctx)
    net.collect_params('.*weight|.*bias|.*gamma|.*beta|.*mean|.*var').load(param_file, ctx)
    return net

In [None]:
path = "/media/queensgambit/Volume/Deep_Learning/data/RL/weights/"

In [None]:
model_arch_path = path+"model-1.29188-1.000-symbol.json"
model_params_path = path+"model-1.29188-1.000-0008.params"

In [None]:
net = load_gluon_net_from_mxnet_checkpoint(model_arch_path, model_params_path, mx.gpu())

In [None]:
net.export('model-1.29188-1.000-softmax', epoch=8)

## be careful argmax() isn't necessarily the choosen move in game

In [None]:
idx = 0
for i in range(idx, idx+15):
    plt.imshow(plane_representation.get_plane_vis(x[i]), cmap='coolwarm_r')
    plt.show()
    
    max_idx = y_policy[i].argmax()
    max_val = y_policy[i][max_idx]
    y_policy[i][max_idx] = 0
    second_max_idx = y_policy[i].argmax()
    y_policy[i][max_idx] = max_val

    if i % 2 == 0:
        print(constants.LABELS[max_idx], y_policy[i][max_idx])
        print(constants.LABELS[second_max_idx], y_policy[i][second_max_idx])
    else:
        print(constants.LABELS_MIRRORED[max_idx], y_policy[i][max_idx])
        print(constants.LABELS_MIRRORED[second_max_idx], y_policy[i][second_max_idx])

    #print(y_best_move_q[i])
    print(y_value[i])

In [None]:
second_max_policy = y_policy[1]

In [None]:
second_max_policy[second_max_policy.argmax()] = 0

In [None]:
constants.LABELS_MIRRORED[second_max_policy.argmax()]

In [None]:
plt.plot(y_policy[1])

In [None]:
y_value[:14]

In [None]:
plt.plot(y_best_move_q[:14], 'o-')

In [None]:
plt.plot(y_value[:14], 'o-')

In [None]:
start_indices

In [None]:
len(y_value)

In [None]:
plt.plot(y_value[idx:idx+14], 'o-')