Python code to train an image classification on MNIST data using Keras
library with LSTM architecture

In [8]:
# import necessary packages
import sys
from tensorflow.examples.tutorials.mnist import input_data
from keras.models import Sequential
from keras.layers import LSTM, Dense
from keras.models import load_model
import numpy as np

#Build the network using LSTM layers
class MnistLSTMClassifier(object):
    def __init__(self):
        # Classifier
        self.time_steps=28 # timesteps to unroll
        self.n_units=128 # hidden LSTM units
        self.n_inputs=28 # rows of 28 pixels (an mnist img is 28x28)
        self.n_classes=10 # mnist classes/labels (0-9)
        self.batch_size=128 # Size of each batch
        self.n_epochs=5
        # Internal
        self._data_loaded = False
        self._trained = False

    def __create_model(self):
        self.model = Sequential()
        self.model.add(LSTM(self.n_units, input_shape=(self.time_steps, self.n_inputs)))
        self.model.add(Dense(self.n_classes, activation='softmax'))
#compile the layers and model
        self.model.compile(loss='categorical_crossentropy',
                      optimizer='rmsprop',
                      metrics=['accuracy'])
#load the data to the model
    def __load_data(self):
        self.mnist = input_data.read_data_sets("mnist", one_hot=True)
        self._data_loaded = True
#initialise the training parameters 
    def train(self, save_model=False):
        self.__create_model()
        if self._data_loaded == False:
            self.__load_data()

        x_train = [x.reshape((-1, self.time_steps, self.n_inputs)) for x in self.mnist.train.images]    #load the train data
        x_train = np.array(x_train).reshape((-1, self.time_steps, self.n_inputs))       #reshape the train data
#initialise the hyperparameters for training
        self.model.fit(x_train, self.mnist.train.labels,
                  batch_size=self.batch_size, epochs=self.n_epochs, shuffle=False)

        self._trained = True
        
        if save_model:
            self.model.save("/content/saved_model/lstm-model.h5")     #save the model
#define the validation/testing function
    def evaluate(self, model=None):
        if self._trained == False and model == None:
            errmsg = "[!] Error: classifier wasn't trained or classifier path is not precised."
            print(errmsg, file=sys.stderr)
            sys.exit(0)

        if self._data_loaded == False:
            self.__load_data()

        x_test = [x.reshape((-1, self.time_steps, self.n_inputs)) for x in self.mnist.test.images]
        x_test = np.array(x_test).reshape((-1, self.time_steps, self.n_inputs))
#check the test/validation loss during the training
        model = load_model(model) if model else self.model
        test_loss = model.evaluate(x_test, self.mnist.test.labels)
        print(test_loss)

#main function
if __name__ == "__main__":
    lstm_classifier = MnistLSTMClassifier()
    lstm_classifier.train(save_model=True)

Extracting mnist/train-images-idx3-ubyte.gz
Extracting mnist/train-labels-idx1-ubyte.gz
Extracting mnist/t10k-images-idx3-ubyte.gz
Extracting mnist/t10k-labels-idx1-ubyte.gz
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


Now evaluate the trained model with the test dataset in the MNIST dataset


In [9]:
lstm_classifier.evaluate(model="/content/saved_model/lstm-model.h5")

[0.07600377612132579, 0.9766]
