In [2]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn 

In [11]:
def gru_forward(input,initial_states,w_ih,w_hh,b_ih,b_hh):
    prev_h = initial_states
    bs,T,i_size = input.shape
    h_size = w_ih.shape[0]//3
    #对权重扩维，复制为batch_size倍
    batch_w_ih = w_ih.unsqueeze(0).tile(bs,1,1)
    batch_w_hh = w_hh.unsqueeze(0).tile(bs,1,1)

    output = torch.zeros(bs,T,h_size) #GRU网络的输出状态序列
    for t in range(T):
        x = input[: , t, :] #t时刻gru cell 输入特征向量 bs,i_size
        w_times_x = torch.bmm(batch_w_ih,x.unsqueeze(-1))
        w_times_x = w_times_x.squeeze(-1)
        w_times_h_prev = torch.bmm(batch_w_hh,prev_h.unsqueeze(-1))
        w_times_h_prev = w_times_h_prev.squeeze(-1)

        r_t = torch.sigmoid(w_times_x[:,:h_size]+w_times_h_prev[:,:h_size]+b_ih[:h_size]+b_hh[:h_size])  #重置门
        z_t = torch.sigmoid(w_times_x[:,h_size:2*h_size]+w_times_h_prev[:,h_size:2*h_size]+b_ih[h_size:2*h_size]+b_hh[h_size:2*h_size]) #更新门
        n_t = torch.tanh(w_times_x[:,2*h_size:3*h_size]+b_ih[2*h_size:3*h_size]+r_t*(w_times_h_prev[:,2*h_size:3*h_size]+b_hh[2*h_size:3*h_size]))#候选状态

        prev_h = (1-z_t)*n_t + z_t*prev_h
        output[: ,t,:] = prev_h
    return output,prev_h 

#测试
bs, T,i_size,h_size = 2,3,4,5
input = torch.randn(bs,T,i_size)
h0 = torch.randn(bs,h_size)

gru_layer = nn.GRU(i_size,h_size,batch_first = True)
output,h_final = gru_layer(input,h0.unsqueeze(0))
print(output)

for k,v in gru_layer.named_parameters():
    print(k,v.shape)

output_custom, h_final_custom = gru_forward(input,h0,gru_layer.weight_ih_l0,gru_layer.weight_hh_l0,gru_layer.bias_ih_l0,gru_layer.bias_hh_l0)
print(output_custom)
print(h_final_custom)

tensor([[[-0.2419,  0.4648, -0.4883, -1.1455,  0.2463],
         [-0.5579,  0.3708, -0.4725, -1.0375,  0.3324],
         [-0.8012,  0.4762, -0.4361, -0.8410,  0.4185]],

        [[ 0.9285, -0.2837, -0.0574,  0.5011, -0.5571],
         [ 0.3154, -0.0776,  0.1312,  0.1449, -0.3146],
         [-0.2662,  0.1968,  0.0817,  0.0113, -0.1194]]],
       grad_fn=<TransposeBackward1>)
weight_ih_l0 torch.Size([15, 4])
weight_hh_l0 torch.Size([15, 5])
bias_ih_l0 torch.Size([15])
bias_hh_l0 torch.Size([15])
tensor([[[-0.2419,  0.4648, -0.4883, -1.1455,  0.2463],
         [-0.5579,  0.3708, -0.4725, -1.0375,  0.3324],
         [-0.8012,  0.4762, -0.4361, -0.8410,  0.4185]],

        [[ 0.9285, -0.2837, -0.0574,  0.5011, -0.5571],
         [ 0.3154, -0.0776,  0.1312,  0.1449, -0.3146],
         [-0.2662,  0.1968,  0.0817,  0.0113, -0.1194]]], grad_fn=<CopySlices>)
tensor([[-0.8012,  0.4762, -0.4361, -0.8410,  0.4185],
        [-0.2662,  0.1968,  0.0817,  0.0113, -0.1194]], grad_fn=<AddBackward0>)
