-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay][Pass] Don't consider constants as free vars in MergeComposite #4919
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
src/relay/pass/merge_composite.cc
Outdated
@@ -162,9 +164,10 @@ class MergeCompositeWrapper : public ExprMutator { | |||
CHECK(pattern.defined()); | |||
Map<std::string, Array<Expr>> args_map; | |||
Map<Expr, Expr> call_map; | |||
auto extract = ExtractPattern(pattern, call, &args_map, &call_map); | |||
Array<Expr> const_list; | |||
auto extract = ExtractPattern(pattern, call, &args_map, &call_map, &const_list); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use members instead of making this API bulky? Then we can also clean a bit for the code, i.e. all ExtractPattern
could only take two parameters and ExtractPattern
for constantnode
could be removed because it really doesn't use var_map
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated
src/relay/pass/merge_composite.cc
Outdated
for (const auto& free_var : free_vars) { | ||
bool is_const = false; | ||
for (const auto& const_var : const_list) { | ||
if (free_var.get() == const_var.get()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just free_var == const_var
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated
I have a feeling you may be able to get around this problem by using bind_params_by_name. I have to do this when using merge composite because weight parameters are still treating as variables until this pass it called, at which time they are replaced with constants. Could you try running:
before the merge composite? |
} | ||
} | ||
|
||
Expr ExtractPattern(const Constant& pattern, const Expr& root, | ||
Map<std::string, Array<Expr>>* var_map) { | ||
Expr ExtractPattern(const Constant& pattern, const Expr& root) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this one can be safely removed
@mbaret thanks for the heads-up. It looks like this works. Do you still think this PR is a worthwhile improvement? |
Thanks @mbaret. If that's the case, I think we can probably just add a unit test and apply the bind pass. All other changes could be reverted. How do you think? |
I think it actually implies a slightly different improvement should be made :) Currently no real checking happens in the ExtractPattern for constants, they are just always accepted as part of the pattern. Probably what should happen instead is that the pattern should fail to match if it's expecting a constant and instead sees a variable. Adding a quick unit test to make sure it doesn't incorrectly match a constant to a var and then implementing that check in ExtractPattern should be sufficient. |
Sounds good. I'll close this PR for now. |
@mbaret @soiferj I tried to call
But if I call
The thing is that later |
Based on my post here, if a pattern contains a constant, that constant should not be a parameter to the function that
MergeComposite
generates. This is because it is assumed that the constant is dealt with internally by the external codegen.A summary of the fix is: create a new
const_list
which is passed around toExtractPattern
calls. When creating the final function, we callGetFreeVarsWithoutConst
instead ofFreeVars
, asFreeVars
also returns constants.GetFreeVarsWithoutConst
will only return free vars that aren't also inconst_list
.@comaniac @mbaret @zhiics would you be able to take a look?