From 9b0df1e9a96ff19418909223f303abf87a88de04 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 15 Sep 2023 15:14:09 -0500 Subject: [PATCH] [TIR] Do not drop 4th argument to tir.max The `tir.op.comm_reducer` utility provides two distinct APIs, either reducing along a tensor axis or reducing along a list of arguments. Prior to this commit, when reducing along a list of arguments, the 4th argument was silently dropped. For example, `tvm.tir.max(1,2,3,4,3,2,1)` would return `3`. --- python/tvm/tir/op.py | 8 +++++++- tests/python/unittest/test_tir_ops.py | 28 +++++++++++++++++++++------ 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 30e2a2948769..905d14296d98 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -3150,10 +3150,16 @@ def reducer(expr, axis, where=None, init=None, *args): if isinstance(axis, (tvm.tir.IterVar, list, tuple)): assert not args return _make_reduce(expr, axis, where, init) + if where is None: assert not args + assert init is None return _reduce_directly(expr, axis) - return _reduce_directly(expr, axis, where, *args) + elif init is None: + assert not args + return _reduce_directly(expr, axis, where) + else: + return _reduce_directly(expr, axis, where, init, *args) doc_str = """Create a {0} expression over axis. diff --git a/tests/python/unittest/test_tir_ops.py b/tests/python/unittest/test_tir_ops.py index 9725650eadae..21981d1f0ba1 100644 --- a/tests/python/unittest/test_tir_ops.py +++ b/tests/python/unittest/test_tir_ops.py @@ -15,8 +15,11 @@ # specific language governing permissions and limitations # under the License. import tvm +import tvm.testing from tvm import te +import pytest + def check_throws(f): try: @@ -213,10 +216,23 @@ def test_if_then_else(): raise ValueError("Unknown combinations") +@pytest.mark.parametrize("num_args", list(range(2, 10))) +def test_comm_reducer(num_args): + """Handle all arguments in tir comm_reducer + + The `tir.comm_reducer` API has two distinct usages. It can reduce + a tensor along a specified axis, similar to numpy.max, or it can + reduce several arguments together, simililar to Python's built-in + max(). This choice is based on the type of the second argument. + + If the `tir.comm_reducer` is reducing all arguments, then all + arguments should be used. In the past, the introduction of new + arguments intended for use when reducing along a tensor axis has + failed to forward these arguments when reducing along a list of + items. + """ + assert tvm.tir.max(*range(num_args)) == num_args - 1 + + if __name__ == "__main__": - test_const_fold() - test_const_fold2() - test_const_fold3() - test_const_fold4() - test_binary_dtype_match() - test_if_then_else() + tvm.testing.main()