Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions python/tvm/_ffi/_ctypes/packed_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def convert_to_tvm_func(pyfunc):
def cfun(args, type_codes, num_args, ret, _):
"""ctypes function"""
num_args = num_args.value if isinstance(num_args, ctypes.c_int) else num_args
pyargs = (C_TO_PY_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args))
pyargs = (C_TO_PY_ARG_SWITCH[ArgTypeCode(type_codes[i])](args[i]) for i in range(num_args))
# pylint: disable=broad-except
try:
rv = local_pyfunc(*pyargs)
Expand Down Expand Up @@ -117,33 +117,33 @@ def _make_tvm_args(args, temp_args):
for i, arg in enumerate(args):
if isinstance(arg, ObjectBase):
values[i].v_handle = arg.handle
type_codes[i] = ArgTypeCode.OBJECT_HANDLE
type_codes[i] = ArgTypeCode.OBJECT_HANDLE.value
elif arg is None:
values[i].v_handle = None
type_codes[i] = ArgTypeCode.NULL
type_codes[i] = ArgTypeCode.NULL.value
elif isinstance(arg, NDArrayBase):
values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)
type_codes[i] = (
ArgTypeCode.NDARRAY_HANDLE if not arg.is_view else ArgTypeCode.DLTENSOR_HANDLE
ArgTypeCode.NDARRAY_HANDLE.value if not arg.is_view else ArgTypeCode.DLTENSOR_HANDLE.value
)
elif isinstance(arg, PyNativeObject):
values[i].v_handle = arg.__tvm_object__.handle
type_codes[i] = ArgTypeCode.OBJECT_HANDLE
type_codes[i] = ArgTypeCode.OBJECT_HANDLE.value
elif isinstance(arg, _nd._TVM_COMPATS):
values[i].v_handle = ctypes.c_void_p(arg._tvm_handle)
type_codes[i] = arg.__class__._tvm_tcode
elif isinstance(arg, Integral):
values[i].v_int64 = arg
type_codes[i] = ArgTypeCode.INT
type_codes[i] = ArgTypeCode.INT.value
elif isinstance(arg, Number):
values[i].v_float64 = arg
type_codes[i] = ArgTypeCode.FLOAT
type_codes[i] = ArgTypeCode.FLOAT.value
elif isinstance(arg, DataType):
values[i].v_str = c_str(str(arg))
type_codes[i] = ArgTypeCode.STR
values[i].v_type = arg
type_codes[i] = ArgTypeCode.TVM_TYPE.value
elif isinstance(arg, Device):
values[i].v_int64 = _device_to_int64(arg)
type_codes[i] = ArgTypeCode.DLDEVICE
type_codes[i] = ArgTypeCode.DLDEVICE.value
elif isinstance(arg, (bytearray, bytes)):
# from_buffer only taeks in bytearray.
if isinstance(arg, bytes):
Expand All @@ -158,31 +158,31 @@ def _make_tvm_args(args, temp_args):
arr.size = len(arg)
values[i].v_handle = ctypes.c_void_p(ctypes.addressof(arr))
temp_args.append(arr)
type_codes[i] = ArgTypeCode.BYTES
type_codes[i] = ArgTypeCode.BYTES.value
elif isinstance(arg, string_types):
values[i].v_str = c_str(arg)
type_codes[i] = ArgTypeCode.STR
type_codes[i] = ArgTypeCode.STR.value
elif isinstance(arg, (list, tuple, dict, _CLASS_OBJECT_GENERIC)):
arg = _FUNC_CONVERT_TO_OBJECT(arg)
values[i].v_handle = arg.handle
type_codes[i] = ArgTypeCode.OBJECT_HANDLE
type_codes[i] = ArgTypeCode.OBJECT_HANDLE.value
temp_args.append(arg)
elif isinstance(arg, _CLASS_MODULE):
values[i].v_handle = arg.handle
type_codes[i] = ArgTypeCode.MODULE_HANDLE
type_codes[i] = ArgTypeCode.MODULE_HANDLE.value
elif isinstance(arg, PackedFuncBase):
values[i].v_handle = arg.handle
type_codes[i] = ArgTypeCode.PACKED_FUNC_HANDLE
type_codes[i] = ArgTypeCode.PACKED_FUNC_HANDLE.value
elif isinstance(arg, ctypes.c_void_p):
values[i].v_handle = arg
type_codes[i] = ArgTypeCode.HANDLE
type_codes[i] = ArgTypeCode.HANDLE.value
elif isinstance(arg, ObjectRValueRef):
values[i].v_handle = ctypes.cast(ctypes.byref(arg.obj.handle), ctypes.c_void_p)
type_codes[i] = ArgTypeCode.OBJECT_RVALUE_REF_ARG
type_codes[i] = ArgTypeCode.OBJECT_RVALUE_REF_ARG.value
elif callable(arg):
arg = convert_to_tvm_func(arg)
values[i].v_handle = arg.handle
type_codes[i] = ArgTypeCode.PACKED_FUNC_HANDLE
type_codes[i] = ArgTypeCode.PACKED_FUNC_HANDLE.value
temp_args.append(arg)
else:
raise TypeError("Don't know how to handle type %s" % type(arg))
Expand Down Expand Up @@ -237,7 +237,7 @@ def __call__(self, *args):
raise get_last_ffi_error()
_ = temp_args
_ = args
return RETURN_SWITCH[ret_tcode.value](ret_val)
return RETURN_SWITCH[ArgTypeCode(ret_tcode.value)](ret_val)


