In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, h):
        super(MultiHeadAttention, self).__init__()
        assert d_model % h == 0, "d_model 必须能被h整除"

        self.d_model = d_model
        self.h = h

        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)

        self.fc_out = nn.Linear(d_model, d_model)
    
    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)

        seq_len_q = q.size(1)
        seq_len_k = k.size(1)

        Q = self.w_q(q).view(batch_size, seq_len_q, self.h, -1).transpose(1, 2)
        K = self.w_k(k).view(batch_size, seq_len_k, self.h, -1).transpose(1, 2)
        V = self.w_v(v).view(batch_size, seq_len_k, self.h, -1).transpose(1, 2)

        scaled_attention = scaled_dot_product_attention(Q, K, V, mask)

        concat_out = scaled_attention.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)

        out = self.fc_out(concat_out)

        return out

def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.size(-1)

    scores = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(d_k)

    if mask is not None:
        scores.masked_fill(mask == 0, float('-inf'))

    attention_weights = F.softmax(scores, dim=-1)
    
    output = torch.matmul(attention_weights, V)

    return output, attention_weights


In [2]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        #self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        return self.w_2(self.w_1(x).relu())

In [3]:
class LayerNorm(nn.Module):
    def __init__(self, feature_size, epsilon=1e-9):
        super(LayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(feature_size))
        self.beta = nn.Parameter(torch.zeros(feature_size))
        self.epsilon = epsilon
        
    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.epsilon) + self.beta

In [35]:
class SublayerConnection(nn.Module):
    def __init__(self, feature_size, dropout=0.1, epsilon=1e-9):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(feature_size, epsilon)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, sublayer):
        return self.norm(x + self.dropout(sublayer(x)))

In [5]:
class Embeddings(nn.Module):
    def __init__(self, vocab_size, d_model):
        super(Embeddings, self).__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.scale_factor = math.sqrt(d_model)

    def forward(self, x):
        return self.embed(x) * self.scale_factor

In [6]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)

        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)

        self.register_buffer('pe', pe)
    
    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

In [7]:
class SourceEmbedding(nn.Module):
    def __init__(self, src_vocab_size, d_model, dropout=0.1):
        super(SourceEmbedding, self).__init__()
        self.embed = Embeddings(src_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, dropout)
    
    def forward(self, x):
        x = self.embed(x)
        return self.positional_encoding(x)

In [8]:
class TargetEmbedding(nn.Module):
    def __init__(self, tgt_vocab_size, d_model, dropout=0.1):
        super(TargetEmbedding, self).__init__()
        self.embed = Embeddings(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, dropout)
    
    def forward(self, x):
        x = self.embed(x)
        return self.positional_encoding(x)

In [None]:
tgt = "<sos> I love NLP <eos>"

tgt_input = tgt[:-1]  # "<sos> I love NLP"
tgt_output = tgt[1:]  # "I love NLP <eos>"

In [9]:
def create_padding_mask(seq, pad_token_id=0):
    mask = (seq != pad_token_id).unsqueeze(1).unsqueeze(2)
    return mask

In [10]:
seq = torch.tensor([[5, 7, 9, 0, 0], [8, 6, 0, 0, 0]])  # 0 表示 <PAD>
print(create_padding_mask(seq))

tensor([[[[ True,  True,  True, False, False]]],


        [[[ True,  True, False, False, False]]]])


In [11]:
def create_look_ahead_mask(size):
    mask = torch.tril(torch.ones(size, size).type(torch.bool))
    return mask

In [12]:
print(create_look_ahead_mask(5))

tensor([[ True, False, False, False, False],
        [ True,  True, False, False, False],
        [ True,  True,  True, False, False],
        [ True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True]])


PyTorch 中的广播机制类似于 NumPy。两个张量进行逐元素运算时，它们的维度会从右向左对齐，如果满足以下任一条件，则可以广播：
- 维度长度相同；
- 其中一个张量该维度长度为 1；
- 缺少该维度（即维度数量较少的会在前面自动补 1）。

In [22]:
def create_decoder_mask(tgt_seq, pad_token_id=0):
    padding_mask = create_padding_mask(tgt_seq, pad_token_id)
    look_ahead_mask = create_look_ahead_mask(tgt_seq.size(1)).to(tgt_seq.device)

    combined_mask = look_ahead_mask.unsqueeze(0) & padding_mask
    # combined_mask = look_ahead_mask & padding_mask
    return combined_mask

In [23]:
tgt_seq = torch.tensor([[1, 2, 3, 4, 0]])  # 0 表示 <PAD>
mask = create_decoder_mask(tgt_seq)
print(mask.shape)


torch.Size([1, 1, 5, 5])


In [25]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, h, d_ff, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, h)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)

        self.sublayers = nn.ModuleList([SublayerConnection(d_model, dropout) for _ in range(2)])
        self.d_model = d_model
    
    def forward(self, x, src_mask):
        x = self.sublayers[0](x, lambda x: self.self_attn(x, x, x, src_mask))
        x = self.sublayers[1](x, self.feed_forward)
        return x

In [26]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, h, d_ff, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, h)
        self.cross_attn = MultiHeadAttention(d_model, h)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)

        self.sublayers = nn.ModuleList([SublayerConnection(d_model, dropout) for _ in range(3)])
        self.d_model = d_model
    
    def forward(self, x, memory, src_mask, tgt_mask):
        x = self.sublayers[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))

        x = self.sublayers[1](x, lambda x: self.cross_attn(x, memory, memory, src_mask))

        x = self.sublayers[2](x, self.feed_forward)

        return x


In [27]:
class Encoder(nn.Module):
    def __init__(self, d_model, N, h, d_ff, dropout=0.1):
        super(Encoder, self).__init__()

        self.layers = nn.ModuleList([EncoderLayer(d_model, h, d_ff, dropout) for _ in range(N)])

        self.norm = LayerNorm(d_model)
    
    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

In [29]:
class Decoder(nn.Module):
    def __init__(self, d_model, N, h, d_ff, dropout=0.1):
        super(Decoder, self).__init__()
        self.layers = nn.ModuleList([DecoderLayer(d_model, h, d_ff, dropout) for _ in range(N)])

        self.norm = LayerNorm(d_model)
    
    def forward(self, x, memory, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return self.norm(x)

In [37]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, N, h, d_ff, dropout=0.1):
        super(Transformer, self).__init__()
        self.src_embedding = SourceEmbedding(src_vocab_size, d_model, dropout)
        self.tgt_embedding = TargetEmbedding(tgt_vocab_size, d_model, dropout=0.1)

        self.encoder = Encoder(d_model, N, h, d_ff, dropout)
        self.decoder = Decoder(d_model, N, h, d_ff, dropout)

        self.fc_out = nn.Linear(d_model, tgt_vocab_size)
    
    def forward(self, src, tgt):
        src_mask = create_padding_mask(src)
        tgt_mask = create_decoder_mask(tgt)

        enc_output = self.encoder(self.src_embedding(src), src_mask)

        dec_output = self.decoder(self.tgt_embedding(tgt), enc_output, src_mask, tgt_mask)

        output = self.fc_out(dec_output)
        
        return output


In [38]:
# 定义词汇表大小（根据数据集）
src_vocab_size = 5000  # 源语言词汇表大小
tgt_vocab_size = 5000  # 目标语言词汇表大小

# 使用 Transformer base 参数
d_model = 512      # 嵌入维度
N = 6              # 编码器和解码器的层数
h = 8              # 多头注意力的头数
d_ff = 2048        # 前馈神经网络的隐藏层维度
dropout = 0.1      # Dropout 概率

# 实例化模型
model = Transformer(
    src_vocab_size=src_vocab_size,
    tgt_vocab_size=tgt_vocab_size,
    d_model=d_model,
    N=N,
    h=h,
    d_ff=d_ff,
    dropout=dropout
)

# 打印模型架构
print(model)

Transformer(
  (src_embedding): SourceEmbedding(
    (embed): Embeddings(
      (embed): Embedding(5000, 512)
    )
    (positional_encoding): PositionalEncoding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (tgt_embedding): TargetEmbedding(
    (embed): Embeddings(
      (embed): Embedding(5000, 512)
    )
    (positional_encoding): PositionalEncoding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (encoder): Encoder(
    (layers): ModuleList(
      (0-5): 6 x EncoderLayer(
        (self_attn): MultiHeadAttention(
          (w_q): Linear(in_features=512, out_features=512, bias=True)
          (w_k): Linear(in_features=512, out_features=512, bias=True)
          (w_v): Linear(in_features=512, out_features=512, bias=True)
          (fc_out): Linear(in_features=512, out_features=512, bias=True)
        )
        (feed_forward): PositionwiseFeedForward(
          (w_1): Linear(in_features=512, out_features=2048, bias=True)
          (w_2): Linear(in_features=2048