<a href="https://colab.research.google.com/github/Dominique-Yiu/ColabCode/blob/main/Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Transformer Realization
**Transformer Architecture** \
1. Encoder
- Input words embedding
  - Turn the sparse one-hot vector into the dense contiguouly vector by FFN without bias.
- Position encoding
- Multi-head self-attention
- Feed-forword network
2. Decoder
- Output words embedding
- Masked multi-head self-attention
- Multi-head cross-attention
- Feed-forword network
- Softmax

实现的难点： \
1. Word Embedding
2. Posotion Embedding
3. Encoder self-attention mask
4. Intra-attention ask
5. Decoder self-attention mask
6. Multi-head self-attention

In [3]:
import torch
import numpy
import torch.nn as nn
import torch.nn.functional as F

In [4]:
batch_size = 2
# the size of the word sheet
max_num_src_words = 8
max_num_tgt_words = 8
# 
model_dim = 8
# the max length of sequence
max_src_seq_len = 5
max_tgt_seq_len = 5
# 位置索引最大值
max_position_len = 5
# generate the sequence length randomly, its size is fixed
# src_len = torch.randint(2, 5, (batch_size,))
# tgt_len = torch.randint(2, 5, (batch_size,))
src_len = torch.Tensor([2, 4]).to(torch.int32)
tgt_len = torch.Tensor([4, 3]).to(torch.int32)

# generate the src/tgt sentence, and pad this sentence with default value '0'
src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_src_words, (L,)), (0, max(src_len) - L)), 0) \
           for L in src_len])
tgt_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_tgt_words, (L,)), (0, max(tgt_len) - L)), 0) \
           for L in tgt_len])

"""Word Embedding"""
# 构造 Word Embedding
src_embedding_table = nn.Embedding(max_num_src_words + 1, model_dim)
tgt_embedding_table = nn.Embedding(max_num_tgt_words + 1, model_dim)
src_embedding = src_embedding_table(src_seq)
tgt_embedding = tgt_embedding_table(tgt_seq)

print(src_embedding_table.weight)
print(src_seq)
print(src_embedding)

Parameter containing:
tensor([[-1.6669, -0.0190, -0.7397, -1.3589,  2.0105, -1.5195,  0.5483,  1.6877],
        [ 1.1968,  1.3093,  0.6002,  1.9838, -0.0909, -0.1004, -0.9728, -1.9918],
        [ 0.7387, -1.4952, -0.6838,  0.4220, -0.0895,  0.9863, -0.8907,  0.8737],
        [ 0.3023,  0.0844,  0.0841, -0.2829, -0.6646,  0.1810, -0.9762, -0.4892],
        [ 0.9160,  0.2793, -0.3699,  0.3383, -0.4026, -1.5498,  0.4681,  0.0609],
        [ 0.3499, -0.4947, -1.4099,  1.2324, -1.1353,  0.0489,  0.1097,  0.1956],
        [-1.2591,  0.7273,  1.1955, -1.2456,  1.9668, -0.3416,  1.0928, -0.7418],
        [ 1.8613, -0.2107, -1.4659,  1.0338,  0.1964, -1.6763, -1.9309, -0.3148],
        [ 0.4745,  0.8303, -1.2484, -1.9925, -0.8226,  0.4716, -1.0549,  0.7839]],
       requires_grad=True)
tensor([[7, 7, 0, 0],
        [4, 2, 6, 2]])
