# LSTM
<img src="../img/lstm_im.png/">

Let:

$fn$ = Number of features

$hs$ = Number of output nodes (hidden size)

$bs$ = Batch size

Then:
 * Each $W_{something}$ matrix below has the shape $(fn, hs)$;
 * Each $U_{something}$ matrix below has the shape $(hs, hs)$;
 * Each $b_{something}$ matrix below has the shape $(1, hs)$;
 * The $x_t$ matrix below has shape $(bs, fn)$, corresponding to the element of index $t$ of each sequence inf the batch.; and
 * The $h_t$ matrix below has shape $(bs, hs)$, corresponding to hidden state at time $t$ of each sequence inf the batch.

And:

$f_t = \sigma(W_f \ x_t + U_f \ h_{t-1} + b_f)$

$i_t = \sigma(W_i \ x_t + U_i \ h_{t-1} + b_i)$

$o_t = \sigma(W_o \ x_t + U_o \ h_{t-1} + b_o)$

$g_t = \tanh \ (W_g \ x_t + U_g \ h_{t-1} + b_g)$ a.k.a. $\tilde{c}_t$

$c_t = f_t \circ c_{t-1} + i_t \circ g_t$

$h_t = o_t \circ \tanh \ (c_t)$

<a src="https://colah.github.io/posts/2015-08-Understanding-LSTMs/" target="_blank">Source</a>

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [98]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

## LSTM FROM SCARTCH Finished Version

In [97]:
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, device):
        super(LSTM, self).__init__()
        self.device = device
        self.W_xf, self.W_hf, self.b_f = torch.randn(input_size, hidden_size), torch.randn(hidden_size, hidden_size), torch.zeros(hidden_size)
        self.W_xi, self.W_hi, self.b_i = torch.randn(input_size, hidden_size), torch.randn(hidden_size, hidden_size), torch.zeros(hidden_size)
        self.W_xo, self.W_ho, self.b_o = torch.randn(input_size, hidden_size), torch.randn(hidden_size, hidden_size), torch.zeros(hidden_size)
        self.W_xc, self.W_hc, self.b_c = torch.randn(input_size, hidden_size), torch.randn(hidden_size, hidden_size), torch.zeros(hidden_size)
        
    def forward(self, X, state):
        prev_h, prev_c = state
        batch_size, timesteps, _  = X.shape
        hidden_seq = []
        for t in range(timesteps):
            x = X[:,t,:]
            forget_gate = torch.sigmoid(x @ self.W_xf + prev_h @ self.W_hf + self.b_f)
            input_gate = torch.sigmoid(x @ self.W_xi + prev_h @ self.W_hf + self.b_i)
            output_gate = torch.sigmoid(x @ self.W_xo + prev_h @ self.W_ho + self.b_o)
            C_tilda = torch.tanh(x @ self.W_xc + prev_h @ self.W_hc + self.b_c)
            Ct = forget_gate * prev_c + input_gate * C_tilda
            Ht = output_gate * torch.tanh(Ct)
            hidden_seq.append(Ht.unsqueeze(0))
        hidden_seq = torch.cat(hidden_seq, dim=0)
        hidden_seq = hidden_seq.permute(1, 0, 2)
        
        return hidden_seq, hidden_seq_2

In [89]:
INPUT_SIZE = 20 
HIDDEN_SIZE = 10
SEQ_LEN = 30
lstm = LSTM(INPUT_SIZE, HIDDEN_SIZE, device)

In [90]:
X = torch.randn(SEQ_LEN, 10, INPUT_SIZE)

In [91]:
h0 = torch.zeros(SEQ_LEN, HIDDEN_SIZE)
c0 = torch.zeros(SEQ_LEN, HIDDEN_SIZE)
ht = lstm(X, (h0, c0))

In [92]:
ht[0].shape

torch.Size([30, 10, 10])

In [93]:
ht[1].shape

torch.Size([30, 10, 10])

In [99]:
ht

(tensor([[[ 8.1236e-03, -5.6972e-01, -1.9282e-01,  ...,  3.3110e-03,
           -1.4730e-02,  4.3709e-01],
          [ 4.3937e-04, -3.1241e-04,  4.2398e-04,  ...,  5.7878e-02,
           -6.2739e-01,  7.3442e-03],
          [ 5.4506e-01, -1.5308e-03,  9.1636e-06,  ...,  1.4428e-01,
            9.9743e-03, -1.3020e-01],
          ...,
          [ 8.4202e-02,  4.8387e-01, -1.2761e-03,  ..., -5.8824e-01,
            6.6815e-02, -4.9246e-05],
          [-7.2291e-01, -7.0754e-05, -7.3520e-01,  ..., -6.3249e-06,
            2.1674e-01,  1.4758e-01],
          [-1.9841e-02,  3.2707e-02,  5.9956e-01,  ...,  1.1214e-03,
           -2.2702e-02, -1.1928e-01]],
 
         [[ 2.8724e-04, -4.8192e-01,  4.1065e-04,  ...,  7.4646e-01,
            5.6220e-01, -5.0417e-01],
          [-5.3685e-01,  1.7344e-02, -2.5793e-01,  ..., -1.4721e-05,
            4.2095e-04,  2.6854e-04],
          [ 7.5432e-01,  5.7725e-04,  4.4640e-01,  ...,  1.2265e-01,
            3.9054e-03, -5.4287e-01],
          ...,
    

In [95]:
ht[1][0]

tensor([[ 8.1236e-03, -5.6972e-01, -1.9282e-01, -1.3982e-01,  5.3829e-01,
          7.3191e-02,  3.1443e-02,  3.3110e-03, -1.4730e-02,  4.3709e-01],
        [ 4.3937e-04, -3.1241e-04,  4.2398e-04,  5.1653e-02, -2.8496e-01,
          2.0555e-01,  9.4550e-05,  5.7878e-02, -6.2739e-01,  7.3442e-03],
        [ 5.4506e-01, -1.5308e-03,  9.1636e-06,  6.1318e-03,  1.3247e-01,
         -1.0212e-04, -6.2275e-04,  1.4428e-01,  9.9743e-03, -1.3020e-01],
        [-2.4175e-01,  4.3197e-01, -2.3800e-01,  1.7865e-02,  1.1318e-01,
         -1.6723e-01,  5.5680e-03, -2.2768e-01, -4.3800e-02,  1.9391e-03],
        [-5.3886e-02,  5.9305e-03,  5.6082e-02, -5.9841e-05, -4.6791e-01,
          6.9016e-01,  1.1277e-06, -3.1930e-01,  7.2904e-01,  8.6519e-03],
        [-1.0295e-01, -4.0401e-03,  1.4152e-03, -6.2615e-05, -2.0703e-02,
          2.2247e-01,  3.6647e-01, -6.5065e-04,  1.4157e-02, -1.2546e-01],
        [ 2.7497e-05, -7.3498e-01, -1.1131e-01, -1.5859e-02,  4.5274e-01,
          2.9226e-01,  9.7908e-0