## `LSTM`
***
***
Time: 2020-09-14<br>
Author: dsy
***

![LSTM](./imgs/LSTM.png)

$$
\begin{aligned}
c_t &= f_t \odot c_{t-1} + i_t \odot \tilde{c_t} \\
h_t &= \omicron_t \odot \tanh(c_t)\\
\tilde{c_t} &= \tanh(W_c x_t + U_c h_{t-1} + b_c)\\
i_t &= \sigma(W_ix_t + U_ih_{t-1}+b_i) \\
f_t &= \sigma(W_fx_t + U_fh_{t-1}+b_f) \\
\omicron_t &= \sigma(W_\omicron x_t+U_\omicron h_{t-1} + b_\omicron)
\end{aligned}
$$

In [1]:
import torch
import torch.nn as nn

In [2]:
class LSTMFromDsy(nn.Module):
    def __init__(self):
        super(LSTMFromDsy,self).__init__()
        
    def wub(self):
        W = torch.rand((self.m,self.m),requires_grad=True)
        U = torch.rand((self.m,self.m),requires_grad=True)
        b = torch.rand((self.m,self.n),requires_grad=True)
        return W,U,b
    
    def forward(self,x):
        '''
        x 形如:[[1,2,3]]
        '''
        self.m,self.n = x.shape
        
        ht = torch.zeros((self.m,self.n))
        ct = torch.zeros((self.m,self.n))
        
        Wi,Ui,bi = self.wub()
        Wf,Uf,bf = self.wub()
        Wo,Uo,bo = self.wub()
        
        it = torch.sigmoid(Wi .matmul(x)+Ui.matmul(ht) + bi)
        ft = torch.sigmoid(Wf .matmul(x)+Uf.matmul(ht) + bf)
        ot = torch.sigmoid(Wo .matmul(x)+Uo.matmul(ht) + bo)
        
        c_t_hat = torch.tanh(ot.matmul(ct))
        ct = ft.matmul(ct) + it.matmul(c_t_hat) 
        ht = ot.matmul(ct)
        
        return ct,ht

In [3]:
lstmfd = LSTMFromDsy()

In [4]:
[lstmfd(torch.rand((1,1))) for i in range(100)]

[(tensor([[0.]], grad_fn=<AddBackward0>),
  tensor([[0.]], grad_fn=<MmBackward>)),
 (tensor([[0.]], grad_fn=<AddBackward0>),
  tensor([[0.]], grad_fn=<MmBackward>)),
 (tensor([[0.]], grad_fn=<AddBackward0>),
  tensor([[0.]], grad_fn=<MmBackward>)),
 (tensor([[0.]], grad_fn=<AddBackward0>),
  tensor([[0.]], grad_fn=<MmBackward>)),
 (tensor([[0.]], grad_fn=<AddBackward0>),
  tensor([[0.]], grad_fn=<MmBackward>)),
 (tensor([[0.]], grad_fn=<AddBackward0>),
  tensor([[0.]], grad_fn=<MmBackward>)),
 (tensor([[0.]], grad_fn=<AddBackward0>),
  tensor([[0.]], grad_fn=<MmBackward>)),
 (tensor([[0.]], grad_fn=<AddBackward0>),
  tensor([[0.]], grad_fn=<MmBackward>)),
 (tensor([[0.]], grad_fn=<AddBackward0>),
  tensor([[0.]], grad_fn=<MmBackward>)),
 (tensor([[0.]], grad_fn=<AddBackward0>),
  tensor([[0.]], grad_fn=<MmBackward>)),
 (tensor([[0.]], grad_fn=<AddBackward0>),
  tensor([[0.]], grad_fn=<MmBackward>)),
 (tensor([[0.]], grad_fn=<AddBackward0>),
  tensor([[0.]], grad_fn=<MmBackward>)),
 (te