In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple

# 单个 LSTM Cell 的实现

## Pytorch 接口调用

输入的是一个形状为 `(batch, hidden_size)` 的 Tensor

In [2]:
batch_size = 8
input_size = 64
hidden_size = 128
bias = True

input_tensor = torch.randn(batch_size, input_size)

torch_lstm_cell = torch.nn.LSTMCell(input_size, hidden_size, bias)
output_tensor_torch = torch_lstm_cell(input_tensor)

## LSTM Cell 手动实现

Batch化的矩阵形式：

$$
\begin{align}
& Z = X_{t-1}W_{ih}^T + \mathbb{b}_{ih} + H_{t-1}W_{hh}^T+ \mathbb{b}_{hh}  \\
& i,f,g,o = \text{split}(Z, 4) \\
& i = \text{sigmoid}(i) \\
& f = \text{sigmoid}(f) \\
& g = \text{tanh}(g) \\
& o = \text{sigmoid}(o) \\
& c_t = i \odot g + c_{t-1} \odot f \\
& h_t = o \odot \text{tanh}(c_t)
\end{align}
$$

In [5]:
def lstm_cell(
    x: torch.Tensor,
    hx: Tuple[torch.Tensor],
    W_hh: torch.Tensor,
    W_ih: torch.Tensor,
    bias_ih: torch.Tensor,
    bias_hh: torch.Tensor,
) -> torch.Tensor:
    h0, c0 = hx
    z = x @ W_ih.t() + h0 @ W_hh.t() + bias_hh + bias_ih
    i, f, g, o = torch.chunk(z, 4, dim=-1)
    i, f, g, o = torch.sigmoid(i), torch.sigmoid(f), torch.tanh(g), torch.sigmoid(o)
    c1 = i * g + c0 * f
    h1 = o * torch.tanh(c1)
    return h1, c1


weight_ih = torch_lstm_cell.weight_ih  # [4 * hidden_size, input_size]
weight_hh = torch_lstm_cell.weight_hh  # [4 * hidden_size, hidden_size]
bias_ih = torch_lstm_cell.bias_ih  # [4 * hidden_size]
bais_hh = torch_lstm_cell.bias_hh  #  [4 * hidden_size]

zero_tensor = torch.zeros(batch_size, hidden_size)  # [batch_size, hidden_size]
hx = (zero_tensor, zero_tensor)

output_tensor = lstm_cell(input_tensor, hx, weight_hh, weight_ih, bias_ih, bais_hh)
print(
    "lstm cell output h1 allclose: ",
    (
        "✅"
        if torch.allclose(output_tensor_torch[0], output_tensor[0], atol=1e-7)
        else "❌"
    ),
)
print(
    "lstm cell output c1 allclose: ",
    (
        "✅"
        if torch.allclose(output_tensor_torch[1], output_tensor[1], atol=1e-7)
        else "❌"
    ),
)

lstm cell output h1 allclose:  ✅
lstm cell output c1 allclose:  ✅


# 多层 RNN 网络

## 多层 RNN 的 Pytorch API

In [8]:
batch_size = 8
input_size = 64
hidden_size = 128
seqlen = 32
bias = True
num_layers = 2

input_tensor = torch.randn(batch_size, seqlen, input_size)

In [9]:
two_layer_lstm_torch = nn.LSTM(
    input_size,
    hidden_size,
    num_layers,
    batch_first=True,
    bidirectional=False,
)

output_tensors_torch = two_layer_lstm_torch(input_tensor)

## 多层 LSTM 的手动实现

In [12]:
def two_layer_lstm(input_tensor, layer_params, hx=None):
    output_tensor = input_tensor.permute(1, 0, 2)
    seqlen, batch_size, _ = output_tensor.shape
    num_layers = len(layer_params)
    hidden_size = layer_params[0][0].size(0) //  # weight_ih_l0
    if hx is None:
        ht = torch.zeros(num_layers, batch_size, hidden_size)
        ct = torch.zeros(num_layers, batch_size, hidden_size)
    else:
        ht, ct = hx

    for layer in range(num_layers):
        output = []
        for t in range(seqlen):
            W_ih, W_hh, bias_ih, bias_hh = layer_params[layer]
            ht[layer], ct[layer] = lstm_cell(
                output_tensor[t], (ht[layer], ct[layer]), W_hh, W_ih, bias_ih, bias_hh
            )
            output.append(ht[layer].clone())
        output_tensor = torch.stack(output)
    return output_tensor.permute(1, 0, 2), (ht, ct)


