Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Transform] Check for zero-param operators in LiftTransformParams #16595

Merged
merged 1 commit into from Feb 29, 2024

Conversation

Lunderberg
Copy link
Contributor

Prior to this commit, LiftTransformParams would extract out all variable binding that have no runtime dependencies. As a result, expressions such as R.zeros([16], "int32") would be extracted out into the parameter transformation, even though they do not depend on any parameters.

This commit updates LiftTransformParams to only output variables that depend on at least one compile-time parameter.

The unit test for this functionality also found that relax::Call was erroneously calling MarkGraphNode in SEqualReduce and SHashReduce. This should only be called for nodes that have have reference equality, such as relax::Var, and not for composite objects. This caused erroneous failures in the unit test when two instances of R.zeros([16], "int32") were being compared by reference equality in StructuralEqual. These extra calls to MarkGraphNode have been removed.

@Lunderberg
Copy link
Contributor Author

This PR depends on changes made in #16594, and is marked as a draft until it lands.

@vinx13
Copy link
Member

vinx13 commented Feb 21, 2024

As a result, expressions such as R.zeros([16], "int32") would be extracted out into the parameter transformation, even though they do not depend on any parameters.
Does this affect the result? If a weight transformation depends on some values like R.zeros, such transformation will no longer be lifted if R.zeros is not lifted, maybe better we can check such dependency

@Lunderberg
Copy link
Contributor Author

Does this affect the result? If a weight transformation depends on some values like R.zeros, such transformation will no longer be lifted if R.zeros is not lifted, maybe better we can check such dependency

Good question, and this case should be handled, by allowing zero-param operators to potentially appear in both functions. (See this unit test for how this looks in practice.) While the case isn't ever explicitly handled, it instead results from the overall lifting.

  1. Add every variables binding that doesn't require runtime parameters to the lifted transform_params function.
  2. For every variable binding that is present in transform_params and depends on the model weights, replace with the output from transform_params.
  3. Run dead-code elimination on both functions.

Since the zero-param operators don't depend on runtime parameters, they appear in the transform_params. Since the zero-param operators don't depend on the model weights, they aren't replaced in the original function. If they are only required in one of the two functions, then the dead-code elimination will remove them from the other.

Copy link
Contributor

@slyubomirsky slyubomirsky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize it's a draft but I had a look anyway. Since the code changes also included those from #16594, it was a little difficult to see what had changed. I didn't see anything to take issue with, though it was a little less than obvious to see how the delta from #16594 accomplished the stated purpose, but I think I follow it.

