In [1]:
import os
import torch
from torch import nn
import torch.nn.functional as F
import math

## Embedding

#### 在transformer中 Embedding主要由TokenEmbedding和PositionalEmbedding组成

- TokenEmbedding: 将对应的单词表中的序号转换成矩阵表示 eg. 输入(batch_size, seq_len), 输出(batch_size, seq_len, feature_len)
- PositionalEmbedding: 进行位置编码, 引入位置信息

In [2]:
# TokenEmbedding 直接继承 nn.Embedding 来用即可
class TokenEmbedding(nn.Embedding):
    def __init__(self, vocab_size, embed_size):
        super().__init__(vocab_size, embed_size, padding_idx=0)

# 示例使用
vocab_size = 1000
embed_size = 224
batch_size = 128
max_len = 64
device = torch.device("cpu")

net = TokenEmbedding(vocab_size, embed_size).to(device)
X = torch.randint(0, vocab_size, (batch_size, max_len)).to(device)
output = net(X)

# X 是原句转换成词库中对应编码的结果 eg. "It is a good time to study Deep Learning." -> [2, 55, 4, 15, 6, 90, 8, 93, 10]
# X 就是 [2, 55, 4, 15, 6, 90, 8, 93, 10] 这样的一个向量
print("X shape:", X.shape)
print("output shape:", output.shape)

X shape: torch.Size([128, 64])
output shape: torch.Size([128, 64, 224])


In [3]:
# PositionalEmbedding 有多种可实现方式，比如正余弦编码或者学习得到的参数

# 学习得到的参数
class PositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_len=64):
        super(PositionalEmbedding, self).__init__()
        # 初始化一个形状为 (1, max_len, d_model) 的可学习参数
        self.pe = nn.Parameter(torch.randn(1, max_len, d_model))
    
    def forward(self, x):
        seq_len = x.shape[1]
        return self.pe[:, :seq_len].to(x.device)

# 正余弦编码("Attention is All You Need" 原文中使用的方式)
class PositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_len=64):
        super(PositionalEmbedding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)
        self.encoding.requires_grad_ = False

        position = torch.arange(0, max_len).unsqueeze(1).float()
        _2i = torch.arange(0, d_model, 2).float()

        self.encoding[:, 0::2] = torch.sin(position / (10000 ** (_2i / d_model)))
        self.encoding[:, 1::2] = torch.cos(position / (10000 ** (_2i / d_model)))
        
    def forward(self, x):
        seq_len = x.shape[1]
        return self.encoding[:seq_len, :].to(x.device)
    

# 示例使用
batch_size = 128
max_len = 64
d_model = 224 # d_model就是Embedding之后的特征维度
device = torch.device("cpu")

PE = PositionalEmbedding(d_model, max_len).to(device)
X = torch.randn(128, max_len, d_model).float().to(device)
output = PE(X)

print(f"X shape: {X.shape}")
print(f"output shape: {output.shape}")

X shape: torch.Size([128, 64, 224])
output shape: torch.Size([64, 224])


In [4]:
# 将 TokenEmbedding 和 PositionalEmbedding 结合起来
class TransformerEmbedding(nn.Module):
    def __init__(self, vocab_size, max_len, d_model, dropout=0.1):
        super(TransformerEmbedding, self).__init__()
        self.token_embedding = TokenEmbedding(vocab_size, d_model)
        self.positional_embedding = PositionalEmbedding(d_model, max_len)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        token_embed = self.token_embedding(x)
        pos_embed = self.positional_embedding(x)
        return self.dropout(token_embed + pos_embed)
    
# 示例使用
vocab_size = 1000
batch_size = 128
max_len = 64
d_model = 224
device = torch.device("cpu")

net = TransformerEmbedding(vocab_size, max_len, d_model).to(device)
X = torch.randint(0, vocab_size, (batch_size, max_len)).to(device)
output = net(X)
print(f"X shape: {X.shape}")
print(f"output shape: {output.shape}")

X shape: torch.Size([128, 64])
output shape: torch.Size([128, 64, 224])


## Multi-Head Attention

In [5]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_head):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.n_head = n_head
        self.head_dim = self.d_model // n_head
        
        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)
        self.wo = nn.Linear(d_model, d_model)
        
    def forward(self, q, k, v, mask=None):
        batch_size, time, dimension = q.shape

        Q = self.wq(q)
        K = self.wk(k)
        V = self.wv(v)

        Q = Q.view(batch_size, time, self.n_head, self.head_dim).permute(0, 2, 1, 3) # (batch_size, n_head, time, head_dim)
        K = K.view(batch_size, time, self.n_head, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, time, self.n_head, self.head_dim).permute(0, 2, 1, 3)

        score = Q @ K.transpose(2, 3) / math.sqrt(self.head_dim) # score: (batch_size, n_head, time, time)
        if mask is not None: # Masked Multi-Head Attention 中的 mask
            score = score.masked_fill(mask == 0, float("-inf"))

        attention = F.softmax(score, dim=-1)
        out = attention @ V # out shape : (batch_size, n_head, time, head_dim)
        out = out.permute(0, 2, 1, 3).contiguous().view(batch_size, time, dimension)

        output = self.wo(out)
        return output
    
# 示例使用
batch_size = 128
seq_len = 64
d_model = 224
n_head = 2
device = torch.device("cpu")

attention_model = MultiHeadAttention(d_model, n_head).to(device)
X = torch.randn((batch_size, seq_len, d_model)).to(device)
output = attention_model(X, X, X, None)

print(f"X shape : {X.shape}")
print(f"output shape : {output.shape}")

X shape : torch.Size([128, 64, 224])
output shape : torch.Size([128, 64, 224])


## Norm

In [6]:
class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-5):
        super(LayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):
        x_mean = x.mean(-1, keepdim=True)
        x_var = x.var(-1, unbiased=False, keepdim=True)
        out = (x - x_mean) / (x_var + self.eps).sqrt()
        out = self.gamma * out + self.beta
        return out
    
# 示例使用
batch_size = 128
seq_len = 64
d_model = 224
device = torch.device("cpu")

norm = LayerNorm(d_model).to(device)
X = torch.randn((batch_size, seq_len, d_model)).to(device)
output = norm(X)

print(f"X shape : {X.shape}")
print(f"output shape : {output.shape}")

X shape : torch.Size([128, 64, 224])
output shape : torch.Size([128, 64, 224])


## PositionwiseFeedForward(FFN)

In [7]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, hidden_dim, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x
    
# 示例使用
batch_size = 128
seq_len = 64
d_model = 224
hidden_dim = 512
dropout = 0.1
device = torch.device("cpu")

ffn = PositionwiseFeedForward(d_model, hidden_dim, dropout).to(device)
X = torch.randn((batch_size, seq_len, d_model)).to(device)
output = ffn(X)

print(f"X shape : {X.shape}")
print(f"output shape : {output.shape}")

X shape : torch.Size([128, 64, 224])
output shape : torch.Size([128, 64, 224])


## EncoderLayer

In [8]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_head, ffn_hidden_dim, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.attention = MultiHeadAttention(d_model, n_head)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = LayerNorm(d_model)
        self.ffn = PositionwiseFeedForward(d_model, ffn_hidden_dim, dropout)
        self.norm2 = LayerNorm(d_model)

    def forward(self, x, mask=None):
        _x = x
        x = self.attention(x, x, x, mask)
        x = self.dropout1(x)
        x = self.norm1(x + _x)

        _x = x
        x = self.ffn(x)
        x = self.norm2(x + _x)
        return x
    
# 示例使用
batch_size = 128
seq_len = 64
d_model = 224
n_head = 2
ffn_hidden_dim = 512
dropout = 0.1
device = torch.device("cpu")

encoderlayer = EncoderLayer(d_model, n_head, ffn_hidden_dim, dropout).to(device)
X = torch.randn((batch_size, seq_len, d_model)).to(device)
output = encoderlayer(X)

print(f"X shape : {X.shape}")
print(f"output shape : {output.shape}")

X shape : torch.Size([128, 64, 224])
output shape : torch.Size([128, 64, 224])


## DecoderLayer

