# QuantLSTM - ONNX (QCDQ) representation

This notebook is divided into `five` parts:

<br><b>Part 1</b> : Introduction to LSTMs.
<br>
<br><b>Part 2</b> : Model creation with brevitas QuantLSTM layer. 
<br>
<br><b>Part 3</b> : Build ONNX model representing the LSTM computation used to process a single input with `QCDQ quantization` (weights/inputs/activations) 
<br>
<br> <b>Part 4</b> : Integration of the QCDQ-LSTM graph with the `SCAN` operator. This operator allows cyclic computations (<i>required for state updates in recurrent neural networks</i>) that are currently not supported in ONNX.
<br>
<br><b>Part 5</b> : Functional verification of the `QCDQ-LSTM` model with brevitas `QuantLSTM` model output.

# Introduction to LSTM's 

`LSTM’s (Long Short-Term Memory)` are sequential neural networks that are capable of learning long term dependencies especially in sequence prediction problems. They are deployed in machine translation, speech recognition, image captioning and especially used for time-series analysis applications.

LSTM's have `feedback connections`, unlike conventional feed-forward neural networks (where the compute path goes only in the forward direction). This makes them capable of processing time-series data like vide streams or analyzing network traffic patterns.
Such feedback connections though also make their hardware implementations compiliacted as they require state updates unlike feed-forward neural networks.
<br>
<br>
The LSTM compute requires the following six compute equations:
$$
  f_t = \sigma (W_f * x_t + U_f * H_{t-1} + b_f) 
$$
$$
  i_t = \sigma (W_i * x_t + U_i * H_{t-1} + b_i)
$$
$$
   \tilde{C_t} = tanh(W_c * x_t + U_c * H_{t-1} + b_c)
$$
$$
  o_t = \sigma (W_o * x_t + U_o * H_{t-1} + b_o)
$$
$$
  C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C_t}
$$
$$
  H_t = tanh(C_t) \odot o_t 
$$

The first four equations represent the `gate computations`.
We compute the `cell state` and the `hidden state` in the last two equations respectively. 
These two states are then fed back into the LSTM cell for the computation of the next input.

# QuantLSTM model creation

In the 2nd part of the notebook, we will create a single layer `QuantLSTM` model in brevitas. We will evaluate with a given set of inputs. We then export this model to `QONNX` so that the same parameters (weights/biases/scales) can be extracted and used in the `QCDQ-LSTM` implementation.

In [1]:
# We import the required libraries to execute different functions in the notebook.
# The first four imports are required to build the QuantLSTM model in brevitas. 
# The model created will then be exported to QONNX and it's parameters used in the QCDQ implementation.

import torch
from torch import nn
from brevitas.nn import QuantLSTM
from brevitas.export import export_onnx_qcdq

#We need the onnx and onnx helper nodes to build the onnx graph for the LSTM compute.
import onnx
from onnx import numpy_helper
from onnx.helper import make_tensor_value_info, make_node, make_graph, make_model, make_tensor
#onnxruntime will be used to execute our onnx model.
import onnxruntime as rt 
from qonnx.util.basic import qonnx_make_model
#numpy allows us to manipulate outputs from the brevitas and the ONNX model
import numpy as np 
# Netron visualization tool will help us view interactable graphs
import netron

No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'


In [2]:
# In this block of code we will create the QuantLSTM model using the brevitas layer
torch.manual_seed(0) #Setting the manual seeds to 0 for consistency in outputs.

# Initializing attributes that can be changed accordingly depending on users requirements.

num_inputs = 25                 #Defining the number of inputs 
num_features_brevitas = 10      #This attribute defines number of features each input comprises of
num_hidden_cells_brevitas = 20  #This attribute defines the number of hidden cells in the QuantLSTM layer

# Creating a sequential model

model_lstm = nn.Sequential( 
    QuantLSTM(input_size = num_features_brevitas, hidden_size = num_hidden_cells_brevitas, bias_quant=None) 
    )                           #No other feature described here implies quantization of inputs/parametersers/activations to 8-bits.
model_lstm.eval()               #Setting the model to eval mode to make sure all the parameters and scales are frozen and not updated on runtime.
export_path = './quant_lstm_quantization_qcdq.onnx' #Setting export path for the model
export_onnx_qcdq(model_lstm,(torch.randn(num_inputs, 1, num_features_brevitas)), opset_version=14, export_path=export_path) #Exporting the model to QCDQ representation. 

# Creating a test input to execute the above created model

in_qcdq_node = np.empty([num_inputs,1,num_features_brevitas],dtype=np.float32).reshape([num_inputs,1,num_features_brevitas])
in_qcdq_node.fill(0.8)          #Using the fill function to fill the numpy array with a value of 0.8
test_input = torch.from_numpy(in_qcdq_node)     #Converting the array to a torch tensor
brevitas_output = model_lstm(test_input)        #Executing the model with the set input
brevitas_output = brevitas_output[0].detach().numpy()
print(brevitas_output)

quant_input_supplied to brevitas =  tensor([[-1.0000, -0.5000, -1.0000,  0.5156, -1.0000,  0.9922, -0.8047, -1.0000,
          0.2188,  0.9922]])
----------------------------
quant_input_supplied to brevitas =  tensor([[-0.7266, -0.9531,  0.9922,  0.9922, -1.0000,  0.9922, -0.7734, -1.0000,
         -0.0859,  0.6250]])
----------------------------
quant_input_supplied to brevitas =  tensor([[-0.6719, -1.0000,  0.0547, -0.5234, -0.0000,  0.1250, -1.0000,  0.3047,
         -0.0312, -1.0000]])
----------------------------
quant_input_supplied to brevitas =  tensor([[-1.0000, -0.1797,  0.3516, -0.1328, -1.0000, -1.0000,  0.8750, -0.2812,
          0.4844, -0.3203]])
----------------------------
quant_input_supplied to brevitas =  tensor([[ 0.6719, -0.1484,  0.5078,  0.5312, -0.2969,  0.1719, -1.0000,  0.4688,
         -0.2500,  0.8672]])
