# local_tile局部分块 - 实践篇

本notebook帮助你理解cute的local_tile操作及其在FlashAttention中的应用。

**学习目标：**
- 理解local_tile的分块语义
- 可视化分块过程
- 分析FlashAttention中的分块策略


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches

plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

print("环境准备完成！")


## 1. local_tile基本概念可视化


In [None]:
def visualize_local_tile_basic():
    """可视化local_tile的基本概念"""
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    
    # 参数
    total_rows = 2048
    total_cols = 64
    tile_rows = 128
    tile_cols = 64
    
    num_row_tiles = total_rows // tile_rows
    
    # 左图：完整矩阵和分块
    ax = axes[0]
    ax.set_xlim(0, 10)
    ax.set_ylim(0, 10)
    
    # 画完整矩阵
    rect = patches.Rectangle((1, 1), 3, 8, linewidth=2, 
                               edgecolor='black', facecolor='lightblue')
    ax.add_patch(rect)
    ax.text(2.5, 0.5, f'mQ ({total_rows}×{total_cols})', ha='center', fontsize=10)
    
    # 画分块线
    tile_height = 8 / num_row_tiles
    colors = plt.cm.Set3(np.linspace(0, 1, num_row_tiles))
    
    for i in range(num_row_tiles):
        y = 1 + i * tile_height
        tile = patches.Rectangle((1, y), 3, tile_height, linewidth=1,
                                   edgecolor='gray', facecolor=colors[i], alpha=0.7)
        ax.add_patch(tile)
        if i < 4 or i >= num_row_tiles - 2:
            ax.text(0.8, y + tile_height/2, f'{i}', ha='right', va='center', fontsize=8)
        elif i == 4:
            ax.text(0.8, y + tile_height/2, '...', ha='right', va='center', fontsize=8)
    
    ax.text(2.5, 9.3, f'{num_row_tiles}个分块，每块{tile_rows}行', ha='center', fontsize=10)
    ax.set_title('原始Tensor与分块划分', fontsize=12)
    ax.axis('off')
    
    # 右图：local_tile操作
    ax = axes[1]
    ax.set_xlim(0, 10)
    ax.set_ylim(0, 10)
    
    # 原矩阵（缩小）
    rect = patches.Rectangle((0.5, 2), 2, 6, linewidth=2, 
                               edgecolor='black', facecolor='lightblue')
    ax.add_patch(rect)
    ax.text(1.5, 1.5, 'mQ', ha='center', fontsize=10)
    
    # 高亮第3个块
    highlight_idx = 3
    tile_h = 6 / num_row_tiles
    highlight = patches.Rectangle((0.5, 2 + highlight_idx * tile_h), 2, tile_h,
                                    linewidth=3, edgecolor='red', facecolor='coral')
    ax.add_patch(highlight)
    
    # 箭头
    ax.annotate('', xy=(5, 5), xytext=(3, 5),
                arrowprops=dict(arrowstyle='->', color='green', lw=2))
    ax.text(4, 5.5, 'local_tile\n(tile_shape, coord)', ha='center', fontsize=9)
    
    # 结果块
    result = patches.Rectangle((5.5, 3.5), 3, 3, linewidth=2,
                                 edgecolor='red', facecolor='coral')
    ax.add_patch(result)
    ax.text(7, 3, f'gQ ({tile_rows}×{tile_cols})', ha='center', fontsize=10)
    ax.text(7, 5, '第3块的视图', ha='center', va='center', fontsize=11)
    
    # 代码示例
    ax.text(5, 1.5, 'local_tile(mQ,\n  Shape<_128, _64>{},\n  make_coord(3, 0))', 
            ha='center', fontsize=9, family='monospace',
            bbox=dict(boxstyle='round', facecolor='wheat'))
    
    ax.set_title('local_tile操作：获取第3个分块', fontsize=12)
    ax.axis('off')
    
    plt.suptitle('local_tile 基本概念', fontsize=14)
    plt.tight_layout()
    plt.show()

visualize_local_tile_basic()


## 2. Python模拟local_tile


