In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras.preprocessing.image import img_to_array, array_to_img, ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from keras.regularizers import l2

In [None]:
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Input, BatchNormalization, Add, Activation, AveragePooling2D, Dropout
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.losses import categorical_crossentropy

In [None]:
def load_dataset(dataset):
  if dataset == 'mnist':
    return tf.keras.datasets.mnist.load_data()
  elif dataset == 'fashion_mnist':
    return tf.keras.datasets.fashion_mnist.load_data()
  elif dataset == 'cifar-10':
    return tf.keras.datasets.cifar10.load_data()
  else:
    return -1

In [None]:
def preprocess_dataset(x_train, y_train, x_test, y_test):
  # expand the dimension (add the channel axis to MNIST and Fashion_MNIST)
  if len(x_train.shape) == 3 :
    x_train = tf.expand_dims(x_train, -1)
    x_test = tf.expand_dims(x_test, -1)
    
  # resize the image
  x_train = np.asarray([img_to_array(array_to_img(im, scale=False).resize((64,64))) for im in x_train])
  x_test = np.asarray([img_to_array(array_to_img(im, scale=False).resize((64,64))) for im in x_test])

  # Convert the labels to their one-hot representation
  y_train = to_categorical(y_train)
  y_test = to_categorical(y_test)

  # Normalise the dataset by mean subtraction
  x_train = x_train.astype('float32')
  x_test = x_test.astype('float32')
  x_train_mean = np.mean(x_train, axis=0)
  x_train -= x_train_mean
  x_test -= x_train_mean

  return x_train, y_train, x_test, y_test

In [None]:
def visualize(pred_labels, test_labels, test_images, dataset_name):
  # list to store the index of the success cases and fail cases
  success_cases = []
  fail_cases = []
  for i in range(len(test_labels)):
    if test_labels[i].argmax() == pred_labels[i].argmax() and len(success_cases)<5:
      success_cases.append(i)
    elif test_labels[i].argmax() != pred_labels[i].argmax() and len(fail_cases)<5:
      fail_cases.append(i)
    if len(success_cases) == 4 and len(fail_cases) == 4:
      break
  
  # cmap = gray from MNIST and fashion_mnist
  if len(test_images[0].shape) == 2:
    cmap = 'gray'
  else:
    cmap = None
  
  # plotting the success cases
  fig, ax = plt.subplots(nrows=2, ncols=2)
  for row in ax:
    for col in row:
        col.imshow(test_images[success_cases.pop(0)], interpolation='nearest', cmap=cmap)
  fig.suptitle('Success Cases of vgg16 on ' + dataset_name)
  plt.show()

  # plotting the fail cases
  fig, ax = plt.subplots(nrows=2, ncols=2)
  for row in ax:
    for col in row:
        col.imshow(test_images[fail_cases.pop(0)], interpolation='nearest', cmap=cmap)
  fig.suptitle('Fail Cases of vgg16 on ' + dataset_name)
  plt.show()

In [None]:
def model_vgg16(image_shape, num_category):
  # number of filters
  filters = [64, 128, 256, 512, 512]
  # number of conv layers
  layers = [2, 2, 3, 3, 3]
  # using tf.keras sequential API
  model = Sequential()
  model.add(Input(shape=image_shape))
  for l in layers:
    f = filters.pop(0)
    for _ in range(l):
      model.add(Conv2D(filters=f,kernel_size=(3,3),padding="same", activation="relu", kernel_initializer='he_normal',kernel_regularizer=l2(1e-4)))
    model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2)))
  model.add(Flatten())
  model.add(Dense(units=4096,activation="relu", kernel_initializer='he_normal'))
  model.add(Dropout(0.3))
  model.add(Dense(units=4096,activation="relu", kernel_initializer='he_normal'))
  model.add(Dropout(0.3))
  model.add(Dense(units=num_category, activation="softmax", kernel_initializer='he_normal'))
  return model

In [None]:
dataset_name = 'mnist'
(train_img, train_lbl), (test_img, test_lbl) = load_dataset(dataset_name)

In [None]:
train_images, train_labels, test_images, test_labels = preprocess_dataset(train_img, train_lbl, test_img, test_lbl)

In [None]:
augment_data = True
batch_size = 128
validation_percent = 0.1
epochs = 100
if augment_data:
  datagen = ImageDataGenerator(
      width_shift_range=0.1,
      height_shift_range=0.1,
      horizontal_flip=True,
      rotation_range=0,
      fill_mode='nearest',
      validation_split=validation_percent
  )
  train_iterator = datagen.flow(train_images, train_labels, batch_size=batch_size, subset='training')
  validation_iterator = datagen.flow(train_images, train_labels, batch_size=batch_size, subset='validation')

In [None]:
vgg16 = model_vgg16(image_shape = train_images[0].shape, num_category = len(train_labels[0]))
vgg16.compile(optimizer='adam', loss="categorical_crossentropy", metrics=['accuracy'])
vgg16.summary()

In [None]:
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=np.sqrt(0.1), patience=3, min_lr=0.5e-6)
early_stop = EarlyStopping(monitor='val_loss', patience=10, verbose=1, restore_best_weights=True)
if augment_data:
  history = vgg16.fit(train_iterator, epochs=epochs, steps_per_epoch=len(train_iterator), 
                                validation_data=validation_iterator, validation_steps=len(validation_iterator), 
                                callbacks=[reduce_lr, early_stop])
else:
  history = vgg16.fit(x=train_images, y=train_labels, epochs=epochs, batch_size=batch_size, 
                      validation_split=validation_percent,
                      callbacks=[reduce_lr, early_stop])

In [None]:
# summarize history for accuracy
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('model accuracy of vgg16 on ' + dataset_name)
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss of vgg16 on ' + dataset_name)
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

In [None]:
eval = vgg16.evaluate(test_images, test_labels, batch_size=batch_size)

In [None]:
predictions = vgg16.predict(test_images, batch_size=batch_size)

In [None]:
visualize(predictions, test_labels, test_img, dataset_name)