Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR][IR] Allow Module to store BaseFunc #4678

Merged
merged 1 commit into from Jan 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 8 additions & 8 deletions include/tvm/relay/module.h
Expand Up @@ -62,7 +62,7 @@ struct Module;
class ModuleNode : public RelayNode {
public:
/*! \brief A map from ids to all global functions. */
tvm::Map<GlobalVar, Function> functions;
tvm::Map<GlobalVar, BaseFunc> functions;
/*! \brief A map from global type vars to ADT type data. */
tvm::Map<GlobalTypeVar, TypeData> type_definitions;

Expand All @@ -75,7 +75,7 @@ class ModuleNode : public RelayNode {
v->Visit("global_type_var_map_", &global_type_var_map_);
}

TVM_DLL static Module make(tvm::Map<GlobalVar, Function> global_funcs,
TVM_DLL static Module make(tvm::Map<GlobalVar, BaseFunc> global_funcs,
tvm::Map<GlobalTypeVar, TypeData> global_type_defs,
std::unordered_set<std::string> imports = {});

Expand All @@ -86,7 +86,7 @@ class ModuleNode : public RelayNode {
* \param update Controls whether you can replace a definition in the
* environment.
*/
TVM_DLL void Add(const GlobalVar& var, const Function& func, bool update = false);
TVM_DLL void Add(const GlobalVar& var, const BaseFunc& func, bool update = false);

/*!
* \brief Add a function to the global environment.
Expand All @@ -95,7 +95,7 @@ class ModuleNode : public RelayNode {
*
* It does not do type inference as Add does.
*/
TVM_DLL void AddUnchecked(const GlobalVar& var, const Function& func);
TVM_DLL void AddUnchecked(const GlobalVar& var, const BaseFunc& func);

/*!
* \brief Add a type-level definition to the global environment.
Expand Down Expand Up @@ -124,7 +124,7 @@ class ModuleNode : public RelayNode {
* \param var The name of the global function to update.
* \param func The new function.
*/
TVM_DLL void Update(const GlobalVar& var, const Function& func);
TVM_DLL void Update(const GlobalVar& var, const BaseFunc& func);

/*!
* \brief Update a type definition in the global environment.
Expand Down Expand Up @@ -184,14 +184,14 @@ class ModuleNode : public RelayNode {
* \param var The global var to lookup.
* \returns The function named by the variable argument.
*/
TVM_DLL Function Lookup(const GlobalVar& var) const;
TVM_DLL BaseFunc Lookup(const GlobalVar& var) const;

/*!
* \brief Look up a global function by its string name
* \param name The name of the function.
* \returns The function named by the argument.
*/
TVM_DLL Function Lookup(const std::string& name) const;
TVM_DLL BaseFunc Lookup(const std::string& name) const;

/*!
* \brief Look up a global type definition by its variable.
Expand Down Expand Up @@ -256,7 +256,7 @@ class ModuleNode : public RelayNode {
*/
TVM_DLL static Module FromExpr(
const Expr& expr,
const tvm::Map<GlobalVar, Function>& global_funcs = {},
const tvm::Map<GlobalVar, BaseFunc>& global_funcs = {},
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions = {});

static constexpr const char* _type_key = "relay.Module";
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/build_module.cc
Expand Up @@ -463,7 +463,7 @@ class RelayBuildModule : public runtime::ModuleNode {
// Optimize input Relay Function and returns Relay Module
relay::Module relay_module = Optimize(func, targets_, params);
// Get the updated function.
func = relay_module->Lookup("main");
func = Downcast<Function>(relay_module->Lookup("main"));

// Generate code for the updated function.
graph_codegen_ = std::unique_ptr<GraphCodegen>(new GraphCodegen());
Expand Down
29 changes: 20 additions & 9 deletions src/relay/backend/vm/compiler.cc
Expand Up @@ -612,7 +612,13 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
CHECK(it != context_->global_map.end());
DLOG(INFO) << "VisitExpr_: generating invoke for " << global->name_hint
<< " with func_index=" << it->second;
auto func = context_->module->Lookup(global);

// TODO(tvm-team):
// Think about mixed call into global that is not a relay::Function
// perhaps establish as an invariance(all functions in mod must be relay::Function)
auto func = Downcast<Function>(context_->module->Lookup(global));


if (IsClosure(func)) {
auto arity = func->params.size();
Emit(Instruction::AllocClosure(it->second, arity, args_registers, NewRegister()));
Expand Down Expand Up @@ -813,7 +819,10 @@ void VMCompiler::Lower(Module mod,
CHECK_EQ(targets.size(), 1)
<< "Currently VM compiler doesn't support heterogeneous compilation";
if (params_.size()) {
auto f = BindParamsByName(mod->Lookup("main"), params_);
BaseFunc base_func = mod->Lookup("main");
CHECK(base_func->IsInstance<FunctionNode>())
<< "VM compiler expects to compile relay::Function";
auto f = BindParamsByName(Downcast<Function>(base_func), params_);
auto gvar = mod->GetGlobalVar("main");
mod->Add(gvar, f);
}
Expand All @@ -837,13 +846,15 @@ void VMCompiler::Lower(Module mod,

for (auto named_func : context_.module->functions) {
auto gvar = named_func.first;
auto func = named_func.second;
VMFunctionCompiler func_compiler(&context_, targets_, target_host_);
auto vm_func = func_compiler.Compile(gvar, func);

size_t func_index = context_.global_map.at(gvar);
CHECK(func_index < exec_->functions.size());
exec_->functions[func_index] = vm_func;
if (auto* n = named_func.second.as<FunctionNode>()) {
auto func = GetRef<Function>(n);
VMFunctionCompiler func_compiler(&context_, targets_, target_host_);
auto vm_func = func_compiler.Compile(gvar, func);

size_t func_index = context_.global_map.at(gvar);
CHECK(func_index < exec_->functions.size());
exec_->functions[func_index] = vm_func;
}
}

#if USE_RELAY_DEBUG
Expand Down
30 changes: 17 additions & 13 deletions src/relay/backend/vm/inline_primitives.cc
Expand Up @@ -110,19 +110,23 @@ struct PrimitiveInliner : ExprMutator {
auto gvar_funcs = module_->functions;
for (auto pair : gvar_funcs) {
auto global = pair.first;
auto func = pair.second;
DLOG(INFO) << "Before inlining primitives: " << global
<< std::endl << AsText(func, false);

func = FunctionNode::make(func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
module_->Add(global, func, true);

DLOG(INFO) << "After inlining primitives: " << global
<< std::endl << AsText(func, false);
auto base_func = pair.second;
if (auto* n = base_func.as<FunctionNode>()) {
auto func = GetRef<Function>(n);

DLOG(INFO) << "Before inlining primitives: " << global
<< std::endl << AsText(func, false);

func = FunctionNode::make(func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
module_->Add(global, func, true);

DLOG(INFO) << "After inlining primitives: " << global
<< std::endl << AsText(func, false);
}
}
return module_;
}
Expand Down
16 changes: 9 additions & 7 deletions src/relay/backend/vm/lambda_lift.cc
Expand Up @@ -188,13 +188,15 @@ class LambdaLifter : public ExprMutator {
// There is an ordering bug here.
auto glob_funcs = module_->functions;
for (auto pair : glob_funcs) {
auto func = pair.second;
func = FunctionNode::make(func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
module_->Add(pair.first, func, true);
if (auto* n = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(n);
func = FunctionNode::make(func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
module_->Add(pair.first, func, true);
}
}
return module_;
}
Expand Down