In [None]:
import tvm

import tvm.te as te

from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R


def showmod(mod: tvm.ir.module.IRModule):
    mod.show(
        black_format=True,
        show_meta=False,
        verbose_expr=True,
        show_object_address=False,
        show_all_struct_info=True,
    )


def createandshowmod(ops):
    te_func = te.create_prim_func(ops).with_attrs({"global_symbol": "test"})
    mod = tvm.IRModule({"test": te_func})
    showmod(mod)


from tvm.tir.buffer import *

#### decl_buffer

In [None]:
"""Declare a new symbolic buffer.

Normally buffer is created automatically during lower and build.
This is only needed if user want to specify their own buffer layout.

See the note below for detailed discussion on usage of buffer.

Parameters
----------
shape : tuple of Expr
    The shape of the buffer.

dtype : str, optional
    The data type of the buffer.

name : str, optional
    The name of the buffer.

data : tir.Var, optional
    The data pointer in the buffer.

strides: array of Expr
    The stride of the buffer.

elem_offset: Expr, optional
    The beginning offset of the array to data.
    In terms of number of elements of dtype.

scope: str, optional
    The storage scope of the buffer, if not global.
    If scope equals empty string, it means it is global memory.

data_alignment: int, optional
    The alignment of data pointer in bytes.
    If -1 is passed, the alignment will be set to TVM's internal default.

offset_factor: int, optional
    The factor of elem_offset field, when set,
    elem_offset is required to be multiple of offset_factor.
    If 0 is pssed, the alignment will be set to 1.
    if non-zero is passed, we will created a Var for elem_offset if elem_offset is not None.

buffer_type: str, optional, {"", "auto_broadcast"}
    auto_broadcast buffer allows one to implement broadcast computation
    without considering whether dimension size equals to one.
    TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension j's shape equals 1.

axis_separators : list of int, optional
    If passed, a list of separators between groups of axes,
    each of which is flattened to an output axis.  For flat
    memory spaces, should either be None, or an empty list.

span: Optional[Span]
    The location of the decl_buffer creation in the source.

Returns
-------
buffer : tvm.tir.Buffer
    The created buffer

Note
----
Buffer data structure reflects the DLTensor structure in dlpack.
While DLTensor data structure is very general, it is usually helpful
to create function that only handles specific case of data structure
and make compiled function benefit from it.

If user pass strides and elem_offset is passed as None
when constructing the function, then the function will be specialized
for the DLTensor that is compact and aligned.
If user pass a fully generic symbolic array to the strides,
then the resulting function becomes fully generic.
"""

import tvm.script
import tvm.script.ir_builder


M, N = tvm.tir.Var("M", "int32"), tvm.tir.Var("N", "int32")
shape = (M, N)
dtype = "float32"
name = "buffer"
data = None
strides = [N, 1]
elem_offset = tvm.tir.Var("elem_offset", dtype="int32")
scope = "global"
data_alignment = 4
offset_factor = 4
buffer_type = "auto_broadcast"
axis_separators = []

buffer = decl_buffer(
    shape=shape,
    dtype=dtype,
    name=name,
    data=data,
    strides=strides,
    elem_offset=elem_offset,
    scope=scope,
    data_alignment=data_alignment,
    offset_factor=offset_factor,
    buffer_type=buffer_type,
    axis_separators=axis_separators,
)

i = tvm.tir.Var("i", "int32")
j = tvm.tir.Var("j", "int32")

initial_loop_M = tvm.tir.For(
    loop_var=i,
    min=0,
    extent=M,
    kind=tvm.tir.ForKind.PARALLEL,
    body=tvm.tir.BufferStore(buffer, 1.0, [i, 0]),
)
initial_loop_N = tvm.tir.For(
    loop_var=j,
    min=1,
    extent=N - 1,
    kind=tvm.tir.ForKind.PARALLEL,
    body=tvm.tir.BufferStore(buffer, 1.0, [0, j]),
)

m = tvm.tir.Var("m", "int32")
n = tvm.tir.Var("n", "int32")

