In [1]:
import torch
import torch.nn as nn

# Set seed for reproducibility
torch.manual_seed(0)

# Input dimensions
seq_len = 4      # number of time steps
input_dim = 3    # features per time step
hidden_dim = 2   # LSTM hidden state size
batch_size = 1   # for simplicity

# Dummy input (sequence of vectors)
x = torch.randn(batch_size, seq_len, input_dim)

# Define LSTM layer (1 layer, unidirectional)
lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)

# Initial hidden and cell states (h0, c0)
h0 = torch.zeros(1, batch_size, hidden_dim)
c0 = torch.zeros(1, batch_size, hidden_dim)

# Forward pass
output, (hn, cn) = lstm(x, (h0, c0))

# Print outputs
print("Input sequence:\n", x.squeeze())
print("\nOutput at each time step:\n", output.squeeze())
print("\nFinal hidden state:\n", hn.squeeze())
print("\nFinal cell state:\n", cn.squeeze())


Input sequence:
 tensor([[ 1.5410, -0.2934, -2.1788],
        [ 0.5684, -1.0845, -1.3986],
        [ 0.4033,  0.8380, -0.7193],
        [-0.4033, -0.5966,  0.1820]])

Output at each time step:
 tensor([[-0.1491,  0.0378],
        [-0.2774,  0.0892],
        [-0.0008,  0.2280],
        [ 0.0942,  0.1251]], grad_fn=<SqueezeBackward0>)

Final hidden state:
 tensor([0.0942, 0.1251], grad_fn=<SqueezeBackward0>)

Final cell state:
 tensor([0.2080, 0.3955], grad_fn=<SqueezeBackward0>)


In [2]:
import torch
import torch.nn as nn

# Config
input_size = 5     # number of features per time step
hidden_size = 8    # size of hidden state and cell state
seq_len = 4        # length of the input sequence
batch_size = 2     # number of sequences in a batch

# Dummy input: shape = (batch_size, seq_len, input_size)
x = torch.randn(batch_size, seq_len, input_size)

# Define LSTM layer
lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, batch_first=True)

# Initial hidden state and cell state (num_layers=1 by default)
h0 = torch.zeros(1, batch_size, hidden_size)  # shape: (num_layers, batch_size, hidden_size)
c0 = torch.zeros(1, batch_size, hidden_size)

# Forward pass
output, (hn, cn) = lstm(x, (h0, c0))

# Print shapes
print("Input shape:", x.shape)
print("Output shape (all time steps):", output.shape)
print("Final hidden state shape:", hn.shape)
print("Final cell state shape:", cn.shape)


Input shape: torch.Size([2, 4, 5])
Output shape (all time steps): torch.Size([2, 4, 8])
Final hidden state shape: torch.Size([1, 2, 8])
Final cell state shape: torch.Size([1, 2, 8])
