-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Transform] Check for zero-param operators in LiftTransformParams #16595
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -58,6 +58,15 @@ struct CollectInfo { | |
*/ | ||
std::vector<Binding> computable_at_compile_time; | ||
|
||
/*! \brief Variables that require a compile-time parameter | ||
* | ||
* Used to distinguish between computed tensors that depend on the | ||
* model weights, and computed tensors that require neither model | ||
* weights nor runtime arguments (e.g. `R.zeros([16], "float16")`). | ||
*/ | ||
std::unordered_set<Variant<relax::Var, tir::Var>, ObjectPtrHash, ObjectPtrEqual> | ||
requires_compile_time_param; | ||
|
||
/*! \brief Variables that are required at runtime */ | ||
std::unordered_set<Variant<relax::Var, tir::Var>, ObjectPtrHash, ObjectPtrEqual> | ||
required_at_runtime; | ||
|
@@ -114,7 +123,8 @@ struct CollectInfo { | |
// Any variable that is computed at compile-time, but is required | ||
// at runtime, must be provided as a parameter. | ||
for (const auto& binding : computable_at_compile_time) { | ||
if (required_at_runtime.count(binding->var)) { | ||
if (requires_compile_time_param.count(binding->var) && | ||
required_at_runtime.count(binding->var)) { | ||
params.push_back(binding->var); | ||
} | ||
} | ||
|
@@ -182,16 +192,21 @@ struct CollectInfo { | |
|
||
// Any binding that is computable at compile-time should be | ||
// suppressed at run-time. | ||
struct SuppressCompileTime : ExprMutator { | ||
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> to_suppress; | ||
explicit SuppressCompileTime(const std::vector<Binding>& bindings) { | ||
for (const auto& binding : bindings) { | ||
to_suppress.insert(binding->var); | ||
} | ||
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> to_suppress; | ||
for (const auto& binding : computable_at_compile_time) { | ||
if (requires_compile_time_param.count(binding->var)) { | ||
to_suppress.insert(binding->var); | ||
} | ||
} | ||
|
||
class SuppressCompileTime : public ExprMutator { | ||
public: | ||
explicit SuppressCompileTime( | ||
const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& to_suppress) | ||
: to_suppress_(to_suppress) {} | ||
|
||
void VisitBinding(const Binding& binding) override { | ||
if (!to_suppress.count(binding->var)) { | ||
if (!to_suppress_.count(binding->var)) { | ||
ExprMutator::VisitBinding(binding); | ||
} | ||
} | ||
|
@@ -205,8 +220,11 @@ struct CollectInfo { | |
return ExprMutator::VisitExpr_(call); | ||
} | ||
} | ||
|
||
private: | ||
const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& to_suppress_; | ||
}; | ||
Expr body = SuppressCompileTime(computable_at_compile_time)(orig_func->body); | ||
Expr body = SuppressCompileTime(to_suppress)(orig_func->body); | ||
body = SeqExpr({DataflowBlock(bindings)}, body); | ||
|
||
Function func(params, body, orig_func->ret_struct_info, orig_func->is_pure, orig_func->attrs); | ||
|
@@ -300,6 +318,7 @@ class LiftableBindingCollector : ExprVisitor { | |
|
||
for (size_t i = num_runtime_params; i < func->params.size(); i++) { | ||
liftable_vars_.insert(func->params[i]); | ||
info_.requires_compile_time_param.insert(func->params[i]); | ||
for (const auto& tir_var : DefinableTIRVarsInStructInfo(GetStructInfo(func->params[i]))) { | ||
liftable_vars_.insert(tir_var); | ||
} | ||
|
@@ -315,12 +334,48 @@ class LiftableBindingCollector : ExprVisitor { | |
} | ||
|
||
void VisitBinding(const Binding& binding) override { | ||
auto bound_value = GetBoundValue(binding); | ||
|
||
if (CanLiftBinding(binding)) { | ||
info_.computable_at_compile_time.push_back(binding); | ||
liftable_vars_.insert(binding->var); | ||
|
||
// There are three type of variables we want to distinguish. | ||
// | ||
// 1. Depend on runtime parameters | ||
// | ||
// Must remain within the original function, cannot be | ||
// lifted out into the `transform_params` function. | ||
// | ||
// 2. Depend on model weights, but not runtime parameters. | ||
// | ||
// Legal to lift out into the `transform_params` function. | ||
// Doing so is beneficial, as it reduces the work performed | ||
// in the inference function. | ||
// | ||
// 3. Depend on neither model weights nor runtime parameters | ||
// (e.g. `R.zeros(shape,dtype)`) | ||
// | ||
// Legal to lift out into the `transform_params` function. | ||
// However, doing so would increase the memory footprint of | ||
// the pre-computed parameters, for little to no benefit. | ||
// These may be duplicated between the `transform_params` | ||
// function and the original function, as they typically | ||
// initialize a tensor to an easy-to-compute state. | ||
// | ||
// Tracking whether a variable depends on the model weights, | ||
// either directly or indirectly, allows us to distinguish | ||
// between categories (2) and (3). | ||
auto upstream_vars = FreeVars(bound_value); | ||
bool depends_on_compile_time_param = std::any_of( | ||
upstream_vars.begin(), upstream_vars.end(), | ||
[&](const Var& var) -> bool { return info_.requires_compile_time_param.count(var); }); | ||
if (depends_on_compile_time_param) { | ||
info_.requires_compile_time_param.insert(binding->var); | ||
} | ||
Comment on lines
+369
to
+375
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this the logic that is meant to handle zero-param operators? I had a really hard time figuring out what the change was from #16594 and this was the only major area that came up in the diff. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This section along with this line, yes. This section collects information to determine if this parameter depends, directly or indirectly, on one of the model weights. The linked line uses the collected information to determine if the I've updated the comment here to indicate why the additional information is being collected. |
||
|
||
} else { | ||
info_.required_at_runtime.insert(binding->var); | ||
auto bound_value = GetBoundValue(binding); | ||
for (const auto& upstream_var : FreeVars(bound_value)) { | ||
info_.required_at_runtime.insert(upstream_var); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the reason for this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
MarkGraphNode
function should be called for IR nodes that have reference equality, such as variables. Using it with therelax::CallNode
means that results fromStructuralEqual
will be dependent on how arelax::Call
node was constructed, even if they have identical contents.Relevant to this PR, if
R.zeros([16], "int32")
appears in bothmain
andtransform_params
, the expected output generated by the TVMScript parser would have two differentrelax::Call
objects, while the output ofLiftTransformParams
would use the samerelax::Call
object. Because therelax::Call
objects were being checked for analogous reference equality, this would causeStructuralEqual
to erroneously report the test as failing.I've added a unit test to specifically exercise this behavior, rather than implicitly relying on the tests for
LiftTransformParams
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch!