# Demo: simple recurrent neural network with time-averaging

In [1]:
import tensorflow as tf
from connectionist.layers import TimeAveragedDense

print(TimeAveragedDense.__doc__)

Dense layer with Time-averaging mechanism.

    In short, time-averaging mechanism simulates continuous-temporal dynamics in a discrete-time recurrent neural networks.
    See Plaut, McClelland, Seidenberg, and Patterson (1996) equation (15) for more details.

    Args:
        tau (float): Time-averaging parameter (How much information should take from the new input). range: [0, 1].

        average_at (str): Where to average. Options: 'before_activation', 'after_activation'.

            When average_at is 'before_activation', the time-averaging is applied BEFORE activation. i.e., time-averaging INPUT.:
                outputs = activation(integrated_input);
                integrated input = tau * (inputs @ weights + bias) + (1-tau) * last_inputs;
                last_inputs is obtained from the last call of this layer, its values stored at `self.states`

            When average_at is 'after_activation', the time-averaging is applied AFTER activation. i.e., time-averaging OUTPUT.:


### Toy model with only one TimeAveragedDense layer

In [None]:
model = TimeAveragedDense(tau=0.2, average_at="after_activation", units=3)
x = tf.constant([[1.0, 2.0]])
model(x)

- in a typical Dense layer, given the same input, the output will be the same regardless of how many time the model is called
- but it is not the case in TimeAveragedDense layer, the output will be different each time the model is called
- this is the core mechanism of TimeAveragedDense layer, it is a kind of a "dampening" layer, more time the layer is called, the output will be closer to the asymptotic value

In [None]:
model(x)

In [None]:
model(x)

### Try to build a toy RNN Cell with TimeAveragedDense layer

In [None]:
from typing import Tuple, List

class RNNCell(tf.keras.layers.Layer):
    def __init__(self, tau, units):
        super().__init__()
        self.tau = tau
        self.units = units

    def build(self, input_shape):
        self.recurrent_dense = tf.keras.layers.Dense(self.units, use_bias=False)
        self.input_dense = tf.keras.layers.Dense(self.units, use_bias=False)
        self.sum = tf.keras.layers.Add()
        self.time_averaged_dense = TimeAveragedDense(tau=self.tau, average_at="after_activation", units=self.units, activation='sigmoid')
        self.built = True

    def call(self, inputs, states=None):
        if states is None:
            outputs = self.input_dense(inputs)
        else:
            outputs = self.sum([self.input_dense(inputs), self.recurrent_dense(states)])
        
        outputs = self.time_averaged_dense(outputs)
        return outputs, outputs

    def reset_states(self):
        self.time_averaged_dense.reset_states()

In [None]:
cell = RNNCell(tau=0.2, units=3)

#### Manually unroll the RNN Cell

In [None]:
x = tf.constant([[1.0, 2.0, 3.0]])
states = None
ys = []
for _ in range(10):
    y, state = cell(x, states)
    ys.append(y.numpy())

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

plt.plot(np.stack(ys).squeeze())

#### Unroll in a proper keras layer

In [None]:
class TimeAveragedRNN(tf.keras.layers.Layer):
    def __init__(self, tau, units):
        super().__init__()
        self.tau = tau
        self.units = units

    def build(self, input_shape):
        self.rnn_cell = RNNCell(tau=self.tau, units=self.units)
        self.built = True

    def call(self, inputs):
        max_ticks = inputs.shape[1]  # (batch_size, seq_len, input_dim)
        outputs = tf.TensorArray(dtype=tf.float32, size=max_ticks)
        states = None

        for t in range(max_ticks):
            this_tick_input = inputs[:, t, :]
            states, output = self.rnn_cell(this_tick_input, states=states)
            outputs = outputs.write(t, output)

        # states persist across tick, but not across batches, so we need to reset it
        self.rnn_cell.reset_states()
        outputs = outputs.stack()  # (seq_len, batch_size, units)
        outputs = tf.transpose(outputs, [1, 0, 2])  # (batch_size, seq_len, units)
        return outputs

Small imporvements compared with manually unrolling the RNN Cell:

- we infer the number of time ticks from the input shape, it allows time varying inputs
- Insteal of a ys list, we use a outputs = tf.TensorArray, it allows long time series and avoid memory issues

In [None]:
rnn = TimeAveragedRNN(tau=0.2, units=3)

In [None]:
x = tf.ones((1, 10, 3))

In [None]:
y = rnn(x)

In [None]:
y.numpy().squeeze()

In [None]:
plt.plot(y.numpy().squeeze())