Skip to content

Commit

Permalink
[Transform] Check for zero-param operators in LiftTransformParams (#1…
Browse files Browse the repository at this point in the history
…6595)

Prior to this commit, `LiftTransformParams` would extract out all
variable binding that have no runtime dependencies.  As a result,
expressions such as `R.zeros([16], "int32")` would be extracted out
into the parameter transformation, even though they do not depend on
any parameters.

This commit updates `LiftTransformParams` to only output variables
that depend on at least one compile-time parameter.

The unit test for this functionality also found that `relax::Call` was
erroneously calling `MarkGraphNode` in `SEqualReduce` and
`SHashReduce`.  This should only be called for nodes that have have
reference equality, such as `relax::Var`, and not for composite
objects.  This caused erroneous failures in the unit test when two
instances of `R.zeros([16], "int32")` were being compared by reference
equality in `StructuralEqual`.  These extra calls to `MarkGraphNode`
have been removed.
  • Loading branch information
Lunderberg committed Feb 29, 2024
1 parent c2c579b commit e261a27
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 12 deletions.
2 changes: 0 additions & 2 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,11 @@ class CallNode : public ExprNode {

bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
// skip sinfo_args check for primitive ops.
equal->MarkGraphNode();
return equal(op, other->op) && equal(args, other->args) && equal(attrs, other->attrs) &&
equal(sinfo_args, other->sinfo_args) && equal(struct_info_, other->struct_info_);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(op);
hash_reduce(args);
hash_reduce(attrs);
Expand Down
75 changes: 65 additions & 10 deletions src/relax/transform/lift_transform_params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@ struct CollectInfo {
*/
std::vector<Binding> computable_at_compile_time;

/*! \brief Variables that require a compile-time parameter
*
* Used to distinguish between computed tensors that depend on the
* model weights, and computed tensors that require neither model
* weights nor runtime arguments (e.g. `R.zeros([16], "float16")`).
*/
std::unordered_set<Variant<relax::Var, tir::Var>, ObjectPtrHash, ObjectPtrEqual>
requires_compile_time_param;

/*! \brief Variables that are required at runtime */
std::unordered_set<Variant<relax::Var, tir::Var>, ObjectPtrHash, ObjectPtrEqual>
required_at_runtime;
Expand Down Expand Up @@ -114,7 +123,8 @@ struct CollectInfo {
// Any variable that is computed at compile-time, but is required
// at runtime, must be provided as a parameter.
for (const auto& binding : computable_at_compile_time) {
if (required_at_runtime.count(binding->var)) {
if (requires_compile_time_param.count(binding->var) &&
required_at_runtime.count(binding->var)) {
params.push_back(binding->var);
}
}
Expand Down Expand Up @@ -182,16 +192,21 @@ struct CollectInfo {

// Any binding that is computable at compile-time should be
// suppressed at run-time.
struct SuppressCompileTime : ExprMutator {
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> to_suppress;
explicit SuppressCompileTime(const std::vector<Binding>& bindings) {
for (const auto& binding : bindings) {
to_suppress.insert(binding->var);
}
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> to_suppress;
for (const auto& binding : computable_at_compile_time) {
if (requires_compile_time_param.count(binding->var)) {
to_suppress.insert(binding->var);
}
}

class SuppressCompileTime : public ExprMutator {
public:
explicit SuppressCompileTime(
const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& to_suppress)
: to_suppress_(to_suppress) {}

void VisitBinding(const Binding& binding) override {
if (!to_suppress.count(binding->var)) {
if (!to_suppress_.count(binding->var)) {
ExprMutator::VisitBinding(binding);
}
}
Expand All @@ -205,8 +220,11 @@ struct CollectInfo {
return ExprMutator::VisitExpr_(call);
}
}

private:
const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& to_suppress_;
};
Expr body = SuppressCompileTime(computable_at_compile_time)(orig_func->body);
Expr body = SuppressCompileTime(to_suppress)(orig_func->body);
body = SeqExpr({DataflowBlock(bindings)}, body);

Function func(params, body, orig_func->ret_struct_info, orig_func->is_pure, orig_func->attrs);
Expand Down Expand Up @@ -300,6 +318,7 @@ class LiftableBindingCollector : ExprVisitor {

for (size_t i = num_runtime_params; i < func->params.size(); i++) {
liftable_vars_.insert(func->params[i]);
info_.requires_compile_time_param.insert(func->params[i]);
for (const auto& tir_var : DefinableTIRVarsInStructInfo(GetStructInfo(func->params[i]))) {
liftable_vars_.insert(tir_var);
}
Expand All @@ -315,12 +334,48 @@ class LiftableBindingCollector : ExprVisitor {
}

void VisitBinding(const Binding& binding) override {
auto bound_value = GetBoundValue(binding);

if (CanLiftBinding(binding)) {
info_.computable_at_compile_time.push_back(binding);
liftable_vars_.insert(binding->var);

// There are three type of variables we want to distinguish.
//
// 1. Depend on runtime parameters
//
// Must remain within the original function, cannot be
// lifted out into the `transform_params` function.
//
// 2. Depend on model weights, but not runtime parameters.
//
// Legal to lift out into the `transform_params` function.
// Doing so is beneficial, as it reduces the work performed
// in the inference function.
//
// 3. Depend on neither model weights nor runtime parameters
// (e.g. `R.zeros(shape,dtype)`)
//
// Legal to lift out into the `transform_params` function.
// However, doing so would increase the memory footprint of
// the pre-computed parameters, for little to no benefit.
// These may be duplicated between the `transform_params`
// function and the original function, as they typically
// initialize a tensor to an easy-to-compute state.
//
// Tracking whether a variable depends on the model weights,
// either directly or indirectly, allows us to distinguish
// between categories (2) and (3).
auto upstream_vars = FreeVars(bound_value);
bool depends_on_compile_time_param = std::any_of(
upstream_vars.begin(), upstream_vars.end(),
[&](const Var& var) -> bool { return info_.requires_compile_time_param.count(var); });
if (depends_on_compile_time_param) {
info_.requires_compile_time_param.insert(binding->var);
}

} else {
info_.required_at_runtime.insert(binding->var);
auto bound_value = GetBoundValue(binding);
for (const auto& upstream_var : FreeVars(bound_value)) {
info_.required_at_runtime.insert(upstream_var);
}
Expand Down
54 changes: 54 additions & 0 deletions tests/python/relax/test_transform_lift_transform_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,5 +795,59 @@ def main(
tvm.ir.assert_structural_equal(Expected, After)


def test_only_lift_when_variable_uses_constants():
"""A variable that has no inputs should not be lifted
For example, `R.zeros`, or the result of allocation function
calls.
"""

@tvm.script.ir_module
class Before:
@R.function
def main(
A: R.Tensor([16], "int32"),
B: R.Tensor([16], "int32"),
):
R.func_attr({"num_input": 1})
with R.dataflow():
offset = R.ones([16], "int32")
A_offset = R.add(A, offset)
B_offset = R.add(B, offset)
output = R.multiply(A_offset, B_offset)
R.output(output)
return output

@tvm.script.ir_module
class Expected:
@R.function
def main(
A: R.Tensor([16], "int32"),
B_offset: R.Tensor([16], "int32"),
):
R.func_attr({"num_input": 1})
with R.dataflow():
offset = R.ones([16], "int32")
A_offset = R.add(A, offset)
output = R.multiply(A_offset, B_offset)
R.output(output)
return output

@R.function
def main_transform_params(params: R.Tuple([R.Tensor([16], "int32")])):
R.func_attr({"num_input": 0})
with R.dataflow():
offset = R.ones([16], "int32")
B = params[0]
B_offset = R.add(B, offset)
output = (B_offset,)
R.output(output)
return output

mod = Before
after = relax.transform.LiftTransformParams()(mod)
tvm.ir.assert_structural_equal(after, Expected)


if __name__ == "__main__":
tvm.testing.main()
23 changes: 23 additions & 0 deletions tests/python/relax/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,5 +122,28 @@ def inner(x: R.Tensor((3,), "float32")) -> R.Tensor((3,), dtype="float32"):
assert_structural_equal(Actual, Expected)


def test_structural_equal_of_call_nodes():
"""relax.Call must be compared by structural equality, not reference"""

# Three identical calls to relax.op.zeros
calls_to_op_zero = [relax.op.zeros([16], "int32") for _ in range(3)]

@R.function(private=True)
def uses_same_object_twice():
A = calls_to_op_zero[0]
B = calls_to_op_zero[0]
C = R.add(A, B)
return C

@R.function(private=True)
def uses_two_different_objects():
A = calls_to_op_zero[1]
B = calls_to_op_zero[2]
C = R.add(A, B)
return C

tvm.ir.assert_structural_equal(uses_same_object_twice, uses_two_different_objects)


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit e261a27

Please sign in to comment.