In [1]:
import numpy as np
import keras
import pickle
import zarr
import keras.backend as K
from keras.models import Model
from keras.layers import Input, Dense, Conv2D, Flatten, BatchNormalization, Activation
from keras import optimizers
from keras.callbacks import ModelCheckpoint
from keras.layers import *


class RL_Datapoint():
    def __init__(self, state, policy, values):
        self.state = state
        self.policy = policy
        self.values = values
        
def __read_file_data(ID, path='dataset/'):
    dataset = zarr.group(store=zarr.ZipStore(path + str(ID) +'.zip', mode="r"))
    X = np.array(dataset['states'])
    policies = np.array(dataset['policies'])
    values = np.array(dataset['values'])
    return X,policies,values

def generator(batch_size, datasetFileLength, path='dataset/'):
    while True:
        files_sequence = list(range(datasetFileLength))
        np.random.shuffle(files_sequence)

        for file_step, file_id in enumerate(files_sequence): 
            X, policies, values = __read_file_data(file_id)
            rand_indices = np.arange(len(X))
            np.random.shuffle(rand_indices)

            for i in range(int(len(X)/batch_size)):
                batch_indices = rand_indices[i*batch_size: (i+1)*batch_size]
                yield X[batch_indices],[policies[batch_indices],values[batch_indices]]

Using TensorFlow backend.


<h3>NN Architecture: ResNet</h3>

In [2]:
def getResidualNetwork(input_shape):
    
    channel_pos = 'channels_first'
    inp_shape = Input(input_shape,name='input1')
    x = Conv2D(256, kernel_size=(3,3), padding = 'same', input_shape=input_shape,data_format=channel_pos,name='conv2d_1')(inp_shape)
    x = BatchNormalization(axis=1,name='batch_normalization_1')(x)
    x_a1 = Activation('relu',name='activation_1')(x)

    x = Conv2D(256, kernel_size=(3,3),name ='conv2d_2' ,padding = 'same',data_format=channel_pos)(x_a1)
    x = BatchNormalization(axis=1, name = 'batch_normalization_2')(x)
    x = Activation('relu',name = 'activation_2')(x)
    x = Conv2D(256, kernel_size=(3,3), name = 'conv2d_3',padding = 'same',data_format=channel_pos)(x)
    x = BatchNormalization(axis=1, name = 'batch_normalization_3')(x)

    x = keras.layers.add([x,x_a1],name='add1')
    x_a2 = Activation('relu',name='activation_3')(x)
    x = Conv2D(256, kernel_size=(3,3),name = 'conv2d_4', padding = 'same',data_format=channel_pos)(x_a2)
    x = BatchNormalization(axis=1,name = 'batch_normalization_4')(x)
    x = Activation('relu',name='activation_4')(x)
    x = Conv2D(256, kernel_size=(3,3), name='conv2d_5',padding = 'same',data_format=channel_pos)(x)
    x = BatchNormalization(axis=1,name='batch_normalization_5')(x)

    x = keras.layers.add([x,x_a2],name='add_2')
    x_a3 = Activation('relu',name='activation_5')(x)
    x = Conv2D(256, kernel_size=(3,3),name='conv2d_6', padding = 'same',data_format=channel_pos)(x_a3)
    x = BatchNormalization(axis=1,name='batch_normalization_6')(x)
    x = Activation('relu',name='activation_6')(x)
    x = Conv2D(256, kernel_size=(3,3), name='conv2d_7',padding = 'same',data_format=channel_pos)(x)
    x = BatchNormalization(axis=1,name='batch_normalization_7')(x)

    x = keras.layers.add([x,x_a3],name='add_3')
    x_a4 = Activation('relu',name='activation_7')(x)
    x = Conv2D(256, kernel_size=(3,3), name = 'conv2d_8',padding = 'same',data_format=channel_pos)(x_a4)
    x = BatchNormalization(axis=1,name='batch_normalization_8')(x)
    x = Activation('relu',name='activation_8')(x)
    x = Conv2D(256, kernel_size=(3,3), name='conv2d_9',padding = 'same',data_format=channel_pos)(x)
    x = BatchNormalization(axis=1,name='batch_normalization_9')(x)

    x = keras.layers.add([x,x_a4],name='add4')
    x_a5 = Activation('relu',name='activation_9')(x)
    x = Conv2D(256, kernel_size=(3,3),name='conv2d_10', padding = 'same',data_format=channel_pos)(x_a5)
    x = BatchNormalization(axis=1,name='batch_normalization_10')(x)
    x = Activation('relu',name='activation_10')(x)
    x = Conv2D(256, kernel_size=(3,3),name='conv2d_11', padding = 'same',data_format=channel_pos)(x)
    x = BatchNormalization(axis=1,name='batch_normalization_11')(x)

    x = keras.layers.add([x,x_a5],name='add_5')
    x_a6 = Activation('relu',name='activation_11')(x)
    x = Conv2D(256, kernel_size=(3,3), name='conv2d_12',padding = 'same',data_format=channel_pos)(x_a6)
    x = BatchNormalization(axis=1,name='batch_normalization_12')(x)
    x = Activation('relu',name='activation_12')(x)
    x = Conv2D(256, kernel_size=(3,3), name='conv2d_13',padding = 'same',data_format=channel_pos)(x)
    x = BatchNormalization(axis=1,name='batch_normalization_13')(x)

    x = keras.layers.add([x,x_a6],name='add6')
    x_a7 = Activation('relu',name='activation_13')(x)
    x = Conv2D(256, kernel_size=(3,3), name='conv2d_14',padding = 'same',data_format=channel_pos)(x_a7)
    x = BatchNormalization(axis=1,name='batch_normalization_14')(x)
    x = Activation('relu',name='activation_14')(x)
    x = Conv2D(256, kernel_size=(3,3), name='conv2d_15',padding = 'same',data_format=channel_pos)(x)
    x = BatchNormalization(axis=1,name='batch_normalization_15')(x)

    x = keras.layers.add([x,x_a7],name='add_7')
    x_a8 = Activation('relu',name='activation_15')(x)
    x = Conv2D(1, kernel_size=(1,1),name='conv2d_17', padding = 'same',data_format=channel_pos)(x_a8)
    xb = BatchNormalization(axis=1,name='batch_normalization_17')(x)
    xConv = Conv2D(2, kernel_size=(1,1), padding = 'same',name='conv2d_16',data_format=channel_pos)(x_a8)
    xA = Activation('relu',name='activation_17')(xb)
    xb = BatchNormalization(axis=1,name='batch_normalization_16')(xConv)
    xF = Flatten(name='flatten_2')(xA)
    xA = Activation('relu',name='activation_16')(xb)

    dense_1 = Dense(256, activation='relu',name='dense_1')(xF)
    xF = Flatten(name='flatten_1')(xA)

    value = Dense(1, activation='tanh', name='value')(dense_1)
    policy = Dense(CLASSES_LEN, activation='softmax', name='policy')(xF)

    from keras.models import Model
    model = Model(inp_shape, [policy,value])

    model.summary()
    return model

