# 1. nn.LSTM 基本介绍

`nn.LSTM` 是 长短期记忆网络（Long Short-Term Memory, LSTM） 的实现，
它能解决普通` RNN `的 梯度消失/梯度爆炸 问题，更适合处理长序列任务（如` NLP`、时间序列预测、语音识别）。

### 定义：

In [None]:
torch.nn.LSTM(
    input_size,      # 输入特征维度
    hidden_size,     # 隐藏层维度
    num_layers=1,    # 堆叠多少层 LSTM
    bias=True,       # 是否使用偏置
    batch_first=False, # True: 输入 (batch, seq, feature)
    dropout=0.0,     # 层与层之间的 dropout
    bidirectional=False # 是否为双向 LSTM
)


## 2. 输入 & 输出格式

- 假设：

    - batch_size = B
    
    - seq_len = T
    
    - input_size = D_in
    
    - hidden_size = D_h

- 输入

    - input：
        - (T, B, D_in) 若 batch_first=False
        - (B, T, D_in) 若 batch_first=True
        - h_0（可选，初始隐藏状态）：(num_layers * num_directions, B, D_h)
        - c_0（可选，初始细胞状态）：(num_layers * num_directions, B, D_h)

- 输出

    - output：
        - (T, B, num_directions * D_h) 或 (B, T, num_directions * D_h)。包含每个时间步的隐藏状态。
    
    - (h_n, c_n)：
    
        - h_n：最后时间步的隐藏状态 (num_layers * num_directions, B, D_h)
        
        - c_n：最后时间步的细胞状态 (num_layers * num_directions, B, D_h)

---
## 3. 代码示例
### (1) 单层 LSTM

In [None]:
import torch
import torch.nn as nn

# 参数
input_size = 5    # 每个时间步的输入特征维度
hidden_size = 3   # 隐藏层维度
seq_len = 4
batch_size = 2

# 定义 LSTM
lstm = nn.LSTM(input_size, hidden_size, batch_first=True)

# 输入 (batch, seq, feature)
x = torch.randn(batch_size, seq_len, input_size)

# 前向传播
output, (h_n, c_n) = lstm(x)

print("输入 x:", x.shape)          # [2, 4, 5]
print("输出 output:", output.shape)  # [2, 4, 3]
print("最终隐藏状态 h_n:", h_n.shape) # [1, 2, 3]
print("最终细胞状态 c_n:", c_n.shape) # [1, 2, 3]


### (2) 多层 LSTM

In [None]:
lstm = nn.LSTM(input_size=5, hidden_size=3, num_layers=2, batch_first=True)
x = torch.randn(2, 4, 5)
output, (h_n, c_n) = lstm(x)

print("output:", output.shape)   # [2, 4, 3]
print("h_n:", h_n.shape)         # [2, 2, 3] (2 层, batch=2, hidden=3)
print("c_n:", c_n.shape)         # [2, 2, 3]


### (3) 双向 LSTM

In [None]:
lstm = nn.LSTM(input_size=5, hidden_size=3, num_layers=1, bidirectional=True, batch_first=True)
x = torch.randn(2, 4, 5)
output, (h_n, c_n) = lstm(x)

print("output:", output.shape)   # [2, 4, 6] (hidden_size * 2)
print("h_n:", h_n.shape)         # [2, 2, 3] (2 directions, batch=2, hidden=3)
print("c_n:", c_n.shape)         # [2, 2, 3]


## 4. LSTM 的优势

- 记忆长序列依赖（相比 RNN）

- 缓解梯度消失/爆炸

- 更适合 NLP / 时间序列