# LSTM in PyTorch

Simiar to `nn.RNN`, there is an `nn.LSTM` module in PyTorch. You can read more about this module on the official [PyTorch documentation](https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html). For the most part you provide the module with identical arguments, but the output consists of three parts instead of two: output, hidden state and cell state.

In [1]:
import torch
import torch.nn as nn

In [2]:
BATCH_SIZE=4
SEQUENCE_LENGTH=5
INPUT_SIZE=2
HIDDEN_SIZE=3
NUM_LAYERS=1

In [3]:
lstm = nn.LSTM(input_size=INPUT_SIZE, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS)

There are more weights in a LSTM. The $ i, g, f, o $ are references to the functions, that we discussed during the theoretical part. Weights for those functions are all collected in a single tensor, so we extract those weights for our convenience.

In [4]:
# input to hidden weights and biases
w_ih = lstm.weight_ih_l0
b_ih = lstm.bias_ih_l0

w_ii = w_ih[0*HIDDEN_SIZE:1*HIDDEN_SIZE]
w_if = w_ih[1*HIDDEN_SIZE:2*HIDDEN_SIZE]
w_ig = w_ih[2*HIDDEN_SIZE:3*HIDDEN_SIZE]
w_io = w_ih[3*HIDDEN_SIZE:]

b_ii = b_ih[0*HIDDEN_SIZE:1*HIDDEN_SIZE]
b_if = b_ih[1*HIDDEN_SIZE:2*HIDDEN_SIZE]
b_ig = b_ih[2*HIDDEN_SIZE:3*HIDDEN_SIZE]
b_io = b_ih[3*HIDDEN_SIZE:]

# hidden to hidden weights and biases
w_hh = lstm.weight_hh_l0
b_hh = lstm.bias_hh_l0

w_hi = w_hh[0*HIDDEN_SIZE:1*HIDDEN_SIZE]
w_hf = w_hh[1*HIDDEN_SIZE:2*HIDDEN_SIZE]
w_hg = w_hh[2*HIDDEN_SIZE:3*HIDDEN_SIZE]
w_ho = w_hh[3*HIDDEN_SIZE:]

b_hi = b_hh[0*HIDDEN_SIZE:1*HIDDEN_SIZE]
b_hf = b_hh[1*HIDDEN_SIZE:2*HIDDEN_SIZE]
b_hg = b_hh[2*HIDDEN_SIZE:3*HIDDEN_SIZE]
b_ho = b_hh[3*HIDDEN_SIZE:]

In [5]:
# create inputs to the LSTM
sequence = torch.randn(SEQUENCE_LENGTH, BATCH_SIZE, INPUT_SIZE)
h_0 = torch.zeros(NUM_LAYERS, BATCH_SIZE, HIDDEN_SIZE)
c_0 = torch.zeros(NUM_LAYERS, BATCH_SIZE, HIDDEN_SIZE)

In [6]:
with torch.inference_mode():
    output, (h_n, c_n) = lstm(sequence, (h_0, c_0))

In the below example we calculate the outputs of the four fully connected neural networks. By using the forget gate f, the input gate i and the output gate o, we derive the hidden and cell values. Try to work through the example below in order to fully understand the LSTM Cell.

In [7]:
def manual_lstm():
    hidden = h_0.clone()
    cell = c_0.clone()
    output = torch.zeros(SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE)
    with torch.inference_mode():
        for idx, seq in enumerate(sequence):
            f = torch.sigmoid(seq @ w_if.T + b_if + hidden[0] @ w_hf.T + b_hf)
            i = torch.sigmoid(seq @ w_ii.T + b_ii + hidden[0] @ w_hi.T + b_hi)
            o = torch.sigmoid(seq @ w_io.T + b_io + hidden[0] @ w_ho.T + b_ho)
            g = torch.tanh(seq @ w_ig.T + b_ig + hidden[0] @ w_hg.T + b_hg)
            
            cell[0] = f * cell[0] + i * g
            hidden[0] = o * torch.tanh(cell[0])
            output[idx] = hidden[0]
    return output, (hidden, cell)

In [8]:
manual_output, (manual_h_n, manual_c_n) = manual_lstm()

In the last step we can compare the values from the LSTM module and our manual implementation. The results are identical.

In [9]:
output

tensor([[[ 0.0909, -0.1463,  0.1052],
         [-0.0274,  0.0318,  0.2162],
         [ 0.1561, -0.1113,  0.1520],
         [ 0.0342, -0.1877,  0.0212]],

        [[ 0.0991, -0.1448,  0.2481],
         [ 0.0410, -0.1487,  0.2124],
         [ 0.2016, -0.2148,  0.1453],
         [ 0.0042, -0.0748,  0.2674]],

        [[ 0.0602, -0.0646,  0.3548],
         [ 0.0159, -0.0749,  0.3285],
         [ 0.2282, -0.2431,  0.2154],
         [-0.0071, -0.1779,  0.2184]],

        [[ 0.0058, -0.1205,  0.2827],
         [-0.0484, -0.1683,  0.1527],
         [ 0.2777, -0.2918,  0.1636],
         [ 0.0416, -0.1358,  0.3673]],

        [[-0.0062, -0.1312,  0.3340],
         [-0.0066, -0.2489,  0.1571],
         [ 0.2070, -0.2525,  0.2602],
         [ 0.1375, -0.2403,  0.2903]]])

In [10]:
manual_output

tensor([[[ 0.0909, -0.1463,  0.1052],
         [-0.0274,  0.0318,  0.2162],
         [ 0.1561, -0.1113,  0.1520],
         [ 0.0342, -0.1877,  0.0212]],

        [[ 0.0991, -0.1448,  0.2481],
         [ 0.0410, -0.1487,  0.2124],
         [ 0.2016, -0.2148,  0.1453],
         [ 0.0042, -0.0748,  0.2674]],

        [[ 0.0602, -0.0646,  0.3548],
         [ 0.0159, -0.0749,  0.3285],
         [ 0.2282, -0.2431,  0.2154],
         [-0.0071, -0.1779,  0.2184]],

        [[ 0.0058, -0.1205,  0.2827],
         [-0.0484, -0.1683,  0.1527],
         [ 0.2777, -0.2918,  0.1636],
         [ 0.0416, -0.1358,  0.3673]],

        [[-0.0062, -0.1312,  0.3340],
         [-0.0066, -0.2489,  0.1571],
         [ 0.2070, -0.2525,  0.2602],
         [ 0.1375, -0.2403,  0.2903]]])

In [11]:
h_n

tensor([[[-0.0062, -0.1312,  0.3340],
         [-0.0066, -0.2489,  0.1571],
         [ 0.2070, -0.2525,  0.2602],
         [ 0.1375, -0.2403,  0.2903]]])

In [12]:
manual_h_n

tensor([[[-0.0062, -0.1312,  0.3340],
         [-0.0066, -0.2489,  0.1571],
         [ 0.2070, -0.2525,  0.2602],
         [ 0.1375, -0.2403,  0.2903]]])

In [13]:
c_n

tensor([[[-0.0222, -0.2391,  0.6195],
         [-0.0178, -0.4750,  0.2472],
         [ 0.6347, -0.5466,  0.3933],
         [ 0.3488, -0.5005,  0.3994]]])

In [14]:
manual_c_n

tensor([[[-0.0222, -0.2391,  0.6195],
         [-0.0178, -0.4750,  0.2472],
         [ 0.6347, -0.5466,  0.3933],
         [ 0.3488, -0.5005,  0.3994]]])