In [1]:
import torch
import torch.nn as nn

In [12]:
bs, T = 2, 3 #batch大小，序列长度
input_size, hidden_size = 2, 3
inputs = torch.randn(bs, T, input_size)
h_prev = torch.zeros(bs, hidden_size)

rnn = nn.RNN(input_size, hidden_size, batch_first=True)
rnn_output, state_final = rnn(inputs, h_prev.unsqueeze(0))
print("PyTorch API output:")
print(rnn_output)
print(state_final)

PyTorch API output:
tensor([[[-0.8279,  0.6314, -0.5839],
         [-0.6335,  0.5656, -0.3044],
         [-0.3409,  0.0547,  0.0663]],

        [[-0.5628,  0.2018, -0.4109],
         [-0.7456,  0.7154, -0.4198],
         [-0.6271,  0.4133, -0.2339]]], grad_fn=<TransposeBackward1>)
tensor([[[-0.3409,  0.0547,  0.0663],
         [-0.6271,  0.4133, -0.2339]]], grad_fn=<StackBackward0>)


In [10]:
#手写单向
def rnn_forward(inputs, weight_ih, weight_hh, bias_ih, bias_hh, h_prev):
    bs, T, input_size = inputs.shape
    h_dim = weight_ih.shape[0]  #h_dim*input_size
    h_out = torch.zeros(bs, T, h_dim)
    
    for t in range(T):
        x = inputs[:, t, :].unsqueeze(2)  # 获取当前输入时刻特征，bs*input_size,扩充维度，bs*input_size*1
        w_ih_batch = weight_ih.unsqueeze(0).tile(bs, 1, 1) # 扩充维度，复制 bs*h_dim*input_size
        w_hh_batch = weight_hh.unsqueeze(0).tile(bs, 1, 1)
        
        w_times_x = torch.bmm(w_ih_batch, x).squeeze(-1)  #按批次矩阵乘法，去掉最后一维 bs*h_dim
        w_times_h = torch.bmm(w_hh_batch, h_prev.unsqueeze(2)).squeeze(-1)  #bs*h_dim
        
        h_prev = torch.tanh(w_times_x+bias_ih+w_times_h+bias_hh)  #bs*h_dim
        
        h_out[:, t, :] = h_prev
    
    return h_out, h_prev.unsqueeze(0)  # bs*T*h_dim, 1*bs*h_dim

In [13]:
#for k,v in rnn.named_parameters():
#    print(k, v)
h_out, h_final = rnn_forward(inputs, rnn.weight_ih_l0, rnn.weight_hh_l0, rnn.bias_ih_l0,
                            rnn.bias_hh_l0, h_prev)
print("rnn_forward output:")
print(h_out, h_final)

rnn_forward output:
tensor([[[-0.8279,  0.6314, -0.5839],
         [-0.6335,  0.5656, -0.3044],
         [-0.3409,  0.0547,  0.0663]],

        [[-0.5628,  0.2018, -0.4109],
         [-0.7456,  0.7154, -0.4198],
         [-0.6271,  0.4133, -0.2339]]], grad_fn=<CopySlices>) tensor([[[-0.3409,  0.0547,  0.0663],
         [-0.6271,  0.4133, -0.2339]]], grad_fn=<UnsqueezeBackward0>)


In [26]:
def bidirectional_rnn_forward(inputs, weight_ih, weight_hh, bias_ih, bias_hh, h_prev, \
                             weight_ih_reverse, weight_hh_reverse, bias_ih_reverse, bias_hh_reverse, h_prev_reverse):
    bs, T, input_size = inputs.shape
    h_dim = weight_ih.shape[0]
    h_out = torch.zeros(bs, T, h_dim * 2)  #双向是两倍特征大小
    
    forward_output = rnn_forward(inputs, weight_ih, weight_hh, bias_ih, bias_hh, h_prev)[0]  # bs*T*h_dim
    # 在序列维度（第一维）上翻转，从最后一个单词到第一个单词
    backward_output = rnn_forward(torch.flip(inputs, [1]), 
                             weight_ih_reverse, weight_hh_reverse, bias_ih_reverse, bias_hh_reverse, h_prev_reverse)[0]
    
    h_out[:, :, :h_dim] = forward_output
    h_out[:, :, h_dim:] = torch.flip(backward_output,[1])  # 翻转回来,才能每个时间步的特征对应上
    
    h_n = torch.zeros(2, bs, h_dim)
    #h_n[0, :, :] = h_out[:, -1, :h_dim]
    #h_n[1, :, :] = h_out[:, 0, h_dim:]  # 反向的final是翻转后的first
    h_n[0, :, :] = forward_output[:, -1, :]
    h_n[1, :, :] = backward_output[:, -1, :]
    return h_out, h_n

In [24]:
h_prev = torch.zeros(bs, hidden_size)
birnn = nn.RNN(input_size, hidden_size, batch_first=True, bidirectional=True)
birnn_output, bistate_final = birnn(inputs, h_prev.unsqueeze(0).tile(2, 1, 1))
print("PyTorch API output:")
print(birnn_output)  # bs,T,2*H_out 
print(bistate_final)  # 2,bs,H_out

PyTorch API output:
tensor([[[ 0.1362, -0.0957,  0.6336,  0.8974,  0.6554,  0.2893],
         [-0.4360,  0.0530,  0.1111, -0.1689,  0.4685,  0.3232],
         [ 0.4801, -0.3914, -0.3356, -0.5937,  0.5782,  0.8116]],

        [[ 0.0231, -0.0647, -0.0660,  0.4987,  0.1443,  0.1118],
         [ 0.0248, -0.0509,  0.3082,  0.7010,  0.5588,  0.1377],
         [-0.0444, -0.0964,  0.0993, -0.0877,  0.4508,  0.6527]]],
       grad_fn=<TransposeBackward1>)
tensor([[[ 0.4801, -0.3914, -0.3356],
         [-0.0444, -0.0964,  0.0993]],

        [[ 0.8974,  0.6554,  0.2893],
         [ 0.4987,  0.1443,  0.1118]]], grad_fn=<StackBackward0>)


In [27]:
#for k,v in birnn.named_parameters():
#    print(k, v)
h_prev = torch.zeros(2, bs, hidden_size)
bih_out, bih_final = bidirectional_rnn_forward(inputs, birnn.weight_ih_l0, birnn.weight_hh_l0, birnn.bias_ih_l0, \
                            birnn.bias_hh_l0, h_prev[0], \
                            birnn.weight_ih_l0_reverse, birnn.weight_hh_l0_reverse, birnn.bias_ih_l0_reverse, \
                            birnn.bias_hh_l0_reverse, h_prev[1])
print("bidirectional_rnn_forward output:")
print(bih_out, bih_final)

bidirectional_rnn_forward output:
tensor([[[ 0.1362, -0.0957,  0.6336,  0.8974,  0.6554,  0.2893],
         [-0.4360,  0.0530,  0.1111, -0.1689,  0.4685,  0.3232],
         [ 0.4801, -0.3914, -0.3356, -0.5937,  0.5782,  0.8116]],

        [[ 0.0231, -0.0647, -0.0660,  0.4987,  0.1443,  0.1118],
         [ 0.0248, -0.0509,  0.3082,  0.7010,  0.5588,  0.1377],
         [-0.0444, -0.0964,  0.0993, -0.0877,  0.4508,  0.6527]]],
       grad_fn=<CopySlices>) tensor([[[ 0.4801, -0.3914, -0.3356],
         [-0.0444, -0.0964,  0.0993]],

        [[ 0.8974,  0.6554,  0.2893],
         [ 0.4987,  0.1443,  0.1118]]], grad_fn=<CopySlices>)
