In [1]:
import torch
import modeling

In [2]:
rnn_cell = modeling.RNNCell(in_features=128, hidden_features=64)

In [3]:
x = torch.rand(32, 128)

In [4]:
memory = rnn_cell.init_memory(batch_size=x.shape[0])

In [5]:
pred, memory = rnn_cell(x, memory)

In [6]:
pred.shape

torch.Size([32, 64])

In [7]:
pred

tensor([[-0.3369, -0.3292, -0.3237,  ...,  0.4945, -0.2610,  0.3480],
        [-0.2680, -0.5288, -0.1565,  ...,  0.6043, -0.2498,  0.0741],
        [-0.2436, -0.5758, -0.2917,  ...,  0.5885, -0.3456,  0.0952],
        ...,
        [-0.2518, -0.6194, -0.0176,  ...,  0.7087, -0.0825,  0.1125],
        [-0.1086, -0.3499, -0.3520,  ...,  0.7109, -0.1036,  0.2769],
        [-0.1962, -0.6386,  0.0139,  ...,  0.7777,  0.1578, -0.0286]],
       grad_fn=<TanhBackward>)

In [8]:
rnn = modeling.RNN(rnn_cell=modeling.RNNCell(in_features=128, hidden_features=64))

In [9]:
x = torch.rand(32, 64, 128)

In [10]:
hiddens = rnn(x)

In [11]:
hiddens.shape

torch.Size([32, 64, 64])

In [12]:
rnn = modeling.RNN(rnn_cell=modeling.RNNCell(in_features=128, hidden_features=64), output_last=True)

In [13]:
rnn

RNN(
  (rnn_cell): RNNCell(
    (input_linear): Linear(in_features=128, out_features=64, bias=True)
    (memory_linear): Linear(in_features=64, out_features=64, bias=True)
  )
)

In [14]:
x = torch.rand(32, 64, 128)

In [15]:
hiddens = rnn(x)

In [16]:
hiddens.shape

torch.Size([32, 64])

In [17]:
hiddens

tensor([[-0.1447,  0.3517, -0.0885,  ...,  0.5245, -0.2525, -0.2523],
        [-0.3391,  0.6071,  0.2502,  ...,  0.4700, -0.1598, -0.6810],
        [-0.3252,  0.2833,  0.3660,  ...,  0.3193,  0.0499, -0.7865],
        ...,
        [-0.4222,  0.3862,  0.2638,  ...,  0.3958, -0.1276, -0.6503],
        [-0.1737,  0.6961,  0.1192,  ...,  0.4698, -0.2408, -0.7065],
        [ 0.1545,  0.2464,  0.4547,  ...,  0.3895,  0.0871, -0.7643]],
       grad_fn=<SqueezeBackward0>)

In [18]:
rnn

RNN(
  (rnn_cell): RNNCell(
    (input_linear): Linear(in_features=128, out_features=64, bias=True)
    (memory_linear): Linear(in_features=64, out_features=64, bias=True)
  )
)

In [19]:
lstm_cell = modeling.LSTMCell(in_features=128, hidden_features=64)

In [20]:
def mult(shapes):
    
    result = 1
    
    for i in shapes:
        result *= i
        
    return result

In [21]:
def n_params(model):
    
    return sum([mult(p.shape) for p in model.parameters()])

In [22]:
n_params(lstm_cell)

49664

In [23]:
x.shape

torch.Size([32, 64, 128])

In [24]:
x_0 = x[0].squeeze()

In [25]:
x_0.shape

torch.Size([64, 128])

In [26]:
output, (hidden, cell) = lstm_cell(x_0)

In [27]:
output.shape

torch.Size([64, 64])

In [28]:
hidden.shape

torch.Size([64, 64])

In [29]:
cell.shape

torch.Size([64, 64])

In [30]:
lstm_cell

LSTMCell(
  (forget_gate_input): Linear(in_features=128, out_features=64, bias=True)
  (forget_gate_hidden): Linear(in_features=64, out_features=64, bias=True)
  (in_gate_input): Linear(in_features=128, out_features=64, bias=True)
  (in_gate_hidden): Linear(in_features=64, out_features=64, bias=True)
  (cell_gate_input): Linear(in_features=128, out_features=64, bias=True)
  (cell_gate_hidden): Linear(in_features=64, out_features=64, bias=True)
  (out_gate_input): Linear(in_features=128, out_features=64, bias=True)
  (out_gate_hidden): Linear(in_features=64, out_features=64, bias=True)
)

In [31]:
lstm = modeling.RNN(rnn_cell=modeling.LSTMCell(in_features=128, hidden_features=64))

In [32]:
x = torch.rand(32, 64, 128)

In [33]:
pred = lstm(x)

In [34]:
pred.shape

torch.Size([32, 64, 64])

In [35]:
bi_lstm = modeling.BidirectionalRNN(rnn_cell_forward=modeling.LSTMCell(in_features=128, hidden_features=64),
                                    rnn_cell_backward=modeling.LSTMCell(in_features=128, hidden_features=64))

In [36]:
n_params(bi_lstm)

99328

In [37]:
hiddens = bi_lstm(x)

In [38]:
hiddens.shape

torch.Size([32, 2, 64, 64])

In [39]:
bi_lstm = modeling.BidirectionalRNN(rnn_cell_forward=modeling.LSTMCell(in_features=128, hidden_features=64), 
                                    rnn_cell_backward=modeling.LSTMCell(in_features=128, hidden_features=64),
                                    output_last=True)

In [40]:
hiddens = bi_lstm(x)

In [41]:
hiddens.shape

torch.Size([32, 2, 64])

In [42]:
bi_lstm

BidirectionalRNN(
  (rnn_cell_forward): LSTMCell(
    (forget_gate_input): Linear(in_features=128, out_features=64, bias=True)
    (forget_gate_hidden): Linear(in_features=64, out_features=64, bias=True)
    (in_gate_input): Linear(in_features=128, out_features=64, bias=True)
    (in_gate_hidden): Linear(in_features=64, out_features=64, bias=True)
    (cell_gate_input): Linear(in_features=128, out_features=64, bias=True)
    (cell_gate_hidden): Linear(in_features=64, out_features=64, bias=True)
    (out_gate_input): Linear(in_features=128, out_features=64, bias=True)
    (out_gate_hidden): Linear(in_features=64, out_features=64, bias=True)
  )
  (rnn_cell_backward): LSTMCell(
    (forget_gate_input): Linear(in_features=128, out_features=64, bias=True)
    (forget_gate_hidden): Linear(in_features=64, out_features=64, bias=True)
    (in_gate_input): Linear(in_features=128, out_features=64, bias=True)
    (in_gate_hidden): Linear(in_features=64, out_features=64, bias=True)
    (cell_gate_