# PyTorch GRU的原理及其手写复现

来自b站up主deep_thoughts 合集【PyTorch源码教程与前沿人工智能算法复现讲解】

P_31_PyTorch_GRU的原理及其手写复现：

https://www.bilibili.com/video/BV1jm4y1Q7uh/?spm_id_from=pageDriver&vd_source=18e91d849da09d846f771c89a366ed40

***资料***

PyTorch GRU 官方文档：

https://pytorch.org/docs/stable/generated/torch.nn.GRU.html#torch.nn.GRU

***论文***

Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling：

https://arxiv.org/pdf/1412.3555.pdf

## GRU 公式

$r_t = \sigma(W_{ir}x_t + b_{ir} + W_{hr}h_{(t-1)}+b_{hr})$

$z_t = \sigma(W_{iz}x_t + b_{iz} + W_{hz}h_{(t-1)}+b_{hz})$

$n_t = tanh(W_{in}x_t + b_{in} + r_t * (W_{hn} h_{(t-1)} + b_{hn}))$

$h_t = (1-z_t)*n_t + z_t*h_(t-1)$

## 查看 LSTM 和 GRU 的参数量

In [2]:
import torch
import torch.nn as nn

lstm_layer = nn.LSTM(3, 5)
gru_layer = nn.GRU(3, 5)
print(sum(p.numel() for p in lstm_layer.parameters()))
print(sum(p.numel() for p in gru_layer.parameters()))

200
150


## 逐行实现GRU网络

In [6]:


def gru_forward(input, initial_states, w_ih, w_hh, b_ih, b_hh):
    prev_h = initial_states
    bs, T, i_size = input.shape
    h_size = w_ih.shape[0] // 3
    
    # 对权重扩维， 复制成 batch_size 倍
    batch_w_ih = w_ih.unsqueeze(0).tile(bs, 1, 1)
    batch_w_hh = w_hh.unsqueeze(0).tile(bs, 1, 1)
    
    output = torch.zeros(bs, T, h_size)  # GRU网络的输出状态序列
    
    for t in range(T):
        x = input[:, t, :]  # t 时刻 GRU cell 的输入特征向量, [bs, i_size]
        w_times_x = torch.bmm(batch_w_ih, x.unsqueeze(-1))  # [bs, 3*h_size, 1]
        w_times_x = w_times_x.squeeze(-1)  # [bs, 3*h_size]
        
        w_times_h_prev = torch.bmm(batch_w_hh, prev_h.unsqueeze(-1))  # [bs, 3*h_size, 1]
        w_times_h_prev = w_times_h_prev.squeeze(-1)  # [bs, 3*h_size]
        
        r_t = torch.sigmoid(w_times_x[:, :h_size]+w_times_h_prev[:, :h_size]+b_ih[:h_size]+b_hh[:h_size])  # 重置门
        z_t = torch.sigmoid(w_times_x[:, h_size:2*h_size]+w_times_h_prev[:, h_size:2*h_size]
                            +b_ih[h_size:2*h_size]+b_hh[h_size:2*h_size])  # 更新门
        n_t = torch.tanh(w_times_x[:, 2*h_size:3*h_size]+b_ih[2*h_size:3*h_size]
                         +r_t*(w_times_h_prev[:, 2*h_size:3*h_size]+b_hh[2*h_size:3*h_size]))  # 候选状态
        prev_h = (1-z_t)*n_t + z_t*prev_h  # 增量更新得到当前时刻最新隐含状态
        output[:, t, :] = prev_h
        
    return output, prev_h

# 测试函数正确性

bs, T, i_size, h_size = 2, 3, 4, 5
input = torch.randn(bs, T, i_size)  # 输入序列
h0 = torch.randn(bs, h_size)

gru_layer = nn.GRU(i_size, h_size, batch_first=True)
output, h_final = gru_layer(input,h0.unsqueeze(0))
print(output)
for k, v in gru_layer.named_parameters():
    print(k, v.shape)
    
# 调用自定义的 gru_forward 函数
output_custom, h_final_custom = gru_forward(input, h0, gru_layer.weight_ih_l0,
                                            gru_layer.weight_hh_l0,
                                            gru_layer.bias_ih_l0,
                                            gru_layer.bias_hh_l0)

print(torch.allclose(output, output_custom))
print(torch.allclose(h_final, h_final_custom))

tensor([[[-0.1634, -0.4756, -0.7797,  0.3584, -0.2965],
         [ 0.0975, -0.2879, -0.7104,  0.1393, -0.3203],
         [ 0.1759, -0.2848, -0.3293,  0.0650, -0.3861]],

        [[-0.3849,  0.3312,  0.6562,  1.2750, -0.6056],
         [-0.4849, -0.0553,  0.5266,  0.9541, -0.7548],
         [-0.0237, -0.0220,  0.4013,  0.7372, -0.7686]]],
       grad_fn=<TransposeBackward1>)
weight_ih_l0 torch.Size([15, 4])
weight_hh_l0 torch.Size([15, 5])
bias_ih_l0 torch.Size([15])
bias_hh_l0 torch.Size([15])
True
True
