This notebook is copied from [*The Annotated Transformer*](http://nlp.seas.harvard.edu/2018/04/03/attention.html), with some modifications.

In [2]:
# Import dependencies.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math, copy, time
from torch.autograd import Variable
import matplotlib.pyplot as plt
import seaborn
import numpy as np

seaborn.set_context(context="talk")
% matplotlib inline

In [None]:
class EncoderDecoder(nn.Module):
    """
    A standard Encoder-Decoder architecture. Base for this and many
    other models.
    """

    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.generator = generator

    def forward(self, src, tgt, src_mask, tgt_mask):
        """Take in and process masked source and target sequences."""
        return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)

    def encode(self, src, src_mask):
        return self.encoder(self.src_embed(src), src_mask)

    def decode(self, memory, src_mask, tgt, tgt_mask):
        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)

In [None]:
class Generator(nn.Module):
    """Define standard linear + softmax generation step."""

    def __init__(self, d_model, vocab):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab)  # Projection matrix.

    def forward(self, x):
        return F.log_softmax(self.proj(x), dim=-1)

## Encoder and Decoder Stacks
### Encoder
The encoder is composed of a stack of $N = 6$ identical layers.

In [None]:
# This is a function!
def clones(module, N):
    """Produce N identical layers."""
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

In [None]:
class Encoder(nn.Module):
    """Core encoder is a stack of N layers."""

    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def forward(self, x, mask):
        """Pass the input (and mask) through each layer in turn."""
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

We employ a **residual connection** around each of the two sub-layers,
followed by **layer normalization**.

In [None]:
class LayerNorm(nn.Module):
    """Construct a Layer Norm module."""

    def __init__(self, n_features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(n_features))
        self.b_2 = nn.Parameter(torch.zeros(n_features))
        self.eps = eps  # For numerical stability.

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)  # `keepdim` - whether the output tensor has `dim` retained or not.
        std = x.std(dim=-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2  # Element-wise multiplication here.

Note that the output of each sub-layer is $LayerNorm(x + Dropout(Sublayer(x)))$,
where $Sublayer(x)$ is the function implemented by the sub-layer itself.
We apply **dropout** to the output of each sub-layer, **before** it is added to the sub-layer input and normalized.

To facilitate these residual connections, all sub-layers in the model,
as well as the embedding layers, produce outputs of dimension $d_{model} = 512$.

In [None]:
class SublayerConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    Layer Norm and Residual Connection implemented as it is in the paper.
    """

    def __init__(self, size, p_dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(p=p_dropout)

    def forward(self, x, sublayer):
        """Apply residual connection to any sublayer with the same size."""
        return self.norm(x + self.dropout(sublayer(x)))

Each **Encoder** layer has two sub-layers. The first is a multi-headed
self-attention mechanism, and the second is a simple, position-wise fully-connected feed-forward layer.

In [None]:
class EncoderLayer(nn.Module):
    """Encoder is made up of self-attention and feed forward."""

    def __init__(self, size, self_attn, feed_forward, p_dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size=size, p_dropout=p_dropout), N=2)
        self.size = size

    def forward(self, x, mask):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

### Decoder
The decoder is also composed of a stack of $N = 6$ identical layers.

In [None]:
class Decoder(nn.Module):
    """Generic N layer decoder with masking."""

    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)  # This is `self.size`, not a method.

    def forward(self, x, memory, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return self.norm(x)  # Apply LayerNorm on the last sublayer output.

In addition to the two sub-layers in each encoder layer, the decoder inserts a third sub-layer, which performs multi-headed attention over the output of the **encoder stack**. Similar to the encoder, we employ **residual connections** around each of the sub-layers, followed by **layer normalization**.

In [None]:
class DecoderLayer(nn.Module):
    """Decoder is made of self-attention, source-attention and feed-forward."""

    def __init__(self, size, self_attn, src_attn, feed_forward, p_dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size=size, p_dropout=p_dropout), N=3)

    def forward(self, x, memory, src_mask, tgt_mask):
        m = memory  # Encoder stack output.
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))  # Masked self-attention.
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        return self.sublayer[2](x, self.feed_forward)

We also modify the self-attention sub-layer in the **decoder stack** to prevent positions from attending to *subsequent* positions. This masking, combined with fact that the output embeddings are offset by one position, ensures that the prediction $i$ can depend only on the *known* outputs at positions less than $i$.

In [None]:
def subsequent_mask(size):
    """Mask out subsequent positions."""
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0

In [None]:
# The position each tgt word (row) is allowed to look at (column).
# Words are blocked for attending to future words during training.
plt.figure(figsize=(5, 5))
plt.imshow(subsequent_mask(20)[0])
None

### Attention
An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a *compatibility function* of the query with the corresponding key.

We call our particular attention “Scaled Dot-Product Attention”. The input consists of queries and keys of dimension $d_k$, and values of dimension $d_v$. We compute the dot products of the query with all keys, divide each by $\sqrt{d_k}$, and apply a softmax function to obtain the weights on the values.

In practice, we compute the attention function on a set of queries simultaneously, packed together into a matrix $Q$. The keys and values
are also packed together into matrices $K$ and $V$. We compute the matrix of outputs as:
$$
\mathrm{Attention}(Q, K, V) = \mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V
$$

In [None]:
def attention(query, key, value, mask=None, dropout=None):
    """Compute Scaled Dot-Product Attention."""
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)  # Mask to -inf
    p_attn = F.softmax(scores, dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

Multi-head attention allows the model to jointly attend to information from **different representation subspaces** at different positions. With a single attention head, averaging inhibits this.
$$
\mathrm{MultiHead}(Q, K, V) = \mathrm{Concat}(\mathrm{head_1}, ..., \mathrm{head_h})W^O    \\
    \text{where}~\mathrm{head_i} = \mathrm{Attention}(QW^Q_i, KW^K_i, VW^V_i)
$$

Where the projections are **parameter matrices** $W^Q_i \in \mathbb{R}^{d_{\text{model}} \times d_k}$, $W^K_i \in \mathbb{R}^{d_{\text{model}} \times d_k}$, $W^V_i \in \mathbb{R}^{d_{\text{model}} \times d_v}$ and $W^O \in \mathbb{R}^{hd_v \times d_{\text{model}}}$.                                                                                                                                                                                             In this work we employ $h=8$ parallel attention layers, or heads. For each of these we use $d_k=d_v=d_{\text{model}}/h=64$. Due to the reduced dimension of each head, the total computational cost is similar to that of single-head attention with full dimensionality.

In [None]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, p_dropout=0.1):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k here.
        self.d_v = self.d_k = d_model // h
        self.h = h  # Number of attention heads.
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=p_dropout)

    def forward(self, query, key, value, mask=None):
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        n_batches = query.size(0)

        # 1) Do all the linear projections in batch from d_model => h * d_k
        query, key, value = [
            l(x).view(n_batches, -1, self.h, self.d_k).transpose(1, 2) for l, x in
            zip(self.linears, (query, key, value))
        ]

        # 2) Apply attention on all the projected vectors in batch.
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)

        # 3) "Concat" using a view and apply a final linear.
        x = x.transpose(1, 2).contiguous().view(n_batches, -1, self.h * self.d_k)
        return self.linears[-1](x)
