diff --git a/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx b/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx index 0bb40bf40..d249db4f2 100644 --- a/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx +++ b/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx @@ -128,30 +128,55 @@ cdef inline int prepare_ctypes_arg( vector.vector[void*]& data_addresses, arg, const size_t idx) except -1: - if isinstance(arg, ctypes_bool): + cdef object arg_type = type(arg) + if arg_type is ctypes_bool: return prepare_arg[cpp_bool](data, data_addresses, arg.value, idx) - elif isinstance(arg, ctypes_int8): + elif arg_type is ctypes_int8: return prepare_arg[int8_t](data, data_addresses, arg.value, idx) - elif isinstance(arg, ctypes_int16): + elif arg_type is ctypes_int16: return prepare_arg[int16_t](data, data_addresses, arg.value, idx) - elif isinstance(arg, ctypes_int32): + elif arg_type is ctypes_int32: return prepare_arg[int32_t](data, data_addresses, arg.value, idx) - elif isinstance(arg, ctypes_int64): + elif arg_type is ctypes_int64: return prepare_arg[int64_t](data, data_addresses, arg.value, idx) - elif isinstance(arg, ctypes_uint8): + elif arg_type is ctypes_uint8: return prepare_arg[uint8_t](data, data_addresses, arg.value, idx) - elif isinstance(arg, ctypes_uint16): + elif arg_type is ctypes_uint16: return prepare_arg[uint16_t](data, data_addresses, arg.value, idx) - elif isinstance(arg, ctypes_uint32): + elif arg_type is ctypes_uint32: return prepare_arg[uint32_t](data, data_addresses, arg.value, idx) - elif isinstance(arg, ctypes_uint64): + elif arg_type is ctypes_uint64: return prepare_arg[uint64_t](data, data_addresses, arg.value, idx) - elif isinstance(arg, ctypes_float): + elif arg_type is ctypes_float: return prepare_arg[float](data, data_addresses, arg.value, idx) - elif isinstance(arg, ctypes_double): + elif arg_type is ctypes_double: return prepare_arg[double](data, data_addresses, arg.value, idx) else: - return 1 + # If no exact types are found, fallback to slower `isinstance` check + if isinstance(arg, ctypes_bool): + return prepare_arg[cpp_bool](data, data_addresses, arg.value, idx) + elif isinstance(arg, ctypes_int8): + return prepare_arg[int8_t](data, data_addresses, arg.value, idx) + elif isinstance(arg, ctypes_int16): + return prepare_arg[int16_t](data, data_addresses, arg.value, idx) + elif isinstance(arg, ctypes_int32): + return prepare_arg[int32_t](data, data_addresses, arg.value, idx) + elif isinstance(arg, ctypes_int64): + return prepare_arg[int64_t](data, data_addresses, arg.value, idx) + elif isinstance(arg, ctypes_uint8): + return prepare_arg[uint8_t](data, data_addresses, arg.value, idx) + elif isinstance(arg, ctypes_uint16): + return prepare_arg[uint16_t](data, data_addresses, arg.value, idx) + elif isinstance(arg, ctypes_uint32): + return prepare_arg[uint32_t](data, data_addresses, arg.value, idx) + elif isinstance(arg, ctypes_uint64): + return prepare_arg[uint64_t](data, data_addresses, arg.value, idx) + elif isinstance(arg, ctypes_float): + return prepare_arg[float](data, data_addresses, arg.value, idx) + elif isinstance(arg, ctypes_double): + return prepare_arg[double](data, data_addresses, arg.value, idx) + else: + return 1 cdef inline int prepare_numpy_arg( @@ -159,36 +184,67 @@ cdef inline int prepare_numpy_arg( vector.vector[void*]& data_addresses, arg, const size_t idx) except -1: - if isinstance(arg, numpy_bool): + cdef object arg_type = type(arg) + if arg_type is numpy_bool: return prepare_arg[cpp_bool](data, data_addresses, arg, idx) - elif isinstance(arg, numpy_int8): + elif arg_type is numpy_int8: return prepare_arg[int8_t](data, data_addresses, arg, idx) - elif isinstance(arg, numpy_int16): + elif arg_type is numpy_int16: return prepare_arg[int16_t](data, data_addresses, arg, idx) - elif isinstance(arg, numpy_int32): + elif arg_type is numpy_int32: return prepare_arg[int32_t](data, data_addresses, arg, idx) - elif isinstance(arg, numpy_int64): + elif arg_type is numpy_int64: return prepare_arg[int64_t](data, data_addresses, arg, idx) - elif isinstance(arg, numpy_uint8): + elif arg_type is numpy_uint8: return prepare_arg[uint8_t](data, data_addresses, arg, idx) - elif isinstance(arg, numpy_uint16): + elif arg_type is numpy_uint16: return prepare_arg[uint16_t](data, data_addresses, arg, idx) - elif isinstance(arg, numpy_uint32): + elif arg_type is numpy_uint32: return prepare_arg[uint32_t](data, data_addresses, arg, idx) - elif isinstance(arg, numpy_uint64): + elif arg_type is numpy_uint64: return prepare_arg[uint64_t](data, data_addresses, arg, idx) - elif isinstance(arg, numpy_float16): + elif arg_type is numpy_float16: return prepare_arg[__half_raw](data, data_addresses, arg, idx) - elif isinstance(arg, numpy_float32): + elif arg_type is numpy_float32: return prepare_arg[float](data, data_addresses, arg, idx) - elif isinstance(arg, numpy_float64): + elif arg_type is numpy_float64: return prepare_arg[double](data, data_addresses, arg, idx) - elif isinstance(arg, numpy_complex64): + elif arg_type is numpy_complex64: return prepare_arg[cpp_single_complex](data, data_addresses, arg, idx) - elif isinstance(arg, numpy_complex128): + elif arg_type is numpy_complex128: return prepare_arg[cpp_double_complex](data, data_addresses, arg, idx) else: - return 1 + # If no exact types are found, fallback to slower `isinstance` check + if isinstance(arg, numpy_bool): + return prepare_arg[cpp_bool](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_int8): + return prepare_arg[int8_t](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_int16): + return prepare_arg[int16_t](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_int32): + return prepare_arg[int32_t](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_int64): + return prepare_arg[int64_t](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_uint8): + return prepare_arg[uint8_t](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_uint16): + return prepare_arg[uint16_t](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_uint32): + return prepare_arg[uint32_t](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_uint64): + return prepare_arg[uint64_t](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_float16): + return prepare_arg[__half_raw](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_float32): + return prepare_arg[float](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_float64): + return prepare_arg[double](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_complex64): + return prepare_arg[cpp_single_complex](data, data_addresses, arg, idx) + elif isinstance(arg, numpy_complex128): + return prepare_arg[cpp_double_complex](data, data_addresses, arg, idx) + else: + return 1 cdef class ParamHolder: @@ -207,12 +263,14 @@ cdef class ParamHolder: cdef size_t n_args = len(kernel_args) cdef size_t i cdef int not_prepared + cdef object arg_type self.data = vector.vector[voidptr](n_args, nullptr) self.data_addresses = vector.vector[voidptr](n_args) for i, arg in enumerate(kernel_args): - if isinstance(arg, Buffer): + arg_type = type(arg) + if arg_type is Buffer: # we need the address of where the actual buffer address is stored - if isinstance(arg.handle, int): + if type(arg.handle) is int: # see note below on handling int arguments prepare_arg[intptr_t](self.data, self.data_addresses, arg.handle, i) continue @@ -220,7 +278,7 @@ cdef class ParamHolder: # it's a CUdeviceptr: self.data_addresses[i] = (arg.handle.getPtr()) continue - elif isinstance(arg, int): + elif arg_type is int: # Here's the dilemma: We want to have a fast path to pass in Python # integers as pointer addresses, but one could also (mistakenly) pass # it with the intention of passing a scalar integer. It's a mistake @@ -228,13 +286,13 @@ cdef class ParamHolder: # call here is to treat it as a pointer address, without any warning! prepare_arg[intptr_t](self.data, self.data_addresses, arg, i) continue - elif isinstance(arg, float): + elif arg_type is float: prepare_arg[double](self.data, self.data_addresses, arg, i) continue - elif isinstance(arg, complex): + elif arg_type is complex: prepare_arg[cpp_double_complex](self.data, self.data_addresses, arg, i) continue - elif isinstance(arg, bool): + elif arg_type is bool: prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i) continue @@ -243,7 +301,30 @@ cdef class ParamHolder: not_prepared = prepare_ctypes_arg(self.data, self.data_addresses, arg, i) if not_prepared: # TODO: revisit this treatment if we decide to cythonize cuda.core - if isinstance(arg, driver.CUgraphConditionalHandle): + if arg_type is driver.CUgraphConditionalHandle: + prepare_arg[intptr_t](self.data, self.data_addresses, int(arg), i) + continue + # If no exact types are found, fallback to slower `isinstance` check + elif isinstance(arg, Buffer): + if isinstance(arg.handle, int): + prepare_arg[intptr_t](self.data, self.data_addresses, arg.handle, i) + continue + else: + self.data_addresses[i] = (arg.handle.getPtr()) + continue + elif isinstance(arg, int): + prepare_arg[intptr_t](self.data, self.data_addresses, arg, i) + continue + elif isinstance(arg, float): + prepare_arg[double](self.data, self.data_addresses, arg, i) + continue + elif isinstance(arg, complex): + prepare_arg[cpp_double_complex](self.data, self.data_addresses, arg, i) + continue + elif isinstance(arg, bool): + prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i) + continue + elif isinstance(arg, driver.CUgraphConditionalHandle): prepare_arg[intptr_t](self.data, self.data_addresses, int(arg), i) continue # TODO: support ctypes/numpy struct