This Notebook will dive into one of the simplest challenges in CV, Fashion MNIST.
It will use Keras and TensorFlow to do so.

# Import Dependencies

In [1]:
# Load in TensorFlow and Keras
import tensorflow as tf
import keras

In [None]:
(train_img, train_GT), (test_img, test_GT) = keras.datasets.fashion_mnist.load_data()

In [16]:
train_img, val_img = tf.split(train_img, (59872, 128))

In [18]:
train_GT, val_GT = tf.split(train_GT, (59872, 128))

Simple MLP in Keras

In [19]:
model = keras.Sequential(layers = [
    keras.layers.Reshape((28, 28, 1), input_shape = (28, 28)),
    keras.layers.Conv2D(64, 3, activation = 'relu', padding = 'same'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(256, 3, activation = 'relu', padding = 'same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPool2D(),
    keras.layers.Conv2D(512, 3, activation = 'relu', padding = 'same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPool2D(),
    keras.layers.Conv2D(1024, 3, activation='relu', padding = 'same'),
    keras.layers.BatchNormalization(),
    keras.layers.AvgPool2D(pool_size=7, padding = 'valid'),
    keras.layers.Flatten(),
    keras.layers.Dense(10, activation = 'softmax')
])

Creating a CallBack in TensorFlow

In [41]:
class EarlyStopping(keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs = {}):
    if not hasattr(self, 'best_val_acc'):
      self.best_val_acc = logs['val_accuracy']
      self.fails = 0
    else:
      if logs['val_accuracy'] > self.best_val_acc:
        self.best_val_acc = logs['val_accuracy']
        self.fails = 0
      else:
        self.fails+= 1
        if self.fails == 5:
          self.model.stop_training = True

In [42]:
model.compile(optimizer = keras.optimizers.Adam(), loss = keras.losses.SparseCategoricalCrossentropy(), metrics = ['accuracy'])

In [None]:
model.fit(train_img, train_GT, batch_size = 2048, epochs = 10, validation_data = (val_img, val_GT), callbacks= [EarlyStopping()])

In [None]:
model.evaluate(test_img, test_GT)