In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# 单个 GRU Cell 的实现

## Pytorch 接口调用

输入的是一个形状为 `(batch, hidden_size)` 的 Tensor

In [2]:
batch_size = 8
input_size = 64
hidden_size = 128
bias = True

input_tensor = torch.randn(batch_size, input_size)

torch_gru_cell = torch.nn.GRUCell(input_size, hidden_size, bias)
output_tensor_torch = torch_gru_cell(input_tensor)

## GRU Cell 手动实现

Batch化的矩阵形式：

$$
\begin{align}
& gi = XW_{ih}^T + b_{ih}\\
& gh = H_{t-1}W_{hh}^T + b_{hh} \\
& ri, zi, ni = \text{split}(gi, 3) \\
& rh, zh, nh = \text{split}(gh, 3) \\
& r_t = \sigma(ri + rh) \\
& z_t = \sigma(zi + zh) \\
& n_t = \tanh(ni + r_t \odot nh) \\
& H_t = (1-z_t)\odot n_t + z_t \odot H_{t-1}
\end{align}
$$


In [3]:
def gru_cell(
    x: torch.Tensor,
    h0: torch.Tensor,
    W_hh: torch.Tensor,
    W_ih: torch.Tensor,
    bias_ih: torch.Tensor,
    bias_hh: torch.Tensor,
) -> torch.Tensor:
    # 输入信息变换
    gates_i = x @ W_ih.t() + bias_ih
    # 隐藏状态信息变换
    gates_h = h0 @ W_hh.t() + bias_hh
    ri, zi, ni = torch.chunk(gates_i, 3, -1)
    rh, zh, nh = torch.chunk(gates_h, 3, -1)
    # rt 为重置门
    rt = torch.sigmoid(ri + rh)
    # zt 为更新门
    zt = torch.sigmoid(zi + zh)
    # nt 代表新的信息
    nt = torch.tanh(ni + rt * nh)
    # ht 是旧信息与新信息的加权
    ht = (1 - zt) * nt + zt * h0
    return ht


weight_ih = torch_gru_cell.weight_ih  # [hidden_size, input_size]
weight_hh = torch_gru_cell.weight_hh  # [hidden_size, hidden_size]
bias_ih = torch_gru_cell.bias_ih  # [hidden_size]
bais_hh = torch_gru_cell.bias_hh  #  [hidden_size]

h_init = torch.zeros(batch_size, hidden_size)  # [batch_size, hidden_size]

output_tensor = gru_cell(input_tensor, h_init, weight_hh, weight_ih, bias_ih, bais_hh)
print(
    "rnn cell output allclose: ",
    "✅" if torch.allclose(output_tensor_torch, output_tensor, atol=1e-6) else "❌",
)

rnn cell output allclose:  ✅


# 多层 GRU 网络

## 多层 GRU 的 Pytorch API

In [4]:
batch_size = 8
input_size = 64
hidden_size = 128
seqlen = 32
bias = True
num_layers = 2

input_tensor = torch.randn(batch_size, seqlen, input_size)

In [5]:
two_layer_gru_torch = nn.GRU(
    input_size,
    hidden_size,
    num_layers,
    batch_first=True,
    bidirectional=False,
)

output_tensors_torch = two_layer_gru_torch(input_tensor)

## 多层 RNN 的手动实现

In [6]:
def two_layer_gru(input_tensor, layer_params, h0=None):
    output_tensor = input_tensor.permute(1, 0, 2)
    seqlen, batch_size, _ = output_tensor.shape
    num_layers = len(layer_params)
    hidden_size = layer_params[0][0].size(0) // 3  # weight_ih_l0
    if h0 is None:
        ht = torch.zeros(num_layers, batch_size, hidden_size)
    else:
        ht = h0

    for layer in range(num_layers):
        output = []
        for t in range(seqlen):
            W_ih, W_hh, bias_ih, bias_hh = layer_params[layer]
            ht[layer] = gru_cell(output_tensor[t], ht[layer], W_hh, W_ih, bias_ih, bias_hh)
            output.append(ht[layer].clone())
        output_tensor = torch.stack(output)
    return output_tensor.permute(1, 0, 2), ht


