# Long Short-Term Memory
In this exercise, we will implement an LSTM. In the class, we have already seen the definition of the LSTM update rules at time step $t$:

$$
\begin{align}
f_t &= \sigma(W_f h_{t-1} + U_f x_t + b_f) \\
i_t &= \sigma(W_i h_{t-1} + U_i x_t + b_i) \\
o_t &= \sigma(W_o h_{t-1} + U_o x_t + b_o) \\
\tilde{c}_t &= \tanh(W_c h_{t-1} + U_c x_t + b_c) \\
c_t &= f_t * c_{t-1} + i_t * \tilde{c}_t \\
h_t &= o_t * \tanh(c_t)
\end{align}
$$

In [None]:
import torch
import torch.nn as nn

Implement this original version of the LSTM as an `LSTMCell`.

In [None]:
class LSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        
        self.W_f = nn.Linear(hidden_dim, input_dim, bias=False)
        self.U_f = nn.Linear(input_dim, hidden_dim, bias=False)
        self.b_f = nn.Parameter(torch.randn(hidden_dim))
        self.W_i = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.U_i = nn.Linear(input_dim, hidden_dim, bias=False)
        self.b_i = nn.Parameter(torch.randn(hidden_dim))
        self.W_o = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.U_o = nn.Linear(input_dim, hidden_dim, bias=False)
        self.b_o = nn.Parameter(torch.randn(hidden_dim))
        self.W_c = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.U_c = nn.Linear(input_dim, hidden_dim, bias=False)
        self.b_c = nn.Parameter(torch.randn(hidden_dim))
    
    def reset_parameters(self):
        for name, weight in self.named_parameters():
            nn.init.normal_(weight, mean=0, std=1.0)
    
    def forward(self, x, c, h):
        #with nn.Parameter: forget_gate = self.W_f @h + self.U_f @x + self.b_f
        forget_gate = torch.sigmoid(self.W_f(h) + self.U_f(x) + self.b_f)
        input_gate = torch.sigmoid(self.W_i(h) + self.U_i(x) + self.b_i)
        output_gate = torch.sigmoid(self.W_o(h) + self.U_o(x) + self.b_o)
        new_cell_memory = torch.tanh(self.W_c(h) + self.U_c(x) + self.b_c)
        new_cell_state = forget_gate * input_gate * new_cell_memory
        new_hidden_state = output_gate * torch.tanh(new_cell_state)
        return new_cell_state, new_hidden_state

Create a 2-layer LSTM from your LSTMCell base class and run a forward pass with a random input sequence to test that all your dimensions are correct.

In [None]:
class LSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, n_layers):
        super().__init__()
        self.layers = nn.ModuleList([])
        for i in range(n_layers):
            in_dim = input_dim if i == 0 else hidden_dim
            self.layers.append(LSTMCell(in_dim, hidden_dim))
        
    def reset_parameters(self):
        for layer in self.layers:
            layer.reset_parameters()
        
    def forward(self, x, c, h):
        """
        x.shape: [sequence_length, input_dim]
        c.shape: [n_layers, hidden_dim]
        h.shape: [n_layers, hidden_dim]
        """
        outputs = []
        # iteration over time steps
        for x_i in x:
            # x_i.shape: [input_dim]
            # iteration over layers
            for layer_i, layer in enumerate(self.layers):
                # LSTMCell forward pass
                c[layer_i], h[layer_i] = layer(x_i, c[layer_i], h[layer_i])
                x_i = h[layer_i]
            outputs.append(x_i)   
        return outputs, (c, h)
        
        sequence_length = 10
        input_dim = 300
        hidden_dim = 20
        n_layers = 2
        inputs = torch.randn(sequence_length, input_dim)
        c0 = torch.zeros(n_layers, hidden_dim)
        h0 = torch.zeros(n_layers, hidden_dim)
        lstm = LSTM(input_dim, hidden_dim, n_layers)
        lstm.reset_parameters()
        outputs, (cn, hn) = lstm(inputs, c0, h0)
        print(len(outputs), outputs[0].shape)
        print(cn.shape)
        print(hn.shape)

Implement a subclass of your LSTM that uses a coupled forget and input gate, i.e. the cell state update becomes:

$$c_t = f_t * c_{t-1} + (1-f_t) * \tilde{c}_t$$

In [None]:
class CoupledLSTMCell(LSTMCell):
    def __init__(self, input_dim, hidden_dim):
        super().__init__(input_dim, hidden_dim)
        
        self.W_i = None
        self.U_i = None
        self.b_i = None
        
    def forward(self, x, c, h):
        forget_gate = torch.sigmoid(self.W_f(h) + self.U_f(x) + self.b_f)
        output_gate = torch.sigmoid(self.W_o(h) + self.U_o(x) + self.b_o)
        new_cell_memory = torch.tanh(self.W_c(h) + self.U_c(x) + self.b_c)
        new_cell_state = forget_gate * c + (1 - forget_gate) * new_cell_memory
        new_hidden_state = output_gate * torch.tanh(new_cell_state)
        return new_cell_state, new_hidden_state
    
class CoupledLSTM(LSTM):
    def __init__(self, input_dim, hidden_dim, n_layers):
        super().__init__(input_dim, hidden_dim, n_layers)
        self.layers = nn.ModuleList([])
        for i in range(n_layers):
            in_dim = input_dim if i == 0 else hidden_dim
            self.layers.append(CoupledLSTMCell(in_dim, hidden_dim))
            

coupled_lstm = CoupledLSTM()
        

**Bonus:** Implement *peephole connections* as described at the start of the Section *Variants on Long Short Term Memory* in [this blog post explaining LSTMs](https://colah.github.io/posts/2015-08-Understanding-LSTMs/).

The gate update definitions get an additional term that looks at the cell state:
$$
\begin{align}
f_t &= \sigma(W_f h_{t-1} + U_f x_t + b_f \boldsymbol{+ V_f c_{t-1}}) \\
i_t &= \sigma(W_i h_{t-1} + U_i x_t + b_i \boldsymbol{+ V_i c_{t-1}}) \\
o_t &= \sigma(W_o h_{t-1} + U_o x_t + b_o \boldsymbol{+ V_o c_t})
\end{align}
$$

To make the task a bit easier, we will implement the last equation with the cell state of the previous time step $t-1$ as $$o_t = \sigma(W_o h_{t-1} + U_o x_t + b_o \boldsymbol{+ V_o c_{t-1}})$$ instead.