Skip to content

Commit

Permalink
Fix topi_matmul test and avoid scalable expression warnings
Browse files Browse the repository at this point in the history
Change-Id: I32273241ae7569b65e082759e4f2ca4355ac6933
  • Loading branch information
lhutton1 committed May 15, 2024
1 parent 7d10268 commit 7363127
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 10 deletions.
2 changes: 1 addition & 1 deletion tests/python/relay/strategy/arm_cpu/test_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def test_sme_dense(data_shape, weight_shape, in_dtype):

with tvm.transform.PassContext(
opt_level=3, config=AOT_APROFILE_AEM_RUNNER.pass_config
), meta_schedule.database.ScheduleFnDatabase(arm_cpu_tir_strategy):
), target, meta_schedule.database.ScheduleFnDatabase(arm_cpu_tir_strategy):
executor_factory = tvm.relay.build(
ir_mod,
target=target,
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/strategy/arm_cpu/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_sme_matmul_with_const_b(data_shape, weight_shape, transpose_a, transpos
)
with tvm.transform.PassContext(
opt_level=3, config=AOT_APROFILE_AEM_RUNNER.pass_config
), meta_schedule.database.ScheduleFnDatabase(arm_cpu_tir_strategy):
), target, meta_schedule.database.ScheduleFnDatabase(arm_cpu_tir_strategy):
executor_factory = tvm.relay.build(
ir_mod,
target=target,
Expand Down
31 changes: 23 additions & 8 deletions tests/python/topi/test_topi_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,30 @@ def test_tensordot():
verify_tensordot((4, 3, 2, 2), (2, 4, 3, 5), ((1, 2, 0), (2, 0, 1)))


@pytest.mark.parametrize("transpose_a,transpose_b", [(True, False), (False, True)])
def test_unsupported_sme_matmul_compute_transpose(transpose_a, transpose_b):
"""
SME matmul compute does not support transposed inputs for now.
"""
err_msg = "Compute definition currently does not support transposed inputs."
with pytest.raises(AssertionError, match=err_msg) as e:
@pytest.mark.parametrize("in_dtype", ["float32", "float16"])
def test_unsupported_sme_matmul_compute_transpose_a(in_dtype):
err_msg = "Transposed lhs not currently supported."
with pytest.raises(AssertionError, match=err_msg):
compute_matmul_sme(
te.placeholder((32, 32), dtype=in_dtype),
te.placeholder((32, 32), dtype=in_dtype),
None,
None,
True,
False,
)


def test_unsupported_sme_matmul_compute_transpose_b():
err_msg = "Rhs must be transposed when dtype is float16."
with pytest.raises(AssertionError, match=err_msg):
compute_matmul_sme(
te.placeholder((32, 32)), te.placeholder((32, 32)), None, None, transpose_a, transpose_b
te.placeholder((32, 32), dtype="float16"),
te.placeholder((32, 32), dtype="float16"),
None,
None,
False,
False,
)


Expand Down

0 comments on commit 7363127

Please sign in to comment.