layer_params = [
    (
        two_layer_lstm_torch.weight_ih_l0,
        two_layer_lstm_torch.weight_hh_l0,
        two_layer_lstm_torch.bias_ih_l0,
        two_layer_lstm_torch.bias_hh_l0,
    ),
    (
        two_layer_lstm_torch.weight_ih_l1,
        two_layer_lstm_torch.weight_hh_l1,
        two_layer_lstm_torch.bias_ih_l1,
        two_layer_lstm_torch.bias_hh_l1,
    ),
]
output_tensors = two_layer_lstm(input_tensor, layer_params)
print(
    "otuput allclose: ",
    (
        "✅"
        if torch.allclose(output_tensors_torch[0], output_tensors[0], atol=1e-6)
        else "❌"
    ),
)
print(
    "hidden states ht allclose: ",
    (
        "✅"
        if torch.allclose(output_tensors_torch[1][0], output_tensors[1][0], atol=1e-6)
        else "❌"
    ),
)
print(
    "hidden states ct allclose: ",
    (
        "✅"
        if torch.allclose(output_tensors_torch[1][1], output_tensors[1][1], atol=1e-6)
        else "❌"
    ),
)

otuput allclose:  ✅
hidden states ht allclose:  ✅
hidden states ct allclose:  ✅


# 双向多层 LSTM

## 双向 LSTM 的 Pytorch 接口

In [13]:
two_layer_bidir_lstm_torch = nn.LSTM(
    input_size, hidden_size, num_layers, batch_first=True, bidirectional=True
)
output_tensors_torch = two_layer_bidir_lstm_torch(input_tensor)

## 双向 LSTM 的手动实现

In [16]:
def two_layer_bidir_lstm(input_tensor, layer_params, hx=None, bidirection=True):
    output_tensor = input_tensor.permute(1, 0, 2)
    seqlen, batch_size, _ = output_tensor.shape
    num_layers = len(layer_params)
    hidden_size = layer_params[0][0].size(0) // 4  # weight_ih_l0
    directions = 2 if bidirection else 1

    if hx is None:
        ht = torch.zeros(directions * num_layers, batch_size, hidden_size)
        ct = torch.zeros(directions * num_layers, batch_size, hidden_size)
    else:
        ht, ct = hx

    for layer in range(num_layers):
        # 正向
        W_ih, W_hh, bias_ih, bias_hh = layer_params[layer][:4]
        output = []
        for t in range(seqlen):
            ht[2 * layer], ct[2 * layer] = lstm_cell(
                output_tensor[t],
                (ht[2 * layer], ct[2 * layer]),
                W_hh,
                W_ih,
                bias_ih,
                bias_hh,
            )
            output.append(ht[2 * layer].clone())
        # 反向
        W_ih, W_hh, bias_ih, bias_hh = layer_params[layer][4:]
        reverse_output = []
        for t in range(seqlen):
            ht[2 * layer + 1], ct[2 * layer + 1] = lstm_cell(
                output_tensor[seqlen - t - 1],
                (ht[2 * layer + 1], ct[2 * layer + 1]),
                W_hh,
                W_ih,
                bias_ih,
                bias_hh,
            )
            reverse_output.append(ht[2 * layer + 1].clone())
        reverse_output.reverse()
        output_tensor = torch.concat(
            [torch.stack(output), torch.stack(reverse_output)], dim=-1
        )
    return output_tensor.permute(1, 0, 2), (ht, ct)


bi_rnn_layer_params = [
    (
        two_layer_bidir_lstm_torch.weight_ih_l0,
        two_layer_bidir_lstm_torch.weight_hh_l0,
        two_layer_bidir_lstm_torch.bias_ih_l0,
        two_layer_bidir_lstm_torch.bias_hh_l0,
        two_layer_bidir_lstm_torch.weight_ih_l0_reverse,
        two_layer_bidir_lstm_torch.weight_hh_l0_reverse,
        two_layer_bidir_lstm_torch.bias_ih_l0_reverse,
        two_layer_bidir_lstm_torch.bias_hh_l0_reverse,
    ),
    (
        two_layer_bidir_lstm_torch.weight_ih_l1,
        two_layer_bidir_lstm_torch.weight_hh_l1,
        two_layer_bidir_lstm_torch.bias_ih_l1,
        two_layer_bidir_lstm_torch.bias_hh_l1,
        two_layer_bidir_lstm_torch.weight_ih_l1_reverse,
        two_layer_bidir_lstm_torch.weight_hh_l1_reverse,
        two_layer_bidir_lstm_torch.bias_ih_l1_reverse,
        two_layer_bidir_lstm_torch.bias_hh_l1_reverse,
    ),
]

output_tensors = two_layer_bidir_lstm(
    input_tensor, bi_rnn_layer_params, bidirection=True
)

print(
    "output allclose: ",
    (
        "✅"
        if torch.allclose(output_tensors_torch[0], output_tensors[0], atol=1e-6)
        else "❌"
    ),
)
print(
    "hidden states ht allclose: ",
    (
        "✅"
        if torch.allclose(output_tensors_torch[1][0], output_tensors[1][0], atol=1e-6)
        else "❌"
    ),
)
print(
    "hidden states ct allclose: ",
    (
        "✅"
        if torch.allclose(output_tensors_torch[1][1], output_tensors[1][1], atol=1e-6)
        else "❌"
    ),
)

output allclose:  ✅
hidden states ht allclose:  ✅
hidden states ct allclose:  ✅
