In [1]:
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack

import numpy as np

# Introduction to the TensorSSA in CuTe DSL

This tutorial introduces what is the `TensorSSA` and why we need it. We also give some examples to show how to use `TensorSSA`.

## What is TensorSSA

`TensorSSA` is a Python class that represents a tensor value in Static Single Assignment (SSA) form within the CuTe DSL. You can think of it as a tensor residing in a (simulated) register.

## Why TensorSSA

`TensorSSA` encapsulates the underlying MLIR tensor value into an object that's easier to manipulate in Python. By overloading numerous Python operators (like `+`, `-`, `*`, `/`, `[]`, etc.), it allows users to express tensor computations (primarily element-wise operations and reductions) in a more Pythonic way. These element-wise operations are then translated into optimized vectorization instructions.

It's part of the CuTe DSL, serving as a bridge between the user-described computational logic and the lower-level MLIR IR, particularly for representing and manipulating register-level data.

## When to use TensorSSA

`TensorSSA` is primarily used in the following scenarios:
  在 CuTe DSL 中，TensorSSA 是**“寄存器中的张量”**。它与普通的 Tensor 有本质区别：
  - Tensor：是一个视图 (View)。它包含一个指向内存（全局内存、共享内存或寄存器堆）的指针和一个布局 (Layout)。它描述了数据“在哪里”以及“如何排列”。
  - TensorSSA：是一个值 (Value)。它代表已经被加载到硬件寄存器中的实际数值，采用静态单赋值 (Static Single Assignment, SSA) 形式。它是不可变的。

### Load from memory and store to memory

：a.load() 在底层触发了 MLIR 的 vector.load 指令。TensorSSA 内部包裹了一个 MLIR 的 vector 类型值。这保证了后续的所有运算都是在硬件的向量单元上直接执行，而不是逐元素循环

In [2]:
@cute.jit
def load_and_store(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):
    """
    Load data from memory and store the result to memory.

    :param res: The destination tensor to store the result.
    :param a: The source tensor to be loaded.
    :param b: The source tensor to be loaded.
    """
    a_vec = a.load()   #! def load(self, *, loc=None, ip=None) -> "TensorSSA": ...
                       #* a 是 Tensor (内存)，a_vec 是 TensorSSA (寄存器) 
                       #* TensorSSA 内部包裹了一个 MLIR 的 vector 类型值。这保证了后续的所有运算都是在硬件的向量单元上直接执行，而不是逐元素循环。
    print(f"a_vec: {a_vec}")
    b_vec = b.load()
    print(f"b_vec: {b_vec}") 
    res.store(a_vec + b_vec)
    cute.print_tensor(a_vec + b_vec)

a = np.ones(12).reshape((3, 4)).astype(np.float32)
b = np.ones(12).reshape((3, 4)).astype(np.float32)
c = np.zeros(12).reshape((3, 4)).astype(np.float32)
load_and_store(from_dlpack(c), from_dlpack(a), from_dlpack(b))

a_vec: tensor_value<vector<12xf32> o (3, 4)>
b_vec: tensor_value<vector<12xf32> o (3, 4)>
tensor(raw_ptr(0x00007ffe8f9f44c0: f32, rmem, align<32>) o (3,4):(1,3), data=
       [[ 2.000000,  2.000000,  2.000000,  2.000000, ],
        [ 2.000000,  2.000000,  2.000000,  2.000000, ],
        [ 2.000000,  2.000000,  2.000000,  2.000000, ]])


### Register-Level Tensor Operations

When writing kernel logic, various computations, transformations, slicing, etc., are performed on data loaded into registers.

底层逻辑：它使用了 MLIR 的 vector.extract_strided_slice。由于 CuTe 是列优先 (Column-Major) 而 MLIR 向量通常期望行优先，TensorSSA 会自动处理数据的 Shuffle (洗牌) 操作，确保逻辑上的切片在物理寄存器中正确实现

In [3]:
@cute.jit
def apply_slice(src: cute.Tensor, dst: cute.Tensor, indices: cutlass.Constexpr):
    """
    Apply slice operation on the src tensor and store the result to the dst tensor.

    :param src: The source tensor to be sliced.
    :param dst: The destination tensor to store the result.
    :param indices: The indices to slice the source tensor.
    """
    src_vec = src.load()
    dst_vec = src_vec[indices]
    print(f"{src_vec} -> {dst_vec}")
    if cutlass.const_expr(isinstance(dst_vec, cute.TensorSSA)):
        dst.store(dst_vec)
        cute.print_tensor(dst)
    else:
        dst[0] = dst_vec
        cute.print_tensor(dst)

#   内存 (src)  ──load()──▶  寄存器 (src_vec: TensorSSA)
#                                 │
#                           [indices] 切片
#                                 │
#                                 ▼
#                           dst_vec (TensorSSA 或 标量)
#                                 │
#                 ┌───────────────┴───────────────┐
#                 │                               │
#           TensorSSA?                        标量?
#                 │                               │
#          dst.store(dst_vec)              dst[0] = dst_vec
#                 │                               │
#                 └───────────────┬───────────────┘
#                                 ▼
#                           内存 (dst)



def slice_1():
    src_shape = (4, 2, 3)
    dst_shape = (4, 3)
    indices = (None, 1, None)    #! None 在 CuTe DSL 中等价于通配符 _（保留该维度所有元素）

    a = np.arange(np.prod(src_shape)).reshape(*src_shape).astype(np.float32)
    dst = np.random.randn(*dst_shape).astype(np.float32)
    apply_slice(from_dlpack(a), from_dlpack(dst), indices)

slice_1()


#   ┌─────────────────────────────────────────────────────────────────┐
#   │  NumPy (CPU)                                                    │
#   │  a: shape (4,2,3), values 0-23                                  │
#   └───────────────────────────┬─────────────────────────────────────┘
#                               │ from_dlpack()
#                               ▼
#   ┌─────────────────────────────────────────────────────────────────┐
#   │  CuTe Tensor (cute.Tensor)                                      │
#   │  内存指针 + Layout 元数据                                         │
#   └───────────────────────────┬─────────────────────────────────────┘
#                               │ src.load()
#                               ▼
#   ┌─────────────────────────────────────────────────────────────────┐
#   │  TensorSSA (register-level)                                     │
#   │  src_vec: vector<24xf32> o (4,2,3)                              │
#   └───────────────────────────┬─────────────────────────────────────┘
#                               │ src_vec[(None, 1, None)]
#                               ▼
#   ┌─────────────────────────────────────────────────────────────────┐
#   │  TensorSSA (sliced)                                             │
#   │  dst_vec: vector<12xf32> o (4,3)                                │
#   │  值: [3,4,5, 9,10,11, 15,16,17, 21,22,23]                        │
#   └───────────────────────────┬─────────────────────────────────────┘
#                               │ dst.store(dst_vec)
#                               ▼
#   ┌─────────────────────────────────────────────────────────────────┐
#   │  结果写回内存                                                    │
#   └─────────────────────────────────────────────────────────────────┘

#   切片操作提取的是原张量中 Dim 1 = 1 的那一"层"（即每个 4 个 2×3 子矩阵的第 2 行）

tensor_value<vector<24xf32> o (4, 2, 3)> -> tensor_value<vector<12xf32> o (4, 3)>
tensor(raw_ptr(0x000055fc790cbc40: f32, generic, align<4>) o (4,3):(3,1), data=
       [[ 3.000000,  4.000000,  5.000000, ],
        [ 9.000000,  10.000000,  11.000000, ],
        [ 15.000000,  16.000000,  17.000000, ],
        [ 21.000000,  22.000000,  23.000000, ]])


In [4]:
def slice_2():
    src_shape = (4, 2, 3)
    dst_shape = (1,)
    indices = 10
    a = np.arange(np.prod(src_shape)).reshape(*src_shape).astype(np.float32)
    dst  = np.random.randn(*dst_shape).astype(np.float32)
    apply_slice(from_dlpack(a), from_dlpack(dst), indices)

