In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [2]:
# Things you should automatically have memorized
# N = number of samples
# T = sequence length
# D = number of input features
# M = number of hidden units, a hyper parameter you can choose
# K = number of output units

In [3]:
# Make some data
N = 1
T = 10
D = 3
M = 5
K = 2

In [4]:
X = np.random.randn(N,T,D)
X

array([[[-0.23057351,  0.34011185, -1.02514592],
        [-0.1646384 ,  0.55413199, -1.77329757],
        [-1.04429864, -0.22444312,  0.83856988],
        [ 0.68521604,  0.5823409 , -0.47796542],
        [-1.32740888, -0.70470459, -1.00652956],
        [-0.40565624,  1.2313286 , -1.96906318],
        [ 1.24372006,  0.17281384,  0.20004363],
        [-0.08168604,  0.52878413, -0.22296019],
        [-0.72683405, -1.29945986,  2.18910435],
        [-1.91439998,  0.76421553, -0.06889867]]])

In [5]:
# Make an RNN
class SimpleRNN(nn.Module):
  def __init__(self, n_inputs, n_hidden, n_outputs):
    super(SimpleRNN, self).__init__()
    self.D = n_inputs
    self.M = n_hidden
    self.K = n_outputs

    # note: batch_first=True
    # applies the convention that our data will be of shape:
    # (num_samples, sequence_length, num_features)
    # rather than:
    # (sequence_length, num_samples, num_features)
    self.rnn = nn.RNN(
        input_size=self.D,
        hidden_size=self.M,
        nonlinearity='tanh',
        batch_first=True
    )
    self.fc = nn.Linear(self.M, self.K)

  def forward(self, X):
    # initial hidden states
    h0 = torch.zeros(1, X.size(0), self.M)
    # get RNN unit output
    # out is of size (N, T, M)
    # 2nd return value is hidden states of each hidden layer
    # we don't need those now
    out, _ = self.rnn(X, h0)
    # we only want h(T) at the final time step
    # N x T x K
    # out = self.fc(out[:,-1,:])
    out = self.fc(out)
    return out


In [6]:
# Instantiate the model
model = SimpleRNN(n_inputs=D,n_hidden=M,n_outputs=K)

In [7]:
# Get the output
inputs = torch.from_numpy(X.astype(np.float32))
out = model(inputs)
out

tensor([[[ 0.5416,  0.4745],
         [ 0.6282,  0.7162],
         [ 0.3514,  0.0372],
         [ 0.3144,  0.4069],
         [ 0.4342,  0.3742],
         [ 0.6546,  0.5738],
         [ 0.3227,  0.5400],
         [ 0.4218,  0.2835],
         [ 0.3228, -0.0259],
         [ 0.4693, -0.1266]]], grad_fn=<ViewBackward0>)

In [8]:
out.shape

torch.Size([1, 10, 2])

In [9]:
# save for later
Yhats_torch = out.detach().numpy()

In [10]:
model.rnn.parameters()

<generator object Module.parameters at 0x7fdc559bded0>

In [11]:
W_xh, W_hh, b_xh, b_hh = model.rnn.parameters()

In [12]:
W_xh.shape

torch.Size([5, 3])

In [13]:
W_xh

Parameter containing:
tensor([[-0.3067,  0.3642,  0.2580],
        [-0.3319,  0.2953,  0.2456],
        [ 0.3493,  0.3615,  0.1211],
        [-0.1931, -0.2119,  0.0412],
        [-0.1780, -0.3662,  0.3844]], requires_grad=True)

In [14]:
W_xh = W_xh.data.numpy()
W_xh

array([[-0.30672246,  0.3642384 ,  0.25798488],
       [-0.33188254,  0.2952884 ,  0.24557331],
       [ 0.34931114,  0.3614636 ,  0.12114366],
       [-0.19305944, -0.21193168,  0.04120395],
       [-0.17796248, -0.36623344,  0.38440922]], dtype=float32)

In [15]:
b_xh = b_xh.data.numpy()
W_hh = W_hh.data.numpy()
b_hh = b_hh.data.numpy()

In [16]:
# Did we do it right?
W_xh.shape, b_xh.shape, W_hh.shape, b_hh.shape

((5, 3), (5,), (5, 5), (5,))

In [17]:
# Now get the FC layer weights
Wo, bo = model.fc.parameters()

In [18]:
Wo = Wo.data.numpy()
bo = bo.data.numpy()
Wo.shape, bo.shape

((2, 5), (2,))

In [19]:
# see if we can replicate the output
h_last = np.zeros(M) # initial hidden state
x = X[0] # the one and only sample
Yhats = np.zeros((T,K)) # where we store the outputs
for t in range(T):
  h = np.tanh(x[t].dot(W_xh.T) + b_xh + h_last.dot(W_hh.T) + b_hh)
  y = h.dot(Wo.T) + bo # we only care about this value on the last iteration
  Yhats[t] = y
  # important: assign h to h_last
  h_last = h

# print the final output
print(Yhats)

[[ 0.54155822  0.47452386]
 [ 0.62818147  0.71615068]
 [ 0.35137679  0.0371797 ]
 [ 0.31441542  0.40687861]
 [ 0.43416625  0.37416902]
 [ 0.65457506  0.57377731]
 [ 0.32267969  0.53996839]
 [ 0.42176932  0.28352412]
 [ 0.32281961 -0.02594833]
 [ 0.46930728 -0.12656425]]


In [20]:
# Check
np.allclose(Yhats, Yhats_torch)

True