Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify Python and C++ TIR lower API #8110

Merged
merged 57 commits into from
Jun 12, 2021
Merged

Unify Python and C++ TIR lower API #8110

merged 57 commits into from
Jun 12, 2021

Conversation

CircleSpin
Copy link
Contributor

@CircleSpin CircleSpin commented May 21, 2021

Currently, the Lower API is bifurcated. There is an implementation in Python and an implementation in C++. Both are used in the codebase despite having different signatures and different implementations. This PR unifies the C++ and Python API into one C++ backend, and allows the C++ API to be called from Python through the FFI.

Unfortunately, the Python version of the API relied on duck typing quite a bit, which made it difficult to fully support the Python API. To allow the C++ backend to fully replace the Python version of lower, @electriclilies had to write three different versions of Lower: LowerModule, LowerPrimFunc, and LowerSchedule, as well as make other changes.

The major changes to the C++ API are:

1, Implementing versions of Lower for Modules, PrimFuncs and Schedules
2. Adding user-defined passes to the C++ version
3. Overload LowerSchedule to take in an Array for the arguments. The Python version accepts a List[Union[Buffer, Tensor, Var]] for the args. Unfortunately, the TVM object system does not allow union types, so Lily had to allow LowerSchedule to take in an Array for the args and raise an error if something other than a Buffer, Tensor or Var is passed in
4. Removing SchedulePostProcForTensorCore (this is deadcode, it was used in one tutorial that has since been deleted and causes errors if called in certain cases)

Additionally, Lily moved the python version of get_binds into C++ and exposed it through the FFI. Lily also renamed form_irmodule to schedule_to_module and moved that into C++. Finally, this refactor was a bit painful. @electriclilies will write a discussion post to go through why it was difficult and what can be changed to avoid issues like this in the future.

@jwfromm jwfromm requested a review from jroesch May 24, 2021 20:02
} else {
ICHECK(false) << "driver.lower expects the first argument to be a te::Schedule or "
<< "IRModule";
throw;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't add a throw here, instead return an empty IRModule(...)

@jroesch
Copy link
Member

jroesch commented May 25, 2021

@CircleSpin @electriclilies ping me again when tests are green, left one comment that jumped out at me, otherwise looks good if we can get it to pass the test suites

@jroesch
Copy link
Member

jroesch commented May 25, 2021

cc @manupa-arm @giuseros this is worth paying attention to, we have been working on cleaning up the internal APIs and bringing everything to C++.

Copy link
Contributor

@manupak manupak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this work! much appreciated. :)

Just few clarifcation questions about the overall design.

bool simple_mode = false);

