# Long Short-Term Memory
In this exercise, we will implement an LSTM. In the class, we have already seen the definition of the LSTM update rules at time step $t$:

$$
\begin{align}
f_t &= \sigma(W_f h_{t-1} + U_f x_t + b_f) \\
i_t &= \sigma(W_i h_{t-1} + U_i x_t + b_i) \\
o_t &= \sigma(W_o h_{t-1} + U_o x_t + b_o) \\
\tilde{c}_t &= \tanh(W_c h_{t-1} + U_c x_t + b_c) \\
c_t &= f_t * c_{t-1} + i_t * \tilde{c}_t \\
h_t &= o_t * \tanh(c_t)
\end{align}
$$

In [1]:
import torch
import torch.nn as nn

Implement this original version of the LSTM as an `LSTMCell`.

In [2]:
class LSTMCell(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, num_chunks=4):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_chunks = num_chunks
        
        # stack weights and biases
        self.w = nn.Linear(hidden_dim, num_chunks * hidden_dim, bias=False)
        self.u = nn.Linear(input_dim, num_chunks * hidden_dim, bias=False)
        self.b = nn.Parameter(torch.empty(num_chunks * hidden_dim))
        
    def reset_parameters(self):
        for weight in self.parameters():
            nn.init.normal_(weight, mean=0, std=1)
    
    def forward(self, x, cell_state, hidden_state):
        updates = self.w(hidden_state) + self.u(x) + self.b
        updates = updates.reshape(self.num_chunks, self.hidden_dim)
        forget_gate = torch.sigmoid(updates[0])
        input_gate = torch.sigmoid(updates[1])
        output_gate = torch.sigmoid(updates[2])
        new_cell_memory = torch.tanh(updates[3])
        new_cell_state = forget_gate * cell_state + input_gate * new_cell_memory  # element-wise multiplications
        new_hidden_state = output_gate * torch.tanh(new_cell_state)
        return new_cell_state, new_hidden_state

Create a 2-layer LSTM from your LSTMCell base class and run a forward pass with a random input sequence to test that all your dimensions are correct.

