In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np

In [2]:
class CleanBasicRNN(nn.Module):
    def __init__(self, batch_size, n_inputs, n_neurons):
        super(CleanBasicRNN, self).__init__()
        
        self.rnn = nn.RNNCell(n_inputs, n_neurons)
        self.hx = torch.randn(batch_size, n_neurons) # initialize hidden state
        
    def forward(self, X):
        output = []

        # for each time step
        for i in range(2):
            self.hx = self.rnn(X[i], self.hx)
            output.append(self.hx)
        
        return output, self.hx

In [3]:
FIXED_BATCH_SIZE = 4 # our batch size is fixed for now
N_INPUT = 3
N_NEURONS = 5

X_batch = torch.tensor([[[0,1,2], [3,4,5], 
                         [6,7,8], [9,0,1]],
                        [[9,8,7], [0,0,0], 
                         [6,5,4], [3,2,1]]
                       ], dtype = torch.float) # X0 and X1

In [4]:
model = CleanBasicRNN(FIXED_BATCH_SIZE, N_INPUT, N_NEURONS)
output_val, states_val = model(X_batch)
print(output_val) #Output of X0 and X1
print(states_val)

[tensor([[ 0.3732,  0.6449, -0.0716,  0.8421, -0.9486],
        [ 0.7228, -0.8435, -0.7468,  0.9988, -0.9998],
        [ 0.8796, -0.4239, -0.9750,  1.0000, -1.0000],
        [-0.9922, -0.9003, -0.9394,  0.9680, -0.9988]], grad_fn=<TanhBackward>), tensor([[ 0.6581, -0.9726, -0.9796,  1.0000, -1.0000],
        [-0.2322, -0.3966, -0.1854, -0.8376,  0.4302],
        [ 0.3885, -0.9442, -0.9277,  0.9951, -0.9999],
        [-0.5736, -0.8624, -0.2245,  0.6911, -0.9764]], grad_fn=<TanhBackward>)]
tensor([[ 0.6581, -0.9726, -0.9796,  1.0000, -1.0000],
        [-0.2322, -0.3966, -0.1854, -0.8376,  0.4302],
        [ 0.3885, -0.9442, -0.9277,  0.9951, -0.9999],
        [-0.5736, -0.8624, -0.2245,  0.6911, -0.9764]], grad_fn=<TanhBackward>)