In [None]:
class TensorView:
    """模拟cute Tensor的视图"""
    def __init__(self, data, offset=0, shape=None, name="tensor"):
        self.data = data  # 底层数据（共享）
        self.offset = offset  # 起始偏移
        self.shape = shape if shape else data.shape
        self.name = name
    
    def __getitem__(self, idx):
        """支持切片访问"""
        if isinstance(idx, tuple):
            actual_idx = tuple(i + self.offset if j == 0 else i 
                              for j, i in enumerate(idx))
            return self.data[actual_idx]
        return self.data[idx + self.offset]
    
    def __repr__(self):
        return f"{self.name}: shape={self.shape}, offset={self.offset}"


def local_tile_py(tensor, tile_shape, tile_coord):
    """
    Python模拟cute的local_tile操作
    
    Args:
        tensor: numpy数组或TensorView
        tile_shape: (tile_rows, tile_cols)
        tile_coord: (row_block, col_block) 或 (None, col_block) 表示通配符
    
    Returns:
        分块视图
    """
    data = tensor.data if isinstance(tensor, TensorView) else tensor
    base_offset = tensor.offset if isinstance(tensor, TensorView) else 0
    
    tile_rows, tile_cols = tile_shape
    row_coord, col_coord = tile_coord
    
    if row_coord is None:
        # 通配符：返回所有行块
        num_row_blocks = data.shape[0] // tile_rows
        result_shape = (tile_rows, tile_cols, num_row_blocks)
        return TensorView(data, base_offset, result_shape, 
                         f"local_tile(all_rows, col={col_coord})")
    else:
        # 固定行块
        row_offset = row_coord * tile_rows
        col_offset = col_coord * tile_cols
        result_shape = (tile_rows, tile_cols)
        return TensorView(data, base_offset + row_offset, result_shape,
                         f"local_tile(row={row_coord}, col={col_coord})")


# 测试
print("=" * 60)
print("Python模拟 local_tile")
print("=" * 60)

# 创建模拟的Q矩阵
seqlen_q = 2048
headdim = 64
kBlockM = 128

mQ = np.arange(seqlen_q * headdim).reshape(seqlen_q, headdim)
print(f"\n原始Tensor mQ: shape = {mQ.shape}")
print(f"分块大小: {kBlockM} × {headdim}")
print(f"分块数量: {seqlen_q // kBlockM}")

# 获取第3个块
m_block = 3
gQ = local_tile_py(mQ, (kBlockM, headdim), (m_block, 0))
print(f"\n获取第{m_block}个块: {gQ}")

# 验证：检查第一个元素
expected_first = m_block * kBlockM * headdim
actual_first = mQ[m_block * kBlockM, 0]
print(f"第{m_block}块第一个元素: 预期={expected_first}, 实际={actual_first}")

# 获取所有块
gQ_all = local_tile_py(mQ, (kBlockM, headdim), (None, 0))
print(f"\n获取所有块: {gQ_all}")


## 3. 可视化FlashAttention的分块策略


