Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
Change-Id: I237b4c5cb5ca22e33529d98cbd75177b94904857
  • Loading branch information
lhutton1 committed May 22, 2024
1 parent 7363127 commit 0d2be71
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 20 deletions.
2 changes: 1 addition & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
from .op import q_multiply_shift, q_multiply_shift_per_axis, shift_left, shift_right
from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace
from .op import start_profile_intrinsic, end_profile_intrinsic
from .op import vscale, get_active_lane_mask, get_vscale_factor
from .op import vscale, get_active_lane_mask, get_vscale_expr
from .generic import add, subtract, multiply

from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3370,13 +3370,13 @@ def get_active_lane_mask(dtype, base, limit):
return call_intrin(dtype, "tir.get_active_lane_mask", base, limit)


def get_vscale_factor(dtype: Union[str, tvm.DataType], min_size: int = 128) -> PrimExpr:
def get_vscale_expr(dtype: Union[str, tvm.DataType], min_size: int = 128) -> PrimExpr:
"""
Create a datatype dependent scalable expression.
Parameters
----------
dtype : tvm.DataType
dtype : Union[str, tvm.DataType]
Element data type.
min_size : int
The minimum size of the scalable vector.
Expand Down
10 changes: 5 additions & 5 deletions python/tvm/tir/tensor_intrin/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _create_ptrue_mask(dtype):
"""
Creates a mask that enables all lanes of a scalable vector.
"""
return T.broadcast(T.bool(True), tir.get_vscale_factor(dtype))
return T.broadcast(T.bool(True), tir.get_vscale_expr(dtype))


def get_sme_transpose_interleave_2svlx2svl_fp32_intrin():
Expand Down Expand Up @@ -213,7 +213,7 @@ def get_sme_transpose_interleave_2svlx2svl_fp32_intrin():
The SME TensorIntrin that can be used in tensorizing a schedule.
"""
SVF = tir.get_vscale_factor("float32")
SVF = tir.get_vscale_expr("float32")
SVF2 = 2 * SVF

