Skip to content

Commit

Permalink
[ARITH][BACKPORT-0.6] fix a min/max simplify bug (#5761)
Browse files Browse the repository at this point in the history
Co-authored-by: d00221512 <d00221512@huawei.com>
  • Loading branch information
xqdan and d00221512 committed Jun 10, 2020
1 parent 2cd6321 commit 85af4a9
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/arithmetic/rewrite_simplify.cc
Expand Up @@ -1157,8 +1157,11 @@ Mutate_(const Min* op, const Expr& self) {
if (min(x * c1, c2).Match(ret)) {
int64_t c1val = c1.Eval()->value;
int64_t c2val = c2.Eval()->value;
if (c1val == 0) {
return c2val < 0 ? c2.Eval() : c1.Eval();
}
if (c2val % c1val == 0) {
if (c2val / c1val >= 0) {
if (c1val > 0) {
return (min(x, c2val / c1val) * c1val).Eval();
} else {
return (max(x, c2val / c1val) * c1val).Eval();
Expand Down Expand Up @@ -1331,8 +1334,11 @@ Mutate_(const Max* op, const Expr& self) {
if (max(x * c1, c2).Match(ret)) {
int64_t c1val = c1.Eval()->value;
int64_t c2val = c2.Eval()->value;
if (c1val == 0) {
return c2val > 0 ? c2.Eval() : c1.Eval();
}
if (c2val % c1val == 0) {
if (c2val / c1val >= 0) {
if (c1val > 0) {
return (max(x, c2val / c1val) * c1val).Eval();
} else {
return (min(x, c2val / c1val) * c1val).Eval();
Expand Down
12 changes: 12 additions & 0 deletions tests/python/unittest/test_arith_rewrite_simplify.py
Expand Up @@ -529,6 +529,12 @@ def test_min_index_simplify():

ck.verify(tvm.min(x * 3, 9), tvm.min(x, 3) * 3)
ck.verify(tvm.min(3 - x, 2), 3 - tvm.max(x, 1))
ck.verify(tvm.min(x * 2, 0), tvm.min(x, 0) * 2)
ck.verify(tvm.min(0 - x * 2, 0), tvm.max(x, 0) * -2)
ck.verify(tvm.min(x * (-2), -4), tvm.max(x, 2) * -2)
ck.verify(tvm.min(x * (-2), 4), tvm.max(x, -2) * -2)
ck.verify(tvm.min(x * (0), 4), 0)
ck.verify(tvm.min(x * (0), -4), -4)

# DivMod rules
# truc div
Expand Down Expand Up @@ -609,6 +615,12 @@ def test_max_index_simplify():

ck.verify(tvm.max(x * 3, 9), tvm.max(x, 3) * 3)
ck.verify(tvm.max(3 - x, 1), 3 - tvm.min(x, 2))
ck.verify(tvm.max(x * 2, 0), tvm.max(x, 0) * 2)
ck.verify(tvm.max(0 - x * 2, 0), tvm.min(x, 0) * -2)
ck.verify(tvm.max(x * (-2), -4), tvm.min(x, 2) * -2)
ck.verify(tvm.max(x * (-2), 4), tvm.min(x, -2) * -2)
ck.verify(tvm.max(x * (0), 4), 4)
ck.verify(tvm.max(x * (0), -4), 0)

# DivMod rules
# truc div
Expand Down

0 comments on commit 85af4a9

Please sign in to comment.