Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][transform][SimplifyExpr] simplify adjacent muls and adds with constants #13213

Merged
merged 2 commits into from
Oct 29, 2022

Conversation

yangulei
Copy link
Contributor

This PR enables simplification and folding of a sub graph containing adjacent muls and adds with constant inputs.

Motivation

Workloads like densenet-121 has several partitions with conv-bn-mul-add-relu pattern, for example:

def @main(%data_0: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] */) -> Tensor[(1, 1000, 1, 1), float32] {
  %0 = nn.conv2d(%data_0, meta[relay.Constant][0] /* ty=Tensor[(64, 3, 7, 7), float32] */, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7]) /* ty=Tensor[(1, 64, 112, 112), float32] */;
  %1 = nn.batch_norm(%0, meta[relay.Constant][1] /* ty=Tensor[(64), float32] */, meta[relay.Constant][2] /* ty=Tensor[(64), float32] */, meta[relay.Constant][3] /* ty=Tensor[(64), float32] */, meta[relay.Constant][4] /* ty=Tensor[(64), float32] */) /* ty=(Tensor[(1, 64, 112, 112), float32], Tensor[(64), float32], Tensor[(64), float32]) */;
  %2 = %1.0 /* ty=Tensor[(1, 64, 112, 112), float32] */;
  %3 = multiply(%2, meta[relay.Constant][5] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 112, 112), float32] */;
  %4 = add(%3, meta[relay.Constant][6] /* ty=Tensor[(64, 1, 1), float32] */) /* ty=Tensor[(1, 64, 112, 112), float32] */;
  %5 = nn.relu(%4) /* ty=Tensor[(1, 64, 112, 112), float32] */;
  // ...
}

Current transforms on this pattern are:

  1. conv-bn-mul-add-relu as the original pattern.
  2. to conv-mul-add-mul-add-relu as bn is expended to mul-add.
  3. to conv-add-mul-add-relu as the first mul is folded into conv.

As all the muls and adds have constant second inputs, they should be folded to a single mul-add and a preferred transform sequence should be:

  1. conv-bn-mul-add-relu as the original pattern.
  2. to conv-mul-add-mul-add-relu as bn is expended to mul-add.
  3. to conv-mul-add-relu as muls and adds are folded to one single mul-add.
  4. to conv-add-relu as mul is folded into conv.

Solution

Actually, any series contain muls and adds with constant inputs could be folded to one particular mul-add. Three rewrite rules are added to make this happen:

  1. mul-mul -> mul
  2. add-add -> add
  3. add-mul -> mul-add

As SimplifyExpr apply simplifications iteratively until no changes to the graph, any mul and add series could be rewritten to one single mul, add or mul-add with one of the binary inputs could evaluates to a constant in the following FoldConstant pass.

@tvm-bot
Copy link
Collaborator

tvm-bot commented Oct 27, 2022

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

  • No users to tag found in teams: relay, transform, simplifyexpr See #10317 for details
  • Built docs for commit 14d8bba can be found here.

Generated by tvm-bot

@yangulei
Copy link
Contributor Author

yangulei commented Oct 27, 2022

@vinx13 This PR could also solve the issue about conv-bias-mul folding mentioned at https://discuss.tvm.apache.org/t/relay-pass-constant-folding-question-on-conv2d-bias-constant-mul-constant-folding/11648.

@yangulei
Copy link
Contributor Author

# conv + bias_add + bn + add + relu -> fused conv_bias_sum, relu
test_sum_pattern([conv2d_bias_sum_pat], 2)
# conv + bias_add + bn + add + relu -> fused conv_bias_sum_relu,
test_sum_pattern([conv2d_bias_sum_relu_pat], 1)

@masahi The tests above failed because the pattern matcher always matches the last two adds in conv-add[bias]-add[bn]-add[sum], but the last add don't have any constant inputs thus cannot be simplified. Do you have any idea to fix this?

In this PR, I generalize the add-add or mul-mul simplification from constant inputs to constant expr inputs, so that complex pattern like mul-add-mul could be simplified to mul-add. In constant inputs case, the information of constant inputs could be embedded into the pattern for matching, thus the last two adds in conv-add[bias]-add[bn]-add[sum] will not match. While in constant expr inputs case, I match all add-add pattern first, then check whether both adds have constant inputs.

@masahi
Copy link
Member

masahi commented Oct 27, 2022

This sounds tricky, since I'd expect the pattern matcher to be operating in a "bottom up" manner. Is it possible to add "constant-ness" condition in the pattern? Otherwise I can only think of adding conv2d into the add -> add pattern, but this is a terrible solution...

@yangulei
Copy link
Contributor Author

The pattern with constant inputs is fine for the first iteration in SimplifyExpr pass. While starting from the second iteration, some constant inputs are rewritten to constant expr inputs.

Take ((x + c1) + c2 )+ c3 as an example, first iteration could rewrite it to (x + (c1 + c2)) + c3. Now the first add becomes x+(c1 + c2), with an input which could evaluate to a constant instead of a constant input, so the pattern with constant inputs will not match. I think this could be solved by evaluating c1 + c2 to a constant like c_12 at the end of the first iteration.

Is there a way to evaluate the constant expr to some concrete constant in the SimplifyExpr pass? If we can do that, the input graph for the later iterations also matches the constant inputs pattern and the simplification will carry on.

@masahi
Copy link
Member

masahi commented Oct 28, 2022

Since we run FoldConstant before SimplifyExpr

pass_seqs.push_back(transform::FoldConstant());
pass_seqs.push_back(transform::FoldScaleAxis());
pass_seqs.push_back(transform::SimplifyExpr());
, aren't constant expressions already constants when we reach SimplifyExpr?

@yangulei
Copy link
Contributor Author

yangulei commented Oct 28, 2022

SimplifyExpr pass do the simplifications in an iterative way.

Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
// the rewrites will be applied in the given order, and repeated until fixed point
DFPatternRewriteComposer composer;
composer.AddRewrite<ConcretizeZerosLikeRewrite>();
composer.AddRewrite<ConcretizeOnesLikeRewrite>();

The iterations happen inside SimplifyExpr, we need to find a way to evaluate the const expr manually.

aren't constant expressions already constants when we reach SimplifyExpr?

Yes, but only for the const expr already exists before SimplifyExpr. Now some const expr are written inside SimplifyExpr, after the FoldConstant pass.

@yangulei
Copy link
Contributor Author

@masahi Problem solved after applying FoldConstant inside SimplifyExpr.

@masahi masahi merged commit e971956 into apache:main Oct 29, 2022
xinetzone pushed a commit to daobook/tvm that referenced this pull request Nov 10, 2022
… constants (apache#13213)

* simplify adjacent muls and adds with constants

* apply FoldConstant inside SimplifyExpr
xinetzone pushed a commit to daobook/tvm that referenced this pull request Nov 25, 2022
… constants (apache#13213)

* simplify adjacent muls and adds with constants

* apply FoldConstant inside SimplifyExpr
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants