# 手撕 Transformer：从零实现面试速通教程

本教程面向 AI/算法工程师面试，目标是“手撕 Transformer”时能在白板/编辑器中快速、正确、可讲解地实现关键模块与完整骨架。

你将学习并实现：
- Scaled Dot-Product Attention（带 mask）
- Multi-Head Attention（MHA）
- Position-wise Feed Forward（FFN）
- 残差连接 + LayerNorm
- 位置编码（Positional Encoding）
- EncoderLayer / DecoderLayer
- Transformer Encoder-Decoder 总装
- 贪心解码（Greedy Decode）与一个极简玩具任务

建议：面试中优先保证“正确 + 清晰 + 注释完善 + 形状无误”。

# 环境与依赖

- Python ≥ 3.8
- 推荐使用 PyTorch（面试常用）
- 若无 torch，可按需安装或在纸上仅写伪代码/接口签名

下面代码会尝试导入 torch 并给出缺失提示。

In [1]:
# Import and quick check
try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch import Tensor
    print(torch.__version__)
except Exception as e:
    print("[Warn] torch not available. You can still read/understand the code.")
    print(e)

2.8.0


# Scaled Dot-Product Attention（带 Mask）
核心公式：

.
令 $Q\in\mathbb{R}^{B\times H\times T_q\times d_k},\ K\in\mathbb{R}^{B\times H\times T_k\times d_k},\ V\in\mathbb{R}^{B\times H\times T_k\times d_v}$。

- 分数矩阵（缩放点积）：
$$
\mathrm{scores} \,=\, \frac{QK^{\top}}{\sqrt{d_k}}\;\;\in\;\mathbb{R}^{B\times H\times T_q\times T_k}
$$

- 加掩码（可见为1，不可见为0）：
令 $M\in\{0,1\}^{B\times 1\times T_q\times T_k}$ 为可见性掩码，定义加性掩码
$$
\tilde{M} \,=\, (1-M)\cdot (-\infty),
$$
则
$$
\mathrm{attn} \,=\, \mathrm{softmax}(\mathrm{scores} + \tilde{M})\;\;\in\;\mathbb{R}^{B\times H\times T_q\times T_k}.
$$

- 加权求和输出：
$$
\mathrm{out} \,=\, \mathrm{attn}\,V\;\;\in\;\mathbb{R}^{B\times H\times T_q\times d_v}.
$$

数值稳定性：使用足够大的负数（实现中以 $-\infty$ 近似）使被遮挡位置在 softmax 后概率趋近 0。

**可视化结构：**

![Scaled Dot-Product Attention](Scaled_dot-product_attention.png)

上图展示了缩放点积注意力的计算流程：输入 Q、K、V 经过矩阵乘法、缩放、Mask、Softmax，最后加权求和得到输出。

In [2]:
import math
from typing import Optional

class ScaledDotProductAttention(nn.Module):
    def __init__(self, dropout: float = 0.0):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, Q: Tensor, K: Tensor, V: Tensor, mask: Optional[Tensor] = None) -> tuple[Tensor, Tensor]:
        """
        Q: (B, H, T_q, d_k)
        K: (B, H, T_k, d_k)
        V: (B, H, T_k, d_v)
        mask: (B, 1, T_q, T_k) 或 (B, H, T_q, T_k)，1 表示可见，0 表示遮挡
        返回: (out, attn)
          out: (B, H, T_q, d_v)
          attn: (B, H, T_q, T_k)
        """
        d_k = Q.size(-1)
        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d_k)  # (B,H,T_q,T_k)
        if mask is not None:
            # 将不可见位置置为 -inf
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = scores.softmax(dim=-1)
        attn = self.dropout(attn)
        out = attn @ V  # (B,H,T_q,d_v)
        return out, attn

# quick shape test (no torch run here if not installed)
if 'torch' in globals():
    B, H, T_q, T_k, d_k, d_v = 2, 4, 5, 6, 8, 8
    Q = torch.randn(B, H, T_q, d_k)
    K = torch.randn(B, H, T_k, d_k)
    V = torch.randn(B, H, T_k, d_v)
    mask = torch.ones(B, 1, T_q, T_k)
    attn = ScaledDotProductAttention()
    out, w = attn(Q, K, V, mask)
    print(out.shape, w.shape)  # expect: (2,4,5,8) (2,4,5,6)

torch.Size([2, 4, 5, 8]) torch.Size([2, 4, 5, 6])


# Multi-Head Attention（MHA）
令头数为 $H$，模型维度 $d_{\text{model}}$，每头维度 $d_k = d_{\text{model}}/H$。对输入 $X\in\mathbb{R}^{B\times T\times d_{\text{model}}}$：

- 线性映射（对每个头分别使用参数）：
$$
Q_i = X W_Q^{(i)},\quad K_i = X W_K^{(i)},\quad V_i = X W_V^{(i)},
$$
其中 $W_Q^{(i)}, W_K^{(i)}, W_V^{(i)}\in\mathbb{R}^{d_{\text{model}}\times d_k}$。

- 头内注意力：
$$
\mathrm{head}_i = \mathrm{Attention}(Q_i, K_i, V_i) = \mathrm{softmax}\!\left(\frac{Q_i K_i^{\top}}{\sqrt{d_k}} + \tilde{M}\right)V_i.
$$

- 头拼接与输出映射：
$$
\mathrm{MHA}(X) = \mathrm{Concat}(\mathrm{head}_1,\dots,\mathrm{head}_H)\, W_O,
$$
其中 $W_O\in\mathbb{R}^{(H\cdot d_k)\times d_{\text{model}}}$，最终输出形状为 $\mathbb{R}^{B\times T\times d_{\text{model}}}$。

**可视化结构：**

![Multi-Head Attention](Multi-Head_Attention.png)

上图展示了多头注意力的完整流程：输入经过线性投影分成多个头，每个头独立进行注意力计算，最后将所有头的输出拼接并通过线性层映射回原始维度。

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.0):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

        self.attn = ScaledDotProductAttention(dropout)
        self.dropout = nn.Dropout(dropout)

    def _split_heads(self, x: Tensor) -> Tensor:
        # x: (B,T,d_model) -> (B,H,T,d_k)
        B, T, _ = x.shape
        x = x.view(B, T, self.num_heads, self.d_k).transpose(1, 2)
        return x

    def _combine_heads(self, x: Tensor) -> Tensor:
        # x: (B,H,T,d_k) -> (B,T,d_model)
        B, H, T, d_k = x.shape
        x = x.transpose(1, 2).contiguous().view(B, T, H * d_k)
        return x

    def forward(self, x_q: Tensor, x_kv: Tensor, mask: Optional[Tensor] = None) -> tuple[Tensor, Tensor]:
        """
        x_q: (B,T_q,d_model)
        x_kv: (B,T_k,d_model)
        mask: (B,1,T_q,T_k) 或 (B,H,T_q,T_k)
        返回: (out, attn)
        """
        Q = self._split_heads(self.W_q(x_q))  # (B,H,T_q,d_k)
        K = self._split_heads(self.W_k(x_kv)) # (B,H,T_k,d_k)
        V = self._split_heads(self.W_v(x_kv)) # (B,H,T_k,d_k)

        out, attn = self.attn(Q, K, V, mask)   # out: (B,H,T_q,d_k)
        out = self._combine_heads(out)         # (B,T_q,d_model)
        out = self.W_o(out)                    # (B,T_q,d_model)
        out = self.dropout(out)
        return out, attn

# quick shape test
if 'torch' in globals():
    B, T_q, T_k, d_model, H = 2, 5, 6, 32, 4
    x_q = torch.randn(B, T_q, d_model)
    x_kv = torch.randn(B, T_k, d_model)
    mask = torch.ones(B, 1, T_q, T_k)
    mha = MultiHeadAttention(d_model, H)
    y, a = mha(x_q, x_kv, mask)
    print(y.shape, a.shape)  # expect: (2,5,32) (2,4,5,6)

torch.Size([2, 5, 32]) torch.Size([2, 4, 5, 6])


# Positional Encoding（位置编码）
采用正弦/余弦固定位置编码。对位置 $\mathrm{pos}\in\{0,\dots,T-1\}$、维度索引 $i\in\{0,\dots,\lfloor\tfrac{d_{\text{model}}}{2}\rfloor-1\}$：

$$
\mathrm{PE}[\mathrm{pos},\,2i] \;=\; \sin\!\left(\frac{\mathrm{pos}}{10000^{\frac{2i}{d_{\text{model}}}}}\right),\quad
\mathrm{PE}[\mathrm{pos},\,2i+1] \;=\; \cos\!\left(\frac{\mathrm{pos}}{10000^{\frac{2i}{d_{\text{model}}}}}\right).
$$

将其加到词嵌入上得到：
$$
X_{\text{pos}} = X + \mathrm{PE}.
$$

In [4]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.0):
        super().__init__()
        pe = torch.zeros(max_len, d_model)  # (T, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # (T,1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))  # (1,T,d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: Tensor) -> Tensor:
        # x: (B,T,d_model)
        T = x.size(1)
        x = x + self.pe[:, :T, :]
        return self.dropout(x)

# quick test
if 'torch' in globals():
    pe = PositionalEncoding(32)
    x = torch.zeros(2, 10, 32)
    y = pe(x)
    print(y.shape)  # (2,10,32)

torch.Size([2, 10, 32])


# FFN + LayerNorm + 残差
Position-wise FFN 通常是两层 MLP：$d_{\text{model}} \to d_{\mathrm{ff}} \to d_{\text{model}}$。

- FFN：
$$
\mathrm{FFN}(x) = W_2\,\sigma(W_1 x + b_1) + b_2,\quad W_1\in\mathbb{R}^{d_{\text{model}}\times d_{\mathrm{ff}}},\; W_2\in\mathbb{R}^{d_{\mathrm{ff}}\times d_{\text{model}}}.
$$

- 残差 + LayerNorm（Post-LN，本教程采用）：
$$
y = \mathrm{LN}\big(x + \mathrm{Sublayer}(x)\big).
$$

（对照）Pre-LN 变体：$y = x + \mathrm{Sublayer}(\mathrm{LN}(x))$。

In [5]:
class FeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.0, activation: str = 'relu'):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        if activation == 'relu':
            self.act = nn.ReLU()
        elif activation == 'gelu':
            self.act = nn.GELU()
        else:
            raise ValueError('activation must be relu or gelu')

    def forward(self, x: Tensor) -> Tensor:
        return self.fc2(self.dropout(self.act(self.fc1(x))))

class ResidualLayerNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-5):
        super().__init__()
        self.ln = nn.LayerNorm(d_model, eps=eps)

    def forward(self, x: Tensor, sublayer_out: Tensor) -> Tensor:
        # x + sublayer(x) 再做 LN
        return self.ln(x + sublayer_out)

# quick test
if 'torch' in globals():
    ff = FeedForward(32, 64)
    x = torch.randn(2, 10, 32)
    y = ff(x)
    print(y.shape)  # (2,10,32)
    ln = ResidualLayerNorm(32)
    z = ln(x, y)
    print(z.shape)  # (2,10,32)

torch.Size([2, 10, 32])
torch.Size([2, 10, 32])


# EncoderLayer / DecoderLayer
- EncoderLayer：Self-Attention + FFN（每个子层后 Residual + LayerNorm）
- DecoderLayer：Masked Self-Attention + Cross-Attention + FFN

设自注意力 $\mathrm{Att}(Q,K,V)=\mathrm{softmax}\!\left(\tfrac{QK^\top}{\sqrt{d_k}}+\tilde{M}\right)V$。

