### What is the difference between a vanilla RNN, and an LSTM
LSTMs are a kind of RNN. How are they different from a Vanilla RNN?

- In a Vanilla RNN, we have a hidden state which we pass forward
$$
h_t = \tanh\left( W \cdot \begin{bmatrix} h_{t-1} \\ x \end{bmatrix}\right)
$$

- In an LSTM, we have a sequence of gates $i$, $f$, $o$, $g$, which each control the increment to the subsequent hidden state
$$
\begin{bmatrix} i \\ f \\ o \\ g\end{bmatrix} = \begin{bmatrix} \sigma \\ \sigma \\ \sigma \\ \tanh\end{bmatrix} \cdot W \cdot \begin{bmatrix} h_{t-1} \\ x\end{bmatrix}
$$

### LSTM Formulation
Functionally we use each of these "gates" to compute the following updates
$$c_t = f \odot c_{t-1} + i \odot g$$
$$h_t = o \odot \tanh(c_t)$$
note that $c_t$ is short for "cell state"

### Explaining the Gates

| character 	| full name   	| range of values 	| role                                                                               	|
|-----------	|-------------	|-----------------	|------------------------------------------------------------------------------------	|
| i         	| input gate  	| 0 to 1          	| Which cell states to we want to increment 	|
| f         	| forget gate 	| 0 to 1          	| What to erase from the previous cell state                                         	|
| o         	| output gate 	| 0 to 1          	| What to pass forward from the cell state to the hidden state                       	|
| g         	| gate        	| -1 to 1         	| By how much do we want to increment each cell state                                	|

Remember that each of the gates corresponds to the increment of a single elements

$$c_t[0][0] = f[0][0] \cdot c_{t-1}[0][0] + i[0][0] \cdot g[0][0]$$

In [1]:
# Torch Imports
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class LSTM(nn.Module):
    def __init__(self, input_sz, hidden_sz):
        super().__init__()
        self.input_sz = input_sz
        self.hidden_size = hidden_sz
        self.W = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4))
        self.U = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
        self.bias = nn.Parameter(torch.Tensor(hidden_sz * 4))
         
    def forward(self, x, h_t, c_t):
        bs, seq_sz, _ = x.size()
        hidden_seq = []
         
        HS = self.hidden_size
        for t in range(seq_sz):
            x_t = x[:, t, :]
            gates = x_t @ self.W + h_t @ self.U + self.bias
            i_t, f_t, g_t, o_t = (
                torch.sigmoid(gates[:, :HS]), # input
                torch.sigmoid(gates[:, HS:HS*2]), # forget
                torch.tanh(gates[:, HS*2:HS*3]), # gate
                torch.sigmoid(gates[:, HS*3:]), # output
            )
            c_t = f_t * c_t + i_t * g_t
            h_t = o_t * torch.tanh(c_t)
            
            hidden_seq.append(h_t.unsqueeze(0))
        hidden_seq = torch.cat(hidden_seq, dim=0)
        hidden_seq = hidden_seq.transpose(0, 1).contiguous()
        return hidden_seq, (h_t, c_t)

### Usage
Torch provides a nn.LSTM module its usage is as follows

In [None]:
# num_features_inp (embedding size) (100)
# hidden_features (256)
# num_layers (2)
lstm = nn.LSTM(100, 256, 2)

# shape = seq_len, batch_size, num_features
x = torch.ones((100, 1, 100))

output, (hidden, cell) = lstm(x)
output.shape, hidden.shape, cell.shape

# output[i] = hidden state i of the last layer
# hidden[i] = the last hidden state of the ith layer
# cell[i] = the last cell state of the ith layer

### References
[Stanford RNN Lecture](https://www.youtube.com/watch?v=6niqTuYFZLQ&t=3445s)