----------------------------
quant_input_supplied to brevitas =  tensor([[ 0.3125,  0.9922,  0.8281, -0.4297, -1.0000,  0.9922, -1.0000,  0.9922,
        

`Abbreviations` : Short-forms defined in the next code block can be referenced here for definitions.

* <b>Wi</b> = "Weight matrix for the input gate" (Similarily for the other three gates)
* <b>Ui</b> = "Recurrence matrix for the input gate" (Similarily for the other three gates)
* <b>bi</b> = "Bias for the input gate" (Similarily for the other three gates)

In [3]:
# In this block of code we store all the parameters (weight matrices, recurrence matrices, biases, scales and zero-points) that we will need to import in the QCDQ implementation.
# Importing the exported quantized model from brevitas
brevitas_lstm_export = onnx.load("./quant_lstm_quantization_qcdq.onnx")
parameters = brevitas_lstm_export.graph.initializer #Extracting all the parameters from the loaded graph

# In this loop we will be printing all the parameters to correctly import the parameters values to the right variables
for i in range(len(parameters)):
    w = numpy_helper.to_array(parameters[i])
    print (brevitas_lstm_export.graph.initializer[i].name)
    print(w.shape)
    print(w,',',i)
    print("-------------------------")
    
# Storing the extracted parameters (weights/biases/scales) to the right variables depending on the order in which they are exported. 
# The abbreviation described in the above block will help in understanding what each variable denotes

bi_val = numpy_helper.to_array(parameters[0])
Wi_val = numpy_helper.to_array(parameters[1])
Ui_val = numpy_helper.to_array(parameters[2])
bf_val = numpy_helper.to_array(parameters[3])
Wf_val = numpy_helper.to_array(parameters[4])
Uf_val = numpy_helper.to_array(parameters[5])
bc_val = numpy_helper.to_array(parameters[6])
Wc_val = numpy_helper.to_array(parameters[7])
Uc_val = numpy_helper.to_array(parameters[8])
bo_val = numpy_helper.to_array(parameters[9])
Wo_val = numpy_helper.to_array(parameters[10])
Uo_val = numpy_helper.to_array(parameters[11])
# Scalar values can either be int or float
inp_scale_val = float(numpy_helper.to_array(parameters[12])) 
w1_scale_val = float(numpy_helper.to_array(parameters[15]))
w2_scale_val = float(numpy_helper.to_array(parameters[18]))
w3_scale_val = float(numpy_helper.to_array(parameters[19]))
w4_scale_val = float(numpy_helper.to_array(parameters[20]))
eq_scale_val_1 = float(numpy_helper.to_array(parameters[12]))
eq_scale_val_2 = float(numpy_helper.to_array(parameters[22]))

0.layers.0.0.input_gate_params.bias
(20,)
[-0.02587563 -0.18425222 -0.18189065  0.02914573 -0.21827428  0.0595416
 -0.20598626 -0.15559138 -0.04639753 -0.2133838   0.18059207  0.18321364
 -0.11679631  0.04684116  0.11439164  0.07105622 -0.02995344 -0.21090843
  0.1625932  -0.19612479] , 0
-------------------------
0.layers.0.0.input_gate_params.input_weight.weight
(20, 10)
[[-4.14119214e-02  1.38706667e-02 -7.36431107e-02 -8.17852393e-02
  -1.93256751e-01  1.23205660e-02 -2.53894478e-02  1.94940954e-01
  -7.36160800e-02  1.72829047e-01]
 [ 1.05855539e-02 -1.00462548e-01 -5.31778559e-02 -2.53751595e-02
   2.31616711e-03 -3.68398018e-02  6.63604736e-02  1.84143797e-01
   3.51473056e-02  8.09932351e-02]
 [ 1.38081744e-01  4.81988601e-02  1.03076197e-01  1.17293097e-01
   2.09298924e-01 -2.04075590e-01  7.65163079e-02 -1.01319486e-02
  -4.01576199e-02 -8.62098187e-02]
 [ 1.34432539e-01  2.04552680e-01 -1.82483241e-01  1.20810278e-01
   1.54187992e-01  3.90806384e-02  2.63404008e-03  1.7207

# LSTM ONNX model

In the 3rd part of the notebook, we will construct the `QCDQ-LSTM` model with standard ONNX operators. After loading all the parameters in the above block we can now start building our ONNX model with QCDQ quantization to represent the LSTM computations described in part-1.


In [4]:
# Setting parameters : Matching the input output lengths exported from brevitas
num_features = 10
num_hidden_cells = 20
activation_bit_width = 8

# The below two parameters are for the 'Clip' operation. 
# Clip node parameters
max_clip_val = (2 ** (activation_bit_width -1) - 1)
min_clip_val = -(2 ** (activation_bit_width -1) - 1)

# Zero-point datatype decides the datatype of the output tensor for the quantization operations hence we defined two. One for signed and other for unsigned.
# Zero point values for quantization
zero_point_signed_val = 0
zero_point_unsigned_val = 0

`Abbreviations` : These describe different short-forms used in the next two blocks.

* <b>ql</b> = "QuantizeLinear"
* <b>dql</b> = "DequantizeLinear"
* <b>clp</b> = "Clip"
* <b>id</b> = "Identity"
* <b>matmul</b> = "Matrix Multiplication"
* <b>el_mul</b> = "Elementwise Multiplication"
* <b>sig</b> = "Sigmoid"

We start defining the model by defining the `inputs` and `outputs` defined as value_info tensors in ONNX.
For LSTMs we need three inputs : `inputs`, `previous hidden state` and `previous cell state`. 
We get three outputs : `hidden_state`, `cell_state` and `concatenated_hidden_states`.

In [5]:
# Defining the inputs 'value info' tensors for the compute graph.
hidden_state = make_tensor_value_info("h_t-1",onnx.TensorProto.FLOAT, [num_hidden_cells,1])
cell_state = make_tensor_value_info("c_t-1", onnx.TensorProto.FLOAT, [num_hidden_cells,1])
inputs = make_tensor_value_info("inp",onnx.TensorProto.FLOAT, [num_features,1])

#Output value info tensor definitions
out_hidden_state = make_tensor_value_info("h_t", onnx.TensorProto.FLOAT, [num_hidden_cells,1])
out_cell_state = make_tensor_value_info("c_t", onnx.TensorProto.FLOAT, [num_hidden_cells,1])
out_hidden_state_concat = make_tensor_value_info("h_t_concat", onnx.TensorProto.FLOAT, [num_hidden_cells,1])#maybe this will have one more dimension

In [6]:
# Once we have defined the inputs and outputs, we will now start defining the operations in the LSTM compute graph.
# We start by quantizing the input with the standard QDQ operation which is 8-bit quantization. 
# Note: For quantization to lower bit-width's we can use the clip node.

# Input quantization
ql_input = make_node("QuantizeLinear", inputs=["inp","inp_scale","zero_point_signed"], outputs=["ql_input_out"],name="ql_input")
id_0_input = make_node("Identity", inputs=["ql_input_out"], outputs=["first_input_out"], name="id_0_input")
dql_input = make_node("DequantizeLinear", inputs=["ql_input_out", 'inp_scale', "zero_point_signed"], outputs=["dql_input_out"],name="dql_input")

# Quantization of the four weight matrices showing QCDQ quantization
ql_w1 = make_node("QuantizeLinear", inputs=["W_f","scale_f","zero_point_signed"], outputs=["ql_wf_out"], name="ql_w1")
clp_w1 = make_node("Clip", inputs=["ql_wf_out","min","max"], outputs=["clp_wf"], name="clp_w1")
dql_w1 = make_node("DequantizeLinear", inputs=["clp_wf","scale_f","zero_point_signed"], outputs=["dql_wf_out"], name="dql_w1")

ql_w2 = make_node("QuantizeLinear", inputs=["W_i","scale_i","zero_point_signed"], outputs=["ql_wi_out"], name="ql_w2")
clp_w2 = make_node("Clip", inputs=["ql_wi_out","min","max"], outputs=["clp_wi"], name="clp_w2")
dql_w2 = make_node("DequantizeLinear", inputs=["clp_wi","scale_i","zero_point_signed"], outputs=["dql_wi_out"], name="dql_w2")

ql_w3 = make_node("QuantizeLinear", inputs=["W_c","scale_c","zero_point_signed"], outputs=["ql_wc_out"], name="ql_w3")
clp_w3 = make_node("Clip", inputs=["ql_wc_out","min","max"], outputs=["clp_wc"], name="clp_w3")
dql_w3 = make_node("DequantizeLinear", inputs=["clp_wc","scale_c","zero_point_signed"], outputs=["dql_wc_out"], name="dql_w3")

ql_w4 = make_node("QuantizeLinear", inputs=["W_o","scale_o","zero_point_signed"], outputs=["ql_wo_out"], name="ql_w4")
clp_w4 = make_node("Clip", inputs=["ql_wo_out","min","max"], outputs=["clp_wo"], name="clp_w4")
dql_w4 = make_node("DequantizeLinear", inputs=["clp_wo","scale_o","zero_point_signed"], outputs=["dql_wo_out"], name="dql_w4")

# Quantizations for the four recurrence weight matrices showing QCDQ quantization
ql_u1 = make_node("QuantizeLinear", inputs=["U_f","scale_f","zero_point_signed"], outputs=["ql_uf_out"], name="ql_u1")
clp_u1 = make_node("Clip", inputs=["ql_uf_out","min","max"], outputs=["clp_uf"], name="clp_u1")
dql_u1 = make_node("DequantizeLinear", inputs=["clp_uf","scale_f","zero_point_signed"], outputs=["dql_uf_out"], name="dql_u1")

ql_u2 = make_node("QuantizeLinear", inputs=["U_i","scale_i","zero_point_signed"], outputs=["ql_ui_out"], name="ql_u2")
clp_u2 = make_node("Clip", inputs=["ql_ui_out","min","max"], outputs=["clp_ui"], name="clp_u2")
dql_u2 = make_node("DequantizeLinear", inputs=["clp_ui","scale_i","zero_point_signed"], outputs=["dql_ui_out"], name="dql_u2")

ql_u3 = make_node("QuantizeLinear", inputs=["U_c","scale_c","zero_point_signed"], outputs=["ql_uc_out"], name="ql_u3")
clp_u3 = make_node("Clip", inputs=["ql_uc_out","min","max"], outputs=["clp_uc"], name="clp_u3")
dql_u3 = make_node("DequantizeLinear", inputs=["clp_uc","scale_c","zero_point_signed"], outputs=["dql_uc_out"], name="dql_u3")

ql_u4 = make_node("QuantizeLinear", inputs=["U_o","scale_o","zero_point_signed"], outputs=["ql_uo_out"], name="ql_u4")
clp_u4 = make_node("Clip", inputs=["ql_uo_out","min","max"], outputs=["clp_uo"], name="clp_u4")
dql_u4 = make_node("DequantizeLinear", inputs=["clp_uo","scale_o","zero_point_signed"], outputs=["dql_uo_out"], name="dql_u4")

# Once we have quantized the weights and inputs we can now start defining the operations for the 6 LSTM equations.
# The first four gate equations have a very similar compute structure. We define the first four gate computations in this order : Forget, Input, Output, Cell 

# 1st Equation : Forget gate
matmul_1_e1 = make_node("MatMul", inputs=["dql_wf_out","dql_input_out"], outputs=["out_m1_e1"], name="matmul_1_e1")
matmul_2_e1 = make_node("MatMul", inputs=["dql_uf_out","h_t-1"], outputs=["out_m2_e1"],name="matmul_2_e1")
add_1_e1 = make_node("Add", inputs=["out_m1_e1","out_m2_e1"], outputs=["out_add1_e1"],name="add_1_e1")
add_2_e1 = make_node("Add", inputs=["out_add1_e1","b_f"], outputs=["f_t_ba"],name="add_2_e1")
ql_1_e1 = make_node("QuantizeLinear", inputs=["f_t_ba","scale_3","zero_point_signed"], outputs=["f_t_ql1"],name="ql_1_e1")
dql_1_e1 = make_node("DequantizeLinear", inputs=["f_t_ql1", "scale_4", "zero_point_signed"], outputs=["f_t_dql1"], name="dql_1_e1")
sig_f_e1     = make_node("Sigmoid", inputs=["f_t_dql1"], outputs=["f_t"],name="sig_f_e1")
ql_2_e1 = make_node("QuantizeLinear", inputs=["f_t","scale_4","zero_point_unsigned"], outputs=["f_t_ql2"],name="ql_2_e1")
dql_2_e1 = make_node("DequantizeLinear", inputs=["f_t_ql2", "scale_4", "zero_point_unsigned"], outputs=["f_t_dql2"], name="dql_2_e1")

# 2nd Equation : Input gate
matmul_1_e2 = make_node("MatMul", inputs=["dql_wi_out","dql_input_out"], outputs=["out_m1_e2"], name="matmul_1_e2")
matmul_2_e2 = make_node("MatMul", inputs=["dql_ui_out","h_t-1"], outputs=["out_m2_e2"],name="matmul_2_e2")
add_1_e2 = make_node("Add", inputs=["out_m1_e2","out_m2_e2"], outputs=["out_add1_e2"],name="add_1_e2")
add_2_e2 = make_node("Add", inputs=["out_add1_e2","b_i"], outputs=["i_t_ba"],name="add_2_e2")
ql_1_e2 = make_node("QuantizeLinear", inputs=["i_t_ba","scale_1","zero_point_signed"], outputs=["i_t_ql1"],name="ql_1_e2")
dql_1_e2 = make_node("DequantizeLinear", inputs=["i_t_ql1","scale_1", "zero_point_signed"], outputs=["i_t_dql1"], name="dql_1_e2")
sig_i_e2     = make_node("Sigmoid", inputs=["i_t_dql1"], outputs=["i_t"],name="sig_i_e2")
ql_2_e2 = make_node("QuantizeLinear", inputs=["i_t","scale_2","zero_point_unsigned"], outputs=["i_t_ql2"],name="ql_2_e2")
dql_2_e2 = make_node("DequantizeLinear", inputs=["i_t_ql2", "scale_2", "zero_point_unsigned"], outputs=["i_t_dql2"], name="dql_2_e2")

# 3rd Equation : Output gate
matmul_1_e3 = make_node("MatMul", inputs=["dql_wo_out","dql_input_out"], outputs=["out_m1_e3"], name="matmul_1_e3")
matmul_2_e3 = make_node("MatMul", inputs=["dql_uo_out","h_t-1"], outputs=["out_m2_e3"],name="matmul_2_e3")
add_1_e3 = make_node("Add", inputs=["out_m1_e3","out_m2_e3"], outputs=["out_add1_e3"],name="add_1_e3")
add_2_e3 = make_node("Add", inputs=["out_add1_e3","b_o"], outputs=["o_t_ba"],name="add_2_e3" )
ql_1_e3 = make_node("QuantizeLinear", inputs=["o_t_ba","scale_7","zero_point_signed"], outputs=["o_t_ql1"],name="ql_1_e3")
dql_1_e3 = make_node("DequantizeLinear", inputs=["o_t_ql1","scale_7", "zero_point_signed"], outputs=["o_t_dql1"], name="dql_1_e3")
sig_o_e3     = make_node("Sigmoid", inputs=["o_t_dql1"], outputs=["o_t"],name="sig_o_e3")
ql_2_e3 = make_node("QuantizeLinear", inputs=["o_t","scale_8","zero_point_unsigned"], outputs=["o_t_ql2"],name="ql_2_e3")
dql_2_e3 = make_node("DequantizeLinear", inputs=["o_t_ql2", "scale_8", "zero_point_unsigned"], outputs=["o_t_dql2"], name="dql_2_e3")

# 4th Equation : Cell gate
matmul_1_e4 = make_node("MatMul", inputs=["dql_wc_out","dql_input_out"], outputs=["out_m1_e4"], name="matmul_1_e4")
matmul_2_e4 = make_node("MatMul", inputs=["dql_uc_out","h_t-1"], outputs=["out_m2_e4"],name="matmul_2_e4")
add_1_e4 = make_node("Add", inputs=["out_m1_e4","out_m2_e4"], outputs=["out_add1_e4"],name="add_1_e4")
add_2_e4 = make_node("Add", inputs=["out_add1_e4","b_c"], outputs=["c_t_ba"],name="add_2_e4")
ql_1_e4 = make_node("QuantizeLinear", inputs=["c_t_ba","scale_5","zero_point_signed"], outputs=["c_t_ql1"],name="ql_1_e4")
dql_1_e4 = make_node("DequantizeLinear", inputs=["c_t_ql1","scale_5", "zero_point_signed"], outputs=["c_t_dql1"], name="dql_1_e4")
tanh_c_e4    = make_node("Tanh", inputs=["c_t_dql1"], outputs=["c_t_partial"],name="tanh_c_e4")
ql_2_e4 = make_node("QuantizeLinear", inputs=["c_t_partial","scale_6","zero_point_signed"], outputs=["c_t_ql2"],name="ql_2_e4")
dql_2_e4 = make_node("DequantizeLinear", inputs=["c_t_ql2", "scale_6", "zero_point_signed"], outputs=["c_t_dql2"], name="dql_2_e4")

# Once we have the first four gate computations we can procedd with the computation of the cell_state and the hidden_state in the 5th and the 6th equations.
# 5th Equation : Cell state compute
el_mul_1_e5 = make_node("Mul", inputs=["f_t_dql2","c_t-1"], outputs=["out_el_mul1_e5"],name="el_mul_1_e5")
ql_1_e5 = make_node("QuantizeLinear", inputs=["out_el_mul1_e5","scale_9","zero_point_signed"], outputs=["fifth_ql1"],name="ql_1_e5")
dql_1_e5 = make_node("DequantizeLinear", inputs=["fifth_ql1","scale_9", "zero_point_signed"], outputs=["fifth_dql1"], name="dql_1_e5")
el_mul_2_e5 = make_node("Mul", inputs=["i_t_dql2","c_t_dql2"], outputs=["out_el_mul2_e5"], name="el_mul_2_e5") 
ql_2_e5 = make_node("QuantizeLinear", inputs=["out_el_mul2_e5","scale_9","zero_point_signed"], outputs=["fifth_ql2"],name="ql_2_e5")
dql_2_e5 = make_node("DequantizeLinear", inputs=["fifth_ql2","scale_9", "zero_point_signed"], outputs=["fifth_dql2"], name="dql_2_e5")
add_1_e5     = make_node("Add", inputs=["fifth_dql1","fifth_dql2"], outputs=["c_t"], name="add_1_e5")   #-----------------> The first output is computed here.
ql_3_e5 = make_node("QuantizeLinear", inputs=["c_t","scale_9","zero_point_signed"], outputs=["h_t_ql"], name="ql_3_e5")
dql_3_e5 = make_node("DequantizeLinear", inputs=["h_t_ql","scale_9","zero_point_signed"], outputs=["h_t_dql"], name="dql_3_e5")

# 6th Equation : Hidden state compute
tanh_node_e6    = make_node("Tanh", inputs=["h_t_dql"], outputs=["out_tanh_e6"], name="tanh_node_e6") 
ql_1_e6 = make_node("QuantizeLinear", inputs=["out_tanh_e6","scale_10","zero_point_signed"], outputs=["sixth_ql1"], name="ql_1_e6")
dql_1_e6 = make_node("DequantizeLinear", inputs=["sixth_ql1","scale_10","zero_point_signed"], outputs=["sixth_dql1"], name="dql_1_e6")
el_mul_1_e6 = make_node("Mul", inputs=["sixth_dql1","o_t_dql2"], outputs=["h_t_inter"], name="el_mul_1_e6")#h_t_inter
ql_2_e6 = make_node("QuantizeLinear", inputs=["h_t_inter","scale_11","zero_point_signed"], outputs=["sixth_ql2"], name="ql_2_e6")
dql_2_e6 = make_node("DequantizeLinear", inputs=["sixth_ql2","scale_11","zero_point_signed"], outputs=["h_t"], name="dql_2_e6") #-----------------> The second output is computed here.
id_1_e6 = make_node("Identity", inputs=["h_t"], outputs=["h_t_concat"], name="id_1_e6") #-----------------> The third output is computed here.

After defining the above operations we now connect them and create a graph with the help of onnx.helper `make_graph` utility function

In [7]:
lstm_body = make_graph(
    nodes=[
           ql_input,
           dql_input, 
           ql_w1,
           clp_w1, 
           dql_w1,
           ql_w2,
           clp_w2, 
           dql_w2,
           ql_w3,
           clp_w3, 
           dql_w3,
           ql_w4,
           clp_w4, 
           dql_w4,
           ql_u1,
           clp_u1, 
           dql_u1,
           ql_u2,
           clp_u2,
           dql_u2,    
           ql_u3,
           clp_u3,
           dql_u3,    
           ql_u4,
           clp_u4,
           dql_u4, 
           matmul_1_e1,
           matmul_2_e1, 
           add_1_e1, 
           add_2_e1,
           ql_1_e1,
           dql_1_e1,
           sig_f_e1,
           ql_2_e1, 
           dql_2_e1, 
           matmul_1_e2,
           matmul_2_e2, 
           add_1_e2, 
           add_2_e2,
           ql_1_e2,
           dql_1_e2,
           sig_i_e2,
           ql_2_e2, 
           dql_2_e2, 
           matmul_1_e3,
           matmul_2_e3, 
           add_1_e3, 
           add_2_e3,
           ql_1_e3,
           dql_1_e3,
           sig_o_e3,
           ql_2_e3, 
           dql_2_e3,  
           matmul_1_e4,
           matmul_2_e4, 
           add_1_e4, 
           add_2_e4,
           ql_1_e4,
           dql_1_e4,
           tanh_c_e4,
           ql_2_e4, 
           dql_2_e4, 
           el_mul_1_e5,
           ql_1_e5, 
           dql_1_e5,
           el_mul_2_e5,
           ql_2_e5,
           dql_2_e5,
           add_1_e5,
           ql_3_e5, 
           dql_3_e5,
           tanh_node_e6,
           ql_1_e6, 
           dql_1_e6,
           el_mul_1_e6,
           ql_2_e6,
           dql_2_e6,   
           id_1_e6
          ],
    name = "qcdq-lsmt-body",
    inputs=[hidden_state,cell_state,inputs], #The order in which the inputs are defined here should match the input order when the scan node is defined.
    outputs = [out_hidden_state, out_cell_state, out_hidden_state_concat],
    value_info=[
            make_tensor_value_info("ql_input_out",onnx.TensorProto.INT8, [num_features,1]),
            make_tensor_value_info("dql_input_out",onnx.TensorProto.FLOAT, [num_features,1]),
            make_tensor_value_info("out_m1_e1",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("out_m2_e1",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("out_add1_e1",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("f_t_ba",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("f_t_ql1",onnx.TensorProto.INT8, [num_hidden_cells,1]),
            make_tensor_value_info("f_t_dql1", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("f_t_ql2",onnx.TensorProto.UINT8, [num_hidden_cells,1]),
            make_tensor_value_info("f_t_dql2", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("out_m1_e2",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("out_m2_e2",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("out_add1_e2",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("i_t_ba",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("i_t_ql1",onnx.TensorProto.INT8, [num_hidden_cells,1]),
            make_tensor_value_info("i_t_dql1", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("i_t_ql2",onnx.TensorProto.UINT8, [num_hidden_cells,1]),
            make_tensor_value_info("i_t_dql2", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("out_m1_e3",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("out_m2_e3",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("out_add1_e3",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("o_t_ba",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("o_t_ql1",onnx.TensorProto.INT8, [num_hidden_cells,1]),
            make_tensor_value_info("o_t_dql1", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("o_t_ql2",onnx.TensorProto.UINT8, [num_hidden_cells,1]),
            make_tensor_value_info("o_t_dql2", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("out_m1_e4",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("out_m2_e4",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("out_add1_e4",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("c_t_ba",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("c_t_ql1",onnx.TensorProto.INT8, [num_hidden_cells,1]),
            make_tensor_value_info("c_t_dql1", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("c_t_ql2",onnx.TensorProto.INT8, [num_hidden_cells,1]),
            make_tensor_value_info("c_t_dql2", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("f_t",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("i_t",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("o_t",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("c_t_partial",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("out_el_mul1_e5",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("out_el_mul2_e5",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("fifth_ql1",onnx.TensorProto.INT8, [num_hidden_cells,1]),
            make_tensor_value_info("fifth_dql1", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("fifth_ql2",onnx.TensorProto.INT8, [num_hidden_cells,1]),
            make_tensor_value_info("fifth_dql2", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("h_t_ql",onnx.TensorProto.INT8, [num_hidden_cells,1]),
            make_tensor_value_info("h_t_dql", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("out_tanh_e6",onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("sixth_ql1",onnx.TensorProto.INT8, [num_hidden_cells,1]),
            make_tensor_value_info("sixth_dql1", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("sixth_ql2",onnx.TensorProto.INT8, [num_hidden_cells,1]),
            make_tensor_value_info("h_t_inter", onnx.TensorProto.FLOAT, [num_hidden_cells,1]),
            make_tensor_value_info("ql_wf_out", onnx.TensorProto.INT8, [num_hidden_cells,num_features]),
            make_tensor_value_info("dql_wf_out",onnx.TensorProto.FLOAT, [num_hidden_cells,num_features]),
            make_tensor_value_info("ql_wi_out", onnx.TensorProto.INT8, [num_hidden_cells,num_features]),
            make_tensor_value_info("dql_wi_out",onnx.TensorProto.FLOAT, [num_hidden_cells,num_features]),
            make_tensor_value_info("ql_wc_out", onnx.TensorProto.INT8, [num_hidden_cells,num_features]),
            make_tensor_value_info("dql_wc_out",onnx.TensorProto.FLOAT, [num_hidden_cells,num_features]),
            make_tensor_value_info("ql_wo_out", onnx.TensorProto.INT8, [num_hidden_cells,num_features]),
            make_tensor_value_info("dql_wo_out",onnx.TensorProto.FLOAT, [num_hidden_cells,num_features]),
            make_tensor_value_info("ql_uf_out",onnx.TensorProto.INT8, [num_hidden_cells,num_hidden_cells]),
            make_tensor_value_info("dql_uf_out",onnx.TensorProto.FLOAT, [num_hidden_cells,num_hidden_cells]),
            make_tensor_value_info("ql_ui_out",onnx.TensorProto.INT8, [num_hidden_cells,num_hidden_cells]),
            make_tensor_value_info("dql_ui_out",onnx.TensorProto.FLOAT, [num_hidden_cells,num_hidden_cells]),
            make_tensor_value_info("ql_uc_out",onnx.TensorProto.INT8, [num_hidden_cells,num_hidden_cells]),
            make_tensor_value_info("dql_uc_out",onnx.TensorProto.FLOAT, [num_hidden_cells,num_hidden_cells]),
            make_tensor_value_info("ql_uo_out",onnx.TensorProto.INT8, [num_hidden_cells,num_hidden_cells]),
            make_tensor_value_info("dql_uo_out",onnx.TensorProto.FLOAT, [num_hidden_cells,num_hidden_cells]),
            make_tensor_value_info("clp_wf",onnx.TensorProto.INT8, [num_hidden_cells,num_features]),
            make_tensor_value_info("clp_wi",onnx.TensorProto.INT8, [num_hidden_cells,num_features]),
            make_tensor_value_info("clp_wc",onnx.TensorProto.INT8, [num_hidden_cells,num_features]),
            make_tensor_value_info("clp_wo",onnx.TensorProto.INT8, [num_hidden_cells,num_features]),
            make_tensor_value_info("clp_uf",onnx.TensorProto.INT8, [num_hidden_cells,num_hidden_cells]), 
            make_tensor_value_info("clp_ui",onnx.TensorProto.INT8, [num_hidden_cells,num_hidden_cells]),
            make_tensor_value_info("clp_uc",onnx.TensorProto.INT8, [num_hidden_cells,num_hidden_cells]),
            make_tensor_value_info("clp_uo",onnx.TensorProto.INT8, [num_hidden_cells,num_hidden_cells]),
        ],
    initializer=[
                 # Initializing the weight and recurrecne matrices
                 make_tensor('W_f',onnx.TensorProto.FLOAT, [num_hidden_cells,num_features], (Wf_val)),
                 make_tensor('U_f',onnx.TensorProto.FLOAT, [num_hidden_cells,num_hidden_cells], (Uf_val)),
                 make_tensor('b_f',onnx.TensorProto.FLOAT, [num_hidden_cells,1], (bf_val)),
                 make_tensor('W_i',onnx.TensorProto.FLOAT, [num_hidden_cells,num_features], (Wi_val)),
                 make_tensor('U_i',onnx.TensorProto.FLOAT, [num_hidden_cells,num_hidden_cells], (Ui_val)),
                 make_tensor('b_i',onnx.TensorProto.FLOAT, [num_hidden_cells,1], (bi_val)),
                 make_tensor('W_o',onnx.TensorProto.FLOAT, [num_hidden_cells,num_features], (Wo_val)),
                 make_tensor('U_o',onnx.TensorProto.FLOAT, [num_hidden_cells,num_hidden_cells], (Uo_val)),
                 make_tensor('b_o',onnx.TensorProto.FLOAT, [num_hidden_cells,1], (bo_val)),
                 make_tensor('W_c',onnx.TensorProto.FLOAT, [num_hidden_cells,num_features], (Wc_val)),
                 make_tensor('U_c',onnx.TensorProto.FLOAT, [num_hidden_cells,num_hidden_cells], (Uc_val)),
                 make_tensor('b_c',onnx.TensorProto.FLOAT, [num_hidden_cells,1], (bc_val)),
                 # Input scale value
                 make_tensor('inp_scale',onnx.TensorProto.FLOAT, [],[inp_scale_val]),
                 # Scale weight values
                 make_tensor('scale_i',onnx.TensorProto.FLOAT, [],[w1_scale_val]),
                 make_tensor('scale_c',onnx.TensorProto.FLOAT, [],[w2_scale_val]),
                 make_tensor('scale_o',onnx.TensorProto.FLOAT, [],[w3_scale_val]),
                 make_tensor('scale_f',onnx.TensorProto.FLOAT, [],[w4_scale_val]),
                 # Scale values for the six equations
                 make_tensor('scale_1',onnx.TensorProto.FLOAT, [],[eq_scale_val_1]),
                 make_tensor('scale_2',onnx.TensorProto.FLOAT, [],[eq_scale_val_1]), 
                 make_tensor('scale_3',onnx.TensorProto.FLOAT, [],[eq_scale_val_1]),
                 make_tensor('scale_test',onnx.TensorProto.FLOAT, [],[eq_scale_val_1]),
                 make_tensor('scale_4',onnx.TensorProto.FLOAT, [],[eq_scale_val_1]),
                 make_tensor('scale_5',onnx.TensorProto.FLOAT, [],[eq_scale_val_1]),
                 make_tensor('scale_6',onnx.TensorProto.FLOAT, [],[eq_scale_val_1]),
                 make_tensor('scale_7',onnx.TensorProto.FLOAT, [],[eq_scale_val_2]), 
                 make_tensor('scale_8',onnx.TensorProto.FLOAT, [],[eq_scale_val_2]),
                 make_tensor('scale_9',onnx.TensorProto.FLOAT, [],[eq_scale_val_1]),
                 make_tensor('scale_10',onnx.TensorProto.FLOAT, [],[eq_scale_val_2]),
                 make_tensor('scale_11',onnx.TensorProto.FLOAT, [],[eq_scale_val_1]),
                 # Scales for zero-points : Zero-point datatype defines the dataype of the output for that quantization
                 make_tensor('zero_point_signed',onnx.TensorProto.INT8,[],[zero_point_signed_val]),
                 make_tensor('zero_point_unsigned',onnx.TensorProto.UINT8,[],[zero_point_unsigned_val]),
                 # Introducing scalars for the clip operators.
                 make_tensor('min', onnx.TensorProto.INT8, [], [min_clip_val]),
                 make_tensor('max', onnx.TensorProto.INT8, [], [max_clip_val]),
                ]
)

The above created graph can now be converted into a qonnx model with the `qonnx_make_model` utility. We save the model with `onnx.save` utility and then view it in Netron with the help of `showInNetron` utility.  


In [8]:
lstm_model = qonnx_make_model(lstm_body, producer_name="QuantizeLSTM_scan")
onnx.save(lstm_model, './lstm_full_graph.onnx')
netron.start('./lstm_full_graph.onnx')

Serving './lstm_full_graph.onnx' at http://localhost:8080


('localhost', 8080)

In this block of code we execute the onnx graph to check that it can execute without any errors. We perform it's functional verification in the later part of the notebook.

In [9]:
# Before the model can be executed, it'd opset version needs to be set to a minimum of '14' to accomodate clip nodes with INT8 and UINT8 input. Otherwise ONNX cannot create an execution session and we get errors.
lstm_model.opset_import[0].version = 14

# Creating the inference session here for the updated model here
sess = rt.InferenceSession(lstm_model.SerializeToString())

# Defining dummy inputs and the model parameters for dummy execution
X_inp = np.empty([num_features,1],dtype=np.float32).reshape([num_features,1])
X_inp.fill(0.8)
hidden_state_input =  np.zeros((num_hidden_cells, 1)).astype(np.float32)
cell_state_input =  np.zeros((num_hidden_cells, 1)).astype(np.float32)

# Assigning the above defined values to the input dictionary of the ONNX model.
input_dict = {}
input_dict["inp"] = X_inp
input_dict["h_t-1"] = hidden_state_input
input_dict["c_t-1"] = cell_state_input 

# Setting up the inference session and executing the onnx model here.
sess = rt.InferenceSession(lstm_model.SerializeToString())
output = sess.run(None, input_dict)
print(output)

[array([[ 0.1484375],
       [-0.0078125],
       [ 0.0390625],
       [ 0.140625 ],
       [ 0.015625 ],
       [ 0.       ],
       [ 0.1015625],
       [-0.1015625],
       [ 0.0390625],
       [-0.0625   ],
       [ 0.015625 ],
       [-0.125    ],
       [ 0.1015625],
       [ 0.03125  ],
       [ 0.1640625],
       [-0.015625 ],
       [-0.0234375],
       [-0.015625 ],
       [-0.046875 ],
       [ 0.0078125]], dtype=float32), array([[ 0.2421875],
       [-0.0078125],
       [ 0.0625   ],
       [ 0.2421875],
       [ 0.03125  ],
       [ 0.0078125],
       [ 0.2265625],
       [-0.234375 ],
       [ 0.0859375],
       [-0.1328125],
       [ 0.0390625],
       [-0.2421875],
       [ 0.1875   ],
       [ 0.0546875],
       [ 0.296875 ],
       [-0.03125  ],
       [-0.0546875],
       [-0.03125  ],
       [-0.109375 ],
       [ 0.0234375]], dtype=float32), array([[ 0.1484375],
       [-0.0078125],
       [ 0.0390625],
       [ 0.140625 ],
       [ 0.015625 ],
       [ 0.       ],

2023-10-20 11:07:46.350885612 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'scale_test'. It is not used by any node and should be removed from the model.
2023-10-20 11:07:46.370978980 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'scale_test'. It is not used by any node and should be removed from the model.


# SCAN Operation Integration

### Introduction to ONNX Scan operation
Observations regarding the `Scan` operator in ONNX:

1. `Scan` can be used to iterate over one or more scan input tensors constructing zero or more scan output tensors. It combines ideas from general recurrences, functional programming cnostructs such as scan, fold, map and zip.
2. The attribute `body` in the node must be a graph specifying the computation to be performed in every iteration.
3. Input is the current values of the `state variables` and the current `iterated element` of the scan input. Returns values of the `state variables` and the `scan output element tensors`. (Can be greater than 1)
4. The values of the scan output tensors are concatenated over all the iterations to produce the scan output values of the scan construct.
5. The properties that make a scan node unique and different from a normal compute node are:
* Allows update of state variable after each input computation; to be used in the processing of the next input.
* It needs to scan your inputs row by row or column by column; then keep computing the output with the updated hidden state for every input; while storing all the intermediate outputs in the form of hidden states.

More information regarding this op can be found in these links:

* https://github.com/onnx/onnx/blob/main/docs/Operators.md#Scan
* https://onnx.ai/onnx/intro/python.html#scan

The `Scan` operation is essentially a container operator which will consume the LSTM graph that we created above in it's body.
To create it, we need to define separate input and output value info tensors just for the Scan operator. We will then follow the same steps as the `QCDQ-LSTM` graph creation to convert the above graph into an executable ONNX model.
<br><br>
We start by defining the input and output value info tensors for the `scan_graph` creation. These tensors act as the wrapper to the previously defined graph.


In [10]:
# Inputs
scan_input = make_tensor_value_info("scan_input",onnx.TensorProto.FLOAT, [None,num_features,1])#X ; scan input. Here None defines the varibale number of inputs that can be supplied for input processing.
scan_hidden_state      = make_tensor_value_info("scan_hidden_state",onnx.TensorProto.FLOAT, [num_hidden_cells,1])# h_t-1
scan_cell_state      = make_tensor_value_info("scan_cell_state",onnx.TensorProto.FLOAT, [num_hidden_cells,1])# c_t-1

# Outputs
scan_out_hidden_state = make_tensor_value_info("scan_out_hidden_state", onnx.TensorProto.FLOAT, [num_hidden_cells,1])#h_t
scan_out_cell_state = make_tensor_value_info("scan_out_cell_state", onnx.TensorProto.FLOAT, [num_hidden_cells,1])#c_t
scan_out_hidden_state_concat = make_tensor_value_info("scan_out_hidden_state_concat", onnx.TensorProto.FLOAT, [None,num_hidden_cells,1])

We will now create the scan operator here now utilizing the `make_node` utility from ONNX.
Note, in the body of the operation we have included the `lstm_body` graph we created in the above steps.

In [11]:
scan_node_lstm = make_node(
    "Scan", 
    inputs=["scan_hidden_state","scan_cell_state","scan_input"], 
    outputs=["scan_out_hidden_state","scan_out_cell_state","scan_out_hidden_state_concat"], 
    num_scan_inputs=1,
    body=lstm_body, domain=''
)

We can now define the graph for the scan operator utilizing the `make_graph` utility.

In [12]:
scan_lstm_node_graph = make_graph(
    nodes = [scan_node_lstm],
    name="lstm-scan-node",
    inputs=[scan_hidden_state,scan_cell_state,scan_input],#h_t-1, c_t-1, X
    outputs=[scan_out_hidden_state,scan_out_cell_state,scan_out_hidden_state_concat]#h_t,c_t,h_t_concat
)

# Creating the model from the above created graph and saving it.
lstm_scan_node_model = qonnx_make_model(scan_lstm_node_graph, producer_name="scan-lstm")
onnx.save(lstm_scan_node_model, './lstm_scan_node_model.onnx')
netron.start('./lstm_scan_node_model.onnx')

#Checking the model for any errors
onnx.checker.check_model(lstm_scan_node_model)
print(lstm_scan_node_model.graph.value_info)

#Conversion to version 14 of onnx to accomodate clip nodes as done for the LSTM graph also.
lstm_scan_node_model.opset_import[0].version = 14

Serving './lstm_scan_node_model.onnx' at http://localhost:8081
[]


Now that we have the SCAN based quantized LSTM model ready, we can now go forward and test it with the same sets of inputs we used for the testing of the brevitas model.


In [13]:
# Defining the values of the varibales to test the execution of the scan model
num_inputs = 25

#Initializing the initial values of the hidden state and the cell state. 
# Also assigning the same input as the one used for the brevitas execution.

hidden_state_inp =  np.zeros((num_hidden_cells, 1)).astype(np.float32)#'h_t-1'
cell_state_inp = np.zeros((num_hidden_cells, 1)).astype(np.float32)#'c_t-1'
scan_inp = np.empty([num_inputs,num_features,1],dtype=np.float32).reshape([num_inputs,num_features,1])
scan_inp.fill(0.8)

# Assigning the defined input values to the input dictionary of the scan model
input_dict = {}
input_dict["scan_hidden_state"] = hidden_state_inp
input_dict["scan_cell_state"] = cell_state_inp
input_dict["scan_input"] = scan_inp

# We can now set up the inference session and execute the scan onnx model here. 
# The execution session gives some warnings which can be ignored.

sess = rt.InferenceSession(lstm_scan_node_model.SerializeToString())
scan_output = sess.run(None, input_dict)
print('Final Hidden State',scan_output[0])
print("------------------------")
print('Final Cell State',scan_output[1])
print("------------------------")
print('All Hidden States',scan_output[2])

Final Hidden State [[ 0.25     ]
 [-0.046875 ]
 [ 0.015625 ]
 [ 0.2734375]
 [ 0.0546875]
 [-0.0390625]
 [ 0.25     ]
 [-0.1953125]
 [ 0.0546875]
 [-0.140625 ]
 [ 0.015625 ]
 [-0.203125 ]
 [ 0.203125 ]
 [ 0.140625 ]
 [ 0.2734375]
 [ 0.03125  ]
 [-0.03125  ]
 [-0.046875 ]
 [-0.0703125]
 [ 0.0078125]]
------------------------
Final Cell State [[ 0.421875 ]
 [-0.078125 ]
 [ 0.0234375]
 [ 0.4921875]
 [ 0.1484375]
 [-0.09375  ]
 [ 0.75     ]
 [-0.59375  ]
 [ 0.1171875]
 [-0.3125   ]
 [ 0.0390625]
 [-0.421875 ]
 [ 0.3984375]
 [ 0.2578125]
 [ 0.828125 ]
 [ 0.0625   ]
 [-0.0703125]
 [-0.109375 ]
 [-0.1484375]
 [ 0.0234375]]
------------------------
All Hidden States [[[ 0.1484375]
  [-0.0078125]
  [ 0.0390625]
  [ 0.140625 ]
  [ 0.015625 ]
  [ 0.       ]
  [ 0.1015625]
  [-0.1015625]
  [ 0.0390625]
  [-0.0625   ]
  [ 0.015625 ]
  [-0.125    ]
  [ 0.1015625]
  [ 0.03125  ]
  [ 0.1640625]
  [-0.015625 ]
  [-0.0234375]
  [-0.015625 ]
  [-0.046875 ]
  [ 0.0078125]]

 [[ 0.203125 ]
  [-0.0234375]
  

2023-10-20 10:50:38.892379706 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'scale_test'. It is not used by any node and should be removed from the model.
2023-10-20 10:50:38.894726380 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'ql_uo_out'. It is not used by any node and should be removed from the model.
2023-10-20 10:50:38.894741924 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'ql_wf_out'. It is not used by any node and should be removed from the model.
2023-10-20 10:50:38.894750521 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'ql_ui_out'. It is not used by any node and should be removed from the model.
2023-10-20 10:50:38.894758793 [W:onnxruntime:, graph.cc:3559 CleanUnusedInitializersAndNodeArgs] Removing initializer 'max'. It is not used by any node and should be removed from the model.
2023-10-20 10:50:38.89476

# Functional Verification

In the final part of the notebook, we compare the output of the 8-bit quantized `(QCDQ)-LSTM` implementation with the `QuantLSTM` brevitas model.


In [14]:
# We first match the shape of both the outputs to perform the functional verification correctly

print('Brevitas Output shape : ', brevitas_output.shape)
all_hidden_states = np.array(scan_output[2])
all_hidden_states = all_hidden_states.reshape([num_inputs,1,num_hidden_cells])
print('SCAN-QCDQ-LSTM output shape :', all_hidden_states.shape)
print('-----------------------------------')
print('Brevitas Output = ',brevitas_output)
print('-----------------------------------')
print('SCAN-QCDQ-LSTM output',all_hidden_states)
print('-----------------------------------')

# Comparison between the 'Scan-LSTM output' and the brevitas 'QuantLSTM' ouptut
# Since the outputs from both models are floating-point, to get a better understanding of the differences we scale the outputs to INT8 precision and then compare their differences.
# The scale used to do that is the last scale of the LSTM graph.

scale = inp_scale_val #The scale value is equal to the value of the inp_scale_val
all_hidden_states = np.array(scan_output[2])
all_hidden_states = all_hidden_states.reshape([num_inputs,1,num_hidden_cells])
all_hidden_state_diff = (all_hidden_states - brevitas_output)
print(all_hidden_state_diff/scale)

Brevitas Output shape :  (25, 1, 20)
SCAN-QCDQ-LSTM output shape : (25, 1, 20)
-----------------------------------
Brevitas Output =  [[[ 0.1484375 -0.0078125  0.0390625  0.140625   0.0078125  0.
    0.109375  -0.09375    0.0390625 -0.0625     0.015625  -0.1171875
    0.1015625  0.03125    0.1640625 -0.015625  -0.0234375 -0.015625
   -0.046875   0.0078125]]

 [[ 0.2109375 -0.0234375  0.03125    0.2109375  0.0234375 -0.015625
    0.1875    -0.1484375  0.046875  -0.09375    0.0234375 -0.1640625
    0.1484375  0.0625     0.2578125 -0.015625  -0.03125   -0.0234375
   -0.0703125  0.015625 ]]

 [[ 0.2421875 -0.0390625  0.015625   0.25       0.03125   -0.0234375
    0.234375  -0.1796875  0.0546875 -0.109375   0.015625  -0.1875
    0.1796875  0.09375    0.3125     0.        -0.03125   -0.03125
   -0.078125   0.0078125]]

 [[ 0.25      -0.0390625  0.015625   0.265625   0.0390625 -0.03125
    0.265625  -0.1875     0.0546875 -0.125      0.015625  -0.1953125
    0.1953125  0.1171875  0.3359375  0.

Note the difference in outputs increases as we progress with processing the inputs. The first two outputs are very close to one another, but as we get the outputs for more inputs we see for some values differ from the brevitas output by a considerable amount.
This behaviour can be attributed to some values being slightly different in the first few outputs (<i>which are not visible</i>) which eventually cause an increase in differences between both values as more inputs are processed.