# 1.基础概念  
![img](./pic/LSTM.png)  

**目的与GRU一样,本质都是对```H_t```进行不同的计算。**

# 2.代码实现  


In [6]:
import torch
from torch import nn

def get_params(emb_size, num_hiddens):
    num_inputs = num_outputs = emb_size

    def normal(shape):
        return torch.randn(size=shape) * 0.01  # 保证均值为0方差为0.01

    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens))

    # 隐藏层参数
    W_xi, W_hi, b_i = three()  # 输入门参数
    W_xf, W_hf, b_f = three()  # 遗忘门参数
    W_xo, W_ho, b_o = three()  # 输出门参数
    W_xc, W_hc, b_c = three()  # 候选记忆参数

    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs)

    # 附加梯度
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params

In [7]:
def init_lstm_state(batch_size, num_hiddens):
    return (torch.zeros((batch_size, num_hiddens)),
            torch.zeros((batch_size, num_hiddens)))

计算公式:
<div style="border-left:2px solid black;padding:10px;margin-left:20px;">

$$ I_t = \sigma(W_{xi}x_t + W_{hi}h_{t-1} + b_i) $$  
$$ F_t = \sigma(W_{xf}x_t + W_{hf}h_{t-1} + b_f) $$  
$$ O_t = \sigma(W_{xo}x_t + W_{ho}h_{t-1} + b_o) $$
$$ \hat{C_t} = \tanh(W_{xc}x_t + W_{hc}h_{t-1} + b_c) $$  
$$ C_t = F_t \odot C_{t-1} + I_t \odot \hat{C_t} $$  
</div>
  
$$ h_t = O_t * \tanh(C_t) $$  

In [8]:
def lstm(inputs, state, params):
    # inputs的形状为(T,bs,emb)
    W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q = params
    (H, C) = state
    outputs = []
    # X的形状为(bs,emb)
    for X in inputs:
        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
        C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
        C = F * C + I * C_tilda
        Y = (H @ W_hq) + b_q
        outputs.append(Y)

    return torch.cat(outputs, dim=0), (H, C)  # 输出wei(T*bs,emb); (bs, h)

In [9]:
class RNNModel:
    def __init__(
        self, vocab_size, emb_size, num_hiddens, get_params, init_state, forward_fn):
        
        self.vocab_size, self.num_hiddens = vocab_size, num_hiddens
        self.params = get_params(emb_size, num_hiddens)
        self.init_state, self.forward_fn = init_state, forward_fn
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.Linear = nn.Linear(emb_size, vocab_size)

    def __call__(self, X, state):
        X = self.embedding(X.T).float()
        X, HC = self.forward_fn(X, state, self.params)
        return self.Linear(X), HC  # 

    def begin_state(self, batch_size):
        return self.init_state(batch_size, self.num_hiddens)

In [11]:
num_hiddens, vocab_size, emb_size = 512, 28, 20
X = torch.arange(10).reshape(2, 5)  # bs=2, T=5
net = RNNModel(vocab_size, emb_size, num_hiddens, get_params, init_lstm_state, lstm)

# emb = nn.Embedding(vocab_size, emb_size)
state = net.begin_state(X.shape[0])
Y, new_state = net(X, state)  # 输入的是(T,bs,emb)-->(T,bs,vocab)
print(Y.shape, len(new_state), new_state[0].shape)

"""
torch.Size([10, 28]) 2 torch.Size([2, 512])
"""

torch.Size([10, 28]) 2 torch.Size([2, 512])


'\ntorch.Size([10, 28]) 2 torch.Size([2, 512])\n'