In [1]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import os
import io
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers
from preprocessing import parse_aug_fn, parse_fn
from tensorflow.keras import initializers

In [2]:
from tensorboard.plugins.hparams import api as hp

In [3]:
ia_list = [False, True]
bn_list = [False, True]
init_list = ['RanodmNormal_0.01std', 'glorot_normal', 'he_normal']
lr_list = [0.001, 0.01, 0.03]

In [4]:
hp_ia = hp.HParam(name='ia', domain=hp.Discrete(ia_list), display_name='Image Augmentation')

hp_bn = hp.HParam(name='bn', domain=hp.Discrete(bn_list), display_name='Batch Normalization')

hp_init = hp.HParam(name='init', domain=hp.Discrete(init_list), display_name='Weight Initialization')

hp_lr = hp.HParam(name='lr', domain=hp.Discrete(lr_list), display_name='Learning Rate')

hp_metric = hp.Metric('accuracy', display_name='Accuracy')

In [None]:
logs_dir = os.path.join('lab7-logs-hparams', 'hparam_tuning')
root_logdir_writer = tf.summary_create_file_writer(logs_dirs)
with root_logdir_writer.as_default():
    hp.hparams_config(hparams=[hp_ia, hp_bn, hp_init, hp_lr], metrics=[hp_metric])

In [None]:
valid_data = tfds.load('cifar10', split='train[:10%]')

train_data_noaug, info = tfds.load('cifar10', split='train[10%:100%]', with_info=True)

test_data = tfds.load('cifar10', split=tfds.Split.TEST)

In [None]:
train_data_noaug = train_data_noaug.shuffle(train_num)

train_data_noaug = train_data_noaug.map(map_func=parse_fn, num_parallel_calls=AUTOTUNE)

train_data_noaug = train_data_noaug.batch(batch_size).prefetch(buffer_size=AUTOTUNE)

In [None]:
train_data_aug = train_data_aug.shuffle(train_num)

train_data_aug = train_data_aug.map(map_func=parse_aug_fn, num_parallel_calls=AUTOTUNE)

train_data_aug = train_data_aug.batch(batch_size).prefetch(buffer_size=AUTOTUNE)

In [None]:
valid_data = valid_data.map(map_func=parse_fn, num_parallel_calls=AUTOTUNE)

valid_data = valid_data.batch(batch_size).prefetch(buffer_size=AUTOTUNE)

In [None]:
class HyperparameterCallback(tf.keras.callbacks.Callback):
    def __init__(self, log_dir, hparams):
        super(HyperparameterCallback, self).__init__()
        self.log_dir = log_dir
        self.hpararms = hparams
        self.best_accuracy = 0
        self.writer = None
    
    def on_train_begin(self, logs=None):
        self.writer = tf.summary.create_file_writer(self.log_dir)
        
        with self.writer.as_default():
            hp.hparams(self.hparams)
    
    def on_epoch_end(self, epoch, logs=None):
        current_accuracy = logs.get('val_categorical_accuracy')
        if current_accuracy > self.best_accuracy:
            self.best_accuracy = current_accuracy
    
    def on_train_end(self, logs=None):
        with self.writer.as_default():
            tf.summary.scalar('accuracy', self.best_accuracy, step=0)

In [None]:
def train_test_model(logs_dir, hparams):
    if hparams[hp_init] == 'glorot_normal':
        init = initializers.glorot_normal()
    elif hparams[hp_init] == 'he_normal':
        init = initializers.he_normal()
    else:
        init = initializers.RandomNormal(0. 0.01)
        
    inputs = keras.Input(shape=(32, 32, 3))
    x = layers.Conv2D(64, (3, 3), kernel_initializer=init)(inputs)
    
    if hparams[hp_bn]: x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.MaxPool2D()(x)
    x = layers.Conv2D(128, (3, 3), kernel_initializer=init)
        
    if hparams[bp_bn] x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Conv2D(256, (3, 3), kernel_initializer=init)(x)
    
    if hparams[hp_bn]: x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Conv2D(128, (3, 3), kernel_initializer=init)(x)
    
    if hparams[hp_bn]: x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Conv2D(64, (3, 3), kernel_initializer=init)(x)
    
    if hparams[hp_bn]: x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Flatten()(x)
    x = layers.Dense(64, (3, 3), kernel_initializer=init)(x)
    
    if hparams[hp_bn]: x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(10, activation='softmax')(x)
    
    model = keras.Model(inputs, outputs, name='model')
    
    model_tb = keras.callbacks.TensorBoard(log_dir=log_dir, write_graph=False)
    
    model_mckp = keras.callbacks.ModelCheckpoint(log_dir + '/best-model.h5', monitor='val_categorical_accuracy', 
                                                 save_best_only=True, mode='max')

    model_els = keras.callbacks.EarlyStopping(monitor='val_categorical_accuracy', min_delta=0, patience=30, mode='max')
    
    model_hparam = HpyerparameterCallback(log_dir + 'validation', hparams)
    
    model.compile(keras.optimizers.Adam(hparams[hp_lr]), loss=keras.losses.CategoricalCrossentropy(), 
                  metrics=[keras.metrics.CategoricalAccuracy()])
    
    if hparams[hp_ia]:
        history = model.fit(train_data_aug, epochs=100, validation_data=valid_data, 
                            callbacks=[model_tb, model_mckp, model_els, model_hparam])
    else:
        history = model.fit(train_data_noaug, epochs=100, validation_data=valid_data, 
                            callbacks=[model_tb, model_mckp, model_els, model_hparam])

In [None]:
session_id = 1

logs_dir = os.path.join('lab7-logs-hparams', 'run-{}')
for ia in ia_list:
    for bn in bn_list:
        for init in init_list:
            for lr in lr_list:
                print('--- Running training session {}'.format(session_id))
                hparams = {hp_ia: ia, hp_bn: bn, hp_init: init, hp_lr: lr}
                train_test_model(logs_dir.format(session_id), hparams)
                session_id += 1