# PyTorch_LSTM和LSTMP的原理及其手写复现

来自b站up主deep_thoughts 合集【PyTorch源码教程与前沿人工智能算法复现讲解】

P_30_PyTorch_LSTM和LSTMP的原理及其手写复现：

https://www.bilibili.com/video/BV1zq4y1m7aH/?spm_id_from=333.788&vd_source=18e91d849da09d846f771c89a366ed40

***资料***

介绍 LSTM 博客：

http://colah.github.io/posts/2015-08-Understanding-LSTMs/

PyTorch 官方文档：

https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html

## LSTM 公式

$i_t = \sigma (W_{ii} x_t + b_{ii}+W_{hi}h_{t-1}+b_{hi})$

$f_t = \sigma (W_{if} x_t + b_{if}+W_{hf}h_{t-1}+b_{hf})$

$i_t = tanh (W_{ig} x_t + b_{ig}+W_{hg}h_{t-1}+b_{hg})$

$o_t = \sigma (W_{io} x_t + b_{io}+W_{ho}h_{t-1}+b_{ho})$

$c_t = f_t \odot c_{t-1} + i_t \odot g_t$

$h_t = o_t \odot tanh(c_t)$

***注： $\odot$ 逐元素相乘***

In [15]:
import torch
import torch.nn as nn

# 实现LSTM和LSTMP的源码
# 定义常量
bs, T, i_size, h_size = 2, 3, 4, 5
proj_size = 3
input = torch.randn(bs, T, i_size)  # 输入序列
c0 = torch.randn(bs, h_size)  # 初始值，不需要训练
h0 = torch.randn(bs, proj_size)

# 调用官方LSTM API
lstm_layer = nn.LSTM(i_size, h_size, batch_first=True, proj_size=proj_size)
output, (h_final, c_final) = lstm_layer(input, (h0.unsqueeze(0), c0.unsqueeze(0)))
print(output)

for k, v in lstm_layer.named_parameters():
    print(k, v.shape)
    
# 自己写一个 LSTM 模型
def lstm_forward(input, initial_states, w_ih, w_hh, b_ih, b_hh, w_hr=None):
    h0, c0 = initial_states  # 初始状态
    bs, T, i_size = input.shape
    h_size = w_ih.shape[0] // 4
    
    prev_h = h0
    prev_c = c0
    batch_w_ih = w_ih.unsqueeze(0).tile(bs, 1, 1)  # [bs, 4*h_size, i_size]
    batch_w_hh = w_hh.unsqueeze(0).tile(bs, 1, 1)  # [bs, 4*h_size, h_size/p_size]
    
    if w_hr is not None:
        p_size = w_hr.shape[0]
        output_size = p_size
        batch_w_hr = w_hr.unsqueeze(0).tile(bs, 1, 1)  # [bs, p_size, h_size]
    else:
        output_size = h_size
        
    output = torch.zeros(bs, T, output_size)  # 输出序列
    
    for t in range(T):
        x = input[:, t, :]  # 当前时刻的输入向量, [bs, i_size]
        w_times_x = torch.bmm(batch_w_ih, x.unsqueeze(-1))  # [bs, 4*h_size, 1]
        w_times_x = w_times_x.squeeze(-1)  # [bs, 4*h_size]
        
        w_times_h_prev = torch.bmm(batch_w_hh, prev_h.unsqueeze(-1))  # [bs, 4*h_size/p_size, 1]
        w_times_h_prev = w_times_h_prev.squeeze(-1)  # [bs, 4*h_size]
        
        # 分别计算输入门(i)、遗忘门(f)、cell门(g)、输出门(o)
        i_t = torch.sigmoid(w_times_x[:, :h_size] + w_times_h_prev[:, :h_size] + b_ih[:h_size] + b_hh[:h_size])
        f_t = torch.sigmoid(w_times_x[:, h_size:2*h_size] + w_times_h_prev[:, h_size:2*h_size] 
                            + b_ih[h_size:2*h_size] + b_hh[h_size:2*h_size])
        g_t = torch.tanh(w_times_x[:, 2*h_size:3*h_size] + w_times_h_prev[:, 2*h_size:3*h_size] 
                         + b_ih[2*h_size:3*h_size] + b_hh[2*h_size:3*h_size])
        o_t = torch.sigmoid(w_times_x[:, 3*h_size:4*h_size] + w_times_h_prev[:, 3*h_size:4*h_size] 
                            + b_ih[3*h_size:4*h_size] + b_hh[3*h_size:4*h_size])
        prev_c = f_t * prev_c + i_t * g_t
        prev_h = o_t * torch.tanh(prev_c)  # [bs, h_size]
        
        if w_hr is not None:  # 做projection
            prev_h = torch.bmm(batch_w_hr, prev_h.unsqueeze(-1))  # [bs, p_size, 1]
            prev_h = prev_h.squeeze(-1)  # [bs, p_size]
                               
        output[:, t, :] = prev_h
        
    return output, (prev_h, prev_c)

output_custom, (h_final_custom, c_final_custom) = lstm_forward(input, (h0, c0), lstm_layer.weight_ih_l0, 
                                                               lstm_layer.weight_hh_l0, lstm_layer.bias_ih_l0, 
                                                               lstm_layer.bias_hh_l0, lstm_layer.weight_hr_l0)

print(output_custom)

tensor([[[ 0.3658, -0.3085,  0.0826],
         [ 0.1759, -0.1754,  0.1034],
         [-0.0041,  0.0428,  0.0213]],

        [[ 0.0031, -0.1201,  0.1253],
         [-0.0213, -0.0656,  0.0130],
         [ 0.0790, -0.2002,  0.1780]]], grad_fn=<TransposeBackward0>)
weight_ih_l0 torch.Size([20, 4])
weight_hh_l0 torch.Size([20, 3])
bias_ih_l0 torch.Size([20])
bias_hh_l0 torch.Size([20])
weight_hr_l0 torch.Size([3, 5])
tensor([[[ 0.3658, -0.3085,  0.0826],
         [ 0.1759, -0.1754,  0.1034],
         [-0.0041,  0.0428,  0.0213]],

        [[ 0.0031, -0.1201,  0.1253],
         [-0.0213, -0.0656,  0.0130],
         [ 0.0790, -0.2002,  0.1780]]], grad_fn=<CopySlices>)