layer_params = [
    (
        two_layer_gru_torch.weight_ih_l0,
        two_layer_gru_torch.weight_hh_l0,
        two_layer_gru_torch.bias_ih_l0,
        two_layer_gru_torch.bias_hh_l0,
    ),
    (
        two_layer_gru_torch.weight_ih_l1,
        two_layer_gru_torch.weight_hh_l1,
        two_layer_gru_torch.bias_ih_l1,
        two_layer_gru_torch.bias_hh_l1,
    ),
]
output_tensors = two_layer_gru(input_tensor, layer_params)
print(
    "otuput allclose: ",
    (
        "✅"
        if torch.allclose(output_tensors_torch[0], output_tensors[0], atol=1e-6)
        else "❌"
    ),
)
print(
    "hidden states allclose: ",
    (
        "✅"
        if torch.allclose(output_tensors_torch[1], output_tensors[1], atol=1e-6)
        else "❌"
    ),
)

otuput allclose:  ✅
hidden states allclose:  ✅


# 双向多层 GRU

## 双向 GRU 的 Pytorch 接口

In [7]:
two_layer_bidir_rnn_torch = nn.GRU(
    input_size, hidden_size, num_layers, batch_first=True, bidirectional=True
)
output_tensors_torch = two_layer_bidir_rnn_torch(input_tensor)

## 双向 GRU 的手动实现

In [8]:
def two_layer_bidir_rnn(input_tensor, layer_params, h0=None, bidirection=True):
    output_tensor = input_tensor.permute(1, 0, 2)
    seqlen, batch_size, _ = output_tensor.shape
    num_layers = len(layer_params)
    hidden_size = layer_params[0][0].size(0) // 3  # weight_ih_l0
    directions = 2 if bidirection else 1
    if h0 is None:
        ht = torch.zeros(directions * num_layers, batch_size, hidden_size)
    else:
        ht = h0

    for layer in range(num_layers):
        # 正向
        W_ih, W_hh, bias_ih, bias_hh = layer_params[layer][:4]
        output = []
        for t in range(seqlen):
            ht[2 * layer] = gru_cell(output_tensor[t], ht[2 * layer], W_hh, W_ih, bias_ih, bias_hh)
            output.append(ht[2 * layer].clone())
        # 反向
        W_ih, W_hh, bias_ih, bias_hh = layer_params[layer][4:]
        reverse_output = []
        for t in range(seqlen):
            ht[2 * layer + 1] = gru_cell(
                output_tensor[seqlen - t - 1],
                ht[2 * layer + 1],
                W_hh,
                W_ih,
                bias_ih,
                bias_hh,
            )
            reverse_output.append(ht[2 * layer + 1].clone())
        reverse_output.reverse()
        output_tensor = torch.concat(
            [torch.stack(output), torch.stack(reverse_output)], dim=-1
        )
    return output_tensor.permute(1, 0, 2), ht


bi_rnn_layer_params = [
    (
        two_layer_bidir_rnn_torch.weight_ih_l0,
        two_layer_bidir_rnn_torch.weight_hh_l0,
        two_layer_bidir_rnn_torch.bias_ih_l0,
        two_layer_bidir_rnn_torch.bias_hh_l0,
        two_layer_bidir_rnn_torch.weight_ih_l0_reverse,
        two_layer_bidir_rnn_torch.weight_hh_l0_reverse,
        two_layer_bidir_rnn_torch.bias_ih_l0_reverse,
        two_layer_bidir_rnn_torch.bias_hh_l0_reverse,
    ),
    (
        two_layer_bidir_rnn_torch.weight_ih_l1,
        two_layer_bidir_rnn_torch.weight_hh_l1,
        two_layer_bidir_rnn_torch.bias_ih_l1,
        two_layer_bidir_rnn_torch.bias_hh_l1,
        two_layer_bidir_rnn_torch.weight_ih_l1_reverse,
        two_layer_bidir_rnn_torch.weight_hh_l1_reverse,
        two_layer_bidir_rnn_torch.bias_ih_l1_reverse,
        two_layer_bidir_rnn_torch.bias_hh_l1_reverse,
    ),
]

output_tensors = two_layer_bidir_rnn(
    input_tensor, bi_rnn_layer_params, bidirection=True
)

print(
    "otuput allclose: ",
    (
        "✅"
        if torch.allclose(output_tensors_torch[0], output_tensors[0], atol=1e-6)
        else "❌"
    ),
)
print(
    "hidden states allclose: ",
    (
        "✅"
        if torch.allclose(output_tensors_torch[1], output_tensors[1], atol=1e-6)
        else "❌"
    ),
)

otuput allclose:  ✅
hidden states allclose:  ✅
