New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Transform] Improvements to LazyTransformParams #16602
[Transform] Improvements to LazyTransformParams #16602
Conversation
* Handle non-bundled parameters in LazyTransformParams. * Check for `"num_input"` attribute * Handle relax.Const in LazyTransformParams Prior to this commit, `LazyTransformParams` would only output a call to the `fset_item` function if that element of the output had a corresponding `relax.Binding`. If `relax.Const` appeared in the output, then the call to `fset_item` would be omitted. This commit updates `LazyTransformParams` to check for any non-`Var` elements of the output tuple.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, the changes look good. My only concern is if all the added functionality is being tested.
for expr, indices in leaf_outputs.items() | ||
for index in indices |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You've gotta love the syntax for nested list comprehensions 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Absolutely. I think I prefer the chained iterators used in Rust or C#, but Pythons list/dict/set comprehensions are a close second.
def unpack_sinfo(sinfo): | ||
if isinstance(sinfo, relax.TupleStructInfo): | ||
for field in sinfo.fields: | ||
yield from unpack_sinfo(field) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First I'd seen yield from
, this seems like a good use for it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, and I really like using yield from
for providing flattened results from a graph structures. This way, the result of unpack_sinfo(top_level.struct_info)
can be collected into a list, but there aren't any temporary lists made along the way.
leaf_outputs = { | ||
expr: indices | ||
for expr, indices in self.out_tuple_map.items() | ||
if not isinstance(expr, relax.Var) | ||
} | ||
if leaf_outputs: | ||
new_bindings = [ | ||
relax.VarBinding( | ||
relax.Var("_", relax.ObjectStructInfo()), | ||
relax.Call( | ||
relax.ExternFunc(self.fset_item), | ||
[*self.extra_set_item_params, index, expr], | ||
None, | ||
[relax.ObjectStructInfo()], | ||
), | ||
) | ||
for expr, indices in leaf_outputs.items() | ||
for index in indices | ||
] | ||
new_body = relax.SeqExpr( | ||
[*new_body.blocks, relax.BindingBlock(new_bindings)], new_body.body | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I presume these additions are for handling the non-var case mentioned in the description?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's correct. Otherwise, a R.const(...)
that occurs within the output tuple wouldn't produce a call to the fset_item
function. I've added a comment to clarify.
After = LazyTransformParams(fset_item=None)(Before) | ||
tvm.ir.assert_structural_equal(After, Expected) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there any test cases that make use of extra parameters for get_item and set_item? If it's not tested, it should be. If there also isn't a case of a non-var output (I'm not sure exactly what that should look like, as I haven't used this pass), that would be good to add too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There should be an existing unit test test_extra_params
that validates the extra_get_item_params
. However, there aren't any unit tests that validate extra_set_item_params
, nor are there any that validate extra_set_item_params
in the code path for R.const(...)
outputs.
I've updated the unit tests to cover those additional cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the additional tests.
Handle non-bundled parameters in LazyTransformParams.
Check for
"num_input"
attributeHandle relax.Const in LazyTransformParams
Prior to this commit,
LazyTransformParams
would only output a call to thefset_item
function if that element of the output had a correspondingrelax.Binding
. Ifrelax.Const
appeared in the output, then the call tofset_item
would be omitted.This commit updates
LazyTransformParams
to check for any non-Var
elements of the output tuple.