Skip to content

Commit

Permalink
[Relay][Pass] Avoid FoldConstant folding some ops (#4245)
Browse files Browse the repository at this point in the history
* [Relay][Pass] Avoid FoldConstant folding some ops

* rename
  • Loading branch information
vinx13 authored and zhiics committed Nov 1, 2019
1 parent cd717de commit aa49e85
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/relay/pass/fold_constant.cc
Expand Up @@ -102,6 +102,9 @@ class ConstantFolder : public ExprMutator {

Expr VisitExpr_(const CallNode* call) final {
static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful");

std::unordered_set<std::string> skip_list{"zeros_like", "ones_like", "full_like", "full"};

auto origin_args = call->args;
Expr res = ExprMutator::VisitExpr_(call);
call = res.as<CallNode>();
Expand All @@ -111,6 +114,9 @@ class ConstantFolder : public ExprMutator {
if (call->args.size() == 0) return res;
const OpNode* op = call->op.as<OpNode>();
if (op == nullptr) return res;
if (skip_list.count(op->name)) {
return res;
}
// skip stateful ops.
if (op_stateful.get(GetRef<Op>(op), false)) return res;
// Try to evaluate shape_of op
Expand Down
16 changes: 16 additions & 0 deletions tests/python/relay/test_pass_fold_constant.py
Expand Up @@ -146,9 +146,25 @@ def expected(dtype):
assert relay.analysis.graph_equal(zz, zexpected)


def test_fold_full():
c_shape = (8, 9, 10)
def before():
dtype = 'float32'
return relay.full(relay.const(1.0, dtype), c_shape, dtype=dtype)

def expected():
# expect no changes
return before()

zz = run_opt_pass(before(), transform.FoldConstant())
zexpected = run_opt_pass(expected(), transform.InferType())
assert relay.analysis.graph_equal(zz, zexpected)


if __name__ == "__main__":
test_fold_const()
test_fold_let()
test_fold_tuple()
test_fold_concat()
test_fold_shape_of()
test_fold_full()

0 comments on commit aa49e85

Please sign in to comment.