# Tensor张量抽象 - 实践篇

本notebook通过Python模拟和代码分析帮助你理解cute Tensor的核心概念。

**学习目标：**
- 理解Tensor = 指针 + Layout 的结构
- 模拟cute的索引计算过程
- 分析FlashAttention中的Tensor使用


## 环境准备


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, List, Union

# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

print("环境准备完成！")


## 1. 模拟cute Layout

Layout是cute的基础概念，它定义了逻辑坐标到物理偏移的映射。


In [None]:
class Layout:
    """
    模拟cute的Layout类
    Layout = (Shape, Stride)
    """
    def __init__(self, shape: Tuple[int, ...], stride: Tuple[int, ...] = None):
        self.shape = shape
        # 如果没有指定stride，默认使用行优先（row-major）
        if stride is None:
            self.stride = self._compute_row_major_stride(shape)
        else:
            self.stride = stride
    
    def _compute_row_major_stride(self, shape: Tuple[int, ...]) -> Tuple[int, ...]:
        """计算行优先的stride"""
        stride = [1]
        for s in reversed(shape[1:]):
            stride.append(stride[-1] * s)
        return tuple(reversed(stride))
    
    def __call__(self, *coords) -> int:
        """给定逻辑坐标，返回线性偏移"""
        assert len(coords) == len(self.shape), f"坐标维度不匹配: {len(coords)} vs {len(self.shape)}"
        offset = 0
        for c, s in zip(coords, self.stride):
            offset += c * s
        return offset
    
    def size(self) -> int:
        """返回总元素数"""
        result = 1
        for s in self.shape:
            result *= s
        return result
    
    def __repr__(self):
        return f"Layout(shape={self.shape}, stride={self.stride})"

# 测试Layout
layout = Layout((4, 3))
print(f"行优先 4x3 矩阵: {layout}")
print(f"元素(2,1)的偏移: {layout(2, 1)}")
print(f"总大小: {layout.size()}")


## 2. 模拟cute Tensor

Tensor = Engine (数据存储) + Layout (索引映射)


In [None]:
class Tensor:
    """
    模拟cute的Tensor类
    Tensor = Engine (数据存储) + Layout (索引映射)
    """
    def __init__(self, data: np.ndarray, layout: Layout):
        self.data = data.flatten()  # 模拟线性内存
        self.layout = layout
    
    def __getitem__(self, coords) -> float:
        """多维索引访问"""
        if isinstance(coords, int):
            return self.data[coords]
        offset = self.layout(*coords)
        return self.data[offset]
    
    def __setitem__(self, coords, value):
        """多维索引写入"""
        if isinstance(coords, int):
            self.data[coords] = value
        else:
            offset = self.layout(*coords)
            self.data[offset] = value
    
    @property
    def shape(self):
        return self.layout.shape
    
    def __repr__(self):
        return f"Tensor(shape={self.shape}, layout={self.layout})"

def make_tensor(data: np.ndarray, shape: Tuple[int, ...], stride: Tuple[int, ...] = None) -> Tensor:
    """模拟cute的make_tensor函数"""
    layout = Layout(shape, stride)
    return Tensor(data, layout)

# 创建测试数据
data = np.arange(12, dtype=np.float32)
print(f"原始数据: {data}")

# 创建Tensor
tensor = make_tensor(data, (4, 3))
print(f"\nTensor: {tensor}")
print(f"tensor[2, 1] = {tensor[2, 1]}")
print(f"tensor[0, 0] = {tensor[0, 0]}")
print(f"tensor[3, 2] = {tensor[3, 2]}")


In [None]:
# 对比行优先和列优先的区别
data = np.arange(12, dtype=np.float32)

# 行优先 (Row-major): stride = (3, 1)
row_major = make_tensor(data, (4, 3), stride=(3, 1))

# 列优先 (Column-major): stride = (1, 4)
col_major = make_tensor(data, (4, 3), stride=(1, 4))

print("线性内存: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]")
print()

print("行优先 (stride=(3,1)):")
print("  逻辑视图:")
for i in range(4):
    row = [f"{row_major[i, j]:.0f}" for j in range(3)]
    print(f"    [{', '.join(row)}]")

print("\n列优先 (stride=(1,4)):")
print("  逻辑视图:")
for i in range(4):
    row = [f"{col_major[i, j]:.0f}" for j in range(3)]
    print(f"    [{', '.join(row)}]")

print("\n注意: 相同的逻辑坐标在不同Layout下对应不同的物理偏移")
print(f"  row_major[1,0] = {row_major[1, 0]} (offset={row_major.layout(1, 0)})")
print(f"  col_major[1,0] = {col_major[1, 0]} (offset={col_major.layout(1, 0)})")


## 4. 可视化Layout映射


In [None]:
def visualize_layout(layout: Layout, title: str):
    """可视化Layout的逻辑→物理映射"""
    M, N = layout.shape
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # 左图: 逻辑视图
    ax1 = axes[0]
    logical = np.zeros((M, N))
    for i in range(M):
        for j in range(N):
            logical[i, j] = layout(i, j)
    
    im1 = ax1.imshow(logical, cmap='viridis')
    ax1.set_title(f'逻辑视图 ({M}x{N})')
    ax1.set_xlabel('列索引 j')
    ax1.set_ylabel('行索引 i')
    
    # 添加数值标注
    for i in range(M):
        for j in range(N):
            ax1.text(j, i, f'{int(logical[i,j])}', ha='center', va='center', color='white', fontweight='bold')
    
    plt.colorbar(im1, ax=ax1, label='物理偏移')
    
    # 右图: 线性内存视图
    ax2 = axes[1]
    total = layout.size()
    linear = np.arange(total).reshape(1, -1)
    
    im2 = ax2.imshow(linear, cmap='viridis', aspect='auto')
    ax2.set_title('线性内存')
    ax2.set_xlabel('物理偏移')
    ax2.set_yticks([])
    
    for idx in range(total):
        ax2.text(idx, 0, f'{idx}', ha='center', va='center', color='white', fontsize=8)
    
    plt.suptitle(f'{title}\nshape={layout.shape}, stride={layout.stride}', fontsize=12)
    plt.tight_layout()
    plt.show()

# 可视化行优先
visualize_layout(Layout((4, 3), (3, 1)), "行优先 (Row-major)")

# 可视化列优先
visualize_layout(Layout((4, 3), (1, 4)), "列优先 (Column-major)")


## 5. 模拟local_tile操作

local_tile从大Tensor中获取一个小的分块视图，不拷贝数据。


In [None]:
class TensorView:
    """
    Tensor的视图（不拷贝数据）
    用于模拟local_tile返回的分块视图
    """
    def __init__(self, tensor: Tensor, offset: int, layout: Layout):
        self.tensor = tensor
        self.offset = offset  # 在原Tensor中的起始偏移
        self.layout = layout
    
    def __getitem__(self, coords) -> float:
        local_offset = self.layout(*coords) if not isinstance(coords, int) else coords
        return self.tensor.data[self.offset + local_offset]
    
    @property
    def shape(self):
        return self.layout.shape

def local_tile(tensor: Tensor, tile_shape: Tuple[int, ...], tile_coord: Tuple[int, ...]) -> TensorView:
    """
    模拟cute的local_tile函数
    从大Tensor中获取一个小的分块视图
    """
    # 计算tile在原Tensor中的起始位置
    start_coords = tuple(tc * ts for tc, ts in zip(tile_coord, tile_shape))
    offset = tensor.layout(*start_coords)
    
    # 创建tile的Layout（继承原始stride）
    tile_layout = Layout(tile_shape, tensor.layout.stride)
    
    return TensorView(tensor, offset, tile_layout)

# 创建一个 8x6 的Tensor
data = np.arange(48, dtype=np.float32)
tensor = make_tensor(data, (8, 6))

print(f"原始Tensor: shape={tensor.shape}")
print(f"内容预览:")
for i in range(8):
    row = [f"{tensor[i, j]:2.0f}" for j in range(6)]
    print(f"  [{', '.join(row)}]")

# 获取 4x3 的tile
tile_00 = local_tile(tensor, (4, 3), (0, 0))
tile_11 = local_tile(tensor, (4, 3), (1, 1))

print(f"\ntile(0,0) at offset {tile_00.offset}:")
for i in range(4):
    row = [f"{tile_00[i, j]:2.0f}" for j in range(3)]
    print(f"  [{', '.join(row)}]")

print(f"\ntile(1,1) at offset {tile_11.offset}:")
for i in range(4):
    row = [f"{tile_11[i, j]:2.0f}" for j in range(3)]
    print(f"  [{', '.join(row)}]")


## 6. FlashAttention中的Tensor使用分析


In [None]:
# FlashAttention中Tensor使用的代码分析

flashattention_code = '''
// ============================================================
// FlashAttention中的Tensor创建和使用模式
// 来自 hopper/mainloop_fwd_sm80.hpp
// ============================================================

// 1. 创建全局内存Tensor
Tensor mQ = make_tensor(
    make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr)),
    make_shape(params.seqlen_q, params.d),     // [N, d]
    make_stride(params.q_row_stride, Int<1>{}) // 行优先
);

// 2. 使用local_tile获取当前block处理的分块
Tensor gQ = local_tile(
    mQ,                               // 原始Tensor
    TileShape_MK{},                   // tile大小，如(128, 64)
    make_coord(m_block, 0)            // tile坐标
);

// 3. 创建共享内存Tensor
extern __shared__ char smem_[];
Tensor sQ = make_tensor(
    make_smem_ptr(reinterpret_cast<Element*>(smem_)),
    SmemLayoutQ{}  // 预定义的共享内存Layout
);

// 4. 使用TiledCopy进行分区和拷贝
auto gmem_thr_copy = gmem_tiled_copy.get_thread_slice(threadIdx.x);
Tensor tQgQ = gmem_thr_copy.partition_S(gQ);  // 源分区
Tensor tQsQ = gmem_thr_copy.partition_D(sQ);  // 目标分区
cute::copy(gmem_tiled_copy, tQgQ, tQsQ);

// 5. 创建Fragment并执行MMA
Tensor acc = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
clear(acc);
cute::gemm(tiled_mma, tSrQ, tSrK, acc);
'''

print(flashattention_code)


## 7. 模拟FlashAttention的分块策略


In [None]:
def simulate_flashattention_tiling(seqlen, headdim, block_m, block_n):
    """模拟FlashAttention的分块策略"""
    print(f"=" * 60)
    print(f"FlashAttention分块模拟")
    print(f"=" * 60)
    print(f"序列长度: {seqlen}")
    print(f"头维度: {headdim}")
    print(f"Q块大小 (kBlockM): {block_m}")
    print(f"K/V块大小 (kBlockN): {block_n}")
    print()
    
    # 计算分块数
    num_q_blocks = (seqlen + block_m - 1) // block_m
    num_kv_blocks = (seqlen + block_n - 1) // block_n
    
    print(f"Q分块数: {num_q_blocks}")
    print(f"K/V分块数: {num_kv_blocks}")
    print()
    
    # 模拟处理流程
    print("处理流程:")
    print("-" * 60)
    
    for m_block in range(min(num_q_blocks, 2)):
        q_start = m_block * block_m
        q_end = min((m_block + 1) * block_m, seqlen)
        
        print(f"\nm_block={m_block}: 处理 Q[{q_start}:{q_end}, :]")
        print(f"  gQ = local_tile(mQ, ({block_m}, {headdim}), ({m_block}, 0))")
        
        for n_block in range(min(num_kv_blocks, 2)):
            k_start = n_block * block_n
            k_end = min((n_block + 1) * block_n, seqlen)
            
            print(f"    n_block={n_block}: K/V[{k_start}:{k_end}, :]")
            print(f"      S_block = Q_block @ K_block^T")
            print(f"      O_block += softmax(S_block) @ V_block")
        
        if num_kv_blocks > 2:
            print(f"    ... 还有 {num_kv_blocks - 2} 个K/V块")
    
    if num_q_blocks > 2:
        print(f"\n... 还有 {num_q_blocks - 2} 个Q块")

# 模拟典型配置
simulate_flashattention_tiling(seqlen=2048, headdim=64, block_m=128, block_n=64)


## 总结

通过本notebook，你应该理解了：

1. **Tensor结构**: Tensor = Engine (数据) + Layout (索引映射)
2. **指针类型**: `make_gmem_ptr()` 全局内存, `make_smem_ptr()` 共享内存
3. **Layout与Stride**: 行优先 stride=(N,1), 列优先 stride=(1,M)
4. **local_tile操作**: 从大Tensor获取分块视图，不拷贝数据
5. **FlashAttention应用**: Q/K/V创建为gmem Tensor，分块处理

## 下一步

在下一节"Layout内存布局"中，我们将深入学习Layout的高级特性。
