Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] External codegen #4482

Merged
merged 11 commits into from
Dec 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ include(cmake/modules/LLVM.cmake)
include(cmake/modules/Micro.cmake)
include(cmake/modules/ANTLR.cmake)
include(cmake/modules/contrib/BLAS.cmake)
include(cmake/modules/contrib/CODEGENC.cmake)
include(cmake/modules/contrib/DNNL.cmake)
include(cmake/modules/contrib/Random.cmake)
include(cmake/modules/contrib/MicroStandaloneRuntime.cmake)
include(cmake/modules/contrib/Sort.cmake)
Expand Down
3 changes: 3 additions & 0 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ set(USE_ROCBLAS OFF)
# Whether use contrib sort
set(USE_SORT ON)

# Whether use MKL-DNN (DNNL) codegen
set(USE_DNNL_CODEGEN OFF)

# Build ANTLR parser for Relay text format
# Possible values:
# - ON: enable ANTLR by searching default locations (cmake find_program for antlr4 and /usr/local for jar)
Expand Down
20 changes: 20 additions & 0 deletions cmake/modules/contrib/CODEGENC.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

file(GLOB CSOURCE_RELAY_CONTRIB_SRC src/relay/backend/contrib/codegen_c/codegen.cc)
list(APPEND COMPILER_SRCS ${CSOURCE_RELAY_CONTRIB_SRC})

28 changes: 28 additions & 0 deletions cmake/modules/contrib/DNNL.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

if(USE_DNNL_CODEGEN STREQUAL "ON")
file(GLOB DNNL_RELAY_CONTRIB_SRC src/relay/backend/contrib/dnnl/codegen.cc)
list(APPEND COMPILER_SRCS ${DNNL_RELAY_CONTRIB_SRC})

find_library(EXTERN_LIBRARY_DNNL dnnl)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${EXTERN_LIBRARY_DNNL})
file(GLOB DNNL_CONTRIB_SRC src/runtime/contrib/dnnl/*)
list(APPEND RUNTIME_SRCS ${DNNL_CONTRIB_SRC})
message(STATUS "Build with DNNL codegen: " ${EXTERN_LIBRARY_DNNL})
endif()

3 changes: 3 additions & 0 deletions include/tvm/build_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ TVM_DLL Target intel_graphics(const std::vector<std::string>& options =
TVM_DLL Target stackvm(const std::vector<std::string>& options =
std::vector<std::string>());

/*! \return A target for external device */
TVM_DLL Target ext_dev(const std::vector<std::string>& options =
zhiics marked this conversation as resolved.
Show resolved Hide resolved
std::vector<std::string>());
} // namespace target

/*!
Expand Down
28 changes: 28 additions & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,15 @@ class FunctionNode : public ExprNode {
*/
bool IsPrimitive() const;

/*!
* \brief Check whether the function should use the TVM default compiler to build, or
* use other compilers.
*
* \return Whether the function will be compiled using the default compiler
* (e.g. those are used in the TVM stack).
*/
bool UseDefaultCompiler() const;

