## Adapted from Victor Zhou's CNN published at [this blog post](https://victorzhou.com/blog/keras-cnn-tutorial/).

In [None]:
import math
import mnist
import numpy as np

# Using tf.keras for Ghost Batch Norm capability.
import tensorflow.keras.backend as K
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, BatchNormalization
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import LearningRateScheduler

import tensorflow as tf
# Disable INFO and WARNING messages from TensorFlow.
# Our Keras version (2.2.4 / TF 1.15.0) throws deprecation warnings.
tf.logging.set_verbosity(tf.logging.ERROR)

In [None]:
train_images = mnist.train_images() 
train_labels = mnist.train_labels()
test_images = mnist.test_images()
test_labels = mnist.test_labels()

# Normalize the images.
train_images = (train_images / 255) - 0.5
test_images = (test_images / 255) - 0.5

# Reshape the images.
train_images = np.expand_dims(train_images, axis=3)
test_images = np.expand_dims(test_images, axis=3)

num_filters = 8
filter_size = 3
pool_size = 2

In [None]:
# Build the model.
batch_size = 100       # TODO: make this less arbitrary.
ghost_batch_size = 10  # TODO: make this less arbitrary.

model = Sequential([
    # The authors required ghost batch normalization in their experiments.
    # The papers they reference use normalization preceding each convolutional layer.
    BatchNormalization(virtual_batch_size=ghost_batch_size),
    Conv2D(filters=num_filters, kernel_size=filter_size, input_shape=(28, 28, 1)),
    MaxPooling2D(pool_size=pool_size),
    Flatten(),
    Dense(units=10, activation='softmax'),
])

In [None]:
# Adapted from https://machinelearningmastery.com/using-learning-rate-schedules-deep-learning-models-python-keras/
def step_lr_decay(epoch):
    # This starts with a learning rate higher than the default learning rate in
    # Vanilla SGD. It drops by a factor of 5 each step, and I chose to drop after
    # every other epoch to make training not take too long on a laptop.
    # Think of this as a half-life equation.
    initial_lrate=0.15
    drop=0.2
    epochs_drop=2.0
    current_rate = initial_lrate * math.pow(drop, math.floor(epoch / epochs_drop))
    print("Epoch %d learning rate: %f" % (epoch + 1, current_rate))
    return current_rate

In [None]:
# Compile the model.
epochs = 6            # Run enough epochs to test the training schedule.
learning_rate = 0.00  # Will be overrided by the scheduler.
decay_rate = 0.00     # Default decay rate.
momentum = 0.0        # Default momentum.

sgd = SGD(lr=learning_rate, momentum=momentum, decay=decay_rate, nesterov=False)

model.compile(
    optimizer=sgd,  # Vanilla SGD experiment.
    loss='categorical_crossentropy',
    metrics=['accuracy'],
)

## Experiment 1: Vanilla SGD, decaying learning rate.

In [None]:
# Train the model.
decay_lr_scheduler = LearningRateScheduler(step_lr_decay)

exp1_results = model.fit(
    train_images,
    to_categorical(train_labels),
    epochs=epochs,
    batch_size=batch_size,
    validation_data=(test_images, to_categorical(test_labels)),
    callbacks=[decay_lr_scheduler]
)

## Experiment 2: Vanilla SGD, increasing batch size.

In [None]:
# Keras has no built in for this. 
# Code is from https://www.codementor.io/nitinsurya/how-to-re-initialize-keras-model-weights-et41zre2g
def reset_weights(model):
    session = K.get_session()
    for layer in model.layers: 
        if hasattr(layer, 'kernel_initializer'):
            layer.kernel.initializer.run(session=session)

In [None]:
reset_weights(model)
learning_rate = 0.07  # Set a new constant.
sgd = SGD(lr=learning_rate, momentum=momentum, decay=decay_rate, nesterov=False)

model.compile(
    optimizer=sgd,  # Vanilla SGD experiment.
    loss='categorical_crossentropy',
    metrics=['accuracy'],
)

# Just run each step independently. Keras has no callback for this.
model.fit(
    train_images,
    to_categorical(train_labels),
    epochs=2,
    batch_size=batch_size,
    validation_data=(test_images, to_categorical(test_labels))
)
model.fit(
    train_images,
    to_categorical(train_labels),
    epochs=2,
    batch_size=batch_size * 5,
    validation_data=(test_images, to_categorical(test_labels))
)
model.fit(
    train_images,
    to_categorical(train_labels),
    epochs=2,
    batch_size=batch_size * 25,
    validation_data=(test_images, to_categorical(test_labels))
)

## Experiment 3: Vanilla SGD, hybrid.

In [None]:
def hybrid_lr_decay(epoch):
    # Start off from the constant learning rate from the hybrid's initialization.
    initial_lrate=0.05
    drop=0.2
    epochs_drop=2.0
    current_rate = initial_lrate * math.pow(drop, math.floor(epoch / epochs_drop))
    print("Epoch %d learning rate: %f" % (epoch + 1, current_rate))
    return current_rate

hybrid_lr_scheduler = LearningRateScheduler(hybrid_lr_decay)

In [None]:
reset_weights(model)
learning_rate = 0.05  # Set a new constant.
sgd = SGD(lr=learning_rate, momentum=momentum, decay=decay_rate, nesterov=False)

model.compile(
    optimizer=sgd,  # Vanilla SGD experiment.
    loss='categorical_crossentropy',
    metrics=['accuracy'],
)

# Run a step, and then increase the batch size manually. The decay scheduler
# will decay the learning rate in the next step.
model.fit(
    train_images,
    to_categorical(train_labels),
    epochs=2,
    batch_size=batch_size,
    validation_data=(test_images, to_categorical(test_labels))
)
model.fit(
    train_images,
    to_categorical(train_labels),
    epochs=4,
    batch_size=batch_size * 5,
    validation_data=(test_images, to_categorical(test_labels)),
    callbacks=[hybrid_lr_scheduler]
)

## Experiment 4: SGD with momentum, decaying learning rate.