# Long Short-Term Memory
In this exercise, we will implement an LSTM. In the class, we have already seen the definition of the LSTM update rules at time step $t$:

$$
\begin{align}
f_t &= \sigma(W_f h_{t-1} + U_f x_t + b_f) \\
i_t &= \sigma(W_i h_{t-1} + U_i x_t + b_i) \\
o_t &= \sigma(W_o h_{t-1} + U_o x_t + b_o) \\
\tilde{c}_t &= \tanh(W_c h_{t-1} + U_c x_t + b_c) \\
c_t &= f_t * c_{t-1} + i_t * \tilde{c}_t \\
h_t &= o_t * \tanh(c_t)
\end{align}
$$

In [1]:
import torch
import torch.nn as nn

Implement this original version of the LSTM as an `LSTMCell`.

In [3]:
class LSTMCell(nn.Module):
    
    def __init__(self, input_dim, hidden_dim):
        super().__init__() 
        self.w_f = nn.Parameter(torch.empty(hidden_dim, hidden_dim))
        self.u_f = nn.Parameter(torch.empty(hidden_dim, input_dim))
        self.w_i = nn.Parameter(torch.empty(hidden_dim, hidden_dim))
        self.u_i = nn.Parameter(torch.empty(hidden_dim, input_dim))
        self.w_o = nn.Parameter(torch.empty(hidden_dim, hidden_dim))
        self.u_o = nn.Parameter(torch.empty(hidden_dim, input_dim))
        self.w_c = nn.Parameter(torch.empty(hidden_dim, hidden_dim))
        self.u_c = nn.Parameter(torch.empty(hidden_dim, input_dim))
        
        self.b_f = nn.Parameter(torch.empty(hidden_dim))
        self.b_i = nn.Parameter(torch.empty(hidden_dim))
        self.b_o = nn.Parameter(torch.empty(hidden_dim))
        self.b_c = nn.Parameter(torch.empty(hidden_dim))
    
    
    def reset_parameters(self):
        for weight in self.parameters():
            nn.init.uniform_(weight, -1, 1)
    
            
    def forward(self, x, h_t, c_t):
        f_t = torch.sigmoid(self.w_f @ h_t + self.u_f @ x + self.b_f)
        i_t = torch.sigmoid(self.w_i @ h_t + self.u_i @ x + self.b_i)
        o_t = torch.sigmoid(self.w_o @ h_t + self.u_o @ x + self.b_o)
        c_memory = torch.tanh(self.w_c @ h_t + self.u_c @ x + self.b_c)
        
        c_t = f_t * c_t + i_t * c_memory
        h_t = o_t * torch.tanh(c_t)
        
        return h_t, c_t

Create a 2-layer LSTM from your LSTMCell base class and run a forward pass with a random input sequence to test that all your dimensions are correct.

Implement a subclass of your LSTM that uses a coupled forget and input gate, i.e. the cell state update becomes:

$$c_t = f_t * c_{t-1} + (1-f_t) * \tilde{c}_t$$

**Bonus:** Implement *peephole connections* as described at the start of the Section *Variants on Long Short Term Memory* in [this blog post explaining LSTMs](https://colah.github.io/posts/2015-08-Understanding-LSTMs/).

The gate update definitions get an additional term that looks at the cell state:
$$
\begin{align}
f_t &= \sigma(W_f h_{t-1} + U_f x_t + b_f \boldsymbol{+ V_f c_{t-1}}) \\
i_t &= \sigma(W_i h_{t-1} + U_i x_t + b_i \boldsymbol{+ V_i c_{t-1}}) \\
o_t &= \sigma(W_o h_{t-1} + U_o x_t + b_o \boldsymbol{+ V_o c_t})
\end{align}
$$

To make the task a bit easier, we will implement the last equation with the cell state of the previous time step $t-1$ as $$o_t = \sigma(W_o h_{t-1} + U_o x_t + b_o \boldsymbol{+ V_o c_{t-1}})$$ instead.