/*!
* \brief Build an IRModule given a module, args and binds
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo : should be a PrimFunc?

pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
.value();

auto user_lower_phase0 = Array<tvm::transform::Pass>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clarification : What is the importance of having phases and what do they represent ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure, it would be good to get some clarification from someone else on this, maybe @junrushao1994 knows?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to take the refactoring opportunity to add some documentation on this. Though completely understand you have mirrored the previous lower impl.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy to add documentation, I just don't actually know what the answer is to this question

IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
Array<ObjectRef> out_arg_list;
Array<tvm::transform::Pass> CreatePassList(bool simple_mode, bool legacy_te_pass) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clarification : what is legacy_te_pass mean here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Legacy te pass means that the object we are lowering is a te::Schedule. I've copied the logic and phrasing directly from the python version -- I don't know the details of why this is still here. But without it, some tests fail.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could be more specific in the name ? e.g. using such as "bool from_te_schedule" ?

Also I see simple_mode is just enabling LoopPartition, we could use "bool enable_loop_partition" ?

pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
// Phase 1
// PHASE 1
if (legacy_te_pass) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jroesch @CircleSpin
A general design question : What do you think of relocating the pass pipeline closer to the target registry ? So that way we could have a function (that can also be a bit parameterized based on target args) that describes pass pipeline rather than mandating the passes here.

Here, we could just query the pass pipeline based on the target and run here. I think this can remove the need to support user defined custom passes at this level.

Thoughts ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think doing something like this is a good idea. I also want to simplify the signature for Lower..

Right now we have a lot of function signatures for Lower because the python version allows multiple input types for each argument.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clarify, are you saying that we would construct the pass list in the registration of driver.lower, then pass it into a function like LowerWithPassList?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, OK having looked at whats being removed -- maybe this is out of scope for the PR.

My original question was about the motivation of custom passes -- specifically at the four possible locations they get inserted (known as phases). If it has some meaning, might worth putting a comment.

Then, the next one was given the feature of adding custom passes -- I was thinking we could rather have custom pass pipeline registered/provided with the proximity for the target. Thus, some could create a new target and may be re-use , re-organize the passes that needs running.

Again now I feel thats out of scope for this PR -- as this essentially mimics what the python lower is doing. :)

Copy link
Contributor

@manupak manupak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! some further thoughts

pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
// Phase 1
// PHASE 1
if (legacy_te_pass) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, OK having looked at whats being removed -- maybe this is out of scope for the PR.

My original question was about the motivation of custom passes -- specifically at the four possible locations they get inserted (known as phases). If it has some meaning, might worth putting a comment.

Then, the next one was given the feature of adding custom passes -- I was thinking we could rather have custom pass pipeline registered/provided with the proximity for the target. Thus, some could create a new target and may be re-use , re-organize the passes that needs running.

Again now I feel thats out of scope for this PR -- as this essentially mimics what the python lower is doing. :)

IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
Array<ObjectRef> out_arg_list;
Array<tvm::transform::Pass> CreatePassList(bool simple_mode, bool legacy_te_pass) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could be more specific in the name ? e.g. using such as "bool from_te_schedule" ?

Also I see simple_mode is just enabling LoopPartition, we could use "bool enable_loop_partition" ?

@electriclilies
Copy link
Contributor

@manupa-arm Yeah I'm not exactly sure what the meaning of each "phase" is-- again, I'm just trying to duplicate the python code in C++. @junrushao1994 Could you weigh in about the significance of each phase?

Also, thanks for the naming suggestions, I'm going to go through and clean stuff up later and I will take them into account then! Ideally I'd like to remove flags if possible

Copy link
Contributor

@electriclilies electriclilies left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Tristan!

/*!
* \brief Build an IRModule given a schedule, args and binds
* \param sch The schedule to lower.
* \param args The arguments to the function.
* \param name The name of the lowered function.
* \param binds Buffer assignments.
* \param simple_mode Disables the loop partition pass. Defaults to false.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not completely sure. This logic was already in the code base, so I just duplicated it and made it more explicit. (For context, when I started this refactor, Jared said he wanted to get it in quickly so I should just try to naively duplicate the existing logic, but he hasn't reviewed it yet so I'm not sure what the timeline is now).

I can try to remove it, but I'm not sure what the best way to go about this is since there are a few tests that call lower directly..

* \param binds Buffer assignments.
* \return The result module.
*/
IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initially this was a function called form_irmodule in python. I translated it to C++ and renamed it as ScheduleToModule. Unfortunately, form_irmodule is called by some tests. To preserve that behavior I had to split the functions apart again and register them separately in the FFI.
The difference between these is that ScheduleToModule just converts the schedule to a module that hasn't yet been lowered, whereas LowerSchedule converts the schedule into a module and then applies the passes.