@@ -169,13 +169,11 @@ class CallNode : public ExprNode {

bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
// skip sinfo_args check for primitive ops.
equal->MarkGraphNode();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason for this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The MarkGraphNode function should be called for IR nodes that have reference equality, such as variables. Using it with the relax::CallNode means that results from StructuralEqual will be dependent on how a relax::Call node was constructed, even if they have identical contents.

Relevant to this PR, if R.zeros([16], "int32") appears in both main and transform_params, the expected output generated by the TVMScript parser would have two different relax::Call objects, while the output of LiftTransformParams would use the same relax::Call object. Because the relax::Call objects were being checked for analogous reference equality, this would cause StructuralEqual to erroneously report the test as failing.

I've added a unit test to specifically exercise this behavior, rather than implicitly relying on the tests for LiftTransformParams.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

Comment on lines +338 to +375
auto upstream_vars = FreeVars(bound_value);
bool depends_on_compile_time_param = std::any_of(
upstream_vars.begin(), upstream_vars.end(),
[&](const Var& var) -> bool { return info_.requires_compile_time_param.count(var); });
if (depends_on_compile_time_param) {
info_.requires_compile_time_param.insert(binding->var);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the logic that is meant to handle zero-param operators? I had a really hard time figuring out what the change was from #16594 and this was the only major area that came up in the diff.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This section along with this line, yes. This section collects information to determine if this parameter depends, directly or indirectly, on one of the model weights. The linked line uses the collected information to determine if the transform_params should output a variable.

I've updated the comment here to indicate why the additional information is being collected.

std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> to_suppress;
explicit SuppressCompileTime(
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> to_suppress)
: to_suppress(to_suppress) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question separate from the review: Are we going with the convention of having member var names end in an underscore? I'm not a partisan on that one, but we should be consistent. @tqchen

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TVM convention, as I understand it, is to have private members have a trailing underscore, while public members do not. Often if I'm making quick subclasses of ExprMutator, I'll have a struct with public fields rather than a class with private fields. Since the class definition already occurs within a function scope, the visibility of the entire class is already restricted, and additional visibility restrictions of private members isn't necessary.

That said, I've updated it to be a class with a private to_suppress_. I figure that if a deviation from convention is big enough to raise a question, it's big enough to avoid the deviation altogether.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was just curious, I don't think it's worthy of holding up a review or anything like that. Having a different rule for public/private members is reasonable.

@Lunderberg
Copy link
Contributor Author

I realize it's a draft but I had a look anyway. Since the code changes also included those from #16594, it was a little difficult to see what had changed.

Apologies there. Once #16594 lands, the "Files Changed" tab in this PR should update to only show the changes unique to this PR. In the meantime, this PR branch has its changes in a separate commit (link), where they can be viewed separately from the #16594 changes.

@slyubomirsky
Copy link
Contributor

Thank you for the changes. The new unit test and the new comment are both helpful.

Prior to this commit, `LiftTransformParams` would extract out all
variable binding that have no runtime dependencies.  As a result,
expressions such as `R.zeros([16], "int32")` would be extracted out
into the parameter transformation, even though they do not depend on
any parameters.

This commit updates `LiftTransformParams` to only output variables
that depend on at least one compile-time parameter.

The unit test for this functionality also found that `relax::Call` was
erroneously calling `MarkGraphNode` in `SEqualReduce` and
`SHashReduce`.  This should only be called for nodes that have have
reference equality, such as `relax::Var`, and not for composite
objects.  This caused erroneous failures in the unit test when two
instances of `R.zeros([16], "int32")` were being compared by reference
equality in `StructuralEqual`.  These extra calls to `MarkGraphNode`
have been removed.
@Lunderberg
Copy link
Contributor Author

With #16594 landed, I've rebased this PR on top of it, and it is now ready for review.

@Lunderberg Lunderberg marked this pull request as ready for review February 23, 2024 15:19
@Lunderberg Lunderberg changed the title [Draft][Transform] Check for zero-param operators in LiftTransformParams [Transform] Check for zero-param operators in LiftTransformParams Feb 23, 2024
Copy link
Contributor

@slyubomirsky slyubomirsky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for responding to previous feedback.

@Lunderberg Lunderberg merged commit e261a27 into apache:main Feb 29, 2024
19 of 21 checks passed
@Lunderberg Lunderberg deleted the relax_only_lift_if_params_used branch February 29, 2024 14:07
Lunderberg added a commit to Lunderberg/tvm that referenced this pull request Mar 12, 2024
…ache#16595)

Prior to this commit, `LiftTransformParams` would extract out all
variable binding that have no runtime dependencies.  As a result,
expressions such as `R.zeros([16], "int32")` would be extracted out
into the parameter transformation, even though they do not depend on
any parameters.

This commit updates `LiftTransformParams` to only output variables
that depend on at least one compile-time parameter.

The unit test for this functionality also found that `relax::Call` was
erroneously calling `MarkGraphNode` in `SEqualReduce` and
`SHashReduce`.  This should only be called for nodes that have have
reference equality, such as `relax::Var`, and not for composite
objects.  This caused erroneous failures in the unit test when two
instances of `R.zeros([16], "int32")` were being compared by reference
equality in `StructuralEqual`.  These extra calls to `MarkGraphNode`
have been removed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants