# Layout内存布局 - 实践篇

本notebook通过可视化和代码示例帮助你深入理解cute的Layout概念。

**学习目标：**
- 理解Shape和Stride的关系
- 可视化不同Layout的内存映射
- 理解层次化Layout的应用场景


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, List, Union

plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

print("环境准备完成！")


## 1. Layout基础类


In [None]:
class Layout:
    """模拟cute的Layout类"""
    def __init__(self, shape: Tuple, stride: Tuple = None):
        self.shape = shape
        if stride is None:
            self.stride = self._compute_row_major_stride(shape)
        else:
            self.stride = stride
    
    def _compute_row_major_stride(self, shape: Tuple) -> Tuple:
        """计算行优先stride"""
        stride = [1]
        for s in reversed(shape[1:]):
            stride.append(stride[-1] * s)
        return tuple(reversed(stride))
    
    def __call__(self, *coords) -> int:
        """计算偏移"""
        offset = 0
        for c, s in zip(coords, self.stride):
            offset += c * s
        return offset
    
    def size(self) -> int:
        result = 1
        for s in self.shape:
            result *= s
        return result
    
    def cosize(self) -> int:
        """计算需要的最小内存大小"""
        max_offset = 0
        for i, (s, st) in enumerate(zip(self.shape, self.stride)):
            max_offset += (s - 1) * st
        return max_offset + 1
    
    def __repr__(self):
        return f"Layout({self.shape}:{self.stride})"

def make_layout(shape, stride=None):
    return Layout(shape, stride)

# 测试
layout = make_layout((4, 3), (3, 1))
print(f"Layout: {layout}")
print(f"Size: {layout.size()}")
print(f"Cosize: {layout.cosize()}")


## 2. 可视化Layout映射


In [None]:
def visualize_2d_layout(layout: Layout, title: str = ""):
    """可视化2D Layout的映射"""
    M, N = layout.shape
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # 逻辑视图 - 显示物理偏移
    ax1 = axes[0]
    offsets = np.zeros((M, N), dtype=int)
    for i in range(M):
        for j in range(N):
            offsets[i, j] = layout(i, j)
    
    im1 = ax1.imshow(offsets, cmap='viridis')
    ax1.set_title(f'逻辑视图 (颜色=物理偏移)')
    ax1.set_xlabel('列 j')
    ax1.set_ylabel('行 i')
    
    for i in range(M):
        for j in range(N):
            ax1.text(j, i, f'{offsets[i,j]}', ha='center', va='center', 
                    color='white', fontweight='bold', fontsize=10)
    
    plt.colorbar(im1, ax=ax1)
    
    # 线性内存视图
    ax2 = axes[1]
    cosize = layout.cosize()
    
    # 创建线性内存表示
    linear_view = np.full((1, cosize), -1)
    coord_labels = [""] * cosize
    
    for i in range(M):
        for j in range(N):
            offset = layout(i, j)
            linear_view[0, offset] = offset
            coord_labels[offset] = f"({i},{j})"
    
    im2 = ax2.imshow(linear_view, cmap='viridis', aspect='auto')
    ax2.set_title('线性内存 (显示逻辑坐标)')
    ax2.set_xlabel('物理偏移')
    ax2.set_yticks([])
    
    for idx in range(cosize):
        if coord_labels[idx]:
            ax2.text(idx, 0, coord_labels[idx], ha='center', va='center', 
                    color='white', fontsize=8, rotation=45)
    
    plt.suptitle(f'{title}\nshape={layout.shape}, stride={layout.stride}', fontsize=12)
    plt.tight_layout()
    plt.show()

# 可视化行优先
visualize_2d_layout(make_layout((4, 4), (4, 1)), "行优先 (Row-Major)")

# 可视化列优先
visualize_2d_layout(make_layout((4, 4), (1, 4)), "列优先 (Column-Major)")


## 3. 转置只需改变Stride


In [None]:
# 演示：转置不需要移动数据，只需改变Layout

# 原始 3x4 矩阵 (行优先)
original = make_layout((3, 4), (4, 1))

# 转置后 4x3 矩阵 (交换shape和stride)
transposed = make_layout((4, 3), (1, 4))

print("原始矩阵 3x4 (行优先):")
print(f"Layout: {original}")
print("逻辑视图:")
for i in range(3):
    row = [f"{original(i, j):2d}" for j in range(4)]
    print(f"  [{', '.join(row)}]")

print("\n转置后 4x3 (只改变Layout，数据不动):")
print(f"Layout: {transposed}")
print("逻辑视图:")
for i in range(4):
    row = [f"{transposed(i, j):2d}" for j in range(3)]
    print(f"  [{', '.join(row)}]")

print("\n观察: 转置后(1,2)的偏移 = 原始(2,1)的偏移")
print(f"transposed(1, 2) = {transposed(1, 2)}")
print(f"original(2, 1) = {original(2, 1)}")


## 4. 非连续Layout（带padding）


In [None]:
# 共享内存中常用padding来避免bank conflict

# 连续Layout: 8x8, stride=(8,1)
contiguous = make_layout((8, 8), (8, 1))

# 带padding的Layout: 8x8, stride=(10,1) - 每行多2个padding
padded = make_layout((8, 8), (10, 1))

print("连续Layout:")
print(f"  {contiguous}")
print(f"  size={contiguous.size()}, cosize={contiguous.cosize()}")

print("\n带padding的Layout:")
print(f"  {padded}")
print(f"  size={padded.size()}, cosize={padded.cosize()}")

print("\n行尾偏移对比:")
print("  连续Layout各行起始: ", [contiguous(i, 0) for i in range(8)])
print("  Padded Layout各行起始:", [padded(i, 0) for i in range(8)])

print("\nPadding的作用: 打破bank conflict模式")
print("GPU共享内存有32个bank，地址 % 32 相同的访问会冲突")
print("通过padding改变stride，使得不同线程访问不同bank")


## 5. Bank Conflict可视化


In [None]:
def visualize_bank_conflict(layout: Layout, title: str):
    """可视化共享内存的bank分布"""
    M, N = layout.shape
    num_banks = 32
    
    fig, ax = plt.subplots(figsize=(12, 6))
    
    banks = np.zeros((M, N), dtype=int)
    for i in range(M):
        for j in range(N):
            offset = layout(i, j)
            banks[i, j] = offset % num_banks
    
    im = ax.imshow(banks, cmap='tab20', vmin=0, vmax=19)
    ax.set_title(f'{title}\n每个元素的bank编号 (offset % 32)')
    ax.set_xlabel('列')
    ax.set_ylabel('行')
    
    for i in range(M):
        for j in range(N):
            ax.text(j, i, f'{banks[i,j]}', ha='center', va='center', 
                   fontsize=8, color='white')
    
    plt.colorbar(im, ax=ax, label='Bank ID')
    plt.tight_layout()
    plt.show()
    
    # 分析同一列的bank分布
    print(f"\n{title} - 列方向bank分布分析:")
    for col in range(min(4, N)):
        col_banks = [banks[row, col] for row in range(M)]
        unique_banks = len(set(col_banks))
        print(f"  列{col}: banks={col_banks[:8]}... 唯一bank数={unique_banks}")

# 连续Layout可能有bank conflict
print("=" * 60)
print("连续Layout (stride=8,1):")
visualize_bank_conflict(make_layout((8, 8), (8, 1)), "连续Layout")

print("\n" + "=" * 60)
print("Padded Layout (stride=9,1):")
visualize_bank_conflict(make_layout((8, 8), (9, 1)), "Padded Layout")


## 6. FlashAttention中的Layout示例


In [None]:
# FlashAttention中典型的Layout配置

print("FlashAttention中的典型Layout配置")
print("=" * 60)

# Q矩阵 Layout
kBlockM = 128   # Q的分块行数
kHeadDim = 64   # 头维度
kPadding = 8    # Padding大小

print(f"\n1. Q矩阵全局内存Layout:")
print(f"   shape = (seqlen, headdim)")
print(f"   stride = (headdim, 1)  # 行优先")
gmem_Q = make_layout((2048, 64), (64, 1))
print(f"   示例: {gmem_Q}")

print(f"\n2. Q块共享内存Layout (带padding):")
print(f"   shape = ({kBlockM}, {kHeadDim})")
print(f"   stride = ({kHeadDim + kPadding}, 1)")
smem_Q = make_layout((kBlockM, kHeadDim), (kHeadDim + kPadding, 1))
print(f"   {smem_Q}")
print(f"   size={smem_Q.size()}, cosize={smem_Q.cosize()}")
print(f"   内存使用率: {smem_Q.size() / smem_Q.cosize() * 100:.1f}%")

print(f"\n3. K/V矩阵Layout:")
kBlockN = 64
smem_K = make_layout((kBlockN, kHeadDim), (kHeadDim + kPadding, 1))
print(f"   K: {smem_K}")
print(f"   V: {smem_K}  # 与K相同")


## 总结

通过本notebook，你应该理解了：

1. **Layout基础**: Layout = (Shape, Stride)，定义逻辑到物理的映射
2. **行优先vs列优先**: 通过不同的Stride实现
3. **转置操作**: 只需交换Shape和Stride，不移动数据
4. **Padding**: 用于避免bank conflict，stride > shape
5. **FlashAttention应用**: 共享内存使用带padding的Layout

## 下一步

在下一节"TiledMMA矩阵乘法"中，我们将学习如何封装Tensor Core操作。
