Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved MLF to contain workspace info #7938

Merged
merged 5 commits into from
May 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 90 additions & 5 deletions python/tvm/micro/model_library_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
from ..relay.backend import executor_factory
from ..relay import param_dict

# This should be kept identical to runtime::symbol::tvm_module_main
MAIN_FUNC_NAME_STR = "__tvm_main__"


class UnsupportedInModelLibraryFormatError(Exception):
"""Raised when export_model_library_format does not support the given Module tree."""
Expand Down Expand Up @@ -73,8 +76,16 @@ def _populate_codegen_dir(mod, codegen_dir: str):
dso_mod.save(file_name)


def _build_memory_map(graph_json):
"""Build a simpler memory map from graph JSON.
def _build_memory_map(mod):
ret = dict()
if isinstance(mod, executor_factory.GraphExecutorFactoryModule):
ret["sids"] = _build_sid_map(mod.graph_json)
ret["functions"] = _build_function_memory_map(mod.function_metadata)
return ret


def _build_sid_map(graph_json):
"""Build a simpler storage id info map from graph JSON.

Parameters
----------
Expand Down Expand Up @@ -117,6 +128,81 @@ def _build_memory_map(graph_json):
return memory_map


def _build_function_memory_map(function_metadata):
"""Build a simple map that shows how much workspace is required to execute
each primitive function. The main_func describes how much memory is required
to execute the main control code.

Parameters
----------
function_metadata : Map<String, FunctionInfo>
This contains all the compiled metadata on a function basis

Returns
-------
dict :
This will have two entries:
1.) A list with one entry per function describing local memory it is using.
2.) A global memory requirement if all functions are executed sequentially
"""
device_max_workspace = dict()
main_func_metadata = function_metadata[MAIN_FUNC_NAME_STR]
num_targets = len(main_func_metadata.workspace_sizes.items())
func_entries = []
target_local_entries = dict()
for i in range(num_targets):
target = main_func_metadata.workspace_sizes.items()[i][0]
device_max_workspace[target] = 0
for func_name, finfo in function_metadata.items():
if func_name == MAIN_FUNC_NAME_STR:
continue
target_local_entries[func_name] = list()

for func_name, finfo in function_metadata.items():
if func_name == MAIN_FUNC_NAME_STR:
continue
assert len(finfo.constant_sizes.items()) == num_targets
assert len(finfo.io_sizes.items()) == num_targets
target = finfo.workspace_sizes.items()[i][0]
workspace_size = finfo.workspace_sizes.items()[i][1]
target_entry = {
"device": int(target.kind.device_type),
"workspace_size_bytes": int(workspace_size),
}
target_local_entries[func_name].append(target_entry)
if workspace_size > device_max_workspace[target]:
device_max_workspace[target] = workspace_size

for func_name, target_entries_ in target_local_entries.items():
func_entry = {
"function_name": str(func_name),
"workspace": target_entries_,
}
func_entries.append(func_entry)

target_main_entries = list()
for i in range(num_targets):
target = main_func_metadata.workspace_sizes.items()[i][0]
main_func_local_workspace = main_func_metadata.workspace_sizes.items()[i][1]
main_func_constants = main_func_metadata.constant_sizes.items()[i][1]
main_func_io = main_func_metadata.io_sizes.items()[i][1]
target_main_entries.append(
{
"device": int(target.kind.device_type),
"workspace_size_bytes": int(device_max_workspace[target])
+ int(main_func_local_workspace),
"constants_size_bytes": int(main_func_constants),
"io_size_bytes": int(main_func_io),
}
)

ret = {
"operator_functions": func_entries,
"main": target_main_entries,
}
return ret


def export_model_library_format(mod: executor_factory.ExecutorFactoryModule, file_name):
"""Export the build artifact in Model Library Format.

Expand All @@ -133,14 +219,13 @@ def export_model_library_format(mod: executor_factory.ExecutorFactoryModule, fil
"""
tempdir = utils.tempdir()
is_aot = isinstance(mod, executor_factory.AOTExecutorFactoryModule)
memory_map = [] if is_aot else _build_memory_map(mod.get_executor_config())
runtime = ["aot"] if is_aot else ["graph"]

