Skip to content

Commit

Permalink
[Unity] Skip constants in CSE pass
Browse files Browse the repository at this point in the history
This patch modifies the CSE pass to skip all constants as
[discussed here](https://discuss.tvm.apache.org/t/common-subexpr-elimination-pass-replaces-constant-args-with-vars/15971)
  • Loading branch information
quic-sanirudh committed Nov 14, 2023
1 parent d5daa98 commit 967a841
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/relax/transform/eliminate_common_subexpr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class SubexprCounter : public ExprVisitor {
e->IsInstance<GlobalVarNode>() || e->IsInstance<tvm::OpNode>() ||
e->IsInstance<PrimValueNode>() || e->IsInstance<StringImmNode>() ||
e->IsInstance<ShapeExprNode>() || e->IsInstance<ExternFuncNode>() ||
(e.as<ConstantNode>() && (e.as<ConstantNode>()->is_scalar())))) {
e.as<ConstantNode>())) {
// also if e has an impure subexpression, we will not deduplicate it
if (!impurity_detector_.Detect(e)) {
int count = 0;
Expand Down
8 changes: 5 additions & 3 deletions tests/python/relax/test_transform_cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,11 @@ class Expected:
def foo() -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((2, 2), dtype="int32")):
with R.dataflow():
lv0 = R.add(R.const(1, dtype="int32"), R.const(1, dtype="int32"))
lv1 = R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32")))
lv2 = R.add(lv1, lv1)
gv = (lv0, lv2)
lv1 = R.add(
R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))),
R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))),
)
gv = (lv0, lv1)
R.output(gv)
return gv

Expand Down

0 comments on commit 967a841

Please sign in to comment.