From 542041967f3672eee4ca4d579262f70350e7ecb6 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Wed, 20 Apr 2022 11:05:06 +0800 Subject: [PATCH] Fix While Node StructuralEqual and StructuralHash issue --- include/tvm/tir/stmt.h | 6 +++--- .../python/unittest/test_tir_structural_equal_hash.py | 10 ++++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 9ccab50eced2..6cdd6499c821 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -996,12 +996,12 @@ class WhileNode : public StmtNode { } bool SEqualReduce(const WhileNode* other, SEqualReducer equal) const { - return equal.DefEqual(condition, other->condition) && equal.DefEqual(body, other->body); + return equal(condition, other->condition) && equal(body, other->body); } void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce.DefHash(condition); - hash_reduce.DefHash(body); + hash_reduce(condition); + hash_reduce(body); } static constexpr const char* _type_key = "tir.While"; diff --git a/tests/python/unittest/test_tir_structural_equal_hash.py b/tests/python/unittest/test_tir_structural_equal_hash.py index d25780a01f79..ff02f1e369ea 100644 --- a/tests/python/unittest/test_tir_structural_equal_hash.py +++ b/tests/python/unittest/test_tir_structural_equal_hash.py @@ -199,6 +199,15 @@ def test_buffer_load_store(): assert not consistent_equal(sy, sz) +def test_while(): + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") + wx = tvm.tir.While(x > 0, tvm.tir.Evaluate(x)) + wy = tvm.tir.While(y > 0, tvm.tir.Evaluate(y)) + assert not consistent_equal(wx, wy) + assert consistent_equal(wx, wy, map_free_vars=True) + + if __name__ == "__main__": test_exprs() test_prim_func() @@ -208,3 +217,4 @@ def test_buffer_load_store(): test_stmt() test_buffer_storage_scope() test_buffer_load_store() + test_while()