Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Transform] Check for zero-param operators in LiftTransformParams #16595

Merged
merged 1 commit into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 0 additions & 2 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,11 @@ class CallNode : public ExprNode {

bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
// skip sinfo_args check for primitive ops.
equal->MarkGraphNode();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason for this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The MarkGraphNode function should be called for IR nodes that have reference equality, such as variables. Using it with the relax::CallNode means that results from StructuralEqual will be dependent on how a relax::Call node was constructed, even if they have identical contents.

Relevant to this PR, if R.zeros([16], "int32") appears in both main and transform_params, the expected output generated by the TVMScript parser would have two different relax::Call objects, while the output of LiftTransformParams would use the same relax::Call object. Because the relax::Call objects were being checked for analogous reference equality, this would cause StructuralEqual to erroneously report the test as failing.

I've added a unit test to specifically exercise this behavior, rather than implicitly relying on the tests for LiftTransformParams.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

return equal(op, other->op) && equal(args, other->args) && equal(attrs, other->attrs) &&
equal(sinfo_args, other->sinfo_args) && equal(struct_info_, other->struct_info_);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(op);
hash_reduce(args);
hash_reduce(attrs);
Expand Down
75 changes: 65 additions & 10 deletions src/relax/transform/lift_transform_params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@ struct CollectInfo {
*/
std::vector<Binding> computable_at_compile_time;

/*! \brief Variables that require a compile-time parameter
*
* Used to distinguish between computed tensors that depend on the
* model weights, and computed tensors that require neither model
* weights nor runtime arguments (e.g. `R.zeros([16], "float16")`).
*/
std::unordered_set<Variant<relax::Var, tir::Var>, ObjectPtrHash, ObjectPtrEqual>
requires_compile_time_param;

/*! \brief Variables that are required at runtime */
std::unordered_set<Variant<relax::Var, tir::Var>, ObjectPtrHash, ObjectPtrEqual>
required_at_runtime;
Expand Down Expand Up @@ -114,7 +123,8 @@ struct CollectInfo {
// Any variable that is computed at compile-time, but is required
// at runtime, must be provided as a parameter.
for (const auto& binding : computable_at_compile_time) {
if (required_at_runtime.count(binding->var)) {
if (requires_compile_time_param.count(binding->var) &&
required_at_runtime.count(binding->var)) {
params.push_back(binding->var);
}
}
Expand Down Expand Up @@ -182,16 +192,21 @@ struct CollectInfo {

// Any binding that is computable at compile-time should be
// suppressed at run-time.
struct SuppressCompileTime : ExprMutator {
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> to_suppress;
explicit SuppressCompileTime(const std::vector<Binding>& bindings) {
for (const auto& binding : bindings) {
to_suppress.insert(binding->var);
}
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> to_suppress;
for (const auto& binding : computable_at_compile_time) {
if (requires_compile_time_param.count(binding->var)) {
to_suppress.insert(binding->var);
}
}

class SuppressCompileTime : public ExprMutator {
public:
explicit SuppressCompileTime(
const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& to_suppress)
: to_suppress_(to_suppress) {}

void VisitBinding(const Binding& binding) override {
if (!to_suppress.count(binding->var)) {
if (!to_suppress_.count(binding->var)) {
ExprMutator::VisitBinding(binding);
}
}
Expand All @@ -205,8 +220,11 @@ struct CollectInfo {
return ExprMutator::VisitExpr_(call);
}
}

private:
const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& to_suppress_;
};
Expr body = SuppressCompileTime(computable_at_compile_time)(orig_func->body);
Expr body = SuppressCompileTime(to_suppress)(orig_func->body);
body = SeqExpr({DataflowBlock(bindings)}, body);

Function func(params, body, orig_func->ret_struct_info, orig_func->is_pure, orig_func->attrs);
Expand Down Expand Up @@ -300,6 +318,7 @@ class LiftableBindingCollector : ExprVisitor {

for (size_t i = num_runtime_params; i < func->params.size(); i++) {
liftable_vars_.insert(func->params[i]);
info_.requires_compile_time_param.insert(func->params[i]);
for (const auto& tir_var : DefinableTIRVarsInStructInfo(GetStructInfo(func->params[i]))) {
liftable_vars_.insert(tir_var);
}
Expand All @@ -315,12 +334,48 @@ class LiftableBindingCollector : ExprVisitor {
}

void VisitBinding(const Binding& binding) override {
auto bound_value = GetBoundValue(binding);

if (CanLiftBinding(binding)) {
info_.computable_at_compile_time.push_back(binding);
liftable_vars_.insert(binding->var);

// There are three type of variables we want to distinguish.
//
// 1. Depend on runtime parameters
//
// Must remain within the original function, cannot be
// lifted out into the `transform_params` function.
//
// 2. Depend on model weights, but not runtime parameters.
//
// Legal to lift out into the `transform_params` function.
// Doing so is beneficial, as it reduces the work performed
// in the inference function.
//
// 3. Depend on neither model weights nor runtime parameters
// (e.g. `R.zeros(shape,dtype)`)
//
// Legal to lift out into the `transform_params` function.
// However, doing so would increase the memory footprint of
// the pre-computed parameters, for little to no benefit.
// These may be duplicated between the `transform_params`
// function and the original function, as they typically
// initialize a tensor to an easy-to-compute state.
//
// Tracking whether a variable depends on the model weights,
// either directly or indirectly, allows us to distinguish
// between categories (2) and (3).
auto upstream_vars = FreeVars(bound_value);
bool depends_on_compile_time_param = std::any_of(
upstream_vars.begin(), upstream_vars.end(),
[&](const Var& var) -> bool { return info_.requires_compile_time_param.count(var); });
if (depends_on_compile_time_param) {
info_.requires_compile_time_param.insert(binding->var);
}
Comment on lines +369 to +375
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the logic that is meant to handle zero-param operators? I had a really hard time figuring out what the change was from #16594 and this was the only major area that came up in the diff.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This section along with this line, yes. This section collects information to determine if this parameter depends, directly or indirectly, on one of the model weights. The linked line uses the collected information to determine if the transform_params should output a variable.

I've updated the comment here to indicate why the additional information is being collected.


} else {
info_.required_at_runtime.insert(binding->var);
auto bound_value = GetBoundValue(binding);
for (const auto& upstream_var : FreeVars(bound_value)) {
info_.required_at_runtime.insert(upstream_var);
}
Expand Down
54 changes: 54 additions & 0 deletions tests/python/relax/test_transform_lift_transform_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,5 +795,59 @@ def main(
tvm.ir.assert_structural_equal(Expected, After)


def test_only_lift_when_variable_uses_constants():
"""A variable that has no inputs should not be lifted

For example, `R.zeros`, or the result of allocation function
calls.
"""

@tvm.script.ir_module
class Before:
@R.function
def main(
A: R.Tensor([16], "int32"),
B: R.Tensor([16], "int32"),
):
R.func_attr({"num_input": 1})
with R.dataflow():
offset = R.ones([16], "int32")
A_offset = R.add(A, offset)
B_offset = R.add(B, offset)
output = R.multiply(A_offset, B_offset)
R.output(output)
return output

@tvm.script.ir_module
class Expected:
@R.function
def main(
A: R.Tensor([16], "int32"),
B_offset: R.Tensor([16], "int32"),
):
R.func_attr({"num_input": 1})
with R.dataflow():
offset = R.ones([16], "int32")
A_offset = R.add(A, offset)
output = R.multiply(A_offset, B_offset)
R.output(output)
return output

@R.function
def main_transform_params(params: R.Tuple([R.Tensor([16], "int32")])):
R.func_attr({"num_input": 0})
with R.dataflow():
offset = R.ones([16], "int32")
B = params[0]
B_offset = R.add(B, offset)
output = (B_offset,)
R.output(output)
return output

mod = Before
after = relax.transform.LiftTransformParams()(mod)
tvm.ir.assert_structural_equal(after, Expected)


if __name__ == "__main__":
tvm.testing.main()
23 changes: 23 additions & 0 deletions tests/python/relax/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,5 +122,28 @@ def inner(x: R.Tensor((3,), "float32")) -> R.Tensor((3,), dtype="float32"):
assert_structural_equal(Actual, Expected)


def test_structural_equal_of_call_nodes():
"""relax.Call must be compared by structural equality, not reference"""

# Three identical calls to relax.op.zeros
calls_to_op_zero = [relax.op.zeros([16], "int32") for _ in range(3)]

@R.function(private=True)
def uses_same_object_twice():
A = calls_to_op_zero[0]
B = calls_to_op_zero[0]
C = R.add(A, B)
return C

@R.function(private=True)
def uses_two_different_objects():
A = calls_to_op_zero[1]
B = calls_to_op_zero[2]
C = R.add(A, B)
return C

tvm.ir.assert_structural_equal(uses_same_object_twice, uses_two_different_objects)


if __name__ == "__main__":
pytest.main([__file__])