Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Do not drop 4th argument to tir.max #15763

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 7 additions & 1 deletion python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3150,10 +3150,16 @@ def reducer(expr, axis, where=None, init=None, *args):
if isinstance(axis, (tvm.tir.IterVar, list, tuple)):
assert not args
return _make_reduce(expr, axis, where, init)

if where is None:
assert not args
assert init is None
return _reduce_directly(expr, axis)
return _reduce_directly(expr, axis, where, *args)
elif init is None:
assert not args
return _reduce_directly(expr, axis, where)
else:
return _reduce_directly(expr, axis, where, init, *args)

doc_str = """Create a {0} expression over axis.

Expand Down
28 changes: 22 additions & 6 deletions tests/python/unittest/test_tir_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm.testing
from tvm import te

import pytest


def check_throws(f):
try:
Expand Down Expand Up @@ -213,10 +216,23 @@ def test_if_then_else():
raise ValueError("Unknown combinations")


@pytest.mark.parametrize("num_args", list(range(2, 10)))
def test_comm_reducer(num_args):
"""Handle all arguments in tir comm_reducer

The `tir.comm_reducer` API has two distinct usages. It can reduce
a tensor along a specified axis, similar to numpy.max, or it can
reduce several arguments together, simililar to Python's built-in
max(). This choice is based on the type of the second argument.

If the `tir.comm_reducer` is reducing all arguments, then all
arguments should be used. In the past, the introduction of new
arguments intended for use when reducing along a tensor axis has
failed to forward these arguments when reducing along a list of
items.
"""
assert tvm.tir.max(*range(num_args)) == num_args - 1


if __name__ == "__main__":
test_const_fold()
test_const_fold2()
test_const_fold3()
test_const_fold4()
test_binary_dtype_match()
test_if_then_else()
tvm.testing.main()