# Tiling分块技术 - 实践篇

本notebook通过实际代码帮助你理解Tiling分块技术。

**学习目标：**
- 实现分块矩阵乘法
- 实现简化版FlashAttention
- 验证分块计算的正确性


## 环境准备


In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

print(f"PyTorch版本: {torch.__version__}")


## 1. 分块矩阵乘法


In [None]:
def tiled_matmul(A, B, block_size=64):
    """
    分块矩阵乘法
    
    Args:
        A: [M, K] 矩阵
        B: [K, N] 矩阵
        block_size: 块大小
    
    Returns:
        C: [M, N] = A @ B
    """
    M, K = A.shape
    K2, N = B.shape
    assert K == K2, "矩阵维度不匹配"
    
    # 初始化输出
    C = torch.zeros(M, N, dtype=A.dtype, device=A.device)
    
    # 计算块数
    num_blocks_m = (M + block_size - 1) // block_size
    num_blocks_n = (N + block_size - 1) // block_size
    num_blocks_k = (K + block_size - 1) // block_size
    
    # 三层循环：遍历输出块
    for i in range(num_blocks_m):
        for j in range(num_blocks_n):
            # 输出块的范围
            m_start = i * block_size
            m_end = min((i + 1) * block_size, M)
            n_start = j * block_size
            n_end = min((j + 1) * block_size, N)
            
            # 累加器
            C_block = torch.zeros(m_end - m_start, n_end - n_start, 
                                  dtype=A.dtype, device=A.device)
            
            # 内层循环：累加K维度的贡献
            for k in range(num_blocks_k):
                k_start = k * block_size
                k_end = min((k + 1) * block_size, K)
                
                # 获取A和B的块
                A_block = A[m_start:m_end, k_start:k_end]
                B_block = B[k_start:k_end, n_start:n_end]
                
                # 累加
                C_block += A_block @ B_block
            
            # 写入结果
            C[m_start:m_end, n_start:n_end] = C_block
    
    return C

# 测试分块矩阵乘法
M, K, N = 256, 128, 192
A = torch.randn(M, K)
B = torch.randn(K, N)

# 标准矩阵乘法
C_standard = A @ B

# 分块矩阵乘法
C_tiled = tiled_matmul(A, B, block_size=64)

# 验证正确性
error = (C_standard - C_tiled).abs().max().item()
print(f"矩阵大小: A={A.shape}, B={B.shape}")
print(f"最大误差: {error:.2e}")
print(f"结果正确: {error < 1e-5}")


## 2. 可视化分块过程


In [None]:
def visualize_tiling(N, block_size):
    """可视化矩阵分块"""
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    num_blocks = (N + block_size - 1) // block_size
    
    # 创建分块矩阵的颜色图
    colors = plt.cm.Set3(np.linspace(0, 1, num_blocks * num_blocks))
    
    for ax_idx, (title, shape) in enumerate([
        ('Q 矩阵', (N, 64)),
        ('K^T 矩阵', (64, N)),
        ('S = QK^T', (N, N))
    ]):
        ax = axes[ax_idx]
        
        # 绘制分块
        for i in range(num_blocks):
            for j in range(num_blocks):
                if ax_idx == 0:  # Q矩阵，只有行分块
                    if j == 0:
                        rect = plt.Rectangle((0, i * block_size), shape[1], block_size,
                                            fill=True, alpha=0.5, 
                                            facecolor=colors[i * num_blocks])
                        ax.add_patch(rect)
                        ax.text(shape[1]/2, i * block_size + block_size/2, 
                               f'Q_{i+1}', ha='center', va='center', fontsize=12)
                elif ax_idx == 1:  # K^T矩阵，只有列分块
                    if i == 0:
                        rect = plt.Rectangle((j * block_size, 0), block_size, shape[0],
                                            fill=True, alpha=0.5, 
                                            facecolor=colors[j])
                        ax.add_patch(rect)
                        ax.text(j * block_size + block_size/2, shape[0]/2, 
                               f'K^T_{j+1}', ha='center', va='center', fontsize=12)
                else:  # S矩阵，行列都分块
                    rect = plt.Rectangle((j * block_size, i * block_size), 
                                        block_size, block_size,
                                        fill=True, alpha=0.5, 
                                        facecolor=colors[i * num_blocks + j])
                    ax.add_patch(rect)
                    ax.text(j * block_size + block_size/2, 
                           i * block_size + block_size/2, 
                           f'S_{i+1},{j+1}', ha='center', va='center', fontsize=10)
        
        ax.set_xlim(0, shape[1])
        ax.set_ylim(0, shape[0])
        ax.set_aspect('equal')
        ax.invert_yaxis()
        ax.set_title(f'{title}\n形状: {shape}', fontsize=14)
        ax.set_xlabel('列')
        ax.set_ylabel('行')
        
        # 添加网格
        for i in range(num_blocks + 1):
            ax.axhline(y=i * block_size, color='black', linewidth=0.5)
        for j in range(num_blocks + 1):
            ax.axvline(x=j * block_size if ax_idx != 0 else 0, color='black', linewidth=0.5)
        if ax_idx == 0:
            ax.axvline(x=shape[1], color='black', linewidth=0.5)
    
    plt.tight_layout()
    plt.suptitle(f'矩阵分块示意图 (N={N}, block_size={block_size})', y=1.02, fontsize=16)
    plt.show()