In [9]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_head, ffn_hidden_dim, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.attention1 = MultiHeadAttention(d_model, n_head)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = LayerNorm(d_model)

        self.cross_attention = MultiHeadAttention(d_model, n_head)
        self.dropout2 = nn.Dropout(dropout)
        self.norm2 = LayerNorm(d_model)

        self.ffn = PositionwiseFeedForward(d_model, ffn_hidden_dim, dropout)
        self.norm3 = LayerNorm(d_model)

    def forward(self, dec_x, enc_x, t_mask, s_mask):
        _x = dec_x
        x = self.attention1(dec_x, dec_x, dec_x, t_mask) # 下三角掩码
        x = self.dropout1(x)
        x = self.norm1(x + _x)

        if enc_x is not None:
            _x = x
            x = self.cross_attention(x, enc_x, enc_x, s_mask)
            x = self.dropout2(x)
            x = self.norm2(x + _x)
        
        _x = x
        x = self.ffn(x)
        x = self.norm3(x + _x)

        return x
    
# 示例使用
batch_size = 128
seq_len = 64
d_model = 224
n_head = 2
ffn_hidden_dim = 512
dropout = 0.1
device = torch.device("cpu")

decoderlayer = DecoderLayer(d_model, n_head, ffn_hidden_dim, dropout).to(device)
X = torch.randn((batch_size, seq_len, d_model)).to(device)
enc_X = torch.randn((batch_size, seq_len, d_model)).to(device)
t_mask = (torch.tril(torch.ones(seq_len, seq_len)) > 0).to(device) # 下三角掩码
s_mask = torch.ones(seq_len, seq_len).to(device) # 源掩码, 这里是全1生成，实际使用应根据实际考虑
output = decoderlayer(X, enc_X, t_mask, s_mask)

print(f"X shape : {X.shape}")
print(f"output shape : {output.shape}")

X shape : torch.Size([128, 64, 224])
output shape : torch.Size([128, 64, 224])


## Encoder

In [10]:
class Encoder(nn.Module):
    def __init__(self, d_model, n_head, ffn_hidden_dim, n_layer, dropout=0.1):
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList(
            [EncoderLayer(d_model, n_head, ffn_hidden_dim, dropout) for _ in range(n_layer)]
        )

    def forward(self, x, s_mask):
        for layer in self.layers:
            x = layer(x, s_mask)
        return x
    
# 示例使用
batch_size = 128
seq_len = 64
d_model = 224
n_head = 2
ffn_hidden_dim = 512
n_layer = 4
dropout = 0.1
device = torch.device("cpu")

encoder = Encoder(d_model, n_head, ffn_hidden_dim, n_layer, dropout).to(device)
X = torch.randn((batch_size, seq_len, d_model)).to(device)
s_mask = torch.ones(seq_len, seq_len).to(device) # 源掩码
output = encoder(X, s_mask)

print(f"X shape : {X.shape}")
print(f"output shape : {output.shape}")

X shape : torch.Size([128, 64, 224])
output shape : torch.Size([128, 64, 224])


## Decoder

In [11]:
class Decoder(nn.Module):
    def __init__(self, d_model, n_head, ffn_hidden_dim, n_layer, dropout=0.1):
        super(Decoder, self).__init__()
        self.layers = nn.ModuleList(
            [DecoderLayer(d_model, n_head, ffn_hidden_dim, dropout) for _ in range(n_layer)]
        )

    def forward(self, dec_x, enc_x, t_mask, s_mask):
        for layer in self.layers:
            dec_x = layer(dec_x, enc_x, t_mask, s_mask)
        return dec_x
    
# 示例使用
batch_size = 128
seq_len = 64
d_model = 224
n_head = 2
ffn_hidden_dim = 512
n_layer = 4
dropout = 0.1
device = torch.device("cpu")

decoder = Decoder(d_model, n_head, ffn_hidden_dim, n_layer, dropout).to(device)
X = torch.randn((batch_size, seq_len, d_model)).to(device)
enc_X = torch.randn((batch_size, seq_len, d_model)).to(device)
t_mask = (torch.tril(torch.ones(seq_len, seq_len)) > 0).to(device) # 下三角掩码
s_mask = torch.ones(seq_len, seq_len).to(device) # 源掩码, 这里是全1生成，实际使用应根据实际考虑
output = decoder(X, enc_X, t_mask, s_mask)