tensor([[[ 1.8613, -0.2107, -1.4659,  1.0338,  0.1964, -1.6763, -1.9309,
          -0.3148],
         [ 1.8613, -0.2107, -1.4659,  1.0338,  0.1964, -1.6763, -1.9309,
 

In [5]:
"""Position Embedding"""
# 构造Position Embedding
pos_mat  =  torch.arange(max_position_len).reshape((-1, 1))
i_mat = torch.pow(10000, torch.arange(0, 8, 2).reshape((1, -1)) / model_dim)
pe_embedding_table = torch.zeros(max_position_len, model_dim)
pe_embedding_table[:, 0::2] = torch.sin(pos_mat / i_mat)
pe_embedding_table[:, 1::2] = torch.cos(pos_mat / i_mat)

pe_embedding = nn.Embedding(max_position_len, model_dim)
pe_embedding.weight = nn.Parameter(pe_embedding_table, requires_grad = False)
src_pos = torch.cat([torch.unsqueeze(torch.arange(max(src_len)), 0) for _ in src_len]).to(torch.int32)
tgt_pos = torch.cat([torch.unsqueeze(torch.arange(max(tgt_len)), 0) for _ in tgt_len]).to(torch.int32)

src_pos_embedding = pe_embedding(src_pos)
tgt_pos_embedding = pe_embedding(tgt_pos)
print(src_pos_embedding)
print(tgt_pos_embedding)

tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
           1.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  9.9833e-02,  9.9500e-01,  9.9998e-03,
           9.9995e-01,  1.0000e-03,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  1.9867e-01,  9.8007e-01,  1.9999e-02,
           9.9980e-01,  2.0000e-03,  1.0000e+00],
         [ 1.4112e-01, -9.8999e-01,  2.9552e-01,  9.5534e-01,  2.9995e-02,
           9.9955e-01,  3.0000e-03,  1.0000e+00]],

        [[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
           1.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  9.9833e-02,  9.9500e-01,  9.9998e-03,
           9.9995e-01,  1.0000e-03,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  1.9867e-01,  9.8007e-01,  1.9999e-02,
           9.9980e-01,  2.0000e-03,  1.0000e+00],
         [ 1.4112e-01, -9.8999e-01,  2.9552e-01,  9.5534e-01,  2.9995e-02,
           9.9955e-01,  3.0000e-03,  1.0000e+00]

In [6]:
import numpy as np
"""Encoder: Self-Attention Mask"""
# 构造encoder的self-attention mask
# mask的shape: [batch_size, max_src_len, max_src_len]，数值为1/-inf
valid_encoder_pos = torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max(src_len) - L)), 0) for L in src_len]), 2)
valid_encoder_pos_matrix = torch.bmm(valid_encoder_pos, valid_encoder_pos.transpose(1, 2))
invalid_encoder_pos_matrix = 1 - valid_encoder_pos_matrix
mask_encoder_self_attention = invalid_encoder_pos_matrix.to(torch.bool)
score = torch.randn(batch_size, max(src_len), max(src_len))
masked_score = score.masked_fill(mask_encoder_self_attention, -1e9)
prob = F.softmax(masked_score, -1)

print(score)
print(masked_score)
print(prob)

tensor([[[ 0.7272,  1.8509,  1.3269,  0.5349],
         [ 2.4649, -0.9675, -0.8312, -0.7691],
         [-0.6841,  0.6586, -0.1978, -1.0875],
         [-0.3524, -1.1417,  0.4297,  0.7176]],

        [[ 0.2645, -0.4864, -1.2504, -0.4116],
         [-0.1779, -0.0664,  0.5948,  0.3009],
         [-0.3937,  0.3348,  0.5173,  0.3991],
         [ 1.9093, -0.8432, -0.1482,  0.0304]]])
