Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RUNTIME] Introduce RValue reference(move) support to TypedPackedFunc #5271

Merged
merged 3 commits into from
Apr 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 18 additions & 16 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class PrimExpr : public BaseExpr {

private:
// Internal function for conversion.
friend class runtime::TVMPODValue_;
friend struct runtime::PackedFuncValueConverter<PrimExpr>;
TVM_DLL static PrimExpr FromObject_(ObjectPtr<Object> ptr);
};

Expand Down Expand Up @@ -451,22 +451,24 @@ inline const TTypeNode* RelayExprNode::type_as() const {

namespace tvm {
namespace runtime {
// Additional implementattion overloads for PackedFunc.
inline TVMPODValue_::operator tvm::PrimExpr() const {
if (type_code_ == kTVMNullptr) return PrimExpr();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
return PrimExpr(static_cast<int>(value_.v_int64));
template<>
struct PackedFuncValueConverter<PrimExpr> {
// common rule for both RetValue and ArgValue.
static PrimExpr From(const TVMPODValue_& val) {
if (val.type_code() == kTVMNullptr) {
return PrimExpr(ObjectPtr<Object>(nullptr));
}
if (val.type_code() == kDLInt) {
return PrimExpr(val.operator int());
}
if (val.type_code() == kDLFloat) {
return PrimExpr(static_cast<float>(val.operator double()));
}
TVM_CHECK_TYPE_CODE(val.type_code(), kTVMObjectHandle);
Object* ptr = val.ptr<Object>();
return PrimExpr::FromObject_(GetObjectPtr<Object>(ptr));
}
if (type_code_ == kDLFloat) {
return PrimExpr(static_cast<float>(value_.v_float64));
}

TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
return PrimExpr::FromObject_(ObjectPtr<Object>(ptr));
}
};
} // namespace runtime
} // namespace tvm
#endif // TVM_IR_EXPR_H_
1 change: 1 addition & 0 deletions include/tvm/node/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/container.h>

#include <type_traits>
#include <vector>
Expand Down
3 changes: 2 additions & 1 deletion include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ typedef enum {
kTVMStr = 11U,
kTVMBytes = 12U,
kTVMNDArrayHandle = 13U,
kTVMObjectRValueRefArg = 14U,
// Extension codes for other frameworks to integrate TVM PackedFunc.
// To make sure each framework's id do not conflict, use first and
// last sections to mark ranges.
Expand Down Expand Up @@ -290,7 +291,7 @@ TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret,
*
* \return 0 when success, -1 when failure happens.
*/
TVM_DLL int TVMCbArgToReturn(TVMValue* value, int code);
TVM_DLL int TVMCbArgToReturn(TVMValue* value, int* code);

/*!
* \brief C type of packed function.
Expand Down
20 changes: 20 additions & 0 deletions include/tvm/runtime/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <dmlc/logging.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h>

#include <cstring>
#include <initializer_list>
Expand Down Expand Up @@ -590,6 +591,25 @@ inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count,
}
}

template<>
struct PackedFuncValueConverter<::tvm::runtime::String> {
static String From(const TVMArgValue& val) {
if (val.IsObjectRef<tvm::runtime::String>()) {
return val.AsObjectRef<tvm::runtime::String>();
} else {
return tvm::runtime::String(val.operator std::string());
}
}

static String From(const TVMRetValue& val) {
if (val.IsObjectRef<tvm::runtime::String>()) {
return val.AsObjectRef<tvm::runtime::String>();
} else {
return tvm::runtime::String(val.operator std::string());
}
}
};

} // namespace runtime
} // namespace tvm

Expand Down
16 changes: 16 additions & 0 deletions include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,17 @@ class ObjectPtr {
data_->IncRef();
}
}
/*!
* \brief Move an ObjectPtr from an RValueRef argument.
* \param ref The rvalue reference.
* \return the moved result.
*/
static ObjectPtr<T> MoveFromRValueRefArg(Object** ref) {
ObjectPtr<T> ptr;
ptr.data_ = *ref;
*ref = nullptr;
return ptr;
}
// friend classes
friend class Object;
friend class ObjectRef;
Expand All @@ -489,6 +500,7 @@ class ObjectPtr {
friend class TVMArgsSetter;
friend class TVMRetValue;
friend class TVMArgValue;
friend class TVMMovableArgValue_;
template <typename RelayRefType, typename ObjType>
friend RelayRefType GetRef(const ObjType* ptr);
template <typename BaseType, typename ObjType>
Expand Down Expand Up @@ -550,6 +562,10 @@ class ObjectRef {
bool unique() const {
return data_.unique();
}
/*! \return The use count of the ptr, for debug purposes */
int use_count() const {
return data_.use_count();
}
/*!
* \brief Try to downcast the internal Object to a
* raw pointer of a corresponding type.
Expand Down
Loading