
<img src="../imgs/transformer.png"  width="500" />

# 1 Encoder Block（单层编码器块）

In [1]:
import torch
from torch import nn


class EncoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ffn_hidden_dim, dropout=0.1):
        super().__init__()
        # 自注意力块
        self.self_attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout)
        # post 层归一化
        self.ln1 = nn.LayerNorm(embed_dim)
        # 前馈层
        self.ffn = nn.Sequential(
                nn.Linear(in_features=embed_dim, out_features=ffn_hidden_dim),
                nn.ReLU(),
                nn.Linear(in_features=ffn_hidden_dim, out_features=embed_dim),
                nn.Dropout(dropout)
        )
        # post 层归一化
        self.ln2 = nn.LayerNorm(embed_dim)
        # dropout
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # 自注意
        attn_output, _ = self.self_attn.forward(
                query=x,
                key=x,
                value=x,
                attn_mask=mask
        )
        x = x + self.dropout(attn_output)
        # 层归一化
        x = self.ln1(x)
        # 前馈层
        ffn_out = self.ffn.forward(input=x)
        x = x + self.dropout(ffn_out)
        # 层归一化
        x = self.ln2(x)
        return x

## 测试

In [2]:
encoder_block = EncoderBlock(embed_dim=512, num_heads=8, ffn_hidden_dim=2048)
x = torch.randn(2, 10, 512)
encoder_output = encoder_block(x)
encoder_output.shape

torch.Size([2, 10, 512])

# 2 Decoder Block（单层解码器块）

In [10]:
import torch
from torch import nn


class DecoderBlock(nn.Module):
    def __init__(self, embed_dim, attn_heads, ffn_hidden_dim, dropout=0.1):
        super().__init__()
        # 带掩码的自注意力
        self.masked_self_attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=attn_heads, dropout=dropout)
        # 层归一化
        self.ln1 = nn.LayerNorm(embed_dim)
        # 交叉注意力
        self.cross_attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=attn_heads, dropout=dropout)
        # 层归一化
        self.ln2 = nn.LayerNorm(embed_dim)
        # 前馈层
        self.ffn = nn.Sequential(
                nn.Linear(in_features=embed_dim, out_features=ffn_hidden_dim),
                nn.ReLU(),
                nn.Linear(in_features=ffn_hidden_dim, out_features=embed_dim),
                nn.Dropout(dropout)
        )
        # 层归一化
        self.ln3 = nn.LayerNorm(embed_dim)
        # dropout
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, encoder_output, self_attn_mask=None, cross_attn_mask=None):
        # 带掩码的自注意力
        attn_output, _ = self.masked_self_attn.forward(
                query=x,
                key=x,
                value=x,
                attn_mask=self_attn_mask
        )
        x = x + self.dropout(attn_output)
        # 层归一化
        x = self.ln1(x)
        # 交叉注意力
        cross_attn_output, _ = self.cross_attn.forward(
                query=x,
                key=encoder_output,
                value=encoder_output,
                attn_mask=cross_attn_mask
        )
        x = x + self.dropout(cross_attn_output)
        # 层归一化
        x = self.ln2(x)
        # 前馈层
        x = self.ffn(x)
        x = x + self.dropout(x)
        # 层归一化
        x = self.ln3(x)
        return x

## 测试

In [11]:
assert encoder_output is not None

x = torch.randn(2, 10, 512)
decoder_block = DecoderBlock(embed_dim=512, attn_heads=8, ffn_hidden_dim=2048)
decoder_output = decoder_block(x, encoder_output)
decoder_output.shape

torch.Size([2, 10, 512])