In [None]:
def visualize_flash_attention_tiling():
    """可视化FlashAttention的分块计算"""
    fig, axes = plt.subplots(1, 3, figsize=(16, 6))
    
    # 参数
    seqlen_q = 2048
    seqlen_k = 2048
    headdim = 64
    kBlockM = 128
    kBlockN = 64
    
    num_m_blocks = seqlen_q // kBlockM
    num_n_blocks = seqlen_k // kBlockN
    
    # 1. Q矩阵分块
    ax = axes[0]
    q_grid = np.zeros((num_m_blocks, 1))
    im = ax.imshow(q_grid, cmap='Blues', aspect='auto')
    
    # 高亮当前处理的块
    current_m = 5
    ax.add_patch(patches.Rectangle((-0.5, current_m - 0.5), 1, 1, 
                                    fill=False, edgecolor='red', linewidth=3))
    
    for i in range(num_m_blocks):
        if i < 4 or i >= num_m_blocks - 2 or i == current_m:
            ax.text(0, i, f'Q{i}', ha='center', va='center', fontsize=8)
        elif i == 4:
            ax.text(0, i, '...', ha='center', va='center', fontsize=8)
    
    ax.set_title(f'Q矩阵\n{num_m_blocks}块，每块{kBlockM}行\n红框=当前Thread Block处理', fontsize=10)
    ax.set_ylabel('M维度')
    ax.set_xticks([])
    ax.set_yticks([])
    
    # 2. K/V矩阵分块
    ax = axes[1]
    kv_grid = np.zeros((num_n_blocks, 1))
    im = ax.imshow(kv_grid, cmap='Greens', aspect='auto')
    
    for i in range(num_n_blocks):
        if i < 4 or i >= num_n_blocks - 2:
            ax.text(0, i, f'K{i}', ha='center', va='center', fontsize=8)
        elif i == 4:
            ax.text(0, i, '...', ha='center', va='center', fontsize=8)
    
    ax.set_title(f'K/V矩阵\n{num_n_blocks}块，每块{kBlockN}行\n需要遍历所有块', fontsize=10)
    ax.set_ylabel('N维度')
    ax.set_xticks([])
    ax.set_yticks([])
    
    # 3. 注意力矩阵分块计算
    ax = axes[2]
    attn_grid = np.zeros((num_m_blocks, num_n_blocks))
    
    # 标记当前计算的行
    attn_grid[current_m, :] = 1
    
    im = ax.imshow(attn_grid, cmap='OrRd', aspect='auto')
    
    # 标记
    ax.axhline(y=current_m - 0.5, color='red', linewidth=2)
    ax.axhline(y=current_m + 0.5, color='red', linewidth=2)
    
    ax.set_title(f'注意力矩阵 S=QK^T\n每个Thread Block处理一行\n遍历所有列', fontsize=10)
    ax.set_xlabel('N维度（K块）')
    ax.set_ylabel('M维度（Q块）')
    ax.set_xticks([])
    ax.set_yticks([])
    
    plt.suptitle('FlashAttention 分块计算策略', fontsize=14)
    plt.tight_layout()
    plt.show()
    
    print(f"配置: seqlen_q={seqlen_q}, seqlen_k={seqlen_k}, headdim={headdim}")
    print(f"分块: kBlockM={kBlockM}, kBlockN={kBlockN}")
    print(f"Q块数: {num_m_blocks}, K块数: {num_n_blocks}")
    print(f"\n每个Thread Block:")
    print(f"  - 处理固定的1个Q块 (local_tile with fixed coord)")
    print(f"  - 遍历所有{num_n_blocks}个K/V块 (local_tile with _ wildcard)")

visualize_flash_attention_tiling()


## 4. 分块坐标详解


In [None]:
def visualize_tile_coordinates():
    """可视化不同tile_coord的效果"""
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # 模拟一个4x4块的矩阵
    num_m = 4
    num_n = 4
    
    # 1. make_coord(2, 1) - 固定坐标
    ax = axes[0]
    grid = np.zeros((num_m, num_n))
    grid[2, 1] = 1  # 高亮(2,1)
    
    im = ax.imshow(grid, cmap='Reds', vmin=0, vmax=1)
    for i in range(num_m):
        for j in range(num_n):
            ax.text(j, i, f'({i},{j})', ha='center', va='center', fontsize=10)
    
    ax.set_title('make_coord(2, 1)\n返回单个块 (TileM, TileN)', fontsize=11)
    ax.set_xlabel('N维度')
    ax.set_ylabel('M维度')
    
    # 2. make_coord(_, 1) - M维度通配符
    ax = axes[1]
    grid = np.zeros((num_m, num_n))
    grid[:, 1] = 1  # 高亮第1列所有块
    
    im = ax.imshow(grid, cmap='Greens', vmin=0, vmax=1)
    for i in range(num_m):
        for j in range(num_n):
            ax.text(j, i, f'({i},{j})', ha='center', va='center', fontsize=10)
    
    ax.set_title('make_coord(_, 1)\n返回所有M块 (TileM, TileN, num_m)', fontsize=11)
    ax.set_xlabel('N维度')
    ax.set_ylabel('M维度')
    
    # 3. make_coord(2, _) - N维度通配符
    ax = axes[2]
    grid = np.zeros((num_m, num_n))
    grid[2, :] = 1  # 高亮第2行所有块
    
    im = ax.imshow(grid, cmap='Blues', vmin=0, vmax=1)
    for i in range(num_m):
        for j in range(num_n):
            ax.text(j, i, f'({i},{j})', ha='center', va='center', fontsize=10)
    
    ax.set_title('make_coord(2, _)\n返回所有N块 (TileM, TileN, num_n)', fontsize=11)
    ax.set_xlabel('N维度')
    ax.set_ylabel('M维度')
    
    plt.suptitle('tile_coord 不同用法', fontsize=14)
    plt.tight_layout()
    plt.show()

visualize_tile_coordinates()


## 5. FlashAttention中的local_tile代码分析


In [None]:
local_tile_code = '''
// ============================================================
// FlashAttention中的local_tile使用
// 来自 hopper/mainloop_fwd_sm80.hpp
// ============================================================

// 定义TileShape
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
//                          ^M=128       ^N=64        ^K=64

// 创建全局Q矩阵Tensor
Tensor mQ = make_tensor(
    make_gmem_ptr(params.ptr_Q + offset), 
    params.shape_Q_packed,    // (seqlen_q, headdim, ...)
    params.stride_Q_packed
)(_, _, bidh, bidb);          // 选择head和batch

// 使用local_tile获取Q的第m_block个块
// select<0, 2> 选择M和K维度 → Shape<_128, _64>
Tensor gQ = local_tile(mQ, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{}));
// gQ形状: (kBlockM, headdim) = (128, 64)

// 创建K矩阵Tensor并获取所有块
Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K + offset), ...);

// 使用通配符_获取所有K块
// select<1, 2> 选择N和K维度 → Shape<_64, _64>
Tensor gK = local_tile(mK, select<1, 2>(TileShape_MNK{}), make_coord(_, _0{}));
// gK形状: (kBlockN, headdim, num_n_blocks) = (64, 64, seqlen_k/64)

// 类似地获取V的所有块
Tensor gV = local_tile(mV, select<1, 2>(TileShape_MNK{}), make_coord(_, _0{}));

// ============================================================
// 主循环中使用分块
// ============================================================

for (int n_block = n_block_min; n_block < n_block_max; ++n_block) {
    // 获取第n_block个K块 (通过索引gK的第3维)
    // 这里tKgK是partition后的结果
    Tensor tKgK_cur = tKgK(_, _, _, n_block);
    
    // 拷贝到共享内存
    cute::copy(gmem_tiled_copy_QKV, tKgK_cur, tKsK);
    cute::cp_async_fence();
    
    // 类似处理V
    Tensor tVgV_cur = tVgV(_, _, _, n_block);
    cute::copy(gmem_tiled_copy_QKV, tVgV_cur, tVsV);
    
    // 等待拷贝完成
    cute::cp_async_wait<0>();
    __syncthreads();
    
    // MMA计算: S = Q @ K^T
    cute::gemm(tiled_mma, tSrQ, tSrK, acc_s);
    
    // ... softmax ...
    
    // MMA计算: O += P @ V
    cute::gemm(tiled_mma, tSrP, tSrV, acc_o);
}
'''

print(local_tile_code)


## 6. 边界处理可视化


