## Playing with hidden state in TF 2

The final implementation differs!

In [None]:
import numpy as np
import tensorflow as tf

In [None]:
np.random.seed(42)
tf.random.set_seed(42)

input_dim = 3
output_dim = 3
num_timesteps = 2
batch_size = 10
nodes = 100

input_layer = tf.keras.Input(shape=(num_timesteps, input_dim), batch_size=batch_size)

cell = tf.keras.layers.LSTMCell(
    nodes,
    kernel_initializer='glorot_uniform',
    recurrent_initializer='glorot_uniform',
    bias_initializer='zeros',
)

lstm = tf.keras.layers.RNN(
    cell,
    return_state=True,
    return_sequences=True,
    stateful=True,
)

lstm_out, hidden_state, cell_state = lstm(input_layer)

output = tf.keras.layers.Dense(output_dim)(lstm_out)

In [None]:
mdl = tf.keras.Model(
    inputs=input_layer,
    outputs=[output]
)

In [None]:
x = np.random.rand(batch_size, num_timesteps, input_dim).astype(np.float32)
out = mdl(x)
print(np.mean(out))

out = mdl(x)
print(np.mean(out))

In [None]:
lstm.reset_states(states=[np.zeros((batch_size, nodes)), np.zeros((batch_size, nodes))])
out = mdl(x)
print(np.mean(out))

In [None]:
lstm.reset_states(states=[np.ones((batch_size, nodes)), np.ones((batch_size, nodes))])
out = mdl(x)
print(np.mean(out))

## Using Keras model class

In [None]:
np.random.seed(42)
tf.random.set_seed(42)

In [None]:
class Model():
    """ car racing defaults """
    def __init__(
            self,
    ):
        self.nodes = nodes
        self.batch_size = batch_size

        input_layer = tf.keras.Input(shape=(num_timesteps, input_dim), batch_size=batch_size)

        cell = tf.keras.layers.LSTMCell(
            nodes,
            kernel_initializer='glorot_uniform',
            recurrent_initializer='glorot_uniform',
            bias_initializer='zeros',
        )

        self.lstm = tf.keras.layers.RNN(
            cell,
            return_state=True,
            return_sequences=True,
            stateful=False
        )

        lstm_out, hidden_state, cell_state = self.lstm(input_layer)
        output = tf.keras.layers.Dense(output_dim)(lstm_out)

        self.model = tf.keras.Model(inputs=input_layer, outputs=[output, hidden_state, cell_state])

    def get_zero_initial_state(self, inputs):
        return [
            tf.zeros((inputs.shape[0], self.nodes)),
            tf.zeros((inputs.shape[0], self.nodes))
        ]

    def get_initial_state(self, inputs):
        return self.initial_state

    def __call__(self, inputs, states=None):
        """ hack to deal with setting initial state """
        if states is None:
            self.lstm.get_initial_state = self.get_zero_initial_state

        else:
            self.initial_state = states
            self.lstm.get_initial_state = self.get_initial_state

        return self.model(inputs)

In [None]:
input_dim = 3
output_dim = 3
num_timesteps = 2
batch_size = 10
nodes = 100

mdl = Model()

x = np.random.rand(1, num_timesteps, input_dim).astype(np.float32)

out, hidden_state, cell_state = mdl(x)
np.mean(out)

In [None]:
out, hidden_state, cell_state = mdl(x)

np.mean(out)

In [None]:
out, hidden_state, cell_state = mdl(x, states=[tf.ones((1, nodes)), tf.ones((1, nodes))] )

np.mean(out)

In [None]:
out.shape