<a href="https://colab.research.google.com/github/GuiXu40/deeplearning0/blob/main/Basic_code/%E5%8D%95%E5%90%91RNN%E5%92%8C%E5%8F%8C%E5%90%91RNN%E5%AE%9E%E7%8E%B0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [17]:
import torch
import torch.nn as nn

bs, T = 2, 3 #批量大小，和输入序列长度
input_size, hidden_size = 2, 3 # 输入特征大小， 隐含层特征大小
input = torch.randn(bs, T, input_size) # 随机初始化一个输入特征序列
h_prev = torch.zeros(bs, hidden_size) # 初始隐含状态
# step1 调用pyTorch RNN API https://pytorch.org/docs/stable/generated/torch.nn.RNN.html
rnn = nn.RNN(input_size, hidden_size, batch_first=True)
rnn_ouput, state_final = rnn(input, h_prev.unsqueeze(0))
#print("Pytorch RNN API:")
#print(rnn_ouput)
#print(state_final)

# step2 手写一个rnn_forward函数，实现RNN原理的计算
def rnn_forward(input, weight_ih, weight_hh, bias_ih, bias_hh, h_prev):
  """
  input: [b, T, input_size]
  weight_ih: [hidden_size, input_size]
  weight_hh: [hidden_size, hidden_size]
  bias_ih: [input_size]
  bias_hh: [hidden_size]
  h_prev: [b, hidden_size]
  """
  bs, T, input_size = input.shape
  h_dim = weight_ih.shape[0]
  h_out = torch.zeros(bs, T, h_dim)

  for t in range(T):
    x = input[:, t, :].unsqueeze(2) # 获取当前时刻的输入特征， bs*input_size*1
    w_ih_batch = weight_ih.unsqueeze(0).tile(bs, 1, 1) # b*h_dim*input_size
    w_hh_batch = weight_hh.unsqueeze(0).tile(bs, 1, 1) # b*h_dim*h_dim

    w_times_x = torch.bmm(w_ih_batch, x).squeeze(-1) # b*h_dim
    w_times_h = torch.bmm(w_hh_batch, h_prev.unsqueeze(2)).squeeze(-1) # b*h_dim
    h_prev = torch.tanh(w_times_x+bias_ih + w_times_h+bias_hh) # b*h_dim

    h_out[:, t, :] = h_prev
  return h_out, h_prev.unsqueeze(0) # b*T*h_dim 1*b*h_dim

#for k, v in rnn.named_parameters():
#  print(k, v)
custom_run_output, custom_state_final = rnn_forward(input, rnn.weight_ih_l0, rnn.weight_hh_l0, rnn.bias_ih_l0, rnn.bias_hh_l0, h_prev)
#print("\n rnn_forward function output:")
#print(custom_run_output)
#print(custom_state_final)

# step3 手写衣蛾bidirectional_rnn_forward 函数， 实现双向RNN的计算原理
def bidirectional_rnn_forward(input, weight_ih, weight_hh, bias_ih, bias_hh, h_prev, \
                               weight_ih_reverse, weight_hh_reverse, bias_ih_reverse, bias_hh_reverse, h_prev_reverse):
  """
  input: [b, T, input_size]
  weight_ih: [hidden_size, input_size]
  weight_hh: [hidden_size, hidden_size]
  bias_ih: [input_size]
  bias_hh: [hidden_size]
  h_prev: [b, hidden_size]
  """
  bs, T, input_size = input.shape
  h_dim = weight_ih.shape[0]
  h_out = torch.zeros(bs, T, h_dim*2) # 初始化一个输出（状态）矩阵，注意双向是两倍的特征大小

  forward_output = rnn_forward(input, weight_ih, weight_hh, bias_ih, bias_hh, h_prev)[0]
  backward_output = rnn_forward(torch.flip(input, [1]), weight_ih_reverse, weight_hh_reverse, bias_ih_reverse, bias_hh_reverse, h_prev_reverse)[0]

  h_out[:, :, :h_dim] = forward_output
  h_out[:, :, h_dim:] = torch.flip(backward_output, [1])

  h_n = torch.zeros(bs, 2, h_dim)
  h_n[:, 0, :] = forward_output[:, -1, :]
  h_n[:, 1, :] = backward_output[:, -1, :]
  h_n = h_n.transpose(0, 1)

  #return h_out, h_out[:, -1, :].reshape((bs, 2, h_dim)).transpose(0, 1)
  return h_out, h_n

# 验证bidirectional_rnn_function 的正确性
bi_rnn = nn.RNN(input_size, hidden_size, batch_first=True, bidirectional=True)
h_prev = torch.zeros(2, bs, hidden_size)
bi_rnn_output, bi_state_final = bi_rnn(input, h_prev)
#for k, v in bi_rnn.named_parameters():
#  print(k, v)
print("pytorch bi_RNN API:")
print(bi_rnn_output)
print(bi_state_final)

custom_bi_run_output, custom_bi_state_final = bidirectional_rnn_forward(input, bi_rnn.weight_ih_l0, bi_rnn.weight_hh_l0, bi_rnn.bias_ih_l0, bi_rnn.bias_hh_l0, h_prev[0], \
                                                                        bi_rnn.weight_ih_l0_reverse, bi_rnn.weight_hh_l0_reverse, \
                                                                        bi_rnn.bias_ih_l0_reverse, bi_rnn.bias_hh_l0_reverse, h_prev[1])

print("\n bidirectional_rnn_forward output:")
print(custom_bi_run_output)
print(custom_bi_state_final)

pytorch bi_RNN API:
tensor([[[ 0.6269,  0.2766,  0.3033, -0.3628,  0.5605,  0.0917],
         [ 0.8897, -0.1003,  0.0740, -0.5450, -0.1588,  0.0874],
         [ 0.8579,  0.1222,  0.4492, -0.5948,  0.4981, -0.0735]],

        [[ 0.7717,  0.4923,  0.2075, -0.3371,  0.7338, -0.3508],
         [ 0.5323, -0.5082,  0.3746, -0.0385,  0.6485, -0.2444],
         [ 0.8457,  0.1279,  0.4327, -0.5614,  0.3480, -0.1893]]],
       grad_fn=<TransposeBackward1>)
tensor([[[ 0.8579,  0.1222,  0.4492],
         [ 0.8457,  0.1279,  0.4327]],

        [[-0.3628,  0.5605,  0.0917],
         [-0.3371,  0.7338, -0.3508]]], grad_fn=<StackBackward0>)

 bidirectional_rnn_forward output:
tensor([[[ 0.6269,  0.2766,  0.3033, -0.3628,  0.5605,  0.0917],
         [ 0.8897, -0.1003,  0.0740, -0.5450, -0.1588,  0.0874],
         [ 0.8579,  0.1222,  0.4492, -0.5948,  0.4981, -0.0735]],

        [[ 0.7717,  0.4923,  0.2075, -0.3371,  0.7338, -0.3508],
         [ 0.5323, -0.5082,  0.3746, -0.0385,  0.6485, -0.2444],
    