In [None]:
def visualize_boundary_handling():
    """可视化不完整块的边界处理"""
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # 参数
    seqlen = 2000  # 不是128的整数倍
    kBlockM = 128
    
    num_full_blocks = seqlen // kBlockM
    remainder = seqlen % kBlockM
    total_blocks = num_full_blocks + (1 if remainder > 0 else 0)
    
    # 左图：完整块 vs 不完整块
    ax = axes[0]
    
    # 画块
    block_height = 0.8
    for i in range(min(total_blocks, 8)):  # 最多显示8个
        if i < num_full_blocks and i < 7:
            # 完整块
            rect = patches.Rectangle((0.5, i), 3, block_height,
                                       facecolor='lightblue', edgecolor='blue')
            ax.add_patch(rect)
            ax.text(2, i + block_height/2, f'块{i}: {kBlockM}行 (完整)', 
                   ha='center', va='center', fontsize=9)
        elif i == total_blocks - 1:
            # 不完整块
            valid_width = 3 * remainder / kBlockM
            rect1 = patches.Rectangle((0.5, i), valid_width, block_height,
                                        facecolor='lightgreen', edgecolor='green')
            ax.add_patch(rect1)
            
            rect2 = patches.Rectangle((0.5 + valid_width, i), 3 - valid_width, block_height,
                                        facecolor='lightcoral', edgecolor='red', hatch='//')
            ax.add_patch(rect2)
            
            ax.text(2, i + block_height/2, f'块{num_full_blocks}: {remainder}行有效', 
                   ha='center', va='center', fontsize=9)
        elif i == 6:
            ax.text(2, i + block_height/2, '...', ha='center', va='center', fontsize=12)
    
    ax.set_xlim(0, 4)
    ax.set_ylim(-0.5, 8.5)
    ax.set_title(f'seqlen={seqlen}, kBlockM={kBlockM}\n最后一块不完整', fontsize=11)
    ax.axis('off')
    
    # 右图：predicate处理
    ax = axes[1]
    
    # 最后一块的详细视图
    tile_rows = 16  # 简化显示
    valid_rows = int(tile_rows * remainder / kBlockM)
    
    for i in range(tile_rows):
        if i < valid_rows:
            color = 'lightgreen'
            label = 'copy'
        else:
            color = 'lightcoral'
            label = 'clear'
        
        rect = patches.Rectangle((0.5, i), 3, 0.8, facecolor=color, edgecolor='gray')
        ax.add_patch(rect)
        if i < 3 or i >= tile_rows - 2 or i == valid_rows - 1 or i == valid_rows:
            ax.text(2, i + 0.4, f'行{i}: {label}', ha='center', va='center', fontsize=8)
    
    ax.axhline(y=valid_rows - 0.1, color='red', linewidth=2, linestyle='--')
    ax.text(4, valid_rows - 0.1, '边界', ha='left', va='center', fontsize=10, color='red')
    
    ax.set_xlim(0, 5)
    ax.set_ylim(-0.5, tile_rows + 0.5)
    ax.set_title('Predicate边界处理\n有效行拷贝，越界行清零', fontsize=11)
    ax.axis('off')
    
    plt.suptitle('不完整块的边界处理', fontsize=14)
    plt.tight_layout()
    plt.show()
    
    print(f"seqlen = {seqlen}")
    print(f"kBlockM = {kBlockM}")
    print(f"完整块数 = {num_full_blocks}")
    print(f"最后一块有效行数 = {remainder}")
    print(f"最后一块越界行数 = {kBlockM - remainder}")

visualize_boundary_handling()


## 7. select辅助函数


In [None]:
def demonstrate_select():
    """演示select函数的作用"""
    print("select 函数演示")
    print("=" * 60)
    
    # 模拟TileShape_MNK
    TileShape_MNK = (128, 64, 64)  # M, N, K
    
    print(f"\nTileShape_MNK = {TileShape_MNK}")
    print(f"  索引0 (M) = {TileShape_MNK[0]}")
    print(f"  索引1 (N) = {TileShape_MNK[1]}")
    print(f"  索引2 (K) = {TileShape_MNK[2]}")
    
    # select<0, 2> 选择M和K
    select_0_2 = (TileShape_MNK[0], TileShape_MNK[2])
    print(f"\nselect<0, 2>(TileShape_MNK) = {select_0_2}")
    print(f"  用于Q矩阵: (seqlen_q, headdim) 分块为 ({select_0_2[0]}, {select_0_2[1]})")
    
    # select<1, 2> 选择N和K
    select_1_2 = (TileShape_MNK[1], TileShape_MNK[2])
    print(f"\nselect<1, 2>(TileShape_MNK) = {select_1_2}")
    print(f"  用于K/V矩阵: (seqlen_k, headdim) 分块为 ({select_1_2[0]}, {select_1_2[1]})")
    
    print("\n" + "=" * 60)
    print("FlashAttention中的使用:")
    print("""
// Q矩阵分块
Tensor gQ = local_tile(mQ, select<0, 2>(TileShape_MNK{}), ...);
// 分块大小: (kBlockM=128, headdim=64)

// K矩阵分块  
Tensor gK = local_tile(mK, select<1, 2>(TileShape_MNK{}), ...);
// 分块大小: (kBlockN=64, headdim=64)

// V矩阵分块 (与K相同)
Tensor gV = local_tile(mV, select<1, 2>(TileShape_MNK{}), ...);
// 分块大小: (kBlockN=64, headdim=64)
""")

demonstrate_select()


## 8. 完整数据流可视化


