# <center> transformer框架 </center>
by Hyr1sky_He

Link to the lecture [68 Transformer【动手学深度学习v2】](https://www.bilibili.com/video/BV1Kq4y1H7FL/?share_source=copy_web&vd_source=7d2cf6f427cab8ff5afa3cb534b98123)

## Multihead Attention

In [1]:
import math
import torch
from torch import nn
from d2l import torch as d2l

if torch.cuda.is_available():
    device = torch.device('cuda')

In [2]:
def transpose_qkv(X, num_heads):
    """Note
    For parrallel computation, we can concat the heads together
    INPUT: X.shape = (batch_size, num_steps, num_hiddens)
    OUTPUT: X.shape = (batch_size * num_heads, num_steps, num_hiddens/num_heads)
    combine the num_heads and num_hiddens/num_heads together
    """
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    X = X.permute(0, 2, 1, 3)
    return X.reshape(-1, X.shape[2], X.shape[3])

def transpose_output(X, num_heads):
    """Note
    Inverse of transpose_qkv
    INPUT: X.shape = (batch_size * num_heads, num_steps, num_hiddens/num_heads)
    OUTPUT: X.shape = (batch_size, num_steps, num_hiddens)
    """
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

In [3]:
class multiheadAttention (nn.Module):
    def __init__(self, query_size, key_size, value_size,
                 num_hiddens, num_heads, dropout, bias=False, 
                 *args, **kwargs) -> None:
        super(multiheadAttention, self).__init__(*args, **kwargs)
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        """Note
        concat all the heads together for matrix multiplication
        """
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)
        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)
        output = self.attention(queries, keys, values, valid_lens)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)

In [4]:
# Test
num_hiddens, num_heads = 100, 5 
attention = multiheadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, 0.5)
attention.eval()
print(attention)

batch_size, num_steps = 2, 4
num_kvpairs, valid_lens = 2, torch.tensor([3, 2])
X = torch.ones((batch_size, num_steps, num_hiddens))
Y = torch.ones((batch_size, num_steps, num_hiddens))
attention(X, Y, Y, valid_lens).shape

multiheadAttention(
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (W_q): Linear(in_features=100, out_features=100, bias=False)
  (W_k): Linear(in_features=100, out_features=100, bias=False)
  (W_v): Linear(in_features=100, out_features=100, bias=False)
  (W_o): Linear(in_features=100, out_features=100, bias=False)
)


torch.Size([2, 4, 100])

## Transformer

In [5]:
# Feed Forward
# Actually, it's a two-layer MLP
class PositionWiseFFN(nn.Module):
    def __init__(self, ffn_num_input, ffn_num_hiddens, pw_num_outputs, **kwargs) -> None:
        super(PositionWiseFFN, self).__init__(**kwargs)
        self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(ffn_num_hiddens, pw_num_outputs)

    def forward(self, X):
        return self.dense2(self.relu(self.dense1(X)))

In [10]:
ffn = PositionWiseFFN(4, 4, 12)
ffn.eval()
print(ffn)
ffn(torch.ones((2, 3, 4)))[0]

PositionWiseFFN(
  (dense1): Linear(in_features=4, out_features=4, bias=True)
  (relu): ReLU()
  (dense2): Linear(in_features=4, out_features=12, bias=True)
)


tensor([[ 0.2211,  0.0970, -0.4608, -0.0946,  0.2883, -0.4178,  0.1048, -0.3202,
          0.1748, -0.1873,  0.4107, -0.1500],
        [ 0.2211,  0.0970, -0.4608, -0.0946,  0.2883, -0.4178,  0.1048, -0.3202,
          0.1748, -0.1873,  0.4107, -0.1500],
        [ 0.2211,  0.0970, -0.4608, -0.0946,  0.2883, -0.4178,  0.1048, -0.3202,
          0.1748, -0.1873,  0.4107, -0.1500]], grad_fn=<SelectBackward0>)

In [14]:
"""
Compare Layer Normalization and Batch Normalization
"""
ln = nn.LayerNorm(2)
bn = nn.BatchNorm1d(2)
X = torch.tensor([[1, 2], [2, 3]], dtype=torch.float32)
# Layer Normalization
print('layer norm:', ln(X), '\nbatch norm:', bn(X))

