## Setup

In [None]:
import tensorflow as tf
import numpy as np
import os, sys

sys.path.append( os.path.abspath('..') )
import utils

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

In [None]:
x_train = (x_train.astype('float32') - 127.5) / 127.5
x_test = (x_test.astype('float32') - 127.5) / 127.5
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)

## Classifier Model

The model is inspired in the Inception architecture from the paper [Going Deeper with Convolutions](https://arxiv.org/abs/1409.4842).

In [None]:
model = tf.keras.Sequential([
    tf.keras.layers.InputLayer(input_shape=(28, 28, 1), name='input'),
    
    tf.keras.layers.Conv2D(filters=64, kernel_size=5, strides=2, padding='same', name='conv2d_0'),
    tf.keras.layers.LeakyReLU(0.2, name='leaky_relu_0'),
    tf.keras.layers.Dropout(0.5, name='dropout_0'),
    tf.keras.layers.BatchNormalization(name='batchnorm_0'),
    
    tf.keras.layers.Conv2D(filters=128, kernel_size=5, strides=2, padding='same', name='conv2d_1'),
    tf.keras.layers.LeakyReLU(0.2, name='leaky_relu_1'),
    tf.keras.layers.Dropout(0.5, name='dropout_1'),
    tf.keras.layers.BatchNormalization(name='batchnorm_1'),
    
    tf.keras.layers.Conv2D(filters=256, kernel_size=5, strides=1, padding='same', name='conv2d_2'),
    tf.keras.layers.LeakyReLU(0.2, name='leaky_relu_2'),
    tf.keras.layers.Dropout(0.5, name='dropout_2'),
    tf.keras.layers.BatchNormalization(name='batchnorm_2'),
    
    tf.keras.layers.GlobalAvgPool2D(name='features'),
    tf.keras.layers.Dense(64, name='dense'),
    tf.keras.layers.Dense(10, name='logits')
], name='mnist_classifier')

In [None]:
model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

In [None]:
model.summary()

## Training

The model is trained for 30 epochs and the one which achieved the best accuracy is saved and later used to evaluate the *Classifier Score* (CS) and *Fréchet Classifier Distance* (FCS) of the GAN models.

In [None]:
hist = model.fit(
    x_train, y_train, epochs=30, batch_size=32,
    validation_data=(x_test, y_test),
    callbacks=[
        utils.callback.SaveIfBestCallback(filename='mnist.h5', save_after=5)
    ]
)

In [None]:
print('Best model accuracy: {:.2f}'.format(100 * max(hist.history['val_accuracy'])))

Execute this last cell to load the model and calculate it's accuracy. It should be the same as the cell above.
It seems like tensorflow has some problem when calculating the accuracy from the loaded models in this case, so it is necessary to explicity compile it to measure the *sparse_categorical_accuracy*.

In [None]:
loaded = tf.keras.models.load_model('mnist.h5', compile=False)
loaded.compile(metrics=['sparse_categorical_accuracy'])
loaded.evaluate(x_test, y_test)