In [None]:
import tensorflow as tf
import numpy as np

import typing

## Experimentation with tf.compat.v1.nn.rnn_cell.LSTMCell

##### inputs to LSTM has the following shape [batch, timesteps, feature]

In [None]:
input_shape = (1, 1024, 2)
x = tf.random.normal(input_shape, stddev=0.1)
x = tf.constant(x, dtype=tf.float32)

In [None]:
lstm_layer = tf.compat.v1.nn.rnn_cell.LSTMCell(num_units=32, use_peepholes=True)

In [None]:
class LSTM(tf.Module):
    
    def __init__(self,
                 num_units: typing.Union[typing.List[int], int],
                 use_peepholes: bool,
                 name: str = None):
        super(LSTM, self).__init__(name)
        
        # if num_units is given as int, ensure it's set to a list
        if isinstance(num_units, int):
            num_units = [num_units]
        else:
            num_units = num_units      
        
        self.num_units = num_units
        self.use_peepholes = use_peepholes
        
        self.lstm_layers = [tf.compat.v1.nn.rnn_cell.LSTMCell(num_units=size, 
                                                              use_peepholes=self.use_peepholes) for size in self.num_units]
        self.multi_lstm_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell(self.lstm_layers)
                
    def __call__(self, x_in):
        
        return tf.compat.v1.nn.dynamic_rnn(cell=self.multi_lstm_cell, inputs=x_in, dtype=tf.float32)

In [None]:
lstm_layer = LSTM(num_units=32, use_peepholes=True)

In [None]:
outputs, state = lstm_layer(x_in=x)

In [None]:
outputs