EncoderLayer：
$$
\begin{aligned}
\tilde{x}_1 &= \mathrm{Att}(X, X, X),\\
X' &= \mathrm{LN}\big(X + \tilde{x}_1\big),\\
\tilde{x}_2 &= \mathrm{FFN}(X'),\\
Y &= \mathrm{LN}\big(X' + \tilde{x}_2\big).
\end{aligned}
$$

DecoderLayer（含两次注意力）：
$$
\begin{aligned}
\tilde{y}_1 &= \mathrm{Att}_{\text{masked}}(Y, Y, Y),\\
Y' &= \mathrm{LN}\big(Y + \tilde{y}_1\big),\\
\tilde{y}_2 &= \mathrm{Att}(Y',\,\mathrm{Mem},\,\mathrm{Mem}),\\
Y'' &= \mathrm{LN}\big(Y' + \tilde{y}_2\big),\\
Z &= \mathrm{LN}\big(Y'' + \mathrm{FFN}(Y'')\big).
\end{aligned}
$$

其中 $\mathrm{Mem}$ 为 Encoder 的输出记忆，$\mathrm{Att}_{\text{masked}}$ 在自回归解码时使用下三角掩码（仅允许关注历史位）。

In [6]:
def make_pad_mask(q_len: int, k_len: int, q_pad: Tensor | None, k_pad: Tensor | None) -> Tensor:
    """
    构造 padding mask（1 可见, 0 屏蔽），形状 (B,1,q_len,k_len)
    q_pad/k_pad: (B,T) 中 1 表示 pad 位置
    """
    if q_pad is None and k_pad is None:
        return None
    if q_pad is None:
        q_mask = torch.zeros_like(k_pad)
    else:
        q_mask = q_pad
    if k_pad is None:
        k_mask = torch.zeros_like(q_mask)
    else:
        k_mask = k_pad
    # 可见位置=1，即非pad
    q_visible = (q_mask == 0).unsqueeze(2)  # (B,T_q,1)
    k_visible = (k_mask == 0).unsqueeze(1)  # (B,1,T_k)
    mask = q_visible & k_visible            # (B,T_q,T_k)
    return mask.unsqueeze(1)                # (B,1,T_q,T_k)


def make_subsequent_mask(T: int) -> Tensor:
    """Decoder 自注意力的下三角可见性掩码（1 可见, 0 屏蔽），形状 (1,1,T,T)"""
    return torch.tril(torch.ones(T, T, dtype=torch.bool)).unsqueeze(0).unsqueeze(0)


class EncoderLayer(nn.Module):
    def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.norm1 = ResidualLayerNorm(d_model)
        self.norm2 = ResidualLayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: Tensor, src_mask: Optional[Tensor] = None) -> tuple[Tensor, Tensor]:
        # Self-Attention
        sa_out, sa_w = self.self_attn(x, x, src_mask)
        x = self.norm1(x, self.dropout(sa_out))
        # FFN
        ff_out = self.ffn(x)
        x = self.norm2(x, self.dropout(ff_out))
        return x, sa_w


class DecoderLayer(nn.Module):
    def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.norm1 = ResidualLayerNorm(d_model)
        self.norm2 = ResidualLayerNorm(d_model)
        self.norm3 = ResidualLayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, y: Tensor, memory: Tensor, tgt_mask: Optional[Tensor], memory_mask: Optional[Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor]]:
        # Masked Self-Attention (decoder)
        sa_out, sa_w = self.self_attn(y, y, tgt_mask)
        y = self.norm1(y, self.dropout(sa_out))
        # Cross-Attention: Q=decoder, K/V=encoder memory
        ca_out, ca_w = self.cross_attn(y, memory, memory_mask)
        y = self.norm2(y, self.dropout(ca_out))
        # FFN
        ff_out = self.ffn(y)
        y = self.norm3(y, self.dropout(ff_out))
        return y, (sa_w, ca_w)


# 总装：Transformer Encoder-Decoder
- 词嵌入 + 位置编码
- N 层 EncoderLayer / DecoderLayer 堆叠
- 输出线性层映射到词表大小
- 解码时使用贪心或 beam search（本教程实现贪心）

**Transformer 整体架构：**

![Transformer Architecture](attention_architerture.png)

上图展示了完整的 Transformer Encoder-Decoder 架构：
- **左侧 Encoder**：输入嵌入 + 位置编码 → N×(多头自注意力 + FFN)
- **右侧 Decoder**：输出嵌入 + 位置编码 → N×(掩码多头自注意力 + 编码器-解码器注意力 + FFN)
- **输出层**：线性映射 + Softmax 生成目标词表概率分布

注意每个子层后都有残差连接和 LayerNorm。

In [7]:
class Transformer(nn.Module):
    def __init__(self, src_vocab: int, tgt_vocab: int, d_model: int = 256, num_heads: int = 8,
                 d_ff: int = 512, num_layers: int = 4, dropout: float = 0.1, max_len: int = 512):
        super().__init__()
        self.src_embed = nn.Embedding(src_vocab, d_model)
        self.tgt_embed = nn.Embedding(tgt_vocab, d_model)
        self.pos_enc = PositionalEncoding(d_model, max_len, dropout)

        self.encoder_layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)
        ])
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)
        ])
        self.out_proj = nn.Linear(d_model, tgt_vocab)

    def encode(self, src: Tensor, src_pad: Optional[Tensor] = None) -> tuple[Tensor, list[Tensor]]:
        # src: (B,T_s), src_pad: (B,T_s) 1=pad
        x = self.pos_enc(self.src_embed(src))  # (B,T_s,d_model)
        attn_weights = []
        src_len = src.size(1)
        src_mask = make_pad_mask(src_len, src_len, src_pad, src_pad)  # (B,1,T,T)
        for layer in self.encoder_layers:
            x, sa_w = layer(x, src_mask)
            attn_weights.append(sa_w)
        return x, attn_weights

    def decode(self, tgt: Tensor, memory: Tensor, src_pad: Optional[Tensor] = None, tgt_pad: Optional[Tensor] = None) -> tuple[Tensor, list[tuple[Tensor, Tensor]]]:
        # tgt: (B,T_t)
        y = self.pos_enc(self.tgt_embed(tgt))
        T_t = tgt.size(1)
        B, T_s = memory.size(0), memory.size(1)
        # masks
        pad_mask = make_pad_mask(T_t, T_t, tgt_pad, tgt_pad)            # (B,1,T_t,T_t)
        subs_mask = make_subsequent_mask(T_t).to(y.device)              # (1,1,T_t,T_t)
        tgt_mask = pad_mask & subs_mask if pad_mask is not None else subs_mask
        mem_mask = make_pad_mask(T_t, T_s, tgt_pad, src_pad)            # (B,1,T_t,T_s)

        attn_pairs = []
        for layer in self.decoder_layers:
            y, (sa_w, ca_w) = layer(y, memory, tgt_mask, mem_mask)
            attn_pairs.append((sa_w, ca_w))
        return y, attn_pairs

    def forward(self, src: Tensor, tgt_inp: Tensor, src_pad: Optional[Tensor] = None, tgt_pad: Optional[Tensor] = None) -> Tensor:
        memory, _ = self.encode(src, src_pad)
        y, _ = self.decode(tgt_inp, memory, src_pad, tgt_pad)
        logits = self.out_proj(y)
        return logits

    @torch.no_grad()
    def greedy_decode(self, src: Tensor, bos_id: int, eos_id: int, max_new_tokens: int,
                      src_pad: Optional[Tensor] = None) -> Tensor:
        self.eval()
        memory, _ = self.encode(src, src_pad)
        B = src.size(0)
        ys = torch.full((B, 1), bos_id, dtype=torch.long, device=src.device)
        for _ in range(max_new_tokens):
            y, _ = self.decode(ys, memory, src_pad, tgt_pad=None)
            logits = self.out_proj(y)  # (B,T,d_vocab)
            next_token = logits[:, -1].argmax(dim=-1, keepdim=True)  # (B,1)
            ys = torch.cat([ys, next_token], dim=1)
            if (next_token == eos_id).all():
                break
        return ys

