<a href="https://colab.research.google.com/github/Jed-77/tensorflow-deeplearning/blob/master/TF2_0_RNN_Shapes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
# Imports
import tensorflow as tf
from tensorflow.keras.layers import Input, SimpleRNN, Flatten, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import SGD, Adam
import numpy as np
import matplotlib.pyplot as plt

In [0]:
# Things you should know!
# N = number of samples (e.g. if we using windows of length 10, then its how many windows can fit into sequence)
# T = sequence length (e.g. if we using windows of length 10 to predict next value, T=10)
# D = number of input features (e.g. the number of functions/stocks etc we have data for)
# M = number of hidden units
# K = number of output units

In [0]:
# Make some data
N = 2
T = 10
D = 3
K = 2
X = np.random.randn(N, T, D)
print(X)

[[[-0.77793668 -1.07274058 -1.28332961]
  [-0.60240791 -0.58348462 -1.33258745]
  [-0.26409752  0.61896939 -0.24345892]
  [ 1.33617474 -0.8782038  -0.87178341]
  [-1.36911427 -0.11213752 -0.83463443]
  [-0.87771077 -1.31518936 -0.14279207]
  [ 0.60548834 -0.50641906  0.11053444]
  [ 0.59813107 -1.14695889 -1.67745537]
  [-0.23479046  1.33625878 -0.07712646]
  [-0.94593678 -1.88330775  0.68517081]]

 [[ 0.84228322  1.10270578  0.50337863]
  [-0.96769632 -0.4457217  -0.27218136]
  [-0.2626106  -2.1637526   0.96763933]
  [-0.97246147  1.12598383  0.98963127]
  [-0.40665527  1.22254771  0.12271403]
  [ 0.59278986 -2.03277646 -0.00542645]
  [ 1.26070947  1.76812368  0.12452051]
  [ 0.40587855  0.92816074 -0.55172977]
  [ 2.96032756 -1.57868623 -0.58505284]
  [-2.08291318 -0.87404986 -1.24929438]]]


In [0]:
# Make an RNN
M = 5
i = Input(shape=(T, D))
x = SimpleRNN(M)(i)
x = Dense(K)(x)
model = Model(i, x)

In [0]:
# Make a prediction
Yhat = model.predict(X)
Yhat

array([[-1.3382235 , -0.5116218 ],
       [-1.9449204 , -0.28821778]], dtype=float32)

In [0]:
# Model summary
model.summary()

Model: "model_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_5 (InputLayer)         [(None, 10, 3)]           0         
_________________________________________________________________
simple_rnn_4 (SimpleRNN)     (None, 5)                 45        
_________________________________________________________________
dense_4 (Dense)              (None, 2)                 12        
Total params: 57
Trainable params: 57
Non-trainable params: 0
_________________________________________________________________


In [0]:
# Get the weights
model.layers[1].get_weights()

[array([[-0.07464129, -0.24965143,  0.429192  , -0.22778082, -0.8423667 ],
        [-0.25631195, -0.40106767,  0.2588299 , -0.698128  , -0.8103905 ],
        [-0.7223803 , -0.7482208 ,  0.6736869 , -0.7189992 , -0.48838496]],
       dtype=float32),
 array([[ 0.30003   , -0.3827919 , -0.6921318 ,  0.05085356, -0.530867  ],
        [ 0.1766843 , -0.524882  ,  0.30694216, -0.77398443,  0.00400643],
        [-0.6218559 , -0.39580104,  0.35819134,  0.26588002, -0.50758594],
        [-0.65360725, -0.1836087 , -0.53750956, -0.23556776,  0.44122177],
        [-0.2546712 ,  0.62257415, -0.09758354, -0.52170634, -0.51560056]],
       dtype=float32),
 array([0., 0., 0., 0., 0.], dtype=float32)]

In [0]:
# Check the shapes of the weights
# so a is the weights connecting input to hidden
# b is the hidden to hidden weights
# c is bias vector
a, b, c = model.layers[1].get_weights()
print(a.shape, b.shape, c.shape)

(3, 5) (5, 5) (5,)


In [0]:
# So we can get weights and biases
Wx, Wh, bh = model.layers[1].get_weights()
Wo, bo = model.layers[2].get_weights()
print(Wo.shape, bo.shape)

(5, 2) (2,)


In [0]:
# lets do a manual RNN layer!
for i in range(N):
  h_last = np.zeros(M)
  x = X[i]
  Yhats = []

  # for each timestep t, 
  for t in range(T):
    h = np.tanh(x[t].dot(Wx) + h_last.dot(Wh) + bh)
    y = h.dot(Wo) + bo
    Yhats.append(y)
    h_last=h

  print(Yhats[-1])

[-1.3382235  -0.51162178]
[-1.94492055 -0.28821775]
