Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] External codegen #4482

Merged
merged 11 commits into from
Dec 18, 2019
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Expand Up @@ -254,6 +254,7 @@ include(cmake/modules/LLVM.cmake)
include(cmake/modules/Micro.cmake)
include(cmake/modules/ANTLR.cmake)
include(cmake/modules/contrib/BLAS.cmake)
include(cmake/modules/contrib/Extern.cmake)
zhiics marked this conversation as resolved.
Show resolved Hide resolved
include(cmake/modules/contrib/Random.cmake)
include(cmake/modules/contrib/MicroStandaloneRuntime.cmake)
include(cmake/modules/contrib/Sort.cmake)
Expand Down
3 changes: 3 additions & 0 deletions cmake/config.cmake
Expand Up @@ -172,6 +172,9 @@ set(USE_ROCBLAS OFF)
# Whether use contrib sort
set(USE_SORT ON)

# Whether use MKL-DNN (DNNL) codegen
set(USE_DNNL_CODEGEN OFF)

# Build ANTLR parser for Relay text format
# Possible values:
# - ON: enable ANTLR by searching default locations (cmake find_program for antlr4 and /usr/local for jar)
Expand Down
33 changes: 33 additions & 0 deletions cmake/modules/contrib/Extern.cmake
@@ -0,0 +1,33 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

message(STATUS "Build with relay.backend.contrib")

file(GLOB CSOURCE_RELAY_CONTRIB_SRC src/relay/backend/contrib/csource/codegen.cc)
list(APPEND COMPILER_SRCS ${CSOURCE_RELAY_CONTRIB_SRC})

if(USE_DNNL_CODEGEN STREQUAL "ON")
file(GLOB DNNL_RELAY_CONTRIB_SRC src/relay/backend/contrib/dnnl/codegen.cc)
list(APPEND COMPILER_SRCS ${DNNL_RELAY_CONTRIB_SRC})

find_library(EXTERN_LIBRARY_DNNL dnnl)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${EXTERN_LIBRARY_DNNL})
file(GLOB DNNL_CONTRIB_SRC src/runtime/contrib/dnnl/*)
list(APPEND RUNTIME_SRCS ${DNNL_CONTRIB_SRC})
message(STATUS "Use DNNL codegen: " ${EXTERN_LIBRARY_DNNL})
endif()

3 changes: 3 additions & 0 deletions include/tvm/build_module.h
Expand Up @@ -170,6 +170,9 @@ TVM_DLL Target intel_graphics(const std::vector<std::string>& options =
TVM_DLL Target stackvm(const std::vector<std::string>& options =
std::vector<std::string>());

/*! \return A target for external device */
TVM_DLL Target ext(const std::vector<std::string>& options =
std::vector<std::string>());
} // namespace target

/*!
Expand Down
27 changes: 27 additions & 0 deletions include/tvm/relay/expr.h
Expand Up @@ -268,6 +268,14 @@ class FunctionNode : public ExprNode {
*/
bool IsPrimitive() const;

/*!
* \brief Check whether the function is an external function.
* External functions are subgraphes that supported by external libraries.
*
* \return Whether the function is external or not.
*/
bool IsExternal() const;
zhiics marked this conversation as resolved.
Show resolved Hide resolved

