In [None]:
import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

%matplotlib inline

random.seed(5)
np.random.seed(5)
torch.random.manual_seed(5)

### Question 2

In this exercise, you will implement a sequence-to-sequence network that reverses strings with the help of attention. We will randomly generate strings consisting of "a", "b", "c", and "d".

In [None]:
BOS = "<s>"
EOS = "</s>"

raw_vocab = list("abcd")
itos = [BOS, EOS] + raw_vocab
stoi = {n: i for i, n in enumerate(itos)}
vocab_size = len(itos)  # Plus BOS/EOS

N = 200
valid_size = 100

def sample_string(min_length, max_length):
    length = random.randrange(min_length, max_length)
    return "".join([random.choice(raw_vocab) for _ in range(length)])

def sample_strings(min_length, max_length, size):
    return [sample_string(min_length, max_length) for _ in range(size)]

def to_tensor(name):
    indices = [stoi[BOS]] + [stoi[n] for n in name] + [stoi[EOS]]
    return torch.tensor(indices, dtype=torch.long).unsqueeze(0)

def make_dataset(lines):
    dataset = [(to_tensor(line), to_tensor(reversed(line))) for line in lines]
    return dataset

train_lines = sample_strings(3, 15, N)
valid_lines = sample_strings(3, 15, valid_size)

train_dataset = make_dataset(train_lines)
valid_dataset = make_dataset(valid_lines)

print(train_lines[:10])

The first part of the model is an RNN-based encoder:

In [None]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_size, hidden_size, bidirectional=False):
        super(Encoder, self).__init__()

        self.embeddings = nn.Embedding(vocab_size, embedding_size)

        if bidirectional:
            hidden_size //= 2
        self.rnn = nn.LSTM(
            embedding_size, 
            hidden_size, 
            bidirectional=bidirectional, 
            batch_first=True
        )

    def forward(self, input, hidden=None):
        """
        input (LongTensor): batch x src length
        src length (batch-length list0: If given, the input will be packed
        hidden: hidden or hidden/cell state input dimensions for the RNN type
        returns:
            output (FloatTensor): batch x src length x hidden size
            hidden_n (FloatTensor): hidden or hidden/cell state input
                dimensions for the RNN type
        """
        emb = self.embeddings(input)
        output, hidden_n = self.rnn(emb, hidden)
        if self.rnn.bidirectional:
            hidden_n = self._reshape_hidden(hidden_n)
        return output, hidden_n

    def _merge_tensor(self, state_tensor):
        forward_states = state_tensor[::2]
        backward_states = state_tensor[1::2]
        return torch.cat([forward_states, backward_states], 2)

    def _reshape_hidden(self, hidden):
        """
        hidden:
            num_layers * num_directions x batch x self.hidden_size // 2
            or a tuple of these
        returns:
            num_layers
        """
        assert self.rnn.bidirectional
        if isinstance(hidden, tuple):
            return tuple(self._merge_tensor(h) for h in hidden)
        else:
            return self._merge_tensor(hidden)

We also need to define a decoder. This implementation works both with and without an attention mechanism.

In [None]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_size, hidden_size, attn=None):
        super(Decoder, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_size)
        self.rnn = nn.LSTM(embedding_size, hidden_size, batch_first=True)
        self.output_layer = nn.Linear(hidden_size, vocab_size)
        self.attn = attn

    def forward(self, input, context, hidden):
        """
        input (LongTensor): batch x tgt length
        context (FloatTensor): batch x src length x hidden size
        hidden: hidden or hidden/cell state input dimensions for the RNN type
        returns (FloatTensor): (batch*tgt length) x output size
        """
        emb = self.embeddings(input)
        output, hidden_n = self.rnn(emb, hidden)

        alignment = None
        # apply attention between source context and query from
        # decoder RNN
        if self.attn is not None:
            output, alignment = self.attn(output, context)

        flat_output = output.contiguous().view(-1, self.rnn.hidden_size)
        return self.output_layer(flat_output), alignment

We can put them together into an encoder-decoder model class, like this:

