In [2]:
import torch
import torch.nn as nn

In [3]:
bs, T, input_size = 3, 4, 5
inputs = torch.randn(bs, T, input_size)
hidden_size = 3

In [7]:
h_0 = torch.randn(bs, hidden_size)
gru_layer = nn.GRU(input_size, hidden_size, batch_first=True)
output, h_final = gru_layer(inputs, h_0.unsqueeze(0))
print("output:", output.shape)
print(output)
print("----" * 20)
print("h_final:", h_final.shape)
print(h_final)
print("----" * 20)
for k,v in gru_layer.named_parameters():
    print(k,v.shape)
#参数数量：3*hidden_size*input_size+3*hidden_size*hidden_size+3*hidden_size+3*hidden_size
#lstm：4*hidden_size*input_size+4*hidden_size*hidden_size+4*hidden_size+4*hidden_size
#lstm另一种表示是把h_prev和x拼接在一起，f_t = sigmoid(W_f·[h_prev, x] + b_f)
#这种的参数数量是4*hidden_size*input_size+4*hidden_size*hidden_size+4*hidden_size，b_ih和b_hh共享了

output: torch.Size([3, 4, 3])
tensor([[[-0.8074,  0.0164,  0.2317],
         [-0.3491, -0.2538,  0.4215],
         [ 0.3843, -0.5728,  0.4064],
         [-0.0106, -0.5821,  0.2731]],

        [[ 0.4062, -0.3633, -0.6866],
         [-0.0600, -0.2247, -0.1095],
         [ 0.3834, -0.4157, -0.0751],
         [-0.4303, -0.1369, -0.1540]],

        [[ 0.0753,  0.5420,  0.2229],
         [ 0.1023, -0.5723,  0.7018],
         [-0.0315, -0.6288,  0.5973],
         [-0.1824, -0.5702,  0.4410]]], grad_fn=<TransposeBackward1>)
--------------------------------------------------------------------------------
h_final: torch.Size([1, 3, 3])
tensor([[[-0.0106, -0.5821,  0.2731],
         [-0.4303, -0.1369, -0.1540],
         [-0.1824, -0.5702,  0.4410]]], grad_fn=<StackBackward0>)
--------------------------------------------------------------------------------
weight_ih_l0 torch.Size([9, 5])
weight_hh_l0 torch.Size([9, 3])
bias_ih_l0 torch.Size([9])
bias_hh_l0 torch.Size([9])


In [10]:
def gru_forward(inputs, initial_states, w_ih, w_hh, b_ih, b_hh):
    h_prev = initial_states  #bs,hidden_size
    bs, T, input_size = inputs.shape
    hidden_size = w_ih.shape[0] // 3
    output = torch.zeros(bs, T, hidden_size)
    
    batch_w_ih = w_ih.unsqueeze(0).tile(bs, 1, 1)  # bs,3*hidden_size,input_size
    batch_w_hh = w_hh.unsqueeze(0).tile(bs, 1, 1)  # bs,3*hidden_size,hidden_size
    batch_b_ih = b_ih.unsqueeze(0).tile(bs, 1)  # bs,3*hidden_size
    batch_b_hh = b_hh.unsqueeze(0).tile(bs, 1)  # bs,3*hidden_size
    
    for t in range(T):
        x = inputs[:, t, :]
        w_times_x = torch.bmm(batch_w_ih, x.unsqueeze(-1)).squeeze(-1)
        #bs,3*h_size,i_size与bs,i_size,1 -> bs,3*h_size,1 -> bs,3*h_size
        w_times_h = torch.bmm(batch_w_hh, h_prev.unsqueeze(-1)).squeeze(-1)
        #bs,3*h_size,h_size与bs,h_size,1 -> bs,3*h_size,1 -> bs,3*h_size
        r_t = torch.sigmoid(w_times_x[:, :hidden_size] + batch_b_ih[:, :hidden_size] + \
                          w_times_h[:, :hidden_size] + batch_b_hh[:, :hidden_size])
        z_t = torch.sigmoid(w_times_x[:, hidden_size:2*hidden_size] + batch_b_ih[:, hidden_size:2*hidden_size] + \
                          w_times_h[:, hidden_size:2*hidden_size] + batch_b_hh[:, hidden_size:2*hidden_size])
        n_t = torch.tanh(w_times_x[:, 2*hidden_size:] + batch_b_ih[:, 2*hidden_size:] + r_t * ( \
                        w_times_h[:, 2*hidden_size:] + batch_b_hh[:, 2*hidden_size:]))
        h_prev = (1 - z_t) * n_t + z_t * h_prev
        output[:, t, :] = h_prev
    
    return output, h_prev.unsqueeze(0)


forward_output, forward_h_final = gru_forward(inputs, h_0, gru_layer.weight_ih_l0, gru_layer.weight_hh_l0, \
                                             gru_layer.bias_ih_l0, gru_layer.bias_hh_l0)
print("forward_output:", forward_output.shape)
print(forward_output)
print("----" * 20)
print("forward_h_final:", forward_h_final.shape)
print(forward_h_final)

forward_output: torch.Size([3, 4, 3])
tensor([[[-0.8074,  0.0164,  0.2317],
         [-0.3491, -0.2538,  0.4215],
         [ 0.3843, -0.5728,  0.4064],
         [-0.0106, -0.5821,  0.2731]],

        [[ 0.4062, -0.3633, -0.6866],
         [-0.0600, -0.2247, -0.1095],
         [ 0.3834, -0.4157, -0.0751],
         [-0.4303, -0.1369, -0.1540]],

        [[ 0.0753,  0.5420,  0.2229],
         [ 0.1023, -0.5723,  0.7018],
         [-0.0315, -0.6288,  0.5973],
         [-0.1824, -0.5702,  0.4410]]], grad_fn=<CopySlices>)
--------------------------------------------------------------------------------
forward_h_final: torch.Size([1, 3, 3])
tensor([[[-0.0106, -0.5821,  0.2731],
         [-0.4303, -0.1369, -0.1540],
         [-0.1824, -0.5702,  0.4410]]], grad_fn=<UnsqueezeBackward0>)


In [11]:
#查看参数数量
test_lstm_layer = nn.LSTM(3, 5)
test_gru_layer = nn.GRU(3, 5)

In [12]:
sum(v.numel() for k, v in test_lstm_layer.named_parameters()), 4*(5*5 + 5*3 + 5 + 5)

(200, 200)

In [13]:
sum(v.numel() for k, v in test_gru_layer.named_parameters()), 3*(5*5 + 5*3 + 5 + 5)

(150, 150)