# Recomputation重计算策略 - 实践篇

本notebook通过实际代码帮助你理解Recomputation策略。

**学习目标：**
- 理解PyTorch的gradient checkpointing
- 对比有无重计算的显存占用
- 理解时间与空间的权衡


## 环境准备


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
import matplotlib.pyplot as plt
import numpy as np
import time

# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA可用: {torch.cuda.is_available()}")


## 1. 理解反向传播的内存需求


In [None]:
def standard_attention(Q, K, V):
    """标准Attention - 保存所有中间结果"""
    scale = 1.0 / (Q.shape[-1] ** 0.5)
    
    # S矩阵: 需要保存用于反向传播
    S = torch.matmul(Q, K.transpose(-2, -1)) * scale
    
    # P矩阵: 需要保存用于反向传播
    P = F.softmax(S, dim=-1)
    
    # 输出
    O = torch.matmul(P, V)
    
    return O

def analyze_memory_for_backward(seq_len, d=64, batch_size=1, num_heads=12):
    """分析反向传播需要保存的内存"""
    bytes_per_element = 2  # FP16
    
    # 需要保存的tensor
    saved_tensors = {
        'Q': batch_size * num_heads * seq_len * d,
        'K': batch_size * num_heads * seq_len * d,
        'V': batch_size * num_heads * seq_len * d,
        'S (QK^T)': batch_size * num_heads * seq_len * seq_len,
        'P (softmax)': batch_size * num_heads * seq_len * seq_len,
    }
    
    total = sum(saved_tensors.values()) * bytes_per_element
    
    return saved_tensors, total

# 分析不同序列长度
print("反向传播需要保存的tensor (标准Attention)")
print("="*70)

for seq_len in [512, 1024, 2048, 4096]:
    tensors, total = analyze_memory_for_backward(seq_len)
    print(f"\n序列长度 N = {seq_len}")
    print("-"*50)
    for name, size in tensors.items():
        print(f"  {name}: {size * 2 / 1e6:.2f} MB")
    print(f"  总计: {total / 1e6:.2f} MB")
    
    # 计算S和P的占比
    sp_size = (tensors['S (QK^T)'] + tensors['P (softmax)']) * 2
    sp_ratio = sp_size / total * 100
    print(f"  S+P占比: {sp_ratio:.1f}%")

print("\n" + "="*70)


## 2. PyTorch Checkpoint使用示例


In [None]:
class AttentionLayer(nn.Module):
    """标准Attention层"""
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads
        
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        batch, seq_len, _ = x.shape
        
        # 投影
        q = self.q_proj(x).view(batch, seq_len, self.num_heads, self.d_head).transpose(1, 2)
        k = self.k_proj(x).view(batch, seq_len, self.num_heads, self.d_head).transpose(1, 2)
        v = self.v_proj(x).view(batch, seq_len, self.num_heads, self.d_head).transpose(1, 2)
        
        # Attention
        scale = 1.0 / (self.d_head ** 0.5)
        scores = torch.matmul(q, k.transpose(-2, -1)) * scale
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, v)
        
        # 输出投影
        out = out.transpose(1, 2).contiguous().view(batch, seq_len, self.d_model)
        return self.out_proj(out)

class AttentionWithCheckpoint(nn.Module):
    """使用Checkpoint的Attention层"""
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.attention = AttentionLayer(d_model, num_heads)
    
    def forward(self, x):
        # 使用checkpoint包装，减少内存占用
        return checkpoint(self.attention, x, use_reentrant=False)

# 测试
d_model = 768
num_heads = 12
batch_size = 2
seq_len = 128

x = torch.randn(batch_size, seq_len, d_model)

attn_normal = AttentionLayer(d_model, num_heads)
attn_checkpoint = AttentionWithCheckpoint(d_model, num_heads)
attn_checkpoint.attention = attn_normal  # 共享权重

# 验证输出相同
out1 = attn_normal(x)
out2 = attn_checkpoint(x)

print(f"输出形状: {out1.shape}")
print(f"输出是否相同: {torch.allclose(out1, out2, atol=1e-5)}")


## 3. 显存占用对比（需要GPU）


In [None]:
def measure_memory(model, x, use_checkpoint=False):
    """测量前向+反向传播的显存占用"""
    if not torch.cuda.is_available():
        return None
    
    device = torch.device('cuda')
    model = model.to(device)
    x = x.to(device).requires_grad_(True)
    
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()
    
    # 前向传播
    out = model(x)
    
    # 反向传播
    loss = out.sum()
    loss.backward()
    
    peak_memory = torch.cuda.max_memory_allocated()
    
    # 清理
    del out, loss, x
    model.cpu()
    torch.cuda.empty_cache()
    
    return peak_memory

if torch.cuda.is_available():
    print("显存占用对比 (前向+反向)")
    print("="*70)
    print(f"{'序列长度':<12} {'标准Attention':<20} {'With Checkpoint':<20} {'节省比例':<15}")
    print("="*70)
    
    d_model = 768
    num_heads = 12
    batch_size = 1
    
    for seq_len in [256, 512, 1024]:
        try:
            x = torch.randn(batch_size, seq_len, d_model)
            
            # 标准Attention
            model_normal = AttentionLayer(d_model, num_heads)
            mem_normal = measure_memory(model_normal, x.clone())
            
            # With Checkpoint
            model_ckpt = AttentionWithCheckpoint(d_model, num_heads)
            mem_ckpt = measure_memory(model_ckpt, x.clone())
            
            if mem_normal and mem_ckpt:
                saving = (mem_normal - mem_ckpt) / mem_normal * 100
                print(f"{seq_len:<12} {mem_normal/1e6:<20.2f} MB {mem_ckpt/1e6:<20.2f} MB {saving:<15.1f}%")
        except RuntimeError as e:
            print(f"{seq_len:<12} OOM")
    
    print("="*70)
else:
    print("CUDA不可用，跳过显存测量")
    print("\n理论分析：使用Checkpoint可以节省S和P矩阵的存储，约减少50-80%的激活值显存")


## 4. FlashAttention的保存策略模拟


In [None]:
def compare_save_strategies(seq_len, d=64, batch_size=1, num_heads=12):
    """对比不同保存策略的内存占用"""
    bytes_per_element = 2  # FP16
    
    # 标准Attention保存的数据
    standard_save = {
        'Q': batch_size * num_heads * seq_len * d,
        'K': batch_size * num_heads * seq_len * d,
        'V': batch_size * num_heads * seq_len * d,
        'S': batch_size * num_heads * seq_len * seq_len,
        'P': batch_size * num_heads * seq_len * seq_len,
    }
    
    # FlashAttention保存的数据
    flash_save = {
        'Q': batch_size * num_heads * seq_len * d,
        'K': batch_size * num_heads * seq_len * d,
        'V': batch_size * num_heads * seq_len * d,
        'O': batch_size * num_heads * seq_len * d,
        'm (row max)': batch_size * num_heads * seq_len,
        'l (row sum)': batch_size * num_heads * seq_len,
    }
    
    standard_total = sum(standard_save.values()) * bytes_per_element
    flash_total = sum(flash_save.values()) * bytes_per_element
    
    return standard_save, flash_save, standard_total, flash_total

# 可视化对比
print("保存策略对比分析")
print("="*70)

seq_lengths = [512, 1024, 2048, 4096, 8192]
standard_mem = []
flash_mem = []

for seq_len in seq_lengths:
    _, _, std_total, flash_total = compare_save_strategies(seq_len)
    standard_mem.append(std_total / 1e6)
    flash_mem.append(flash_total / 1e6)
    
    reduction = (std_total - flash_total) / std_total * 100
    print(f"N={seq_len:<5}: 标准={std_total/1e6:>8.1f}MB, Flash={flash_total/1e6:>8.1f}MB, 减少{reduction:.1f}%")

print("="*70)


## 5. 可视化保存策略对比


In [None]:
# 可视化
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# 左图: 绝对内存占用
ax1 = axes[0]
x = np.arange(len(seq_lengths))
width = 0.35

bars1 = ax1.bar(x - width/2, standard_mem, width, label='标准Attention', color='coral')
bars2 = ax1.bar(x + width/2, flash_mem, width, label='FlashAttention', color='steelblue')

ax1.set_xlabel('序列长度')
ax1.set_ylabel('保存的激活值大小 (MB)')
ax1.set_title('反向传播需要保存的数据量')
ax1.set_xticks(x)
ax1.set_xticklabels([str(s) for s in seq_lengths])
ax1.legend()
ax1.set_yscale('log')
ax1.grid(True, alpha=0.3, axis='y')

# 右图: 详细分解 (以N=2048为例)
ax2 = axes[1]

std_save, flash_save, _, _ = compare_save_strategies(2048)
bytes_per_element = 2

# 标准Attention的分解
std_labels = list(std_save.keys())
std_sizes = [v * bytes_per_element / 1e6 for v in std_save.values()]

# FlashAttention的分解
flash_labels = list(flash_save.keys())
flash_sizes = [v * bytes_per_element / 1e6 for v in flash_save.values()]

# 绘制堆叠柱状图
x_pos = [0, 1]
colors_std = plt.cm.Reds(np.linspace(0.3, 0.7, len(std_labels)))
colors_flash = plt.cm.Blues(np.linspace(0.3, 0.7, len(flash_labels)))

bottom_std = 0
for i, (label, size) in enumerate(zip(std_labels, std_sizes)):
    ax2.bar(0, size, bottom=bottom_std, color=colors_std[i], label=f'Std: {label}')
    bottom_std += size

bottom_flash = 0
for i, (label, size) in enumerate(zip(flash_labels, flash_sizes)):
    ax2.bar(1, size, bottom=bottom_flash, color=colors_flash[i], label=f'FA: {label}')
    bottom_flash += size

ax2.set_xticks([0, 1])
ax2.set_xticklabels(['标准Attention', 'FlashAttention'])
ax2.set_ylabel('保存的数据量 (MB)')
ax2.set_title('N=2048时的保存数据分解')
ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=9)

plt.tight_layout()
plt.show()


## 6. 时间vs空间权衡分析


In [None]:
def analyze_time_space_tradeoff(seq_len, d=64, batch_size=1, num_heads=12,
                                 peak_tflops=312, memory_bw_tbs=2.0):
    """分析时间vs空间的权衡"""
    bytes_per_element = 2
    
    # 计算量 (FLOPs)
    fwd_flops = batch_size * num_heads * (4 * seq_len * seq_len * d + 5 * seq_len * seq_len)
    
    # 标准方式
    std_save_bytes = batch_size * num_heads * (3 * seq_len * d + 2 * seq_len * seq_len) * bytes_per_element
    std_bwd_time = fwd_flops / (peak_tflops * 1e12)  # 只计算部分
    
    # FlashAttention方式
    flash_save_bytes = batch_size * num_heads * (4 * seq_len * d + 2 * seq_len) * bytes_per_element
    flash_bwd_time = 2 * fwd_flops / (peak_tflops * 1e12)  # 需要重新计算前向
    
    # 额外时间
    extra_time = flash_bwd_time - std_bwd_time
    
    # 节省内存
    saved_memory = std_save_bytes - flash_save_bytes
    
    return {
        'std_memory_mb': std_save_bytes / 1e6,
        'flash_memory_mb': flash_save_bytes / 1e6,
        'saved_memory_mb': saved_memory / 1e6,
        'extra_time_us': extra_time * 1e6,
        'memory_saving_ratio': saved_memory / std_save_bytes * 100,
    }

print("时间vs空间权衡分析")
print("="*70)
print(f"{'序列长度':<10} {'标准内存':<15} {'Flash内存':<15} {'节省内存':<15} {'额外时间':<15}")
print("="*70)

for seq_len in [512, 1024, 2048, 4096]:
    result = analyze_time_space_tradeoff(seq_len)
    print(f"{seq_len:<10} {result['std_memory_mb']:<15.1f}MB "
          f"{result['flash_memory_mb']:<15.1f}MB "
          f"{result['saved_memory_mb']:<15.1f}MB "
          f"{result['extra_time_us']:<15.1f}μs")

print("="*70)
print("\n结论: FlashAttention用少量额外计算时间换取大量内存节省")


## 总结

通过本notebook，你应该理解了：

1. **反向传播的内存需求**
   - 需要保存前向传播的中间结果
   - S和P矩阵占据大部分激活值内存

2. **PyTorch Checkpoint机制**
   - 使用`checkpoint`函数包装需要重计算的部分
   - 反向时自动重新执行前向计算

3. **FlashAttention的保存策略**
   - 只保存Q, K, V, O, m, l
   - 不保存S和P矩阵
   - 显存从O(N²)降到O(N)

4. **时间vs空间权衡**
   - 用额外的计算时间换取大量内存节省
   - 由于Attention是内存绑定的，实际总时间可能反而更快

## 章节总结

恭喜你完成了FlashAttention核心思想的学习！你现在应该理解：
- 为什么标准Attention有内存瓶颈
- IO-Aware算法设计的重要性
- Tiling如何解决大矩阵问题
- Recomputation如何节省显存

下一步可以学习**Online Softmax算法**，这是FlashAttention的数学核心。
