# Long Short-Term Memory
In this exercise, we will implement an LSTM. In the class, we have already seen the definition of the LSTM update rules at time step $t$:

$$
\begin{align}
f_t &= \sigma(W_f h_{t-1} + U_f x_t + b_f) \\
i_t &= \sigma(W_i h_{t-1} + U_i x_t + b_i) \\
o_t &= \sigma(W_o h_{t-1} + U_o x_t + b_o) \\
\tilde{c}_t &= \tanh(W_c h_{t-1} + U_c x_t + b_c) \\
c_t &= f_t * c_{t-1} + i_t * \tilde{c}_t \\
h_t &= o_t * \tanh(c_t)
\end{align}
$$

In [1]:
import torch
import torch.nn as nn

Implement this original version of the LSTM as an `LSTMCell`.

In [5]:
class LSTMCell(nn.Module):
  def __init__(self, input_size, hidden_size):
    super().__init__()

    self.input_size = input_size
    self.hidden_size = hidden_size

    # weights and biases for forget gate
    self.Wfh = nn.Parameter(torch.zeros(hidden_size, hidden_size))
    self.Wfx = nn.Parameter(torch.zeros(hidden_size, input_size))
    self.bf = nn.Parameter(torch.zeros(hidden_size))

    # weights and biases for input gate
    self.Wih = nn.Parameter(torch.zeros(hidden_size, hidden_size))
    # could also be written as: self.Wix = nn.Linear(input_size, hidden_size, bias=False) # => y = Wx
    self.Wix = nn.Parameter(torch.zeros(hidden_size, input_size))
    self.bi = nn.Parameter(torch.zeros(hidden_size))

    # weights and biases for output gate
    self.Woh = nn.Parameter(torch.zeros(hidden_size, hidden_size))
    self.Wox = nn.Parameter(torch.zeros(hidden_size, input_size))
    self.bo = nn.Parameter(torch.zeros(hidden_size))

    # weights and biases for new cell memory
    self.Wch = nn.Parameter(torch.zeros(hidden_size, hidden_size))
    self.Wcx = nn.Parameter(torch.zeros(hidden_size, input_size))
    self.bc = nn.Parameter(torch.zeros(hidden_size))

  def reset_parameters(self):
    for weight in self.parameters():
      nn.init.uniform_(weight, -1, 1)

  def forward(self, x, hidden_state, cell_state):
    forget_gate = torch.sigmoid(self.Wfh @ hidden_state + self.Wfx @ x + self.bf)
    input_gate = torch.sigmoid(self.Wih @ hidden_state + self.Wix @ x + self.bi)
    output_gate = torch.sigmoid(self.Woh @ hidden_state + self.Wox @ x + self.bo)
    new_cell_memory = torch.tanh(self.Wch @ hidden_state + self.Wcx @ x + self.bc)
    new_cell_state = forget_gate * cell_state + input_gate * new_cell_memory
    new_hidden_state = output_gate * torch.tanh(new_cell_state)

    return new_hidden_state, new_cell_state

Create a 2-layer LSTM from your LSTMCell base class and run a forward pass with a random input sequence to test that all your dimensions are correct.

In [10]:
class LSTM(nn.Module):
  def __init__(self, input_size, hidden_size, num_layers):
    super().__init__()

    self.input_size = input_size
    self.hidden_size = hidden_size
    self.num_layers = num_layers

    self.layers = nn.ModuleList([])
    for i in range(num_layers):
      in_dim = input_size if i == 0 else hidden_size
      self.layers.append(LSTMCell(in_dim, hidden_size))

  def reset_parameters(self):
    for lstm_cell in self.layers:
      lstm_cell.reset_parameters()

  def forward(self, x, hidden_state, cell_state):
    outputs = []
    for x_i in x: # iterate over the time steps first
      new_hidden_states = torch.empty_like(hidden_state)
      new_cell_states = torch.empty_like(cell_state)
      for i, lstm_cell in enumerate(self.layers): # iterate over the layers
        new_hidden_states[i], new_cell_states[i] = lstm_cell(x_i, hidden_state[i], cell_state[i])
        x_i = new_hidden_states[i]
      outputs.append(new_hidden_states[-1])
    return outputs, (new_hidden_states, new_cell_states)

In [12]:
input_size = 5
hidden_size = 10
sequence_length = 6
num_layers = 2

x = torch.zeros(sequence_length, input_size)
hidden_state = torch.zeros(num_layers, hidden_size)
cell_state = torch.zeros(num_layers, hidden_size)

lstm = LSTM(input_size, hidden_size, num_layers)
lstm.reset_parameters()
outputs, (hidden_state, cell_state) = lstm(x, hidden_state, cell_state)

print(len(outputs))
print(outputs[0].shape)
print(outputs[-1])
print(hidden_state.shape, cell_state.shape)
print(hidden_state)

2
torch.Size([10])
tensor([-0.1625,  0.0985,  0.0153,  0.0129, -0.0629, -0.0757,  0.0426, -0.1395,
         0.0219, -0.0074], grad_fn=<SelectBackward0>)
torch.Size([10]) torch.Size([10])
tensor([-0.2060, -0.0070,  0.0467, -0.3238,  0.3237, -0.0578, -0.1450, -0.0792,
         0.2230, -0.2209], grad_fn=<UnbindBackward0>)


Implement a subclass of your LSTM that uses a coupled forget and input gate, i.e. the cell state update becomes:

$$c_t = f_t * c_{t-1} + (1-f_t) * \tilde{c}_t$$

**Bonus:** Implement *peephole connections* as described at the start of the Section *Variants on Long Short Term Memory* in [this blog post explaining LSTMs](https://colah.github.io/posts/2015-08-Understanding-LSTMs/).

The gate update definitions get an additional term that looks at the cell state:
$$
\begin{align}
f_t &= \sigma(W_f h_{t-1} + U_f x_t + b_f \boldsymbol{+ V_f c_{t-1}}) \\
i_t &= \sigma(W_i h_{t-1} + U_i x_t + b_i \boldsymbol{+ V_i c_{t-1}}) \\
o_t &= \sigma(W_o h_{t-1} + U_o x_t + b_o \boldsymbol{+ V_o c_t})
\end{align}
$$

To make the task a bit easier, we will implement the last equation with the cell state of the previous time step $t-1$ as $$o_t = \sigma(W_o h_{t-1} + U_o x_t + b_o \boldsymbol{+ V_o c_{t-1}})$$ instead.