# GRU


### 比较LSTM和GRU的参数量

In [5]:
import torch
import torch.nn as nn

lstm_layer = nn.LSTM(3,5) # input_size, hidden_size
gru_layer = nn.GRU(3,5)

num_lstm_p = sum(p.numel() for p in lstm_layer.parameters())
num_gru_p = sum(p.numel() for p in gru_layer.parameters())

print(num_lstm_p, num_gru_p)

200 150


### 官方API

In [11]:
import torch
import torch.nn as nn

# input_size 输入序列的特征大小, seq_len：时间  hidden_size 网络的细胞状态的大小
batch_size, seq_len, input_size,hidden_size = 2,3,4,5
input = torch.randn(batch_size, seq_len, input_size)
h0 = torch.randn(batch_size,hidden_size) 

gru_layer = nn.GRU(input_size,hidden_size, batch_first=True)
output, h_final =gru_layer(input, h0.unsqueeze(0))
print(output)


for p,name in gru_layer.named_parameters():
    print(p,name.shape)



tensor([[[ 0.3980, -1.8681,  0.4913, -0.5419,  0.0948],
         [ 0.6235, -1.3081,  0.6259, -0.3296, -0.4504],
         [ 0.2099, -0.9961,  0.7369, -0.3072, -0.1944]],

        [[ 0.8470, -0.4014, -0.9545,  0.8089, -0.5204],
         [ 0.2678, -0.0425, -0.7397,  0.7672, -0.1327],
         [ 0.0273,  0.0763, -0.7891,  0.8487,  0.0870]]],
       grad_fn=<TransposeBackward1>)
weight_ih_l0 torch.Size([15, 4])
weight_hh_l0 torch.Size([15, 5])
bias_ih_l0 torch.Size([15])
bias_hh_l0 torch.Size([15])


### GRU的代码实现

\* 这个符号表示逐元素的相乘，Wx是矩阵相乘

In [None]:

def gru_forward(input, initial_states, w_ih, w_hh, b_ih, b_hh):

    batch_size, seq_len, input_size = input.shape
    prev_h = initial_states
    hidden_size = w_ih.shape[0]//3  # 公式中只有三组相乘

    # 对权重扩维，复制成batch_size倍
    batch_w_ih = w_ih.unsqueeze(0).tile(batch_size, 1,1)
    batch_w_hh = w_hh.unsqueeze(0).tile(batch_size, 1,1)

    output = torch.zeros(batch_size, seq_len, hidden_size)


    for t in range(seq_len):

        x = input[:,t,:] # t时刻gru cell的输入特征向量  [batch_size, input_size]

        w_times_x = torch.bmm(batch_w_ih, x.unsqueeze(-1)) # (batch_size, 3*hidden_size, 1)
        w_times_x = w_times_x.squeeze(-1) # (batch_size, 3*hidden_size)

        w_times_h_prev = torch.bmm(batch_w_hh, prev_h.unsqueeze(-1)) 
        w_times_h_prev = w_times_h_prev.squeeze(-1)   #  (batch_size, 3*hidden_size)

        # 计算r_t和z_t  重置门和更新们

        r_t = torch.sigmoid(w_times_x[:,:hidden_size] + w_times_h_prev[:,:hidden_size] +b_ih[:hidden_size] + b_hh[:hidden_size])
        z_t = torch.sigmoid(w_times_x[:,hidden_size:hidden_size*2] + w_times_h_prev[:,hidden_size:hidden_size*2]  +b_ih[hidden_size:hidden_size*2] + b_hh[hidden_size:hidden_size*2])

        # 计算候选状态 n_t

        n_t = torch.tanh(w_times_x[:,hidden_size*2:hidden_size*3] + b_ih[2*hidden_size:3*hidden_size] \
                         + r_t*(w_times_h_prev[:,2*hidden_size:hidden_size*3] + b_hh[2*hidden_size:3*hidden_size])
                         )
        
        prev_h = (1-z_t)*n_t + z_t*prev_h

        output[:,t,:] = prev_h


    return output,prev_h


output_custom, h_final_custom = gru_forward(input, h0, gru_layer.weight_ih_l0, gru_layer.weight_hh_l0, gru_layer.bias_ih_l0, gru_layer.bias_hh_l0)    


output_custom


tensor([[[ 0.3980, -1.8681,  0.4913, -0.5419,  0.0948],
         [ 0.6235, -1.3081,  0.6259, -0.3296, -0.4504],
         [ 0.2099, -0.9961,  0.7369, -0.3072, -0.1944]],

        [[ 0.8470, -0.4014, -0.9545,  0.8089, -0.5204],
         [ 0.2678, -0.0425, -0.7397,  0.7672, -0.1327],
         [ 0.0273,  0.0763, -0.7891,  0.8487,  0.0870]]], grad_fn=<CopySlices>)

In [13]:
print(torch.allclose(output, output_custom))

True
