-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unify Python and C++ TIR lower API #8110
Conversation
src/driver/driver_api.cc
Outdated
} else { | ||
ICHECK(false) << "driver.lower expects the first argument to be a te::Schedule or " | ||
<< "IRModule"; | ||
throw; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't add a throw here, instead return an empty IRModule(...)
@CircleSpin @electriclilies ping me again when tests are green, left one comment that jumped out at me, otherwise looks good if we can get it to pass the test suites |
cc @manupa-arm @giuseros this is worth paying attention to, we have been working on cleaning up the internal APIs and bringing everything to C++. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this work! much appreciated. :)
Just few clarifcation questions about the overall design.
include/tvm/driver/driver_api.h
Outdated
bool simple_mode = false); | ||
|
||
/*! | ||
* \brief Build an IRModule given a module, args and binds |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo : should be a PrimFunc?
src/driver/driver_api.cc
Outdated
pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>()) | ||
.value(); | ||
|
||
auto user_lower_phase0 = Array<tvm::transform::Pass>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clarification : What is the importance of having phases and what do they represent ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure, it would be good to get some clarification from someone else on this, maybe @junrushao1994 knows?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to take the refactoring opportunity to add some documentation on this. Though completely understand you have mirrored the previous lower impl.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm happy to add documentation, I just don't actually know what the answer is to this question
src/driver/driver_api.cc
Outdated
IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name, | ||
const std::unordered_map<te::Tensor, tir::Buffer>& binds) { | ||
Array<ObjectRef> out_arg_list; | ||
Array<tvm::transform::Pass> CreatePassList(bool simple_mode, bool legacy_te_pass) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clarification : what is legacy_te_pass mean here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Legacy te pass means that the object we are lowering is a te::Schedule. I've copied the logic and phrasing directly from the python version -- I don't know the details of why this is still here. But without it, some tests fail.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we could be more specific in the name ? e.g. using such as "bool from_te_schedule" ?
Also I see simple_mode is just enabling LoopPartition, we could use "bool enable_loop_partition" ?
src/driver/driver_api.cc
Outdated
pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); | ||
// Phase 1 | ||
// PHASE 1 | ||
if (legacy_te_pass) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jroesch @CircleSpin
A general design question : What do you think of relocating the pass pipeline closer to the target registry ? So that way we could have a function (that can also be a bit parameterized based on target args) that describes pass pipeline rather than mandating the passes here.
Here, we could just query the pass pipeline based on the target and run here. I think this can remove the need to support user defined custom passes at this level.
Thoughts ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think doing something like this is a good idea. I also want to simplify the signature for Lower..
Right now we have a lot of function signatures for Lower because the python version allows multiple input types for each argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To clarify, are you saying that we would construct the pass list in the registration of driver.lower
, then pass it into a function like LowerWithPassList
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, OK having looked at whats being removed -- maybe this is out of scope for the PR.
My original question was about the motivation of custom passes -- specifically at the four possible locations they get inserted (known as phases). If it has some meaning, might worth putting a comment.
Then, the next one was given the feature of adding custom passes -- I was thinking we could rather have custom pass pipeline registered/provided with the proximity for the target. Thus, some could create a new target and may be re-use , re-organize the passes that needs running.
Again now I feel thats out of scope for this PR -- as this essentially mimics what the python lower is doing. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! some further thoughts
src/driver/driver_api.cc
Outdated
pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); | ||
// Phase 1 | ||
// PHASE 1 | ||
if (legacy_te_pass) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, OK having looked at whats being removed -- maybe this is out of scope for the PR.
My original question was about the motivation of custom passes -- specifically at the four possible locations they get inserted (known as phases). If it has some meaning, might worth putting a comment.
Then, the next one was given the feature of adding custom passes -- I was thinking we could rather have custom pass pipeline registered/provided with the proximity for the target. Thus, some could create a new target and may be re-use , re-organize the passes that needs running.
Again now I feel thats out of scope for this PR -- as this essentially mimics what the python lower is doing. :)
src/driver/driver_api.cc
Outdated
IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name, | ||
const std::unordered_map<te::Tensor, tir::Buffer>& binds) { | ||
Array<ObjectRef> out_arg_list; | ||
Array<tvm::transform::Pass> CreatePassList(bool simple_mode, bool legacy_te_pass) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we could be more specific in the name ? e.g. using such as "bool from_te_schedule" ?
Also I see simple_mode is just enabling LoopPartition, we could use "bool enable_loop_partition" ?
@manupa-arm Yeah I'm not exactly sure what the meaning of each "phase" is-- again, I'm just trying to duplicate the python code in C++. @junrushao1994 Could you weigh in about the significance of each phase? Also, thanks for the naming suggestions, I'm going to go through and clean stuff up later and I will take them into account then! Ideally I'd like to remove flags if possible |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Tristan!
/*! | ||
* \brief Build an IRModule given a schedule, args and binds | ||
* \param sch The schedule to lower. | ||
* \param args The arguments to the function. | ||
* \param name The name of the lowered function. | ||
* \param binds Buffer assignments. | ||
* \param simple_mode Disables the loop partition pass. Defaults to false. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not completely sure. This logic was already in the code base, so I just duplicated it and made it more explicit. (For context, when I started this refactor, Jared said he wanted to get it in quickly so I should just try to naively duplicate the existing logic, but he hasn't reviewed it yet so I'm not sure what the timeline is now).
I can try to remove it, but I'm not sure what the best way to go about this is since there are a few tests that call lower directly..
* \param binds Buffer assignments. | ||
* \return The result module. | ||
*/ | ||
IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Initially this was a function called form_irmodule
in python. I translated it to C++ and renamed it as ScheduleToModule
. Unfortunately, form_irmodule is called by some tests. To preserve that behavior I had to split the functions apart again and register them separately in the FFI.
The difference between these is that ScheduleToModule
just converts the schedule to a module that hasn't yet been lowered, whereas LowerSchedule
converts the schedule into a module and then applies the passes.
@@ -136,7 +98,7 @@ def lower( | |||
|
|||
Parameters | |||
---------- | |||
input : Union[schedule.Schedule, PrimFunc, IRModule] | |||
inputs : Union[schedule.Schedule, PrimFunc, IRModule] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The word input is a python built in function so you can't name variables input.
Initially the name of the parameter to the function was inputs
and the name of the parameter in the documentation was input
. I just changed the documentation to match the function signature. I don't like that they are different. I agree that inputs
is not an ideal name though and I'm open to suggestions for other names
src/driver/driver_api.cc
Outdated
pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>()) | ||
.value(); | ||
|
||
auto user_lower_phase0 = Array<tvm::transform::Pass>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure, it would be good to get some clarification from someone else on this, maybe @junrushao1994 knows?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Absolutely phenomenal improvement, thank you for the very well thought out refactor!
include/tvm/driver/driver_api.h
Outdated
@@ -42,17 +43,64 @@ | |||
#include <vector> | |||
|
|||
namespace tvm { | |||
|
|||
/*! | |||
* \brief Build an IRModule given a module, args and binds |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* \brief Build an IRModule given a module, args and binds | |
* \brief Build an IRModule given a module |
Updated the comment to reflect the API. Please check the docs for the other APIs that don't take args/binds as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated!
python/tvm/driver/build_module.py
Outdated
binds: dict | ||
The bind specification | ||
|
||
arg_list: list | ||
The list of symbolic buffers of arguments. | ||
""" | ||
binds = {} if binds is None else binds.copy() | ||
arg_list = [] | ||
for x in args: | ||
if isinstance(x, tensor.Tensor): | ||
any_dim = any(isinstance(i, tvm.tir.Var) for i in x.shape) | ||
buffer_type = "auto_broadcast" if any_dim and not compact else "" | ||
if x not in binds: | ||
buf = tvm.tir.decl_buffer( | ||
x.shape, dtype=x.dtype, name=x.name, buffer_type=buffer_type | ||
) | ||
binds[x] = buf | ||
arg_list.append(buf) | ||
else: | ||
arg_list.append(binds[x]) | ||
elif isinstance(x, schedule.Buffer): | ||
arg_list.append(x) | ||
elif isinstance(x, tvm.tir.Var): | ||
arg_list.append(x) | ||
else: | ||
raise ValueError("args must be Tensor, Buffer or Var") | ||
return binds, arg_list | ||
|
||
|
||
def form_irmodule(sch, args, name, binds): | ||
"""According to the given schedule, form a function. | ||
out_arr = ffi.get_binds(args, compact, binds) | ||
return out_arr[0], out_arr[1] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably we should update the doc string to match the return variable names. Actually the reverse would be preferable, binds, arg_list = ffi.get_binds(...)
but I understand you are limited here by the object system.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup can do this
src/relay/backend/compile_engine.cc
Outdated
cache_node->funcs = | ||
tvm::LowerSchedule(cfunc->schedule, all_args, cache_node->func_name, binds); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't need to route through the relay.backend.lower
packed function any longer right? I'd suggest either removing it and the above conditional (line 766) or also reproducing its functionality in c++ if we do need it as part of this PR. This way we avoid an unnecessary round trip through python.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I don't think we do. I'll remove it (I didn't look too closely at what this call was doing, oops!)
Co-authored-by: Chris Sullivan <csullivan@octoml.ai>
/*! | ||
* \brief Build an IRModule given a schedule, args and binds | ||
* \param sch The schedule to lower. | ||
* \param args The arguments to the function. | ||
* \param name The name of the lowered function. | ||
* \param binds Buffer assignments. | ||
* \param simple_mode Disables the loop partition pass. Defaults to false. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is probably fine for now.
* \param binds Buffer assignments. | ||
* \return The result module. | ||
*/ | ||
IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you be a little more explicit about this in the documentation?
@@ -136,7 +98,7 @@ def lower( | |||
|
|||
Parameters | |||
---------- | |||
input : Union[schedule.Schedule, PrimFunc, IRModule] | |||
inputs : Union[schedule.Schedule, PrimFunc, IRModule] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe use inp
instead? I can't really think of a good name either.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this great work! I think it should work for TensorIR as long as pass all existing unit tests :) Generally, it looks good to me except for some nits in code style
src/driver/driver_api.cc
Outdated
c_binds.insert({kv.first, kv.second}); | ||
} | ||
} | ||
IRModule mod = ScheduleToModule(sch, args, name, c_binds); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IRModule mod = ScheduleToModule(sch, args, name, c_binds); | |
IRModule mod = ScheduleToModule(std::move(sch), args, name, c_binds); |
src/driver/driver_api.cc
Outdated
|
||
IRModule LowerModule(IRModule mod, bool simple_mode) { | ||
auto pass_list = CreatePassList(simple_mode, false); | ||
return LowerWithPassList(mod, pass_list); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return LowerWithPassList(mod, pass_list); | |
return LowerWithPassList(std::(mod), pass_list); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please apply std::move
to all following FFI functions
src/driver/driver_api.cc
Outdated
sch = sch.normalize(); | ||
|
||
// Before TIR transformation. | ||
auto bounds = te::InferBound(sch); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please explicitly mark the type as possible as we can. auto
is not so user-friendly that others can understand the code.
src/driver/driver_api.cc
Outdated
|
||
// Before TIR transformation. | ||
auto bounds = te::InferBound(sch); | ||
auto stmt = te::ScheduleOps(sch, bounds, false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use const ref or std::move
to reduce unnecessary memory copy
auto stmt = te::ScheduleOps(sch, bounds, false); | |
auto stmt = te::ScheduleOps(sch, std::move(bounds), false); |
Co-authored-by: Tristan Konolige <tristan.konolige@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One small nit, but otherwise things look really good. That net -1K lines is so satisfying to see. Big thanks!
include/tvm/driver/driver_api.h
Outdated
@@ -42,17 +43,67 @@ | |||
#include <vector> | |||
|
|||
namespace tvm { | |||
|
|||
/*! | |||
* \brief Build an IRModule given an input IRModule |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* \brief Build an IRModule given an input IRModule | |
* \brief Lower an IRModule |
Could you maybe also add a little more detail on what it means to lower an irmodule?
Thanks @tkonolige! @Hzfengsy @csullivan can you take another look and approve? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks again for the great refactor @electriclilies, removing 4x the code than what is needed by the change == 💪 🔥 .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LTGM. Thanks a lot for the refactoring
Thanks @CircleSpin and @electriclilies ! Thanks @Hzfengsy @tkonolige @csullivan @YuchenJin @manupa-arm for reviewing |
Thanks @tqchen! |
Currently, the Lower API is bifurcated. There is an implementation in Python and an implementation in C++. Both are used in the codebase despite having different signatures and different implementations. This PR unifies the C++ and Python API into one C++ backend, and allows the C++ API to be called from Python through the FFI.
Unfortunately, the Python version of the API relied on duck typing quite a bit, which made it difficult to fully support the Python API. To allow the C++ backend to fully replace the Python version of lower, @electriclilies had to write three different versions of Lower: LowerModule, LowerPrimFunc, and LowerSchedule, as well as make other changes.
The major changes to the C++ API are:
1, Implementing versions of Lower for Modules, PrimFuncs and Schedules
2. Adding user-defined passes to the C++ version
3. Overload LowerSchedule to take in an Array for the arguments. The Python version accepts a List[Union[Buffer, Tensor, Var]] for the args. Unfortunately, the TVM object system does not allow union types, so Lily had to allow LowerSchedule to take in an Array for the args and raise an error if something other than a Buffer, Tensor or Var is passed in
4. Removing SchedulePostProcForTensorCore (this is deadcode, it was used in one tutorial that has since been deleted and causes errors if called in certain cases)
Additionally, Lily moved the python version of get_binds into C++ and exposed it through the FFI. Lily also renamed form_irmodule to schedule_to_module and moved that into C++. Finally, this refactor was a bit painful. @electriclilies will write a discussion post to go through why it was difficult and what can be changed to avoid issues like this in the future.