Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Pass] Don't consider constants as free vars in MergeComposite #4919

Closed
wants to merge 3 commits into from

Conversation

soiferj
Copy link
Contributor

@soiferj soiferj commented Feb 20, 2020

Based on my post here, if a pattern contains a constant, that constant should not be a parameter to the function that MergeComposite generates. This is because it is assumed that the constant is dealt with internally by the external codegen.

A summary of the fix is: create a new const_list which is passed around to ExtractPattern calls. When creating the final function, we call GetFreeVarsWithoutConst instead of FreeVars, as FreeVars also returns constants. GetFreeVarsWithoutConst will only return free vars that aren't also in const_list.

@comaniac @mbaret @zhiics would you be able to take a look?

Copy link
Contributor

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -162,9 +164,10 @@ class MergeCompositeWrapper : public ExprMutator {
CHECK(pattern.defined());
Map<std::string, Array<Expr>> args_map;
Map<Expr, Expr> call_map;
auto extract = ExtractPattern(pattern, call, &args_map, &call_map);
Array<Expr> const_list;
auto extract = ExtractPattern(pattern, call, &args_map, &call_map, &const_list);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use members instead of making this API bulky? Then we can also clean a bit for the code, i.e. all ExtractPattern could only take two parameters and ExtractPattern for constantnode could be removed because it really doesn't use var_map.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

for (const auto& free_var : free_vars) {
bool is_const = false;
for (const auto& const_var : const_list) {
if (free_var.get() == const_var.get()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just free_var == const_var?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

@mbaret
Copy link
Contributor

mbaret commented Feb 21, 2020

I have a feeling you may be able to get around this problem by using bind_params_by_name. I have to do this when using merge composite because weight parameters are still treating as variables until this pass it called, at which time they are replaced with constants.

Could you try running:

f = relay.build_module.bind_params_by_name(mod["main"], params)
mod = tvm.IRModule()
mod["main"] = f

before the merge composite?

}
}

Expr ExtractPattern(const Constant& pattern, const Expr& root,
Map<std::string, Array<Expr>>* var_map) {
Expr ExtractPattern(const Constant& pattern, const Expr& root) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this one can be safely removed

@soiferj
Copy link
Contributor Author

soiferj commented Feb 21, 2020

@mbaret thanks for the heads-up. It looks like this works. Do you still think this PR is a worthwhile improvement?

@zhiics
Copy link
Member

zhiics commented Feb 21, 2020

Thanks @mbaret. If that's the case, I think we can probably just add a unit test and apply the bind pass. All other changes could be reverted. How do you think?

@mbaret
Copy link
Contributor

mbaret commented Feb 21, 2020

I think it actually implies a slightly different improvement should be made :) Currently no real checking happens in the ExtractPattern for constants, they are just always accepted as part of the pattern. Probably what should happen instead is that the pattern should fail to match if it's expecting a constant and instead sees a variable.

Adding a quick unit test to make sure it doesn't incorrectly match a constant to a var and then implementing that check in ExtractPattern should be sufficient.

@soiferj
Copy link
Contributor Author

soiferj commented Feb 25, 2020

Sounds good. I'll close this PR for now.

@soiferj soiferj closed this Feb 25, 2020
@corehalt
Copy link
Contributor

corehalt commented Feb 26, 2020

I have a feeling you may be able to get around this problem by using bind_params_by_name. I have to do this when using merge composite because weight parameters are still treating as variables until this pass it called, at which time they are replaced with constants.

Could you try running:

f = relay.build_module.bind_params_by_name(mod["main"], params)
mod = tvm.IRModule()
mod["main"] = f

before the merge composite?

@mbaret @soiferj I tried to call bind_params_by_name() before merge composite.
If I don't call bind_params_by_name() I get something like:

  %155 = nn.relu(%154) /* ty=Tensor[(1, 1024, 7, 7), float32] */;
  %157 = fn (%scompiler_input52: Tensor[(1, 1024, 7, 7), float32], %scompiler_input53: Tensor[(1024, 1024, 1, 1), float32], Compiler="scompiler", ExternalSymbol="scompiler_0", Primitive=1) -> Tensor[(1, 1024, 7, 7), float32] {
    %156 = fn (%x26: Tensor[(1, 1024, 7, 7), float32], %y26: Tensor[(1024, 1024, 1, 1), float32], Primitive=1, Composite="conv") -> Tensor[(1, 1024, 7, 7), float32] {
      nn.conv2d(%x26, %y26, padding=[0, 0, 0, 0], channels=1024, kernel_size=[1, 1]) /* ty=Tensor[(1, 1024, 7, 7), float32] */
    };
    %156(%scompiler_input52, %scompiler_input53) /* ty=Tensor[(1, 1024, 7, 7), float32] */
  };
  %158 = %157(%155, %separable_conv_block_13_conv2_weight) /* ty=Tensor[(1, 1024, 7, 7), float32] */;

But if I call bind_params_by_name() I can see something different:

  %129 = nn.relu(%128) /* ty=Tensor[(1, 1024, 7, 7), float32] */;
  %131 = fn (%scompiler_input52: Tensor[(1, 1024, 7, 7), float32], %scompiler_input53: Tensor[(1024, 1024, 1, 1), float32], Compiler="scompiler", ExternalSymbol="scompiler_0", Primitive=1) -> Tensor[(1, 1024, 7, 7), float32] {
    %130 = fn (%x26: Tensor[(1, 1024, 7, 7), float32], %y26: Tensor[(1024, 1024, 1, 1), float32], Primitive=1, Composite="conv") -> Tensor[(1, 1024, 7, 7), float32] {
      nn.conv2d(%x26, %y26, padding=[0, 0, 0, 0], channels=1024, kernel_size=[1, 1]) /* ty=Tensor[(1, 1024, 7, 7), float32] */
    };
    %130(%scompiler_input52, %scompiler_input53) /* ty=Tensor[(1, 1024, 7, 7), float32] */
  };
  %132 = %131(%129, meta[relay.Constant][52] /* ty=Tensor[(1024, 1024, 1, 1), float32] */ /* ty=Tensor[(1024, 1024, 1, 1), float32] */) /* ty=Tensor[(1, 1024, 7, 7), float32] */;

The thing is that later meta[relay.Constant][52] is exposed as a Relay.VarNode on the codegen anyway. Is there any way of accessing these constants from a custom codegen?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants