Skip to content

Commit

Permalink
[TIR] Do not drop 4th argument to tir.max
Browse files Browse the repository at this point in the history
The `tir.op.comm_reducer` utility provides two distinct APIs, either
reducing along a tensor axis or reducing along a list of arguments.

Prior to this commit, when reducing along a list of arguments, the 4th
argument was silently dropped.  For example,
`tvm.tir.max(1,2,3,4,3,2,1)` would return `3`.
  • Loading branch information
Lunderberg committed Sep 15, 2023
1 parent 64ab31e commit 9b0df1e
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 7 deletions.
8 changes: 7 additions & 1 deletion python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3150,10 +3150,16 @@ def reducer(expr, axis, where=None, init=None, *args):
if isinstance(axis, (tvm.tir.IterVar, list, tuple)):
assert not args
return _make_reduce(expr, axis, where, init)

if where is None:
assert not args
assert init is None
return _reduce_directly(expr, axis)
return _reduce_directly(expr, axis, where, *args)
elif init is None:
assert not args
return _reduce_directly(expr, axis, where)
else:
return _reduce_directly(expr, axis, where, init, *args)

doc_str = """Create a {0} expression over axis.
Expand Down
28 changes: 22 additions & 6 deletions tests/python/unittest/test_tir_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm.testing
from tvm import te

import pytest


def check_throws(f):
try:
Expand Down Expand Up @@ -213,10 +216,23 @@ def test_if_then_else():
raise ValueError("Unknown combinations")


@pytest.mark.parametrize("num_args", list(range(2, 10)))
def test_comm_reducer(num_args):
"""Handle all arguments in tir comm_reducer
The `tir.comm_reducer` API has two distinct usages. It can reduce
a tensor along a specified axis, similar to numpy.max, or it can
reduce several arguments together, simililar to Python's built-in
max(). This choice is based on the type of the second argument.
If the `tir.comm_reducer` is reducing all arguments, then all
arguments should be used. In the past, the introduction of new
arguments intended for use when reducing along a tensor axis has
failed to forward these arguments when reducing along a list of
items.
"""
assert tvm.tir.max(*range(num_args)) == num_args - 1


if __name__ == "__main__":
test_const_fold()
test_const_fold2()
test_const_fold3()
test_const_fold4()
test_binary_dtype_match()
test_if_then_else()
tvm.testing.main()

0 comments on commit 9b0df1e

Please sign in to comment.