# Sequence-to-sequence RNN with Attention
We will now add attention to our sequence-to-sequence RNN. There are several ways to incorporate the context vector $c$ into the RNN architecture:
1. Add an additional term to the computation of the gates/states (i.e. treat it as an input just like $h_{t-1}$ and $x_t$). This was used in the original paper (Bahdanau et al, 2015), described in Appendix A.
2. Concatenate it with the hidden state of the last time step $h_{t-1}$ and project the concatenation down from `enc_hidden_dim + dec_hidden_dim` to `dec_hidden_dim`.
3. Concatenate it with the input $x_t$ and downproject it.

We will use variant 2 in this exercise. We'll make our lives a bit easier by implementing a 1-layer decoder and working with a batch size of 1.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

Since we have to compute the context vector at every step, we can't use the high-level `nn.LSTM` interface by PyTorch. We first implement a decoder LSTM class that operates an `nn.LSTMCell`. We start with the `__init__` method where we initialize all parameters.

In [2]:
class DecoderLSTMWithAttention(nn.Module):
    
    def __init__(self, input_dim, enc_output_dim, dec_hidden_dim):
        super().__init__()
        self.hidden_dim = dec_hidden_dim
        self.down_project = nn.Linear(enc_output_dim + dec_hidden_dim, dec_hidden_dim, bias=False)
        self.v = nn.Parameter(torch.empty(dec_hidden_dim))
        self.W = nn.Parameter(torch.zeros(dec_hidden_dim, dec_hidden_dim))
        self.U = nn.Linear(enc_output_dim, dec_hidden_dim, bias=False)  # can use parameter or Linear layer
        self.cell = nn.LSTMCell(input_dim, dec_hidden_dim)

Add a `reset_parameters` method that initializes all parameters.

In [3]:
def reset_parameters(self):
    self.down_project.reset_parameters()
    nn.init.normal_(self.v, mean=0, std=1)
    nn.init.normal_(self.W, mean=0, std=1)
    self.U.reset_parameters()
    self.cell.reset_parameters()

DecoderLSTMWithAttention.reset_parameters = reset_parameters

Add a `forward` method that takes a sequence `y` and encoder hidden states `encoder_hidden_states` as input. `encoder_hidden_states` is a tensor of size `[sequence_length, encoder_output_dim]`, where `encoder_output_dim = num_directions * encoder_hidden_dim`. The `forward` method should call `compute_context_vector` that computes the attention-weighted context vector. We will implement it later.

In [4]:
def forward(self, y, encoder_hidden_states):
    outputs = []
    previous_decoder_hidden_state = torch.zeros(self.hidden_dim)
    cell_state = torch.zeros(self.hidden_dim)
    
    # iterate over sequence y
    for y_i in y:
        context_vector = self.compute_context_vector(previous_decoder_hidden_state, encoder_hidden_states)
        concatenated = torch.cat([previous_decoder_hidden_state, context_vector], dim=-1)
        projected = self.down_project(concatenated)
        previous_decoder_hidden_state, cell_state = self.cell(y_i, (projected, cell_state))
        outputs.append(previous_decoder_hidden_state)
    
    return torch.stack(outputs), (previous_decoder_hidden_state, cell_state)
    
DecoderLSTMWithAttention.forward = forward

Now it's time to implement the `compute_context_vector` function. Its inputs are `previous_decoder_hidden_state` and `encoder_hidden_states`. Use either additive or multiplicative attention, as we saw it in the course. Extend the trainable parameters in your `__init__` method if necessary and initialize them in `reset_parameters`.

In [5]:
def compute_context_vector(self, previous_decoder_hidden_state, encoder_hidden_states):
    scores = []

    for encoder_hidden_state in encoder_hidden_states:
        scores.append(self.v @ torch.tanh(self.W @ previous_decoder_hidden_state + self.U(encoder_hidden_state)))
        # 1 x dec_hidden_dim @ (dec_hiddem_dim x enc_output_dim) --> 1 x enc_output_dim
        # 1 X enc_output_dim @ enc_output_dim x 1 --> 1 x 1

    scores = torch.stack(scores)
    attention_weights = F.softmax(scores, dim=-1)

    context_vector  = attention_weights @ encoder_hidden_states
    # 1 x enc_output_dim @ enc_output_dim x enc_output_dim --> 1 x enc_output_dim
    return context_vector

