In [1]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, LSTM

## In this model, we're passing the rows of the image as the sequences. So basically, we're showing the the model each pixel row of the image, in order, and having it make the prediction. (28 sequences of 28 elements)

In [2]:
mnist = tf.keras.datasets.mnist  # mnist is a dataset of 28x28 images of handwritten digits and their labels
(x_train, y_train),(x_test, y_test) = mnist.load_data()  # unpacks images to x_train/x_test and labels to y_train/y_test

x_train = x_train/255.0
x_test = x_test/255.0

print(x_train.shape)
print(x_train[0].shape)

(60000, 28, 28)
(28, 28)


## The type of RNN cell that we're going to use is the LSTM cell. Layers will have dropout, and we'll have a dense layer at the end, before the output layer.

In [3]:
model = Sequential()
model.add(LSTM(128, input_shape=(x_train.shape[1:]), activation='relu', return_sequences=True))
model.add(Dropout(0.2))

model.add(LSTM(128, activation='relu'))
model.add(Dropout(0.1))

model.add(Dense(32, activation='relu'))
model.add(Dropout(0.2))

model.add(Dense(10, activation='softmax'))

## This should all be straight forward, where rather than Dense or Conv, we're just using LSTM as the layer type. The only new thing is return_sequences. This flag is used for when you're continuing on to another recurrent layer. If you are, then you want to return sequences. If you're not going to another recurrent-type of layer, then you don't set this to true.

In [4]:
opt = tf.keras.optimizers.Adam(lr=0.001, decay=1e-6)

model.compile(
    loss='sparse_categorical_crossentropy',
    optimizer=opt,
    metrics=['accuracy'],
)

model.fit(x_train,
          y_train,
          epochs=3,
          validation_data=(x_test, y_test))

Epoch 1/3


  super(Adam, self).__init__(name, **kwargs)


Epoch 2/3
Epoch 3/3


<keras.callbacks.History at 0x1384a844a60>

In [5]:
score = model.evaluate(x_test, y_test)
print(score[1])

0.9765999913215637
