In [1]:
# Tensorflow
import tensorflow as tf
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Input, Dense, Flatten, Conv2D, Dropout, MaxPool2D, Activation
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l2
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint
# TF extensions
from tensorboard.plugins.hparams import api as hp
# Python
import os
import json
from functools import partial
# Custom
from utils import preview

In [2]:
# Add before any TF calls - initializes the keras global outside of any tf.functions
temp = tf.zeros([4, 32, 32, 3])
preprocess_input(temp);
AUTOTUNE = tf.data.experimental.AUTOTUNE
# strategy = tf.distribute.MirroredStrategy()
# print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

In [3]:
def load_image(file_path):
    """
    Load an image from the file path and extract the label from the directory name
    """
    img = tf.io.read_file(file_path)
    img = tf.image.decode_png(img, channels=3)
    img = tf.image.convert_image_dtype(img, tf.float32)
    img = preprocess_input(img)
    label = tf.strings.split(file_path, os.path.sep)[-3]
    label = (label == 'stickie')
    return img, label

def read_dataset(path, batch_size=32):
    """
    Read training dataset
    """
    # Load data
    dataset = tf.data.Dataset.list_files(os.path.join(path, '*/images/*.png'))
    dataset = dataset.map(load_image, num_parallel_calls=AUTOTUNE)
    # Repeat, shuffle, batch and prefetch
    dataset = dataset.repeat(None).shuffle(100).batch(batch_size).prefetch(AUTOTUNE)
    
    # Determine how many steps to run per epoch from the data description
    data_split = path.strip('/').split('/')[-1]
    data_description = json.load(open(os.path.join(path, '../DATA_DESCRIPTION.json')))
    num_examples = data_description[data_split]
    num_steps = num_examples // batch_size

    return dataset, num_steps

In [4]:
train_path = '../data/processed/200428_092427/train/'
read_dataset(train_path)

(<PrefetchDataset shapes: ((None, None, None, 3), (None,)), types: (tf.float32, tf.bool)>,
 18)

In [5]:
# Custom Hparams callbacks
class HparamsMetricCallback(tf.keras.callbacks.Callback):
    """
    Metric callback for Hparams dashboard
    Eager execution mode only (there might be a way to use @tf.function)
    """
    def __init__(self, metric, log_dir):
        """
        Arguments:
        - metric - str - validation metric (should correspond to a metric used in `model.compile`)
        - log_dir - str - log directory to store the metric (should be same dir as Tensorboard)
        
        Example:
        ```
        model.compile(..., metrics=['accuracy'])
        tensorboard_cb = Tensorboard(log_dir=log_dir)
        hparams_metric_cb = HparamsMetricCallback(metric='val_accuracy', log_dir=log_dir)
        ```
        """
        self.metric = metric
        self.log_dir = log_dir

    def on_epoch_end(self, epoch, logs):
        """
        This function will automatically be called during a model.fit() call
        Creates a tf.summary from the validation metric stored in the training logs
        """
        with tf.summary.create_file_writer(self.log_dir).as_default():
            tf.summary.scalar(self.metric, logs[self.metric], epoch)

            
def create_hparams_callbacks(log_dir, opt_metric, hparams):
    """
    Create the two callbacks necessary to use hparams in Tensorboard
    """
    # Hparams metric callback to log the validation score
    hparams_metric_cb = HparamsMetricCallback(
        metric=opt_metric,
        log_dir=log_dir
    )
    # Hparams callback to log the hyperparameter values
    with tf.summary.create_file_writer(log_dir).as_default():
        hp.hparams_config(
            hparams=[hp.HParam(hparam)for hparam in hparams],
            metrics=[hp.Metric(opt_metric)]
        )
    hparams_cb = hp.KerasCallback(
        writer=log_dir,
        hparams={hparam: args[hparam] for hparam in hparams}
    )
    return hparams_metric_cb, hparams_cb

In [6]:
def get_model(args, metrics):
    """
    Create trainable model initialised from VGG-16 pretrained on ImageNet
    """
    # Pre-trained model
    vgg = VGG16(weights='imagenet', input_tensor=Input(shape=(224,224,3)), include_top=False)
    vgg.trainable = False
    for layer in vgg.layers:
        layer.trainable = False
    
    # Add trainable output layer
    output = vgg.layers[-1].output
    output = Dense(1, activation='sigmoid')(Flatten()(output))
    model = Model(vgg.input, output)
    
    # Compile
    model.compile(
        loss="binary_crossentropy",
        optimizer=Adam(learning_rate=args['learning_rate']),
        metrics=metrics
    )

    return model

In [7]:
#get_model({'learning_rate':0.1}, ['accuracy']).summary()

In [20]:
def train_and_evaluate(args):
    """
    Main training function
    Training logs and model checkpoints will be stored in args['job_dir']

    Arguments:
    - args - dict - Training parameters.
      Should contain:
        - 'learning_rate'     - float - initial learning rate for training
        - 'l2_regularisation' - float - regularisation used for dense (fully connected) layers
        - 'batch_size'        - int   - mini-batch size used using training (Adam optimisation)
        - 'epochs'            - int   - number of training epochs
        - 'job_dir'           - str   - job directory used to store the logs and model checkpoints
    """
    # Training parameters
    metrics = ['accuracy']
    opt_metric = 'val_accuracy'
    hparams = ['learning_rate']
    log_dir = os.path.join(args['job_dir'], 'training-logs')
    model_dir = os.path.join(args['job_dir'], 'model-weights.tf')

    # Model definition
    model = get_model(args, metrics)

    # Callback definition
    tensorboard_cb = TensorBoard(
        log_dir=log_dir
    )
    checkpoint_cb = ModelCheckpoint(
        filepath=model_dir,
        save_format='tf',
        monitor=opt_metric,
        mode='max',
        save_freq='epoch',
        save_weights_only=True,
        save_best_only=True,
        verbose=0
    )
    hparams_metric_cb, hparams_cb = create_hparams_callbacks(log_dir, opt_metric, hparams)
    callbacks = [tensorboard_cb, checkpoint_cb, hparams_metric_cb, hparams_cb]

    # Load data
    train_dir, val_dir = [os.path.join(args['data_dir'], split) for split in ['train', 'val']]
    train_dataset, train_steps = read_dataset(train_dir, args['batch_size'])
    val_dataset, val_steps = read_dataset(val_dir, args['batch_size'])

    # Train model
    model.fit(
        train_dataset,
        epochs=int(args['epochs'] * args['epoch_split']),
        steps_per_epoch=train_steps // args['epoch_split'],
        validation_data=val_dataset,
        validation_steps=3,
        callbacks=callbacks,
        verbose=1
    )

In [21]:
! rm -r ../train-output  # remove logs from previous training session

In [22]:
for learning_rate in [0.001, 0.01]:
    args = {
        'learning_rate': learning_rate,
        'batch_size': 16,
        'epochs': 0.1,
        'epoch_split': 100,  # split epoch to see training progress more frequently
        'job_dir': '../train-output',
        'data_dir': '../data/processed/200428_095708'
    }

    train_and_evaluate(args)

Train for 5 steps, validate for 3 steps
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Train for 5 steps, validate for 3 steps
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [11]:
# ! poetry run tensorboard --logdir='train-output/training-logs'
# ! tensorboard --logdir='../train-output/training-logs'