In [1]:
import tvm
from tvm import tir
from tvm.script import tir as T
from tvm.tir import TensorIntrin

import numpy as np

from tvm import rpc
from tvm.contrib import utils


target = tvm.target.Target({
                            "kind": "llvm",
                            "mtriple": "riscv64-linux-unknown-gnu",
                            "mattr": ["+m", "+a", "+f", "+d", "+c", "+v"],
                            "mabi": "lp64d",
                            "vector-width": 256 ,
                            "cl-opt": ["-riscv-v-register-bit-width-lmul:int=2"]
                        })


In [11]:
import tvm
from tvm import tir
from tvm.script import tir as T

M = 16
N = 16
K = 64

@T.prim_func
def mm( a_: T.handle, b_: T.handle, c_: T.handle ) -> None:
    A = T.match_buffer(a_, (M, K))
    B = T.match_buffer(b_, (K, N))
    C =T.match_buffer(c_, (M, N))
    for i, j, k in T.grid(M, N, K):
        with T.block("update"):
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            with T.init():
                C[vi, vj] = 0.0
            C[vi, vj] +=  A[vi, vk] * B[vk, vj]

In [7]:
def get_kername(m_size, n_size):
    return "mm_"+ str(m_size) + "x" + str(n_size)

def init_mm_register(m_size, n_size, k_size):
    ker_name = get_kername(m_size, n_size)
    from tvm.script import tir as T
    @T.prim_func
    def desc(a: T.handle, b: T.handle, c: T.handle) -> None:
        A = T.match_buffer(a, (m_size, k_size), align=64, offset_factor=1)
        B = T.match_buffer(b, (k_size, n_size), align=64, offset_factor=1)
        C = T.match_buffer(c, (m_size, n_size), align=64, offset_factor=1)
        
        with T.block("root"):
            T.reads(A[0 : m_size, 0 : k_size], B[0 : k_size, 0 : n_size])
            T.writes(C[0 : m_size, 0 : n_size])
            with T.init():
                for vi, vj in T.grid(m_size, n_size):
                    with T.block("gemm_init"):
                        i, j = T.axis.remap("SS", [vi, vj])
                        T.reads()
                        T.writes(C[i, j])
                        C[i, j] = T.float32(0.0)
            for vi, vj, vk in T.grid(m_size, n_size, k_size):
                with T.block("gemm"):
                    i, j, k = T.axis.remap("SSR", [vi, vj, vk])
                    T.reads(C[i, j], A[i, k], B[k, j])
                    T.writes(C[i, j])
                    C[i, j] = C[i, j] + A[i, k] * B[k, j]

    @T.prim_func
    def intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
        A = T.match_buffer(a, (m_size, k_size), align=64, offset_factor=1)
        B = T.match_buffer(b, (k_size, n_size), align=64, offset_factor=1)
        C = T.match_buffer(c, (m_size, n_size), align=64, offset_factor=1)
        with T.block("root"):
            T.reads(A[0 : m_size, 0 : k_size], B[0 : k_size, 0 : n_size])
            T.writes(C[0 : m_size, 0 : n_size])
            T.evaluate(
                T.call_extern(
                    "void",                   
                    ker_name,           
                    C.data, A.data, B.data, 64
                )
            )
    tir.TensorIntrin.register(ker_name, desc, intrin)
    return ker_name

In [8]:
ker_name = init_mm_register(M, N, K)

In [9]:
print(ker_name)

mm_16x16


