# Fuck Transformer

下图是一个Transformer基本结构，先把编码器和解码器的理解搞搞懂了。

<img src="assets/transformer.png"  width="400" />

In [1]:
# set GPU id
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

In [4]:
import torch
from torch import nn
import torch.nn.functional as F
import math
torch.cuda.is_available()

True

In [5]:
X = torch.rand(128, 64, 512)
# batch_size, seq_len, hidden_size
X.shape

torch.Size([128, 64, 512])

In [6]:
d_model = 512   # embedding dimension in qkv space
n_heads = 8     # number of heads

## Fuck Token Embedding

token embedding是将输入的token转换成向量

In [None]:
class TokenEmbedding(nn.Embedding):
    def __init__(self, vocab_size, d_model):
        super(TokenEmbedding, self).__init__(vocab_size, d_model, padding_idx=1)

## Fuck Positional Encoding
$$
\begin{aligned}P E_{(p o s, 2 i)} & =\sin \left(\text { pos } / 10000^{2 i / d_{\mathrm{model}}}\right) \\P E_{(p o s, 2 i+1)} & =\cos \left(p o s / 10000^{2 i / d_{\mathrm{model}}}\right)\end{aligned}
$$

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000, device='cuda'):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model, device)
        self.encoding.require_grad = False
        
        # generate positional encoding
        pos = torch.arange(0, max_len, device).float().unsqueeze(1) # 2d tensor (max_len, 1)
        _2i = torch.arange(0, d_model, 2, device)
        
        # even indices (2i)
        self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model)))       
        # odd indices (2i+1)
        self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))
    
    def forward(self, x):
        # x: (batch_size, seq_len, d_model)
        seq_len = x.shape[1]
        return self.encoding[:seq_len, :]

In [None]:
# 可以把两个encoding结合起来~
class TransformerEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model, max_len, drop_prob, device):
        super(TransformerEmbedding, self).__init__()
        self.token_embedding = TokenEmbedding(vocab_size, d_model)
        self.position_embedding = PositionalEncoding(d_model, max_len, device)
        self.dropout = nn.Dropout(drop_prob)

    def forward(self, x):
        token_embedding = self.token_embedding(x)
        pos_embedding = self.position_embedding(x)
        return self.dropout(token_embedding + pos_embedding)

## Fuck LM

图像处理一般用的是BN，但是Transformer用的是LN。一般来说，LN可以减少显存的需求，因为BN需要加载很多batch，而LN则是在通道维度对数据进行归一化。
$$
\mu_{\mathcal B} = \frac{1}{m}\sum_{i=1}^{m}x_{i} \\
\sigma_{\mathcal B}^{2} = \frac{1}{m}\sum_{i=1}^{m}(x_{i}-\mu_{\mathcal B})^{2} \\
\hat{x}_{i} = \frac{x_{i}-\mu_{\mathcal B}}{\sqrt{\sigma_{\mathcal B}^{2}+\epsilon}} \\
y_{i} = \gamma \hat{x}_{i} + \beta
$$

In [None]:
class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-10):
        super(LayerNorm, self).__init__()
        # learnable params
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        
    def forward(self, x):
        # mean and variance
        mean = x.mean(-1, keepdim=True)
        var = x.var(-1, unbiased=False, keepdim=True)
        # normalize
        x = (x - mean) / torch.sqrt(var + self.eps)
        # scale and shift
        x = self.gamma * x + self.beta
        return x

## Fuck FFN

$$
\operatorname{FFN}(x, W_1, W_2, b_1, b_2) = \max(0, x W_1 + b_1) W_2 + b_2
$$

