# 标准Attention内存瓶颈 - 实践篇

本notebook通过实际代码帮助你理解标准Attention的内存问题。

**学习目标：**
- 实现标准Attention并测量显存占用
- 观察O(N²)内存增长
- 理解内存绑定的概念


## 环境准备


In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import time

# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

# 检查GPU
print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA可用: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"显存: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")


## 1. 标准Attention实现


In [None]:
def standard_attention(Q, K, V, scale=None):
    """
    标准Self-Attention实现
    
    Args:
        Q: Query矩阵 [batch, seq_len, d]
        K: Key矩阵 [batch, seq_len, d]
        V: Value矩阵 [batch, seq_len, d]
        scale: 缩放因子，默认为 1/sqrt(d)
    
    Returns:
        O: 输出矩阵 [batch, seq_len, d]
        S: 注意力分数矩阵 [batch, seq_len, seq_len]
        P: 注意力权重矩阵 [batch, seq_len, seq_len]
    """
    if scale is None:
        scale = 1.0 / (Q.shape[-1] ** 0.5)
    
    # 步骤1: 计算注意力分数 S = QK^T / sqrt(d)
    # [batch, seq_len, d] @ [batch, d, seq_len] -> [batch, seq_len, seq_len]
    S = torch.matmul(Q, K.transpose(-2, -1)) * scale
    
    # 步骤2: Softmax归一化
    P = F.softmax(S, dim=-1)
    
    # 步骤3: 加权求和 O = PV
    # [batch, seq_len, seq_len] @ [batch, seq_len, d] -> [batch, seq_len, d]
    O = torch.matmul(P, V)
    
    return O, S, P

# 测试
batch_size = 2
seq_len = 128
d = 64

Q = torch.randn(batch_size, seq_len, d)
K = torch.randn(batch_size, seq_len, d)
V = torch.randn(batch_size, seq_len, d)

O, S, P = standard_attention(Q, K, V)

print(f"输入形状: Q={Q.shape}, K={K.shape}, V={V.shape}")
print(f"中间矩阵形状: S={S.shape}, P={P.shape}")
print(f"输出形状: O={O.shape}")
print(f"\nP的每行和（应该都是1）: {P[0, 0, :].sum().item():.6f}")


## 2. 内存占用分析


In [None]:
def calculate_memory_usage(batch_size, seq_len, d, dtype=torch.float16):
    """计算标准Attention的理论内存占用"""
    bytes_per_element = 2 if dtype == torch.float16 else 4
    
    # Q, K, V, O: 每个 batch × seq_len × d
    qkvo_memory = 4 * batch_size * seq_len * d * bytes_per_element
    
    # S, P: 每个 batch × seq_len × seq_len
    sp_memory = 2 * batch_size * seq_len * seq_len * bytes_per_element
    
    return qkvo_memory, sp_memory

# 分析不同序列长度的内存占用
d = 64
batch_size = 1
seq_lengths = [512, 1024, 2048, 4096, 8192, 16384]

print("="*70)
print(f"{'序列长度':<12} {'Q/K/V/O内存':<15} {'S+P内存':<15} {'S+P占比':<10}")
print("="*70)

qkvo_list = []
sp_list = []

for seq_len in seq_lengths:
    qkvo, sp = calculate_memory_usage(batch_size, seq_len, d)
    total = qkvo + sp
    sp_ratio = sp / total * 100
    
    qkvo_list.append(qkvo / 1e6)
    sp_list.append(sp / 1e6)
    
    print(f"{seq_len:<12} {qkvo/1e6:<15.2f} MB {sp/1e6:<15.2f} MB {sp_ratio:<10.1f}%")

print("="*70)


## 3. 可视化内存增长


In [None]:
# 可视化O(N) vs O(N²)增长
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# 左图：绝对内存占用
ax1 = axes[0]
x = np.array(seq_lengths)
ax1.bar(np.arange(len(seq_lengths)) - 0.2, qkvo_list, 0.4, label='Q/K/V/O (O(N))', color='steelblue')
ax1.bar(np.arange(len(seq_lengths)) + 0.2, sp_list, 0.4, label='S+P (O(N²))', color='coral')
ax1.set_xticks(np.arange(len(seq_lengths)))
ax1.set_xticklabels([str(s) for s in seq_lengths])
ax1.set_xlabel('序列长度 N')
ax1.set_ylabel('内存占用 (MB)')
ax1.set_title('标准Attention内存占用分解')
ax1.legend()
ax1.set_yscale('log')
ax1.grid(True, alpha=0.3)

# 右图：内存增长曲线
ax2 = axes[1]
n = np.linspace(512, 16384, 100)
linear_memory = n * 64 * 2 * 4 / 1e6  # Q,K,V,O
quadratic_memory = n * n * 2 * 2 / 1e6  # S,P

ax2.plot(n, linear_memory, label='O(N): Q/K/V/O', color='steelblue', linewidth=2)
ax2.plot(n, quadratic_memory, label='O(N²): S+P', color='coral', linewidth=2)
ax2.set_xlabel('序列长度 N')
ax2.set_ylabel('内存占用 (MB)')
ax2.set_title('内存增长趋势对比')
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_yscale('log')

plt.tight_layout()
plt.show()


## 4. GPU显存实际测量


In [None]:
def measure_attention_memory(seq_len, d=64, batch_size=1, num_heads=12, dtype=torch.float16):
    """实际测量GPU上Attention的显存占用"""
    if not torch.cuda.is_available():
        return None, None
    
    device = torch.device('cuda')
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()
    
    # 初始显存
    initial_memory = torch.cuda.memory_allocated()
    
    # 创建输入
    Q = torch.randn(batch_size, num_heads, seq_len, d, device=device, dtype=dtype)
    K = torch.randn(batch_size, num_heads, seq_len, d, device=device, dtype=dtype)
    V = torch.randn(batch_size, num_heads, seq_len, d, device=device, dtype=dtype)
    
    input_memory = torch.cuda.memory_allocated() - initial_memory
    
    # 执行Attention
    scale = 1.0 / (d ** 0.5)
    S = torch.matmul(Q, K.transpose(-2, -1)) * scale
    P = F.softmax(S, dim=-1)
    O = torch.matmul(P, V)
    
    peak_memory = torch.cuda.max_memory_allocated() - initial_memory
    
    # 清理
    del Q, K, V, S, P, O
    torch.cuda.empty_cache()
    
    return input_memory, peak_memory

if torch.cuda.is_available():
    print("实际GPU显存测量 (batch=1, heads=12, d=64, FP16)")
    print("="*60)
    print(f"{'序列长度':<12} {'输入显存':<15} {'峰值显存':<15} {'中间矩阵':<15}")
    print("="*60)
    
    for seq_len in [512, 1024, 2048, 4096]:
        try:
            input_mem, peak_mem = measure_attention_memory(seq_len)
            intermediate = peak_mem - input_mem
            print(f"{seq_len:<12} {input_mem/1e6:<15.2f} MB {peak_mem/1e6:<15.2f} MB {intermediate/1e6:<15.2f} MB")
        except RuntimeError as e:
            print(f"{seq_len:<12} OOM - 显存不足!")
    
    print("="*60)
else:
    print("CUDA不可用，跳过GPU测量")


## 5. 计算强度分析


In [None]:
def calculate_arithmetic_intensity(seq_len, d, batch_size=1, num_heads=12):
    """计算Attention各步骤的计算强度"""
    bytes_per_element = 2  # FP16
    
    # 步骤1: S = QK^T
    flops_1 = 2 * batch_size * num_heads * seq_len * seq_len * d
    memory_read_1 = 2 * batch_size * num_heads * seq_len * d * bytes_per_element  # Q, K
    memory_write_1 = batch_size * num_heads * seq_len * seq_len * bytes_per_element  # S
    ai_1 = flops_1 / (memory_read_1 + memory_write_1)
    
    # 步骤2: P = softmax(S)
    flops_2 = 5 * batch_size * num_heads * seq_len * seq_len  # exp, sum, div, etc.
    memory_2 = 2 * batch_size * num_heads * seq_len * seq_len * bytes_per_element  # read S, write P
    ai_2 = flops_2 / memory_2
    
    # 步骤3: O = PV
    flops_3 = 2 * batch_size * num_heads * seq_len * seq_len * d
    memory_read_3 = batch_size * num_heads * (seq_len * seq_len + seq_len * d) * bytes_per_element  # P, V
    memory_write_3 = batch_size * num_heads * seq_len * d * bytes_per_element  # O
    ai_3 = flops_3 / (memory_read_3 + memory_write_3)
    
    return ai_1, ai_2, ai_3

print("Attention各步骤的计算强度 (FLOPs/Byte)")
print("="*60)
print(f"{'序列长度':<12} {'QK^T':<15} {'Softmax':<15} {'PV':<15}")
print("="*60)

for seq_len in [512, 1024, 2048, 4096, 8192]:
    ai_1, ai_2, ai_3 = calculate_arithmetic_intensity(seq_len, 64)
    print(f"{seq_len:<12} {ai_1:<15.2f} {ai_2:<15.2f} {ai_3:<15.2f}")

print("="*60)
print(f"\nA100 GPU平衡点: ~156 FLOPs/Byte")
print(f"H100 GPU平衡点: ~296 FLOPs/Byte")
print(f"\n结论: Attention的计算强度远低于GPU平衡点，是内存绑定的!")


## 6. 训练时的显存问题


In [None]:
def attention_with_grad(Q, K, V):
    """带梯度的Attention，用于测试训练时显存"""
    scale = 1.0 / (Q.shape[-1] ** 0.5)
    S = torch.matmul(Q, K.transpose(-2, -1)) * scale
    P = F.softmax(S, dim=-1)
    O = torch.matmul(P, V)
    return O

if torch.cuda.is_available():
    print("训练模式 vs 推理模式 显存对比")
    print("="*60)
    print(f"{'序列长度':<12} {'推理显存':<15} {'训练显存':<15} {'增加倍数':<10}")
    print("="*60)
    
    device = torch.device('cuda')
    
    for seq_len in [512, 1024, 2048]:
        try:
            # 推理模式
            torch.cuda.reset_peak_memory_stats()
            torch.cuda.empty_cache()
            
            with torch.no_grad():
                Q = torch.randn(1, 12, seq_len, 64, device=device, dtype=torch.float16)
                K = torch.randn(1, 12, seq_len, 64, device=device, dtype=torch.float16)
                V = torch.randn(1, 12, seq_len, 64, device=device, dtype=torch.float16)
                O = attention_with_grad(Q, K, V)
            
            inference_mem = torch.cuda.max_memory_allocated() / 1e6
            
            del Q, K, V, O
            torch.cuda.empty_cache()
            
            # 训练模式
            torch.cuda.reset_peak_memory_stats()
            
            Q = torch.randn(1, 12, seq_len, 64, device=device, dtype=torch.float16, requires_grad=True)
            K = torch.randn(1, 12, seq_len, 64, device=device, dtype=torch.float16, requires_grad=True)
            V = torch.randn(1, 12, seq_len, 64, device=device, dtype=torch.float16, requires_grad=True)
            O = attention_with_grad(Q, K, V)
            loss = O.sum()
            loss.backward()
            
            training_mem = torch.cuda.max_memory_allocated() / 1e6
            
            ratio = training_mem / inference_mem
            
            print(f"{seq_len:<12} {inference_mem:<15.2f} MB {training_mem:<15.2f} MB {ratio:<10.2f}x")
            
            del Q, K, V, O, loss
            torch.cuda.empty_cache()
            
        except RuntimeError as e:
            print(f"{seq_len:<12} OOM")
    
    print("="*60)
    print("\n说明: 训练时需要保存中间结果用于反向传播，显存占用更大!")
else:
    print("CUDA不可用，跳过测试")


## 总结

通过本notebook，你应该理解了：

1. **O(N²)内存问题**
   - 中间矩阵S和P的大小是N×N
   - 随着序列长度增加，内存占用呈二次增长

2. **内存绑定特性**
   - Attention的计算强度(~64)远低于GPU平衡点(~150+)
   - 性能受限于HBM带宽而非计算能力

3. **训练时显存翻倍**
   - 反向传播需要保存前向的中间结果
   - 进一步加剧了内存问题

## 下一步

在下一节"IO-Aware算法设计"中，我们将学习如何从内存访问角度优化算法。
