## `LSTM`
***
***
Time: 2020-09-14<br>
Author: dsy
***

![LSTM](./imgs/LSTM.png)

$$
\begin{aligned}
c_t &= f_t \odot c_{t-1} + i_t \odot \tilde{c_t} \\
h_t &= \omicron_t \odot \tanh(c_t)\\
\tilde{c_t} &= \tanh(W_c x_t + U_c h_{t-1} + b_c)\\
i_t &= \sigma(W_ix_t + U_ih_{t-1}+b_i) \\
f_t &= \sigma(W_fx_t + U_fh_{t-1}+b_f) \\
\omicron_t &= \sigma(W_\omicron x_t+U_\omicron h_{t-1} + b_\omicron)
\end{aligned}
$$

In [1]:
import torch
import torch.nn as nn

In [2]:
class LSTMFromDsy2(nn.Module):
    def __init__(self):
        super(LSTMFromDsy2,self).__init__()
    def wub(self):
        W = torch.rand(1,requires_grad=True)
        U = torch.rand(1,requires_grad=True)
        b = torch.rand(1,requires_grad=True)
        return W,U,b
    
    def forward(self,x):
        m,n = x.shape
        
        wi,ui,bi = self.wub()
        wf,uf,bf = self.wub()
        wo,uo,bo = self.wub()
        wch,uch,bch = self.wub()
        
        ct = torch.zeros((m,n+1),requires_grad=False)
        ht = torch.zeros((m,n+1),requires_grad=False)
        
        
        for i in range(m):
            for j in range(n):
                ct_hat = torch.tanh(wch * x[i][j] + uch * ht[i][j] + bch)
                it = torch.sigmoid(wi * x[i][j] + ui * ht[i][j] + bi)
                ft = torch.sigmoid(wf * x[i][j] + uf * ht[i][j] + bf)
                ot = torch.sigmoid(wo * x[i][j] + uo * ht[i][j] + bo)
                ct[i][j+1] = ft * ct[i][j] + it * ct_hat
                ht[i][j+ 1] = ot * torch.tanh(ct[i][j+1])
                
        return ct[:,1:],ht[:,1:]
        

In [3]:
X = torch.randn((4,4))
X

tensor([[ 0.7357,  1.3949,  0.0956, -0.2546],
        [ 0.7246, -0.4216, -0.5635, -0.3912],
        [ 0.9221, -0.6590,  0.9428, -0.2178],
        [-0.4692,  0.1325, -1.4172, -0.8113]])

In [4]:
lstmfd2 = LSTMFromDsy2()
ct,ht = lstmfd2(X)

In [5]:
ct

tensor([[0.3991, 0.8497, 1.0377, 1.0834],
        [0.3984, 0.5952, 0.7042, 0.8172],
        [0.4118, 0.5642, 0.9567, 1.0358],
        [0.3119, 0.6136, 0.5323, 0.6150]], grad_fn=<SliceBackward>)

In [6]:
ht

tensor([[0.2542, 0.4833, 0.5628, 0.5820],
        [0.2537, 0.3731, 0.4324, 0.4841],
        [0.2614, 0.3577, 0.5281, 0.5659],
        [0.2025, 0.3791, 0.3472, 0.3885]], grad_fn=<SliceBackward>)

In [7]:
# 以前的实现想法
class LSTMFromDsy(nn.Module):
    def __init__(self):
        super(LSTMFromDsy,self).__init__()
        
    def wub(self):
        W = torch.rand((self.m,self.m),requires_grad=True)
        U = torch.rand((self.m,self.m),requires_grad=True)
        b = torch.rand((self.m,self.n),requires_grad=True)
        return W,U,b
    
    def forward(self,x):
        '''
        x 形如:[[1,2,3]]
        '''
        self.m,self.n = x.shape
        
        ht = torch.zeros((self.m,self.n))
        ct = torch.zeros((self.m,self.n))
        
        Wi,Ui,bi = self.wub()
        Wf,Uf,bf = self.wub()
        Wo,Uo,bo = self.wub()
        
        it = torch.sigmoid(Wi .matmul(x)+Ui.matmul(ht) + bi)
        ft = torch.sigmoid(Wf .matmul(x)+Uf.matmul(ht) + bf)
        ot = torch.sigmoid(Wo .matmul(x)+Uo.matmul(ht) + bo)
        
        c_t_hat = torch.tanh(ot.matmul(ct))
        ct = ft.matmul(ct) + it.matmul(c_t_hat) 
        ht = ot.matmul(ct)
        
        return ct,ht