In [None]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, hidden_dim, dropout=0.1):
        self.fc1 = nn.Linear(d_model, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

## Fuck Multi-Head Attention
$$ \text{Attention}(Q,K,V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V $$
$$ \text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1,\dots,\text{head}_h)W^O $$

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super(MultiHeadAttention, self).__init__()
        
        self.d_model = d_model
        self.n_heads = n_heads
        
        # linear projections for qkv
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        
        # multi-head attention need combinational linear projections
        self.w_combine = nn.Linear(d_model, d_model)
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self, q, k, v, mask=None):
        batch, seq, dim = q.shape
        n_d = self.d_model // self.n_heads
        q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)
        
        # split q, k, v
        # note1: basically simply split the last dimension into n_heads
        # note2: why permute? head dim cannot be placed at the last dimension, so we need to permute it
        q = q.view(batch, seq, self.n_heads, n_d).permute(0, 2, 1, 3)
        k = k.view(batch, seq, self.n_heads, n_d).permute(0, 2, 1, 3)    
        v = v.view(batch, seq, self.n_heads, n_d).permute(0, 2, 1, 3)
        
        # attention scores
        score = q @ k.transpose(-2, -1) / math.sqrt(n_d)
        # apply mask
        if mask is not None:
            # mask = torch.tril(torch.ones(seq, seq, dtype=torch.bool))
            score = score.masked_fill(mask == 0, -1e9)
        score = self.softmax(score) @ v
        
        # combine heads
        score = score.permute(0, 2, 1, 3).contiguous().view(batch, seq, dim)
        output = self.w_combine(score)
        
        return output
        

In [9]:
attention = MultiHeadAttention(d_model=d_model, n_heads=n_heads)

# Test the attention layer
output = attention(X, X, X)
print(output)
print(output.shape)

