In [37]:
# reference: https://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html, https://blog.csdn.net/m0_45478865/article/details/104455978
%matplotlib inline
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(1)

<torch._C.Generator at 0x1d8ff39de30>

In [38]:
lstm_layer = nn.LSTM(3, 3)  # Input dim is 3, output dim is 3
inputs = [torch.randn(1, 3) for _ in range(5)]  # make a sequence of length 5 

In [40]:
# initialize the hidden state, a tuple with h_0 and c_0
hidden_layer = (torch.randn(1, 1, 3),
          torch.randn(1, 1, 3))

In [41]:
for i in inputs:
    # Step through the sequence one element at a time.
    # after each step, hidden contains the hidden state.
    output, hidden_layer = lstm_layer(i.view(1, 1, -1), hidden_layer)

In [42]:
output

tensor([[[-0.3600,  0.0893,  0.0215]]], grad_fn=<StackBackward>)

In [43]:
hidden_layer

(tensor([[[-0.3600,  0.0893,  0.0215]]], grad_fn=<StackBackward>),
 tensor([[[-1.1298,  0.4467,  0.0254]]], grad_fn=<StackBackward>))

In [44]:
# alternatively, we can do the entire sequence all at once.
# the first value returned by LSTM is all of the hidden states throughout
# the sequence. the second is just the most recent hidden state
# (compare the last slice of "out" with "hidden" below, they are the same)
# The reason for this is that:
# "out" will give you access to all hidden states in the sequence
# "hidden" will allow you to continue the sequence and backpropagate,
# by passing it as an argument  to the lstm at a later time
# Add the extra 2nd dimension
inputs = torch.cat(inputs).view(len(inputs), 1, -1)
hidden_layer = (torch.randn(1, 1, 3), torch.randn(1, 1, 3))  # clean out hidden state
output, hidden_layer = lstm_layer(inputs, hidden_layer)
print(output)
print(hidden_layer)

tensor([[[-0.0187,  0.1713, -0.2944]],

        [[-0.3521,  0.1026, -0.2971]],

        [[-0.3191,  0.0781, -0.1957]],

        [[-0.1634,  0.0941, -0.1637]],

        [[-0.3368,  0.0959, -0.0538]]], grad_fn=<StackBackward>)
(tensor([[[-0.3368,  0.0959, -0.0538]]], grad_fn=<StackBackward>), tensor([[[-0.9825,  0.4715, -0.0633]]], grad_fn=<StackBackward>))