# 极简玩具任务：Copy Task（验证前向/反向是否正确）
任务：输入序列 [a b c]，输出也为 [a b c]。
- 词表：{PAD=0, BOS=1, EOS=2, 其他 3..V-1}
- 损失：交叉熵（忽略 PAD）
- 只训练少量步数，演示损失可下降

In [8]:
import random

def make_copy_batch(batch_size: int, seq_len: int, vocab_size: int, pad_id: int = 0, bos_id: int = 1, eos_id: int = 2):
    """构造一批 copy 样本。返回 src,tgt_inp,tgt_out, 以及 pad mask。"""
    src = []
    tgt_inp = []
    tgt_out = []
    for _ in range(batch_size):
        toks = [random.randint(3, vocab_size - 1) for _ in range(seq_len)]
        src.append(toks)
        # tgt: 以 BOS 开始，后接相同序列，最后 EOS
        tgt_inp.append([bos_id] + toks)
        tgt_out.append(toks + [eos_id])
    src = torch.tensor(src, dtype=torch.long)
    tgt_inp = torch.tensor(tgt_inp, dtype=torch.long)
    tgt_out = torch.tensor(tgt_out, dtype=torch.long)
    # 无 pad，这里简单起见
    src_pad = torch.zeros_like(src)
    tgt_pad = torch.zeros_like(tgt_inp)
    return src, tgt_inp, tgt_out, src_pad, tgt_pad

