In [None]:
!pip install gpustat
!gpustat

In [None]:
import numpy as np
import os
import h5py
from tensorflow import keras

from google.colab import drive

drive.mount('/content/gdrive')

In [None]:
STATES_FILE_PATH = '/content/gdrive/My Drive/data_source/reversi_zero/states.npy'
POLICIES_FILE_PATH = '/content/gdrive/My Drive/data_source/reversi_zero/policies.npy'
VALUES_FILE_PATH = '/content/gdrive/My Drive/data_source/reversi_zero/values.npy'

MODEL_ARCHIVE_PATH = '/content/gdrive/My Drive/data_source/reversi_zero/models/'
LOG_PATH = '/content/gdrive/My Drive/data_source/reversi_zero/log/'

In [None]:
class NNModel(object):

    def __init__(self):
        self._create_model()
    
    def predict(self, state):
        result = self.model.predict(np.array([state]))
        return result[0][0], result[1][0]
    
    def fit(self, states, policies, values, batch_size=None, epochs=1, callbacks=[]):
        self.model.fit(
            x=states,
            y=[policies, values],
            batch_size=batch_size,
            epochs=epochs,
            callbacks=callbacks
        )
    
    def save(self, path):
        self.model.save(path)
    
    def load(self, path):
        loaded_model = keras.models.load_model(path)
        self.model.set_weights(loaded_model.get_weights())
    
    def clone(self):
        cloned_model = NNModel()
        cloned_model.model.set_weights(self.model.get_weights())
        return cloned_model

    def _create_model(self):
        model_input = keras.layers.Input(
            shape=(8, 8, 2),
            dtype=np.float32
        )
        
        # Residual blocks
        layer = NNModel._conv_batchnorm(
            filters=128, 
            kernel_size=(3, 3), 
            padding='same', 
            kernel_regularizer=keras.regularizers.l2(1e-4), 
            activation='relu'
        )(model_input)
        for _ in range(16):
            layer = NNModel._residual_block(
                filters=128, 
                kernel_size=(3, 3), 
                padding='same', 
                kernel_regularizer=keras.regularizers.l2(1e-4)
            )(layer)
        
        # Value generation
        value = layer
        value = NNModel._conv_batchnorm(
            filters=1, 
            kernel_size=(1, 1), 
            kernel_regularizer=keras.regularizers.l2(1e-4), 
            activation='relu'
        )(value)
        value = keras.layers.Flatten()(value)
        value = keras.layers.Dense(
            units=64,
            activation='relu',
            kernel_regularizer=keras.regularizers.l2(1e-4)
        )(value)
        value = keras.layers.Dense(
            units=1,
            activation='tanh',
            name='value'
        )(value)

        # Policy generation
        policy = layer
        policy = NNModel._conv_batchnorm(
            filters=2, 
            kernel_size=(1, 1), 
            kernel_regularizer=keras.regularizers.l2(1e-4), 
            activation='relu'
        )(policy)
        policy = keras.layers.Flatten()(policy)
        policy = keras.layers.Dense(
            units=8 * 8,
            activation='softmax',
            name='policy'
        )(policy)

        # Model assemble
        model = keras.models.Model(inputs=[model_input], outputs=[policy, value])
        self.model = model
        self._compile()
    
    def _compile(self):
        self.model.compile(
            optimizer=keras.optimizers.Adam(),
            loss=[keras.losses.categorical_crossentropy, keras.losses.mean_squared_error],
            loss_weights=[0.5, 0.5]
        )
    
    @staticmethod
    def _conv_batchnorm(filters, kernel_size, padding='valid', kernel_regularizer=None, activation='linear'):
        def structure(input_tensor):
            layer = keras.layers.Conv2D(
                filters=filters,
                kernel_size=kernel_size,
                padding=padding,
                kernel_regularizer=kernel_regularizer
            )(input_tensor)
            layer = keras.layers.BatchNormalization()(layer)
            layer = keras.layers.Activation(activation)(layer)
            return layer
        return structure
    
    @staticmethod
    def _residual_block(filters, kernel_size, padding='valid', kernel_regularizer=None):
        def structure(input_tensor):
            residual = input_tensor
            layer = NNModel._conv_batchnorm(
                filters=filters, 
                kernel_size=kernel_size, 
                padding=padding, 
                kernel_regularizer=keras.regularizers.l2(1e-4),
                activation='relu'
            )(input_tensor)
            layer = NNModel._conv_batchnorm(
                filters=filters, 
                kernel_size=kernel_size, 
                padding=padding, 
                kernel_regularizer=keras.regularizers.l2(1e-4),
            )(layer)
            layer = keras.layers.Add()([residual, layer])
            layer = keras.layers.Activation('relu')(layer)
            return layer
        return structure

In [None]:
states, policies, values = np.load(STATES_FILE_PATH), np.load(POLICIES_FILE_PATH), np.load(VALUES_FILE_PATH)

def augment(states, policies, values):
    states = np.concatenate([
        states, 
        states[:, ::-1, ::-1, :], 
        states[:, ::-1, :, :], 
        states[:, :, ::-1, :],
        np.rot90(states, k=1, axes=(1, 2)),
        np.rot90(states, k=2, axes=(1, 2)), 
        np.rot90(states, k=3, axes=(1, 2))
    ])
    policies = np.reshape(policies, (policies.shape[0], 8, 8))
    policies = np.concatenate([
        policies,
        policies[:, ::-1, ::-1],
        policies[:, ::-1, :],
        policies[:, :, ::-1],
        np.rot90(policies, k=1, axes=(1, 2)),
        np.rot90(policies, k=2, axes=(1, 2)), 
        np.rot90(policies, k=3, axes=(1, 2))
    ])
    policies = np.reshape(policies, (policies.shape[0], 8 * 8))
    values = np.concatenate([values] * 7)
    return states, policies, values

states, policies, values = augment(states, policies, values)
print(states.shape, policies.shape, values.shape)

In [None]:
model = NNModel()
model.fit(
    states, 
    policies,
    values,
    batch_size=256, 
    epochs=64,
    callbacks=[
        keras.callbacks.LearningRateScheduler(lambda ei, lr: lr / 2.0 if ei > 0 and (ei / 16).is_integer() else lr),
        keras.callbacks.TensorBoard(log_dir=LOG_PATH, update_freq=64),
        keras.callbacks.ModelCheckpoint(MODEL_ARCHIVE_PATH + 'model_checkpoint_{epoch:02d}.hdf5', period=8)
    ]
)