#### LSTM vs LSTMCell

This notebook compares the implementations of LSTM and LSTMCell in pytorch. It also has an example of how to use a LSTM to "manually" iterate through time.

1.  LSTMCell vs LSTM (GPU Support)

In [1]:
import torch.nn as nn
import torch

torch.manual_seed(12)

# features 10, hidden size 20
rnn = nn.LSTMCell(10, 20)
# (time_steps, batch, input_size)
input = torch.randn(2, 3, 10)
# (batch, hidden_size)
hx = torch.randn(3, 20)
# (batch, hidden_size)
cx = torch.randn(3, 20)
# list to save outputs each time step
output = []
for i in range(input.size()[0]):
        print(f" step {i} , Input : {input[i].shape}")
        print(f" step {i} , hidden state : {hx.shape}")
        print(f" step {i} , cell state : {cx.shape}")

        hx, cx = rnn(input[i], (hx, cx))
        output.append(hx)       

output  = torch.stack(output, dim = 0)

print(f" \n Output LSTMCell: {output.shape}")
print(f" \n Output LSTMCell:\n {output}")

 step 0 , Input : torch.Size([3, 10])
 step 0 , hidden state : torch.Size([3, 20])
 step 0 , cell state : torch.Size([3, 20])
 step 1 , Input : torch.Size([3, 10])
 step 1 , hidden state : torch.Size([3, 20])
 step 1 , cell state : torch.Size([3, 20])
 
 Output LSTMCell: torch.Size([2, 3, 20])
 
 Output LSTMCell:
 tensor([[[ 0.2552, -0.1090, -0.0835,  0.2958, -0.4213, -0.0170,  0.2474,
           0.1677, -0.3955, -0.1091,  0.2386, -0.0052,  0.3028,  0.2637,
           0.0113, -0.3665, -0.0687, -0.0280, -0.3560,  0.0247],
         [-0.0958, -0.0987,  0.3122, -0.1819, -0.0390, -0.3634,  0.1293,
           0.2141, -0.5724, -0.1846, -0.3133, -0.1564, -0.0349,  0.2602,
           0.3982, -0.3867,  0.3485,  0.0338, -0.2344,  0.0718],
         [-0.1596,  0.0429,  0.3663, -0.0377,  0.1968, -0.2228, -0.0025,
          -0.0077, -0.1359,  0.0810,  0.1197, -0.0336,  0.0129, -0.1988,
          -0.2814,  0.2207, -0.0302, -0.1498, -0.3188, -0.0169]],

        [[ 0.1216, -0.0640, -0.1142, -0.1368, -0.

In [2]:
torch.manual_seed(12)

# (input_size, hidden_size, layers)
rnn = nn.LSTM(10, 20, 1)
# (seq_len, batch, input_size)
input = torch.randn(2, 3, 10)
# (num_layers, batch, hidden_size)
h0 = torch.randn(1, 3, 20)
c0 = torch.randn(1, 3, 20)
output, (hn, cn) = rnn(input, (h0, c0))

print(f" \n Output : {output.shape}")
print(f" \n hidden state : {hn.shape}")
print(f" \n cell state : {cn.shape}")
print(f" \n Output LSTM:\n {output}")


 
 Output : torch.Size([2, 3, 20])
 
 hidden state : torch.Size([1, 3, 20])
 
 cell state : torch.Size([1, 3, 20])
 
 Output LSTM:
 tensor([[[ 0.2552, -0.1090, -0.0835,  0.2958, -0.4213, -0.0170,  0.2474,
           0.1677, -0.3955, -0.1091,  0.2386, -0.0052,  0.3028,  0.2637,
           0.0113, -0.3665, -0.0687, -0.0280, -0.3560,  0.0247],
         [-0.0958, -0.0987,  0.3122, -0.1819, -0.0390, -0.3634,  0.1293,
           0.2141, -0.5724, -0.1846, -0.3133, -0.1564, -0.0349,  0.2602,
           0.3982, -0.3867,  0.3485,  0.0338, -0.2344,  0.0718],
         [-0.1596,  0.0429,  0.3663, -0.0377,  0.1968, -0.2228, -0.0025,
          -0.0077, -0.1359,  0.0810,  0.1197, -0.0336,  0.0129, -0.1988,
          -0.2814,  0.2207, -0.0302, -0.1498, -0.3188, -0.0169]],

        [[ 0.1216, -0.0640, -0.1142, -0.1368, -0.2694, -0.1684,  0.1206,
           0.1096, -0.2894, -0.0532,  0.1241,  0.0037,  0.0609,  0.0435,
          -0.0029,  0.0019, -0.0049, -0.0011, -0.2284, -0.0170],
         [-0.1441,  0.

2. Using LSTM like an LSTMCell (loop through time)

In [3]:
torch.manual_seed(12)

# (input_size, hidden_size, layers)
rnn = nn.LSTM(10, 20, 1)
# (seq_len, batch, input_size)
input = torch.randn(2, 3, 10)
# (num_layers, batch, hidden_size)
hx = torch.randn(1, 3, 20)
cx = torch.randn(1, 3, 20)

outputs = []
for i in range(input.shape[0]):
    
        print(f" \n step {i} , Input : {input[i:i+1,:,:].shape}")
        print(f" step {i} , hidden state : {hx.shape}")
        print(f" step {i} , cell state : {cx.shape}")

        output, (hx, cx) = rnn(input[i:i+1,:,:], (hx, cx))
        outputs.append(output)
        
outputs  = torch.stack(outputs, dim = 0)
print(f" \n Output LSTM: \n {outputs}")

 
 step 0 , Input : torch.Size([1, 3, 10])
 step 0 , hidden state : torch.Size([1, 3, 20])
 step 0 , cell state : torch.Size([1, 3, 20])
 
 step 1 , Input : torch.Size([1, 3, 10])
 step 1 , hidden state : torch.Size([1, 3, 20])
 step 1 , cell state : torch.Size([1, 3, 20])
 
 Output LSTM: 
 tensor([[[[ 0.2552, -0.1090, -0.0835,  0.2958, -0.4213, -0.0170,  0.2474,
            0.1677, -0.3955, -0.1091,  0.2386, -0.0052,  0.3028,  0.2637,
            0.0113, -0.3665, -0.0687, -0.0280, -0.3560,  0.0247],
          [-0.0958, -0.0987,  0.3122, -0.1819, -0.0390, -0.3634,  0.1293,
            0.2141, -0.5724, -0.1846, -0.3133, -0.1564, -0.0349,  0.2602,
            0.3982, -0.3867,  0.3485,  0.0338, -0.2344,  0.0718],
          [-0.1596,  0.0429,  0.3663, -0.0377,  0.1968, -0.2228, -0.0025,
           -0.0077, -0.1359,  0.0810,  0.1197, -0.0336,  0.0129, -0.1988,
           -0.2814,  0.2207, -0.0302, -0.1498, -0.3188, -0.0169]]],


        [[[ 0.1216, -0.0640, -0.1142, -0.1368, -0.2694, -0.168