In [1]:
import torch
import torch.nn as nn

import airsim

In [2]:
E, H  = 2, 3
B, L  = 4, 5

rnn = nn.RNN(E, H)

In [3]:
for k, v in rnn.state_dict().items():
    print(f'{k:10s} : {tuple(v.shape)}')

weight_ih_l0 : (3, 2)
weight_hh_l0 : (3, 3)
bias_ih_l0 : (3,)
bias_hh_l0 : (3,)


In [9]:
X  = torch.rand(L, B, E)
Y, Hn = rnn(X)
print(tuple(Y.shape), tuple(Hn.shape))

(5, 4, 3) (1, 4, 3)


tensor([[[-0.5892, -0.7327, -0.7597],
         [-0.6388, -0.7431, -0.7853],
         [-0.6909, -0.6159, -0.6811],
         [-0.5768, -0.6716, -0.6906]],

        [[-0.6946, -0.0788, -0.6712],
         [-0.6010, -0.0499, -0.6274],
         [-0.7398,  0.0119, -0.6447],
         [-0.4624, -0.1920, -0.6333]],

        [[-0.7228, -0.4213, -0.6799],
         [-0.7076, -0.3794, -0.5940],
         [-0.8328, -0.2243, -0.5916],
         [-0.7911, -0.4191, -0.6873]],

        [[-0.6418, -0.3269, -0.7339],
         [-0.5429, -0.3485, -0.7091],
         [-0.7455, -0.1306, -0.6357],
         [-0.5461, -0.2064, -0.6436]],

        [[-0.6439, -0.0139, -0.3838],
         [-0.7741, -0.2550, -0.6453],
         [-0.5617, -0.4312, -0.6689],
         [-0.7838, -0.2850, -0.6220]]], grad_fn=<StackBackward0>)

In [5]:
W_ih, W_hh = rnn.weight_ih_l0.detach(), rnn.weight_hh_l0.detach()
B_ih, B_hh = rnn.bias_ih_l0.detach(),   rnn.bias_hh_l0.detach()

In [6]:
Hn = torch.zeros(B,H)

for x in X:
    Hn =torch.tanh(  torch.addmm(B_ih, x,  W_ih.t()) 
                   + torch.addmm(B_hh, Hn, W_hh.t()) )
    print(Hn)

tensor([[-0.6241, -0.6546, -0.6907],
        [-0.6302, -0.7953, -0.8331],
        [-0.4298, -0.7612, -0.7472],
        [-0.5355, -0.6998, -0.7071]])
tensor([[-0.6029, -0.0277, -0.5855],
        [-0.5112, -0.0350, -0.5878],
        [-0.5839,  0.0292, -0.4721],
        [-0.4434,  0.1026, -0.3651]])
tensor([[-0.8247, -0.4566, -0.7299],
        [-0.7248, -0.5091, -0.6773],
        [-0.7972, -0.4669, -0.7140],
        [-0.5600, -0.6489, -0.6995]])
tensor([[-0.7217, -0.2778, -0.7733],
        [-0.7296,  0.0712, -0.5544],
        [-0.7343, -0.2306, -0.7528],
        [-0.4039, -0.1314, -0.5416]])
tensor([[-0.7988, -0.1129, -0.5872],
        [-0.8177, -0.5626, -0.8028],
        [-0.6999,  0.0087, -0.3847],
        [-0.6034, -0.3323, -0.4602]])


In [7]:
Hn = torch.zeros(1,B,H)
for x in X:    
    _, Hn = rnn( x.view(1,B,E), Hn )
    print(Hn)

tensor([[[-0.6241, -0.6546, -0.6907],
         [-0.6302, -0.7953, -0.8331],
         [-0.4298, -0.7612, -0.7472],
         [-0.5355, -0.6998, -0.7071]]], grad_fn=<StackBackward0>)
tensor([[[-0.6029, -0.0277, -0.5855],
         [-0.5112, -0.0350, -0.5878],
         [-0.5839,  0.0292, -0.4721],
         [-0.4434,  0.1026, -0.3651]]], grad_fn=<StackBackward0>)
tensor([[[-0.8247, -0.4566, -0.7299],
         [-0.7248, -0.5091, -0.6773],
         [-0.7972, -0.4669, -0.7140],
         [-0.5600, -0.6489, -0.6995]]], grad_fn=<StackBackward0>)
tensor([[[-0.7217, -0.2778, -0.7733],
         [-0.7296,  0.0712, -0.5544],
         [-0.7343, -0.2306, -0.7528],
         [-0.4039, -0.1314, -0.5416]]], grad_fn=<StackBackward0>)
tensor([[[-0.7988, -0.1129, -0.5872],
         [-0.8177, -0.5626, -0.8028],
         [-0.6999,  0.0087, -0.3847],
         [-0.6034, -0.3323, -0.4602]]], grad_fn=<StackBackward0>)