slice_2()

tensor_value<vector<24xf32> o (4, 2, 3)> -> ?
tensor(raw_ptr(0x000055fc71cef030: f32, generic, align<4>) o (1):(1), data=
       [ 13.000000, ])


#### cute.print_tensor，只能打印 HBM 中的结果

In [21]:
import torch

@cute.jit
def slicing_examples(t: cute.Tensor):
    # Scalar access
    cute.printf("t[1,2] = {}", t[1, 2])

    #! 提取第二列 (shape: (N,)) using (None, row_index)
    row = t[(None, 1)]
    row_frag = cute.make_rmem_tensor(row.layout, row.element_type)
    row_frag.store(row.load())      #! 需要读回HBM才能通过 cute.print_tensor
    # print("Second row:")
    cute.print_tensor(row_frag)

    #!  提取第三行 (shape: (M,)) using (col_index, None)
    col = t[(2, None)]
    col_frag = cute.make_rmem_tensor(col.layout, col.element_type)
    col_frag.store(col.load())
    # print("Third column:")
    cute.print_tensor(col_frag)

    # Printing the first row directly (*t[2] == *t[2, 0])
    cute.printf(
        "t[2] = {} (equivalent to t[{}])",
        t[2],
        cute.make_identity_tensor(t.layout.shape)[2]
    )

# 4x3 example tensor
arr = torch.arange(12, dtype=torch.float32).reshape(4, 3)
slicing_examples(from_dlpack(arr))

t[1,2] = 5.000000
tensor(raw_ptr(0x00007ffe8f9f44e0: f32, rmem, align<32>) o (4):(3), data=
       [ 1.000000, ],
       [ 4.000000, ],
       [ 7.000000, ],
       [ 10.000000, ])
tensor(raw_ptr(0x00007ffe8f9f44c0: f32, rmem, align<32>) o (3):(1), data=
       [ 6.000000, ],
       [ 7.000000, ],
       [ 8.000000, ])
t[2] = 6.000000 (equivalent to t[(2,0)])


In [32]:
@cute.jit
def create_tensor_exam():
    layout = cute.make_layout((4, 3))
    # ptr = cute.make_ptr()
    reg_tensor = cute.make_rmem_tensor(layout, cutlass.BFloat16)
    #reg_tensor.fill(0)
    cute.print_tensor(reg_tensor)
    # return tensor

create_tensor_exam()

tensor(raw_ptr(0x00007ffe8f9f44e0: bf16, rmem, align<32>) o (4,3):(1,4), data=
       [[ 0.000000, -0.000000, -16.000000, ],
        [ 0.000000,  293601280.000000, -81857426827086135296.000000, ],
        [ 186091919409888222206532988439248240640.000000,  34634616274944.000000,  190079603397242969825244409620089274368.000000, ],
        [ 0.000000,  0.000000,  0.000000, ]])


## 算术运算 (Arithmetic) —— 自动向量化与类型提升

  add_res = a_vec + b_vec # 两个寄存器向量相加
  mul_res = a_vec * 2.0   # 向量与标量相乘
  - 含义：TensorSSA 重载了 Python 运算符，使其表现得像 NumPy。
  - 底层逻辑：
    - 自动广播：当 a_vec * 2.0 执行时，TensorSSA 会自动调用 vector.broadcast 将标量 2.0 扩展为与 a_vec 相同形状的向量。
    - 类型提升：它会自动处理不同精度（如 Float16 到 Float32）的转换，生成对应的类型转换 IR 指令。

