In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import math

import matplotlib.pyplot as plt

from typing import Tuple, Optional

# 一些辅助函数

In [2]:
# compare_output 用于计算 手工实现的结果与 pytorch 实现的结果是否一致
def compare_output(output_tag, output_ref, output, rtol=1e-5, atol=1e-8):
    flag = torch.allclose(output_ref, output, rtol=rtol, atol=atol)
    flag_emoj = "✅" if flag else "❌"
    print(f"{output_tag}: {flag_emoj}")


# show heatmaps 用于对多头注意力机制中输出的注意力权重矩阵进行可视化
def show_heatmaps(
    matrics, xlabel, ylabel, titles=None, figsize=(2.5, 2.5), cmap="Reds"
):
    """
    matrics: [batch_size, nheads, L, L]
    """
    # d2l.use_svg_display()
    # 每行显示一个batch 中的一个句子，每列显示一个 head
    num_rows, num_cols = matrics.shape[0], matrics.shape[1]
    fig, axes = plt.subplots(
        num_rows, num_cols, figsize=figsize, sharex=True, sharey=True, squeeze=False
    )
    for i, (row_axes, row_matrics) in enumerate(zip(axes, matrics)):
        for j, (ax, matrix) in enumerate(zip(row_axes, row_matrics)):
            # 绘制子窗口中的图像
            pcm = ax.imshow(matrix.detach().numpy(), cmap=cmap)
            # 在整个画布的最后一行显示xlabel
            if i == num_rows - 1:
                ax.set_xlabel(xlabel)
            # 在整个画布的第一列显示ylabel
            if j == 0:
                ax.set_ylabel(ylabel)
            if titles:
                ax.set_title(titles[j])
    # 显示colorbar
    fig.colorbar(pcm, ax=axes, shrink=0.6)

# Transformer 相关的参数定义

In [3]:
d_model = 512  # 输入特征的维度
nhead = 8  # AttentionHead 的数量
num_encoder_layers = 6
num_decoder_layers = 6
dim_feedforward = 2048
dropout = 0
batch_first = True
norm_first = True
bias = True

# Token Embedding

我们可以把 Token Embedding 看成一个 `[vocab_size, embed_size]`的查找表矩阵，每行对应了一个 token_id 的 embedding vector。



## Pytorch 中的 `nn.Embedding`

In [4]:
vocab_size = 32000
token_embedding_layer = nn.Embedding(vocab_size, d_model)

batch_size = 2
seqlen = 5

input_token_ids = torch.randint(0, vocab_size, (batch_size, seqlen), dtype=torch.long)
token_embeddings_torch = token_embedding_layer(input_token_ids)

## 手动查找

In [5]:
# Pytorch 支持使用一个 2d Tensor 作为 Indices 来进行 Tensor 来索引
token_embeddings = token_embedding_layer.weight[input_token_ids]
compare_output("token_embeddings", token_embeddings_torch, token_embeddings)

token_embeddings: ✅


# 位置编码

$PE_{(pos,2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right)$

$PE_{(pos,2i + 1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right)$

In [1]:
import torch

embedding_mat = torch.pow(1000, -torch.arange(0, 512, 2) / 512)
embedding_mat[:5]

tensor([1.0000, 0.9734, 0.9475, 0.9222, 0.8977])

In [2]:
import torch

embedding_mat = torch.pow(10000, -torch.arange(0, 512, 2) / 512)
embedding_mat[:5]

tensor([1.0000, 0.9647, 0.9306, 0.8977, 0.8660])

In [3]:
max_seqlen = 512
embed_size = 2048

position_mat = torch.arange(0, max_seqlen)[:, None]  # [512, 1]
embedding_mat = torch.pow(10000, torch.arange(0, embed_size, 2) / embed_size)[
    None, :
]  # [1, 1024]
position_embedding_matrix = position_mat / embedding_mat  # [512, 1024]
position_embed = torch.zeros(max_seqlen, embed_size)
position_embed[:, 0::2] = torch.sin(position_embedding_matrix)
position_embed[:, 1::2] = torch.cos(position_embedding_matrix)

In [6]:
position_embedding_layer = nn.Embedding(max_seqlen, embed_size)
position_embedding_layer.weight = nn.Parameter(
    position_embed, requires_grad=False
)

seq_len = 10
batch_size = 2
# input_positions 为输入的句子序列对应的位置 id
input_positions = torch.arange(0, seq_len)[None, :].repeat((batch_size, 1))
# 也可以直接对 position_embedding_matrix 进行 2 维索引
position_embedding = position_embedding_layer(input_positions)
print(position_embedding.shape)

torch.Size([2, 10, 2048])


# MultiHeadAttention

## 生成 Padding Mask 和 Attention Mask

In [10]:
def get_padding_mask(key_seq_lengths, max_seq_len):
    batch_size = len(key_seq_lengths)
    key_padding_mask = torch.arange(1, max_seq_len + 1)
    key_padding_mask = key_padding_mask[None, :].repeat(batch_size, 1)
    key_padding_mask = key_padding_mask > torch.tensor(key_seq_lengths)[:, None]
    return key_padding_mask


def get_causal_attn_mask(max_seq_len):
    attn_mask = torch.ones(max_seq_len, max_seq_len).tril() == 0
    return attn_mask

## Pytorch API 的使用

In [11]:
torch_attn = nn.MultiheadAttention(
    d_model,
    nhead,
    dropout=dropout,
    bias=bias,
    batch_first=batch_first,
)

In [12]:
batch_size = 4
seqlen = 10

q_tensor = torch.randn(batch_size, seqlen, d_model)
k_tensor = torch.randn(batch_size, seqlen, d_model)
v_tensor = torch.randn(batch_size, seqlen, d_model)

batch_seq_len = [4, 9, 6, 10]
padding_mask = get_padding_mask(batch_seq_len, seqlen)
attn_mask = get_causal_attn_mask(seqlen)

need_weights = True
# [batch_size, seqlen, d_model], [batch_size, nheads, seqlen, seqlen]
attn_output_torch, attn_output_weights_torch = torch_attn(
    q_tensor,
    k_tensor,
    v_tensor,
    key_padding_mask=padding_mask,
    attn_mask=attn_mask,
    need_weights=need_weights,
    average_attn_weights=False,
)

In [13]:
in_proj_weight = torch_attn.in_proj_weight  # [3 * d_model, d_model]
in_proj_bias = torch_attn.in_proj_bias  # [3 * d_model]
out_proj_weight = torch_attn.out_proj.weight  # [d_model, d_model]
out_proj_bias = torch_attn.out_proj.bias  # [d_model]

## 手动实现 MHA

In [14]:
def multihead_attn(
    query: Tensor,  # [N,L,Eq]
    key: Tensor,
    value: Tensor,
    nheads,
    in_proj_weight: Tensor,
    in_proj_bias: Tensor,
    out_proj_weight: Tensor,
    out_proj_bias: Tensor,
    dropout_p: float = 0,
    key_padding_mask: Tensor = None,
    attn_mask: Tensor = None,
    need_weights=True,
):
    q_proj_weight, k_proj_weight, v_proj_weight = in_proj_weight.chunk(3)
    q_proj_bias, k_proj_bias, v_proj_bias = in_proj_bias.chunk(3)

    q_proj = query @ q_proj_weight.t() + q_proj_bias  # N,L,E
    k_proj = key @ k_proj_weight.t() + k_proj_bias
    v_proj = value @ v_proj_weight.t() + v_proj_bias

    batch_size, qlen, embed_size = q_proj.shape
    klen = k_proj.size(1)
    head_embed_size = embed_size // nheads
    q_proj = q_proj.reshape(batch_size, qlen, nhead, head_embed_size).permute(
        0, 2, 1, 3
    )
    k_proj = k_proj.reshape(batch_size, klen, nhead, head_embed_size).permute(
        0, 2, 1, 3
    )
    v_proj = v_proj.reshape(batch_size, klen, nhead, head_embed_size).permute(
        0, 2, 1, 3
    )
    # [batch_size, nhead, qlen, ken]
    atten_weights = q_proj @ k_proj.transpose(2, 3) / math.sqrt(head_embed_size)

    if key_padding_mask is not None:
        key_padding_mask = torch.where(key_padding_mask, float("-inf"), 0)
        atten_weights += key_padding_mask.view(batch_size, 1, 1, klen)

    if attn_mask is not None:
        attn_mask = torch.where(attn_mask, float("-inf"), 0)
        atten_weights += attn_mask

    atten_weights = F.softmax(atten_weights, dim=-1)

    if dropout_p > 0:
        atten_weights = F.dropout(atten_weights, p=dropout_p)

    attn_output = atten_weights @ v_proj  # [batch_size, nhead, qlen, head_embed_size]
    # [batch_size, qlen, embed_size]
    attn_output = attn_output.permute(0, 2, 1, 3).reshape(batch_size, qlen, embed_size)
    attn_output = attn_output @ out_proj_weight.t() + out_proj_bias

    if need_weights:
        return attn_output, atten_weights

    return attn_output

In [15]:
attn_output, attn_output_weights = multihead_attn(
    q_tensor,
    k_tensor,
    v_tensor,
    nhead,
    in_proj_weight,
    in_proj_bias,
    out_proj_weight,
    out_proj_bias,
    key_padding_mask=padding_mask,
    attn_mask=attn_mask,
    need_weights=need_weights,
)

In [16]:
compare_output("attn_output", attn_output, attn_output_torch)
compare_output("attn_output_weights", attn_output_weights_torch, attn_output_weights)

attn_output: ✅
attn_output_weights: ✅


# Position-wise Feedforward

In [17]:
def feedforward_block(
    x,
    linear1_weight,
    linear1_bias,
    linear2_weight,
    linear2_bias,
    dropout_p=0,
):
    x = x @ linear1_weight.t() + linear1_bias
    x = F.relu(x)
    x = F.dropout(x, dropout_p)
    x = x @ linear2_weight.t() + linear2_bias
    return x

# TransformerEncoderLayer

## Pytorch API 的使用

In [18]:
torch_encoder_layer = nn.TransformerEncoderLayer(
    d_model,
    nhead,
    dim_feedforward,
    dropout=dropout,
    batch_first=batch_first,
    norm_first=norm_first,
    bias=bias,
)

In [19]:
batch_size = 4
seqlen = 10

input_tensor = torch.randn(batch_size, seqlen, d_model)
encoder_layer_output_torch = torch_encoder_layer(input_tensor)

##  手动实现

In [20]:
def transformer_encoder_layer_fwd(
    src,
    nhead,
    in_proj_weight,
    in_proj_bias,
    out_proj_weight,
    out_proj_bias,
    linear1_weight,
    linear1_bias,
    linear2_weight,
    linear2_bias,
    norm1_weight,
    norm1_bias,
    norm2_weight,
    norm2_bias,
    dropout_p=0,
    layer_norm_eps=1e-5,
    attn_mask=None,
    padding_mask=None,
):
    norm = F.layer_norm(
        src, (src.size(-1),), norm1_weight, norm1_bias, eps=layer_norm_eps
    )
    out = multihead_attn(
        norm,
        norm,
        norm,
        nhead,
        in_proj_weight,
        in_proj_bias,
        out_proj_weight,
        out_proj_bias,
        dropout_p,
        padding_mask,
        attn_mask,
        need_weights=False,
    )
    out += src
    norm = F.layer_norm(
        out, (out.size(-1),), norm2_weight, norm2_bias, eps=layer_norm_eps
    )
    out = (
        feedforward_block(
            norm, linear1_weight, linear1_bias, linear2_weight, linear2_bias, dropout_p
        )
        + out
    )
    return out

In [21]:
in_proj_weight = torch_encoder_layer.self_attn.in_proj_weight
in_proj_bias = torch_encoder_layer.self_attn.in_proj_bias
out_proj_weight = torch_encoder_layer.self_attn.out_proj.weight
out_proj_bias = torch_encoder_layer.self_attn.out_proj.bias
linear1_weight = torch_encoder_layer.linear1.weight
linear1_bias = torch_encoder_layer.linear1.bias
linear2_weight = torch_encoder_layer.linear2.weight
linear2_bias = torch_encoder_layer.linear2.bias
norm1_weight = torch_encoder_layer.norm1.weight
norm1_bias = torch_encoder_layer.norm1.bias
norm2_weight = torch_encoder_layer.norm2.weight
norm2_bias = torch_encoder_layer.norm2.bias


encoder_layer_output = transformer_encoder_layer_fwd(
    input_tensor,
    nhead,
    in_proj_weight,
    in_proj_bias,
    out_proj_weight,
    out_proj_bias,
    linear1_weight,
    linear1_bias,
    linear2_weight,
    linear2_bias,
    norm1_weight,
    norm1_bias,
    norm2_weight,
    norm2_bias,
    dropout_p=dropout,
)

In [22]:
compare_output(
    "encoder_layer_output", encoder_layer_output_torch, encoder_layer_output, atol=1e-6
)

encoder_layer_output: ✅


# TransformerDecoderLayer

## Pytorch 实现

In [23]:
decoder_layer_torch = nn.TransformerDecoderLayer(
    d_model,
    nhead,
    dim_feedforward=dim_feedforward,
    dropout=dropout,
    batch_first=batch_first,
    norm_first=norm_first,
)

In [24]:
batch_size = 2
src_seq_len = [8, 10]
tgt_seq_len = [6, 4]

encoder_memory = torch.randn(batch_size, max(src_seq_len), d_model)
tgt_input_tensor = torch.randn(batch_size, max(tgt_seq_len), d_model)

src_padding_mask = get_padding_mask(src_seq_len, max(src_seq_len))
tgt_padding_mask = get_padding_mask(tgt_seq_len, max(tgt_seq_len))
tgt_mask = get_causal_attn_mask(max(tgt_seq_len))

decoder_output_torch = decoder_layer_torch(
    tgt_input_tensor,
    encoder_memory,
    tgt_mask=tgt_mask,
    tgt_key_padding_mask=tgt_padding_mask,
    memory_key_padding_mask=src_padding_mask,
    tgt_is_causal=True,
)

## 从 Pytorch Module 中获取每一层的参数

In [25]:
self_attn_in_proj_weight = decoder_layer_torch.self_attn.in_proj_weight
self_attn_in_proj_bias = decoder_layer_torch.self_attn.in_proj_bias
self_attn_out_proj_weight = decoder_layer_torch.self_attn.out_proj.weight
self_attn_out_proj_bias = decoder_layer_torch.self_attn.out_proj.bias

cross_attn_in_proj_weight = decoder_layer_torch.multihead_attn.in_proj_weight
cross_attn_in_proj_bias = decoder_layer_torch.multihead_attn.in_proj_bias
cross_attn_out_proj_weight = decoder_layer_torch.multihead_attn.out_proj.weight
cross_attn_out_proj_bias = decoder_layer_torch.multihead_attn.out_proj.bias

linear1_weight = decoder_layer_torch.linear1.weight
linear1_bias = decoder_layer_torch.linear1.bias
linear2_weight = decoder_layer_torch.linear2.weight
linear2_bias = decoder_layer_torch.linear2.bias
norm1_weight = decoder_layer_torch.norm1.weight
norm1_bias = decoder_layer_torch.norm1.bias
norm2_weight = decoder_layer_torch.norm2.weight
norm2_bias = decoder_layer_torch.norm2.bias
norm3_weight = decoder_layer_torch.norm3.weight
norm3_bias = decoder_layer_torch.norm3.bias

## 手动实现

In [26]:
def transformer_decoder_layer_fwd(
    tgt,
    memory,
    nhead,
    self_attn_in_proj_weight,
    self_attn_in_proj_bias,
    self_attn_out_proj_weight,
    self_attn_out_proj_bias,
    cross_attn_in_proj_weight,
    cross_attn_in_proj_bias,
    cross_attn_out_proj_weight,
    cross_attn_out_proj_bias,
    linear1_weight,
    linear1_bias,
    linear2_weight,
    linear2_bias,
    norm1_weight,
    norm1_bias,
    norm2_weight,
    norm2_bias,
    norm3_weight,
    norm3_bias,
    dropout_p=0,
    layer_norm_eps=1e-5,
    tgt_mask=None,
    memory_mask=None,
    tgt_key_padding_mask=None,
    memory_key_padding_mask=None,
):
    norm1 = F.layer_norm(
        tgt, (tgt.size(-1),), norm1_weight, norm1_bias, eps=layer_norm_eps
    )
    out1 = multihead_attn(
        norm1,
        norm1,
        norm1,
        nhead,
        self_attn_in_proj_weight,
        self_attn_in_proj_bias,
        self_attn_out_proj_weight,
        self_attn_out_proj_bias,
        dropout_p,
        key_padding_mask=tgt_key_padding_mask,
        attn_mask=tgt_mask,
        need_weights=False,
    )
    out1 += tgt

    norm2 = F.layer_norm(
        out1, (out1.size(-1),), norm2_weight, norm2_bias, eps=layer_norm_eps
    )
    out2 = multihead_attn(
        norm2,
        memory,
        memory,
        nhead,
        cross_attn_in_proj_weight,
        cross_attn_in_proj_bias,
        cross_attn_out_proj_weight,
        cross_attn_out_proj_bias,
        dropout_p,
        key_padding_mask=memory_key_padding_mask,
        attn_mask=memory_mask,
        need_weights=False,
    )
    out2 += out1

    norm3 = F.layer_norm(
        out2, (out2.size(-1),), norm3_weight, norm3_bias, eps=layer_norm_eps
    )

    out3 = (
        feedforward_block(
            norm3, linear1_weight, linear1_bias, linear2_weight, linear2_bias, dropout_p
        )
        + out2
    )
    return out3

In [27]:
decoder_output = transformer_decoder_layer_fwd(
    tgt_input_tensor,
    encoder_memory,
    nhead,
    self_attn_in_proj_weight,
    self_attn_in_proj_bias,
    self_attn_out_proj_weight,
    self_attn_out_proj_bias,
    cross_attn_in_proj_weight,
    cross_attn_in_proj_bias,
    cross_attn_out_proj_weight,
    cross_attn_out_proj_bias,
    linear1_weight,
    linear1_bias,
    linear2_weight,
    linear2_bias,
    norm1_weight,
    norm1_bias,
    norm2_weight,
    norm2_bias,
    norm3_weight,
    norm3_bias,
    dropout_p=dropout,
    tgt_mask=tgt_mask,
    memory_mask=None,
    tgt_key_padding_mask=tgt_padding_mask,
    memory_key_padding_mask=src_padding_mask,
)

In [28]:
compare_output("decoder_layer_output", decoder_output_torch, decoder_output, atol=1e-6)

decoder_layer_output: ✅


# Decoder 损失函数

In [29]:
max_tgt_seqlen = max(tgt_seq_len)

# 先将 decoder 的输出经过一个分类头进行变换
decoder_head = nn.Linear(d_model, vocab_size)
logits = decoder_head(decoder_output)  # [batch_size, seqlen, vocab_size]

# 生成真实标签，使用-100进行 batch 内的填充，计算损失时会自动忽略
labels = torch.stack(
    [
        F.pad(
            torch.randint(0, vocab_size, (seqlen,)),
            (0, max_tgt_seqlen - seqlen),
            value=-100,
        )
        for seqlen in tgt_seq_len
    ],
    dim=0,
)

decoder_loss = F.cross_entropy(
    logits.transpose(1, 2), labels, ignore_index=-100, reduction="mean"
)
print(decoder_loss)

tensor(10.9175, grad_fn=<NllLoss2DBackward0>)


# 自回归生成

每一步推理预测一个字符，然后将该预测的字符和 decoder 的输入拼在一起后再次作为输入。

In [30]:
encoder_layer = nn.TransformerEncoderLayer(
    d_model,
    nhead,
    dim_feedforward,
    dropout,
    batch_first=batch_first,
    norm_first=norm_first,
)
encoder = nn.TransformerEncoder(
    encoder_layer=encoder_layer,
    num_layers=num_encoder_layers,
    enable_nested_tensor=False,
)

decoder_layer = nn.TransformerDecoderLayer(
    d_model,
    nhead,
    dim_feedforward,
    dropout,
    batch_first=batch_first,
    norm_first=norm_first,
)

decoder = nn.TransformerDecoder(
    decoder_layer=decoder_layer,
    num_layers=num_decoder_layers,
)
decoder_head = nn.Linear(d_model, vocab_size)
token_embedding = nn.Embedding(vocab_size, d_model)

src_seqlen = 10
src_ids = torch.randint(0, vocab_size, (src_seqlen,))
tgt_ids = torch.LongTensor([0])  # <bos>

input_embeddings = token_embedding(src_ids)
memory = encoder(input_embeddings)

In [31]:
decoder_steps = 5
for _ in range(decoder_steps):
    tgt_embeddings = token_embedding(tgt_ids)
    decode_output = decoder(tgt_embeddings, memory)
    logits = decoder_head(decode_output)[-1, :][None, :]
    next_token_id = torch.argmax(logits, -1)
    tgt_ids = torch.concat([tgt_ids, next_token_id], dim=-1)
print(tgt_ids)

tensor([    0,  4474,  5914,  3152,  5914, 10405])
