Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[runtime] AOTExecutor implementation and c target code-generator #10283

Merged
merged 21 commits into from Mar 3, 2022

Conversation

areusch
Copy link
Contributor

@areusch areusch commented Feb 17, 2022

This PR adds the AOTExecutor implementation and the metadata code-generator for use with the c target. This PR allows users to run models which use only the CPU via the AOTExecutor using the C++ runtime.

cc @Mousius @manupa-arm @kparzysz-quic @masahi @mehrdadh

Copy link
Contributor

@kparzysz-quic kparzysz-quic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks ok to me. Please change the make_unique statements though, but it can go into a subsequent PR.


const int num_args = args_.size();
::std::unique_ptr<TVMValue> call_values{new TVMValue[num_args]};
::std::unique_ptr<int> call_type_codes{new int[num_args]};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two lines should be

auto call_values = std::make_unique<TVMValue[]>(num_args);
auto call_type_codes = std::make_unique<int[]>(num_args);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed, thanks

return_ttypes_.reserve(ttypes.size());
for (auto ttype : ttypes) {
return_ttypes_.push_back(ttype);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of this is equivalent to return_ttypes_ = FlattenTupleType(e->checked_type());.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Member

@masahi masahi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found some left-over log while running test_conv2d

auto metadata = Downcast<runtime::metadata::MetadataBase>(*value);
const runtime::metadata::MetadataArrayNode* arr =
value->as<runtime::metadata::MetadataArrayNode>();
std::cout << "Is array? " << arr << std::endl;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

void Visit(const char* key, ObjectRef* value) final {
const runtime::metadata::MetadataArrayNode* arr =
value->as<runtime::metadata::MetadataArrayNode>();
std::cout << "Is array? " << arr << std::endl;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

for (unsigned int i = 0; i < array->array.size(); ++i) { // ObjectRef o : *(array->array)) {
ObjectRef o = array->array[i];
std::cout << "visiting array element " << i << ": " << o->type_index() << " ("
<< o.operator->() << ")" << std::endl;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

for (unsigned int i = 0; i < arr->array.size(); i++) {
ObjectRef o = arr->array[i];
std::cout << "queue-visiting array element " << i << ": " << o->type_index() << " ("
<< o.operator->() << ")" << std::endl;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

auto struct_name = std::get<0>(item);
auto obj = std::get<1>(item);
auto arr = obj.as<runtime::metadata::MetadataArrayNode>();
std::cout << "codegen: " << struct_name;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

pools.push_back(
runtime::metadata::TensorInfo(make_object<target::metadata::InMemoryTensorInfoNode>(
var->name_hint,
std::vector<int64_t>{metadata->pool_inputs.value()[var]->allocated_size},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check that metadata->pool_inputs is non-null (it's Optional<Map<tir::Var, tir::usmp::AllocatedPoolInfo>>)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it done?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is done earlier in the function, but it was in the wrong spot and i've reorganized the checks.

lib: tvm.runtime.Module = tvm.runtime.load_module("compiled_lib.so")
# Call the library factory function for default and create
# a new runtime.Module, wrap with graph module.
gmod = graph_executor.GraphModule(lib["default"](dev))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need update?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed, thanks!

- GraphExecutor, in python/tvm/contrib/graph_executor.py
- VM Executor, in python/tvm/runtime/vm.py

TODO(areusch): Consolidate these into this module.
Copy link
Member

@masahi masahi Mar 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized that, we have two notions of executor. One is the runtime one above, the other is

def create_executor(kind="debug", mod=None, device=None, target="llvm", params=None):

which is used a lot in the test cases.

Do we intend to support create_executor(kind="aot", ...), given that we can now run things via the cpp runtime?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that is a good point. added support here.

@@ -75,6 +75,13 @@ struct TVMMetadata {
const struct TVMTensorInfo* outputs;
/*! \brief Number of elements in `outputs` array. */
int64_t num_outputs;
/*! \brief Memory Pools needed by the AOT run_model function.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are not using run_model anymore, are you? In your previous branch indeed I'm seeing tvmgen_default_run_model generated, but after rebase it is replaced with tvmgen_default___tvm_main__.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is right!.

We had run (entry_point), run_model and tvm_main.
The entry_point is supposed to call run_model, however, run_model is identical to tvm_main -- therefore, it was removed due to no need of maintaining two symbols for relay and tir versions of main.

So I think it needs to be tvm_main now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to just "AOT main" since the function should probably eventually be renamed based on mod_name.

Copy link
Contributor

@manupak manupak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat! Looking good.

Few nits and questions (for my understanding).

import numpy as np


class AotModule(object):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(stylistic) : AotModule --> AOTModule or AoTModule ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i kind of find it easier to read the acronyms if we use CapWords, this also follows https://www.python.org/dev/peps/pep-0008/#class-names which is linked from numpydoc

@@ -317,6 +327,16 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}
}

void PushArgs(const Expr& expr, std::vector<tir::Var> sids, Array<PrimExpr>* args) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit : any reason mixed pass-by-value, pass-by-(const)-references usage and pointers here ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, sorry kinda rough. changed to pass-by-const-ref and pointer per c++ style guide.


// Validate choice of use_unpacked_api_ and use_call_cpacked_
if (runtime_config->name == kTvmRuntimeCrt) {
CHECK(interface_api == "packed" || static_cast<bool>(use_unpacked_api_) == true)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are both of these CHECKs tested ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the CRT CHECK is thoroughly tested by existing code (ran into a lot of those tests in making CI happy). added tests for the C++ one.


/*!
* \file aot_executor_factory.cc
* \brief Graph executor factory implementations
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need update

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

namespace tvm {
namespace runtime {

AotExecutorFactory::AotExecutorFactory(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my knowledge, I never understood why these are called "factories". What maybe the reason ? :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's because they are container objects which hold the constructor arguments to an executor. They also handle serializing those arguments during export_library and reconstructing them during load_module. they follow the factory pattern (e.g. a no-arg or less-args functor that constructs a given datatype), so that's why the name.

with pytest.raises(tvm.TVMError, match="Packed interface required for packed operators"):
with pytest.raises(
tvm.TVMError,
match=re.escape(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the above comment, I'd expect two of these tests here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think the check fails when both of these conditions are false, so we just need one to assert we get the error message. the remaining single-true cases are tested in various other tests, so i think we should be ok. but lmk if you have a more specific concern.

Copy link
Member

@masahi masahi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More minor comments, finished the whole thing.

At some point we should review our meta data situation... There are so many classes / files and they look quite messy.

<< ") or unpacked-api == true (got: " << use_unpacked_api_
<< ") when targeting c runtime";
} else if (runtime_config->name == kTvmRuntimeCpp) {
CHECK(static_cast<bool>(use_unpacked_api_) == false &&
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should use ICHECK, also at L918

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

* processor.
* \param devs The device of the host and devices where graph nodes will be
* executed on.
* \param lookup_linked_param_func If given, a PackedFunc invoked to lookup linked parameters
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lookup_linked_param_func not given in the param list

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

* \brief Initialize the AOT executor with metadata, runtime::Module, and device.
* \param module The module containing the compiled functions for the host
* processor.
* \param devs The device of the host and devices where graph nodes will be
Copy link
Member

@masahi masahi Mar 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no "graph".

And why only the first element of devs is ever used in aot_executor.cc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reworded the comment, added ICHECK, and tested.

// "
// << "(when passed from runtime)";
metadata_ = metadata;
// code_ = code;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

CHECK_EQ(ret_type_code, kTVMOpaqueHandle)
<< "Expected kOpaqueHandle returned; got " << ret_type_code;
CHECK(ret_value.v_handle != nullptr) << "get_c_metadata returned nullptr";

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ICHECK or ICHECK_EQ in this function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

* module.
* \param target_module The main TIR-lowered internal runtime module
* \param modules All the external modules that needs to be imported inside the metadata module(s).
* \param target The target that all the modules are compiled for
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

metadata not documented

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

address_.push_back(i_str.str());
Visit(nullptr, &metadata);
address_.pop_back();
// ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


auto mod = MetadataModuleCreate(metadata);
std::vector<String> func_names{"get_c_metadata"};
// definer.GetOutput() +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

lookup_func << "};" << std::endl;

auto mod = MetadataModuleCreate(metadata);
std::vector<String> func_names{"get_c_metadata"};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than hard code get_c_metadata everywhere, we should introduce a new symbol at

namespace symbol {
(similar to tvm_module_main)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

{sid_array, 0, tir::builtin::kArrDeviceType, kDLCPU})));
new_stmts.push_back(tir::Evaluate(
tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
{sid_array, 0, tir::builtin::kArrDeviceId, 0})));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are these changes for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these changes fixed the flaky stack bug where device allocations were using bad threadlocalstorage for workspace. the fix was to explicitly initialize the device_id and device_type fields of the created DLTensor. there are other fields we need to initialize as well, but which we don't use in the c-codgen generated code, at least not extensively.

return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
if (!metadata_.defined()) {
TVMFunctionHandle f_handle;
int32_t ret_code = TVMBackendGetFuncFromEnv(this, "get_c_metadata", &f_handle);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

@areusch areusch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@manupa-arm @kparzysz-quic @masahi thanks for the review! addressed your comments, ptal when you have time.

@@ -75,6 +75,13 @@ struct TVMMetadata {
const struct TVMTensorInfo* outputs;
/*! \brief Number of elements in `outputs` array. */
int64_t num_outputs;
/*! \brief Memory Pools needed by the AOT run_model function.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to just "AOT main" since the function should probably eventually be renamed based on mod_name.

- GraphExecutor, in python/tvm/contrib/graph_executor.py
- VM Executor, in python/tvm/runtime/vm.py

TODO(areusch): Consolidate these into this module.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that is a good point. added support here.

import numpy as np


class AotModule(object):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i kind of find it easier to read the acronyms if we use CapWords, this also follows https://www.python.org/dev/peps/pep-0008/#class-names which is linked from numpydoc

lib: tvm.runtime.Module = tvm.runtime.load_module("compiled_lib.so")
# Call the library factory function for default and create
# a new runtime.Module, wrap with graph module.
gmod = graph_executor.GraphModule(lib["default"](dev))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed, thanks!

@@ -317,6 +327,16 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}
}

void PushArgs(const Expr& expr, std::vector<tir::Var> sids, Array<PrimExpr>* args) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, sorry kinda rough. changed to pass-by-const-ref and pointer per c++ style guide.

* module.
* \param target_module The main TIR-lowered internal runtime module
* \param modules All the external modules that needs to be imported inside the metadata module(s).
* \param target The target that all the modules are compiled for
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

address_.push_back(i_str.str());
Visit(nullptr, &metadata);
address_.pop_back();
// ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

lookup_func << "};" << std::endl;

auto mod = MetadataModuleCreate(metadata);
std::vector<String> func_names{"get_c_metadata"};
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


auto mod = MetadataModuleCreate(metadata);
std::vector<String> func_names{"get_c_metadata"};
// definer.GetOutput() +
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

{sid_array, 0, tir::builtin::kArrDeviceType, kDLCPU})));
new_stmts.push_back(tir::Evaluate(
tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
{sid_array, 0, tir::builtin::kArrDeviceId, 0})));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these changes fixed the flaky stack bug where device allocations were using bad threadlocalstorage for workspace. the fix was to explicitly initialize the device_id and device_type fields of the created DLTensor. there are other fields we need to initialize as well, but which we don't use in the c-codgen generated code, at least not extensively.

@github-actions github-actions bot requested a review from masahi March 2, 2022 18:24
Copy link
Member

@masahi masahi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM modulo some unaddressed comments. But they can go into the next LLVM PR.

#10283 (comment)
#10283 (comment)

@github-actions github-actions bot requested a review from masahi March 3, 2022 00:22
@areusch
Copy link
Contributor Author

areusch commented Mar 3, 2022

@manupa-arm @kparzysz-quic please take a look when you have a minute, CI is green now.

Copy link
Contributor

@tmoreau89 tmoreau89 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - thanks @areusch for landing AoT executor !

@tmoreau89 tmoreau89 merged commit d721d32 into apache:main Mar 3, 2022
@tmoreau89
Copy link
Contributor

Thank you @areusch @masahi @kparzysz-quic @manupa-arm for the reviews! The PR has been merged.

shukun-ziqiangxu pushed a commit to shukun-ziqiangxu/tvm that referenced this pull request Mar 6, 2022
…che#10283)

* Add memory pools to Metadata classes.

* Move ShapeToJSON to utils.

* Track returned TensorType from AOTExecutorCodegen.

* Support calling Relay functions with Tuple.

* Expand supported TIR calling conventions to work with C++ runtime.

* Rename MetadataModule to ConstLoaderModule.

* Add runtime AOT executor module.

* Add AOT code-generation.

* Add a runtime Module to mux between .text Metadata and live Metadata.

* Move launch_param to namespace

* Add test of c++ AOT.

* Fix incongruity between kTvmRuntimeCrt constant

* Expand ExecutorCodegenMetadata to include AOT runtime metadata.

* commit cpp test

* Make Metadata compile under C.

* Ignore ephemeral metadata_module export_model_library_format.

 * This module does not need to be exported, since it is merely a C++
   wrapper around get_c_metadata, and get_metadata is not used in C.

* address manupa, kparszsyc, masahi comments.

* further address comments

* clang and python format

* Fix broken test

* Address lingering comments from masahi, kparszyzc
pfk-beta pushed a commit to pfk-beta/tvm that referenced this pull request Apr 11, 2022
…che#10283)

* Add memory pools to Metadata classes.

* Move ShapeToJSON to utils.

* Track returned TensorType from AOTExecutorCodegen.

* Support calling Relay functions with Tuple.

* Expand supported TIR calling conventions to work with C++ runtime.

* Rename MetadataModule to ConstLoaderModule.

* Add runtime AOT executor module.

* Add AOT code-generation.

* Add a runtime Module to mux between .text Metadata and live Metadata.

* Move launch_param to namespace

* Add test of c++ AOT.

* Fix incongruity between kTvmRuntimeCrt constant

* Expand ExecutorCodegenMetadata to include AOT runtime metadata.

* commit cpp test

* Make Metadata compile under C.

* Ignore ephemeral metadata_module export_model_library_format.

 * This module does not need to be exported, since it is merely a C++
   wrapper around get_c_metadata, and get_metadata is not used in C.

* address manupa, kparszsyc, masahi comments.

* further address comments

* clang and python format

* Fix broken test

* Address lingering comments from masahi, kparszyzc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants