# Transformer Networks

In [None]:
!pip install d2l==1.0.3

import math
import torch
from torch import nn
from d2l import torch as d2l
import pandas as pd

## Queries, Keys, Values

Given a database $\mathcal{D} = \{(k_1, v_1), \dots, (k_m, v_m)\}$ containing $m$ key-value tuples and a query $q$, it is possible to define the attention of the query over the database as $A(q, \mathcal{D}) = \sum_{i = 1}^m \alpha(q, k_i)v_i$, where $\alpha$ denotes a collection of normalized attention weights.

## Multi-Head Attention

Multi-head attention comes in handy to extract different features by implementing different attention heads working in parallel and concatenating each result to obtain the final result.

In [None]:
class MultiHeadAttention(d2l.Module):
    def __init__(self, num_hiddens, num_heads, dropout, bias=False, **kwargs):
        super().__init__()
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        self.W_q = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_k = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_v = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_o = nn.LazyLinear(num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        '''
        Shape of queries, keys, or values: (batch_size, no. of queries or key-value pairs, num_hiddens).
        Shape of valid_lens: (batch_size,) or (batch_size, no. of queries).
        After transposing, shape of output queries, keys, or values: (batch_size * num_heads, no. of queries or key-value pairs, num_hiddens / num_heads).
        '''
        queries = self.transpose_qkv(self.W_q(queries))
        keys = self.transpose_qkv(self.W_k(keys))
        values = self.transpose_qkv(self.W_v(values))

        if valid_lens is not None:
            '''
            On axis 0, copy the first item for num_heads times.
            Then, copy the next item, and so on.
            '''
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0)

        '''
        Shape of output: (batch size * num_heads, no. of queries, num_hiddens / num_heads).
        Shape of output_concat: (batch_size, no. of queries, num_hiddens).
        '''
        output = self.attention(queries, keys, values, valid_lens)
        output_concat = self.transpose_output(output)
        return self.W_o(output_concat)

In [None]:
@d2l.add_to_class(MultiHeadAttention)

# Transposition for parallel computation of multiple attention heads.
def transpose_qkv(self, X):
    '''
    Shape of input X: (batch_size, no. of queries or key-value pairs, num_hiddens).
    Shape of output X: (batch_size, no. of queries or key-value pairs, num_heads, num_hiddens / num_heads)
    '''
    X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1)
    '''
    Shape of output X: (batch_size, num_heads, no. of queries or key-value pairs, num_hiddens / num_heads).
    '''
    X = X.permute(0, 2, 1, 3)
    '''
    Shape of output: (batch_size * num_heads, no. of queries or key-value pairs, num_hiddens / num_heads).
    '''
    return X.reshape(-1, X.shape[2], X.shape[3])

# Reverse the operation defined in transpose_qkv.
@d2l.add_to_class(MultiHeadAttention)
def transpose_output(self, X):
    X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

In [None]:
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, num_kvpairs = 2, 4, 6
valid_lens = torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
d2l.check_shape(attention(X, Y, Y, valid_lens), (batch_size, num_queries, num_hiddens))

## Self-Attention

Self-attention makes comparisons between input vectors. <br>
However, since self-attention is permutation invariant, it might struggle on some tasks where the order of the input sequence matters. <br>
For this reason, it can come in handy, for each input $x_j$, to append a positional encoding $p_j = pos(j)$ that is defined using a function $pos: \mathbb{N} \to \mathbb{R}^d$, which can be learned using lookup tables or fixed a priori.

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.P = torch.zeros((1, max_len, num_hiddens)) # Create a long enough P.
        X = torch.arange(max_len, dtype=torch.float32).reshape(1, 1) / torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)

    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)

In [None]:
encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)', figsize=(6, 2.5), legend=["Col %d" % d for d in torch.arange(6, 10)])

## Transformers

Commonly used for natural language processing tasks, a transformer is a particular instance of the encoder-decoder architecture. <br>
In particular, the encoder takes the input data and processes the corresponding positional encoding using a multi-head self-attention layer (eventually adding a residual connection to handle gradient flows) and performing computations via layer normalization, a multilayer perceptron and another layer normalization before producing the output. <br>
The decoder works similarly as the encoder, but it implements both a masked multi-head attention layer and a generic multi-head attention layer for the encoder outputs. <br>
Overall, transformers are highly scalable and parallelizable, although they require much more memory to store the needed parameters.

In [None]:
class PositionWiseFFN(nn.Module):
    def __init__(self, ffn_num_hiddens, ffn_num_outputs):
        super().__init__()
        self.dense1 = nn.LazyLinear(ffn_num_hiddens)
        self.relu = nn.ReLU()
        self.dense2 = nn.LazyLinear(ffn_num_outputs)

    def forward(self, X):
        return self.dense2(self.relu(self.dense1(X)))

In [None]:
class AddNorm(nn.Module): 
    def __init__(self, norm_shape, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(norm_shape)

    def forward(self, X, Y):
        return self.ln(self.dropout(Y) + X)

Start by creating the encoder.

In [None]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, num_hiddens, ffn_num_hiddens, num_heads, dropout, use_bias=False):
        super().__init__()
        self.attention = d2l.MultiHeadAttention(num_hiddens, num_heads, dropout, use_bias)
        self.addnorm1 = AddNorm(num_hiddens, dropout)
        self.ffn = PositionWiseFFN(ffn_num_hiddens, num_hiddens)
        self.addnorm2 = AddNorm(num_hiddens, dropout)

    def forward(self, X, valid_lens):
        Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
        return self.addnorm2(Y, self.ffn(Y))

In [None]:
class TransformerEncoder(d2l.Encoder):
    def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads, num_blks, dropout, use_bias=False):
        super().__init__()
        self.num_hiddens = num_hiddens
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
        self.blks = nn.Sequential()
        for i in range(num_blks):
            self.blks.add_module("block"+str(i), TransformerEncoderBlock(
                num_hiddens, ffn_num_hiddens, num_heads, dropout, use_bias))

    def forward(self, X, valid_lens):
        '''
        Since positional encoding values are between -1 and 1, the embedding
        values are multiplied by the square root of the embedding dimension
        to rescale before they are summed up.
        '''
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
        self.attention_weights = [None] * len(self.blks)
        for i, blk in enumerate(self.blks):
            X = blk(X, valid_lens)
            self.attention_weights[i] = blk.attention.attention.attention_weights
        return X

Then, implement the decoder.

In [None]:
class TransformerDecoderBlock(nn.Module):
    def __init__(self, num_hiddens, ffn_num_hiddens, num_heads, dropout, i):
        super().__init__()
        self.i = i
        self.attention1 = d2l.MultiHeadAttention(num_hiddens, num_heads, dropout)
        self.addnorm1 = AddNorm(num_hiddens, dropout)
        self.attention2 = d2l.MultiHeadAttention(num_hiddens, num_heads, dropout)
        self.addnorm2 = AddNorm(num_hiddens, dropout)
        self.ffn = PositionWiseFFN(ffn_num_hiddens, num_hiddens)
        self.addnorm3 = AddNorm(num_hiddens, dropout)

    def forward(self, X, state):
        enc_outputs, enc_valid_lens = state[0], state[1]
        '''
        During training, all the tokens of any output sequence are processed
        at the same time, so state[2][self.i] is None as initialized.
        When decoding any output sequence token by token during prediction,
        state[2][self.i] contains representations of the decoded output at
        the i-th block up to the current time step.
        '''
        if state[2][self.i] is None:
            key_values = X
        else:
            key_values = torch.cat((state[2][self.i], X), dim=1)
        state[2][self.i] = key_values
        if self.training:
            batch_size, num_steps, _ = X.shape
            '''
            Shape of dec_valid_lens: (batch_size, num_steps), where every row is [1, 2, ..., num_steps].
            '''
            dec_valid_lens = torch.arange(
                1, num_steps + 1, device=X.device).repeat(batch_size, 1)
        else:
            dec_valid_lens = None

        X2 = self.attention1(X, key_values, key_values, dec_valid_lens) # Self-attention.
        Y = self.addnorm1(X, X2)
        Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens) # Encoder-decoder self-attention.
        Z = self.addnorm2(Y, Y2)
        return self.addnorm3(Z, self.ffn(Z)), state

In [None]:
class TransformerDecoder(d2l.AttentionDecoder):
    def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads, num_blks, dropout):
        super().__init__()
        self.num_hiddens = num_hiddens
        self.num_blks = num_blks
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
        self.blks = nn.Sequential()
        for i in range(num_blks):
            self.blks.add_module("block"+str(i), TransformerDecoderBlock(num_hiddens, ffn_num_hiddens, num_heads, dropout, i))
        self.dense = nn.LazyLinear(vocab_size)

    def init_state(self, enc_outputs, enc_valid_lens):
        return [enc_outputs, enc_valid_lens, [None] * self.num_blks]

    def forward(self, X, state):
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
        self._attention_weights = [[None] * len(self.blks) for _ in range (2)]
        for i, blk in enumerate(self.blks):
            X, state = blk(X, state)
            self._attention_weights[0][i] = blk.attention1.attention.attention_weights # Decoder self-attention weights.
            self._attention_weights[1][i] = blk.attention2.attention.attention_weights # Encoder-decoder self-attention weights.
        return self.dense(X), state

    @property
    def attention_weights(self):
        return self._attention_weights

Having created the transformer, try to train it. <br>
Then, after training, use the model to try and translate some sentences to English to French.

In [None]:
ata = d2l.MTFraEng(batch_size=128)
num_hiddens, num_blks, dropout = 256, 2, 0.2
ffn_num_hiddens, num_heads = 64, 4
encoder = TransformerEncoder(len(data.src_vocab), num_hiddens, ffn_num_hiddens, num_heads, num_blks, dropout)
decoder = TransformerDecoder(len(data.tgt_vocab), num_hiddens, ffn_num_hiddens, num_heads, num_blks, dropout)
model = d2l.Seq2Seq(encoder, decoder, tgt_pad=data.tgt_vocab['<pad>'], lr=0.001)
trainer = d2l.Trainer(max_epochs=30, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)

In [None]:
engs = ['go .', 'i lost .', 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
preds, _ = model.predict_step(data.build(engs, fras), d2l.try_gpu(), data.num_steps)

for en, fr, p in zip(engs, fras, preds):
    translation = []
    for token in data.tgt_vocab.to_tokens(p):
        if token == '<eos>':
            break
        translation.append(token)
    print(f'{en} => {translation}, bleu,'
          f'{d2l.bleu(" ".join(translation), fr, k=2):.3f}')

Show the encoder and decoder weights.

In [None]:
_, dec_attention_weights = model.predict_step(data.build([engs[-1]], [fras[-1]]), d2l.try_gpu(), data.num_steps, True)
enc_attention_weights = torch.cat(model.encoder.attention_weights, 0)
shape = (num_blks, num_heads, -1, data.num_steps)
enc_attention_weights = enc_attention_weights.reshape(shape)
d2l.check_shape(enc_attention_weights, (num_blks, num_heads, data.num_steps, data.num_steps))

In [None]:
d2l.show_heatmaps(enc_attention_weights.cpu(), xlabel='Key positions', ylabel='Query positions', titles=['Head %d' % i for i in range(1, 5)], figsize=(7, 3.5))

In [None]:
dec_attention_weights_2d = [head[0].tolist() for step in dec_attention_weights for attn in step for blk in attn for head in blk]
dec_attention_weights_filled = torch.tensor(pd.DataFrame(dec_attention_weights_2d).fillna(0.0).values)
shape = (-1, 2, num_blks, num_heads, data.num_steps)
dec_attention_weights = dec_attention_weights_filled.reshape(shape)
dec_self_attention_weights, dec_inter_attention_weights = \ dec_attention_weights.permute(1, 2, 3, 0, 4)

In [None]:
d2l.check_shape(dec_self_attention_weights, (num_blks, num_heads, data.num_steps, data.num_steps))
d2l.check_shape(dec_inter_attention_weights, (num_blks, num_heads, data.num_steps, data.num_steps))

In [None]:
d2l.show_heatmaps(dec_self_attention_weights[:, :, :, :], xlabel='Key positions', ylabel='Query positions', titles=['Head %d' % i for i in range(1, 5)], figsize=(7, 3.5))

In [None]:
d2l.show_heatmaps(dec_inter_attention_weights, xlabel='Key positions', ylabel='Query positions', titles=['Head %d' % i for i in range(1, 5)], figsize=(7, 3.5))