@@ -136,7 +98,7 @@ def lower(

Parameters
----------
input : Union[schedule.Schedule, PrimFunc, IRModule]
inputs : Union[schedule.Schedule, PrimFunc, IRModule]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The word input is a python built in function so you can't name variables input. Initially the name of the parameter to the function was inputs and the name of the parameter in the documentation was input. I just changed the documentation to match the function signature. I don't like that they are different. I agree that inputs is not an ideal name though and I'm open to suggestions for other names

pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
.value();

auto user_lower_phase0 = Array<tvm::transform::Pass>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure, it would be good to get some clarification from someone else on this, maybe @junrushao1994 knows?

Copy link
Contributor

@csullivan csullivan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely phenomenal improvement, thank you for the very well thought out refactor!

@@ -42,17 +43,64 @@
#include <vector>

namespace tvm {

/*!
* \brief Build an IRModule given a module, args and binds
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* \brief Build an IRModule given a module, args and binds
* \brief Build an IRModule given a module

Updated the comment to reflect the API. Please check the docs for the other APIs that don't take args/binds as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated!

include/tvm/driver/driver_api.h Outdated Show resolved Hide resolved
python/tvm/autotvm/feature.py Outdated Show resolved Hide resolved
Comment on lines 57 to 64
binds: dict
The bind specification

arg_list: list
The list of symbolic buffers of arguments.
"""
binds = {} if binds is None else binds.copy()
arg_list = []
for x in args:
if isinstance(x, tensor.Tensor):
any_dim = any(isinstance(i, tvm.tir.Var) for i in x.shape)
buffer_type = "auto_broadcast" if any_dim and not compact else ""
if x not in binds:
buf = tvm.tir.decl_buffer(
x.shape, dtype=x.dtype, name=x.name, buffer_type=buffer_type
)
binds[x] = buf
arg_list.append(buf)
else:
arg_list.append(binds[x])
elif isinstance(x, schedule.Buffer):
arg_list.append(x)
elif isinstance(x, tvm.tir.Var):
arg_list.append(x)
else:
raise ValueError("args must be Tensor, Buffer or Var")
return binds, arg_list


def form_irmodule(sch, args, name, binds):
"""According to the given schedule, form a function.
out_arr = ffi.get_binds(args, compact, binds)
return out_arr[0], out_arr[1]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably we should update the doc string to match the return variable names. Actually the reverse would be preferable, binds, arg_list = ffi.get_binds(...) but I understand you are limited here by the object system.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup can do this

src/driver/driver_api.cc Outdated Show resolved Hide resolved
src/driver/driver_api.cc Outdated Show resolved Hide resolved
src/driver/driver_api.cc Outdated Show resolved Hide resolved
src/driver/driver_api.cc Outdated Show resolved Hide resolved
src/driver/driver_api.cc Outdated Show resolved Hide resolved
Comment on lines 773 to 774
cache_node->funcs =
tvm::LowerSchedule(cfunc->schedule, all_args, cache_node->func_name, binds);
Copy link
Contributor

@csullivan csullivan Jun 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't need to route through the relay.backend.lower packed function any longer right? I'd suggest either removing it and the above conditional (line 766) or also reproducing its functionality in c++ if we do need it as part of this PR. This way we avoid an unnecessary round trip through python.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I don't think we do. I'll remove it (I didn't look too closely at what this call was doing, oops!)

src/driver/driver_api.cc Outdated Show resolved Hide resolved
/*!
* \brief Build an IRModule given a schedule, args and binds
* \param sch The schedule to lower.
* \param args The arguments to the function.
* \param name The name of the lowered function.
* \param binds Buffer assignments.
* \param simple_mode Disables the loop partition pass. Defaults to false.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is probably fine for now.

* \param binds Buffer assignments.
* \return The result module.
*/
IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you be a little more explicit about this in the documentation?

@@ -136,7 +98,7 @@ def lower(

Parameters
----------
input : Union[schedule.Schedule, PrimFunc, IRModule]
inputs : Union[schedule.Schedule, PrimFunc, IRModule]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe use inp instead? I can't really think of a good name either.

Copy link
Member

@Hzfengsy Hzfengsy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this great work! I think it should work for TensorIR as long as pass all existing unit tests :) Generally, it looks good to me except for some nits in code style

c_binds.insert({kv.first, kv.second});
}
}
IRModule mod = ScheduleToModule(sch, args, name, c_binds);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
IRModule mod = ScheduleToModule(sch, args, name, c_binds);
IRModule mod = ScheduleToModule(std::move(sch), args, name, c_binds);


IRModule LowerModule(IRModule mod, bool simple_mode) {
auto pass_list = CreatePassList(simple_mode, false);
return LowerWithPassList(mod, pass_list);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return LowerWithPassList(mod, pass_list);
return LowerWithPassList(std::(mod), pass_list);

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please apply std::move to all following FFI functions

sch = sch.normalize();

// Before TIR transformation.
auto bounds = te::InferBound(sch);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please explicitly mark the type as possible as we can. auto is not so user-friendly that others can understand the code.


// Before TIR transformation.
auto bounds = te::InferBound(sch);
auto stmt = te::ScheduleOps(sch, bounds, false);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use const ref or std::move to reduce unnecessary memory copy

Suggested change
auto stmt = te::ScheduleOps(sch, bounds, false);
auto stmt = te::ScheduleOps(sch, std::move(bounds), false);

Copy link
Contributor

@tkonolige tkonolige left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One small nit, but otherwise things look really good. That net -1K lines is so satisfying to see. Big thanks!

@@ -42,17 +43,67 @@
#include <vector>

namespace tvm {

/*!
* \brief Build an IRModule given an input IRModule
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* \brief Build an IRModule given an input IRModule
* \brief Lower an IRModule

Could you maybe also add a little more detail on what it means to lower an irmodule?

@electriclilies
Copy link
Contributor

Thanks @tkonolige! @Hzfengsy @csullivan can you take another look and approve?

Copy link
Contributor

@csullivan csullivan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks again for the great refactor @electriclilies, removing 4x the code than what is needed by the change == 💪 🔥 .

Copy link
Member

@Hzfengsy Hzfengsy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LTGM. Thanks a lot for the refactoring

@tqchen tqchen merged commit 9dd1286 into apache:main Jun 12, 2021
@tqchen
Copy link
Member

tqchen commented Jun 12, 2021

Thanks @CircleSpin and @electriclilies ! Thanks @Hzfengsy @tkonolige @csullivan @YuchenJin @manupa-arm for reviewing

@electriclilies
Copy link
Contributor

Thanks @tqchen!

trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Jun 17, 2021
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Jun 17, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants