In [1]:
import torch
import modeling

In [2]:
rnn_cell = modeling.RNNCell(in_features=128, hidden_features=64)

In [3]:
x = torch.rand(32, 128)

In [4]:
memory = rnn_cell.init_memory(batch_size=x.shape[0])

In [5]:
pred, memory = rnn_cell(x, memory)

In [6]:
pred.shape

torch.Size([32, 64])

In [7]:
pred

tensor([[-0.0651, -0.4203,  0.2848,  ...,  0.0268,  0.2710, -0.5656],
        [-0.2416, -0.6598,  0.5456,  ..., -0.1998,  0.1020, -0.6889],
        [-0.3814, -0.5202,  0.4697,  ...,  0.0243,  0.1561, -0.6879],
        ...,
        [-0.0147, -0.4815,  0.4722,  ...,  0.0567,  0.2315, -0.6212],
        [-0.1974, -0.3412,  0.4346,  ..., -0.1991, -0.0600, -0.2572],
        [-0.2033, -0.4515,  0.3694,  ...,  0.0057,  0.0714, -0.7429]],
       grad_fn=<TanhBackward>)

In [8]:
rnn = modeling.RNN(rnn_cell=modeling.RNNCell(in_features=128, hidden_features=64))

In [9]:
x = torch.rand(32, 64, 128)

In [10]:
hiddens = rnn(x)

In [11]:
hiddens.shape

torch.Size([32, 64, 64])

In [12]:
rnn = modeling.RNN(rnn_cell=modeling.RNNCell(in_features=128, hidden_features=64), output_last=True)

In [13]:
rnn

RNN(
  (rnn_cell): RNNCell(
    (input_linear): Linear(in_features=128, out_features=64, bias=True)
    (memory_linear): Linear(in_features=64, out_features=64, bias=True)
  )
)

In [14]:
x = torch.rand(32, 64, 128)

In [15]:
hiddens = rnn(x)

In [16]:
hiddens.shape

torch.Size([32, 64])

In [17]:
hiddens

tensor([[-0.3070,  0.1827,  0.1650,  ...,  0.0149, -0.3389, -0.0159],
        [-0.1961, -0.0348,  0.3140,  ...,  0.1178, -0.1746, -0.2944],
        [-0.1999, -0.2278,  0.2596,  ..., -0.0030, -0.3328, -0.1432],
        ...,
        [-0.1677, -0.0399,  0.0925,  ...,  0.2667, -0.3900,  0.0542],
        [-0.1550,  0.4114,  0.2531,  ..., -0.0120,  0.0215,  0.3434],
        [-0.3744, -0.0557, -0.0739,  ...,  0.1852, -0.4326,  0.0766]],
       grad_fn=<SqueezeBackward0>)

In [18]:
rnn

RNN(
  (rnn_cell): RNNCell(
    (input_linear): Linear(in_features=128, out_features=64, bias=True)
    (memory_linear): Linear(in_features=64, out_features=64, bias=True)
  )
)

In [19]:
lstm_cell = modeling.LSTMCell(in_features=128, hidden_features=64)

In [20]:
def mult(shapes):
    
    result = 1
    
    for i in shapes:
        result *= i
        
    return result

In [21]:
def n_params(model):
    
    return sum([mult(p.shape) for p in model.parameters()])

In [22]:
n_params(lstm_cell)

49664

In [23]:
x.shape

torch.Size([32, 64, 128])

In [24]:
x_0 = x[0].squeeze()

In [25]:
x_0.shape

torch.Size([64, 128])

In [26]:
output, (hidden, cell) = lstm_cell(x_0)

In [27]:
output.shape

torch.Size([64, 64])

In [28]:
hidden.shape

torch.Size([64, 64])

In [29]:
cell.shape

torch.Size([64, 64])

In [30]:
lstm_cell

LSTMCell(
  (forget_gate_input): Linear(in_features=128, out_features=64, bias=True)
  (forget_gate_hidden): Linear(in_features=64, out_features=64, bias=True)
  (in_gate_input): Linear(in_features=128, out_features=64, bias=True)
  (in_gate_hidden): Linear(in_features=64, out_features=64, bias=True)
  (cell_gate_input): Linear(in_features=128, out_features=64, bias=True)
  (cell_gate_hidden): Linear(in_features=64, out_features=64, bias=True)
  (out_gate_input): Linear(in_features=128, out_features=64, bias=True)
  (out_gate_hidden): Linear(in_features=64, out_features=64, bias=True)
)

In [31]:
lstm = modeling.RNN(rnn_cell=modeling.LSTMCell(in_features=128, hidden_features=64))

In [32]:
x = torch.rand(32, 64, 128)

In [33]:
pred = lstm(x)

In [34]:
pred.shape

torch.Size([32, 64, 64])

In [35]:
bi_lstm = modeling.BidirectionalRNN(rnn_cell_forward=modeling.LSTMCell(in_features=128, hidden_features=64),
                                    rnn_cell_backward=modeling.LSTMCell(in_features=128, hidden_features=64))

In [36]:
n_params(bi_lstm)

99328

In [37]:
hiddens = bi_lstm(x)

In [38]:
hiddens.shape

torch.Size([32, 2, 64, 64])

In [39]:
bi_lstm = modeling.BidirectionalRNN(rnn_cell_forward=modeling.LSTMCell(in_features=128, hidden_features=64), 
                                    rnn_cell_backward=modeling.LSTMCell(in_features=128, hidden_features=64),
                                    output_last=True)

In [40]:
hiddens = bi_lstm(x)

In [41]:
hiddens.shape

torch.Size([32, 2, 64])

In [42]:
bi_lstm

BidirectionalRNN(
  (rnn_cell_forward): LSTMCell(
    (forget_gate_input): Linear(in_features=128, out_features=64, bias=True)
    (forget_gate_hidden): Linear(in_features=64, out_features=64, bias=True)
    (in_gate_input): Linear(in_features=128, out_features=64, bias=True)
    (in_gate_hidden): Linear(in_features=64, out_features=64, bias=True)
    (cell_gate_input): Linear(in_features=128, out_features=64, bias=True)
    (cell_gate_hidden): Linear(in_features=64, out_features=64, bias=True)
    (out_gate_input): Linear(in_features=128, out_features=64, bias=True)
    (out_gate_hidden): Linear(in_features=64, out_features=64, bias=True)
  )
  (rnn_cell_backward): LSTMCell(
    (forget_gate_input): Linear(in_features=128, out_features=64, bias=True)
    (forget_gate_hidden): Linear(in_features=64, out_features=64, bias=True)
    (in_gate_input): Linear(in_features=128, out_features=64, bias=True)
    (in_gate_hidden): Linear(in_features=64, out_features=64, bias=True)
    (cell_gate_