Skip to content

Commit

Permalink
[Object][FFI] Introduce runtime::String::CanConvertFrom (#5718)
Browse files Browse the repository at this point in the history
* [Object][FFI] Introduce runtime::String::CanConvertFrom

* Update container.h
  • Loading branch information
junrushao committed Jun 3, 2020
1 parent 927510a commit 9151a51
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 47 deletions.
9 changes: 9 additions & 0 deletions include/tvm/runtime/container.h
Expand Up @@ -1301,6 +1301,15 @@ class String : public ObjectRef {
*/
operator std::string() const { return std::string{get()->data, size()}; }

/*!
* \brief Check if a TVMArgValue can be converted to String, i.e. it can be std::string or String
* \param val The value to be checked
* \return A boolean indicating if val can be converted to String
*/
static bool CanConvertFrom(const TVMArgValue& val) {
return val.type_code() == kTVMStr || val.IsObjectRef<tvm::runtime::String>();
}

/*!
* \brief Hash the binary bytes
* \param data The data pointer
Expand Down
2 changes: 1 addition & 1 deletion src/ir/attrs.cc
Expand Up @@ -37,7 +37,7 @@ void DictAttrsNode::InitByPackedArgs(const runtime::TVMArgs& args, bool allow_un
runtime::TVMArgValue val = args[i + 1];
if (val.IsObjectRef<ObjectRef>()) {
dict.Set(key, val.operator ObjectRef());
} else if (val.type_code() == kTVMStr) {
} else if (String::CanConvertFrom(val)) {
dict.Set(key, val.operator String());
} else {
dict.Set(key, val.operator PrimExpr());
Expand Down
52 changes: 16 additions & 36 deletions src/node/container.cc
Expand Up @@ -292,29 +292,16 @@ TVM_REGISTER_REFLECTION_VTABLE(MapNode, MapNodeTrait)

TVM_REGISTER_GLOBAL("node.Map").set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size() % 2, 0);
if (args.size() != 0 && args[0].type_code() == kTVMStr) {
MapNode::ContainerType data;
for (int i = 0; i < args.num_args; i += 2) {
CHECK(args[i].type_code() == kTVMStr) << "key of str map need to be str";
CHECK(args[i + 1].IsObjectRef<ObjectRef>()) << "value of the map to be object";
data.emplace(
std::make_pair(String(args[i].operator std::string()), args[i + 1].operator ObjectRef()));
}
auto node = make_object<MapNode>();
node->data = std::move(data);
*ret = Map<ObjectRef, ObjectRef>(node);
} else {
// Container node.
MapNode::ContainerType data;
for (int i = 0; i < args.num_args; i += 2) {
CHECK(args[i].IsObjectRef<ObjectRef>()) << "key of map need to be object";
CHECK(args[i + 1].IsObjectRef<ObjectRef>()) << "value of map to be object";
data.emplace(std::make_pair(args[i].operator ObjectRef(), args[i + 1].operator ObjectRef()));
}
auto node = make_object<MapNode>();
node->data = std::move(data);
*ret = Map<ObjectRef, ObjectRef>(node);
MapNode::ContainerType data;
for (int i = 0; i < args.num_args; i += 2) {
ObjectRef k =
String::CanConvertFrom(args[i]) ? args[i].operator String() : args[i].operator ObjectRef();
ObjectRef v = args[i + 1];
data.emplace(std::move(k), std::move(v));
}
auto node = make_object<MapNode>();
node->data = std::move(data);
*ret = Map<ObjectRef, ObjectRef>(node);
});

TVM_REGISTER_GLOBAL("node.MapSize").set_body([](TVMArgs args, TVMRetValue* ret) {
Expand All @@ -331,27 +318,20 @@ TVM_REGISTER_GLOBAL("node.MapGetItem").set_body([](TVMArgs args, TVMRetValue* re
CHECK(ptr->IsInstance<MapNode>());

auto* n = static_cast<const MapNode*>(ptr);
if (args[1].type_code() == kTVMStr) {
auto it = n->data.find(String(args[1].operator std::string()));
CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map";
*ret = (*it).second;
} else {
auto it = n->data.find(args[1].operator ObjectRef());
CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map";
*ret = (*it).second;
}
auto it = n->data.find(String::CanConvertFrom(args[1]) ? args[1].operator String()
: args[1].operator ObjectRef());
CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map";
*ret = (*it).second;
});

TVM_REGISTER_GLOBAL("node.MapCount").set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
Object* ptr = static_cast<Object*>(args[0].value().v_handle);
CHECK(ptr->IsInstance<MapNode>());
const MapNode* n = static_cast<const MapNode*>(ptr);
if (args[1].type_code() == kTVMStr) {
*ret = static_cast<int64_t>(n->data.count(String(args[1].operator std::string())));
} else {
*ret = static_cast<int64_t>(n->data.count(args[1].operator ObjectRef()));
}
int64_t cnt = n->data.count(String::CanConvertFrom(args[1]) ? args[1].operator String()
: args[1].operator ObjectRef());
*ret = cnt;
});

TVM_REGISTER_GLOBAL("node.MapItems").set_body([](TVMArgs args, TVMRetValue* ret) {
Expand Down
9 changes: 9 additions & 0 deletions src/relay/ir/dataflow_matcher.cc
Expand Up @@ -121,6 +121,15 @@ bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) {
return val->data == rhs.operator std::string();
}
break;
case kTVMObjectHandle:
if (rhs.IsObjectRef<String>()) {
if (auto* val = lhs.as<tir::StringImmNode>()) {
return rhs.operator String() == val->value;
} else if (auto* val = lhs.as<StringObj>()) {
return rhs.operator String() == val->data;
}
}
break;
default:
CHECK(false) << "Unsupported type code in Pattern Node " << rhs.type_code();
}
Expand Down
3 changes: 2 additions & 1 deletion src/runtime/graph/debug/graph_runtime_debug.cc
Expand Up @@ -20,6 +20,7 @@
/*!
* \file graph_runtime_debug.cc
*/
#include <tvm/runtime/container.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
Expand Down Expand Up @@ -173,7 +174,7 @@ PackedFunc GraphRuntimeDebug::GetFunction(const std::string& name,
});
} else if (name == "debug_get_output") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (args[0].type_code() == kTVMStr) {
if (String::CanConvertFrom(args[0])) {
this->DebugGetNodeOutput(this->GetNodeIndex(args[0]), args[1]);
} else {
this->DebugGetNodeOutput(args[0], args[1]);
Expand Down
15 changes: 6 additions & 9 deletions src/runtime/graph/graph_runtime.cc
Expand Up @@ -390,17 +390,17 @@ PackedFunc GraphRuntime::GetFunction(const std::string& name,
// Return member functions during query.
if (name == "set_input") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (args[0].type_code() == kTVMStr) {
int in_idx = this->GetInputIndex(args[0]);
if (String::CanConvertFrom(args[0])) {
int in_idx = this->GetInputIndex(args[0].operator String());
if (in_idx >= 0) this->SetInput(in_idx, args[1]);
} else {
this->SetInput(args[0], args[1]);
}
});
} else if (name == "set_input_zero_copy") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (args[0].type_code() == kTVMStr) {
int in_idx = this->GetInputIndex(args[0]);
if (String::CanConvertFrom(args[0])) {
int in_idx = this->GetInputIndex(args[0].operator String());
if (in_idx >= 0) this->SetInputZeroCopy(in_idx, args[1]);
} else {
this->SetInputZeroCopy(args[0], args[1]);
Expand All @@ -417,11 +417,8 @@ PackedFunc GraphRuntime::GetFunction(const std::string& name,
} else if (name == "get_input") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
int in_idx = 0;
if (args[0].type_code() == kTVMStr) {
in_idx = this->GetInputIndex(args[0]);
} else if (args[0].IsObjectRef<runtime::String>()) {
auto str = args[0].AsObjectRef<runtime::String>();
in_idx = this->GetInputIndex(str);
if (String::CanConvertFrom(args[0])) {
in_idx = this->GetInputIndex(args[0].operator String());
} else {
in_idx = args[0];
}
Expand Down

0 comments on commit 9151a51

Please sign in to comment.