Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
250 lines (231 sloc) 8.21 KB
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/runtime/module.h
* \brief Runtime container of the functions generated by TVM,
* This is used to support dynamically link, load and save
* functions from different convention under unified API.
*/
#ifndef TVM_RUNTIME_MODULE_H_
#define TVM_RUNTIME_MODULE_H_
#include <dmlc/io.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>
#include <memory>
#include <vector>
#include <string>
#include <unordered_map>
namespace tvm {
namespace runtime {
class ModuleNode;
class PackedFunc;
/*!
* \brief Module container of TVM.
*/
class Module : public ObjectRef {
public:
Module() {}
// constructor from container.
explicit Module(ObjectPtr<Object> n)
: ObjectRef(n) {}
/*!
* \brief Get packed function from current module by name.
*
* \param name The name of the function.
* \param query_imports Whether also query dependency modules.
* \return The result function.
* This function will return PackedFunc(nullptr) if function do not exist.
* \note Implemented in packed_func.cc
*/
inline PackedFunc GetFunction(const std::string& name, bool query_imports = false);
// The following functions requires link with runtime.
/*!
* \brief Import another module into this module.
* \param other The module to be imported.
*
* \note Cyclic dependency is not allowed among modules,
* An error will be thrown when cyclic dependency is detected.
*/
inline void Import(Module other);
/*! \return internal container */
inline ModuleNode* operator->();
/*! \return internal container */
inline const ModuleNode* operator->() const;
/*!
* \brief Load a module from file.
* \param file_name The name of the host function module.
* \param format The format of the file.
* \note This function won't load the import relationship.
* Re-create import relationship by calling Import.
*/
TVM_DLL static Module LoadFromFile(const std::string& file_name,
const std::string& format = "");
// refer to the corresponding container.
using ContainerType = ModuleNode;
friend class ModuleNode;
};
/*!
* \brief Base container of module.
*
* Please subclass ModuleNode to create a specific runtime module.
*
* \code
*
* class MyModuleNode : public ModuleNode {
* public:
* // implement the interface
* };
*
* // use make_object to create a specific
* // instace of MyModuleNode.
* Module CreateMyModule() {
* ObjectPtr<MyModuleNode> n =
* tvm::runtime::make_object<MyModuleNode>();
* return Module(n);
* }
*
* \endcode
*/
class ModuleNode : public Object {
public:
/*! \brief virtual destructor */
virtual ~ModuleNode() {}
/*!
* \return The per module type key.
* \note This key is used to for serializing custom modules.
*/
virtual const char* type_key() const = 0;
/*!
* \brief Get a PackedFunc from module.
*
* The PackedFunc may not be fully initialized,
* there might still be first time running overhead when
* executing the function on certain devices.
* For benchmarking, use prepare to eliminate
*
* \param name the name of the function.
* \param sptr_to_self The ObjectPtr that points to this module node.
*
* \return PackedFunc(nullptr) when it is not available.
*
* \note The function will always remain valid.
* If the function need resource from the module(e.g. late linking),
* it should capture sptr_to_self.
*/
virtual PackedFunc GetFunction(
const std::string& name,
const ObjectPtr<Object>& sptr_to_self) = 0;
/*!
* \brief Save the module to file.
* \param file_name The file to be saved to.
* \param format The format of the file.
*/
TVM_DLL virtual void SaveToFile(const std::string& file_name,
const std::string& format);
/*!
* \brief Save the module to binary stream.
* \param stream The binary stream to save to.
* \note It is recommended to implement this for device modules,
* but not necessarily host modules.
* We can use this to do AOT loading of bundled device functions.
*/
TVM_DLL virtual void SaveToBinary(dmlc::Stream* stream);
/*!
* \brief Get the source code of module, when available.
* \param format Format of the source code, can be empty by default.
* \return Possible source code when available.
*/
TVM_DLL virtual std::string GetSource(const std::string& format = "");
/*!
* \brief Get packed function from current module by name.
*
* \param name The name of the function.
* \param query_imports Whether also query dependency modules.
* \return The result function.
* This function will return PackedFunc(nullptr) if function do not exist.
* \note Implemented in packed_func.cc
*/
TVM_DLL PackedFunc GetFunction(const std::string& name, bool query_imports = false);
/*!
* \brief Import another module into this module.
* \param other The module to be imported.
*
* \note Cyclic dependency is not allowed among modules,
* An error will be thrown when cyclic dependency is detected.
*/
TVM_DLL void Import(Module other);
/*!
* \brief Get a function from current environment
* The environment includes all the imports as well as Global functions.
*
* \param name name of the function.
* \return The corresponding function.
*/
TVM_DLL const PackedFunc* GetFuncFromEnv(const std::string& name);
/*! \return The module it imports from */
const std::vector<Module>& imports() const {
return imports_;
}
// integration with the existing components.
static constexpr const uint32_t _type_index = TypeIndex::kRuntimeModule;
static constexpr const char* _type_key = "runtime.Module";
// NOTE: ModuleNode can still be sub-classed
//
TVM_DECLARE_FINAL_OBJECT_INFO(ModuleNode, Object);
protected:
friend class Module;
/*! \brief The modules this module depend on */
std::vector<Module> imports_;
private:
/*! \brief Cache used by GetImport */
std::unordered_map<std::string,
std::unique_ptr<PackedFunc> > import_cache_;
};
/*! \brief namespace for constant symbols */
namespace symbol {
/*! \brief Global variable to store module context. */
constexpr const char* tvm_module_ctx = "__tvm_module_ctx";
/*! \brief Global variable to store device module blob */
constexpr const char* tvm_dev_mblob = "__tvm_dev_mblob";
/*! \brief Number of bytes of device module blob. */
constexpr const char* tvm_dev_mblob_nbytes = "__tvm_dev_mblob_nbytes";
/*! \brief global function to set device */
constexpr const char* tvm_set_device = "__tvm_set_device";
/*! \brief Auxiliary counter to global barrier. */
constexpr const char* tvm_global_barrier_state = "__tvm_global_barrier_state";
/*! \brief Prepare the global barrier before kernels that uses global barrier. */
constexpr const char* tvm_prepare_global_barrier = "__tvm_prepare_global_barrier";
/*! \brief Placeholder for the module's entry function. */
constexpr const char* tvm_module_main = "__tvm_main__";
} // namespace symbol
// implementations of inline functions.
inline void Module::Import(Module other) {
return (*this)->Import(other);
}
inline ModuleNode* Module::operator->() {
return static_cast<ModuleNode*>(get_mutable());
}
inline const ModuleNode* Module::operator->() const {
return static_cast<const ModuleNode*>(get());
}
} // namespace runtime
} // namespace tvm
#include <tvm/runtime/packed_func.h> // NOLINT(*)
#endif // TVM_RUNTIME_MODULE_H_
You can’t perform that action at this time.