Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR] Separate ArgTypeCode from DLDataTypeCode #5730

Merged
merged 1 commit into from Jun 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/dlpack
8 changes: 4 additions & 4 deletions include/tvm/runtime/c_runtime_api.h
Expand Up @@ -87,12 +87,14 @@ typedef enum {
} TVMDeviceExtType;

/*!
* \brief The type code in used in the TVM FFI.
* \brief The type code in used in the TVM FFI for argument passing.
*/
typedef enum {
// The type code of other types are compatible with DLPack.
// The next few fields are extension types
// that is used by TVM API calls.
kTVMArgInt = kDLInt,
kTVMArgFloat = kDLFloat,
kTVMOpaqueHandle = 3U,
kTVMNullptr = 4U,
kTVMDataType = 5U,
Expand All @@ -115,9 +117,7 @@ typedef enum {
// The following section of code is used for non-reserved types.
kTVMExtReserveEnd = 64U,
kTVMExtEnd = 128U,
// The rest of the space is used for custom, user-supplied datatypes
kTVMCustomBegin = 129U,
} TVMTypeCode;
} TVMArgTypeCode;

/*!
* \brief The Device information, abstract away common device types.
Expand Down
37 changes: 8 additions & 29 deletions include/tvm/runtime/data_type.h
Expand Up @@ -45,7 +45,8 @@ class DataType {
kInt = kDLInt,
kUInt = kDLUInt,
kFloat = kDLFloat,
kHandle = TVMTypeCode::kTVMOpaqueHandle,
kHandle = TVMArgTypeCode::kTVMOpaqueHandle,
kCustomBegin = 129
};
/*! \brief default constructor */
DataType() {}
Expand Down Expand Up @@ -248,7 +249,7 @@ TVM_DLL uint8_t ParseCustomDatatype(const std::string& s, const char** scan);
* \param type_code The type code .
* \return The name of type code.
*/
inline const char* TypeCode2Str(int type_code);
inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code);

/*!
* \brief convert a string to TVM type.
Expand All @@ -265,38 +266,16 @@ inline DLDataType String2DLDataType(std::string s);
inline std::string DLDataType2String(DLDataType t);

// implementation details
inline const char* TypeCode2Str(int type_code) {
switch (type_code) {
inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) {
switch (static_cast<int>(type_code)) {
case kDLInt:
return "int";
case kDLUInt:
return "uint";
case kDLFloat:
return "float";
case kTVMStr:
return "str";
case kTVMBytes:
return "bytes";
case kTVMOpaqueHandle:
case DataType::kHandle:
return "handle";
case kTVMNullptr:
return "NULL";
case kTVMDLTensorHandle:
return "ArrayHandle";
case kTVMDataType:
return "DLDataType";
case kTVMContext:
return "TVMContext";
case kTVMPackedFuncHandle:
return "FunctionHandle";
case kTVMModuleHandle:
return "ModuleHandle";
case kTVMNDArrayHandle:
return "NDArrayContainer";
case kTVMObjectHandle:
return "Object";
case kTVMObjectRValueRefArg:
return "ObjectRValueRefArg";
default:
LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
return "";
Expand All @@ -311,8 +290,8 @@ inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*)
if (DataType(t).is_void()) {
return os << "void";
}
if (t.code < kTVMCustomBegin) {
os << TypeCode2Str(t.code);
if (t.code < DataType::kCustomBegin) {
os << DLDataTypeCode2Str(static_cast<DLDataTypeCode>(t.code));
} else {
os << "custom[" << GetCustomTypeName(t.code) << "]";
}
Expand Down
49 changes: 47 additions & 2 deletions include/tvm/runtime/packed_func.h
Expand Up @@ -327,9 +327,16 @@ class TVMArgs {
inline TVMArgValue operator[](int i) const;
};

/*!
* \brief Convert argument type code to string.
* \param type_code The input type code.
* \return The corresponding string repr.
*/
inline const char* ArgTypeCode2Str(int type_code);

// macro to check type code.
#define TVM_CHECK_TYPE_CODE(CODE, T) \
CHECK_EQ(CODE, T) << " expected " << TypeCode2Str(T) << " but get " << TypeCode2Str(CODE)
CHECK_EQ(CODE, T) << " expected " << ArgTypeCode2Str(T) << " but get " << ArgTypeCode2Str(CODE)

/*!
* \brief Type traits for runtime type check during FFI conversion.
Expand Down Expand Up @@ -394,7 +401,7 @@ class TVMPODValue_ {
} else {
if (type_code_ == kTVMNullptr) return nullptr;
LOG(FATAL) << "Expect "
<< "DLTensor* or NDArray but get " << TypeCode2Str(type_code_);
<< "DLTensor* or NDArray but get " << ArgTypeCode2Str(type_code_);
return nullptr;
}
}
Expand Down Expand Up @@ -982,6 +989,44 @@ inline void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const { body_(
inline PackedFunc::FType PackedFunc::body() const { return body_; }

// internal namespace
inline const char* ArgTypeCode2Str(int type_code) {
switch (type_code) {
case kDLInt:
return "int";
case kDLUInt:
return "uint";
case kDLFloat:
return "float";
case kTVMStr:
return "str";
case kTVMBytes:
return "bytes";
case kTVMOpaqueHandle:
return "handle";
case kTVMNullptr:
return "NULL";
case kTVMDLTensorHandle:
return "ArrayHandle";
case kTVMDataType:
return "DLDataType";
case kTVMContext:
return "TVMContext";
case kTVMPackedFuncHandle:
return "FunctionHandle";
case kTVMModuleHandle:
return "ModuleHandle";
case kTVMNDArrayHandle:
return "NDArrayContainer";
case kTVMObjectHandle:
return "Object";
case kTVMObjectRValueRefArg:
return "ObjectRValueRefArg";
default:
LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
return "";
}
}

namespace detail {

template <bool stop, std::size_t I, typename F>
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/tir/op.h
Expand Up @@ -740,7 +740,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value) {
// datatypes lowering pass, we will lower the value to its true representation in the format
// specified by the datatype.
// TODO(gus) when do we need to start worrying about doubles not being precise enough?
if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(kTVMCustomBegin)) {
if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(DataType::kCustomBegin)) {
return FloatImm(t, static_cast<double>(value));
}
LOG(FATAL) << "cannot make const for type " << t;
Expand Down
Expand Up @@ -18,14 +18,14 @@
package org.apache.tvm;

// Type code used in API calls
public enum TypeCode {
public enum ArgTypeCode {
INT(0), UINT(1), FLOAT(2), HANDLE(3), NULL(4), TVM_TYPE(5),
TVM_CONTEXT(6), ARRAY_HANDLE(7), NODE_HANDLE(8), MODULE_HANDLE(9),
FUNC_HANDLE(10), STR(11), BYTES(12), NDARRAY_CONTAINER(13);

public final int id;

private TypeCode(int id) {
private ArgTypeCode(int id) {
this.id = id;
}

Expand Down
14 changes: 7 additions & 7 deletions jvm/core/src/main/java/org/apache/tvm/Function.java
Expand Up @@ -80,7 +80,7 @@ private static Function getGlobalFunc(String name, boolean isResident, boolean a
* @param isResident Whether this is a resident function in jvm
*/
Function(long handle, boolean isResident) {
super(TypeCode.FUNC_HANDLE);
super(ArgTypeCode.FUNC_HANDLE);
this.handle = handle;
this.isResident = isResident;
}
Expand Down Expand Up @@ -187,7 +187,7 @@ public Function pushArg(String arg) {
* @return this
*/
public Function pushArg(NDArrayBase arg) {
int id = arg.isView ? TypeCode.ARRAY_HANDLE.id : TypeCode.NDARRAY_CONTAINER.id;
int id = arg.isView ? ArgTypeCode.ARRAY_HANDLE.id : ArgTypeCode.NDARRAY_CONTAINER.id;
Base._LIB.tvmFuncPushArgHandle(arg.handle, id);
return this;
}
Expand All @@ -198,7 +198,7 @@ public Function pushArg(NDArrayBase arg) {
* @return this
*/
public Function pushArg(Module arg) {
Base._LIB.tvmFuncPushArgHandle(arg.handle, TypeCode.MODULE_HANDLE.id);
Base._LIB.tvmFuncPushArgHandle(arg.handle, ArgTypeCode.MODULE_HANDLE.id);
return this;
}

Expand All @@ -208,7 +208,7 @@ public Function pushArg(Module arg) {
* @return this
*/
public Function pushArg(Function arg) {
Base._LIB.tvmFuncPushArgHandle(arg.handle, TypeCode.FUNC_HANDLE.id);
Base._LIB.tvmFuncPushArgHandle(arg.handle, ArgTypeCode.FUNC_HANDLE.id);
return this;
}

Expand Down Expand Up @@ -249,12 +249,12 @@ private static void pushArgToStack(Object arg) {
Base._LIB.tvmFuncPushArgBytes((byte[]) arg);
} else if (arg instanceof NDArrayBase) {
NDArrayBase nd = (NDArrayBase) arg;
int id = nd.isView ? TypeCode.ARRAY_HANDLE.id : TypeCode.NDARRAY_CONTAINER.id;
int id = nd.isView ? ArgTypeCode.ARRAY_HANDLE.id : ArgTypeCode.NDARRAY_CONTAINER.id;
Base._LIB.tvmFuncPushArgHandle(nd.handle, id);
} else if (arg instanceof Module) {
Base._LIB.tvmFuncPushArgHandle(((Module) arg).handle, TypeCode.MODULE_HANDLE.id);
Base._LIB.tvmFuncPushArgHandle(((Module) arg).handle, ArgTypeCode.MODULE_HANDLE.id);
} else if (arg instanceof Function) {
Base._LIB.tvmFuncPushArgHandle(((Function) arg).handle, TypeCode.FUNC_HANDLE.id);
Base._LIB.tvmFuncPushArgHandle(((Function) arg).handle, ArgTypeCode.FUNC_HANDLE.id);
} else if (arg instanceof TVMValue) {
TVMValue tvmArg = (TVMValue) arg;
switch (tvmArg.typeCode) {
Expand Down
4 changes: 2 additions & 2 deletions jvm/core/src/main/java/org/apache/tvm/Module.java
Expand Up @@ -45,7 +45,7 @@ private static Function getApi(String name) {
}

Module(long handle) {
super(TypeCode.MODULE_HANDLE);
super(ArgTypeCode.MODULE_HANDLE);
this.handle = handle;
}

Expand Down Expand Up @@ -138,7 +138,7 @@ public String typeKey() {
*/
public static Module load(String path, String fmt) {
TVMValue ret = getApi("ModuleLoadFromFile").pushArg(path).pushArg(fmt).invoke();
assert ret.typeCode == TypeCode.MODULE_HANDLE;
assert ret.typeCode == ArgTypeCode.MODULE_HANDLE;
return ret.asModule();
}

Expand Down
2 changes: 1 addition & 1 deletion jvm/core/src/main/java/org/apache/tvm/NDArrayBase.java
Expand Up @@ -27,7 +27,7 @@ public class NDArrayBase extends TVMValue {
private boolean isReleased = false;

NDArrayBase(long handle, boolean isView) {
super(TypeCode.ARRAY_HANDLE);
super(ArgTypeCode.ARRAY_HANDLE);
this.handle = handle;
this.isView = isView;
}
Expand Down
4 changes: 2 additions & 2 deletions jvm/core/src/main/java/org/apache/tvm/TVMValue.java
Expand Up @@ -18,9 +18,9 @@
package org.apache.tvm;

public class TVMValue {
public final TypeCode typeCode;
public final ArgTypeCode typeCode;

public TVMValue(TypeCode tc) {
public TVMValue(ArgTypeCode tc) {
typeCode = tc;
}

Expand Down
2 changes: 1 addition & 1 deletion jvm/core/src/main/java/org/apache/tvm/TVMValueBytes.java
Expand Up @@ -21,7 +21,7 @@ public class TVMValueBytes extends TVMValue {
public final byte[] value;

public TVMValueBytes(byte[] value) {
super(TypeCode.BYTES);
super(ArgTypeCode.BYTES);
this.value = value;
}

Expand Down
2 changes: 1 addition & 1 deletion jvm/core/src/main/java/org/apache/tvm/TVMValueDouble.java
Expand Up @@ -21,7 +21,7 @@ public class TVMValueDouble extends TVMValue {
public final double value;

public TVMValueDouble(double value) {
super(TypeCode.FLOAT);
super(ArgTypeCode.FLOAT);
this.value = value;
}

Expand Down
4 changes: 2 additions & 2 deletions jvm/core/src/main/java/org/apache/tvm/TVMValueHandle.java
Expand Up @@ -18,13 +18,13 @@
package org.apache.tvm;

/**
* Java class related to TVM handles (TypeCode.HANDLE)
* Java class related to TVM handles (ArgTypeCode.HANDLE)
*/
public class TVMValueHandle extends TVMValue {
public final long value;

public TVMValueHandle(long value) {
super(TypeCode.HANDLE);
super(ArgTypeCode.HANDLE);
this.value = value;
}

Expand Down
2 changes: 1 addition & 1 deletion jvm/core/src/main/java/org/apache/tvm/TVMValueLong.java
Expand Up @@ -21,7 +21,7 @@ public class TVMValueLong extends TVMValue {
public final long value;

public TVMValueLong(long value) {
super(TypeCode.INT);
super(ArgTypeCode.INT);
this.value = value;
}

Expand Down
2 changes: 1 addition & 1 deletion jvm/core/src/main/java/org/apache/tvm/TVMValueNull.java
Expand Up @@ -19,6 +19,6 @@

public class TVMValueNull extends TVMValue {
public TVMValueNull() {
super(TypeCode.NULL);
super(ArgTypeCode.NULL);
}
}
2 changes: 1 addition & 1 deletion jvm/core/src/main/java/org/apache/tvm/TVMValueString.java
Expand Up @@ -21,7 +21,7 @@ public class TVMValueString extends TVMValue {
public final String value;

public TVMValueString(String value) {
super(TypeCode.STR);
super(ArgTypeCode.STR);
this.value = value;
}

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/__init__.py
Expand Up @@ -23,7 +23,7 @@
# top-level alias
# tvm._ffi
from ._ffi.base import TVMError, __version__
from ._ffi.runtime_ctypes import TypeCode, DataType
from ._ffi.runtime_ctypes import DataTypeCode, DataType
from ._ffi import register_object, register_func, register_extension, get_global_func

# top-level alias
Expand Down
12 changes: 6 additions & 6 deletions python/tvm/_ffi/_ctypes/object.py
Expand Up @@ -18,7 +18,7 @@
"""Runtime Object api"""
import ctypes
from ..base import _LIB, check_call
from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
from .types import ArgTypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
from .ndarray import _register_ndarray, NDArrayBase


Expand Down Expand Up @@ -60,12 +60,12 @@ def _return_object(x):
obj.handle = handle
return obj

RETURN_SWITCH[TypeCode.OBJECT_HANDLE] = _return_object
C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_HANDLE] = _wrap_arg_func(
_return_object, TypeCode.OBJECT_HANDLE)
RETURN_SWITCH[ArgTypeCode.OBJECT_HANDLE] = _return_object
C_TO_PY_ARG_SWITCH[ArgTypeCode.OBJECT_HANDLE] = _wrap_arg_func(
_return_object, ArgTypeCode.OBJECT_HANDLE)

C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_RVALUE_REF_ARG] = _wrap_arg_func(
_return_object, TypeCode.OBJECT_RVALUE_REF_ARG)
C_TO_PY_ARG_SWITCH[ArgTypeCode.OBJECT_RVALUE_REF_ARG] = _wrap_arg_func(
_return_object, ArgTypeCode.OBJECT_RVALUE_REF_ARG)


class PyNativeObject:
Expand Down