TVM_DLL static Function make(tvm::Array<Var> params,
Expr body,
Type ret_type,
Expand Down Expand Up @@ -588,6 +596,25 @@ std::string AsText(const NodeRef& node,
bool show_meta_data = true,
runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);

/*! \brief namespace of the attributes that are attached to a function. */
namespace attr {
/*! \brief Mark the function as a primitive function. */
constexpr const char* kPrimitive = "Primitive";
/*!
* \brief Mark the function as an external function that needs to be handled by
* the external codegen tool/backend.
*/
constexpr const char* kExternal = "External";
/*! \brief Indicate if the function is a closure. */
constexpr const char* kClosure = "Closure";
/*! \brief Store a Var to parameter/Constant mapping on a Function. */
constexpr const char* kParams = "__params__";
/*! \brief Store the function name. */
constexpr const char* kFuncName = "FuncName";
/*! \brief Mark if the function should be avoided being optimized. */
constexpr const char* kSkipOptimization = "SkipOptimization";
} // namespace attr

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_EXPR_H_
11 changes: 10 additions & 1 deletion python/tvm/module.py
Expand Up @@ -133,7 +133,16 @@ def export_library(self,
self.save(path_obj)
files = [path_obj]
is_system_lib = self.type_key == "llvm" and self.get_function("__tvm_is_system_module")()
has_imported_c_file = False
if self.imported_modules:
for i, m in enumerate(self.imported_modules):
if m.type_key == "c":
has_imported_c_file = True
c_file_name = "tmp_" + str(i) + ".cc"
path_cc = temp.relpath(c_file_name)
with open(path_cc, "w") as f:
f.write(m.get_source())
files.append(path_cc)
path_cc = temp.relpath("devc.cc")
with open(path_cc, "w") as f:
f.write(_PackImportsToC(self, is_system_lib))
Expand All @@ -143,7 +152,7 @@ def export_library(self,
fcompile = _tar.tar
else:
fcompile = _cc.create_shared
if self.type_key == "c":
if self.type_key == "c" or has_imported_c_file:
options = []
if "options" in kwargs:
opts = kwargs["options"]
Expand Down
4 changes: 4 additions & 0 deletions src/codegen/build_module.cc
Expand Up @@ -309,6 +309,10 @@ Target intel_graphics(const std::vector<std::string>& options) {
Target stackvm(const std::vector<std::string>& options) {
return CreateTarget("stackvm", options);
}

Target ext(const std::vector<std::string>& options) {
zhiics marked this conversation as resolved.
Show resolved Hide resolved
return CreateTarget("ext_dev", options);
}
} // namespace target

bool LLVMEnabled() {
Expand Down
1 change: 1 addition & 0 deletions src/codegen/codegen.cc
Expand Up @@ -69,6 +69,7 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
<< "Only support simply one-level hierarchy";
std::string tkey = im->type_key();
stream->Write(tkey);
if (tkey == "c") continue;
im->SaveToBinary(stream);
}
// translate to C program
Expand Down
22 changes: 22 additions & 0 deletions src/relay/backend/build_module.cc
Expand Up @@ -73,6 +73,10 @@ struct GraphCodegen {
return CallFunc<std::string>("get_graph_json", nullptr);
}

Array<tvm::runtime::Module> GetExternalModules() {
return CallFunc<Array<tvm::runtime::Module> >("get_external_modules", nullptr);
}

Map<std::string, Array<LoweredFunc> > GetLoweredFunc() {
return CallFunc<Map<std::string, Array<LoweredFunc> > >("get_lowered_funcs", nullptr);
}
Expand Down Expand Up @@ -148,6 +152,10 @@ class RelayBuildModule : public runtime::ModuleNode {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->graph_codegen_->GetLoweredFunc();
});
} else if (name == "get_external_modules") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->graph_codegen_->GetExternalModules();
});
} else if (name == "optimize") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 2);
Expand Down Expand Up @@ -474,6 +482,20 @@ class RelayBuildModule : public runtime::ModuleNode {
target_host_,
BuildConfig::Current());
}
Array<tvm::runtime::Module> ext_mods = graph_codegen_->GetExternalModules();
if (!ext_mods.empty()) {
CHECK(lowered_funcs.size() > 0 || ext_mods.size() == 1)
<< "Expect to have a TVM DSOModule when multiple external runtime modules exist";
if (lowered_funcs.size() == 0) {
// Execute the whole module using external runtime.
ret_.mod = ext_mods[0];
} else {
// Import all external runtime modules.
for (const auto& it : ext_mods) {
ret_.mod.Import(it);
}
}
}
}

protected:
Expand Down
83 changes: 70 additions & 13 deletions src/relay/backend/compile_engine.cc
Expand Up @@ -27,6 +27,7 @@
#include <tvm/runtime/registry.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <topi/tags.h>
Expand Down Expand Up @@ -608,6 +609,46 @@ class CompileEngineImpl : public CompileEngineNode {
return LowerShapeFuncInternal(key)->cached_func;
}

