# Using dynamic RNN and LSTM with TFLite converter

This test script highlights the issue of unable to access the state while using TFLite interpreter.

The major part of this script takes the reference from: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py

In [1]:
import numpy as np
import tensorflow as tf

from tensorflow.python.ops import control_flow_util
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.tools import optimize_for_inference_lib
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs

# Turn warning off
tf.logging.set_verbosity(tf.logging.ERROR)

  from ._conv import register_converters as _register_converters


In [2]:
print (tf.__version__)

1.13.0-dev20190228


### set up mnist and initial parameters

In [3]:
# download and process mnist
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
# Define constants
# Unrolled through 28 time steps
time_steps = 28
# Rows of 28 pixels
n_input = 28
# Learning rate for Adam optimizer
learning_rate = 0.001
# MNIST is meant to be classified in 10 classes(0-9).
n_classes = 10
# Batch size
batch_size = 16
# Lstm Units.
num_units = 16
TRAIN_STEPS = 1

Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz


### Functions for building the network

Use `tf.lite.experimental.nn.TFLiteLSTMCell` for LSTM and `tf.lite.experimental.nn.dynamic_rnn` for dynamic rnn

In [4]:
def buildLstmLayer():
    return tf.nn.rnn_cell.MultiRNNCell([
        tf.lite.experimental.nn.TFLiteLSTMCell(
            num_units, use_peepholes=True, forget_bias=0, name="rnn1"),
        tf.lite.experimental.nn.TFLiteLSTMCell(
            num_units, num_proj=8, forget_bias=0, name="rnn2"),
        tf.lite.experimental.nn.TFLiteLSTMCell(
            num_units // 2,
            use_peepholes=True,
            num_proj=8,
            forget_bias=0,
            name="rnn3"),
        tf.lite.experimental.nn.TFLiteLSTMCell(
            num_units, forget_bias=0, name="rnn4")
    ])

def buildModel(lstm_layer, is_dynamic_rnn):
    # Weights and biases for output softmax layer.
    out_weights = tf.Variable(
        tf.random_normal([num_units, n_classes]))
    out_bias = tf.Variable(tf.random_normal([n_classes]))

    # input image placeholder
    x = tf.placeholder(
        "float", [None, time_steps, n_input], name="INPUT_IMAGE")

    # x is shaped [batch_size,time_steps,num_inputs]
    if is_dynamic_rnn:
        lstm_input = tf.transpose(x, perm=[1, 0, 2])
        outputs, state = tf.lite.experimental.nn.dynamic_rnn(
          lstm_layer, lstm_input, dtype="float32", time_major=True)
        outputs = tf.unstack(outputs, axis=0)
    else:
        lstm_input = tf.unstack(x, time_steps, 1)
        outputs, state = tf.nn.static_rnn(lstm_layer, lstm_input, dtype="float32")

    # Compute logits by multiplying outputs[-1] of shape [batch_size,num_units]
    # by the softmax layer's out_weight of shape [num_units,n_classes]
    # plus out_bias
    prediction = tf.matmul(outputs[-1], out_weights) + out_bias
    output_class = tf.nn.softmax(prediction, name="OUTPUT_CLASS")

    state_c, state_h = state[0]
    state_c = tf.expand_dims(state_c, 0)
    state_h = tf.expand_dims(state_h, 0)

    state_out = tf.identity(tf.concat([state_c, state_h], 0), name='state_out')

    return x, prediction, output_class, state_out


### Utility functions for training, saving/restoring, and serving (inferencing) the model

In [5]:
def trainModel(x, prediction, output_class, sess):
    # input label placeholder
    y = tf.placeholder("float", [None, n_classes])
    # Loss function
    loss = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=y))
    # Optimization
    opt = tf.train.AdamOptimizer(
        learning_rate=learning_rate).minimize(loss)

    # Initialize variables
    init = tf.global_variables_initializer()
    sess.run(init)
    for _ in range(TRAIN_STEPS):
        batch_x, batch_y = mnist.train.next_batch(
          batch_size=batch_size, shuffle=False)

        batch_x = batch_x.reshape((batch_size, time_steps,
                                 n_input))
        sess.run(opt, feed_dict={x: batch_x, y: batch_y})