In [None]:
c_src = """ 
#include <riscv_vector.h>
void mm_16x16(float *out, const float *a, const float *b, const int K) { 
    vfloat32m2_t  acc0, acc1, acc2, acc3, acc4, acc5, acc6, acc7, acc8, acc9, acc10, acc11, acc12, acc13, acc14;
    size_t vl = __riscv_vsetvl_e32m2(16);
    acc0 = __riscv_vfmv_v_f_f32m2(0.f, vl);
    acc1 = __riscv_vfmv_v_f_f32m2(0.f, vl);
    acc2 = __riscv_vfmv_v_f_f32m2(0.f, vl);
    acc3 = __riscv_vfmv_v_f_f32m2(0.f, vl);
    acc4 = __riscv_vfmv_v_f_f32m2(0.f, vl);
    acc5 = __riscv_vfmv_v_f_f32m2(0.f, vl);
    acc6 = __riscv_vfmv_v_f_f32m2(0.f, vl);
    acc7 = __riscv_vfmv_v_f_f32m2(0.f, vl);
    acc8 = __riscv_vfmv_v_f_f32m2(0.f, vl);
    acc9 = __riscv_vfmv_v_f_f32m2(0.f, vl);
    acc10 = __riscv_vfmv_v_f_f32m2(0.f, vl);
    acc11 = __riscv_vfmv_v_f_f32m2(0.f, vl);
    acc12 = __riscv_vfmv_v_f_f32m2(0.f, vl);
    acc13 = __riscv_vfmv_v_f_f32m2(0.f, vl);
    acc14 = __riscv_vfmv_v_f_f32m2(0.f, vl);
    for(int k = 0 ; k < K; ++k){
        vfloat32m2_t vb = __riscv_vle32_v_f32m2(b + 16 * k, vl);
        acc0 = __riscv_vfmacc_vf_f32m2(acc0, *(a + k + K * 0), vb, vl);
        acc1 = __riscv_vfmacc_vf_f32m2(acc1, *(a + k + K * 1), vb, vl);
        acc2 = __riscv_vfmacc_vf_f32m2(acc2, *(a + k + K * 2), vb, vl);
        acc3 = __riscv_vfmacc_vf_f32m2(acc3, *(a + k + K * 3), vb, vl);
        acc4 = __riscv_vfmacc_vf_f32m2(acc4, *(a + k + K * 4), vb, vl);
        acc5 = __riscv_vfmacc_vf_f32m2(acc5, *(a + k + K * 5), vb, vl);
        acc6 = __riscv_vfmacc_vf_f32m2(acc6, *(a + k + K * 6), vb, vl);
        acc7 = __riscv_vfmacc_vf_f32m2(acc7, *(a + k + K * 7), vb, vl);
        acc8 = __riscv_vfmacc_vf_f32m2(acc8, *(a + k + K * 8), vb, vl);
        acc9 = __riscv_vfmacc_vf_f32m2(acc9, *(a + k + K * 9), vb, vl);
        acc10 = __riscv_vfmacc_vf_f32m2(acc10, *(a + k + K * 10), vb, vl);
        acc11 = __riscv_vfmacc_vf_f32m2(acc11, *(a + k + K * 11), vb, vl);
        acc12 = __riscv_vfmacc_vf_f32m2(acc12, *(a + k + K * 12), vb, vl);
        acc13 = __riscv_vfmacc_vf_f32m2(acc13, *(a + k + K * 13), vb, vl);
        acc14 = __riscv_vfmacc_vf_f32m2(acc14, *(a + k + K * 14), vb, vl);
    }
    const int out_strides = 16;
    __riscv_vse32_v_f32m2(out + out_strides * 0, acc0, vl);
    __riscv_vse32_v_f32m2(out + out_strides * 1, acc1, vl);
    __riscv_vse32_v_f32m2(out + out_strides * 2, acc2, vl);
    __riscv_vse32_v_f32m2(out + out_strides * 3, acc3, vl);
    __riscv_vse32_v_f32m2(out + out_strides * 4, acc4, vl);
    __riscv_vse32_v_f32m2(out + out_strides * 5, acc5, vl);
    __riscv_vse32_v_f32m2(out + out_strides * 6, acc6, vl);
    __riscv_vse32_v_f32m2(out + out_strides * 7, acc7, vl);
    __riscv_vse32_v_f32m2(out + out_strides * 8, acc8, vl);
    __riscv_vse32_v_f32m2(out + out_strides * 9, acc9, vl);
    __riscv_vse32_v_f32m2(out + out_strides * 10, acc10, vl);
    __riscv_vse32_v_f32m2(out + out_strides * 11, acc11, vl);
    __riscv_vse32_v_f32m2(out + out_strides * 12, acc12, vl);
    __riscv_vse32_v_f32m2(out + out_strides * 13, acc13, vl);
    __riscv_vse32_v_f32m2(out + out_strides * 14, acc14, vl);

    acc0 = __riscv_vfmv_v_f_f32m2(0.f, vl);
    for(int k = 0 ; k < K; ++k){
        vfloat32m2_t vb = __riscv_vle32_v_f32m2(b + 16 * k, vl);
        acc0 = __riscv_vfmacc_vf_f32m2(acc0, *(a + k + K * 15), vb, vl);
    }
    __riscv_vse32_v_f32m2(out + out_strides * 15, acc0, vl);
}
"""

In [16]:
sch = tir.Schedule(mm)
s_blk = sch.get_block("update")    
loop_x, loop_y, loop_z = sch.get_loops(s_blk)[-3:]
sch.tensorize(loop_x, ker_name)

