diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index 8b710819e63f..6bc6fbf5b026 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -1301,6 +1301,15 @@ class String : public ObjectRef { */ operator std::string() const { return std::string{get()->data, size()}; } + /*! + * \brief Check if a TVMArgValue can be converted to String, i.e. it can be std::string or String + * \param val The value to be checked + * \return A boolean indicating if val can be converted to String + */ + static bool CanConvertFrom(const TVMArgValue& val) { + return val.type_code() == kTVMStr || val.IsObjectRef(); + } + /*! * \brief Hash the binary bytes * \param data The data pointer diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index 18b17d36651d..af46439cff7c 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -37,7 +37,7 @@ void DictAttrsNode::InitByPackedArgs(const runtime::TVMArgs& args, bool allow_un runtime::TVMArgValue val = args[i + 1]; if (val.IsObjectRef()) { dict.Set(key, val.operator ObjectRef()); - } else if (val.type_code() == kTVMStr) { + } else if (String::CanConvertFrom(val)) { dict.Set(key, val.operator String()); } else { dict.Set(key, val.operator PrimExpr()); diff --git a/src/node/container.cc b/src/node/container.cc index f8bad0070c55..f7b9dd32d1cc 100644 --- a/src/node/container.cc +++ b/src/node/container.cc @@ -292,29 +292,16 @@ TVM_REGISTER_REFLECTION_VTABLE(MapNode, MapNodeTrait) TVM_REGISTER_GLOBAL("node.Map").set_body([](TVMArgs args, TVMRetValue* ret) { CHECK_EQ(args.size() % 2, 0); - if (args.size() != 0 && args[0].type_code() == kTVMStr) { - MapNode::ContainerType data; - for (int i = 0; i < args.num_args; i += 2) { - CHECK(args[i].type_code() == kTVMStr) << "key of str map need to be str"; - CHECK(args[i + 1].IsObjectRef()) << "value of the map to be object"; - data.emplace( - std::make_pair(String(args[i].operator std::string()), args[i + 1].operator ObjectRef())); - } - auto node = make_object(); - node->data = std::move(data); - *ret = Map(node); - } else { - // Container node. - MapNode::ContainerType data; - for (int i = 0; i < args.num_args; i += 2) { - CHECK(args[i].IsObjectRef()) << "key of map need to be object"; - CHECK(args[i + 1].IsObjectRef()) << "value of map to be object"; - data.emplace(std::make_pair(args[i].operator ObjectRef(), args[i + 1].operator ObjectRef())); - } - auto node = make_object(); - node->data = std::move(data); - *ret = Map(node); + MapNode::ContainerType data; + for (int i = 0; i < args.num_args; i += 2) { + ObjectRef k = + String::CanConvertFrom(args[i]) ? args[i].operator String() : args[i].operator ObjectRef(); + ObjectRef v = args[i + 1]; + data.emplace(std::move(k), std::move(v)); } + auto node = make_object(); + node->data = std::move(data); + *ret = Map(node); }); TVM_REGISTER_GLOBAL("node.MapSize").set_body([](TVMArgs args, TVMRetValue* ret) { @@ -331,15 +318,10 @@ TVM_REGISTER_GLOBAL("node.MapGetItem").set_body([](TVMArgs args, TVMRetValue* re CHECK(ptr->IsInstance()); auto* n = static_cast(ptr); - if (args[1].type_code() == kTVMStr) { - auto it = n->data.find(String(args[1].operator std::string())); - CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map"; - *ret = (*it).second; - } else { - auto it = n->data.find(args[1].operator ObjectRef()); - CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map"; - *ret = (*it).second; - } + auto it = n->data.find(String::CanConvertFrom(args[1]) ? args[1].operator String() + : args[1].operator ObjectRef()); + CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map"; + *ret = (*it).second; }); TVM_REGISTER_GLOBAL("node.MapCount").set_body([](TVMArgs args, TVMRetValue* ret) { @@ -347,11 +329,9 @@ TVM_REGISTER_GLOBAL("node.MapCount").set_body([](TVMArgs args, TVMRetValue* ret) Object* ptr = static_cast(args[0].value().v_handle); CHECK(ptr->IsInstance()); const MapNode* n = static_cast(ptr); - if (args[1].type_code() == kTVMStr) { - *ret = static_cast(n->data.count(String(args[1].operator std::string()))); - } else { - *ret = static_cast(n->data.count(args[1].operator ObjectRef())); - } + int64_t cnt = n->data.count(String::CanConvertFrom(args[1]) ? args[1].operator String() + : args[1].operator ObjectRef()); + *ret = cnt; }); TVM_REGISTER_GLOBAL("node.MapItems").set_body([](TVMArgs args, TVMRetValue* ret) { diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index eb305c95a8a6..e9543e354bd1 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -121,6 +121,15 @@ bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) { return val->data == rhs.operator std::string(); } break; + case kTVMObjectHandle: + if (rhs.IsObjectRef()) { + if (auto* val = lhs.as()) { + return rhs.operator String() == val->value; + } else if (auto* val = lhs.as()) { + return rhs.operator String() == val->data; + } + } + break; default: CHECK(false) << "Unsupported type code in Pattern Node " << rhs.type_code(); } diff --git a/src/runtime/graph/debug/graph_runtime_debug.cc b/src/runtime/graph/debug/graph_runtime_debug.cc index 9f206fd48d6e..5439be9109f9 100644 --- a/src/runtime/graph/debug/graph_runtime_debug.cc +++ b/src/runtime/graph/debug/graph_runtime_debug.cc @@ -20,6 +20,7 @@ /*! * \file graph_runtime_debug.cc */ +#include #include #include #include @@ -173,7 +174,7 @@ PackedFunc GraphRuntimeDebug::GetFunction(const std::string& name, }); } else if (name == "debug_get_output") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - if (args[0].type_code() == kTVMStr) { + if (String::CanConvertFrom(args[0])) { this->DebugGetNodeOutput(this->GetNodeIndex(args[0]), args[1]); } else { this->DebugGetNodeOutput(args[0], args[1]); diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index 8f7f98808bd5..59bfb68f039b 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -390,8 +390,8 @@ PackedFunc GraphRuntime::GetFunction(const std::string& name, // Return member functions during query. if (name == "set_input") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - if (args[0].type_code() == kTVMStr) { - int in_idx = this->GetInputIndex(args[0]); + if (String::CanConvertFrom(args[0])) { + int in_idx = this->GetInputIndex(args[0].operator String()); if (in_idx >= 0) this->SetInput(in_idx, args[1]); } else { this->SetInput(args[0], args[1]); @@ -399,8 +399,8 @@ PackedFunc GraphRuntime::GetFunction(const std::string& name, }); } else if (name == "set_input_zero_copy") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - if (args[0].type_code() == kTVMStr) { - int in_idx = this->GetInputIndex(args[0]); + if (String::CanConvertFrom(args[0])) { + int in_idx = this->GetInputIndex(args[0].operator String()); if (in_idx >= 0) this->SetInputZeroCopy(in_idx, args[1]); } else { this->SetInputZeroCopy(args[0], args[1]); @@ -417,11 +417,8 @@ PackedFunc GraphRuntime::GetFunction(const std::string& name, } else if (name == "get_input") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { int in_idx = 0; - if (args[0].type_code() == kTVMStr) { - in_idx = this->GetInputIndex(args[0]); - } else if (args[0].IsObjectRef()) { - auto str = args[0].AsObjectRef(); - in_idx = this->GetInputIndex(str); + if (String::CanConvertFrom(args[0])) { + in_idx = this->GetInputIndex(args[0].operator String()); } else { in_idx = args[0]; }