# Transformer
Transformer is completely based on the attention mechanism, without using CNN or RNN. It used in the seq2seq tasks in text data, but have been widely used in many DL tasks, such as language, visual and audio.
It is a classic encoder-decoder architecture, which has been shown in the fig below. Compared with seq2seq model we implemented by using Bahdanau attention in section10.4, Transformer's encoder and decoder is stacked by layers of self-attention modules. Origin input sequence and output embedding sequence will add with positional encoding, then put them into encoder and decoder.
![transformer](../statics/imgs/section10.7_fig1.jpg)
Generally, fig above shows that the encoder of Transformer is consists of many layers with same structure, and each layer owns 2 sublayer. First sublayer is multi-head self-attention pooling, and second sublayer is positionwise feed-forward network. To be concisely, when we execute self-attention calculations in encoder, query, key and value are come from the output of the former encoder layer. In addition, each sublayer uses residual connection.
Decoder also stacked with many same layers, and uses residual and layer normalization in the layer as well. Expect from two sublayer we described in the encoder, the third sublayer is inserted into the sublayer as well, which we called encoder-decoder attention layer. In encoder-decoder attention, query comes from the output from the last decoder layer. However, each position in decoder can only consider every position before current position. This type of masked attention keeps attributes of auto-regressive, and make sure the prediction only depends on those generated output tokens.

In [1]:
import math
import torch
from torch import nn
from d2l import torch as d2l
import pandas as pd

In [2]:
# position-wise feed-forward network
class PositionWiseFFN(nn.Module):
    def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, **kwargs):
        super(PositionWiseFFN, self).__init__(**kwargs)
        self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)

    def forward(self, X):
        return self.dense2(self.relu(self.dense1(X)))

In [3]:
ffn = PositionWiseFFN(4, 4, 8)
ffn.eval()
ffn(torch.ones((2, 3, 4)))[0]

tensor([[ 0.4256, -0.0439, -0.2478,  0.1027,  0.3032, -0.6424, -0.0154, -0.6512],
        [ 0.4256, -0.0439, -0.2478,  0.1027,  0.3032, -0.6424, -0.0154, -0.6512],
        [ 0.4256, -0.0439, -0.2478,  0.1027,  0.3032, -0.6424, -0.0154, -0.6512]],
       grad_fn=<SelectBackward0>)

In [4]:
# residual & normalization
ln = nn.LayerNorm(2)
bn = nn.BatchNorm1d(2)
X = torch.tensor([[1, 2], [2, 3]], dtype=torch.float32)
print('layer norm:', ln(X), '\nbatch norm:', bn(X))

layer norm: tensor([[-1.0000,  1.0000],
        [-1.0000,  1.0000]], grad_fn=<NativeLayerNormBackward0>) 
batch norm: tensor([[-1.0000, -1.0000],
        [ 1.0000,  1.0000]], grad_fn=<NativeBatchNormBackward0>)


In [5]:
class AddNorm(nn.Module):
    def __init__(self, normalized_shape, dropout, **kwargs):
        super(AddNorm, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(normalized_shape)

    def forward(self, X, Y):
        # layernorm after residual
        return self.ln(self.dropout(Y) + X)

In [6]:
add_norm = AddNorm([3, 4], 0.5)
add_norm.eval()
add_norm(torch.ones((2, 3, 4)), torch.ones((2, 3, 4))).shape

torch.Size([2, 3, 4])

In [7]:
# encoder
class EncoderBlock(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, use_bias=False, **kwargs):
        super(EncoderBlock, self).__init__(**kwargs)
        self.attention = d2l.MultiHeadAttention(key_size, query_size, value_size,
                                                num_hiddens, num_heads, dropout, use_bias)
        self.addNorm1 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
        self.addNorm2 = AddNorm(norm_shape, dropout)

    def forward(self, X, valid_lens):
        Y = self.addNorm1(X, self.attention(X, X, X, valid_lens))
        return self.addNorm2(Y, self.ffn(Y))

In [8]:
# any layer of encoder will not change the shape of input
X = torch.ones((2, 100, 24))
valid_lens = torch.tensor([3, 2])
encoder_blk = EncoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5)
encoder_blk.eval()
encoder_blk(X, valid_lens).shape

torch.Size([2, 100, 24])

In [9]:
class TransformerEncoder(d2l.Encoder):
    def __init__(self, vocab_size, key_size, query_size, value_size,
                 num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout,
                 use_bias=False, **kwargs):
        super(TransformerEncoder, self).__init__(**kwargs)
        self.num_hiddens = num_hiddens
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module("block" + str(i),
                                 EncoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape,ffn_num_input,
                                              ffn_num_hiddens, num_heads, dropout, use_bias))
    def forward(self, X, valid_lens, *args):
        # as positional encoding is between -1 and 1, so embedding values have to be scaled, then adds with positional code
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
        self.attention_weights = [None] * len(self.blks)
        for i, blk in enumerate(self.blks):
            X = blk(X, valid_lens)
            self.attention_weights[i] = blk.attention.attention.attention_weights
        return X

In [10]:
encoder = TransformerEncoder(200, 24, 24, 24, 24, [100, 24], 24, 48, 8, 2, 0.5)
encoder.eval()

TransformerEncoder(
  (embedding): Embedding(200, 24)
  (pos_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (blks): Sequential(
    (block0): EncoderBlock(
      (attention): MultiHeadAttention(
        (attention): DotProductAttention(
          (dropout): Dropout(p=0.5, inplace=False)
        )
        (W_q): Linear(in_features=24, out_features=24, bias=False)
        (W_k): Linear(in_features=24, out_features=24, bias=False)
        (W_v): Linear(in_features=24, out_features=24, bias=False)
        (W_o): Linear(in_features=24, out_features=24, bias=False)
      )
      (addNorm1): AddNorm(
        (dropout): Dropout(p=0.5, inplace=False)
        (ln): LayerNorm((100, 24), eps=1e-05, elementwise_affine=True)
      )
      (ffn): PositionWiseFFN(
        (dense1): Linear(in_features=24, out_features=48, bias=True)
        (relu): ReLU()
        (dense2): Linear(in_features=48, out_features=24, bias=True)
      )
      (addNorm2): AddNorm(
        (d

In [11]:
encoder(torch.ones((2, 100), dtype=torch.long), valid_lens).shape

torch.Size([2, 100, 24])

In [None]:
# decoder