# MNIST handwritten digits classification with RNNs

In this notebook, we'll train a recurrent neural network (RNN) to classify MNIST digits using **Tensorflow** (version $\ge$ 2.0 required) with the **Keras API**. 

This notebook builds on the MNIST-MLP notebook, so the recommended order is to go through the MNIST-MLP notebook before starting with this one. 

First, the needed imports.

In [None]:
%matplotlib inline

from pml_utils import show_failures

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers 
from tensorflow.keras.utils import plot_model, to_categorical

from distutils.version import LooseVersion as LV

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

print('Using Tensorflow version: {}, and Keras version: {}.'.format(tf.__version__, tf.keras.__version__))
assert(LV(tf.__version__) >= LV("2.0.0"))

Next, let's load and process the MNIST dataset. First time we may have to download the data, which can take a while.

In [None]:
from tensorflow.keras.datasets import mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()
nb_classes = 10
img_rows, img_cols = 28, 28

X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255

# one-hot encoding:
Y_train = to_categorical(y_train, nb_classes)
Y_test = to_categorical(y_test, nb_classes)

print()
print('MNIST data loaded: train:',len(X_train),'test:',len(X_test))
print('X_train:', X_train.shape)
print('y_train:', y_train.shape)
print('Y_train:', Y_train.shape)

### Images as sequences

Note that in this notebook we are using *a sequence model* for image classification.  Therefore, we consider here an image to be a sequence of (pixel) input vectors.

More exactly, we consider each MNIST digit image (of size 28x28 pixels) to be a sequence of length 28 (number of image rows) with a 28-dimensional input vector (each image row, having 28 columns) associated with each time step. 

### Initialization

Now we are ready to create a recurrent model.  Keras contains three types of recurrent layers:

 * `SimpleRNN`, a fully-connected RNN where the output is fed back to input.
 * `LSTM`, the Long-Short Term Memory unit layer.
 * `GRU`, the Gated Recurrent Unit layer.

See https://keras.io/layers/recurrent/ for more information.

In [None]:
# Number of hidden units to use:
nb_units = 50

inputs = keras.Input(shape=(img_rows, img_cols))

# Recurrent layers supported: SimpleRNN, LSTM, GRU:
x = layers.SimpleRNN(nb_units)(inputs)

# To stack multiple RNN layers, all RNN layers except the last one need
# to have "return_sequences=True".  An example of using two RNN layers:
#x = layers.SimpleRNN(16, return_sequences=True)(inputs)
#x = layers.SimpleRNN(32)(x)

outputs = layers.Dense(units=nb_classes, activation='softmax')(x)
model = keras.Model(inputs=inputs, outputs=outputs,
                    name="rnn_model")
model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

print(model.summary())

In [None]:
plot_model(model, show_shapes=True)

### Learning

Now let's train the RNN model. Note that we do not need the `reshape()` function as in the MLP case. 

This is a relatively complex model, so training (especially with LSTM and GRU layers) can be considerably slower than with MLPs. 

In [None]:
%%time

epochs = 3

history = model.fit(X_train, 
                    Y_train, 
                    epochs=epochs, 
                    batch_size=128,
                    verbose=2)

In [None]:
plt.figure(figsize=(5,3))
plt.plot(history.epoch,history.history['loss'])
plt.title('loss')

plt.figure(figsize=(5,3))
plt.plot(history.epoch,history.history['accuracy'])
plt.title('accuracy');

### Inference

With enough training epochs and a large enough model, the test accuracy should exceed 98%.  

You can compare your result with the state-of-the art [here](http://rodrigob.github.io/are_we_there_yet/build/classification_datasets_results.html).  Even more results can be found [here](http://yann.lecun.com/exdb/mnist/). 

In [None]:
%%time
scores = model.evaluate(X_test, Y_test, verbose=2)
print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))

We can now take a closer look at the results using the `show_failures()` helper function.

Here are the first 10 test digits the RNN classified to a wrong class:

In [None]:
predictions = model.predict(X_test)

show_failures(predictions, y_test, X_test)

We can use `show_failures()` to inspect failures in more detail. For example, here are failures in which the true class was "6":

In [None]:
show_failures(predictions, y_test, X_test, trueclass=6)