visualize_tiling(256, 64)


## 3. 简化版FlashAttention实现


In [None]:
def flash_attention_forward(Q, K, V, Br=64, Bc=64):
    """
    FlashAttention前向传播的简化Python实现
    
    Args:
        Q, K, V: [batch, seq_len, d]
        Br: Q的块大小
        Bc: K/V的块大小
    
    Returns:
        O: [batch, seq_len, d]
    """
    batch, N, d = Q.shape
    scale = 1.0 / (d ** 0.5)
    
    # 计算块数
    Tr = (N + Br - 1) // Br
    Tc = (N + Bc - 1) // Bc
    
    # 初始化输出
    O = torch.zeros_like(Q)
    
    # 外循环：遍历Q的块
    for i in range(Tr):
        # 获取Q的第i块
        q_start = i * Br
        q_end = min((i + 1) * Br, N)
        Q_i = Q[:, q_start:q_end, :]  # [batch, Br_actual, d]
        
        Br_actual = q_end - q_start
        
        # 初始化softmax统计量
        m_i = torch.full((batch, Br_actual), float('-inf'), 
                         dtype=Q.dtype, device=Q.device)
        l_i = torch.zeros(batch, Br_actual, dtype=Q.dtype, device=Q.device)
        O_i = torch.zeros(batch, Br_actual, d, dtype=Q.dtype, device=Q.device)
        
        # 内循环：遍历K/V的块
        for j in range(Tc):
            # 获取K/V的第j块
            kv_start = j * Bc
            kv_end = min((j + 1) * Bc, N)
            K_j = K[:, kv_start:kv_end, :]  # [batch, Bc_actual, d]
            V_j = V[:, kv_start:kv_end, :]  # [batch, Bc_actual, d]
            
            # 计算局部注意力分数 S_ij = Q_i @ K_j^T / sqrt(d)
            S_ij = torch.matmul(Q_i, K_j.transpose(-2, -1)) * scale  # [batch, Br_actual, Bc_actual]
            
            # Online Softmax: 更新最大值
            m_ij = S_ij.max(dim=-1).values  # [batch, Br_actual]
            m_new = torch.maximum(m_i, m_ij)
            
            # 计算指数和更新
            exp_m_old = torch.exp(m_i - m_new)  # [batch, Br_actual]
            P_ij = torch.exp(S_ij - m_new.unsqueeze(-1))  # [batch, Br_actual, Bc_actual]
            l_new = exp_m_old * l_i + P_ij.sum(dim=-1)  # [batch, Br_actual]
            
            # 更新输出
            O_i = exp_m_old.unsqueeze(-1) * O_i + torch.matmul(P_ij, V_j)
            
            # 更新统计量
            m_i = m_new
            l_i = l_new
        
        # 最终归一化
        O_i = O_i / l_i.unsqueeze(-1)
        
        # 写回输出
        O[:, q_start:q_end, :] = O_i
    
    return O

# 测试FlashAttention
batch_size = 2
seq_len = 256
d = 64

Q = torch.randn(batch_size, seq_len, d)
K = torch.randn(batch_size, seq_len, d)
V = torch.randn(batch_size, seq_len, d)

print("测试FlashAttention实现...")
print(f"输入形状: Q={Q.shape}, K={K.shape}, V={V.shape}")


## 4. 验证正确性


In [None]:
def standard_attention(Q, K, V):
    """标准Attention实现"""
    scale = 1.0 / (Q.shape[-1] ** 0.5)
    S = torch.matmul(Q, K.transpose(-2, -1)) * scale
    P = F.softmax(S, dim=-1)
    O = torch.matmul(P, V)
    return O

# 对比标准实现和FlashAttention
O_standard = standard_attention(Q, K, V)
O_flash = flash_attention_forward(Q, K, V, Br=64, Bc=64)

# 计算误差
max_error = (O_standard - O_flash).abs().max().item()
mean_error = (O_standard - O_flash).abs().mean().item()
relative_error = ((O_standard - O_flash).abs() / (O_standard.abs() + 1e-8)).mean().item()

print("="*50)
print("正确性验证")
print("="*50)
print(f"最大绝对误差: {max_error:.2e}")
print(f"平均绝对误差: {mean_error:.2e}")
print(f"平均相对误差: {relative_error:.2e}")
print(f"结果匹配: {max_error < 1e-4}")
print("="*50)


## 5. 不同块大小的影响


In [None]:
def analyze_block_sizes(seq_len, d=64):
    """分析不同块大小的SRAM使用和循环次数"""
    block_sizes = [16, 32, 64, 128, 256]
    
    results = []
    for Br in block_sizes:
        for Bc in block_sizes:
            # SRAM使用量估算 (FP16)
            bytes_per_element = 2
            sram_usage = (
                Br * d +  # Q块
                Bc * d +  # K块
                Bc * d +  # V块
                Br * Bc + # S块
                Br * d +  # O块
                2 * Br    # m, l统计量
            ) * bytes_per_element
            
            # 循环次数
            Tr = (seq_len + Br - 1) // Br
            Tc = (seq_len + Bc - 1) // Bc
            total_iterations = Tr * Tc
            
            results.append({
                'Br': Br, 'Bc': Bc,
                'SRAM (KB)': sram_usage / 1024,
                'Iterations': total_iterations
            })
    
    return results

# 分析N=1024的情况
results = analyze_block_sizes(1024)

# 显示部分结果
print("块大小与资源使用分析 (N=1024, d=64)")
print("="*60)
print(f"{'Br':<8} {'Bc':<8} {'SRAM使用(KB)':<15} {'迭代次数':<15}")
print("="*60)

for r in results[:15]:  # 只显示前15个
    print(f"{r['Br']:<8} {r['Bc']:<8} {r['SRAM (KB)']:<15.1f} {r['Iterations']:<15}")

print("="*60)
print("\n注意: SRAM使用需要小于GPU共享内存大小 (如A100为~164KB/SM)")


## 6. 可视化块大小权衡


In [None]:
# 可视化块大小与SRAM使用、迭代次数的关系
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# 假设 Br = Bc
block_sizes = [16, 32, 64, 128, 256]
sram_usage = []
iterations_1024 = []
iterations_4096 = []

for bs in block_sizes:
    # SRAM使用 (KB)
    usage = (bs * 64 + 2 * bs * 64 + bs * bs + bs * 64 + 2 * bs) * 2 / 1024
    sram_usage.append(usage)
    
    # 迭代次数
    Tr_1024 = (1024 + bs - 1) // bs
    Tc_1024 = (1024 + bs - 1) // bs
    iterations_1024.append(Tr_1024 * Tc_1024)
    
    Tr_4096 = (4096 + bs - 1) // bs
    Tc_4096 = (4096 + bs - 1) // bs
    iterations_4096.append(Tr_4096 * Tc_4096)

# 左图: SRAM使用
ax1 = axes[0]
bars = ax1.bar(range(len(block_sizes)), sram_usage, color='steelblue', alpha=0.7)
ax1.axhline(y=164, color='red', linestyle='--', label='A100 SRAM限制 (164KB)')
ax1.set_xticks(range(len(block_sizes)))
ax1.set_xticklabels([str(bs) for bs in block_sizes])
ax1.set_xlabel('块大小 (Br = Bc)')
ax1.set_ylabel('SRAM使用 (KB)')
ax1.set_title('块大小与SRAM使用')
ax1.legend()
ax1.grid(True, alpha=0.3, axis='y')

# 添加数值标签
for bar, usage in zip(bars, sram_usage):
    ax1.annotate(f'{usage:.1f}',
                xy=(bar.get_x() + bar.get_width() / 2, bar.get_height()),
                xytext=(0, 3), textcoords="offset points",
                ha='center', va='bottom', fontsize=10)

# 右图: 迭代次数
ax2 = axes[1]
x = np.arange(len(block_sizes))
width = 0.35
bars1 = ax2.bar(x - width/2, iterations_1024, width, label='N=1024', color='steelblue')
bars2 = ax2.bar(x + width/2, iterations_4096, width, label='N=4096', color='coral')
ax2.set_xticks(x)
ax2.set_xticklabels([str(bs) for bs in block_sizes])
ax2.set_xlabel('块大小 (Br = Bc)')
ax2.set_ylabel('迭代次数')
ax2.set_title('块大小与迭代次数')
ax2.legend()
ax2.set_yscale('log')
ax2.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print("\n结论: 需要在SRAM容量限制和迭代次数之间找到平衡")


## 总结

通过本notebook，你应该理解了：

1. **分块矩阵乘法**
   - 将大矩阵分成小块逐步计算
   - 结果与直接计算完全相同

2. **FlashAttention的分块策略**
   - 外循环遍历Q块，内循环遍历K/V块
   - 使用Online Softmax增量更新

3. **块大小的权衡**
   - 大块: 更少迭代，但需要更多SRAM
   - 小块: 更多迭代，但SRAM需求小
   - 需要根据GPU硬件选择合适的块大小

## 下一步

在下一节"Recomputation重计算策略"中，我们将学习如何用计算换内存。
