In [None]:
import tensorflow as tf

### Definition of the cell.

Make sure to add layers in this order

- `input`
- `forget`
- `carry`
- `output`



In [None]:
class MyLSTMCell(tf.keras.layers.Layer):

  def __init__(self, units, **kwargs):
    super().__init__(**kwargs)

    self.state_size = [units,units]
    self.output_size = units

    self.dense_i = tf.keras.layers.Dense(
        units=units,
        bias_initializer="zeros",
        activation="sigmoid")

    self.dense_f = tf.keras.layers.Dense(
        units=units,
        bias_initializer="ones",
        activation="sigmoid")

    self.dense_g = tf.keras.layers.Dense(
        units=units,
        bias_initializer="zeros",
        activation="tanh")

    self.dense_o = tf.keras.layers.Dense(
        units=units,
        bias_initializer="zeros",
        activation="sigmoid")

  def call(self, inputs, outputs):
    memory_state = outputs[0] # h
    carry_state = outputs[1]  # c

    inputs_and_memory = tf.concat([inputs, memory_state], axis=-1)

    i = self.dense_i(inputs_and_memory)
    f = self.dense_f(inputs_and_memory)
    g = self.dense_g(inputs_and_memory)
    o = self.dense_o(inputs_and_memory)

    new_carry_state = f * carry_state + i * g
    y = o * tf.keras.activations.tanh(new_carry_state)

    return y, [y, new_carry_state]

### Optional: Check implementation

The implementation of `MyLSTMCell` is checked by comparing it to the working of the built-in `LSTMCell` of Keras.

We initialize both cells with the same weights and then run them througth a RNN and see whether they give the same output for the same input.

In [None]:
BATCH_SIZE = 5
NUM_STEPS = 10
NUM_FEATURES = 3
UNITS=4
# Some random input
inputs = tf.random.normal([BATCH_SIZE, NUM_STEPS, NUM_FEATURES])

In [None]:
# Check the output shape of running the cell
my_lstm_cell = MyLSTMCell(UNITS)
rnn = tf.keras.layers.RNN(my_lstm_cell)
output = rnn(inputs)
output.shape

In [None]:
# Do the same for the built-in LSTMCell in Keras
cell = tf.keras.layers.LSTMCell(UNITS)
rnn2 = tf.keras.layers.RNN(cell)
output2 = rnn2(inputs)
output2.shape

The internal implementation of LSTMCell is different. It doesn't use `Dense` layers and organizes the weights into three tensors: `kernel`, `recurrent_kernel` and `bias`

In [None]:
# Retrieve the weigths and check their shapes
kernel_weights, recurrent_kernel_weights, bias = rnn2.get_weights()
kernel_weights.shape, recurrent_kernel_weights.shape, bias.shape

Our implementation of an LSTM cell organizes the weights into 8 tensors (4 Dense layers with 2 weight tensors each)

In [None]:
# Loop over the weights and print their size
for w in rnn.get_weights():
  print(w.shape)

Get the weights from our implementation and name them.
The order in which they were created matters!

In [None]:
rnn_weights = rnn.get_weights()
i_w = rnn_weights[0]
i_b = rnn_weights[1]
f_w = rnn_weights[2]
f_b = rnn_weights[3]
c_w = rnn_weights[4]
c_b = rnn_weights[5]
o_w = rnn_weights[6]
o_b = rnn_weights[7]

In [None]:
bias = tf.concat(values=[i_b, f_b, c_b, o_b],axis=-1)
kernel = tf.concat(values=[i_w[:NUM_FEATURES], f_w[:NUM_FEATURES], c_w[:NUM_FEATURES], o_w[:NUM_FEATURES]], axis=-1)
recurrent_kernel = tf.concat(values=[i_w[NUM_FEATURES:], f_w[NUM_FEATURES:], c_w[NUM_FEATURES:], o_w[NUM_FEATURES:]], axis=-1)
bias.shape

In [None]:
# Set the weights on RNN2
rnn2.set_weights([kernel, recurrent_kernel,   bias])

Check that the result of running both RNNs is the same.  Eyeball the results.

In [None]:
rnn2(inputs)

In [None]:
rnn(inputs)

Use `tf.debugging.assert_near` to check that the results are actually "equal".

In [None]:

tf.debugging.assert_near(rnn2(inputs), rnn(inputs))
print("OK")