diff --git a/src/relax/transform/eliminate_common_subexpr.cc b/src/relax/transform/eliminate_common_subexpr.cc index fa90d4193337..842470c46357 100644 --- a/src/relax/transform/eliminate_common_subexpr.cc +++ b/src/relax/transform/eliminate_common_subexpr.cc @@ -94,7 +94,7 @@ class SubexprCounter : public ExprVisitor { if (!(e->IsInstance() || e->IsInstance() || e->IsInstance() || e->IsInstance() || e->IsInstance() || e->IsInstance() || - e->IsInstance() || + e->IsInstance() || e->IsInstance() || (e.as() && (e.as()->is_scalar())))) { // also if e has an impure subexpression, we will not deduplicate it if (!impurity_detector_.Detect(e)) { diff --git a/tests/python/relax/test_transform_cse.py b/tests/python/relax/test_transform_cse.py index cf66ae3c1c70..d69ec61b5c1f 100644 --- a/tests/python/relax/test_transform_cse.py +++ b/tests/python/relax/test_transform_cse.py @@ -276,5 +276,19 @@ def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32 verify(Before, Expected) +def test_do_not_eliminate_extern_func(): + @I.ir_module + class Before: + @R.function(pure=False) + def foo(x: R.Tensor((2, 3), dtype="float32")): + y = R.call_packed("extern_func_name", x, sinfo_args=R.Tensor([2, 3])) + z = R.call_packed("extern_func_name", y, sinfo_args=R.Tensor([2, 3])) + return z + + Expected = Before + + verify(Before, Expected) + + if __name__ == "__main__": tvm.testing.main()