In [None]:
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
import numpy as np
import pandas as pd

In [None]:
(x_train, labels_train), (x_test, labels_test) = mnist.load_data()

In [None]:
# labels_train.shape
print(x_train.shape)
print(labels_train.shape)

In [None]:
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255

In [None]:
y_train = to_categorical(labels_train, 10)
y_test = to_categorical(labels_test, 10)

In [None]:
#Convolutional Layer
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)

In [None]:
#CNN
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPool2D, Dense, Flatten,Dropout

net = Sequential()
net.add(Conv2D(filters=32, kernel_size=(5,5), activation='relu',input_shape=(28,28,1)))
net.add(MaxPool2D(pool_size=(2, 2)))

net.add(Conv2D(32, (3, 3), activation='relu'))
net.add(MaxPool2D(pool_size=(2, 2)))

net.add(Flatten())
net.add(Dense(256, activation='relu'))
net.add(Dropout(rate=0.5))
net.add(Dense(10, activation='softmax'))

In [None]:
net.summary()

In [None]:
net.compile(loss='categorical_crossentropy', optimizer='adam')
history = net.fit(x_train, y_train,
validation_data=(x_test, y_test),
 epochs=30,
 batch_size=256)

In [None]:
import matplotlib.pyplot as plt
plt.figure()
plt.plot(history.history['loss'], label='training loss')
plt.plot(history.history['val_loss'], label='validation loss')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.legend()

In [None]:
#testing
import numpy as np
outputs=net.predict(x_test)
labels_predicted=np.argmax(outputs, axis=1)
misclassified=sum(labels_predicted!=labels_test)
print('Percentage misclassified = ',100*misclassified/labels_test.size)

In [None]:
outputs=net.predict(x_test)
labels_predicted=np.argmax(outputs, axis=1)
correct_classified=sum(labels_predicted==labels_test)
print('Percentage correctly classified MNIST=',100*correct_classified/labels_test.size)

In [None]:
# Look at confusion matrix
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import itertools


def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j],
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

# Predict the values from the validation dataset
Y_pred = net.predict(x_test)
# Convert predictions classes to one hot vectors
Y_pred_classes = np.argmax(Y_pred,axis = 1)
# Convert validation observations to one hot vectors
Y_true = np.argmax(y_test,axis = 1)
# compute the confusion matrix
confusion_mtx = confusion_matrix(Y_true, Y_pred_classes)
# plot the confusion matrix
plot_confusion_matrix(confusion_mtx, classes = range(10))

In [None]:
plt.figure(figsize=(8, 2))
for i in range(0,8):
 ax=plt.subplot(2,8,i+1)
 plt.imshow(x_test[i,:].reshape(28,28), cmap=plt.get_cmap('gray_r'))
 plt.title(labels_test[i])
 ax.get_xaxis().set_visible(False)
 ax.get_yaxis().set_visible(False)
for i in range(0,8):
#  output = net.predict(x_test[i,:].reshape(1, 784)) #if MLP
 output = net.predict(x_test[i,:].reshape(1, 28,28,1)) #if CNN
 output=output[0,0:]
 plt.subplot(2,8,8+i+1)
 plt.bar(np.arange(10.),output)
 plt.title(np.argmax(output))

In [None]:
# net.save("network_for_mnist.h5")