print(f"X shape : {X.shape}")
print(f"output shape : {output.shape}")

X shape : torch.Size([128, 64, 224])
output shape : torch.Size([128, 64, 224])


## Transformer

In [12]:
class Transformer(nn.Module):
    def __init__(self, src_pad_idx, trg_pad_idx, enc_vocab_size, dec_vocab_size, max_len, d_model, n_head, ffn_hidden_dim, n_encoderlayer, n_decoderlayer, dropout=0.1):
        super(Transformer, self).__init__()
        self.src_pad_idx = src_pad_idx # encoder的输入中填充符的数字表示
        self.trg_pad_idx = trg_pad_idx # decoder的输入中填充符的数字表示

        self.encoder_embedding = TransformerEmbedding(enc_vocab_size, max_len, d_model, dropout)
        self.decoder_embedding = TransformerEmbedding(dec_vocab_size, max_len, d_model, dropout)

        self.encoder = Encoder(d_model, n_head, ffn_hidden_dim, n_encoderlayer, dropout)
        self.decoder = Decoder(d_model, n_head, ffn_hidden_dim, n_decoderlayer, dropout)

        self.fc = nn.Linear(d_model, dec_vocab_size)

    def make_pad_mask(self, q, k, pad_idx_q, pad_idx_k):
        len_q, len_k = q.shape[1], k.shape[1]

        # Q : (batch_size, n_head, time, time)
        # q : (batch_size, time)
        q = q.ne(pad_idx_q).unsqueeze(1).unsqueeze(3) # (batch_size, 1, time, 1)
        q = q.repeat(1, 1, 1, len_k) # (batch_size, 1, time, len_k)

        k = k.ne(pad_idx_k).unsqueeze(1).unsqueeze(3) # (batch_size, 1, time, 1)
        k = k.repeat(1, 1, 1, len_q) # (batch_size, 1, time, len_q)

        mask = q & k
        return mask
    
    def make_casual_mask(self, q, k): # 下三角mask，用来让decoder只看到当前位置及以前的信息
        len_q, len_k = q.shape[1], k.shape[1]
        mask = torch.tril(torch.ones((len_q, len_k)).type(torch.BoolTensor)).to(q.device)
        return mask

    def forward(self, src, trg):
        src_mask = self.make_pad_mask(src, src, self.src_pad_idx, self.src_pad_idx)
        trg_mask = self.make_pad_mask(trg, trg, self.trg_pad_idx, self.trg_pad_idx) * self.make_casual_mask(trg, trg)
        src_trg_mask = self.make_pad_mask(trg, src, self.trg_pad_idx, self.src_pad_idx)

        enc_X = self.encoder_embedding(src)
        dec_X = self.decoder_embedding(trg)

        enc = self.encoder(enc_X, src_mask)
        dec = self.decoder(dec_X, enc, trg_mask, src_trg_mask)

        out = self.fc(dec)
        out = F.softmax(out, dim=-1)
        return out
    
# 示例使用
src_pad_idx = 0
trg_pad_idx = 0
enc_vocab_size = 1000
dec_vocab_size = 1000
batch_size = 128
max_len = 64
d_model = 224
n_head = 2
ffn_hidden_dim = 512
n_encoderlayer = 4
n_decoderlayer = 4
dropout = 0.1
device = torch.device("cpu")

transformer = Transformer(src_pad_idx, 
                          trg_pad_idx, 
                          enc_vocab_size, 
                          dec_vocab_size, 
                          max_len, 
                          d_model,
                          n_head,
                          ffn_hidden_dim,
                          n_encoderlayer, 
                          n_decoderlayer, 
                          dropout).to(device)

src = torch.randint(0, enc_vocab_size, (batch_size, max_len)).to(device)
trg = torch.randint(0, dec_vocab_size, (batch_size, max_len)).to(device)
output = transformer(src, trg)

print(f"src shape : {src.shape}")
print(f"trg shape : {trg.shape}")
print(f"output shape : {output.shape}")

src shape : torch.Size([128, 64])
trg shape : torch.Size([128, 64])
output shape : torch.Size([128, 64, 1000])