def __init_handle_by_constructor__(fconstructor, args):
Expand All @@ -260,7 +260,7 @@ def __init_handle_by_constructor__(fconstructor, args):
raise get_last_ffi_error()
_ = temp_args
_ = args
assert ret_tcode.value == ArgTypeCode.OBJECT_HANDLE
assert ArgTypeCode(ret_tcode.value) == ArgTypeCode.OBJECT_HANDLE
handle = ret_val.v_handle
return handle

Expand Down Expand Up @@ -294,6 +294,7 @@ def _get_global_func(name, allow_missing=False):
raise ValueError("Cannot find global function %s" % name)



# setup return handle for function type
_object.__init_by_constructor__ = __init_handle_by_constructor__
RETURN_SWITCH[ArgTypeCode.PACKED_FUNC_HANDLE] = _handle_return_func
Expand Down
9 changes: 6 additions & 3 deletions python/tvm/_ffi/_ctypes/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import ctypes
import struct
from ..base import py_str, check_call, _LIB
from ..runtime_ctypes import TVMByteArray, ArgTypeCode, Device
from ..runtime_ctypes import TVMByteArray, ArgTypeCode, Device,DataType


class TVMValue(ctypes.Union):
Expand All @@ -30,6 +30,7 @@ class TVMValue(ctypes.Union):
("v_float64", ctypes.c_double),
("v_handle", ctypes.c_void_p),
("v_str", ctypes.c_char_p),
('v_type', DataType),
]


Expand Down Expand Up @@ -77,9 +78,9 @@ def _return_device(value):
return Device(arr[0], arr[1])


def _wrap_arg_func(return_f, type_code):
def _wrap_arg_func(return_f, type_code: ArgTypeCode):
def _wrap_func(x):
tcode = ctypes.c_int(type_code)
tcode = ctypes.c_int(type_code.value)
check_call(_LIB.TVMCbArgToReturn(ctypes.byref(x), ctypes.byref(tcode)))
return return_f(x)

Expand All @@ -97,6 +98,7 @@ def _device_to_int64(dev):
ArgTypeCode.FLOAT: lambda x: x.v_float64,
ArgTypeCode.HANDLE: _return_handle,
ArgTypeCode.NULL: lambda x: None,
ArgTypeCode.TVM_TYPE: lambda x: x.v_type,
ArgTypeCode.STR: lambda x: py_str(x.v_str),
ArgTypeCode.BYTES: _return_bytes,
ArgTypeCode.DLDEVICE: _return_device,
Expand All @@ -107,6 +109,7 @@ def _device_to_int64(dev):
ArgTypeCode.FLOAT: lambda x: x.v_float64,
ArgTypeCode.HANDLE: _return_handle,
ArgTypeCode.NULL: lambda x: None,
ArgTypeCode.TVM_TYPE: lambda x: x.v_type,
ArgTypeCode.STR: lambda x: py_str(x.v_str),
ArgTypeCode.BYTES: _return_bytes,
ArgTypeCode.DLDEVICE: _return_device,
Expand Down
Loading