Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RUNTIME] Update Module and Registry to use String Container #14902

Merged
merged 2 commits into from
May 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions apps/dso_plugin_module/plugin_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ class MyModuleNode : public ModuleNode {

virtual const char* type_key() const final { return "MyModule"; }

virtual PackedFunc GetFunction(const std::string& name,
const ObjectPtr<Object>& sptr_to_self) final {
virtual PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) final {
if (name == "add") {
return TypedPackedFunc<int(int)>([sptr_to_self, this](int value) { return value_ + value; });
} else if (name == "mul") {
Expand Down
19 changes: 9 additions & 10 deletions include/tvm/runtime/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class Module : public ObjectRef {
* This function will return PackedFunc(nullptr) if function do not exist.
* \note Implemented in packed_func.cc
*/
inline PackedFunc GetFunction(const std::string& name, bool query_imports = false);
inline PackedFunc GetFunction(const String& name, bool query_imports = false);
// The following functions requires link with runtime.
/*!
* \brief Import another module into this module.
Expand All @@ -111,7 +111,7 @@ class Module : public ObjectRef {
* \note This function won't load the import relationship.
* Re-create import relationship by calling Import.
*/
TVM_DLL static Module LoadFromFile(const std::string& file_name, const std::string& format = "");
TVM_DLL static Module LoadFromFile(const String& file_name, const String& format = "");
// refer to the corresponding container.
using ContainerType = ModuleNode;
friend class ModuleNode;
Expand Down Expand Up @@ -165,14 +165,13 @@ class TVM_DLL ModuleNode : public Object {
* If the function need resource from the module(e.g. late linking),
* it should capture sptr_to_self.
*/
virtual PackedFunc GetFunction(const std::string& name,
const ObjectPtr<Object>& sptr_to_self) = 0;
virtual PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) = 0;
/*!
* \brief Save the module to file.
* \param file_name The file to be saved to.
* \param format The format of the file.
*/
virtual void SaveToFile(const std::string& file_name, const std::string& format);
virtual void SaveToFile(const String& file_name, const String& format);
/*!
* \brief Save the module to binary stream.
* \param stream The binary stream to save to.
Expand All @@ -186,12 +185,12 @@ class TVM_DLL ModuleNode : public Object {
* \param format Format of the source code, can be empty by default.
* \return Possible source code when available.
*/
virtual std::string GetSource(const std::string& format = "");
virtual String GetSource(const String& format = "");
/*!
* \brief Get the format of the module, when available.
* \return Possible format when available.
*/
virtual std::string GetFormat();
virtual String GetFormat();
/*!
* \brief Get packed function from current module by name.
*
Expand All @@ -201,7 +200,7 @@ class TVM_DLL ModuleNode : public Object {
* This function will return PackedFunc(nullptr) if function do not exist.
* \note Implemented in packed_func.cc
*/
PackedFunc GetFunction(const std::string& name, bool query_imports = false);
PackedFunc GetFunction(const String& name, bool query_imports = false);
/*!
* \brief Import another module into this module.
* \param other The module to be imported.
Expand All @@ -217,7 +216,7 @@ class TVM_DLL ModuleNode : public Object {
* \param name name of the function.
* \return The corresponding function.
*/
const PackedFunc* GetFuncFromEnv(const std::string& name);
const PackedFunc* GetFuncFromEnv(const String& name);
/*! \return The module it imports from */
const std::vector<Module>& imports() const { return imports_; }

Expand Down Expand Up @@ -268,7 +267,7 @@ class TVM_DLL ModuleNode : public Object {
* \param target The target module name.
* \return Whether runtime is enabled.
*/
TVM_DLL bool RuntimeEnabled(const std::string& target);
TVM_DLL bool RuntimeEnabled(const String& target);

/*! \brief namespace for constant symbols */
namespace symbol {
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -1942,7 +1942,7 @@ inline TVMRetValue::operator T() const {
return PackedFuncValueConverter<T>::From(*this);
}

inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) {
inline PackedFunc Module::GetFunction(const String& name, bool query_imports) {
return (*this)->GetFunction(name, query_imports);
}

Expand Down
12 changes: 6 additions & 6 deletions include/tvm/runtime/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@
#ifndef TVM_RUNTIME_REGISTRY_H_
#define TVM_RUNTIME_REGISTRY_H_

#include <tvm/runtime/container/string.h>
#include <tvm/runtime/packed_func.h>

#include <string>
#include <type_traits>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -295,32 +295,32 @@ class Registry {
* \param override Whether allow override existing function.
* \return Reference to the registry.
*/
TVM_DLL static Registry& Register(const std::string& name, bool override = false); // NOLINT(*)
TVM_DLL static Registry& Register(const String& name, bool override = false); // NOLINT(*)
/*!
* \brief Erase global function from registry, if exist.
* \param name The name of the function.
* \return Whether function exist.
*/
TVM_DLL static bool Remove(const std::string& name);
TVM_DLL static bool Remove(const String& name);
/*!
* \brief Get the global function by name.
* \param name The name of the function.
* \return pointer to the registered function,
* nullptr if it does not exist.
*/
TVM_DLL static const PackedFunc* Get(const std::string& name); // NOLINT(*)
TVM_DLL static const PackedFunc* Get(const String& name); // NOLINT(*)
/*!
* \brief Get the names of currently registered global function.
* \return The names
*/
TVM_DLL static std::vector<std::string> ListNames();
TVM_DLL static std::vector<String> ListNames();

// Internal class.
struct Manager;

protected:
/*! \brief name of the function */
std::string name_;
String name_;
/*! \brief internal packed function */
PackedFunc func_;
friend struct Manager;
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/runtime/vm/executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class TVM_DLL Executable : public ModuleNode {
*
* \return PackedFunc or nullptr when it is not available.
*/
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;
PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) final;

/*! \brief Get the property of the runtime module .*/
int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable; };
Expand All @@ -88,7 +88,7 @@ class TVM_DLL Executable : public ModuleNode {
* \param path The path to write the serialized data to.
* \param format The format of the serialized blob.
*/
void SaveToFile(const std::string& path, const std::string& format) final;
void SaveToFile(const String& path, const String& format) final;

/*!
* \brief Serialize the executable into global section, constant section, and
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/vm/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ class TVM_DLL VirtualMachine : public runtime::ModuleNode {
* If the function needs resource from the module(e.g. late linking),
* it should capture sptr_to_self.
*/
virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self);
virtual PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self);

virtual ~VirtualMachine() {}

Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1310,7 +1310,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
class AOTExecutorCodegenModule : public runtime::ModuleNode {
public:
AOTExecutorCodegenModule() {}
virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) {
virtual PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) {
if (name == "init") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
ICHECK_EQ(args.num_args, 2) << "The expected of arguments are: "
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ class RelayBuildModule : public runtime::ModuleNode {
* \param sptr_to_self The pointer to the module node.
* \return The corresponding member function.
*/
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) final {
if (name == "get_graph_json") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetGraphJSON(); });
Expand Down
8 changes: 4 additions & 4 deletions src/relay/backend/contrib/ethosu/source_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class EthosUModuleNode : public ModuleNode {
* \param file_name The file to be saved to.
* \param format The format of the file.
*/
void SaveToFile(const std::string& file_name, const std::string& format) final {
void SaveToFile(const String& file_name, const String& format) final {
std::string fmt = GetFileFormat(file_name, format);
ICHECK_EQ(fmt, "c") << "Can only save to format="
<< "c";
Expand All @@ -87,9 +87,9 @@ class EthosUModuleNode : public ModuleNode {
out.close();
}

std::string GetSource(const std::string& format) final { return c_source; }
String GetSource(const String& format) final { return c_source; }

std::string GetFormat() override { return "c"; }
String GetFormat() override { return "c"; }

Array<CompilationArtifact> GetArtifacts() { return compilation_artifacts_; }

Expand All @@ -101,7 +101,7 @@ class EthosUModuleNode : public ModuleNode {
*
* \return The function pointer when it is found, otherwise, PackedFunc(nullptr).
*/
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) final {
if (name == "get_func_names") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
Array<String> func_names;
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/graph_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
class GraphExecutorCodegenModule : public runtime::ModuleNode {
public:
GraphExecutorCodegenModule() {}
virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) {
virtual PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) {
if (name == "init") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
ICHECK_EQ(args.num_args, 2) << "The expected of arguments are: "
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
VirtualDevice host_virtual_device_;
};

PackedFunc VMCompiler::GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) {
PackedFunc VMCompiler::GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) {
if (name == "lower") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
ICHECK_EQ(args.num_args, 2);
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/vm/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class VMCompiler : public runtime::ModuleNode {
VMCompiler() = default;
virtual ~VMCompiler() = default;

virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self);
virtual PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self);

const char* type_key() const final { return "VMCompiler"; }

Expand Down
2 changes: 1 addition & 1 deletion src/relay/printer/model_library_format_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class ModelLibraryFormatPrinter : public ::tvm::runtime::ModuleNode {
return rv;
}

PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) override {
PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) override {
if (name == "print") {
return TypedPackedFunc<std::string(ObjectRef)>(
[sptr_to_self, this](ObjectRef node) { return Print(node); });
Expand Down
3 changes: 1 addition & 2 deletions src/runtime/aot_executor/aot_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ AotExecutor::AotExecutor(tvm::runtime::Module module, const std::vector<Device>&
}
}

PackedFunc AotExecutor::GetFunction(const std::string& name,
const ObjectPtr<Object>& sptr_to_self) {
PackedFunc AotExecutor::GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) {
// Return member functions during query.
if (name == "set_input") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/aot_executor/aot_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class TVM_DLL AotExecutor : public ModuleNode {
* \param sptr_to_self The pointer to the module node.
* \return The corresponding member function.
*/
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) override;
PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) override;

/*!
* \return The type key of the executor.
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/aot_executor/aot_executor_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ AotExecutorFactory::AotExecutorFactory(
}

PackedFunc AotExecutorFactory::GetFunction(
const std::string& name, const tvm::runtime::ObjectPtr<tvm::runtime::Object>& sptr_to_self) {
const String& name, const tvm::runtime::ObjectPtr<tvm::runtime::Object>& sptr_to_self) {
if (name == module_name_) {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
ICHECK_GT(args.num_args, 0) << "Must supply at least one device argument";
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/aot_executor/aot_executor_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class TVM_DLL AotExecutorFactory : public runtime::ModuleNode {
* \param sptr_to_self The pointer to the module node.
* \return The corresponding member function.
*/
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;
PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) final;

/*!
* \return The type key of the executor.
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/const_loader_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class ConstLoaderModuleNode : public ModuleNode {
}
}

PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) final {
VLOG(1) << "ConstLoaderModuleNode::GetFunction(" << name << ")";
// Initialize and memoize the module.
// Usually, we have some warmup runs. The module initialization should be
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/contrib/coreml/coreml_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class CoreMLRuntime : public ModuleNode {
* \param sptr_to_self The pointer to the module node.
* \return The corresponding member function.
*/
virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self);
virtual PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self);

/*! \brief Get the property of the runtime module .*/
int GetPropertyMask() const final {
Expand Down
3 changes: 1 addition & 2 deletions src/runtime/contrib/coreml/coreml_runtime.mm
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,7 @@
model_ = std::unique_ptr<CoreMLModel>(new CoreMLModel(url));
}

PackedFunc CoreMLRuntime::GetFunction(const std::string& name,
const ObjectPtr<Object>& sptr_to_self) {
PackedFunc CoreMLRuntime::GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) {
// Return member functions during query.
if (name == "invoke" || name == "run") {
return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { model_->Invoke(); });
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/contrib/dnnl/dnnl_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
}

/* Override GetFunction to reimplement Run method */
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) override {
PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) override {
if (this->symbol_name_ == name) {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
ICHECK(this->initialized_) << "The module has not been initialized";
Expand Down
5 changes: 2 additions & 3 deletions src/runtime/contrib/ethosn/ethosn_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ EthosnModule::EthosnModule(std::vector<OrderedCompiledNetwork>* cmms) {
}
}

PackedFunc EthosnModule::GetFunction(const std::string& name,
const ObjectPtr<Object>& sptr_to_self) {
PackedFunc EthosnModule::GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) {
if (network_map_.find(name) != network_map_.end()) {
return PackedFunc([sptr_to_self, this, name](TVMArgs args, TVMRetValue* rv) {
*rv = Inference(args, network_map_[name].proc_mem_alloc.get(),
Expand Down Expand Up @@ -143,7 +142,7 @@ Module EthosnModule::LoadFromBinary(void* strm) {
return Module(n);
}

void EthosnModule::SaveToFile(const std::string& path, const std::string& format) {
void EthosnModule::SaveToFile(const String& path, const String& format) {
std::string data;
dmlc::MemoryStringStream writer(&data);
dmlc::SeekStream* strm = &writer;
Expand Down
4 changes: 2 additions & 2 deletions src/runtime/contrib/ethosn/ethosn_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class EthosnModule : public ModuleNode {
* \param sptr_to_self The ObjectPtr that points to this module node.
* \return The function pointer when it is found, otherwise, PackedFunc(nullptr).
*/
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;
PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) final;
/*!
* \brief Save a compiled network to a binary stream, which can then be
* serialized to disk.
Expand Down Expand Up @@ -100,7 +100,7 @@ class EthosnModule : public ModuleNode {
* \brief Save a module to a specified path.
* \param path Where to save the serialized module.
*/
void SaveToFile(const std::string& path, const std::string& format) override;
void SaveToFile(const String& path, const String& format) override;

const char* type_key() const override { return "ethos-n"; }

Expand Down
4 changes: 2 additions & 2 deletions src/runtime/contrib/json/json_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class JSONRuntimeBase : public ModuleNode {
* \param sptr_to_self The pointer to the module node.
* \return The packed function.
*/
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) override {
PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) override {
if (name == "get_symbol") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->symbol_name_; });
Expand Down Expand Up @@ -145,7 +145,7 @@ class JSONRuntimeBase : public ModuleNode {
* \param format the format to return.
* \return A string of JSON.
*/
std::string GetSource(const std::string& format = "json") override { return graph_json_; }
String GetSource(const String& format = "json") override { return graph_json_; }

protected:
/*!
Expand Down
Loading
Loading