In [None]:
def visualize_complete_dataflow():
    """可视化local_tile在完整数据流中的位置"""
    print("FlashAttention 完整数据流")
    print("=" * 70)
    
    flow = """
1. 创建全局Tensor
   ┌─────────────────────────────────────────────────────────────────┐
   │  Tensor mQ = make_tensor(make_gmem_ptr(ptr_Q), shape, stride)  │
   │  Tensor mK = make_tensor(make_gmem_ptr(ptr_K), shape, stride)  │
   │  Tensor mV = make_tensor(make_gmem_ptr(ptr_V), shape, stride)  │
   └─────────────────────────────────────────────────────────────────┘
                                   │
                                   ▼
2. 使用local_tile获取分块视图  ◄── 本节重点
   ┌─────────────────────────────────────────────────────────────────┐
   │  Tensor gQ = local_tile(mQ, tile_shape, make_coord(m_block, 0)) │
   │  Tensor gK = local_tile(mK, tile_shape, make_coord(_, 0))       │
   │  Tensor gV = local_tile(mV, tile_shape, make_coord(_, 0))       │
   └─────────────────────────────────────────────────────────────────┘
                                   │
                                   ▼
3. 使用TiledCopy分区
   ┌─────────────────────────────────────────────────────────────────┐
   │  auto thr_copy = gmem_tiled_copy.get_thread_slice(threadIdx.x) │
   │  Tensor tQgQ = thr_copy.partition_S(gQ)                         │
   │  Tensor tKgK = thr_copy.partition_S(gK)                         │
   └─────────────────────────────────────────────────────────────────┘
                                   │
                                   ▼
4. 主循环：遍历K/V块
   ┌─────────────────────────────────────────────────────────────────┐
   │  for n_block in range(num_blocks):                              │
   │      // 拷贝K[n_block]到共享内存                                │
   │      copy(tKgK(_, _, _, n_block), tKsK)                        │
   │      cp_async_fence(); cp_async_wait<0>()                       │
   │                                                                  │
   │      // 拷贝到寄存器                                            │
   │      copy(smem_tiled_copy, tCsK, tCrK)                          │
   │                                                                  │
   │      // MMA计算                                                 │
   │      gemm(tiled_mma, tCrQ, tCrK, acc_s)  // S = Q @ K^T        │
   │      // ... softmax ...                                         │
   │      gemm(tiled_mma, tCrP, tCrV, acc_o)  // O += P @ V         │
   └─────────────────────────────────────────────────────────────────┘
                                   │
                                   ▼
5. 写回输出
   ┌─────────────────────────────────────────────────────────────────┐
   │  Tensor gO = local_tile(mO, tile_shape, make_coord(m_block, 0)) │
   │  copy(acc_o, gO)                                                │
   └─────────────────────────────────────────────────────────────────┘
"""
    print(flow)

visualize_complete_dataflow()


## 总结

通过本notebook，你应该理解了：

1. **local_tile基本概念**: 从大Tensor获取分块视图，零拷贝
2. **tile_shape**: 定义分块大小
3. **tile_coord**: 指定分块位置，支持通配符`_`
4. **select辅助函数**: 从多维Shape中选择特定维度
5. **边界处理**: 最后一块可能不完整，需要predicate
6. **在FlashAttention中的应用**: Q固定块，K/V遍历所有块

## 关键代码模式

```cpp
// 1. 定义TileShape
using TileShape_MNK = Shape<Int<128>, Int<64>, Int<64>>;

// 2. 创建全局Tensor
Tensor mQ = make_tensor(make_gmem_ptr(ptr), shape, stride);

// 3. 获取固定块（Q）
Tensor gQ = local_tile(mQ, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{}));

// 4. 获取所有块（K/V）用于遍历
Tensor gK = local_tile(mK, select<1, 2>(TileShape_MNK{}), make_coord(_, _0{}));

// 5. 在循环中访问每个块
for (int n = 0; n < size<2>(gK); ++n) {
    Tensor gK_n = gK(_, _, n);  // 第n个K块
}
```

## 3_2cute核心概念 学习完成！

至此，你已经学习了cute的5个核心概念：
1. **Tensor**: 多维数组抽象
2. **Layout**: 描述数据在内存中的排布
3. **TiledMMA**: 封装Tensor Core矩阵乘法
4. **TiledCopy**: 高效内存拷贝抽象
5. **local_tile**: 获取Tensor的局部分块视图

这些概念是理解FlashAttention实现的基础！
