In [47]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [48]:
# Notation
# N = number of samples
# T = sequence length
# D = number of input features
# M = number of hidden units
# K = number of output units

In [49]:
# Make some data
N = 1
T = 10
D = 3
M = 5
K = 2
X = np.random.randn(N, T, D)

In [50]:
X[0,:,0]

array([-1.42202959, -0.30584167, -0.10123845,  1.08277257,  0.16910943,
        0.65416014,  0.0113329 ,  0.2241193 , -1.48342856,  0.41035106])

In [51]:
X[0,:,1]

array([ 0.27789963,  1.43820656, -2.15667515,  0.46328696, -0.79548003,
       -0.37435387,  0.4919916 , -1.01477065, -0.24366567, -1.1042471 ])

In [52]:
# Make an RNN
class SimpleRNN(nn.Module):
    def __init__(self, n_inputs, n_hidden, n_outputs):
        super(SimpleRNN, self).__init__()
        self.D = n_inputs
        self.M = n_hidden
        self.K = n_outputs
        
        self.rnn = nn.RNN(input_size=self.D,
                          hidden_size=self.M,  
                          nonlinearity='tanh',
                          batch_first=True)
        
        self.fc = nn.Linear(self.M, self.K)

    def forward(self, X):
        # initial hidden states
        h0 = torch.zeros(1, X.size(0), self.M)

        # get RNN unit output
        out, _ = self.rnn(X, h0)

        # we only want h(T) at the final time step
        # out = self.fc(out[:, -1, :])
        out = self.fc(out)
        return out

In [53]:
# Instantiate the model
model = SimpleRNN(n_inputs=D, n_hidden=M, n_outputs=K)

In [54]:
# Get the output
inputs = torch.from_numpy(X.astype(np.float32))
out = model(inputs)
out

tensor([[[ 0.3534, -0.0065],
         [ 0.3624, -0.1680],
         [ 0.0615, -0.1249],
         [ 0.2041, -0.8968],
         [ 0.0632, -0.3090],
         [ 0.2877, -0.5916],
         [ 0.3546, -0.4842],
         [ 0.2157, -0.2293],
         [-0.0696, -0.1272],
         [ 0.2788, -0.3015]]], grad_fn=<AddBackward0>)

In [55]:
out.shape

torch.Size([1, 10, 2])

In [56]:
# Save for later
Yhats_torch = out.detach().numpy()

In [57]:
# W_xh input_x to 1st hidden unit connection weights
# W_hh hidden to hidden unit connection
# Bias connection
W_xh, W_hh, b_xh, b_hh = model.rnn.parameters()

In [58]:
W_xh.shape

torch.Size([5, 3])

In [59]:
W_xh

Parameter containing:
tensor([[ 0.1077, -0.3469, -0.1210],
        [-0.4451,  0.2497,  0.4097],
        [-0.2887, -0.3506, -0.0323],
        [-0.0926,  0.1427, -0.0499],
        [ 0.2744, -0.3330,  0.3996]], requires_grad=True)

In [60]:
W_xh = W_xh.data.numpy()
W_xh

array([[ 0.10774289, -0.34690294, -0.12098341],
       [-0.4451445 ,  0.24967088,  0.40971905],
       [-0.28871554, -0.35062268, -0.03232153],
       [-0.09263885,  0.14269963, -0.04990234],
       [ 0.27435327, -0.33303034,  0.39962932]], dtype=float32)

In [61]:
W_xh.shape

(5, 3)

In [62]:
b_xh = b_xh.data.numpy()
W_hh = W_hh.data.numpy()
b_hh = b_hh.data.numpy()

In [63]:
# Did we do it right?
W_xh.shape, b_xh.shape, W_hh.shape, b_hh.shape

((5, 3), (5,), (5, 5), (5,))

In [64]:
# Now get the FC layer weights
Wo, bo = model.fc.parameters()

In [65]:
Wo = Wo.data.numpy()
bo = bo.data.numpy()
Wo.shape, bo.shape

((2, 5), (2,))