TVM_DLL static Function make(tvm::Array<Var> params,
Expr body,
Type ret_type,
Expand Down Expand Up @@ -588,6 +597,25 @@ std::string AsText(const NodeRef& node,
bool show_meta_data = true,
runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);

/*! \brief namespace of the attributes that are attached to a function. */
namespace attr {
/*! \brief Mark the function as a primitive function. */
constexpr const char* kPrimitive = "Primitive";
/*!
* \brief Indicate the compiler that should be used for builing this function.
* When this is unset or set to "default", the default compilation pipeline will be used.
*/
constexpr const char* kCompiler = "Compiler";
/*! \brief Indicate if the function is a closure. */
constexpr const char* kClosure = "Closure";
/*! \brief Store a Var to parameter/Constant mapping on a Function. */
constexpr const char* kParams = "__params__";
/*! \brief Store the unique external symbol for external compilers. */
constexpr const char* kExternalSymbol = "ExternalSymbol";
/*! \brief Mark if the function should be avoided being optimized. */
constexpr const char* kSkipOptimization = "SkipOptimization";
} // namespace attr

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_EXPR_H_
11 changes: 10 additions & 1 deletion python/tvm/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,16 @@ def export_library(self,
self.save(path_obj)
files = [path_obj]
is_system_lib = self.type_key == "llvm" and self.get_function("__tvm_is_system_module")()
has_imported_c_file = False
if self.imported_modules:
for i, m in enumerate(self.imported_modules):
if m.type_key == "c":
has_imported_c_file = True
c_file_name = "tmp_" + str(i) + ".cc"
path_cc = temp.relpath(c_file_name)
with open(path_cc, "w") as f:
f.write(m.get_source())
files.append(path_cc)
path_cc = temp.relpath("devc.cc")
with open(path_cc, "w") as f:
f.write(_PackImportsToC(self, is_system_lib))
Expand All @@ -143,7 +152,7 @@ def export_library(self,
fcompile = _tar.tar
else:
fcompile = _cc.create_shared
if self.type_key == "c":
if self.type_key == "c" or has_imported_c_file:
options = []
if "options" in kwargs:
opts = kwargs["options"]
Expand Down
4 changes: 4 additions & 0 deletions src/codegen/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,10 @@ Target intel_graphics(const std::vector<std::string>& options) {
Target stackvm(const std::vector<std::string>& options) {
return CreateTarget("stackvm", options);
}

Target ext_dev(const std::vector<std::string>& options) {
return CreateTarget("ext_dev", options);
}
} // namespace target

bool LLVMEnabled() {
Expand Down
1 change: 1 addition & 0 deletions src/codegen/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
<< "Only support simply one-level hierarchy";
std::string tkey = im->type_key();
stream->Write(tkey);
if (tkey == "c") continue;
im->SaveToBinary(stream);
}
// translate to C program
Expand Down
22 changes: 22 additions & 0 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ struct GraphCodegen {
return CallFunc<std::string>("get_graph_json", nullptr);
}

Array<tvm::runtime::Module> GetExternalModules() {
return CallFunc<Array<tvm::runtime::Module> >("get_external_modules", nullptr);
}

Map<std::string, Array<LoweredFunc> > GetLoweredFunc() {
return CallFunc<Map<std::string, Array<LoweredFunc> > >("get_lowered_funcs", nullptr);
}
Expand Down Expand Up @@ -148,6 +152,10 @@ class RelayBuildModule : public runtime::ModuleNode {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->graph_codegen_->GetLoweredFunc();
});
} else if (name == "get_external_modules") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->graph_codegen_->GetExternalModules();
});
} else if (name == "optimize") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 2);
Expand Down Expand Up @@ -474,6 +482,20 @@ class RelayBuildModule : public runtime::ModuleNode {
target_host_,
BuildConfig::Current());
}
Array<tvm::runtime::Module> ext_mods = graph_codegen_->GetExternalModules();
if (!ext_mods.empty()) {
CHECK(lowered_funcs.size() > 0 || ext_mods.size() == 1)
<< "Expect to have a TVM DSOModule when multiple external runtime modules exist";
if (lowered_funcs.size() == 0) {
// Execute the whole module using external runtime.
ret_.mod = ext_mods[0];
} else {
// Import all external runtime modules.
for (const auto& it : ext_mods) {
ret_.mod.Import(it);
}
}
}
}

protected:
Expand Down
83 changes: 70 additions & 13 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/runtime/registry.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <topi/tags.h>
Expand Down Expand Up @@ -608,6 +609,46 @@ class CompileEngineImpl : public CompileEngineNode {
return LowerShapeFuncInternal(key)->cached_func;
}

Array<tvm::runtime::Module> LowerExternalFunctions() {
std::unordered_map<std::string, relay::Module> ext_mods;
std::vector<CCacheKey> cached_ext_funcs;
for (const auto& it : cache_) {
auto src_func = it.first->source_func;
CHECK(src_func.defined());
if (!src_func->UseDefaultCompiler()) {
auto compiler = FunctionGetAttr(src_func, attr::kCompiler);
const tvm::ir::StringImm* code_gen = compiler.as<tvm::ir::StringImm>();
CHECK(code_gen) << "No external codegen is set";
if (ext_mods.find(code_gen->value) == ext_mods.end()) {
ext_mods[code_gen->value] = relay::ModuleNode::make({}, {});
}
auto ext_symbol = FunctionGetAttr(src_func, attr::kExternalSymbol);
const tvm::ir::StringImm* symbol_name = ext_symbol.as<tvm::ir::StringImm>();
CHECK(symbol_name) << "No external symbol is set for:\n" << AsText(src_func, false);
auto gv = GlobalVarNode::make(symbol_name->value);
ext_mods[code_gen->value]->Add(gv, src_func);
cached_ext_funcs.push_back(it.first);
}
}

Array<tvm::runtime::Module> ret;
for (const auto& it : ext_mods) {
std::string ext_name = "relay.ext." + it.first;
auto pf = tvm::runtime::Registry::Get(ext_name);
CHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n";
runtime::Module ext_mod = (*pf)(it.second);
CHECK(ext_mod.defined()) << "No external runtime is generated.";
ret.push_back(ext_mod);
}

// No need to cache external functions as we collected them all to create
// external runtime modules.
for (const auto& it : cached_ext_funcs) {
cache_.erase(it);
}
return ret;
}

void Clear() final {
cache_.clear();
}
Expand Down Expand Up @@ -648,6 +689,18 @@ class CompileEngineImpl : public CompileEngineNode {
value->use_count = 0;
cache_[key] = value;
}
// No need to lower external functions for now. We will invoke the external
// codegen tool once and lower all functions together.
if (!key->source_func->UseDefaultCompiler()) {
auto cache_node = make_node<CachedFuncNode>();
const auto name_node =
FunctionGetAttr(key->source_func, attr::kExternalSymbol).as<tvm::ir::StringImm>();
CHECK(name_node != nullptr) << "External function has not been attached a name yet.";
cache_node->func_name = name_node->value;
cache_node->target = tvm::target::ext_dev();
value->cached_func = CachedFunc(cache_node);
return value;
}
// Enforce use the target.
With<Target> target_scope(key->target);

Expand Down Expand Up @@ -759,42 +812,46 @@ const CompileEngine& CompileEngine::Global() {
return *inst;
}


TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey")
.set_body_typed<CCacheKey(Function, Target)>(CCacheKeyNode::make);

TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGlobal")
.set_body_typed<CompileEngine()>([]() {
return CompileEngine::Global();
});
return CompileEngine::Global();
});

TVM_REGISTER_GLOBAL("relay.backend._CompileEngineClear")
.set_body_typed<void(const CompileEngine&)>([](CompileEngine self) {
self->Clear();
});
self->Clear();
});

TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower")
.set_body_typed<CachedFunc(CompileEngine, CCacheKey)>(
[](CompileEngine self, CCacheKey key) {
return self->Lower(key);
});
return self->Lower(key);
});

TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLowerShapeFunc")
.set_body_typed<CachedFunc(CompileEngine, CCacheKey)>(
[](CompileEngine self, CCacheKey key) {
return self->LowerShapeFunc(key);
});
return self->LowerShapeFunc(key);
});

TVM_REGISTER_GLOBAL("relay.backend._CompileLowerExternalFunctions")
.set_body_typed<void(const CompileEngine&)>([](CompileEngine self) {
return self->LowerExternalFunctions();
});

TVM_REGISTER_GLOBAL("relay.backend._CompileEngineJIT")
.set_body_typed<PackedFunc(CompileEngine, CCacheKey)>(
[](CompileEngine self, CCacheKey key) {
return self->JIT(key);
});
return self->JIT(key);
});

TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListItems")
.set_body_typed<Array<NodeRef>(CompileEngine)>(
[](CompileEngine self){
return static_cast<CompileEngineImpl*>(self.operator->())->ListItems();
});
return static_cast<CompileEngineImpl*>(self.operator->())->ListItems();
});
} // namespace relay
} // namespace tvm
7 changes: 7 additions & 0 deletions src/relay/backend/compile_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_

#include <tvm/lowered_func.h>
#include <tvm/runtime/module.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/transform.h>
Expand Down Expand Up @@ -186,6 +187,12 @@ class CompileEngineNode : public Node {
* \return The result.
*/
virtual CachedFunc LowerShapeFunc(const CCacheKey& key) = 0;
/*!
* \brief Lower the external function using external codegen tools.
* \return The runtime moduels for each needed external codegen tool.
*/
virtual tvm::Array<tvm::runtime::Module> LowerExternalFunctions() = 0;

/*! \brief clear the cache. */
virtual void Clear() = 0;

Expand Down