Array<tvm::runtime::Module> LowerExternalFunctions() {
std::unordered_map<std::string, relay::Module> ext_mods;
std::vector<CCacheKey> cached_ext_funcs;
for (const auto& it : cache_) {
auto src_func = it.first->source_func;
CHECK(src_func.defined());
if (src_func->IsExternal()) {
auto compiler = FunctionGetAttr(src_func, attr::kExternal);
const tvm::ir::StringImm* code_gen = compiler.as<tvm::ir::StringImm>();
CHECK(code_gen) << "No external codegen is set";
if (ext_mods.find(code_gen->value) == ext_mods.end()) {
ext_mods[code_gen->value] = relay::ModuleNode::make({}, {});
}
auto ext_func_name = FunctionGetAttr(src_func, attr::kFuncName);
const tvm::ir::StringImm* func_name = ext_func_name.as<tvm::ir::StringImm>();
CHECK(func_name) << "No external function name is set for:\n" << AsText(src_func, false);
auto gv = GlobalVarNode::make(func_name->value);
ext_mods[code_gen->value]->Add(gv, src_func);
cached_ext_funcs.push_back(it.first);
}
}

Array<tvm::runtime::Module> ret;
for (const auto& it : ext_mods) {
std::string ext_name = "relay.ext." + it.first;
auto pf = tvm::runtime::Registry::Get(ext_name);
CHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n";
runtime::Module ext_mod = (*pf)(it.second);
CHECK(ext_mod.defined()) << "No external runtime is generated.";
ret.push_back(ext_mod);
}

// No need to cache external functions as we collected them all to create
// external runtime modules.
for (const auto& it : cached_ext_funcs) {
cache_.erase(it);
}
return ret;
}

void Clear() final {
cache_.clear();
}
Expand Down Expand Up @@ -648,6 +689,18 @@ class CompileEngineImpl : public CompileEngineNode {
value->use_count = 0;
cache_[key] = value;
}
// No need to lower external function for now. We will invoke the external
// codegen tool once and lower all functions together.
if (key->source_func->IsExternal()) {
auto cache_node = make_node<CachedFuncNode>();
const auto name_node =
FunctionGetAttr(key->source_func, attr::kFuncName).as<tvm::ir::StringImm>();
CHECK(name_node != nullptr) << "External function has not been attached a name yet.";
cache_node->func_name = name_node->value;
cache_node->target = tvm::target::ext();
value->cached_func = CachedFunc(cache_node);
return value;
}
// Enforce use the target.
With<Target> target_scope(key->target);

Expand Down Expand Up @@ -759,42 +812,46 @@ const CompileEngine& CompileEngine::Global() {
return *inst;
}


TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey")
.set_body_typed<CCacheKey(Function, Target)>(CCacheKeyNode::make);

TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGlobal")
.set_body_typed<CompileEngine()>([]() {
return CompileEngine::Global();
});
return CompileEngine::Global();
});

TVM_REGISTER_GLOBAL("relay.backend._CompileEngineClear")
.set_body_typed<void(const CompileEngine&)>([](CompileEngine self) {
self->Clear();
});
self->Clear();
});

TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower")
.set_body_typed<CachedFunc(CompileEngine, CCacheKey)>(
[](CompileEngine self, CCacheKey key) {
return self->Lower(key);
});
return self->Lower(key);
});

TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLowerShapeFunc")
.set_body_typed<CachedFunc(CompileEngine, CCacheKey)>(
[](CompileEngine self, CCacheKey key) {
return self->LowerShapeFunc(key);
});
return self->LowerShapeFunc(key);
});

TVM_REGISTER_GLOBAL("relay.backend._CompileLowerExternalFunctions")
.set_body_typed<void(const CompileEngine&)>([](CompileEngine self) {
return self->LowerExternalFunctions();
});

TVM_REGISTER_GLOBAL("relay.backend._CompileEngineJIT")
.set_body_typed<PackedFunc(CompileEngine, CCacheKey)>(
[](CompileEngine self, CCacheKey key) {
return self->JIT(key);
});
return self->JIT(key);
});

TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListItems")
.set_body_typed<Array<NodeRef>(CompileEngine)>(
[](CompileEngine self){
return static_cast<CompileEngineImpl*>(self.operator->())->ListItems();
});
return static_cast<CompileEngineImpl*>(self.operator->())->ListItems();
});
} // namespace relay
} // namespace tvm
7 changes: 7 additions & 0 deletions src/relay/backend/compile_engine.h
Expand Up @@ -26,6 +26,7 @@
#define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_

#include <tvm/lowered_func.h>
#include <tvm/runtime/module.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/transform.h>
Expand Down Expand Up @@ -186,6 +187,12 @@ class CompileEngineNode : public Node {
* \return The result.
*/
virtual CachedFunc LowerShapeFunc(const CCacheKey& key) = 0;
/*!
* \brief Lower the external function using external codegen tools.
* \return The runtime moduels for each needed external codegen tool.
*/
virtual tvm::Array<tvm::runtime::Module> LowerExternalFunctions() = 0;

/*! \brief clear the cache. */
virtual void Clear() = 0;

Expand Down