metadata = {
"version": 1,
"version": 2,
"model_name": mod.libmod_name,
"export_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%SZ"),
"memory": memory_map,
"memory": _build_memory_map(mod),
"target": {int(k): str(v) for k, v in mod.target.items()},
"runtimes": runtime,
}
Expand Down
12 changes: 10 additions & 2 deletions python/tvm/relay/backend/executor_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,18 @@ class AOTExecutorFactoryModule(ExecutorFactoryModule):
The name of module
params : dict of str to NDArray
The parameters of module
function_metadata : Map of String to FunctionInfo
This holds a map function names to their information
"""

def __init__(self, ir_mod, target, libmod, libmod_name, params):
def __init__(self, ir_mod, target, libmod, libmod_name, params, function_metadata):
self.ir_mod = ir_mod
self.target = target
self.lib = libmod
self.libmod_name = libmod_name
self.params = params
self.iter_cnt = 0
self.function_metadata = function_metadata

def get_params(self):
return self.params
Expand Down Expand Up @@ -118,9 +121,13 @@ class GraphExecutorFactoryModule(ExecutorFactoryModule):
The name of module
params : dict of str to NDArray
The parameters of module
function_metadata : Map of String to FunctionInfo
This holds a map function names to their information
"""

def __init__(self, ir_mod, target, graph_json_str, libmod, libmod_name, params):
def __init__(
self, ir_mod, target, graph_json_str, libmod, libmod_name, params, function_metadata
):
assert isinstance(graph_json_str, string_types)
fcreate = get_global_func("tvm.graph_executor_factory.create")
args = []
Expand All @@ -136,6 +143,7 @@ def __init__(self, ir_mod, target, graph_json_str, libmod, libmod_name, params):
self.libmod_name = libmod_name
self.params = params
self.iter_cnt = 0
self.function_metadata = function_metadata

def export_library(self, file_name, fcompile=None, addons=None, **kwargs):
return self.module.export_library(file_name, fcompile, addons, **kwargs)
Expand Down
12 changes: 10 additions & 2 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(self):
self._optimize = self.mod["optimize"]
self._set_params_func = self.mod["set_params"]
self._get_params_func = self.mod["get_params"]
self._get_function_metadata = self.mod["get_function_metadata"]

def build(self, mod, target=None, target_host=None, params=None, executor="graph"):
"""
Expand Down Expand Up @@ -200,6 +201,12 @@ def get_module(self):
"""Return the built module."""
return self._get_module()

def get_function_metadata(self):
"""Return the compiled function metadata.
Currently, the metadata contains workspace size required by
each PrimFunc"""
return self._get_function_metadata()

def get_params(self):
"""Return the updated weights."""
params = self._get_params_func()
Expand Down Expand Up @@ -325,14 +332,15 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default"
executor_config, runtime_mod, params = bld_mod.build(
mod=ir_mod, target=target, params=params, executor=executor
)
func_metadata = bld_mod.get_function_metadata()

if executor == "aot":
executor_factory = _executor_factory.AOTExecutorFactoryModule(
ir_mod, target, runtime_mod, mod_name, params
ir_mod, target, runtime_mod, mod_name, params, func_metadata
)
elif executor == "graph":
executor_factory = _executor_factory.GraphExecutorFactoryModule(
ir_mod, target, executor_config, runtime_mod, mod_name, params
ir_mod, target, executor_config, runtime_mod, mod_name, params, func_metadata
)
else:
assert False, "Executor " + executor + " not supported"
Expand Down
87 changes: 86 additions & 1 deletion src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@
#include <tvm/ir/module.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/object.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt.h>

#include <algorithm>
Expand Down Expand Up @@ -270,6 +273,79 @@ class AOTExecutorCodegen : public ExprVisitor {
return ss.str();
}

/*!
* \brief Update the "main" control function's metadata
*
* \param func The main function that contains calls to operator tir primitive functions
*/
void UpdateMainWorkspaceSize(const tir::PrimFunc& primfunc, const relay::Function& func) {
Integer workspace_size = CalculateWorkspaceBytes(primfunc);
// Populate FunctionInfo
auto fi_node = make_object<FunctionInfoNode>();
// Initialize all target workspaces to zero
for (const auto& kv : targets_) {
auto tgt = kv.second;
fi_node->workspace_sizes.Set(tgt, 0);
}
fi_node->workspace_sizes.Set(target_host_, workspace_size);
fi_node->relay_primfuncs.Set(target_host_, func);

int64_t io_size = 0;
for (const auto& input : input_vars_) {
io_size += CalculateRelayExprSizeBytes(input->checked_type());
}
io_size += CalculateRelayExprSizeBytes(func->body->checked_type());
fi_node->io_sizes.Set(target_host_, io_size);

int64_t const_size = 0;
for (const auto& kv : params_by_expr_) {
const_size += CalculateRelayExprSizeBytes(kv.first->checked_type());
}
fi_node->constant_sizes.Set(target_host_, const_size);
function_metadata_.Set(String(runtime::symbol::tvm_module_main), FunctionInfo(fi_node));
}

/*!
* \brief Update the function metadata for a given cached function and its relay
* primitive function.
*
* \param cfunc The cached function as provided the by the compile engine
* \param relay_func The source relay primitive function
* \param relay_target The target associated with relay primitive function
*/
void UpdateFunctionMetadata(const CachedFunc& cfunc, const Function& relay_func,
const Target& relay_target) {
auto fi_node = make_object<FunctionInfoNode>();
for (const auto& kv : cfunc->funcs->functions) {
auto primfunc = Downcast<tir::PrimFunc>(kv.second);
Integer workspace_size = CalculateWorkspaceBytes(primfunc);
Target primfunc_target = relay_target;
if (primfunc->attrs->dict.count("target")) {
primfunc_target = Downcast<Target>(primfunc->attrs->dict["target"]);
}
fi_node->workspace_sizes.Set(primfunc_target, workspace_size);
// Calculating size for I/O
for (auto const& param : primfunc->params) {
auto p_shape = primfunc->buffer_map[param]->shape;
int num_of_elements = 1;
for (const auto& dim_index_expr : p_shape) {
if (dim_index_expr->IsInstance<IntImmNode>()) {
num_of_elements *= dim_index_expr.as<IntImmNode>()->value;
} else {
// If shape is dynamic, we cannot calculate workspace in compile time.
num_of_elements = 0;
}
}
int element_size = primfunc->buffer_map[param]->dtype.bytes();
fi_node->io_sizes.Set(primfunc_target, element_size * num_of_elements);
}
fi_node->constant_sizes.Set(primfunc_target, 0);
fi_node->tir_primfuncs.Set(primfunc_target, primfunc);
fi_node->relay_primfuncs.Set(primfunc_target, relay_func);
}
function_metadata_.Set(cfunc->func_name, FunctionInfo(fi_node));
}

void VisitExpr_(const CallNode* op) override {
// Descend the call tree
for (auto arg : op->args) {
Expand Down Expand Up @@ -336,6 +412,8 @@ class AOTExecutorCodegen : public ExprVisitor {
lowered_funcs_[target->str()] = IRModule(Map<GlobalVar, BaseFunc>({}));
}
lowered_funcs_[target->str()]->Update(lowered_func->funcs);
// Update function metadata via looking at all primfuncs
UpdateFunctionMetadata(lowered_func, func, target);

// Generate the TIR function call
CreateFuncCall(GetRef<Call>(op), lowered_func->func_name);
Expand Down Expand Up @@ -488,6 +566,8 @@ class AOTExecutorCodegen : public ExprVisitor {
std::unordered_map<int, te::Var> sids_table_;
/*! \brief lowered funcs */
std::unordered_map<std::string, IRModule> lowered_funcs_;
/*! \brief lowered funcs */
Map<String, FunctionInfo> function_metadata_;
/*! \brief compile engine */
CompileEngine compile_engine_;
/*! \brief the set of statements that make the program */
Expand Down Expand Up @@ -531,6 +611,7 @@ class AOTExecutorCodegen : public ExprVisitor {
VisitExpr(func->body);

auto prim_func = CreateMainFunc(func->params.size());
UpdateMainWorkspaceSize(prim_func, func);
LoweredOutput ret;

ret.params = std::unordered_map<std::string, std::pair<int, const tvm::runtime::NDArray>>();
Expand Down Expand Up @@ -559,7 +640,7 @@ class AOTExecutorCodegen : public ExprVisitor {
symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_prefix), prim_func);
ret.lowered_funcs.Set(target_host_str, IRModule(symbol_map));
}

ret.function_metadata = std::move(function_metadata_);
ret.metadata =
runtime::Metadata(input_vars_.size(), return_sid_.size(), runtime::kTvmExecutorAot);
return ret;
Expand Down Expand Up @@ -602,6 +683,10 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode {
} else if (name == "get_external_modules") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = get_external_modules(); });
} else if (name == "get_function_metadata") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->output_.function_metadata;
});
} else if (name == "get_metadata") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = output_.metadata; });
Expand Down
8 changes: 8 additions & 0 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ struct ExecutorCodegen {

virtual void UpdateOutput(BuildOutput* ret) = 0;

Map<String, FunctionInfo> GetFunctionMetadata() {
return CallFunc<Map<String, FunctionInfo>>("get_function_metadata", nullptr);
}

std::unordered_map<std::string, tvm::runtime::NDArray> GetParams() {
std::unordered_map<std::string, tvm::runtime::NDArray> ret;
auto names = CallFunc<Array<runtime::String>>("list_params_name", nullptr);
Expand Down Expand Up @@ -197,6 +201,10 @@ class RelayBuildModule : public runtime::ModuleNode {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->executor_codegen_->GetExternalModules();
});
} else if (name == "get_function_metadata") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->executor_codegen_->GetFunctionMetadata();
});
} else if (name == "optimize") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
ICHECK_EQ(args.num_args, 2);
Expand Down
Loading