Skip to content

Commit

Permalink
[Transform] Improvements to LazyTransformParams (#16602)
Browse files Browse the repository at this point in the history
* [Transform] Improvements to LazyTransformParams

* Handle non-bundled parameters in LazyTransformParams.

* Check for `"num_input"` attribute

* Handle relax.Const in LazyTransformParams

  Prior to this commit, `LazyTransformParams` would only output a call
  to the `fset_item` function if that element of the output had a
  corresponding `relax.Binding`.  If `relax.Const` appeared in the
  output, then the call to `fset_item` would be omitted.

  This commit updates `LazyTransformParams` to check for any non-`Var`
  elements of the output tuple.

* Update based on review comments
  • Loading branch information
Lunderberg committed Feb 22, 2024
1 parent 8fe0164 commit 5308ef1
Show file tree
Hide file tree
Showing 2 changed files with 310 additions and 20 deletions.
126 changes: 107 additions & 19 deletions python/tvm/relax/transform/lazy_transform_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class LivenessAnalysis(PyExprVisitor):
"""

def __init__(self, out_tuple_var: relax.Var) -> None:
self.last_appear_in_var_binding = None
self.last_appear_in_var_binding = []
self.out_tuple_var = out_tuple_var
self.var_liveness_end = {}
self.ended_vars = set()
Expand Down Expand Up @@ -132,20 +132,22 @@ def __init__(
self.extra_get_item_params = extra_get_item_params
self.fset_item = fset_item
self.extra_set_item_params = extra_set_item_params
# the only input param, which should be a Tuple
self.input_tuple_param = None
self.input_params_set = None
self.out_tuple_map = None
self.out_tuple_var = None
self.memory_free_insertion = None

def transform(self, func: relax.Function) -> relax.Function:
self.input_tuple_param = func.params[0]
if func.attrs is not None and "num_input" in func.attrs:
num_input = func.attrs["num_input"].value
else:
num_input = 0

seq_expr = func.body
self.out_tuple_var = seq_expr.body

# Step 1. collect out_tuple_map and input_params_set
forward_collector = ForwardCollector(self.out_tuple_var, self.input_tuple_param)
forward_collector = ForwardCollector(self.out_tuple_var, func.params[num_input])
forward_collector.visit_expr(func)
self.out_tuple_map = forward_collector.out_tuple_map
# input_params_set is the set of binding var for var = params[i]
Expand All @@ -157,24 +159,65 @@ def transform(self, func: relax.Function) -> relax.Function:
self.memory_free_insertion = liveness.var_liveness_end

# Step 3. rewrite get item and set item
new_body = func.body
if self.fget_item is not None:
new_body = LazyInputMutator(self, self.mod).visit_expr(new_body)
new_func = LazyInputMutator(self, self.mod).visit_expr(func)

new_body = new_func.body
if self.fset_item is not None:
# The LazyOutputMutator only inspects variable bindings
# for replacement. If the output tuple includes elements
# that do not have a variable binding, such as
# `relax.Const`, these must still produce a call to the
# `"set_item"` function.
leaf_outputs = {
expr: indices
for expr, indices in self.out_tuple_map.items()
if not isinstance(expr, relax.Var)
}
if leaf_outputs:
new_bindings = [
relax.VarBinding(
relax.Var("_", relax.ObjectStructInfo()),
relax.Call(
relax.ExternFunc(self.fset_item),
[*self.extra_set_item_params, index, expr],
None,
[relax.ObjectStructInfo()],
),
)
for expr, indices in leaf_outputs.items()
for index in indices
]
new_body = relax.SeqExpr(
[*new_body.blocks, relax.BindingBlock(new_bindings)], new_body.body
)

new_body = LazyOutputMutator(self, self.mod).visit_expr(new_body)

# Step 4. Add parameters of get_item and set_item (except index) to the function.
params = [*self.extra_get_item_params, *self.extra_set_item_params]
params = [
*func.params[:num_input],
*self.extra_get_item_params,
*self.extra_set_item_params,
]

# Step 5. Find all shape parameters that should be retained as
# parameters.
symbolic_vars = relax.analysis.defined_symbolic_vars(func)
if symbolic_vars:

def unpack_sinfo(sinfo):
if isinstance(sinfo, relax.TupleStructInfo):
for field in sinfo.fields:
yield from unpack_sinfo(field)
else:
yield sinfo

# direct iterate over the struct info annotation
for sinfo in self.input_tuple_param.struct_info.fields:
if not isinstance(sinfo, relax.TensorStructInfo):
params.append(relax.Var("symbolic_var_holder", sinfo))
for param in func.params[num_input:]:
for sinfo in unpack_sinfo(param.struct_info):
if not isinstance(sinfo, relax.TensorStructInfo):
params.append(relax.Var("symbolic_var_holder", sinfo))

return relax.Function(
params,
Expand All @@ -191,22 +234,67 @@ def __init__(self, func_creator, mod: Optional[IRModule] = None) -> None:
self.func_creator = func_creator
super().__init__(mod)

def visit_tuple_getitem_(self, op: relax.TupleGetItem) -> relax.Expr:
# rewrite get item
tuple_get_item = super().visit_tuple_getitem_(op)
if tuple_get_item.tuple_value == self.func_creator.input_tuple_param:
def visit_function_(self, func: relax.Function) -> relax.Expr:
if func.attrs is not None and "num_input" in func.attrs:
num_input = func.attrs["num_input"].value
else:
num_input = 0

params = list(func.params)[num_input:]
if len(params) == 1 and isinstance(params[0].struct_info_, relax.TupleStructInfo):
self.tuple_param = params[0]
self.params = {}
else:
self.tuple_param = None
self.params = {var: i for i, var in enumerate(params)}
func = relax.Function(
func.params[:num_input],
func.body,
func.ret_struct_info,
is_pure=False,
attrs=func.attrs,
span=func.span,
).without_attr("relax.force_pure")
output = super().visit_function_(func)
self.tuple_param = None
self.params = {}
return output

def visit_var_(self, var: relax.Var) -> relax.Expr:
if var in self.params:
index = self.params[var]
get_item_result = self.builder_.emit(
relax.Call(
relax.ExternFunc(self.func_creator.fget_item),
self.func_creator.extra_get_item_params + [relax.PrimValue(index)],
None,
[relax.ObjectStructInfo()],
)
)
match_cast = relax.MatchCast(var, get_item_result, var.struct_info)
self.builder_.emit_normalized(match_cast)

del self.params[var]

return super().visit_var_(var)

def visit_tuple_getitem_(self, node: relax.TupleGetItem) -> relax.Expr:
sinfo = node.struct_info

node = super().visit_tuple_getitem_(node)

if self.tuple_param is not None and node.tuple_value.same_as(self.tuple_param):
get_item_result = self.builder_.emit(
relax.Call(
relax.ExternFunc(self.func_creator.fget_item),
self.func_creator.extra_get_item_params
+ [relax.PrimValue(tuple_get_item.index)],
self.func_creator.extra_get_item_params + [relax.PrimValue(node.index)],
None,
[relax.ObjectStructInfo()],
)
)
return self.builder_.match_cast(get_item_result, op.struct_info)
return self.builder_.match_cast(get_item_result, sinfo)
else:
return tuple_get_item
return node


@mutator
Expand Down

0 comments on commit 5308ef1

Please sign in to comment.