# Adapted from Victor Zhou's CNN published at [this blog post](https://victorzhou.com/blog/keras-cnn-tutorial/).

In [None]:
import numpy as np
import mnist

# Using tf.keras for Ghost Batch Norm capability.
import tensorflow.keras.backend as K  # For viewing "backend" parameter values of the model.
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.utils import to_categorical

import tensorflow as tf
# Disable INFO and WARNING messages from TensorFlow.
# Our Keras version (2.2.4 / TF 1.15.0) throws deprecation warnings.
tf.logging.set_verbosity(tf.logging.ERROR)

In [None]:
train_images = mnist.train_images() 
train_labels = mnist.train_labels()
test_images = mnist.test_images()
test_labels = mnist.test_labels()

# Normalize the images.
train_images = (train_images / 255) - 0.5
test_images = (test_images / 255) - 0.5

# Reshape the images.
train_images = np.expand_dims(train_images, axis=3)
test_images = np.expand_dims(test_images, axis=3)

num_filters = 8
filter_size = 3
pool_size = 2

In [None]:
# Build the model.
batch_size = 100       # TODO: make this less arbitrary.
ghost_batch_size = 10  # TODO: make this less arbitrary.

model = Sequential([
    # The authors required ghost batch normalization in their experiments.
    # The papers they reference use normalization preceding each convolutional layer.
    BatchNormalization(virtual_batch_size=ghost_batch_size),
    Conv2D(filters=num_filters, kernel_size=filter_size, input_shape=(28, 28, 1)),
    MaxPooling2D(pool_size=pool_size),
    Flatten(),
    Dense(units=10, activation='softmax'),
])

In [None]:
# Compile the model.
epochs = 1            # TODO: do enough full-data passes to establish a training schedule.
learning_rate = 0.01  # Default SGD learning rate.
decay_rate = 0.00     # Default decay rate.
momentum = 0.0        # Default momentum.

sgd = SGD(lr=learning_rate, momentum=momentum, decay=decay_rate, nesterov=False)

model.compile(
    optimizer=sgd,  # Vanilla SGD experiment.
    loss='categorical_crossentropy',
    metrics=['accuracy'],
)

In [None]:
# Train the model.
results = model.fit(
    train_images,
    to_categorical(train_labels),
    epochs=epochs,
    batch_size=batch_size,
    validation_data=(test_images, to_categorical(test_labels)),
)

In [None]:
results.history["val_acc"]