In [34]:
@cute.jit
def binary_op_2(res: cute.Tensor, a: cute.Tensor, c: cutlass.Constexpr):
    a_vec = a.load()

    add_res = a_vec + c
    cute.print_tensor(add_res)  # prints [3.000000, 3.000000, 3.000000]

    sub_res = a_vec - c
    cute.print_tensor(sub_res)  # prints [-1.000000, -1.000000, -1.000000]

    mul_res = a_vec * c
    cute.print_tensor(mul_res)  # prints [2.000000, 2.000000, 2.000000]

    div_res = a_vec / c
    cute.print_tensor(div_res)  # prints [0.500000, 0.500000, 0.500000]

    floor_div_res = a_vec // c
    cute.print_tensor(floor_div_res)  # prints [0.000000, 0.000000, 0.000000]

    mod_res = a_vec % c
    cute.print_tensor(mod_res)  # prints [1.000000, 1.000000, 1.000000]


a = np.empty((3,), dtype=np.float32)
a.fill(1.0)
c = 2.0
res = np.empty((3,), dtype=np.float32)
binary_op_2(from_dlpack(res), from_dlpack(a), c)

tensor(raw_ptr(0x00007ffe8f9f4420: f32, rmem, align<32>) o (3):(1), data=
       [ 3.000000, ],
       [ 3.000000, ],
       [ 3.000000, ])
tensor(raw_ptr(0x00007ffe8f9f4440: f32, rmem, align<32>) o (3):(1), data=
       [-1.000000, ],
       [-1.000000, ],
       [-1.000000, ])
tensor(raw_ptr(0x00007ffe8f9f4460: f32, rmem, align<32>) o (3):(1), data=
       [ 2.000000, ],
       [ 2.000000, ],
       [ 2.000000, ])
tensor(raw_ptr(0x00007ffe8f9f4480: f32, rmem, align<32>) o (3):(1), data=
       [ 0.500000, ],
       [ 0.500000, ],
       [ 0.500000, ])
tensor(raw_ptr(0x00007ffe8f9f44a0: f32, rmem, align<32>) o (3):(1), data=
       [ 0.000000, ],
       [ 0.000000, ],
       [ 0.000000, ])
tensor(raw_ptr(0x00007ffe8f9f44c0: f32, rmem, align<32>) o (3):(1), data=
       [ 1.000000, ],
       [ 1.000000, ],
       [ 1.000000, ])


In [35]:
@cute.jit
def binary_op_3(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):
    a_vec = a.load()
    b_vec = b.load()

    gt_res = a_vec > b_vec
    res.store(gt_res)

    """
    ge_res = a_ >= b_   # [False, True, False]
    lt_res = a_ < b_    # [True, False, True]
    le_res = a_ <= b_   # [True, False, True]
    eq_res = a_ == b_   # [False, False, False]
    """


a = np.array([1, 2, 3], dtype=np.float32)
b = np.array([2, 1, 4], dtype=np.float32)
res = np.empty((3,), dtype=np.bool_)
binary_op_3(from_dlpack(res), from_dlpack(a), from_dlpack(b))
print(res)  # prints [False, True, False]

[False  True False]


## SSA reduce

In [39]:
@cute.jit
def ssa_reduce(a: cute.Tensor):
    """
    Apply reduction operation on the src tensor.

    :param src: The source tensor to be reduced.
    """
    a_vec = a.load()
    red_res = a_vec.reduce(cute.ReductionOp.ADD, 0.0, reduction_profile = 0)
    cute.printf(red_res)

    red_res = a_vec.reduce(cute.ReductionOp.ADD, 0.0, reduction_profile = (None, 1))
    cute.print_tensor(red_res)

    red_res = a_vec.reduce(cute.ReductionOp.ADD, 0.0, reduction_profile = (1, None))
    cute.print_tensor(red_res)

a = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float16)
ssa_reduce(from_dlpack(a))

21.000000
tensor(raw_ptr(0x00007ffe8f9f4460: f16, rmem, align<32>) o (2):(1), data=
       [ 6.000000, ],
       [ 15.000000, ])
tensor(raw_ptr(0x00007ffe8f9f4480: f16, rmem, align<32>) o (3):(1), data=
       [ 5.000000, ],
       [ 7.000000, ],
       [ 9.000000, ])