In [3]:
class LSTM(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([])
        for i in range(num_layers):
            in_dim = input_dim if i == 0 else hidden_dim
            self.layers.append(LSTMCell(in_dim, hidden_dim))
    
    def reset_parameters(self):
        for layer in self.layers:
            layer.reset_parameters()
    
    def _forward(self, x, cell_state, hidden_state):
        assert cell_state.dim() == 2, "Cell state has the wrong number of dimensions"
        assert cell_state.size(0) == len(self.layers), "First dimension should be the number of layers"
        assert cell_state.size() == hidden_state.size(), "Hidden state has the wrong dimensions"
        new_cell_state = torch.empty_like(cell_state)
        new_hidden_state = torch.empty_like(hidden_state)
        for i, layer in enumerate(self.layers):
            new_cell_state[i], new_hidden_state[i] = layer(x, cell_state[i], hidden_state[i])
            x = new_hidden_state[i]  # input to layers above first is output hidden state
        return new_cell_state, new_hidden_state
    
    def forward(self, x, cell_state, hidden_state):
        assert x.dim() == 2, "input needs to be of shape [sequence length, input dim]"
        for x_i in x:
            cell_state, hidden_state = self._forward(x_i, cell_state, hidden_state)
        return cell_state, hidden_state
        

input_dim = 5
hidden_dim = 10
output_dim = 8
sequence_length = 6
num_layers = 2
x = torch.randn(sequence_length, input_dim)
c0 = torch.randn(num_layers, hidden_dim)
h0 = torch.randn(num_layers, hidden_dim)
lstm = LSTM(input_dim, hidden_dim, num_layers)
lstm.reset_parameters()
cn, hn = lstm(x, c0, h0)
print('LSTM outputs:', cn.shape, hn.shape)
print(hn)

LSTM outputs: torch.Size([2, 10]) torch.Size([2, 10])
tensor([[ 2.6558e-02, -4.7880e-01, -3.8033e-01, -5.0786e-01, -2.5652e-01,
         -6.8788e-01, -1.8028e-01,  1.5213e-01,  1.4925e-01, -3.4361e-01],
        [ 7.8318e-02,  2.1964e-01, -1.4630e-02, -7.0103e-04, -3.5553e-01,
         -3.4361e-01,  4.7176e-01,  7.3530e-01,  4.8829e-01,  1.5587e-01]],
       grad_fn=<CopySlices>)


Implement a subclass of your LSTM that uses a coupled forget and input gate, i.e. the cell state update becomes:

$$c_t = f_t * c_{t-1} + (1-f_t) * \tilde{c}_t$$

In [4]:
class CoupledLSTMCell(LSTMCell):
    
    def __init__(self, input_dim, hidden_dim):
        super().__init__(input_dim, hidden_dim, num_chunks=3)
    
    def forward(self, x, cell_state, hidden_state):
        updates = self.w(hidden_state) + self.u(x) + self.b
        updates = updates.reshape(self.num_chunks, hidden_dim)
        forget_gate = torch.sigmoid(updates[0])
        output_gate = torch.sigmoid(updates[1])  # updated index
        new_cell_memory = torch.tanh(updates[2])  # updated index
        new_cell_state = forget_gate * cell_state + (1 - forget_gate) * new_cell_memory  # updated gate
        new_hidden_state = output_gate * torch.tanh(new_cell_state)
        return new_cell_state, new_hidden_state

class CoupledLSTM(LSTM):    
    
    def __init__(self, input_dim, hidden_dim, num_layers):
        super().__init__(input_dim, hidden_dim, num_layers)
        self.layers = nn.ModuleList([])  # reset layers
        for i in range(num_layers):
            in_dim = input_dim if i == 0 else hidden_dim
            self.layers.append(CoupledLSTMCell(in_dim, hidden_dim))

coupled_lstm = CoupledLSTM(input_dim, hidden_dim, num_layers)
coupled_lstm.reset_parameters()
cn, hn = coupled_lstm(x, c0, h0)
print(cn.size(), hn.shape)

torch.Size([2, 10]) torch.Size([2, 10])


**Bonus:** Implement *peephole connections* as described at the start of the Section *Variants on Long Short Term Memory* in [this blog post explaining LSTMs](https://colah.github.io/posts/2015-08-Understanding-LSTMs/).

The gate update definitions get an additional term that looks at the cell state:
$$
\begin{align}
f_t &= \sigma(W_f h_{t-1} + U_f x_t + b_f \boldsymbol{+ V_f c_{t-1}}) \\
i_t &= \sigma(W_i h_{t-1} + U_i x_t + b_i \boldsymbol{+ V_i c_{t-1}}) \\
o_t &= \sigma(W_o h_{t-1} + U_o x_t + b_o \boldsymbol{+ V_o c_t})
\end{align}
$$

To make the task a bit easier, we will implement the last equation with the cell state of the previous time step $t-1$ as $$o_t = \sigma(W_o h_{t-1} + U_o x_t + b_o \boldsymbol{+ V_o c_{t-1}})$$ instead.

In [5]:
class PeepholeLSTMCell(LSTMCell):
    
    def __init__(self, input_dim, hidden_dim, num_chunks=4):
        super().__init__(input_dim, hidden_dim, num_chunks)
        
        # add weights for additional term
        self.v = nn.Linear(hidden_dim, 3 * hidden_dim, bias=False)
    
    def forward(self, x, cell_state, hidden_state):
        # add peephole updates and stack it with an all-zero tensor
        # to avoid changing the new cell memory's computation
        peephole_updates = torch.cat([self.v(hidden_state), torch.zeros(self.hidden_dim)])
        updates = self.w(hidden_state) + self.u(x) + self.b + peephole_updates
        # rest stays the same
        updates = updates.reshape(self.num_chunks, hidden_dim)
        updates = torch.unbind(updates)
        forget_gate = torch.sigmoid(updates[0])
        input_gate = torch.sigmoid(updates[1])
        output_gate = torch.sigmoid(updates[2])
        new_cell_memory = torch.tanh(updates[3])
        new_cell_state = forget_gate * cell_state + input_gate * new_cell_memory  # element-wise multiplications
        new_hidden_state = output_gate * torch.tanh(new_cell_state)
        return new_cell_state, new_hidden_state

class PeepholeLSTM(LSTM):    
    
    def __init__(self, input_dim, hidden_dim, num_layers):
        super().__init__(input_dim, hidden_dim, num_layers)
        self.layers = nn.ModuleList([])  # reset layers
        for i in range(num_layers):
            in_dim = input_dim if i == 0 else hidden_dim
            self.layers.append(PeepholeLSTMCell(in_dim, hidden_dim))

peephole_lstm = PeepholeLSTM(input_dim, hidden_dim, num_layers)
peephole_lstm.reset_parameters()
cn, hn = peephole_lstm(x, c0, h0)
print(cn.size(), hn.shape)

torch.Size([2, 10]) torch.Size([2, 10])
