# Long Short-Term Memory
In this exercise, we will implement an LSTM. In the class, we have already seen the definition of the LSTM update rules at time step $t$:

$$
\begin{align}
f_t &= \sigma(W_f h_{t-1} + U_f x_t + b_f) \\
i_t &= \sigma(W_i h_{t-1} + U_i x_t + b_i) \\
o_t &= \sigma(W_o h_{t-1} + U_o x_t + b_o) \\
\tilde{c}_t &= \tanh(W_c h_{t-1} + U_c x_t + b_c) \\
c_t &= f_t * c_{t-1} + i_t * \tilde{c}_t \\
h_t &= o_t * \tanh(c_t)
\end{align}
$$

In [1]:
import torch
import torch.nn as nn

Implement this original version of the LSTM as an `LSTMCell`.

In [None]:
#DO NOT EXECUTE THIS CELL
#Other implementation to use linear instead of parameter

class LSTMCell(nn.Module):

    def __init__(self, input_dim, hidden_dim, num_chunks=4):
        super().__init__()
        #self.W_f = nn.Linear(hidden_dim, hidden_dim, bias=False)
        #self.W_i = nn.Linear(hidden_dim, hidden_dim, bias=False)
        #self.W_o = nn.Linear(hidden_dim, hidden_dim, bias=False)
        #self.W_c = nn.Linear(hidden_dim, hidden_dim, bias=False)
        #self.U_f = nn.Linear(input_dim, hidden_dim, bias=False)
        #self.U_i = nn.Linear(input_dim, hidden_dim, bias=False)
        #self.U_o = nn.Linear(input_dim, hidden_dim, bias=False)
        #self.U_c = nn.Linear(input_dim, hidden_dim, bias=False)

        #define everything together
        self.w = nn.Linear(hidden_dim, num_chunks * hidden_dim, bias=False)
        self.u = nn.Linear(input_dim, num_chunks * hidden_dim, bias=False)
        self.b = nn.Parameter(torch.zeros(num_chunks * hidden_dim))

    def reset_parameters(self): 
        for param in self.parameters():
            nn.init.normal(param, mean=0, std=1)

    def forward(self, x, hidden_state, cell_state):
        updates = self.w(hidden_state) + self.u(x) + self.b
        updates = updates.reshape(self.num_chunks, self.hidden_dim)

        forget_gate = torch.sigmoid(updates[0])
        input_gate = torch.sigmoid(updates[1])
        output_gate = torch.sigmoid(updates[2])
        cell_gate = torch.tanh(updates[3])

        new_cell_state = forget_gate * cell_state + input_gate * cell_gate
        new_hidden_state = output_gate * torch.tanh(new_cell_state)
        return new_hidden_state, new_cell_state

In [11]:
class LSTMCell(nn.Module):

    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.W_f = nn.Parameter(torch.randn(hidden_dim, hidden_dim))
        self.U_f = nn.Parameter(torch.randn(hidden_dim, input_dim))
        self.W_i = nn.Parameter(torch.randn(hidden_dim, hidden_dim))
        self.U_i = nn.Parameter(torch.randn(hidden_dim, input_dim))
        self.W_o = nn.Parameter(torch.randn(hidden_dim, hidden_dim))
        self.U_o = nn.Parameter(torch.randn(hidden_dim, input_dim))
        self.W_c = nn.Parameter(torch.randn(hidden_dim, hidden_dim))
        self.U_c = nn.Parameter(torch.randn(hidden_dim, input_dim))

        self.b_f = nn.Parameter(torch.randn(hidden_dim))
        self.b_i = nn.Parameter(torch.randn(hidden_dim))
        self.b_o = nn.Parameter(torch.randn(hidden_dim))
        self.b_c = nn.Parameter(torch.randn(hidden_dim))


    def reset_parameters(self): 
        for param in self.parameters():
            nn.init.normal(param, mean=0, std=1)

    def forward(self, x, hidden_state, cell_state):
        f = torch.sigmoid(torch.matmul(self.W_f, hidden_state) + torch.matmul(self.U_f, x) + self.b_f)
        i = torch.sigmoid(torch.matmul(self.W_i, hidden_state) + torch.matmul(self.U_i, x) + self.b_i)
        o = torch.sigmoid(torch.matmul(self.W_o, hidden_state) + torch.matmul(self.U_o, x) + self.b_o)
        c_hat = torch.tanh(torch.matmul(self.W_c, hidden_state) + torch.matmul(self.U_c, x) + self.b_c)
        c = f * cell_state + i * c_hat
        h = o * torch.tanh(c)
        return h, c