In [None]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, src, tgt):
        """
        src, tgt (LongTensor): (batch size x sequence length)
        returns (FloatTensor): (batch*tgt length) x output size
        """
        context, enc_hidden = self.encoder(src)
        return self.decoder(tgt, context=context, hidden=enc_hidden)

With our base model defined, we can write training and validation code:

In [None]:
def train_epoch(model, train_iter, loss, optimizer):
    epoch_loss = 0.0
    model.train()
    random.shuffle(train_iter)  # present examples in random order
    for src, tgt in train_iter:
        model.zero_grad()
        tgt_in = tgt[:, :-1]
        pred, _ = model(src, tgt_in)
        gold = tgt[:, 1:].contiguous().view(-1)

        batch_loss = loss(pred, gold)
        batch_loss.backward()
        optimizer.step()
        epoch_loss += batch_loss.item()
    return epoch_loss


def validate(model, data_iter):
    model.eval()
    n_correct = 0
    n_total = 0
    with torch.no_grad():
        for src, tgt in data_iter:
            tgt_in = tgt[:, :-1]
            pred = model(src, tgt_in)[0].argmax(dim=1)
            gold = tgt[:, 1:].contiguous().view(-1)
            n_correct += (pred == gold).sum().item()
            n_total += gold.size(0)
        return n_correct / n_total


def train(model, train, valid, epochs=30, learning_rate=0.5):
    loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

    train_losses = []
    valid_accs = []
    epochs = list(range(1, epochs + 1))
    for epoch in epochs:
        print('Training epoch {}'.format(epoch))
        train_loss = train_epoch(model, train, loss, optimizer)
        train_losses.append(train_loss)
        valid_acc = validate(model, valid)
        valid_accs.append(valid_acc)
        print('Train loss: {} ; Validation acc: {}'.format(train_loss, valid_acc))

Train a unidirectional model without attention:

In [None]:
embedding_size = vocab_size
hidden_size = 64

enc = Encoder(vocab_size, embedding_size, hidden_size)
dec = Decoder(vocab_size, embedding_size, hidden_size)
enc.embeddings.weight.data = torch.eye(vocab_size)
dec.embeddings.weight.data = enc.embeddings.weight.data
enc.embeddings.weight.requires_grad = False
dec.embeddings.weight.requires_grad = False

model = Seq2Seq(enc, dec)
print(model)

train(model, train_dataset, valid_dataset, epochs=50)

This model often manages to predict the right sequence, but it also often fails. Note that the decoder makes use of the *last* hidden state from the encoder, which has recently seen the final time step of the source sequence (in other words, the first element it needs to predict), but the other elements less recently. If only there were a way to make it easier for the model to focus on less recent positions...

The attention mechanism is an extra layer for the decoder that can do precisely this. In this exercise, we consider a simple but effective attention mechanism called *dot product attention*.


In [None]:
class DotProdAttention(nn.Module):

    def __init__(self, hidden_size):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.Tanh()
        )

    def forward(self, query, context):
        """
        query: batch x tgt_length x hidden_size
        context: batch x src_length x hidden_size
        """

        return attn_h_t, alignment

Now that attention has been implemented, we can train the model:

In [None]:
attn = DotProdAttention(hidden_size)
enc = Encoder(vocab_size, embedding_size, hidden_size, bidirectional=True)
dec = Decoder(vocab_size, embedding_size, hidden_size, attn=attn)
enc.embeddings.weight.data = torch.eye(vocab_size)
dec.embeddings.weight.data = enc.embeddings.weight.data
enc.embeddings.weight.requires_grad = False
dec.embeddings.weight.requires_grad = False

attn_model = Seq2Seq(enc, dec)
print(attn_model)

train(attn_model, train_dataset, valid_dataset, epochs=30)

We can also visualize the model's attention matrix:

In [None]:
import matplotlib.pyplot as plt

string = "abacadabacc"  # try something
reversed_string = reversed(string)

src = to_tensor(string)
tgt = to_tensor(reversed_string)

with torch.no_grad():
    _, alignment = attn_model(src, tgt)

attn_matrix = alignment.squeeze(0).numpy()
plt.matshow(attn_matrix)