In [1]:
import tensorflow as tf
import numpy as np
rng = np.random.default_rng(seed=42)

In [89]:
ds1 = (tf.data.Dataset.range(100).
       window(size=3, shift=1, drop_remainder=True).
       flat_map(lambda ds : ds.batch(3)).
       window(size=4, drop_remainder=True).
       flat_map(lambda ds: ds.batch(4)).
       map(lambda S : (S[:, 0], S[:, 1:]))
)

for item, label in ds1.take(3):
  print(f"{item} maps to {label}\n")

[0 1 2 3] maps to [[1 2]
 [2 3]
 [3 4]
 [4 5]]

[4 5 6 7] maps to [[5 6]
 [6 7]
 [7 8]
 [8 9]]

[ 8  9 10 11] maps to [[ 9 10]
 [10 11]
 [11 12]
 [12 13]]



In [24]:
rnn = tf.keras.layers.SimpleRNN(32, return_sequences=True, input_shape=[None, 5])

In [26]:
X = rng.standard_normal(size=(10, 20, 5))

In [31]:
rnn(X).shape

TensorShape([10, 20, 32])

In [65]:
model = tf.keras.Sequential([
    rnn,
    tf.keras.layers.SimpleRNN(32, return_sequences=True),
    tf.keras.layers.SimpleRNN(64, return_sequences=True),
    tf.keras.layers.Dense(14)
])

In [66]:
model.summary()

Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 simple_rnn (SimpleRNN)      (None, None, 32)          1216      
                                                                 
 simple_rnn_2 (SimpleRNN)    (None, None, 32)          2080      
                                                                 
 simple_rnn_3 (SimpleRNN)    (None, None, 64)          6208      
                                                                 
 dense_2 (Dense)             (None, None, 14)          910       
                                                                 
Total params: 10414 (40.68 KB)
Trainable params: 10414 (40.68 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


### Parameter calculation

RNN: (units * units) + (units * num_features) + (1 * units)

- units \* units = `recurrent_weights`
- units \* features = `input_weights`
- 1 \* units = `bias`


Dense: (input_size \* units) + (1 * units)

- input_size \* units = `input_weights`
- 1\* units = `bias`

In [79]:
l1_par = 32*32 + 32*5 + 1*32 # RNN
l2_par = 32*32 + 32*32 + 1*32 # RNN
l3_par = 64 * 64 + 64 * 32 + 1*64 # RNN
l4_par = 64 * 14 + 1*14 # Dense

l1_par, l2_par, l3_par, l4_par

(1216, 2080, 6208, 910)