@T.prim_func
Expand Down Expand Up @@ -347,7 +347,7 @@ def get_sme_transpose_interleave_block2_2svl_fp16_intrin():
"""
# pylint: enable=line-too-long
SVF = tir.get_vscale_factor("float16")
SVF = tir.get_vscale_expr("float16")
SVF2 = 2 * SVF

@T.prim_func
Expand Down Expand Up @@ -532,7 +532,7 @@ def get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K, in_dtype):
The SME TensorIntrin that can be used in tensorizing a schedule.
"""
SVF = tir.get_vscale_factor("float32")
SVF = tir.get_vscale_expr("float32")
SVF2 = 2 * SVF
fmopa_intrin = (
"llvm.aarch64.sme.mopa" if in_dtype == "float32" else "llvm.aarch64.sme.mopa.wide"
Expand Down Expand Up @@ -577,7 +577,7 @@ def impl():
rows_per_iter = 1 if in_dtype == "float32" else 2
with T.serial(T.ceildiv(K, rows_per_iter)) as k:
k_row = k * rows_per_iter
in_dtype_svf = tir.get_vscale_factor(in_dtype)
in_dtype_svf = tir.get_vscale_expr(in_dtype)

a_low = T.BufferLoad(A, [k_row, T.Ramp(0, 1, in_dtype_svf)])
b_low = T.BufferLoad(B, [k_row, T.Ramp(0, 1, in_dtype_svf)])
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/topi/arm_cpu/dense_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,15 @@ def _alter_dense(attrs, inputs, tinfos, out_type):

weight_dtype = tinfos[1].dtype
encoded_weight = inputs[1]

# For dense the weights (rhs) are provided in transposed format,
# i.e. they are of the shape (n, k).
transpose_b = True

# The SME schedule expects the rhs to be in the format (k, n). We can do this
# transformation at compile time in the case of float32. Note: For the
# float16->float32 schedule the transformation currently happens at runtime
# with the ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE intrinsic.
if weight_dtype == "float32":
encoded_weight = relay.const(encoded_weight.data.numpy().transpose(), weight_dtype)
transpose_b = False
Expand Down
22 changes: 11 additions & 11 deletions python/tvm/topi/arm_cpu/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ def compute_matmul_sme(cfg, data_a, data_b, _, out_dtype, transpose_a=False, tra
if not out_dtype:
out_dtype = data_a.dtype

tile_m = 2 * tvm.tir.get_vscale_factor(data_a.dtype)
tile_k = tvm.tir.get_vscale_factor(data_a.dtype)
tile_m = 2 * tvm.tir.get_vscale_expr(data_a.dtype)
tile_k = tvm.tir.get_vscale_expr(data_a.dtype)
if data_a.dtype == "float32":
tile_k *= 2
tile_n = 2 * tvm.tir.get_vscale_factor(data_a.dtype)
tile_n = 2 * tvm.tir.get_vscale_expr(data_a.dtype)

M_padded, pad_M = pad_dim_to_multiple(M, tile_m)
_, pad_K = pad_dim_to_multiple(K, tile_k)
Expand Down Expand Up @@ -140,13 +140,13 @@ def tir_schedule_matmul_sme(sch):
extent_n = sch.get(n).extent

if in_dtype == "float16":
tile_m = T.cast(2 * tvm.tir.get_vscale_factor(in_dtype), extent_m.dtype)
tile_k = T.cast(tvm.tir.get_vscale_factor(in_dtype), extent_k.dtype)
tile_n = T.cast(2 * tvm.tir.get_vscale_factor(in_dtype), extent_n.dtype)
tile_m = T.cast(2 * tvm.tir.get_vscale_expr(in_dtype), extent_m.dtype)
tile_k = T.cast(tvm.tir.get_vscale_expr(in_dtype), extent_k.dtype)
tile_n = T.cast(2 * tvm.tir.get_vscale_expr(in_dtype), extent_n.dtype)
else:
tile_m = T.cast(2 * tvm.tir.get_vscale_factor(in_dtype), extent_m.dtype)
tile_k = T.cast(2 * tvm.tir.get_vscale_factor(in_dtype), extent_k.dtype)
tile_n = T.cast(2 * tvm.tir.get_vscale_factor(in_dtype), extent_n.dtype)
tile_m = T.cast(2 * tvm.tir.get_vscale_expr(in_dtype), extent_m.dtype)
tile_k = T.cast(2 * tvm.tir.get_vscale_expr(in_dtype), extent_k.dtype)
tile_n = T.cast(2 * tvm.tir.get_vscale_expr(in_dtype), extent_n.dtype)

# Interleave the input utilizing the matrix tile
interleave_a_block = sch.cache_read(gemm_block, 0, "global")
Expand All @@ -170,8 +170,8 @@ def tir_schedule_matmul_sme(sch):
sch.tensorize(inner_k, transpose_interleave_intrin_name)

# Split and reorder the loops of the GeMM for tensorization
tile_m = T.cast(2 * tvm.tir.get_vscale_factor(out_dtype), extent_m.dtype)
tile_n = T.cast(2 * tvm.tir.get_vscale_factor(out_dtype), extent_n.dtype)
tile_m = T.cast(2 * tvm.tir.get_vscale_expr(out_dtype), extent_m.dtype)
tile_n = T.cast(2 * tvm.tir.get_vscale_expr(out_dtype), extent_n.dtype)
m, n, k = sch.get_loops(gemm_block)
outer_m, inner_m = sch.split(m, factors=(None, tile_m), disable_predication=True)
outer_n, inner_n = sch.split(n, factors=(None, tile_n), disable_predication=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/strategy/arm_cpu/test_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ class TestDense(BasicDenseTests):
"data_shape,weight_shape",
[
((32, 32), (32, 32)),
((3, 3), (68, 3)),
((2, 35), (6, 35)),
((3, 3), (68, 3)),
((79, 65), (152, 65)),
],
)
Expand Down

0 comments on commit 0d2be71

Please sign in to comment.