def saveAndRestoreModel(lstm_layer, sess, saver, is_dynamic_rnn):
    model_dir = 'export/dynamic_rnn'
    saver.save(sess, model_dir)

    # Reset the graph.
    tf.reset_default_graph()
    x, prediction, output_class, output_state = buildModel(lstm_layer, is_dynamic_rnn)

    new_sess = tf.Session()
    saver = tf.train.Saver()
    saver.restore(new_sess, model_dir)
    return x, prediction, output_class, output_state, new_sess

def getInferenceResult(x, output_class, output_state, sess):
    b1, _ = mnist.train.next_batch(batch_size=1)
    sample_input = np.reshape(b1, (1, time_steps, n_input))
    [expected_output, expected_state] = sess.run([output_class, output_state], feed_dict={x: sample_input})
    frozen_graph = tf.graph_util.convert_variables_to_constants(
        sess, sess.graph_def, [output_class.op.name, output_state.op.name])
    with open('output_graph.pb', 'wb') as f:
        f.write(frozen_graph.SerializeToString())
    return sample_input, expected_output, expected_state, frozen_graph

### Create and run session

In [6]:
control_flow_util.ENABLE_CONTROL_FLOW_V2 = True

In [7]:
sess = tf.Session()
x, prediction, output_class, output_state = buildModel(
    buildLstmLayer(), is_dynamic_rnn=True)
trainModel(x, prediction, output_class, sess)
saver = tf.train.Saver()

x, prediction, output_class, output_state, new_sess = saveAndRestoreModel(
    buildLstmLayer(), sess, saver, is_dynamic_rnn=True)

### get inference result for sanity test

In [8]:
test_inputs, expected_output, expected_state, frozen_graph = getInferenceResult(
    x, output_class, output_state, new_sess)

### Functions for converting graph to TFLite

#### Convert frozen graph to tf lite model, perform inference with interpreter

In [9]:
def tfliteInvoke(graph, test_inputs, outputs, output_state, state_out=False):
    tf.reset_default_graph()
    # Turn the input into placeholder of shape 1
    tflite_input = tf.placeholder(
        "float", [1, time_steps, n_input], name="INPUT_IMAGE_LITE")
    tf.import_graph_def(graph, name="", input_map={"INPUT_IMAGE": tflite_input})
    with tf.Session() as sess:
        curr = sess.graph_def
        curr = convert_op_hints_to_stubs(graph_def=curr)
    
    # if set, include state as output tensor
    if state_out:
        converter = tf.lite.TFLiteConverter(curr, [tflite_input], [outputs, output_state])
    else:
        converter = tf.lite.TFLiteConverter(curr, [tflite_input], [outputs])

    tflite = converter.convert()
    interpreter = tf.lite.Interpreter(model_content=tflite)

    try:
        interpreter.allocate_tensors()
    except ValueError:
        assert False

    input_index = (interpreter.get_input_details()[0]["index"])
    interpreter.set_tensor(input_index, test_inputs)
    
    # note: kernel will crash here if trying to include lstm state
    interpreter.invoke()
    output_index = (interpreter.get_output_details()[0]["index"])
    result = interpreter.get_tensor(output_index)
    
    # if set, get state from interpreter
    if state_out:
        state_index = (interpreter.get_output_details()[1]["index"])
        state = interpreter.get_tensor(state_index)
    else:
        state = None

    # Reset all variables so it will not pollute other inferences.
    interpreter.reset_all_variables()
    
    return result, state

In [10]:
result, state = tfliteInvoke(frozen_graph, test_inputs, output_class, output_state)

In [11]:
print(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2))

True


### Kernel crashes while trying to access LSTM state

#### note: kernel will crash in  `interpreter.invoke()` if trying to include lstm state

In [None]:
result, state = tfliteInvoke(frozen_graph, test_inputs, output_class, output_state, state_out=True)