# 训练演示（可选）
if 'torch' in globals():
    torch.manual_seed(0)
    V = 100
    model = Transformer(src_vocab=V, tgt_vocab=V, d_model=128, num_heads=4, d_ff=256, num_layers=2, dropout=0.1)
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    optim = torch.optim.Adam(model.parameters(), lr=3e-4)

    for step in range(50):  # 小步数演示
        model.train()
        src, tgt_inp, tgt_out, src_pad, tgt_pad = make_copy_batch(batch_size=16, seq_len=5, vocab_size=V)
        logits = model(src, tgt_inp, src_pad, tgt_pad)     # (B,T+1,V)
        loss = criterion(logits.reshape(-1, V), tgt_out.reshape(-1))
        optim.zero_grad()
        loss.backward()
        optim.step()
        if (step + 1) % 10 == 0:
            print(f"step {step+1}: loss={loss.item():.4f}")

    # 贪心解码测试
    model.eval()
    src, tgt_inp, tgt_out, src_pad, tgt_pad = make_copy_batch(batch_size=2, seq_len=5, vocab_size=V)
    pred = model.greedy_decode(src, bos_id=1, eos_id=2, max_new_tokens=6)
    print("src:", src)
    print("pred:", pred)

step 10: loss=4.4004
step 20: loss=4.2882
step 30: loss=4.2305
step 40: loss=4.0967
step 30: loss=4.2305
step 40: loss=4.0967
step 50: loss=4.0603
src: tensor([[46, 37, 72, 78, 88],
        [33, 99, 70, 35,  9]])
pred: tensor([[ 1, 46, 46, 46,  2],
        [ 1, 46, 46,  2,  2]])
step 50: loss=4.0603
src: tensor([[46, 37, 72, 78, 88],
        [33, 99, 70, 35,  9]])
pred: tensor([[ 1, 46, 46, 46,  2],
        [ 1, 46, 46,  2,  2]])


# 复杂度、易错点与面试答题要点

## 复杂度
注意力计算主复杂度：
$$
\mathcal{O}\big(B\,\cdot\,H\,\cdot\,T_q\,\cdot\,T_k\,\cdot\,d_k\big).
$$
当 $T_q\approx T_k\approx T$ 且 $d_k\approx d_{\text{model}}/H$：
$$
\mathcal{O}\big(B\,\cdot\,T^2\,\cdot\,d_{\text{model}}\big),
$$
这是标准 Transformer 的瓶颈（$T^2$）。

（可选）注意力权重显存占用：$\mathcal{O}(B\,\cdot\,H\,\cdot\,T_q\,\cdot\,T_k)$。

## 易错点清单
1. MHA 头部分割/合并时的维度变换（view/transpose/contiguous）
2. Mask 的取值方向（1=可见还是 1=遮挡）要统一（本教程采用 1=可见）
3. Decoder 自注意力需要“下三角”mask（Subsequent Mask）
4. Cross-Attention 的 $K/V$ 来自 Encoder 的 memory
5. 残差 + LayerNorm 的顺序（本教程采用 Post-LN：sublayer 后残差再 LN）
6. 位置编码长度要覆盖输入最大长度

## 面试快速讲解结构（建议）
- 大纲：Embedding+PE → MHA（QKV、分头、缩放点积）→ FFN → 残差+LN → 编解码器堆叠 → 输出层
- 强调形状：清晰写出 $(B,T,d_{\text{model}})$ / $(B,H,T,d_k)$
- 强调 Mask：pad mask 与 subsequent mask 作用位置
- 强调解码：贪心/BeamSearch 差异
- 可扩展点：相对位置编码、Pre-LN、RoPE、FlashAttention、Efficient Transformer