# PyTorch RNN的原理及其手写复现

来自b站up主deep_thoughts 合集【PyTorch源码教程与前沿人工智能算法复现讲解】

P_29_PyTorch RNN的原理及其手写复现：

https://www.bilibili.com/video/BV13i4y1R7jB/?spm_id_from=333.788&vd_source=18e91d849da09d846f771c89a366ed40

RNN 官方文档： https://pytorch.org/docs/stable/generated/torch.nn.RNN.html

torch.flip 官方文档： https://pytorch.org/docs/stable/generated/torch.flip.html

***论文***

# 循环神经网络 RNN

## 记忆单元分类
* RNN
* GRU
* LSTM

## 模型类别
* 单向循环
* 双向循环
* 多层单向或双向叠加

## 优缺点
* 优点
  * 可处理变长序列
  * 模型大小与序列长度无关
  * 计算量与序列长度呈线性增长
  * 考虑历史信息
  * 便于流式输出
  * 权重时不变
* 缺点
  * 串行计算比较慢
  * 无法获取太长的历史信息

## 应用场景
* AI诗歌生成
* 文本情感分类
* 词法识别
* 机器翻译
* 语音识别/生成
* 语言模型

In [1]:
import torch
import torch.nn as nn

## 1. 单向、单层RNN

In [5]:
single_rnn = nn.RNN(4, 3, 1, batch_first=True)
input = torch.randn(1, 2, 4)  # bs * sl * fs
output, h_n = single_rnn(input)
print(output.shape)
print(h_n.shape)

torch.Size([1, 2, 3])
torch.Size([1, 1, 3])


## 2. 双向、单层RNN

In [6]:
bidirectional_rnn = nn.RNN(4, 3, 1, batch_first=True, bidirectional=True)
bi_output, bi_h_n = bidirectional_rnn(input)
print(bi_output.shape)
print(bi_h_n.shape)

torch.Size([1, 2, 6])
torch.Size([2, 1, 3])


## 单向RNN与双向RNN的逐行实现

In [14]:
import torch
import torch.nn as nn

bs, T = 2, 3  # 批大小，输入序列长度
input_size, hidden_size = 2, 3  # 输入特征大小，隐含层特征大小
input = torch.randn(bs, T, input_size)  # 随机初始化一个输入特征序列
h_prev = torch.zeros(bs, hidden_size)  # 初始隐含状态

# step1 调用PyTorch RNN API
rnn = nn.RNN(input_size, hidden_size, batch_first=True)
rnn_output, state_final = rnn(input, h_prev.unsqueeze(0))
print("PyTorch API output:")
print(rnn_output)
print(state_final)

# step2 手写一个rnn_forward函数，实现单向RNN的计算原理
def rnn_forward(input, weight_ih, weight_hh, bias_ih, bias_hh, h_prev):
    bs, T, input_size = input.shape
    h_dim = weight_ih.shape[0]
    h_out = torch.zeros(bs, T, h_dim)  # 初始化一个输出（状态）矩阵
    
    for t in range(T):
        x = input[:, t, :].unsqueeze(2)  # 获取当前时刻输入特征, bs * input_size
        w_ih_batch = weight_ih.unsqueeze(0).tile(bs, 1, 1)  # bs*h_dim+input_size
        w_hh_batch = weight_hh.unsqueeze(0).tile(bs, 1, 1)  # bs*h_dim*h_dim
        
        w_times_x = torch.bmm(w_ih_batch, x).squeeze(-1)  # bs*h_dim
        w_times_h = torch.bmm(w_hh_batch, h_prev.unsqueeze(2)).squeeze(-1)  # bs*h_dim
        h_prev = torch.tanh(w_times_x+bias_ih+w_times_h+bias_hh)
        
        h_out[:, t, :] = h_prev
        
    return h_out, h_prev.unsqueeze(0)

# 验证一下rnn_forward的正确性
#for k,v in rnn.named_parameters():
    #print(k, v)
custom_rnn_output, custom_state_final = rnn_forward(input, rnn.weight_ih_l0, rnn.weight_hh_l0, rnn.bias_ih_l0, rnn.bias_hh_l0, h_prev)

print("\n custom rnn_forward function output:")
print(custom_rnn_output)
print(custom_state_final)

PyTorch API output:
tensor([[[-0.0021,  0.1526,  0.6347],
         [-0.5690,  0.1357,  0.4924],
         [ 0.5337, -0.0540,  0.9681]],

        [[ 0.3480, -0.0049,  0.8693],
         [-0.0484, -0.1439,  0.9146],
         [ 0.6201,  0.0670,  0.9676]]], grad_fn=<TransposeBackward1>)
tensor([[[ 0.5337, -0.0540,  0.9681],
         [ 0.6201,  0.0670,  0.9676]]], grad_fn=<StackBackward0>)

 custom rnn_forward function output:
tensor([[[-0.0021,  0.1526,  0.6347],
         [-0.5690,  0.1357,  0.4924],
         [ 0.5337, -0.0540,  0.9681]],

        [[ 0.3480, -0.0049,  0.8693],
         [-0.0484, -0.1439,  0.9146],
         [ 0.6201,  0.0670,  0.9676]]], grad_fn=<CopySlices>)
tensor([[[ 0.5337, -0.0540,  0.9681],
         [ 0.6201,  0.0670,  0.9676]]], grad_fn=<UnsqueezeBackward0>)


## 手写双向RNN

In [15]:
# step3 手写一个bidirectional_rnn_forward函数，实现双向RNN的计算原理
def bidirectional_rnn_forward(input, weight_ih, weight_hh, bias_ih, bias_hh, h_prev, \
                             weight_ih_reverse, weight_hh_reverse, bias_ih_reverse, bias_hh_reverse, h_prev_reverse):
    bs, T, input_size = input.shape
    h_dim = weight_ih.shape[0]
    h_out = torch.zeros(bs, T, h_dim*2)  # 初始化一个输出（状态）矩阵，注意双向是两倍的特征大小
    
    forward_output = rnn_forward(input, weight_ih, weight_hh, bias_ih, bias_hh, h_prev)[0]  # forward layer
    backward_output = rnn_forward(torch.flip(input, [1]), weight_ih_reverse, weight_hh_reverse, bias_ih_reverse, bias_hh_reverse, h_prev_reverse)[0]  # backward layer
    
    h_out[:, :, :h_dim] = forward_output
    h_out[:, :, h_dim:] = backward_output
    
    return h_out, h_out[:, -1, :].reshape((bs, 2, h_dim)).transpose(0, 1)

# 验证一下bidirectional_rnn_forward的正确性
bi_rnn = nn.RNN(input_size, hidden_size, batch_first=True, bidirectional=True)
h_prev = torch.zeros(2, bs, hidden_size)
bi_rnn_output, bi_state_final = bi_rnn(input, h_prev)

#for k,v in bi_rnn.named_parameters():
    #print(k, v)

custom_bi_rnn_output, custom_bi_state_final = bidirectional_rnn_forward(input, bi_rnn.weight_ih_l0, bi_rnn.weight_hh_l0, bi_rnn.bias_ih_l0, bi_rnn.bias_hh_l0, h_prev[0], \
                             bi_rnn.weight_ih_l0_reverse, bi_rnn.weight_hh_l0_reverse, bi_rnn.bias_ih_l0_reverse, bi_rnn.bias_hh_l0_reverse, h_prev[1])

print("PyTorch API output:")
print(bi_rnn_output)
print(bi_state_final)

print("\n custom bidirectional_rnn_forward function output:")
print(custom_bi_rnn_output)
print(custom_bi_state_final)

PyTorch API output:
tensor([[[ 0.4750,  0.4664,  0.1212, -0.5255, -0.5600,  0.3535],
         [-0.0551,  0.7235,  0.1034, -0.2400, -0.9647, -0.3409],
         [ 0.4662, -0.3591, -0.8366,  0.0743, -0.2414,  0.8352]],

        [[ 0.6957, -0.0872, -0.4927, -0.3050, -0.3211,  0.3237],
         [ 0.6182, -0.3142, -0.6594,  0.2009, -0.6185,  0.6459],
         [ 0.7947, -0.5041, -0.6501,  0.0895,  0.2259,  0.8508]]],
       grad_fn=<TransposeBackward1>)
tensor([[[ 0.4662, -0.3591, -0.8366],
         [ 0.7947, -0.5041, -0.6501]],

        [[-0.5255, -0.5600,  0.3535],
         [-0.3050, -0.3211,  0.3237]]], grad_fn=<StackBackward0>)

 custom bidirectional_rnn_forward function output:
tensor([[[ 0.4750,  0.4664,  0.1212,  0.0743, -0.2414,  0.8352],
         [-0.0551,  0.7235,  0.1034, -0.2400, -0.9647, -0.3409],
         [ 0.4662, -0.3591, -0.8366, -0.5255, -0.5600,  0.3535]],

        [[ 0.6957, -0.0872, -0.4927,  0.0895,  0.2259,  0.8508],
         [ 0.6182, -0.3142, -0.6594,  0.2009, -0.6185