In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

Important labels for network shape:


*   N = number of samples
*   T = sequence length
*   D = number of input features
*   M = number of hidden units (hyperparameter and can be tuned)
*   K = number of output units



In [2]:
# Make some data
N = 1
T = 10
D = 3
M = 5
K = 2
X = np.random.randn(N, T, D)

In [3]:
# Make an RNN
class SimpleRNN(nn.Module):
  def __init__(self, n_inputs, n_hidden, n_outputs):
    super(SimpleRNN, self).__init__()
    self.D = n_inputs
    self.M = n_hidden
    self.K = n_outputs
    self.rnn = nn.RNN(
        input_size=self.D,
        hidden_size=self.M,
        nonlinearity='tanh',
        batch_first=True)
    self.fc = nn.Linear(self.M, self.K)

  def forward(self, X):
    # initial hidden states
    h0 = torch.zeros(1, X.size(0), self.M)

    # get RNN unit output
    out, _ = self.rnn(X, h0)

    # This time output will be N x T x K instead of N x K
    # I am only accessing h(T) at the final time step using
    # out = self.fc(out[:, -1, :])
    out = self.fc(out)
    return out

In [4]:
# Instantiate the model
model = SimpleRNN(n_inputs=D, n_hidden=M, n_outputs=K)

In [5]:
# Get the output
inputs = torch.from_numpy(X.astype(np.float32))
out = model(inputs)
out 

tensor([[[ 0.0917, -0.2369],
         [ 0.2728, -0.2486],
         [-0.0164, -0.3989],
         [-0.1969, -0.3519],
         [ 0.0475, -0.3387],
         [-0.0955, -0.5257],
         [-0.2014, -0.3308],
         [-0.1045, -0.0631],
         [-0.0864, -0.5629],
         [-0.0057, -0.6995]]], grad_fn=<AddBackward0>)

In [6]:
# Looking at the previous values for the shape variables
# it can be seen that this output does indeed have shape N x T x K
out.shape

torch.Size([1, 10, 2])

In [7]:
# Save for later
Yhats_torch = out.detach().numpy()

In [8]:
W_xh, W_hh, b_xh, b_hh = model.rnn.parameters()

In [9]:
# Input to hidden weight
# Input dimensionality is 3, hidden is 5
W_xh.shape

torch.Size([5, 3])

In [10]:
W_xh

Parameter containing:
tensor([[ 0.1504, -0.1802, -0.1529],
        [-0.3785,  0.2207,  0.0571],
        [ 0.0859, -0.2511,  0.0584],
        [-0.2800, -0.0195,  0.1244],
        [ 0.3980,  0.0428, -0.1774]], requires_grad=True)

In [11]:
W_xh = W_xh.data.numpy()
W_xh

array([[ 0.1503795 , -0.18021508, -0.15287407],
       [-0.37845144,  0.22074474,  0.05712606],
       [ 0.08588155, -0.25113446,  0.05840528],
       [-0.27996483, -0.01949871,  0.12437741],
       [ 0.39804035,  0.04280187, -0.1773759 ]], dtype=float32)

In [12]:
b_xh = b_xh.data.numpy()
W_hh = W_hh.data.numpy()
b_hh = b_hh.data.numpy()

In [13]:
# Check the shape of all the weights in order
W_xh.shape, b_xh.shape, W_hh.shape, b_hh.shape

((5, 3), (5,), (5, 5), (5,))

This makes sense as the hidden to hidden weight is obviously 5 x 5. Important to note that pytorch separates the input-to-hidden bias and hidden-to-hidden bias, but this shouldn't be a problem as long as shapes are used correctly. 

In [14]:
# Now get the FC layer weights
Wo, bo = model.fc.parameters()

In [15]:
Wo = Wo.data.numpy()
bo = bo.data.numpy()
Wo.shape, bo.shape

((2, 5), (2,))

Ordering here is fine and matches what has already been seen during the input-to-hidden and hidden-to-hidden weights and biases.

In [17]:
# Can the output be repliacted?
h_last = np.zeros(M) # initial hidden state
x = X[0] # the one and only samples
Yhats = np.zeros((T, K)) # outputs are stored here

for t in range(T):
  h = np.tanh(x[t].dot(W_xh.T) + b_xh + h_last.dot(W_hh.T) + b_hh)
  y = h.dot(Wo.T) + bo # Only care about this value on the last iteration
  Yhats[t] = y
  
  # important: assign h to h_last
  h_last = h

# print the final output
print(Yhats)

[[ 0.09166565 -0.23685287]
 [ 0.2728051  -0.24859737]
 [-0.01639132 -0.39893358]
 [-0.19688449 -0.35194097]
 [ 0.04749314 -0.33867361]
 [-0.09554299 -0.5256515 ]
 [-0.20135794 -0.33078339]
 [-0.10445531 -0.06312564]
 [-0.08644079 -0.56293861]
 [-0.00570175 -0.69949129]]


In [18]:
# Check that manual calculation matches model calculation
np.allclose(Yhats, Yhats_torch)

True