# IO-Aware算法设计 - 实践篇

本notebook通过实际代码帮助你理解IO-Aware算法设计的核心概念。

**学习目标：**
- 理解Roofline模型
- 比较不同实现的HBM访问量
- 可视化计算强度与性能的关系


## 环境准备


In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

print(f"PyTorch版本: {torch.__version__}")


## 1. Roofline模型可视化


In [None]:
def plot_roofline(peak_compute, memory_bandwidth, operations=None):
    """
    绘制Roofline模型图
    
    Args:
        peak_compute: 峰值计算能力 (TFLOPs)
        memory_bandwidth: 内存带宽 (TB/s)
        operations: 要标注的操作列表 [(name, arithmetic_intensity, color), ...]
    """
    # 计算脊点
    ridge_point = peak_compute / memory_bandwidth
    
    # 创建x轴（计算强度）
    ai = np.logspace(-1, 3, 1000)  # 0.1 到 1000 FLOPs/Byte
    
    # 计算性能上限
    memory_bound = ai * memory_bandwidth  # 内存绑定区域
    compute_bound = np.full_like(ai, peak_compute)  # 计算绑定区域
    performance = np.minimum(memory_bound, compute_bound)
    
    # 绘图
    fig, ax = plt.subplots(figsize=(12, 7))
    
    # Roofline曲线
    ax.loglog(ai, performance, 'b-', linewidth=2.5, label='Roofline')
    
    # 标注内存带宽限制和计算限制
    ax.axhline(y=peak_compute, color='r', linestyle='--', alpha=0.5, label=f'峰值计算: {peak_compute} TFLOPs')
    ax.axvline(x=ridge_point, color='g', linestyle='--', alpha=0.5, label=f'脊点: {ridge_point:.0f} FLOPs/Byte')
    
    # 填充区域
    ax.fill_between(ai[ai < ridge_point], 0, memory_bound[ai < ridge_point], 
                    alpha=0.1, color='orange', label='内存绑定区域')
    ax.fill_between(ai[ai >= ridge_point], 0, compute_bound[ai >= ridge_point], 
                    alpha=0.1, color='blue', label='计算绑定区域')
    
    # 标注操作点
    if operations:
        for name, op_ai, color in operations:
            op_perf = min(op_ai * memory_bandwidth, peak_compute)
            ax.scatter([op_ai], [op_perf], s=150, c=color, zorder=5, edgecolors='black')
            ax.annotate(name, (op_ai, op_perf), textcoords="offset points", 
                       xytext=(10, 10), fontsize=11, fontweight='bold')
    
    ax.set_xlabel('计算强度 (FLOPs/Byte)', fontsize=12)
    ax.set_ylabel('性能 (TFLOPs)', fontsize=12)
    ax.set_title('Roofline模型 - GPU性能分析', fontsize=14)
    ax.legend(loc='lower right')
    ax.grid(True, alpha=0.3)
    ax.set_xlim(0.1, 1000)
    ax.set_ylim(1, 500)
    
    plt.tight_layout()
    plt.show()
    
    return ridge_point

# A100 GPU参数
peak_compute = 312  # TFLOPs (FP16 Tensor Core)
memory_bandwidth = 2.0  # TB/s

# 常见操作的计算强度
operations = [
    ('Softmax', 2.5, 'red'),
    ('矩阵乘法 (小)', 30, 'orange'),
    ('Attention (标准)', 70, 'purple'),
    ('矩阵乘法 (大)', 200, 'green'),
]

ridge = plot_roofline(peak_compute, memory_bandwidth, operations)
print(f"\nA100 GPU脊点: {ridge:.0f} FLOPs/Byte")
print("标准Attention的计算强度约为70 FLOPs/Byte，远低于脊点，是内存绑定的！")


## 2. HBM访问量计算


In [None]:
def calculate_hbm_access(seq_len, d, batch_size=1, num_heads=12, 
                          sram_size=100*1024, method='standard'):
    """
    计算不同方法的HBM访问量
    
    Args:
        seq_len: 序列长度
        d: 头维度
        sram_size: SRAM大小 (bytes)
        method: 'standard' 或 'flash'
    """
    bytes_per_element = 2  # FP16
    
    # Q, K, V, O 的大小
    qkvo_size = batch_size * num_heads * seq_len * d * bytes_per_element
    
    # S, P 矩阵的大小
    sp_size = batch_size * num_heads * seq_len * seq_len * bytes_per_element
    
    if method == 'standard':
        # 标准Attention的HBM访问
        # 读: Q, K (for QK^T), S (for softmax), P, V (for PV)
        # 写: S, P, O
        hbm_read = 3 * qkvo_size + 2 * sp_size  # Q,K,V + S,P
        hbm_write = qkvo_size + 2 * sp_size     # O + S,P
        
    elif method == 'flash':
        # FlashAttention的HBM访问
        # 读: Q, K, V (多次读取，但总量约为 N²d/M)
        # 写: O
        M = sram_size // bytes_per_element  # SRAM能容纳的元素数
        
        # 外循环次数 ≈ N * d / M (Q块数)
        # 每次外循环读取所有K,V
        num_outer_loops = max(1, (seq_len * d) // M)
        
        hbm_read = qkvo_size + num_outer_loops * 2 * qkvo_size  # 初始Q + 多次K,V
        hbm_write = qkvo_size  # 只写O
        
    return hbm_read, hbm_write, hbm_read + hbm_write

# 对比不同序列长度下的HBM访问量
print("HBM访问量对比 (batch=1, heads=12, d=64, FP16)")
print("="*80)
print(f"{'序列长度':<12} {'标准Attention':<20} {'FlashAttention':<20} {'减少比例':<15}")
print("="*80)

seq_lengths = [512, 1024, 2048, 4096, 8192]
standard_access = []
flash_access = []

for seq_len in seq_lengths:
    std_r, std_w, std_total = calculate_hbm_access(seq_len, 64, method='standard')
    flash_r, flash_w, flash_total = calculate_hbm_access(seq_len, 64, method='flash')
    
    standard_access.append(std_total / 1e6)
    flash_access.append(flash_total / 1e6)
    
    reduction = (std_total - flash_total) / std_total * 100
    
    print(f"{seq_len:<12} {std_total/1e6:<20.2f} MB {flash_total/1e6:<20.2f} MB {reduction:<15.1f}%")

print("="*80)


## 3. 可视化HBM访问量对比


In [None]:
# 可视化HBM访问量
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# 左图：绝对访问量
ax1 = axes[0]
x = np.arange(len(seq_lengths))
width = 0.35

bars1 = ax1.bar(x - width/2, standard_access, width, label='标准Attention', color='coral')
bars2 = ax1.bar(x + width/2, flash_access, width, label='FlashAttention', color='steelblue')

ax1.set_xlabel('序列长度')
ax1.set_ylabel('HBM访问量 (MB)')
ax1.set_title('HBM访问量对比')
ax1.set_xticks(x)
ax1.set_xticklabels([str(s) for s in seq_lengths])
ax1.legend()
ax1.set_yscale('log')
ax1.grid(True, alpha=0.3, axis='y')

# 右图：减少比例
ax2 = axes[1]
reduction_ratios = [(s - f) / s * 100 for s, f in zip(standard_access, flash_access)]

bars = ax2.bar(x, reduction_ratios, color='green', alpha=0.7)
ax2.set_xlabel('序列长度')
ax2.set_ylabel('HBM访问减少比例 (%)')
ax2.set_title('FlashAttention减少的HBM访问量')
ax2.set_xticks(x)
ax2.set_xticklabels([str(s) for s in seq_lengths])
ax2.axhline(y=50, color='red', linestyle='--', label='50%减少')
ax2.legend()
ax2.grid(True, alpha=0.3, axis='y')

# 添加数值标签
for bar, ratio in zip(bars, reduction_ratios):
    height = bar.get_height()
    ax2.annotate(f'{ratio:.1f}%',
                xy=(bar.get_x() + bar.get_width() / 2, height),
                xytext=(0, 3),
                textcoords="offset points",
                ha='center', va='bottom', fontsize=10)

plt.tight_layout()
plt.show()


## 4. IO成本与计算成本的对比


In [None]:
def estimate_time_breakdown(seq_len, d=64, batch_size=1, num_heads=12,
                            peak_tflops=312, memory_bw_tbs=2.0):
    """
    估算计算时间和内存访问时间的占比
    """
    bytes_per_element = 2  # FP16
    
    # 计算量 (FLOPs)
    # QK^T: 2 * N * N * d, PV: 2 * N * N * d, Softmax: ~5 * N * N
    total_flops = batch_size * num_heads * (4 * seq_len * seq_len * d + 5 * seq_len * seq_len)
    
    # 标准Attention的HBM访问量 (bytes)
    qkvo_bytes = 4 * batch_size * num_heads * seq_len * d * bytes_per_element
    sp_bytes = 2 * batch_size * num_heads * seq_len * seq_len * bytes_per_element
    total_bytes = qkvo_bytes + 2 * sp_bytes  # 读写S和P各一次
    
    # 时间估算
    compute_time = total_flops / (peak_tflops * 1e12)  # 秒
    memory_time = total_bytes / (memory_bw_tbs * 1e12)  # 秒
    
    return compute_time, memory_time

print("计算时间 vs 内存访问时间 (A100 GPU)")
print("="*70)
print(f"{'序列长度':<12} {'计算时间(μs)':<15} {'内存时间(μs)':<15} {'内存时间占比':<15}")
print("="*70)

compute_times = []
memory_times = []

for seq_len in seq_lengths:
    ct, mt = estimate_time_breakdown(seq_len)
    compute_times.append(ct * 1e6)  # 转换为微秒
    memory_times.append(mt * 1e6)
    
    mem_ratio = mt / (ct + mt) * 100
    print(f"{seq_len:<12} {ct*1e6:<15.2f} {mt*1e6:<15.2f} {mem_ratio:<15.1f}%")

print("="*70)
print("\n结论：内存访问时间占总时间的大部分，这就是为什么需要IO-Aware优化！")


## 5. 可视化时间占比


In [None]:
# 可视化时间占比
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# 左图：堆叠柱状图
ax1 = axes[0]
x = np.arange(len(seq_lengths))
width = 0.6

ax1.bar(x, compute_times, width, label='计算时间', color='steelblue')
ax1.bar(x, memory_times, width, bottom=compute_times, label='内存访问时间', color='coral')

ax1.set_xlabel('序列长度')
ax1.set_ylabel('时间 (μs)')
ax1.set_title('Attention执行时间分解')
ax1.set_xticks(x)
ax1.set_xticklabels([str(s) for s in seq_lengths])
ax1.legend()
ax1.grid(True, alpha=0.3, axis='y')

# 右图：饼图（以N=4096为例）
ax2 = axes[1]
idx = seq_lengths.index(4096)
sizes = [compute_times[idx], memory_times[idx]]
labels = [f'计算时间\n{compute_times[idx]:.1f}μs', f'内存访问时间\n{memory_times[idx]:.1f}μs']
colors = ['steelblue', 'coral']
explode = (0, 0.05)

ax2.pie(sizes, explode=explode, labels=labels, colors=colors, autopct='%1.1f%%',
        shadow=True, startangle=90)
ax2.set_title(f'N=4096时的时间分解')

plt.tight_layout()
plt.show()


## 6. 算子融合的效果模拟


In [None]:
def simulate_kernel_fusion_effect(seq_len, d=64, batch_size=1, num_heads=12):
    """
    模拟算子融合的效果
    """
    bytes_per_element = 2
    
    # 非融合版本：三个独立kernel
    # Kernel 1: 读Q,K -> 写S
    # Kernel 2: 读S -> 写P
    # Kernel 3: 读P,V -> 写O
    
    qkv_size = batch_size * num_heads * seq_len * d * bytes_per_element
    sp_size = batch_size * num_heads * seq_len * seq_len * bytes_per_element
    
    unfused_io = {
        'K1_read': 2 * qkv_size,  # Q, K
        'K1_write': sp_size,       # S
        'K2_read': sp_size,        # S
        'K2_write': sp_size,       # P
        'K3_read': sp_size + qkv_size,  # P, V
        'K3_write': qkv_size,      # O
    }
    unfused_total = sum(unfused_io.values())
    
    # 融合版本：单个kernel
    # 只需要读Q,K,V，写O
    fused_io = {
        'read': 3 * qkv_size,  # Q, K, V
        'write': qkv_size,     # O
    }
    fused_total = sum(fused_io.values())
    
    return unfused_total, fused_total, unfused_io, fused_io

print("算子融合效果模拟")
print("="*70)

for seq_len in [1024, 2048, 4096]:
    unfused, fused, unfused_io, fused_io = simulate_kernel_fusion_effect(seq_len)
    reduction = (unfused - fused) / unfused * 100
    
    print(f"\n序列长度 N = {seq_len}")
    print("-"*50)
    print("非融合版本 (3个kernel):")
    for k, v in unfused_io.items():
        print(f"  {k}: {v/1e6:.2f} MB")
    print(f"  总计: {unfused/1e6:.2f} MB")
    
    print("\n融合版本 (1个kernel):")
    for k, v in fused_io.items():
        print(f"  {k}: {v/1e6:.2f} MB")
    print(f"  总计: {fused/1e6:.2f} MB")
    
    print(f"\n减少IO: {reduction:.1f}%")
    print("-"*50)


## 总结

通过本notebook，你应该理解了：

1. **Roofline模型**
   - 性能上限 = min(峰值计算, 计算强度 × 带宽)
   - Attention处于内存绑定区域

2. **HBM访问量分析**
   - 标准Attention的中间矩阵S,P导致大量HBM访问
   - FlashAttention通过分块和融合显著减少HBM访问

3. **时间分解**
   - 内存访问时间占总时间的大部分
   - 优化内存访问比优化计算更重要

4. **算子融合**
   - 将多个操作合并成单个kernel
   - 避免中间结果写回HBM

## 下一步

在下一节"Tiling分块技术"中，我们将学习如何将大矩阵分块处理。
