# MNIST Image classification using CNN in Keras

## 1. Import Modules

In [None]:
# Import MNIST dataset
from keras.datasets import mnist
# Import keras numpy utilities
from keras.utils import np_utils
# Import keras modules
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D
from keras.layers.advanced_activations import LeakyReLU
from keras.losses import categorical_crossentropy
from keras.optimizers import Adam
# Import numpy
import numpy as np
# Import matplotlib
import matplotlib.pyplot as plt

## 2. Read in MNIST data

In [None]:
# Read data
(x_train, y_train), (x_test, y_test) = mnist.load_data()

In [None]:
# Shapes of training set
print("Training set (images) shape: {shape}".format(shape=x_train.shape))
print("Training set (labels) shape: {shape}".format(shape=y_train.shape))

# Shapes of test set
print("Test set (images) shape: {shape}".format(shape=x_test.shape))
print("Test set (labels) shape: {shape}".format(shape=y_test.shape))

n_classes = 10

In [None]:
# Create dictionary of target classes
label_dict = {
    0: 'Zero',
    1: 'One',
    2: 'Two',
    3: 'Three',
    4: 'Four',
    5: 'Five',
    6: 'Six',
    7: 'Seven',
    8: 'Eight',
    9: 'Nine'
}

In [None]:
plt.figure(figsize=[5,5])

# Display first image in training data
plt.subplot(121)
img = x_train[0]
lbl = y_train[0]
plt.imshow(img, cmap='gray')
plt.title("(Label: " + str(label_dict[lbl]) + ")")

# Display second image in testing data
plt.subplot(122)
img = x_test[0]
lbl = y_test[0]
plt.imshow(img, cmap='gray')
plt.title("(Label: " + str(label_dict[lbl]) + ")")

## 3. Data preprocessing

In [None]:
# Convert the data format from uint8 to float32
X_train = x_train.astype('float32')
X_test = x_test.astype('float32')
# Normalize inputs from 0-255 to 0-1
X_train = X_train / 255
X_test = X_test / 255
# Reshape the input data from (28, 28) to (28, 28, 1)
X_train = X_train.reshape(-1, 28, 28, 1)
X_test = X_test.reshape(-1, 28, 28, 1)
print("Shape of training data: ", X_train.shape)
print("Shape of testing data: ", X_test.shape)
print("Type of training data element: ", type(X_train[0][0][0][0]))
print("Type of testing data element: ", type(X_test[0][0][0][0]))

# one-hot encoding using keras' numpy related utilities
Y_train = np_utils.to_categorical(y_train, n_classes)
Y_test = np_utils.to_categorical(y_test, n_classes)
print("Shape of training label after one-hot encoding: ", Y_train.shape)
print("Shape of testing label after one-hot encoding: ", Y_test.shape)

## 4. Training Parameters

In [None]:
epochs = 50
batch_size = 64

## 5. Model

### 5.1. Network

In [None]:
model = Sequential()
model.add(Conv2D(4, kernel_size=(3,3), activation='linear', input_shape=(28,28,1), padding='same'))
model.add(LeakyReLU(alpha=0.1))

model.add(MaxPooling2D((2, 2),padding='same'))

model.add(Conv2D(8, kernel_size=(3, 3), activation='linear',padding='same'))
model.add(LeakyReLU(alpha=0.1))

model.add(MaxPooling2D((2, 2),padding='same'))

model.add(Flatten())

model.add(Dense(8, activation='linear'))
model.add(LeakyReLU(alpha=0.1))           

model.add(Dense(n_classes, activation='softmax'))

model.summary()

### 5.2. Loss and Optimizer

In [None]:
model.compile(loss=categorical_crossentropy, optimizer=Adam(),metrics=['accuracy'])

## 6. Keras session

In [None]:
mnist_train = model.fit(X_train, Y_train, batch_size=batch_size,epochs=epochs,verbose=1,validation_data=(X_test, Y_test))

## 7. Visualize Loss

In [None]:
plt.plot(range(epochs), mnist_train.history['loss'], 'b', label='Training Loss')
plt.plot(range(epochs), mnist_train.history['val_loss'], 'r', label='Test Loss')
plt.title('Training and Test Loss')
plt.xlabel('Epochs ',fontsize=16)
plt.ylabel('Loss',fontsize=16)
plt.legend()
plt.figure()
plt.show()

plt.plot(range(epochs), mnist_train.history['acc'], 'b', label='Training Accuracy')
plt.plot(range(epochs), mnist_train.history['val_acc'], 'r', label='Test Accuracy')
plt.title('Training and Test Accuracy')
plt.xlabel('Epochs ',fontsize=16)
plt.ylabel('Accuracy',fontsize=16)
plt.legend()
plt.figure()
plt.show()