tensor([[[ 7.2721e-01,  1.8509e+00, -1.0000e+09, -1.0000e+09],
         [ 2.4649e+00, -9.6753e-01, -1.0000e+09, -1.0000e+09],
         [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
         [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09]],

        [[ 2.6455e-01, -4.8645e-01, -1.2504e+00, -4.1157e-01],
         [-1.7793e-01, -6.6429e-02,  5.9476e-01,  3.0092e-01],
         [-3.9373e-01,  3.3476e-01,  5.1728e-01,  3.9905e-01],
         [ 1.9093e+00, -8.4321e-01, -1.4817e-01,  3.0356e-02]]])
tensor([[[0.2453, 0.7547, 0.0000, 0.0000],
         [0.9687, 0.0313, 0.0000, 0.0000],
         [0.2500, 0.2500, 

以上主要实现了词向量，位置编码，编码器子注意力的掩码

In [26]:
"""Corss-Attention"""
# Q @ K^T shape: [batch_size, tht_seq_len, src_seq_len]
valid_encoder_pos = torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max(src_len) - L)), 0) for L in src_len]), 2)
valid_decoder_pos = torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max(tgt_len) - L)), 0) for L in tgt_len]), 2)
# 目标序列位置对原始序列的关系（有效性）
valid_cross_pos_matrix = torch.bmm(valid_encoder_pos, valid_decoder_pos.transpose(1, 2))
invalid_cross_pos_matrix = 1 - valid_cross_pos_matrix
mask_cross_attention = invalid_cross_pos_matrix.to(torch.bool)
score = torch.randn(batch_size, max(tgt_len), max(src_len))
masked_score = score.masked_fill(mask_cross_attention, -1e9)
prob = F.softmax(masked_score, -1)

print(valid_encoder_pos)
print(valid_decoder_pos)
print(valid_cross_pos_matrix)
print(masked_score)
print(prob)

tensor([[[1.],
         [1.],
         [0.],
         [0.]],

        [[1.],
         [1.],
         [1.],
         [1.]]])
tensor([[[1.],
         [1.],
         [1.],
         [1.]],

        [[1.],
         [1.],
         [1.],
         [0.]]])
tensor([[[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[1., 1., 1., 0.],
         [1., 1., 1., 0.],
         [1., 1., 1., 0.],
         [1., 1., 1., 0.]]])
tensor([[[-2.7826e-01, -1.8842e+00, -1.0313e+00, -1.6518e-01],
         [-1.5267e+00,  1.0573e-01, -1.5322e+00,  2.0614e-01],
         [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
         [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09]],

        [[ 1.3311e-01,  2.7074e+00,  5.1707e-01, -1.0000e+09],
         [ 1.5017e+00, -5.1609e-01,  8.0322e-01, -1.0000e+09],
         [ 1.5498e+00,  8.3718e-02,  7.6715e-01, -1.0000e+09],
         [-2.6354e-01, -2.0937e-01,  7.5223e-01, -1.0000e+09]]])


In [35]:
"""Decoder Self-Attention Mask"""
# 因果Mask
valid_decoder_tri_matrix = torch.cat([torch.unsqueeze(F.pad(torch.tril(torch.ones((L, L))), (0, max(tgt_len) - L, 0, max(tgt_len) - L)), 0) for L in tgt_len])
invalid_decoder_tri_matrix = 1 - valid_decoder_tri_matrix
mask_invalid_decoder_tri_matrix = invalid_decoder_tri_matrix.to(torch.bool)

score = torch.randn(batch_size, max(tgt_len), max(tgt_len))
masked_score = score.masked_fill(mask_invalid_decoder_tri_matrix, -1e9)
prob = F.softmax(masked_score, -1)
print(tgt_len)
print(prob)

tensor([4, 3], dtype=torch.int32)
tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.1625, 0.8375, 0.0000, 0.0000],
         [0.1744, 0.1313, 0.6943, 0.0000],
         [0.1704, 0.4480, 0.0531, 0.3285]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [0.6037, 0.3963, 0.0000, 0.0000],
         [0.3056, 0.2307, 0.4637, 0.0000],
         [0.2500, 0.2500, 0.2500, 0.2500]]])


$Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})V$

In [None]:
"""Scaled Self-Attention"""
def scaled_dot_product_attention(Q, K, V, attn_mask):
  score = torch.bmm(Q, K.transpose(-2, -1)) / torch.sqrt(model_dim)
  masked_score = torch.masked_fill(score * attn_mask, -1e9)
  prob = F.softmax(masked_score, -1)
  context = torch.bmm(prob, V)
  return context