Create a 2-layer LSTM from your LSTMCell base class and run a forward pass with a random input sequence to test that all your dimensions are correct.

In [16]:
class LSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers):
        super().__init__()
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            in_dim = input_dim if i == 0 else hidden_dim
            self.layers.append(LSTMCell(in_dim, hidden_dim))

    def reset_parameters(self):
        for layer in self.layers:
            layer.reset_parameters()

    def forward(self, x, hidden_state, cell_state):
        outputs = []
        for x_i in x:
            for i, layer in enumerate(self.layers):
                hidden_state[i], cell_state[i] = layer(x_i, hidden_state[i], cell_state[i])
                x_i = hidden_state[i]
            outputs.append(hidden_state[-1])
        return outputs, (hidden_state, cell_state)
    

#Example of usage
input_dim = 5
hidden_dim = 10
num_layers = 2
seq_len = 6

x = torch.randn(seq_len, input_dim)
hidden_state = torch.zeros(num_layers, hidden_dim)
cell_state = torch.zeros(num_layers, hidden_dim)

lstm = LSTM(input_dim, hidden_dim, num_layers)
lstm.reset_parameters()

outputs, (hidden_state, cell_state) = lstm(x, hidden_state, cell_state)
print('LSTM outputs: ', len(outputs), outputs[0].shape)
print('Hidden state: ', hidden_state.shape)
print('Cell state: ', cell_state.shape)


LSTM outputs:  6 torch.Size([10])
Hidden state:  torch.Size([2, 10])
Cell state:  torch.Size([2, 10])


  nn.init.normal(param, mean=0, std=1)


Implement a subclass of your LSTM that uses a coupled forget and input gate, i.e. the cell state update becomes:

$$c_t = f_t * c_{t-1} + (1-f_t) * \tilde{c}_t$$

In [None]:
class CoupledLSTMCell(LSTMCell):

    def __init__(self, inout_dim, hidden_dim):
        super().__init__(inout_dim, hidden_dim)

    def forward(self, x, hidden_state, cell_state):
        updates = self.w(hidden_state) + self.u(x) + self.b
        updates = updates.reshape(self.num_chunks, self.hidden_dim)

        forget_gate = torch.sigmoid(updates[0])
        output_gate = torch.sigmoid(updates[1])
        cell_gate = torch.tanh(updates[2])

        new_cell_state = forget_gate * cell_state + (1 - forget_gate) * cell_gate
        new_hidden_state = output_gate * torch.tanh(new_cell_state)
        return new_hidden_state, new_cell_state
    
class CoupledLSTM(LSTM):
    
    def __init__(self, input_dim, hidden_dim, num_layers):
        super().__init__(input_dim, hidden_dim, num_layers)
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            in_dim = input_dim if i == 0 else hidden_dim
            self.layers.append(CoupledLSTMCell(in_dim, hidden_dim))

**Bonus:** Implement *peephole connections* as described at the start of the Section *Variants on Long Short Term Memory* in [this blog post explaining LSTMs](https://colah.github.io/posts/2015-08-Understanding-LSTMs/).

The gate update definitions get an additional term that looks at the cell state:
$$
\begin{align}
f_t &= \sigma(W_f h_{t-1} + U_f x_t + b_f \boldsymbol{+ V_f c_{t-1}}) \\
i_t &= \sigma(W_i h_{t-1} + U_i x_t + b_i \boldsymbol{+ V_i c_{t-1}}) \\
o_t &= \sigma(W_o h_{t-1} + U_o x_t + b_o \boldsymbol{+ V_o c_t})
\end{align}
$$

To make the task a bit easier, we will implement the last equation with the cell state of the previous time step $t-1$ as $$o_t = \sigma(W_o h_{t-1} + U_o x_t + b_o \boldsymbol{+ V_o c_{t-1}})$$ instead.