# 1. GRU概念

![img](./pic/GRU.png)  
  
**本质是对```H_t```的计算方法进行了改良:**
  
- **GRU**的核心思想是：在上一个时刻的隐状态```H_t-1```基础上，通过一个门控函数，来决定是否更新隐状态```H_t```。门控函数的输出是一个0-1的数，如果门控函数输出大于门控阈值，则更新隐状态```H_t```，否则不更新。门控函数的输入是上一个时刻的隐
- R_t是重置门,有助于捕捉序列中短期依赖关系
- Z_t是更新门，有助于捕捉序列中长期依赖关系



In [1]:
import torch
from torch import nn

def get_params(emb_size, num_hiddens):
    num_inputs = num_outputs = emb_size

    def normal(shape):
        return torch.randn(size=shape) * 0.01  # 保证均值为0方差为0.01

    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens))

    # 隐藏层参数
    W_xz, W_hz, b_z = three()  # 更新门参数
    W_xr, W_hr, b_r = three()  # 重置门参数
    W_xh, W_hh, b_h = three()  # 候选隐状态参数

    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs)

    # 附加梯度
    params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params

In [2]:
def init_gru_state(batch_size, num_hiddens):
    return torch.zeros((batch_size, num_hiddens))

<div style="border-left:2px solid black;padding:10px;margin-left:20px;">

$$ R_t = \sigma(W_{xr} \cdot x_t + W_{hr} \cdot H_{t-1} + b_r) $$
$$ Z_t = \sigma(W_{xz} \cdot x_t + W_{hz} \cdot H_{t-1} + b_z) $$
$$ \hat{H_t} = \tanh(W_{xh} \cdot x_t + (R_t \odot H_{t-1})W_{hh}  + b_h) $$
$$ H_t = (1 - Z_t) \odot \hat{H_t} + Z_t \odot H_{t-1} $$
</div>
  
$$ o_t = W_{hq} \cdot H_t + b_q $$

In [6]:
def gru(inputs, state, params):
    # inputs的形状为(T,bs,emb)
    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
    H = state
    outputs = []
    # X的形状为(bs,emb)
    for X in inputs:
        Z = torch.sigmoid(torch.matmul(X, W_xz) + torch.matmul(H, W_hz) + b_z)
        R = torch.sigmoid(torch.matmul(X, W_xr) + torch.matmul(H, W_hr) + b_r)
        H_tilde = torch.tanh(torch.matmul(X, W_xh) + R * torch.matmul(H, W_hh) + b_h)
        H = Z * H + (1 - Z) * H_tilde
        Y = torch.matmul(H, W_hq) + b_q
        outputs.append(Y)

    return torch.cat(outputs, dim=0), H  # 输出wei(T*bs,emb); (bs, h)

In [7]:
class RNNModel:
    def __init__(
        self, vocab_size, emb_size, num_hiddens, get_params, init_state,forward_fn):
        
        self.vocab_size, self.num_hiddens = vocab_size, num_hiddens
        self.params = get_params(emb_size, num_hiddens)
        self.init_state, self.forward_fn = init_state, forward_fn
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.Linear = nn.Linear(emb_size, vocab_size)

    def __call__(self, X, state):
        X = self.embedding(X.T).float()
        X, H = self.forward_fn(X, state, self.params)
        return self.Linear(X), H  # 

    def begin_state(self, batch_size):
        return self.init_state(batch_size, self.num_hiddens)

其余过程与RNN一样，只是```H_t```计算过程不同！

In [8]:
num_hiddens, vocab_size, emb_size = 512, 28, 20
X = torch.arange(10).reshape(2, 5)  # bs=2, T=5
net = RNNModel(vocab_size, emb_size, num_hiddens, get_params, init_gru_state, gru)

# emb = nn.Embedding(vocab_size, emb_size)
state = net.begin_state(X.shape[0])
Y, new_state = net(X, state)  # 输入的是(T,bs,emb)-->(T,bs,vocab)
print(Y.shape, len(new_state), new_state.shape)

"""
torch.Size([10, 28]) 2 torch.Size([2, 512])
"""

torch.Size([10, 28]) 2 torch.Size([2, 512])


'\ntorch.Size([10, 28]) 2 torch.Size([2, 512])\n'

## 1.2 调用api

In [9]:
emb = nn.Embedding(vocab_size, emb_size)
net2 = nn.GRU(emb_size, num_hiddens)
state2 = torch.zeros((1, 2, num_hiddens))
Y, new_state = net2(emb(X.T), state2)  # 输入的是X(T,bs,emb)-->(T,bs,hiddens)
print(Y.shape, len(new_state), new_state.shape)

"""
torch.Size([5, 2, 512]) 1 torch.Size([1, 2, 512])
"""

torch.Size([5, 2, 512]) 1 torch.Size([1, 2, 512])


'\ntorch.Size([5, 2, 512]) 1 torch.Size([1, 2, 512])\n'