DecoderLSTMWithAttention.compute_context_vector = compute_context_vector

**Sequence-to-sequence model.** We will use the following hyperparameters.

In [6]:
# Typically, encoder/decoder hidden dimensions are the same,
# but here we choose them differently to test our implementation.
embedding_dim = 10
enc_hidden_dim = 15
dec_hidden_dim = 20
num_layers = 2
bidirectional = True
num_directions = 2 if bidirectional else 1

Now we define the model.

In [7]:
class Seq2seqLSTMWithAttention(nn.Module):
    
    def __init__(self, embedding_dim, enc_hidden_dim, num_enc_layers, bidirectional, dec_hidden_dim):
        super().__init__()
        num_directions = 2 if bidirectional else 1
        encoder_output_dim = enc_hidden_dim * num_directions
        self.encoder = nn.LSTM(embedding_dim, enc_hidden_dim, num_layers= num_enc_layers, bidirectional=bidirectional)
        self.decoder = DecoderLSTMWithAttention(embedding_dim, encoder_output_dim, dec_hidden_dim)

        self.encoder.reset_parameters()
        self.decoder.reset_parameters()

    def forward(self, x, y, h0, c0):
        encoder_hidden_states, _ = self.encoder(x, (h0, c0))
        decoder_outputs, (hn_dec, cn_dec) = self.decoder(y, encoder_hidden_states)
        return decoder_outputs, (hn_dec, cn_dec)

Try your Module with an example input.

In [8]:
model = Seq2seqLSTMWithAttention(embedding_dim, enc_hidden_dim, num_layers, bidirectional, dec_hidden_dim)
x = torch.randn(10, embedding_dim)
y = torch.randn(8, embedding_dim)
h0 = torch.zeros(num_layers * num_directions, enc_hidden_dim)
c0 = torch.zeros(num_layers * num_directions, enc_hidden_dim)
outputs, _ = model(x, y, h0, c0)
assert list(outputs.shape) == [8, dec_hidden_dim], "Wrong output shape"

Create a subclass of your decoder LSTM that implements the other type of attention (additive or multiplicative) that you haven't implemented above. What do you need to change?

In [9]:
class DecoderLSTMWithAdditiveAttention(DecoderLSTMWithAttention):
    # or:  DecoderLSTMWithMultiplicativeAttention
    
    def __init__(self, input_dim, enc_output_dim, dec_hidden_dim):
        super().__init__(input_dim, enc_output_dim, dec_hidden_dim)
        self.W = nn.Parameter(torch.randn(dec_hidden_dim, dec_hidden_dim))
        self.U = nn.Parameter(torch.randn(dec_hidden_dim, enc_output_dim))
        self.v = nn.Parameter(torch.randn(dec_hidden_dim))

    def reset_parameters(self):
        # Instead of calling super().reset_parameters()
        # Directly initialize what you need
        self.down_project.reset_parameters()
        nn.init.normal_(self.U, mean=0, std=1)
        nn.init.normal_(self.W, mean=0, std=1)
        nn.init.normal_(self.v, mean=0, std=1)
        self.cell.reset_parameters()

    def compute_context_vector(self, previous_decoder_hidden_state, encoder_hidden_states):
        scores = []

        for encoder_hidden_state in encoder_hidden_states:
            attention_score = self.v @ torch.tanh(self.W @ previous_decoder_hidden_state + self.U @ encoder_hidden_state)
            scores.append(attention_score)

        scores = torch.stack(scores)
        attention_weights = F.softmax(scores, dim=-1)

        context_vector  = attention_weights @ encoder_hidden_states
        return context_vector
        

We can test our implementation with the code below.

In [10]:
enc_output_dim = enc_hidden_dim * num_directions
# Uncomment the version you just implemented
model.decoder = DecoderLSTMWithAdditiveAttention(embedding_dim, enc_output_dim, dec_hidden_dim)
# model.decoder = DecoderLSTMWithMultiplicativeAttention(embedding_dim, enc_output_dim, dec_hidden_dim)
model.decoder.reset_parameters()
outputs, _ = model(x, y, h0, c0)
assert list(outputs.shape) == [8, dec_hidden_dim], "Wrong output shape"