<h3>Main Parameters of the NN</h3>

In [12]:
# ----------- parameters ----------------
files_len = 10                                # amount of zip files
file_ids = np.arange(files_len)               
# np.random.shuffle(file_ids)
train_ids = file_ids[:int(len(file_ids)*0.9)] # file_ids for training set
val_ids = file_ids[int(len(file_ids)*0.9):]   # file_ids for validation set 
filepath="models/model-{epoch:02d}.hdf5"      # path to save model
zip_length = 10000                            # amount of datapoints in a zip file
data_len = files_len * zip_length             # whole amount of datapoints in all zip files
batch_size = 256                              # batch size
steps_per_epoch = int((len(train_ids)*zip_length)/batch_size) # amount of batches in one epoch
CLASSES_LEN = 2272                            # amount of classes for policy
channel_pos = 'channels_last'          
dataset_path = 'dataset_tuning/'                     # relative directory to the dataset
inp_shape = (60,8,8)                          # shape of a datapoint

<h3>Build the model</h3>

In [4]:
model = getResidualNetwork(inp_shape)
sgd = optimizers.SGD(lr=0.000, momentum=0.9, decay=0.1/5, nesterov=False)

def acc_reg(y_true,y_pred):
    return K.constant(1) - K.square(K.mean((y_pred-y_true), axis=1))

model.compile(loss=['categorical_crossentropy','mean_squared_error'], optimizer=sgd,
              metrics=['accuracy', acc_reg], loss_weights=[0.999,0.001])

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input1 (InputLayer)             (None, 60, 8, 8)     0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 256, 8, 8)    138496      input1[0][0]                     
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 256, 8, 8)    1024        conv2d_1[0][0]                   
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 256, 8, 8)    0           batch_normalization_1[0][0]      
__________________________________________________________________________________________________
conv2d_2 (

In [13]:
val_id = files_len -1 
x_val, policies_val, values_val = __read_file_data(val_id,path=dataset_path)


<h3>Define callbacks</h3>

<h4>Callback for checkpoint and tensorboard</h4>

In [8]:
# callbacks 
from datetime import datetime
checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, save_best_only=True, save_weights_only=False, mode='auto', period=1)
logdir="logs/scalars/" + datetime.now().strftime("%Y%m%d-%H%M%S")
# tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir, update_freq='batch')

<h4>Callback for Learning rate</h4>

In [9]:
# learning rate
from  LearningRateScheduler import *
epochs = 20
batch_len = epochs * int(data_len/ (batch_size))
max_lr = 0.001*8
total_it = batch_len
min_lr = 0.0001
print('BatchLen: ', batch_len, ' - DataLen: ', data_len)
lr_schedule = OneCycleSchedule(start_lr=max_lr/8, max_lr=max_lr, cycle_length=total_it*.4, cooldown_length=total_it*.6, finish_lr=min_lr)
scheduler = LinearWarmUp(lr_schedule, start_lr=min_lr, length=total_it/30)
bt = BatchLearningRateScheduler(scheduler)

BatchLen:  7800  - DataLen:  100000


In [10]:
# callbacks_list = [checkpoint,tensorboard_callback,bt]
callbacks_list = [checkpoint,bt]


In [14]:
history = model.fit_generator(generator(batch_size,len(train_ids),path=dataset_path), steps_per_epoch=int((len(train_ids)*zip_length)/batch_size), callbacks=callbacks_list,
                    epochs=epochs, validation_data=(x_val, [policies_val,values_val]))


Epoch 1/20
  2/351 [..............................] - ETA: 1:26:21 - loss: 7.7351 - policy_loss: 7.7428 - value_loss: 0.0424 - policy_acc: 0.0000e+00 - policy_acc_reg: 1.0000 - value_acc: 0.9746 - value_acc_reg: 0.9576

KeyboardInterrupt: 