tensor([[[ 0.2861, -0.1101, -0.1069,  ...,  0.0108,  0.1135, -0.0261],
         [ 0.3062, -0.0061, -0.0564,  ...,  0.0058,  0.1606, -0.0263],
         [ 0.2674,  0.0042, -0.0396,  ..., -0.0093,  0.1470, -0.0329],
         ...,
         [ 0.2354,  0.0711,  0.0138,  ..., -0.0502,  0.0706, -0.0317],
         [ 0.2355,  0.0703,  0.0145,  ..., -0.0496,  0.0719, -0.0338],
         [ 0.2348,  0.0689,  0.0163,  ..., -0.0479,  0.0699, -0.0328]],

        [[ 0.3113,  0.0644, -0.0197,  ..., -0.0653,  0.0694,  0.0685],
         [ 0.2187,  0.0320, -0.0161,  ..., -0.1151,  0.0081,  0.0978],
         [ 0.2235, -0.0074,  0.0051,  ..., -0.1105,  0.0440,  0.0430],
         ...,
         [ 0.2411,  0.0832,  0.0254,  ..., -0.0854,  0.0786, -0.0326],
         [ 0.2420,  0.0843,  0.0254,  ..., -0.0897,  0.0798, -0.0304],
         [ 0.2403,  0.0845,  0.0270,  ..., -0.0889,  0.0787, -0.0318]],

        [[ 0.1825, -0.0087, -0.0757,  ...,  0.0404,  0.0239, -0.0520],
         [ 0.2673,  0.1149, -0.0148,  ..., -0

## Fuck Encoder

<img src="assets/encoder.png"  width="400" />

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, ffn_hidden, n_heads, drop_prob=0.1):
        super(EncoderLayer, self).__init__()
        self.attention = MultiHeadAttention(d_model, n_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(drop_prob)
        self.ffn = PositionWiseFeedForward(d_model, ffn_hidden, drop_prob)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout2 = nn.Dropout(drop_prob)
    
    def forward(self, x, mask=None):
        _x = x 
        # self attention
        x = self.attention(x, x, x, mask)
        
        # normalization
        x = self.dropout1(x)
        x = self.norm1(x + _x)
        
        # position-wise feed-forward
        _x = x 
        x = self.ffn(x)
        
        # normalization
        x = self.dropout2(x)
        x = self.norm2(x + _x)
        return x

## Fuck Decoder

<img src="assets/decoder.png"  width="400" />

Decoder层和Encoder层的区别：
* 带mask的注意力：当前时刻是看不到未来时刻东西的
* 交叉注意力：Decoder提供q，Encoder提供k和v

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, ffn_hidden, n_heads, drop_prob=0.1):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(drop_prob)
        
        self.cross_attn = MultiHeadAttention(d_model, n_heads)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout2 = nn.Dropout(drop_prob)
        
        self.ffn = PositionWiseFeedForward(d_model, ffn_hidden, drop_prob)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout3 = nn.Dropout(drop_prob)
    
    def forward(self, dec, enc, t_mask, s_mask):
        # dec: 解码器输入
        # enc: 编码器输出
        # t_mask: 编码器的mask，对padding的掩码，统一句子的长度
        # s_mask: 解码器的mask，对未来信息的掩码，当前时刻只能看到之前的词
        _x = dec
        x = self.self_attn(dec, dec, dec, t_mask)   # 下三角掩码，也就是不希望看到未来的信息
        
        x = self.dropout1(x)
        x = self.norm1(x + _x)
        
        # 判断编码器是否有值
        if enc is not None:
            _x = x
            x = self.cross_attn(x, enc, enc, s_mask)
            x = self.dropout2(x)
            x = self.norm2(x + _x)
        
        _x = x
        x = self.ffn(x)
        x = self.dropout3(x)
        x = self.norm3(x + _x)
        
        return x

In [None]:
# 接下来把每个层拼起来
class Encoder(nn.Module):
    def __init__(
        self,
        env_voc_size,
        max_len,
        d_model,
        ffn_hidden,
        n_head,
        n_layer,
        drop_prob,
        device,
    ):
        super(Encoder, self).__init__()

        self.embedding = TransformerEmbedding(
            env_voc_size, d_model, max_len, drop_prob, device
        )

        self.layers = nn.ModuleList(
            [
                EncoderLayer(d_model, ffn_hidden, n_head, drop_prob)
                for _ in range(n_layer)
            ]
        )

    def forward(self, x, s_mask):
        x = self.embedding(x)
        for layer in self.layers:
            x = layer(x, s_mask)
        return x

In [None]:
class Decoder(nn.Module):
    def __init__(
        self,
        dec_voc_size,
        max_len,
        d_model,
        ffn_hidden,
        n_head,
        n_layer,
        drop_prob,
        device,
    ):
        super(Decoder, self).__init__()

        self.embedding = TransformerEmbedding(
            dec_voc_size, d_model, max_len, drop_prob, device
        )

        self.layers = nn.ModuleList(
            [
                DecoderLayer(d_model, ffn_hidden, n_head, drop_prob)
                for _ in range(n_layer)
            ]
        )

        self.fc = nn.Linear(d_model, dec_voc_size)

    def forward(self, dec, enc, t_mask, s_mask):
        dec = self.embedding(dec)
        for layer in self.layers:
            dec = layer(dec, enc, t_mask, s_mask)

        dec = self.fc(dec)

        return dec

## Fuck Transformer!

In [None]:
class Transformer(nn.Module):
    def __init__(
        self,
        src_pad_idx,
        trg_pad_idx,
        enc_voc_size,
        dec_voc_size,
        max_len,
        d_model,
        n_heads,
        ffn_hidden,
        n_layers,
        drop_prob,
        device,
    ):
        super(Transformer, self).__init__()
        self.encoder = Encoder(
            enc_voc_size,
            max_len,
            d_model,
            ffn_hidden,
            n_heads,
            n_layers,
            drop_prob,
            device,
        )
        self.decoder = Decoder(
            dec_voc_size,
            max_len,
            d_model,
            ffn_hidden,
            n_heads,
            n_layers,
            drop_prob,
            device,
        )

        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device

    # 构建下三角mask
    def make_pad_mask(self, q, k, pad_idx_q, pad_idx_k):
        len_q, len_k = q.size(1), k.size(1)

        # (Batch, Time, len_q, len_k)
        q = q.ne(pad_idx_q).unsqueeze(1).unsqueeze(3)
        q = q.repeat(1, 1, 1, len_k)

        k = k.ne(pad_idx_k).unsqueeze(1).unsqueeze(2)
        k = k.repeat(1, 1, len_q, 1)

        mask = q & k
        return mask

    def make_casual_mask(self, q, k):
        len_q, len_k = q.size(1), k.size(1)
        mask = (
            torch.tril(torch.ones(len_q, len_k)).type(torch.BoolTensor).to(self.device)
        )
        return mask

    def forward(self, src, trg):
        src_mask = self.make_pad_mask(src, src, self.src_pad_idx, self.src_pad_idx)
        trg_mask = self.make_pad_mask(
            trg, trg, self.trg_pad_idx, self.trg_pad_idx
        ) * self.make_casual_mask(trg, trg)
        src_trg_mask = self.make_pad_mask(trg, src, self.trg_pad_idx, self.src_pad_idx)

        enc = self.encoder(src, src_mask)
        ouput = self.decoder(trg, enc, trg_mask, src_trg_mask)
        return ouput

In [None]:
def initialize_weights(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        nn.init.xavier_uniform_(m.weight.data)