diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 4ecac98cde20..0fee976eb130 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -88,7 +88,7 @@ from .op import q_multiply_shift, q_multiply_shift_per_axis, shift_left, shift_right from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace from .op import start_profile_intrinsic, end_profile_intrinsic -from .op import vscale, get_active_lane_mask, get_vscale_factor +from .op import vscale, get_active_lane_mask, get_vscale_expr from .generic import add, subtract, multiply from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index c086bebafa4c..22c8e1b9e913 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -3370,13 +3370,13 @@ def get_active_lane_mask(dtype, base, limit): return call_intrin(dtype, "tir.get_active_lane_mask", base, limit) -def get_vscale_factor(dtype: Union[str, tvm.DataType], min_size: int = 128) -> PrimExpr: +def get_vscale_expr(dtype: Union[str, tvm.DataType], min_size: int = 128) -> PrimExpr: """ Create a datatype dependent scalable expression. Parameters ---------- - dtype : tvm.DataType + dtype : Union[str, tvm.DataType] Element data type. min_size : int The minimum size of the scalable vector. diff --git a/python/tvm/tir/tensor_intrin/arm_cpu.py b/python/tvm/tir/tensor_intrin/arm_cpu.py index 97d8a304c981..3a3430af514f 100644 --- a/python/tvm/tir/tensor_intrin/arm_cpu.py +++ b/python/tvm/tir/tensor_intrin/arm_cpu.py @@ -173,7 +173,7 @@ def _create_ptrue_mask(dtype): """ Creates a mask that enables all lanes of a scalable vector. """ - return T.broadcast(T.bool(True), tir.get_vscale_factor(dtype)) + return T.broadcast(T.bool(True), tir.get_vscale_expr(dtype)) def get_sme_transpose_interleave_2svlx2svl_fp32_intrin(): @@ -213,7 +213,7 @@ def get_sme_transpose_interleave_2svlx2svl_fp32_intrin(): The SME TensorIntrin that can be used in tensorizing a schedule. """ - SVF = tir.get_vscale_factor("float32") + SVF = tir.get_vscale_expr("float32") SVF2 = 2 * SVF @T.prim_func @@ -347,7 +347,7 @@ def get_sme_transpose_interleave_block2_2svl_fp16_intrin(): """ # pylint: enable=line-too-long - SVF = tir.get_vscale_factor("float16") + SVF = tir.get_vscale_expr("float16") SVF2 = 2 * SVF @T.prim_func @@ -532,7 +532,7 @@ def get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K, in_dtype): The SME TensorIntrin that can be used in tensorizing a schedule. """ - SVF = tir.get_vscale_factor("float32") + SVF = tir.get_vscale_expr("float32") SVF2 = 2 * SVF fmopa_intrin = ( "llvm.aarch64.sme.mopa" if in_dtype == "float32" else "llvm.aarch64.sme.mopa.wide" @@ -577,7 +577,7 @@ def impl(): rows_per_iter = 1 if in_dtype == "float32" else 2 with T.serial(T.ceildiv(K, rows_per_iter)) as k: k_row = k * rows_per_iter - in_dtype_svf = tir.get_vscale_factor(in_dtype) + in_dtype_svf = tir.get_vscale_expr(in_dtype) a_low = T.BufferLoad(A, [k_row, T.Ramp(0, 1, in_dtype_svf)]) b_low = T.BufferLoad(B, [k_row, T.Ramp(0, 1, in_dtype_svf)]) diff --git a/python/tvm/topi/arm_cpu/dense_alter_op.py b/python/tvm/topi/arm_cpu/dense_alter_op.py index 398f8398af1c..0ad878b7412e 100644 --- a/python/tvm/topi/arm_cpu/dense_alter_op.py +++ b/python/tvm/topi/arm_cpu/dense_alter_op.py @@ -55,7 +55,15 @@ def _alter_dense(attrs, inputs, tinfos, out_type): weight_dtype = tinfos[1].dtype encoded_weight = inputs[1] + + # For dense the weights (rhs) are provided in transposed format, + # i.e. they are of the shape (n, k). transpose_b = True + + # The SME schedule expects the rhs to be in the format (k, n). We can do this + # transformation at compile time in the case of float32. Note: For the + # float16->float32 schedule the transformation currently happens at runtime + # with the ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE intrinsic. if weight_dtype == "float32": encoded_weight = relay.const(encoded_weight.data.numpy().transpose(), weight_dtype) transpose_b = False diff --git a/python/tvm/topi/arm_cpu/matmul.py b/python/tvm/topi/arm_cpu/matmul.py index 42db54137f08..2f09e24c87a2 100644 --- a/python/tvm/topi/arm_cpu/matmul.py +++ b/python/tvm/topi/arm_cpu/matmul.py @@ -46,11 +46,11 @@ def compute_matmul_sme(cfg, data_a, data_b, _, out_dtype, transpose_a=False, tra if not out_dtype: out_dtype = data_a.dtype - tile_m = 2 * tvm.tir.get_vscale_factor(data_a.dtype) - tile_k = tvm.tir.get_vscale_factor(data_a.dtype) + tile_m = 2 * tvm.tir.get_vscale_expr(data_a.dtype) + tile_k = tvm.tir.get_vscale_expr(data_a.dtype) if data_a.dtype == "float32": tile_k *= 2 - tile_n = 2 * tvm.tir.get_vscale_factor(data_a.dtype) + tile_n = 2 * tvm.tir.get_vscale_expr(data_a.dtype) M_padded, pad_M = pad_dim_to_multiple(M, tile_m) _, pad_K = pad_dim_to_multiple(K, tile_k) @@ -140,13 +140,13 @@ def tir_schedule_matmul_sme(sch): extent_n = sch.get(n).extent if in_dtype == "float16": - tile_m = T.cast(2 * tvm.tir.get_vscale_factor(in_dtype), extent_m.dtype) - tile_k = T.cast(tvm.tir.get_vscale_factor(in_dtype), extent_k.dtype) - tile_n = T.cast(2 * tvm.tir.get_vscale_factor(in_dtype), extent_n.dtype) + tile_m = T.cast(2 * tvm.tir.get_vscale_expr(in_dtype), extent_m.dtype) + tile_k = T.cast(tvm.tir.get_vscale_expr(in_dtype), extent_k.dtype) + tile_n = T.cast(2 * tvm.tir.get_vscale_expr(in_dtype), extent_n.dtype) else: - tile_m = T.cast(2 * tvm.tir.get_vscale_factor(in_dtype), extent_m.dtype) - tile_k = T.cast(2 * tvm.tir.get_vscale_factor(in_dtype), extent_k.dtype) - tile_n = T.cast(2 * tvm.tir.get_vscale_factor(in_dtype), extent_n.dtype) + tile_m = T.cast(2 * tvm.tir.get_vscale_expr(in_dtype), extent_m.dtype) + tile_k = T.cast(2 * tvm.tir.get_vscale_expr(in_dtype), extent_k.dtype) + tile_n = T.cast(2 * tvm.tir.get_vscale_expr(in_dtype), extent_n.dtype) # Interleave the input utilizing the matrix tile interleave_a_block = sch.cache_read(gemm_block, 0, "global") @@ -170,8 +170,8 @@ def tir_schedule_matmul_sme(sch): sch.tensorize(inner_k, transpose_interleave_intrin_name) # Split and reorder the loops of the GeMM for tensorization - tile_m = T.cast(2 * tvm.tir.get_vscale_factor(out_dtype), extent_m.dtype) - tile_n = T.cast(2 * tvm.tir.get_vscale_factor(out_dtype), extent_n.dtype) + tile_m = T.cast(2 * tvm.tir.get_vscale_expr(out_dtype), extent_m.dtype) + tile_n = T.cast(2 * tvm.tir.get_vscale_expr(out_dtype), extent_n.dtype) m, n, k = sch.get_loops(gemm_block) outer_m, inner_m = sch.split(m, factors=(None, tile_m), disable_predication=True) outer_n, inner_n = sch.split(n, factors=(None, tile_n), disable_predication=True) diff --git a/tests/python/relay/strategy/arm_cpu/test_dense.py b/tests/python/relay/strategy/arm_cpu/test_dense.py index 0419d14201f0..3a8427e8154d 100644 --- a/tests/python/relay/strategy/arm_cpu/test_dense.py +++ b/tests/python/relay/strategy/arm_cpu/test_dense.py @@ -102,8 +102,8 @@ class TestDense(BasicDenseTests): "data_shape,weight_shape", [ ((32, 32), (32, 32)), - ((3, 3), (68, 3)), ((2, 35), (6, 35)), + ((3, 3), (68, 3)), ((79, 65), (152, 65)), ], )