layer norm: tensor([[-1.0000,  1.0000],
        [-1.0000,  1.0000]], grad_fn=<NativeLayerNormBackward0>) 
batch norm: tensor([[-1.0000, -1.0000],
        [ 1.0000,  1.0000]], grad_fn=<NativeBatchNormBackward0>)


In [15]:
"""
Residual Connection
Layer Normalization
"""
class AddNorm(nn.Module):
    def __init__(self, normalized_shape, dropout, **kwargs) -> None:
        super(AddNorm, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(normalized_shape)

    def forward(self, X, Y):
        return self.ln(self.dropout(Y) + X)

In [16]:
add_norm = AddNorm([3, 4], 0.5)
add_norm.eval()
add_norm(torch.ones((2, 3, 4)), torch.ones((2, 3, 4))).shape

torch.Size([2, 3, 4])

In [18]:
class EncoderBlock(nn.Module):
    def __init__(self, query_size, key_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, bias=False, **kwargs) -> None:
        super(EncoderBlock, self).__init__(**kwargs)
        self.attention = multiheadAttention(query_size, key_size, value_size, num_hiddens, num_heads, dropout, bias)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
        self.addnorm2 = AddNorm(norm_shape, dropout)

    def forward(self, X, valid_lens):
        Y = self.addnorm1(X, self.attention(X, X, X, valid_lens)) # self-attention
        return self.addnorm2(Y, self.ffn(Y))

In [20]:
X = torch.ones((2, 100, 24))
valid_lens = torch.tensor([3, 2])
encoder_blk = EncoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5)
encoder_blk.eval()
encoder_blk(X, valid_lens).shape

torch.Size([2, 100, 24])

In [30]:
"""
- Self-Attention
- Add & Norm
- Feed Forward
- Add & Norm
- Encoder Block
"""
class TransformerEncoder(d2l.Encoder):
    def __init__(self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, use_bias=False, **kwargs) -> None:
        super(TransformerEncoder, self).__init__(**kwargs)
        self.num_hiddens = num_hiddens
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module(f'block{i}', EncoderBlock(query_size, key_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, use_bias))

    def forward(self, X, valid_lens, *args):
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
        for blk in self.blks:
            X = blk(X, valid_lens)
        return X

In [29]:
# Two layer Transformer Encoder
encoder = TransformerEncoder(200, 24, 24, 24, 24, [100, 24], 24, 48, 8, 2, 0.5)
encoder.eval()
encoder(torch.ones((2, 100), dtype=torch.long), valid_lens).shape

torch.Size([2, 100, 24])

In [32]:
class DecoderBlock(nn.Module):
    """No.i Block"""
    def __init__(self, query_size, key_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, i, **kwargs) -> None:
        super(DecoderBlock, self).__init__(**kwargs)
        self.i = i
        self.attention1 = multiheadAttention(query_size, key_size, value_size, num_hiddens, num_heads, dropout)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.attention2 = multiheadAttention(query_size, key_size, value_size, num_hiddens, num_heads, dropout)
        self.addnorm2 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
        self.addnorm3 = AddNorm(norm_shape, dropout)

    def forward(self, X, state):
        enc_outputs, enc_valid_lens = state[0], state[1]
        # Masked Multihead Attention
        if state[2][self.i] is None:
            key_values = X
        else:
            key_values = torch.cat((state[2][self.i], X), dim=1)
        state[2][self.i] = key_values
        if self.training: # train
            batch_size, num_steps, _ = X.shape
            # Shape: (batch_size, num_steps), the values in the attention weights
            # at the positions excluded from the attention
            dec_valid_lens = torch.arange(1, num_steps + 1, device=X.device).repeat(batch_size, 1)
        else:
            dec_valid_lens = None # predict
        X2 = self.attention1(X, key_values, key_values, dec_valid_lens) # Masked Multihead Attention
        Y = self.addnorm1(X, X2)
        Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens) # Introduced syntax factor
        Z = self.addnorm2(Y, Y2)
        return self.addnorm3(Z, self.ffn(Z)), state

In [33]:
decoder_blk = DecoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5, 0)
decoder_blk.eval()
X = torch.ones((2, 100, 24))
state = [encoder_blk(X, valid_lens), valid_lens, None]
decoder_blk(X, state)[0].shape

TypeError: multiheadAttention.forward() takes 5 positional arguments but 6 were given