From ff2f8f0c1b53dff0608fa9ba1beb3b4f9913a512 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 2 Dec 2022 08:47:08 -0600 Subject: [PATCH] [FFI] Remove string conversion of tvm::runtime::DataType Previously, the FFI would automatically convert all instances of `tvm::runtime::DataType` to a string for FFI usage. This was an unnecessary step of formatting/parsing to every FFI call with `tvm::runtime::DataType` arguments, and resulted in duplicate parsing/formatting implementations in C++ and Python. This commit updates the FFI to pass `tvm::runtime::DataType` directly, using the existing `TVMArgTypeCode::kTVMDataType` type code. The `tvm.DataType` wrapper class is updated with additional methods for backwards compatibility (e.g. `"float" in dtype`, `bits = int(dtype[-2:])`), which can be phased out over time. --- python/tvm/_ffi/_ctypes/packed_func.py | 41 ++--- python/tvm/_ffi/_ctypes/types.py | 9 +- python/tvm/_ffi/runtime_ctypes.py | 226 ++++++++++++++++--------- python/tvm/tir/ir_builder.py | 12 +- src/runtime/c_runtime_api.cc | 8 +- src/runtime/data_type.cc | 33 ++++ 6 files changed, 214 insertions(+), 115 deletions(-) create mode 100644 src/runtime/data_type.cc diff --git a/python/tvm/_ffi/_ctypes/packed_func.py b/python/tvm/_ffi/_ctypes/packed_func.py index ee6ed05a74f7..a103535892fd 100644 --- a/python/tvm/_ffi/_ctypes/packed_func.py +++ b/python/tvm/_ffi/_ctypes/packed_func.py @@ -75,7 +75,7 @@ def convert_to_tvm_func(pyfunc): def cfun(args, type_codes, num_args, ret, _): """ctypes function""" num_args = num_args.value if isinstance(num_args, ctypes.c_int) else num_args - pyargs = (C_TO_PY_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args)) + pyargs = (C_TO_PY_ARG_SWITCH[ArgTypeCode(type_codes[i])](args[i]) for i in range(num_args)) # pylint: disable=broad-except try: rv = local_pyfunc(*pyargs) @@ -117,33 +117,33 @@ def _make_tvm_args(args, temp_args): for i, arg in enumerate(args): if isinstance(arg, ObjectBase): values[i].v_handle = arg.handle - type_codes[i] = ArgTypeCode.OBJECT_HANDLE + type_codes[i] = ArgTypeCode.OBJECT_HANDLE.value elif arg is None: values[i].v_handle = None - type_codes[i] = ArgTypeCode.NULL + type_codes[i] = ArgTypeCode.NULL.value elif isinstance(arg, NDArrayBase): values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p) type_codes[i] = ( - ArgTypeCode.NDARRAY_HANDLE if not arg.is_view else ArgTypeCode.DLTENSOR_HANDLE + ArgTypeCode.NDARRAY_HANDLE.value if not arg.is_view else ArgTypeCode.DLTENSOR_HANDLE.value ) elif isinstance(arg, PyNativeObject): values[i].v_handle = arg.__tvm_object__.handle - type_codes[i] = ArgTypeCode.OBJECT_HANDLE + type_codes[i] = ArgTypeCode.OBJECT_HANDLE.value elif isinstance(arg, _nd._TVM_COMPATS): values[i].v_handle = ctypes.c_void_p(arg._tvm_handle) type_codes[i] = arg.__class__._tvm_tcode elif isinstance(arg, Integral): values[i].v_int64 = arg - type_codes[i] = ArgTypeCode.INT + type_codes[i] = ArgTypeCode.INT.value elif isinstance(arg, Number): values[i].v_float64 = arg - type_codes[i] = ArgTypeCode.FLOAT + type_codes[i] = ArgTypeCode.FLOAT.value elif isinstance(arg, DataType): - values[i].v_str = c_str(str(arg)) - type_codes[i] = ArgTypeCode.STR + values[i].v_type = arg + type_codes[i] = ArgTypeCode.TVM_TYPE.value elif isinstance(arg, Device): values[i].v_int64 = _device_to_int64(arg) - type_codes[i] = ArgTypeCode.DLDEVICE + type_codes[i] = ArgTypeCode.DLDEVICE.value elif isinstance(arg, (bytearray, bytes)): # from_buffer only taeks in bytearray. if isinstance(arg, bytes): @@ -158,31 +158,31 @@ def _make_tvm_args(args, temp_args): arr.size = len(arg) values[i].v_handle = ctypes.c_void_p(ctypes.addressof(arr)) temp_args.append(arr) - type_codes[i] = ArgTypeCode.BYTES + type_codes[i] = ArgTypeCode.BYTES.value elif isinstance(arg, string_types): values[i].v_str = c_str(arg) - type_codes[i] = ArgTypeCode.STR + type_codes[i] = ArgTypeCode.STR.value elif isinstance(arg, (list, tuple, dict, _CLASS_OBJECT_GENERIC)): arg = _FUNC_CONVERT_TO_OBJECT(arg) values[i].v_handle = arg.handle - type_codes[i] = ArgTypeCode.OBJECT_HANDLE + type_codes[i] = ArgTypeCode.OBJECT_HANDLE.value temp_args.append(arg) elif isinstance(arg, _CLASS_MODULE): values[i].v_handle = arg.handle - type_codes[i] = ArgTypeCode.MODULE_HANDLE + type_codes[i] = ArgTypeCode.MODULE_HANDLE.value elif isinstance(arg, PackedFuncBase): values[i].v_handle = arg.handle - type_codes[i] = ArgTypeCode.PACKED_FUNC_HANDLE + type_codes[i] = ArgTypeCode.PACKED_FUNC_HANDLE.value elif isinstance(arg, ctypes.c_void_p): values[i].v_handle = arg - type_codes[i] = ArgTypeCode.HANDLE + type_codes[i] = ArgTypeCode.HANDLE.value elif isinstance(arg, ObjectRValueRef): values[i].v_handle = ctypes.cast(ctypes.byref(arg.obj.handle), ctypes.c_void_p) - type_codes[i] = ArgTypeCode.OBJECT_RVALUE_REF_ARG + type_codes[i] = ArgTypeCode.OBJECT_RVALUE_REF_ARG.value elif callable(arg): arg = convert_to_tvm_func(arg) values[i].v_handle = arg.handle - type_codes[i] = ArgTypeCode.PACKED_FUNC_HANDLE + type_codes[i] = ArgTypeCode.PACKED_FUNC_HANDLE.value temp_args.append(arg) else: raise TypeError("Don't know how to handle type %s" % type(arg)) @@ -237,7 +237,7 @@ def __call__(self, *args): raise get_last_ffi_error() _ = temp_args _ = args - return RETURN_SWITCH[ret_tcode.value](ret_val) + return RETURN_SWITCH[ArgTypeCode(ret_tcode.value)](ret_val) def __init_handle_by_constructor__(fconstructor, args): @@ -260,7 +260,7 @@ def __init_handle_by_constructor__(fconstructor, args): raise get_last_ffi_error() _ = temp_args _ = args - assert ret_tcode.value == ArgTypeCode.OBJECT_HANDLE + assert ArgTypeCode(ret_tcode.value) == ArgTypeCode.OBJECT_HANDLE handle = ret_val.v_handle return handle @@ -294,6 +294,7 @@ def _get_global_func(name, allow_missing=False): raise ValueError("Cannot find global function %s" % name) + # setup return handle for function type _object.__init_by_constructor__ = __init_handle_by_constructor__ RETURN_SWITCH[ArgTypeCode.PACKED_FUNC_HANDLE] = _handle_return_func diff --git a/python/tvm/_ffi/_ctypes/types.py b/python/tvm/_ffi/_ctypes/types.py index 38d3cd72b55d..3acf114fdaa3 100644 --- a/python/tvm/_ffi/_ctypes/types.py +++ b/python/tvm/_ffi/_ctypes/types.py @@ -19,7 +19,7 @@ import ctypes import struct from ..base import py_str, check_call, _LIB -from ..runtime_ctypes import TVMByteArray, ArgTypeCode, Device +from ..runtime_ctypes import TVMByteArray, ArgTypeCode, Device,DataType class TVMValue(ctypes.Union): @@ -30,6 +30,7 @@ class TVMValue(ctypes.Union): ("v_float64", ctypes.c_double), ("v_handle", ctypes.c_void_p), ("v_str", ctypes.c_char_p), + ('v_type', DataType), ] @@ -77,9 +78,9 @@ def _return_device(value): return Device(arr[0], arr[1]) -def _wrap_arg_func(return_f, type_code): +def _wrap_arg_func(return_f, type_code: ArgTypeCode): def _wrap_func(x): - tcode = ctypes.c_int(type_code) + tcode = ctypes.c_int(type_code.value) check_call(_LIB.TVMCbArgToReturn(ctypes.byref(x), ctypes.byref(tcode))) return return_f(x) @@ -97,6 +98,7 @@ def _device_to_int64(dev): ArgTypeCode.FLOAT: lambda x: x.v_float64, ArgTypeCode.HANDLE: _return_handle, ArgTypeCode.NULL: lambda x: None, + ArgTypeCode.TVM_TYPE: lambda x: x.v_type, ArgTypeCode.STR: lambda x: py_str(x.v_str), ArgTypeCode.BYTES: _return_bytes, ArgTypeCode.DLDEVICE: _return_device, @@ -107,6 +109,7 @@ def _device_to_int64(dev): ArgTypeCode.FLOAT: lambda x: x.v_float64, ArgTypeCode.HANDLE: _return_handle, ArgTypeCode.NULL: lambda x: None, + ArgTypeCode.TVM_TYPE: lambda x: x.v_type, ArgTypeCode.STR: lambda x: py_str(x.v_str), ArgTypeCode.BYTES: _return_bytes, ArgTypeCode.DLDEVICE: _return_device, diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index fa12bf9ce37a..210a423798db 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -16,15 +16,21 @@ # under the License. """Common runtime ctypes.""" # pylint: disable=invalid-name + import ctypes +import enum +import functools import json import numpy as np +from typing import Union, Optional + from .base import _LIB, check_call +import tvm tvm_shape_index_t = ctypes.c_int64 -class ArgTypeCode(object): +class ArgTypeCode(enum.Enum): """Type code used in API calls""" INT = 0 @@ -51,7 +57,7 @@ class TVMByteArray(ctypes.Structure): _fields_ = [("data", ctypes.POINTER(ctypes.c_byte)), ("size", ctypes.c_size_t)] -class DataTypeCode(object): +class DataTypeCode(enum.Enum): """DataType code in DLTensor.""" INT = 0 @@ -64,14 +70,11 @@ class DataTypeCode(object): class DataType(ctypes.Structure): """TVM datatype structure""" - _fields_ = [("type_code", ctypes.c_uint8), ("bits", ctypes.c_uint8), ("lanes", ctypes.c_uint16)] - CODE2STR = { - DataTypeCode.INT: "int", - DataTypeCode.UINT: "uint", - DataTypeCode.FLOAT: "float", - DataTypeCode.HANDLE: "handle", - DataTypeCode.BFLOAT: "bfloat", - } + _fields_ = [ + ("_type_code", ctypes.c_uint8), + ("bits", ctypes.c_uint8), + ("lanes", ctypes.c_uint16), + ] NUMPY2STR = { np.dtype(np.bool_): "bool", np.dtype(np.int8): "int8", @@ -87,23 +90,82 @@ class DataType(ctypes.Structure): np.dtype(np.float64): "float64", np.dtype(np.float_): "float64", } - STR2DTYPE = { - "bool": {"type_code": DataTypeCode.UINT, "bits": 1, "lanes": 1}, - "int8": {"type_code": DataTypeCode.INT, "bits": 8, "lanes": 1}, - "int16": {"type_code": DataTypeCode.INT, "bits": 16, "lanes": 1}, - "int32": {"type_code": DataTypeCode.INT, "bits": 32, "lanes": 1}, - "int64": {"type_code": DataTypeCode.INT, "bits": 64, "lanes": 1}, - "uint8": {"type_code": DataTypeCode.UINT, "bits": 8, "lanes": 1}, - "uint16": {"type_code": DataTypeCode.UINT, "bits": 16, "lanes": 1}, - "uint32": {"type_code": DataTypeCode.UINT, "bits": 32, "lanes": 1}, - "uint64": {"type_code": DataTypeCode.UINT, "bits": 64, "lanes": 1}, - "float16": {"type_code": DataTypeCode.FLOAT, "bits": 16, "lanes": 1}, - "float32": {"type_code": DataTypeCode.FLOAT, "bits": 32, "lanes": 1}, - "float64": {"type_code": DataTypeCode.FLOAT, "bits": 64, "lanes": 1}, - } - def __init__(self, type_str): + def __init__( + self, + value: Union[int, str, np.dtype, "DataType"], + bits: Optional[int] = None, + lanes: Optional[int] = None, + ): super(DataType, self).__init__() + + if isinstance(value, (str, np.dtype)): + value = self._unpack_str(value) + + if isinstance(value, DataType): + assert bits is None + assert lanes is None + self.type_code = value.type_code + self.bits = value.bits + self.lanes = value.lanes + else: + assert bits is not None + assert lanes is not None + self.type_code = value + self.bits = bits + self.lanes = lanes + + @property + def type_code(self) -> DataTypeCode: + """The type code of the datatype + + This internal field must be a `ctypes.c_uint8` to match the + struct definition. This wrapper allows the Python API to + present the `enum.Enum` subclass. + """ + return DataTypeCode(self._type_code) + + @type_code.setter + def type_code(self, val: Union[DataTypeCode, int]): + # Round trip through DataTypeCode ensures that the integer + # provided is valid. + if isinstance(val, int): + val = DataTypeCode(val) + + self._type_code = val.value + + def with_lanes(self, lanes: int) -> "DataType": + """Return the current datatype with the specified lanes""" + return DataType(self.type_code, self.bits, lanes) + + @classmethod + def _ffi_string_to_data_type_func(cls): + func = getattr(cls, "_string_to_data_type_func", None) + if func: + return func + + import tvm # pylint: disable=import-outside-toplevel + + cls._string_to_data_type = func = tvm._ffi.registry.get_global_func( + "runtime.String2DLDataType" + ) + return func + + @classmethod + def _ffi_data_type_to_string_func(cls): + func = getattr(cls, "_data_type_to_string_func", None) + if func: + return func + + import tvm # pylint: disable=import-outside-toplevel + + cls._data_type_to_string = func = tvm._ffi.registry.get_global_func( + "runtime.DLDataType2String" + ) + return func + + @classmethod + def _unpack_str(cls, type_str): numpy_str_map = DataType.NUMPY2STR if type_str in numpy_str_map: type_str = numpy_str_map[type_str] @@ -112,66 +174,18 @@ def __init__(self, type_str): assert isinstance(type_str, str) - str_dtype_map = DataType.STR2DTYPE - if type_str in str_dtype_map: - dtype_map = str_dtype_map[type_str] - self.bits = dtype_map["bits"] - self.type_code = dtype_map["type_code"] - self.lanes = dtype_map["lanes"] - return - - arr = type_str.split("x") - head = arr[0] - self.lanes = int(arr[1]) if len(arr) > 1 else 1 - bits = 32 - - if head.startswith("int"): - self.type_code = DataTypeCode.INT - head = head[3:] - elif head.startswith("uint"): - self.type_code = DataTypeCode.UINT - head = head[4:] - elif head.startswith("float"): - self.type_code = DataTypeCode.FLOAT - head = head[5:] - elif head.startswith("handle"): - self.type_code = DataTypeCode.HANDLE - bits = 64 - head = "" - elif head.startswith("bfloat"): - self.type_code = DataTypeCode.BFLOAT - head = head[6:] - elif head.startswith("custom"): - # pylint: disable=import-outside-toplevel - import tvm.runtime._ffi_api - - low, high = head.find("["), head.find("]") - if not low or not high or low >= high: - raise ValueError("Badly formatted custom type string %s" % type_str) - type_name = head[low + 1 : high] - self.type_code = tvm.runtime._ffi_api._datatype_get_type_code(type_name) - head = head[high + 1 :] - else: - raise ValueError("Do not know how to handle type %s" % type_str) - bits = int(head) if head else bits - self.bits = bits + return cls._ffi_string_to_data_type_func()(type_str) - def __repr__(self): - # pylint: disable=import-outside-toplevel - if self.bits == 1 and self.lanes == 1: - return "bool" - if self.type_code in DataType.CODE2STR: - type_name = DataType.CODE2STR[self.type_code] - else: - import tvm.runtime._ffi_api + def __str__(self): + return self._ffi_data_type_to_string_func()(self) - type_name = "custom[%s]" % tvm.runtime._ffi_api._datatype_get_type_name(self.type_code) - x = "%s%d" % (type_name, self.bits) - if self.lanes != 1: - x += "x%d" % self.lanes - return x + def __repr__(self): + return f'DataType("{str(self)}")' def __eq__(self, other): + if isinstance(other, (str, np.dtype)): + other = DataType(other) + return ( self.bits == other.bits and self.type_code == other.type_code @@ -181,6 +195,56 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) + def __hash__(self): + return (self.type_code, self.bits, self.lanes).__hash__() + + def __contains__(self, search): + """Backwards compatibility wrapper + + To support use of the datatype as a string. Use should be + avoided in the future. + + Example + ------- + + .. code-block:: python + + # Old method, supported by this wrapper + is_floating_point = "float" in dtype + + # New method, preferred + is_floating_point = dtype.type_code == DataTypeCode.FLOAT + """ + return search in str(self) + + def __getitem__(self, index): + """Backwards compatibility wrapper + + To support use of the datatype as a string. Use should be + avoided in the future. + + Example + ------- + + .. code-block:: python + + # Old method, supported by this wrapper + bits = int(dtype[-2:]) + + # New method, preferred + bits = dtype.bits + """ + return str(self)[index] + + @property + def dtype(self): + """Converter attribute to allow use as a np.dtype + + See https://numpy.org/doc/stable/reference/arrays.dtypes.html, + under section "Types with .dtype" + """ + return str(self) + RPC_SESS_MASK = 128 diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index ce8cd1b403bc..558f4ad2f3a7 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -20,6 +20,8 @@ from tvm.runtime import ObjectGeneric, convert, const from tvm.ir import container as _container +from typing import Union + from . import stmt as _stmt from . import expr as _expr from . import buffer as _buffer @@ -72,10 +74,10 @@ class BufferVar(ObjectGeneric): """ - def __init__(self, builder, buffer, content_type): + def __init__(self, builder, buffer, content_type: Union[str, tvm.DataType]): self._builder = builder self._buffer = buffer - self._content_type = content_type + self._content_type = tvm.DataType(content_type) def asobject(self): return self._buffer @@ -108,8 +110,8 @@ def __setitem__(self, index, value): index = self._normalize_index(index) value = convert(value) - value_element = value.dtype.split("x", maxsplit=1)[0] - content_element = self._content_type.split("x", maxsplit=1)[0] + value_element = value.dtype.with_lanes(1) + content_element = self._content_type.with_lanes(1) if value_element != content_element: raise ValueError( "data type does not match content type %s vs %s" % (value.dtype, self._content_type) @@ -244,7 +246,7 @@ def for_range(self, begin, end, name="i", dtype=None, kind="serial"): # auto infer dtype when it's not specified def get_dtype(expr): if isinstance(expr, _expr.PrimExpr): - if not expr.dtype.startswith("int"): + if not expr.dtype.type_code == tvm.DataTypeCode.INT: raise NotImplementedError( f"Infer loop_var dtype failed:" f" unsupported dtype in loop begin or end {expr.dtype}" diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index 744cc799a51b..ce498cde229f 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -476,13 +476,9 @@ int TVMFuncCall(TVMFunctionHandle func, TVMValue* args, int* arg_type_codes, int (static_cast(func)) ->CallPacked(TVMArgs(args, arg_type_codes, num_args), &rv); // handle return string. - if (rv.type_code() == kTVMStr || rv.type_code() == kTVMDataType || rv.type_code() == kTVMBytes) { + if (rv.type_code() == kTVMStr || rv.type_code() == kTVMBytes) { TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get(); - if (rv.type_code() != kTVMDataType) { - e->ret_str = *rv.ptr(); - } else { - e->ret_str = rv.operator std::string(); - } + e->ret_str = *rv.ptr(); if (rv.type_code() == kTVMBytes) { e->ret_bytes.data = e->ret_str.c_str(); e->ret_bytes.size = e->ret_str.length(); diff --git a/src/runtime/data_type.cc b/src/runtime/data_type.cc new file mode 100644 index 000000000000..1332ad4f999d --- /dev/null +++ b/src/runtime/data_type.cc @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file data_type.cc + * \brief Data-type handling + */ +#include +#include + +namespace tvm { +namespace runtime { +TVM_REGISTER_GLOBAL("runtime.String2DLDataType").set_body_typed(String2DLDataType); +TVM_REGISTER_GLOBAL("runtime.DLDataType2String").set_body_typed(DLDataType2String); + +} // namespace runtime +} // namespace tvm