Skip to content

Commit

Permalink
Update based on review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg committed Feb 22, 2024
1 parent a8ca845 commit 8d69261
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 6 deletions.
41 changes: 35 additions & 6 deletions src/relax/transform/lift_transform_params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,14 @@ struct CollectInfo {
}
}

struct SuppressCompileTime : ExprMutator {
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> to_suppress;
class SuppressCompileTime : public ExprMutator {
public:
explicit SuppressCompileTime(
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> to_suppress)
: to_suppress(to_suppress) {}
const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& to_suppress)
: to_suppress_(to_suppress) {}

void VisitBinding(const Binding& binding) override {
if (!to_suppress.count(binding->var)) {
if (!to_suppress_.count(binding->var)) {
ExprMutator::VisitBinding(binding);
}
}
Expand All @@ -218,8 +218,11 @@ struct CollectInfo {
return ExprMutator::VisitExpr_(call);
}
}

private:
const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& to_suppress_;
};
Expr body = SuppressCompileTime(std::move(to_suppress))(orig_func->body);
Expr body = SuppressCompileTime{to_suppress}(orig_func->body);
body = SeqExpr({DataflowBlock(bindings)}, body);

Function func(params, body, orig_func->ret_struct_info, orig_func->is_pure, orig_func->attrs);
Expand Down Expand Up @@ -335,6 +338,32 @@ class LiftableBindingCollector : ExprVisitor {
info_.computable_at_compile_time.push_back(binding);
liftable_vars_.insert(binding->var);

// There are three type of variables we want to distinguish.
//
// 1. Depend on runtime parameters
//
// Must remain within the original function, cannot be
// lifted out into the `transform_params` function.
//
// 2. Depend on model weights, but not runtime parameters.
//
// Legal to lift out into the `transform_params` function.
// Doing so is beneficial, as it reduces the work performed
// in the inference function.
//
// 3. Depend on neither model weights nor runtime parameters
// (e.g. `R.zeros(shape,dtype)`)
//
// Legal to lift out into the `transform_params` function.
// However, doing so would increase the memory footprint of
// the pre-computed parameters, for little to no benefit.
// These may be duplicated between the `transform_params`
// function and the original function, as they typically
// initialize a tensor to an easy-to-compute state.
//
// Tracking whether a variable depends on the model weights,
// either directly or indirectly, allows us to distinguish
// between categories (2) and (3).
auto upstream_vars = FreeVars(bound_value);
bool depends_on_compile_time_param = std::any_of(
upstream_vars.begin(), upstream_vars.end(),
Expand Down
23 changes: 23 additions & 0 deletions tests/python/relax/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,5 +122,28 @@ def inner(x: R.Tensor((3,), "float32")) -> R.Tensor((3,), dtype="float32"):
assert_structural_equal(Actual, Expected)


def test_structural_equal_of_call_nodes():
"""relax.Call must be compared by structural equality, not reference"""

# Three identical calls to relax.op.zeros
calls_to_op_zero = [relax.op.zeros([16], "int32") for _ in range(3)]

@R.function(private=True)
def uses_same_object_twice():
A = calls_to_op_zero[0]
B = calls_to_op_zero[0]
C = R.add(A, B)
return C

@R.function(private=True)
def uses_two_different_objects():
A = calls_to_op_zero[1]
B = calls_to_op_zero[2]
C = R.add(A, B)
return C

tvm.ir.assert_structural_equal(uses_same_object_twice, uses_two_different_objects)


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 8d69261

Please sign in to comment.