inner_loop = tvm.tir.For(
    loop_var=n,
    min=1,
    extent=N - 1,
    kind=tvm.tir.ForKind.PARALLEL,  # serial/parallel/vectorized/unrolled
    body=tvm.tir.BufferStore(
        buffer,
        tvm.tir.BufferLoad(buffer, [m - 1, n]) + tvm.tir.BufferLoad(buffer, [m, n - 1]),
        [m, n],
    ),
)
outer_loop = tvm.tir.For(
    loop_var=m,
    min=1,
    extent=M - 1,
    kind=tvm.tir.ForKind.PARALLEL,
    body=inner_loop,
)

func = tvm.tir.PrimFunc(
    params=[M, N, elem_offset, buffer],
    body=tvm.tir.SeqStmt([initial_loop_M, initial_loop_N, outer_loop]),
    ret_type=None,
)
ir_module = tvm.IRModule({"main": func})
# showmod(ir_module)

ir_module = tvm.transform.Sequential([tvm.tir.transform.Simplify()])(ir_module)
showmod(ir_module)

seq = tvm.transform.Sequential(
    [
        tvm.tir.transform.VectorizeLoop(),
        tvm.tir.transform.UnrollLoop(),
    ]
)

ir_module = seq(ir_module)
showmod(ir_module)

In [None]:
M, N = 4, 256
shape = (M, N)
dtype = "float32"
name = "buffer"
data = None
strides = [N, 1]
elem_offset = tvm.tir.Var("elem_offset", dtype="int32")
scope = "global"
data_alignment = 4
offset_factor = 4
buffer_type = "auto_broadcast"
axis_separators = []

buffer = decl_buffer(
    shape=shape,
    dtype=dtype,
    name=name,
    data=data,
    strides=strides,
    elem_offset=elem_offset,
    scope=scope,
    data_alignment=data_alignment,
    offset_factor=offset_factor,
    buffer_type=buffer_type,
    axis_separators=axis_separators,
)

i = tvm.tir.Var("i", "int32")
j = tvm.tir.Var("j", "int32")

initial_loop_M = tvm.tir.For(
    loop_var=i,
    min=0,
    extent=M,
    kind=tvm.tir.ForKind.VECTORIZED,
    body=tvm.tir.BufferStore(buffer, 1.0, [i, 0]),
)
initial_loop_N = tvm.tir.For(
    loop_var=j,
    # TODO: 缺少一个 Pass，能否自动识别可向量化但是 min 不为零的循环。
    min=0,  # change to 0 for vectorized
    extent=N - 1,
    kind=tvm.tir.ForKind.VECTORIZED,
    body=tvm.tir.BufferStore(buffer, 1.0, [0, j + 1]),
)

m = tvm.tir.Var("m", "int32")
n = tvm.tir.Var("n", "int32")

inner_loop = tvm.tir.For(
    loop_var=n,
    min=0,  # change to 0 for vectorized
    extent=N - 1,
    kind=tvm.tir.ForKind.VECTORIZED,  # serial/parallel/vectorized/unrolled
    body=tvm.tir.BufferStore(
        buffer,
        tvm.tir.BufferLoad(buffer, [m - 1, n + 1]) + tvm.tir.BufferLoad(buffer, [m, n]),
        [m, n + 1],
    ),
)
outer_loop = tvm.tir.For(
    loop_var=m,
    min=1,  # Only the last index of a buffer can be used in vectorized loops.
    extent=M - 1,
    # UnrollLoop 通常需要以下条件之一才能生效：
    # - 循环边界是编译时常量且较小（默认阈值是 16）
    # - 循环被显式标记为 kind=UNROLLED
    kind=tvm.tir.ForKind.UNROLLED,
    body=inner_loop,
)

func = tvm.tir.PrimFunc(
    params=[elem_offset, buffer],
    body=tvm.tir.SeqStmt([initial_loop_M, initial_loop_N, outer_loop]),
    ret_type=None,
)
ir_module = tvm.IRModule({"main": func})
showmod(ir_module)

seq = tvm.transform.Sequential(
    [
        tvm.tir.transform.Simplify(),
        tvm.tir.transform.VectorizeLoop(),
        tvm.tir.transform.UnrollLoop(),
    ]
)

ir_module = seq(ir_module)
showmod(ir_module)

seq = tvm.transform.Sequential(
    [
        tvm.tir.transform.FlattenBuffer(),
    ]
)
ir_module = seq(ir_module)
showmod(ir_module)