In [17]:
print(sch.mod["main"].script())

# from tvm.script import tir as T

@T.prim_func
def mm(A: T.Buffer((16, 64), "float32"), B: T.Buffer((64, 16), "float32"), C: T.Buffer((16, 16), "float32")):
    # with T.block("root"):
    with T.block("update_o"):
        vi_o = T.axis.spatial(1, 0)
        vj_o = T.axis.spatial(1, 0)
        vk_o = T.axis.reduce(1, 0)
        T.reads(A[0:16, 0:64], B[0:64, 0:16])
        T.writes(C[0:16, 0:16])
        A_1 = T.match_buffer(A[0:16, 0:64], (16, 64), offset_factor=1)
        B_1 = T.match_buffer(B[0:64, 0:16], (64, 16), offset_factor=1)
        C_1 = T.match_buffer(C[0:16, 0:16], (16, 16), offset_factor=1)
        with T.init():
            for i, j in T.grid(16, 16):
                with T.block("update_init"):
                    vi_i_init, vj_i_init = T.axis.remap("SS", [i, j])
                    T.reads()
                    T.writes(C[vi_i_init, vj_i_init])
                    C[vi_i_init, vj_i_init] = T.float32(0.0)
        T.call_extern("void", "mm_16x16", C_1.data, A_1.data, B

In [22]:
sch.annotate(sch.get_block("update_o"), "pragma_import_c", c_src)

In [23]:
print(sch.mod["main"])

# from tvm.script import tir as T

@T.prim_func
def mm(A: T.Buffer((16, 64), "float32"), B: T.Buffer((64, 16), "float32"), C: T.Buffer((16, 16), "float32")):
    # with T.block("root"):
    with T.block("update_o"):
        vi_o = T.axis.spatial(1, 0)
        vj_o = T.axis.spatial(1, 0)
        vk_o = T.axis.reduce(1, 0)
        T.reads(A[0:16, 0:64], B[0:64, 0:16])
        T.writes(C[0:16, 0:16])
        T.block_attr({"pragma_import_c": metadata["runtime.String"][0]})
        A_1 = T.match_buffer(A[0:16, 0:64], (16, 64), offset_factor=1)
        B_1 = T.match_buffer(B[0:64, 0:16], (64, 16), offset_factor=1)
        C_1 = T.match_buffer(C[0:16, 0:16], (16, 16), offset_factor=1)
        with T.init():
            for i, j in T.grid(16, 16):
                with T.block("update_init"):
                    vi_i_init, vj_i_init = T.axis.remap("SS", [i, j])
                    T.reads()
                    T.writes(C[vi_i_init, vj_i_init])
                    C[vi_i_init, vj_i_init] = T.flo

In [24]:
f = tvm.build(sch.mod, target="c")

In [26]:
print(f.get_source())

// tvm target: c -keys=cpu 
#define TVM_EXPORTS
#include "tvm/runtime/c_runtime_api.h"
#include "tvm/runtime/c_backend_api.h"
#include <math.h>
#include <stdbool.h>
 
void mm_16x16(float *out, const float *a, const float *b, const int K) { 
    vfloat32m2_t  acc0, acc1, acc2, acc3, acc4, acc5, acc6, acc7, acc8, acc9, acc10, acc11, acc12, acc13, acc14;
    size_t vl = __riscv_vsetvl_e32m2(16);
    acc0 = __riscv_vfmv_v_f_f32m2(0.f, vl);
    acc1 = __riscv_vfmv_v_f_f32m2(0.f, vl);
    acc2 = __riscv_vfmv_v_f_f32m2(0.f, vl);
    acc3 = __riscv_vfmv_v_f_f32m2(0.f, vl);
    acc4 = __riscv_vfmv_v_f_f32m2(0.f, vl);
    acc5 = __riscv_vfmv_v_f_f32m2(0.f, vl);
    acc6 = __riscv_vfmv_v_f_f32m2(0.f, vl);
    acc7 = __riscv_vfmv_v_f_f32m2(0.f, vl);
    acc8 = __riscv_vfmv_v_f_f32m2(0.f, vl);
    acc9 = __riscv_vfmv_v_f_f32m2(0.f, vl);
    acc10 = __riscv_vfmv_v_f_f32m2(0.f, vl);
    acc11 = __riscv_vfmv_v_f_f32m2(0.f, vl);
    acc12 = __riscv_vfmv_v_f_f32m2(0.f, vl);
    acc13 = __riscv_vfmv_v_f_