## 引言

1. 现代式语言模型，或者现代式人工智能最最核心的是 Transformer 架构，Transformer 架构最特色底层的计算机制是 Attention；
2. 在 Transformer 架构上，在 Attention 计算上花再多的时间探索都是值得的。

In [2]:
import torch
from torch import nn
import torch.nn.functional as F
torch.manual_seed(42)

<torch._C.Generator at 0x7efb58c6f4d0>

## Casual/Decoder only 单向注意力的实现

- BERT：双向注意力（bidirectional self attention）

    $$
    \quad \text{Attention}(Q^{(n \times d_k)}, K^{(n \times d_k)}, V^{(n \times d_v)}) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V 
    $$

- GPT：单向因果注意力（causal self attention）

    $$
    \quad \text{Attention}(Q^{(n \times d_k)}, K^{(n \times d_k)}, V^{(n \times d_v)}) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}+ M\right)V
    $$

    - $M_{ij}=0, j\ge i$
    - $M_{ij}=1, j\leq i$
    
    $$
    M = \begin{pmatrix}
    1 & -\infty & -\infty & \cdots & -\infty \\
    1 & 1 & -\infty & \cdots & -\infty \\
    1 & 1 & 1 & \cdots & -\infty \\
    \vdots & \vdots & \vdots & \ddots & \vdots \\
    1 & 1 & 1 & \cdots & 1
    \end{pmatrix}_{n\times n}
    $$

- T5：encoder 输出 K/V（取值相同），decoder 输出 Q，两者做 Cross attention

    $$
    \begin{split}
    \text{Encoder Self-Attention} &: \quad \text{Attention}(Q^{(n \times d_k)}, K^{(n \times d_k)}, V^{(n \times d_v)}) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\\
    \text{Decoder Masked Self-Attention} & : \quad \text{Attention}(Q^{(m \times d_k)}, K^{(m \times d_k)}, V^{(m \times d_v)}) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}+M\right)V \\
    \text{Cross-Attention} & : \quad \text{Attention}(Q^{(m \times d_k)}, K^{(n \times d_k)}, V^{(n \times d_v)}) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V \\
    \end{split}
    $$

## Training & Inference/Generate

- llama2/3 inference code: autoregressive, token by token generation
    - https://github.com/meta-llama/llama3/blob/main/llama/generation.py#L179-L192C13
    - 天然隐式地存在一个mask matrix
    - 第一个单词，预测第二个单词，
    - 第一个单词+第二个单词 => 预测第三个单词
    - ...
- training 的时候，因为有 casual mask（下三角矩阵的存在），等价于 autoregressive，token by token
    - 显式地加 mask matrix，不让模型看到后边的结果
- 计算 PPL （语言模型训练好坏的一个指标）的过程就是已有文本的测试集，可以用 casual mask 的方式实现自注意力，实现 autoregressive，token by token

In [3]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# 初始化模型和 tokenizer
model = GPT2LMHeadModel.from_pretrained('gpt2').to('cuda')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# 输入序列
input_text = "The quick brown fox jumps over the lazy dog"
input_ids = tokenizer.encode(input_text, return_tensors='pt')

In [4]:
input_ids.shape

torch.Size([1, 9])

In [5]:
# 方式一：model() 内部使用 attention_mask
outputs = model(input_ids.to('cuda'), )
logits = outputs.logits
logits.shape, logits[:, 1:-1, :]

(torch.Size([1, 9, 50257]),
 tensor([[[-62.3139, -61.5645, -66.4938,  ..., -68.1286, -68.3228, -63.5829],
          [-66.3240, -66.7452, -72.1618,  ..., -75.1955, -73.4651, -68.1786],
          [-88.2910, -88.7236, -93.4422,  ..., -98.6212, -90.6379, -90.9913],
          ...,
          [-80.7563, -82.8596, -87.4034,  ..., -91.0716, -89.5648, -84.5701],
          [-94.8247, -94.5054, -97.7886,  ..., -97.1508, -98.4995, -96.5095],
          [-88.8787, -87.6110, -92.3262,  ..., -95.8310, -93.5164, -91.9581]]],
        device='cuda:0', grad_fn=<SliceBackward0>))

In [6]:
# 方式二：逐步生成每个 token，并输出每一步的 logits
generated_logits = []

# 从第一个 token 开始逐步生成
for i in range(1, input_ids.size(1)):
    step_input_ids = input_ids[:, :i]  # 当前步骤的输入序列
    outputs = model(step_input_ids.to('cuda'))
    logits = outputs.logits
    next_token_logits = logits[:, -1, :]  # 获取最后一个 token 的 logits
    generated_logits.append(next_token_logits)

generated_logits = torch.stack(generated_logits, dim=1)[:, :, :]

In [7]:
generated_logits.shape, generated_logits[:, 1:, :]

(torch.Size([1, 8, 50257]),
 tensor([[[-62.3139, -61.5645, -66.4938,  ..., -68.1286, -68.3228, -63.5830],
          [-66.3240, -66.7452, -72.1618,  ..., -75.1955, -73.4651, -68.1786],
          [-88.2910, -88.7236, -93.4422,  ..., -98.6211, -90.6379, -90.9913],
          ...,
          [-80.7563, -82.8596, -87.4034,  ..., -91.0716, -89.5648, -84.5701],
          [-94.8247, -94.5054, -97.7886,  ..., -97.1508, -98.4995, -96.5095],
          [-88.8787, -87.6110, -92.3262,  ..., -95.8310, -93.5164, -91.9581]]],
        device='cuda:0', grad_fn=<SliceBackward0>))