In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

%matplotlib inline

In [3]:
N = 1
T = 10
D = 3
M = 5
K = 2

X = np.random.randn(N, T, D)

In [7]:
X.shape

(1, 10, 3)

In [12]:
class SimpleRNN(nn.Module):
    def __init__(self, n_inputs, n_hidden, n_outputs):
        super(SimpleRNN, self).__init__()
        self.D = n_inputs
        self.M = n_hidden
        self.K = n_outputs
        self.rnn = nn.RNN(input_size=self.D, 
                         hidden_size=self.M, 
                         nonlinearity="tanh", 
                         batch_first=True)
        self.fc = nn.Linear(in_features=self.M, out_features=self.K)
        
    def forward(self, X):
        h0 = torch.zeros(1, X.size(0), self.M)
        
        out, _ = self.rnn(X, h0)        
        out = self.fc(out)
        
        return out

In [13]:
model = SimpleRNN(n_inputs=D, n_hidden=M, n_outputs=K)

In [14]:
inputs = torch.from_numpy(X.astype(np.float32))
out = model(inputs)
out

tensor([[[ 0.5987, -0.2216],
         [ 0.7002, -0.6593],
         [ 0.6124, -0.3099],
         [ 0.6238, -0.4604],
         [ 0.5325, -0.7311],
         [ 0.5148, -0.1264],
         [ 0.6934, -0.6686],
         [ 0.2917, -0.5214],
         [ 0.3582, -0.5384],
         [ 0.7235, -0.5445]]], grad_fn=<AddBackward0>)

In [15]:
out.shape

torch.Size([1, 10, 2])

In [17]:
Yhats_torch = out.detach().numpy()

In [21]:
list(model.rnn.parameters())

[Parameter containing:
 tensor([[-0.3180, -0.4152,  0.1189],
         [-0.3502, -0.3808, -0.0015],
         [ 0.2193, -0.1365,  0.0794],
         [-0.2302, -0.2482, -0.0319],
         [-0.0545, -0.2153, -0.1017]], requires_grad=True),
 Parameter containing:
 tensor([[-0.1560, -0.0367, -0.0564, -0.0230, -0.0830],
         [-0.3804, -0.1069, -0.2608, -0.3025, -0.1604],
         [ 0.1130,  0.0211,  0.1655,  0.1819,  0.2535],
         [-0.1622,  0.3057,  0.0178, -0.3950, -0.1463],
         [-0.4174, -0.3990, -0.1748,  0.3901, -0.0380]], requires_grad=True),
 Parameter containing:
 tensor([-0.3132, -0.0158,  0.1979,  0.0683,  0.0420], requires_grad=True),
 Parameter containing:
 tensor([ 0.2985, -0.3526,  0.0883, -0.0957, -0.4079], requires_grad=True)]

In [22]:
W_xh, W_hh, b_xh, b_hh = model.rnn.parameters()

In [25]:
W_xh

Parameter containing:
tensor([[-0.3180, -0.4152,  0.1189],
        [-0.3502, -0.3808, -0.0015],
        [ 0.2193, -0.1365,  0.0794],
        [-0.2302, -0.2482, -0.0319],
        [-0.0545, -0.2153, -0.1017]], requires_grad=True)

In [26]:
W_xh.shape

torch.Size([5, 3])

In [29]:
W_xh = W_xh.detach().numpy()
W_hh = W_hh.detach().numpy()
b_xh = b_xh.detach().numpy()
b_hh = b_hh.detach().numpy()

In [30]:
W_xh.shape, b_xh.shape, W_hh.shape, b_hh.shape

((5, 3), (5,), (5, 5), (5,))

In [32]:
list(model.fc.parameters())

[Parameter containing:
 tensor([[ 0.0610, -0.2819,  0.3305,  0.1422,  0.1549],
         [ 0.4065, -0.0250, -0.0797,  0.1969, -0.3758]], requires_grad=True),
 Parameter containing:
 tensor([ 0.4410, -0.4391], requires_grad=True)]

In [33]:
Wo, bo = model.fc.parameters()

In [35]:
Wo

Parameter containing:
tensor([[ 0.0610, -0.2819,  0.3305,  0.1422,  0.1549],
        [ 0.4065, -0.0250, -0.0797,  0.1969, -0.3758]], requires_grad=True)

In [37]:
Wo, bo = Wo.detach().numpy(), bo.detach().numpy()

In [38]:
Wo.shape, bo.shape

((2, 5), (2,))

In [76]:
h_last = np.zeros(M)
x = X[0]
Yhats = np.zeros((T, K))

for t in range(T):
    h = np.tanh(x[t].dot(W_xh.T) + b_xh + h_last.dot(W_hh.T) + b_hh)    
    y = h.dot(Wo.T) + bo # We only care about this value on the last iteration
    Yhats[t] = y
    
    h_last = h
print(Yhats)

[[ 0.59865131 -0.2216457 ]
 [ 0.70023942 -0.65925628]
 [ 0.61236143 -0.30994398]
 [ 0.62379994 -0.46043421]
 [ 0.53251987 -0.73106588]
 [ 0.51476081 -0.12639734]
 [ 0.69336252 -0.66861054]
 [ 0.29174639 -0.52137123]
 [ 0.35820881 -0.53836104]
 [ 0.7234539  -0.54447798]]


In [82]:
np.allclose(Yhats, Yhats_torch)

True