<a href="https://colab.research.google.com/github/alicepearse/DL_prac/blob/master/Pytorch_practise_RNN_shapes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [31]:
import torch 
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [32]:
# Make some data
N = 1
T = 10
D = 3
M = 5
K = 2
X = np.random.randn(N, T, D)

In [33]:
# Make RNN
class SimpleRNN(nn.Module):
  def __init__(self, n_inputs, n_hidden, n_outputs):
    super(SimpleRNN, self).__init__()
    self.D = n_inputs
    self.M = n_hidden
    self.K = n_outputs
    self.rnn = nn.RNN(
        input_size = self.D,
        hidden_size = self.M,
        nonlinearity ='tanh',
        batch_first=True)
    self.fc = nn.Linear(self.M, self.K)

  def forward(self, X):
    h0 = torch.zeros(1, X.size(0), self.M)

    out, _ = self.rnn(X, h0)

    out = self.fc(out)
    return out

In [34]:
# Instantiate the model
model = SimpleRNN(n_inputs=D, n_hidden=M, n_outputs=K)

In [35]:
# Get the output
inputs = torch.from_numpy(X.astype(np.float32))
out = model(inputs)
out

tensor([[[-0.1016,  0.1344],
         [-0.2939,  0.1317],
         [ 0.4538,  0.2110],
         [-0.0132, -0.0451],
         [ 0.6825,  0.0192],
         [-0.0916, -0.1198],
         [ 0.3576,  0.4373],
         [ 0.3130, -0.3192],
         [ 0.3512,  0.2971],
         [ 0.0263, -0.0525]]], grad_fn=<AddBackward0>)

In [36]:
out.shape

torch.Size([1, 10, 2])

In [37]:
# Save for later
Yhats_torch = out.detach().numpy()

In [38]:
W_xh, W_hh, b_xh, b_hh = model.rnn.parameters()

In [39]:
W_xh.shape

torch.Size([5, 3])

In [40]:
W_xh

Parameter containing:
tensor([[-0.4398,  0.1717,  0.3512],
        [-0.0110,  0.2314,  0.1526],
        [-0.1131,  0.4237, -0.4227],
        [ 0.0526,  0.2155,  0.1588],
        [ 0.3332, -0.3652, -0.1453]], requires_grad=True)

In [41]:
W_xh = W_xh.data.numpy()
W_xh

array([[-0.43984896,  0.17169935,  0.35122377],
       [-0.01095418,  0.23137617,  0.15256268],
       [-0.11311445,  0.42370972, -0.4227131 ],
       [ 0.05255882,  0.21554586,  0.1587588 ],
       [ 0.3331866 , -0.36517578, -0.14529742]], dtype=float32)

In [42]:
b_xh = b_xh.data.numpy()
W_hh = W_hh.data.numpy()
b_hh = b_hh.data.numpy()

In [43]:
W_xh.shape, b_xh.shape, W_hh.shape, b_hh.shape

((5, 3), (5,), (5, 5), (5,))

In [44]:
# get FC layer weights
Wo, bo = model.fc.parameters()

In [45]:
Wo = Wo.data.numpy()
bo = bo.data.numpy()
Wo.shape, bo.shape

((2, 5), (2,))

In [48]:
# See if we can replicate the output
h_last = np.zeros(M)
x = X[0]
Yhats = np.zeros((T, K))

for t in range(T):
  h = np.tanh(x[t].dot(W_xh.T) + b_xh + h_last.dot(W_hh.T) + b_hh)
  y = h.dot(Wo.T) + bo
  Yhats[t] = y

  h_last = h

print(Yhats)

[[-0.1016025   0.13444278]
 [-0.29389649  0.13166414]
 [ 0.45379834  0.21097822]
 [-0.01317817 -0.04505938]
 [ 0.68245626  0.01923217]
 [-0.09164284 -0.11978903]
 [ 0.35764371  0.43726489]
 [ 0.31295147 -0.31922131]
 [ 0.35121695  0.29708971]
 [ 0.02633688 -0.05248884]]


In [49]:
# Check 
np.allclose(Yhats, Yhats_torch)

True