diff --git a/conda-recipe/run_test.bat b/conda-recipe/run_test.bat index 031bc6e69a..fd9bf19494 100644 --- a/conda-recipe/run_test.bat +++ b/conda-recipe/run_test.bat @@ -3,7 +3,9 @@ call "%ONEAPI_ROOT%\compiler\latest\env\vars.bat" @echo on -python -m numba.runtests -b -v -m -- numba_dppy.tests +export NUMBA_DEBUG=1 + +python -m numba.runtests -b -v -m -- numba_dppy.tests.test_usmarray.TestUsmArray.test_numba_usmarray_as_ndarray IF %ERRORLEVEL% NEQ 0 exit /B 1 exit /B 0 diff --git a/conda-recipe/run_test.sh b/conda-recipe/run_test.sh index 8a30af0c51..06d516c5f8 100644 --- a/conda-recipe/run_test.sh +++ b/conda-recipe/run_test.sh @@ -8,6 +8,7 @@ source ${ONEAPI_ROOT}/tbb/latest/env/vars.sh set -x -python -m numba.runtests -b -v -m -- numba_dppy.tests +export NUMBA_DEBUG=1 +python -m numba.runtests -b -v -m -- numba_dppy.tests.test_usmarray.TestUsmArray.test_numba_usmarray_as_ndarray exit 0 diff --git a/numba_dppy/dppy_rt.c b/numba_dppy/dppy_rt.c new file mode 100644 index 0000000000..83fd6949b2 --- /dev/null +++ b/numba_dppy/dppy_rt.c @@ -0,0 +1,171 @@ +#include "_pymodule.h" +#include "core/runtime/nrt_external.h" +#include "assert.h" +#include +#if !defined _WIN32 + #include +#else + #include +#endif + +NRT_ExternalAllocator usmarray_allocator; +NRT_external_malloc_func internal_allocator = NULL; +NRT_external_free_func internal_free = NULL; +void *(*get_queue_internal)(void) = NULL; +void (*free_queue_internal)(void*) = NULL; + +void * save_queue_allocator(size_t size, void *opaque) { + // Allocate a pointer-size more space than neded. + int new_size = size + sizeof(void*); + // Get the current queue + void *cur_queue = get_queue_internal(); // this makes a copy + // Use that queue to allocate. + void *data = internal_allocator(new_size, cur_queue); + // Set first pointer-sized data in allocated space to be the current queue. + *(void**)data = cur_queue; + // Return the pointer after this queue in memory. + return (char*)data + sizeof(void*); +} + +void save_queue_deallocator(void *data, void *opaque) { + // Compute original allocation location by subtracting the length + // of the queue pointer from the data location that Numba thinks + // starts the object. + void *orig_data = (char*)data - sizeof(void*); + // Get the queue from the original data by derefencing the first qword. + void *obj_queue = *(void**)orig_data; + // Free the space using the correct queue. + internal_free(orig_data, obj_queue); + // Free the queue itself. + free_queue_internal(obj_queue); +} + +void usmarray_memsys_init(void) { + #if !defined _WIN32 + char *lib_name = "libDPCTLSyclInterface.so"; + char *malloc_name = "DPCTLmalloc_shared"; + char *free_name = "DPCTLfree_with_queue"; + char *get_queue_name = "DPCTLQueueMgr_GetCurrentQueue"; + char *free_queue_name = "DPCTLQueue_Delete"; + + void *sycldl = dlopen(lib_name, RTLD_NOW); + assert(sycldl != NULL); + internal_allocator = (NRT_external_malloc_func)dlsym(sycldl, malloc_name); + usmarray_allocator.malloc = save_queue_allocator; + if (usmarray_allocator.malloc == NULL) { + printf("Did not find %s in %s\n", malloc_name, lib_name); + exit(-1); + } + + usmarray_allocator.realloc = NULL; + + internal_free = (NRT_external_free_func)dlsym(sycldl, free_name); + usmarray_allocator.free = save_queue_deallocator; + if (usmarray_allocator.free == NULL) { + printf("Did not find %s in %s\n", free_name, lib_name); + exit(-1); + } + + get_queue_internal = (void *(*)(void))dlsym(sycldl, get_queue_name); + if (get_queue_internal == NULL) { + printf("Did not find %s in %s\n", get_queue_name, lib_name); + exit(-1); + } + usmarray_allocator.opaque_data = NULL; + + free_queue_internal = (void (*)(void*))dlsym(sycldl, free_queue_name); + if (free_queue_internal == NULL) { + printf("Did not find %s in %s\n", free_queue_name, lib_name); + exit(-1); + } + #else + char *lib_name = "libDPCTLSyclInterface.dll"; + char *malloc_name = "DPCTLmalloc_shared"; + char *free_name = "DPCTLfree_with_queue"; + char *get_queue_name = "DPCTLQueueMgr_GetCurrentQueue"; + char *free_queue_name = "DPCTLQueue_Delete"; + + HMODULE sycldl = LoadLibrary(lib_name); + assert(sycldl != NULL); + internal_allocator = (NRT_external_malloc_func)GetProcAddress(sycldl, malloc_name); + usmarray_allocator.malloc = save_queue_allocator; + if (usmarray_allocator.malloc == NULL) { + printf("Did not find %s in %s\n", malloc_name, lib_name); + exit(-1); + } + + usmarray_allocator.realloc = NULL; + + internal_free = (NRT_external_free_func)GetProcAddress(sycldl, free_name); + usmarray_allocator.free = save_queue_deallocator; + if (usmarray_allocator.free == NULL) { + printf("Did not find %s in %s\n", free_name, lib_name); + exit(-1); + } + + get_queue_internal = (void *(*)(void))GetProcAddress(sycldl, get_queue_name); + if (get_queue_internal == NULL) { + printf("Did not find %s in %s\n", get_queue_name, lib_name); + exit(-1); + } + usmarray_allocator.opaque_data = NULL; + + free_queue_internal = (void (*)(void*))GetProcAddress(sycldl, free_queue_name); + if (free_queue_internal == NULL) { + printf("Did not find %s in %s\n", free_queue_name, lib_name); + exit(-1); + } + #endif +} + +void * usmarray_get_ext_allocator(void) { + return (void*)&usmarray_allocator; +} + +static PyObject * +get_external_allocator(PyObject *self, PyObject *args) { + return PyLong_FromVoidPtr(usmarray_get_ext_allocator()); +} + +static PyMethodDef ext_methods[] = { +#define declmethod_noargs(func) { #func , ( PyCFunction )func , METH_NOARGS, NULL } + declmethod_noargs(get_external_allocator), + {NULL}, +#undef declmethod_noargs +}; + +static PyObject * +build_c_helpers_dict(void) +{ + PyObject *dct = PyDict_New(); + if (dct == NULL) + goto error; + +#define _declpointer(name, value) do { \ + PyObject *o = PyLong_FromVoidPtr(value); \ + if (o == NULL) goto error; \ + if (PyDict_SetItemString(dct, name, o)) { \ + Py_DECREF(o); \ + goto error; \ + } \ + Py_DECREF(o); \ +} while (0) + + _declpointer("usmarray_get_ext_allocator", &usmarray_get_ext_allocator); + +#undef _declpointer + return dct; +error: + Py_XDECREF(dct); + return NULL; +} + +MOD_INIT(_dppy_rt) { + PyObject *m; + MOD_DEF(m, "numba_dppy._dppy_rt", "No docs", ext_methods) + if (m == NULL) + return MOD_ERROR_VAL; + usmarray_memsys_init(); + PyModule_AddObject(m, "c_helpers", build_c_helpers_dict()); + return MOD_SUCCESS_VAL(m); +} diff --git a/numba_dppy/numpy_usm_shared.py b/numba_dppy/numpy_usm_shared.py new file mode 100644 index 0000000000..0f058bc778 --- /dev/null +++ b/numba_dppy/numpy_usm_shared.py @@ -0,0 +1,430 @@ +import numpy as np +from inspect import getmembers, isfunction, isclass, isbuiltin +from numbers import Number +import numba +from types import FunctionType as ftype, BuiltinFunctionType as bftype +from numba import types +from numba.extending import typeof_impl, register_model, type_callable, lower_builtin +from numba.np import numpy_support +from numba.core.pythonapi import box, allocator +from llvmlite import ir +import llvmlite.llvmpy.core as lc +import llvmlite.binding as llb +from numba.core import types, cgutils, config +import builtins +import sys +from ctypes.util import find_library +from numba.core.typing.templates import builtin_registry as templates_registry +from numba.core.typing.npydecl import registry as typing_registry +from numba.core.imputils import builtin_registry as lower_registry +import importlib +import functools +import inspect +from numba.core.typing.templates import CallableTemplate +from numba.np.arrayobj import _array_copy + +import dpctl.dptensor.numpy_usm_shared as nus +from dpctl.dptensor.numpy_usm_shared import ndarray, functions_list, class_list + + +debug = config.DEBUG + +def dprint(*args): + if debug: + print(*args) + sys.stdout.flush() + +# # This code makes it so that Numba can contain calls into the DPPLSyclInterface library. +# sycl_mem_lib = find_library('DPCTLSyclInterface') +# dprint("sycl_mem_lib:", sycl_mem_lib) +# # Load the symbols from the DPPL Sycl library. +# llb.load_library_permanently(sycl_mem_lib) + +import dpctl +from dpctl.memory import MemoryUSMShared +import numba_dppy._dppy_rt + +# functions_list = [o[0] for o in getmembers(np) if isfunction(o[1]) or isbuiltin(o[1])] +# class_list = [o for o in getmembers(np) if isclass(o[1])] + +# Register the helper function in dppl_rt so that we can insert calls to them via llvmlite. +for py_name, c_address in numba_dppy._dppy_rt.c_helpers.items(): + llb.add_symbol(py_name, c_address) + + +# This class creates a type in Numba. +class UsmSharedArrayType(types.Array): + def __init__( + self, + dtype, + ndim, + layout, + readonly=False, + name=None, + aligned=True, + addrspace=None, + ): + # This name defines how this type will be shown in Numba's type dumps. + name = "UsmArray:ndarray(%s, %sd, %s)" % (dtype, ndim, layout) + super(UsmSharedArrayType, self).__init__( + dtype, + ndim, + layout, + py_type=ndarray, + readonly=readonly, + name=name, + addrspace=addrspace, + ) + + # Tell Numba typing how to combine UsmSharedArrayType with other ndarray types. + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): + if method == "__call__": + for inp in inputs: + if not isinstance(inp, (UsmSharedArrayType, types.Array, types.Number)): + return None + + return UsmSharedArrayType + else: + return None + + +# This tells Numba how to create a UsmSharedArrayType when a usmarray is passed +# into a njit function. +@typeof_impl.register(ndarray) +def typeof_ta_ndarray(val, c): + try: + dtype = numpy_support.from_dtype(val.dtype) + except NotImplementedError: + raise ValueError("Unsupported array dtype: %s" % (val.dtype,)) + layout = numpy_support.map_layout(val) + readonly = not val.flags.writeable + return UsmSharedArrayType(dtype, val.ndim, layout, readonly=readonly) + + +# This tells Numba to use the default Numpy ndarray data layout for +# object of type UsmArray. +register_model(UsmSharedArrayType)(numba.core.datamodel.models.ArrayModel) + +# This tells Numba how to convert from its native representation +# of a UsmArray in a njit function back to a Python UsmArray. +@box(UsmSharedArrayType) +def box_array(typ, val, c): + nativearycls = c.context.make_array(typ) + nativeary = nativearycls(c.context, c.builder, value=val) + if c.context.enable_nrt: + np_dtype = numpy_support.as_dtype(typ.dtype) + dtypeptr = c.env_manager.read_const(c.env_manager.add_const(np_dtype)) + # Steals NRT ref + newary = c.pyapi.nrt_adapt_ndarray_to_python(typ, val, dtypeptr) + return newary + else: + parent = nativeary.parent + c.pyapi.incref(parent) + return parent + + +# This tells Numba to use this function when it needs to allocate a +# UsmArray in a njit function. +@allocator(UsmSharedArrayType) +def allocator_UsmArray(context, builder, size, align): + context.nrt._require_nrt() + + mod = builder.module + u32 = ir.IntType(32) + + # Get the Numba external allocator for USM memory. + ext_allocator_fnty = ir.FunctionType(cgutils.voidptr_t, []) + ext_allocator_fn = mod.get_or_insert_function( + ext_allocator_fnty, name="usmarray_get_ext_allocator" + ) + ext_allocator = builder.call(ext_allocator_fn, []) + # Get the Numba function to allocate an aligned array with an external allocator. + fnty = ir.FunctionType(cgutils.voidptr_t, [cgutils.intp_t, u32, cgutils.voidptr_t]) + fn = mod.get_or_insert_function( + fnty, name="NRT_MemInfo_alloc_safe_aligned_external" + ) + fn.return_value.add_attribute("noalias") + if isinstance(align, builtins.int): + align = context.get_constant(types.uint32, align) + else: + assert align.type == u32, "align must be a uint32" + return builder.call(fn, [size, align, ext_allocator]) + + +registered = False + +def is_usm_callback(obj): + if isinstance(obj, numba.core.runtime._nrt_python._MemInfo): + mobj = obj + while isinstance(mobj, numba.core.runtime._nrt_python._MemInfo): + ea = mobj.external_allocator + d = mobj.data + dppl_rt_allocator = numba_dppy._dppy_rt.get_external_allocator() + if ea == dppl_rt_allocator: + return True + mobj = mobj.parent + if isinstance(mobj, ndarray): + mobj = mobj.base + return False + +def numba_register(): + global registered + if not registered: + registered = True + ndarray.add_external_usm_checker(is_usm_callback) + numba_register_typing() + numba_register_lower_builtin() + + +# Copy a function registered as a lowerer in Numba but change the +# "np" import in Numba to point to usmarray instead of NumPy. +def copy_func_for_usmarray(f, usmarray_mod): + import copy as cc + + # Make a copy so our change below doesn't affect anything else. + gglobals = cc.copy(f.__globals__) + # Make the "np"'s in the code use usmarray instead of Numba's default NumPy. + gglobals["np"] = usmarray_mod + # Create a new function using the original code but the new globals. + g = ftype(f.__code__, gglobals, None, f.__defaults__, f.__closure__) + # Some other tricks to make sure the function copy works. + g = functools.update_wrapper(g, f) + g.__kwdefaults__ = f.__kwdefaults__ + return g + + +def types_replace_array(x): + return tuple([z if z != types.Array else UsmSharedArrayType for z in x]) + + +def numba_register_lower_builtin(): + todo = [] + todo_builtin = [] + todo_getattr = [] + + # For all Numpy identifiers that have been registered for typing in Numba... + # this registry contains functions, getattrs, setattrs, casts and constants...need to do them all? FIX FIX FIX + for ig in lower_registry.functions: + impl, func, types = ig + # If it is a Numpy function... + if isinstance(func, ftype): + if func.__module__ == np.__name__: + # If we have overloaded that function in the usmarray module (always True right now)... + if func.__name__ in functions_list: + todo.append(ig) + if isinstance(func, bftype): + if func.__module__ == np.__name__: + # If we have overloaded that function in the usmarray module (always True right now)... + if func.__name__ in functions_list: + todo.append(ig) + + for lg in lower_registry.getattrs: + func, attr, types = lg + types_with_usmarray = types_replace_array(types) + if UsmSharedArrayType in types_with_usmarray: + dprint( + "lower_getattr:", func, type(func), attr, type(attr), types, type(types) + ) + todo_getattr.append((func, attr, types_with_usmarray)) + + for lg in todo_getattr: + lower_registry.getattrs.append(lg) + + for impl, func, types in todo + todo_builtin: + try: + usmarray_func = eval("dpctl.dptensor.numpy_usm_shared." + func.__name__) + except: + dprint("failed to eval", func.__name__) + continue + dprint( + "need to re-register lowerer for usmarray", impl, func, types, usmarray_func + ) + new_impl = copy_func_for_usmarray(impl, nus) + lower_registry.functions.append((new_impl, usmarray_func, types)) + + +def argspec_to_string(argspec): + first_default_arg = len(argspec.args) - len(argspec.defaults) + non_def = argspec.args[:first_default_arg] + arg_zip = list(zip(argspec.args[first_default_arg:], argspec.defaults)) + combined = [a + "=" + str(b) for a, b in arg_zip] + return ",".join(non_def + combined) + + +def numba_register_typing(): + todo = [] + todo_classes = [] + todo_getattr = [] + + # For all Numpy identifiers that have been registered for typing in Numba... + for ig in typing_registry.globals: + val, typ = ig + dprint("Numpy registered:", val, type(val), typ, type(typ)) + # If it is a Numpy function... + if isinstance(val, (ftype, bftype)): + # If we have overloaded that function in the usmarray module (always True right now)... + if val.__name__ in functions_list: + todo.append(ig) + if isinstance(val, type): + if isinstance(typ, numba.core.types.functions.Function): + todo.append(ig) + elif isinstance(typ, numba.core.types.functions.NumberClass): + pass + #todo_classes.append(ig) + + for tgetattr in templates_registry.attributes: + if tgetattr.key == types.Array: + todo_getattr.append(tgetattr) + + for val, typ in todo_classes: + dprint("todo_classes:", val, typ, type(typ)) + + try: + dptype = eval("dpctl.dptensor.numpy_usm_shared." + val.__name__) + except: + dprint("failed to eval", val.__name__) + continue + + typing_registry.register_global(dptype, numba.core.types.NumberClass(typ.instance_type)) + + for val, typ in todo: + assert len(typ.templates) == 1 + # template is the typing class to invoke generic() upon. + template = typ.templates[0] + dprint("need to re-register for usmarray", val, typ, typ.typing_key) + try: + dpval = eval("dpctl.dptensor.numpy_usm_shared." + val.__name__) + except: + dprint("failed to eval", val.__name__) + continue + """ + if debug: + print("--------------------------------------------------------------") + print("need to re-register for usmarray", val, typ, typ.typing_key) + print("val:", val, type(val), "dir val", dir(val)) + print("typ:", typ, type(typ), "dir typ", dir(typ)) + print("typing key:", typ.typing_key) + print("name:", typ.name) + print("key:", typ.key) + print("templates:", typ.templates) + print("template:", template, type(template)) + print("dpval:", dpval, type(dpval)) + print("--------------------------------------------------------------") + """ + + class_name = "DparrayTemplate_" + val.__name__ + + @classmethod + def set_key_original(cls, key, original): + cls.key = key + cls.original = original + + def generic_impl(self): + original_typer = self.__class__.original.generic(self.__class__.original) + ot_argspec = inspect.getfullargspec(original_typer) + astr = argspec_to_string(ot_argspec) + + typer_func = """def typer({}): + original_res = original_typer({}) + if isinstance(original_res, types.Array): + return UsmSharedArrayType(dtype=original_res.dtype, ndim=original_res.ndim, layout=original_res.layout) + + return original_res""".format( + astr, ",".join(ot_argspec.args) + ) + + try: + gs = globals() + ls = locals() + gs["original_typer"] = ls["original_typer"] + exec(typer_func, globals(), locals()) + except NameError as ne: + print("NameError in exec:", ne) + sys.exit(0) + except: + print("exec failed!", sys.exc_info()[0]) + sys.exit(0) + + try: + exec_res = eval("typer") + except NameError as ne: + print("NameError in eval:", ne) + sys.exit(0) + except: + print("eval failed!", sys.exc_info()[0]) + sys.exit(0) + + return exec_res + + new_usmarray_template = type( + class_name, + (template,), + {"set_class_vars": set_key_original, "generic": generic_impl}, + ) + + new_usmarray_template.set_class_vars(dpval, template) + + assert callable(dpval) + type_handler = types.Function(new_usmarray_template) + typing_registry.register_global(dpval, type_handler) + + # Handle usmarray attribute typing. + for tgetattr in todo_getattr: + class_name = tgetattr.__name__ + "_usmarray" + dprint("tgetattr:", tgetattr, type(tgetattr), class_name) + + @classmethod + def set_key(cls, key): + cls.key = key + + def getattr_impl(self, attr): + if attr.startswith("resolve_"): + def wrapper(*args, **kwargs): + attr_res = tgetattr.__getattribute__(self, attr)(*args, **kwargs) + if isinstance(attr_res, types.Array): + return UsmSharedArrayType( + dtype=attr_res.dtype, + ndim=attr_res.ndim, + layout=attr_res.layout, + ) + + return wrapper + else: + return tgetattr.__getattribute__(self, attr) + + new_usmarray_template = type( + class_name, + (tgetattr,), + {"set_class_vars": set_key, "__getattribute__": getattr_impl}, + ) + + new_usmarray_template.set_class_vars(UsmSharedArrayType) + templates_registry.register_attr(new_usmarray_template) + + +@typing_registry.register_global(nus.as_ndarray) +class DparrayAsNdarray(CallableTemplate): + def generic(self): + def typer(arg): + return types.Array(dtype=arg.dtype, ndim=arg.ndim, layout=arg.layout) + + return typer + + +@typing_registry.register_global(nus.from_ndarray) +class DparrayFromNdarray(CallableTemplate): + def generic(self): + def typer(arg): + return UsmSharedArrayType(dtype=arg.dtype, ndim=arg.ndim, layout=arg.layout) + + return typer + + +@lower_registry.lower(nus.as_ndarray, UsmSharedArrayType) +def usmarray_conversion_as(context, builder, sig, args): + return _array_copy(context, builder, sig, args) + + +@lower_registry.lower(nus.from_ndarray, types.Array) +def usmarray_conversion_from(context, builder, sig, args): + return _array_copy(context, builder, sig, args) diff --git a/numba_dppy/tests/test_usmarray.py b/numba_dppy/tests/test_usmarray.py new file mode 100644 index 0000000000..abf1a78ec6 --- /dev/null +++ b/numba_dppy/tests/test_usmarray.py @@ -0,0 +1,200 @@ +import numba +import numpy +import unittest + +import dpctl.dptensor.numpy_usm_shared as usmarray + + +@numba.njit() +def numba_mul_add(a): + return a * 2.0 + 13 + + +@numba.njit() +def numba_add_const(a): + return a + 13 + + +@numba.njit() +def numba_mul(a, b): # a is usmarray, b is numpy + return a * b + + +@numba.njit() +def numba_mul_usmarray_asarray(a, b): # a is usmarray, b is numpy + return a * usmarray.asarray(b) + + +# @numba.njit() +# def f7(a): # a is usmarray +# # implicit conversion of a to numpy.ndarray +# b = numpy.ones(10) +# c = a * b +# d = a.argsort() # with no implicit conversion this fails + + +@numba.njit +def numba_usmarray_as_ndarray(a): + return usmarray.as_ndarray(a) + + +@numba.njit +def numba_usmarray_from_ndarray(a): + return usmarray.from_ndarray(a) + + +@numba.njit() +def numba_usmarray_ones(): + return usmarray.ones(10) + + +@numba.njit +def numba_usmarray_empty(): + return usmarray.empty((10, 10)) + + +@numba.njit() +def numba_identity(a): + return a + + +@numba.njit +def numba_shape(x): + return x.shape + + +@numba.njit +def numba_T(x): + return x.T + + +class TestUsmArray(unittest.TestCase): + def ndarray(self): + """Create NumPy array""" + return numpy.ones(10) + + def usmarray(self): + """Create dpCtl USM array""" + return usmarray.ones(10) + + def test_python_numpy(self): + """Testing Python Numpy""" + z2 = numba_mul_add.py_func(self.ndarray()) + self.assertEqual(type(z2), numpy.ndarray, z2) + + def test_numba_numpy(self): + """Testing Numba Numpy""" + z2 = numba_mul_add(self.ndarray()) + self.assertEqual(type(z2), numpy.ndarray, z2) + + def test_usmarray_ones(self): + """Testing usmarray ones""" + a = usmarray.ones(10) + self.assertIsInstance(a, usmarray.ndarray, type(a)) + self.assertTrue(usmarray.has_array_interface(a)) + + def test_usmarray_usmarray_as_ndarray(self): + """Testing usmarray.usmarray.as_ndarray""" + nd1 = self.usmarray().as_ndarray() + self.assertEqual(type(nd1), numpy.ndarray, nd1) + + def test_usmarray_as_ndarray(self): + """Testing usmarray.as_ndarray""" + nd2 = usmarray.as_ndarray(self.usmarray()) + self.assertEqual(type(nd2), numpy.ndarray, nd2) + + def test_usmarray_from_ndarray(self): + """Testing usmarray.from_ndarray""" + nd2 = usmarray.as_ndarray(self.usmarray()) + dp1 = usmarray.from_ndarray(nd2) + self.assertIsInstance(dp1, usmarray.ndarray, type(dp1)) + self.assertTrue(usmarray.has_array_interface(dp1)) + + def test_usmarray_multiplication(self): + """Testing usmarray multiplication""" + c = self.usmarray() * 5 + self.assertIsInstance(c, usmarray.ndarray, type(c)) + self.assertTrue(usmarray.has_array_interface(c)) + + def test_python_usmarray_mul_add(self): + """Testing Python usmarray""" + c = self.usmarray() * 5 + b = numba_mul_add.py_func(c) + self.assertIsInstance(b, usmarray.ndarray, type(b)) + self.assertTrue(usmarray.has_array_interface(b)) + + @unittest.expectedFailure + def test_numba_usmarray_mul_add(self): + """Testing Numba usmarray""" + # fails if run tests in bunch + c = self.usmarray() * 5 + b = numba_mul_add(c) + self.assertIsInstance(b, usmarray.ndarray, type(b)) + self.assertTrue(usmarray.has_array_interface(b)) + + def test_python_mixing_usmarray_and_numpy_ndarray(self): + """Testing Python mixing usmarray and numpy.ndarray""" + h = numba_mul.py_func(self.usmarray(), self.ndarray()) + self.assertIsInstance(h, usmarray.ndarray, type(h)) + self.assertTrue(usmarray.has_array_interface(h)) + + def test_numba_usmarray_2(self): + """Testing Numba usmarray 2""" + d = numba_identity(self.usmarray()) + self.assertIsInstance(d, usmarray.ndarray, type(d)) + self.assertTrue(usmarray.has_array_interface(d)) + + @unittest.expectedFailure + def test_numba_usmarray_constructor_from_numpy_ndarray(self): + """Testing Numba usmarray constructor from numpy.ndarray""" + e = numba_mul_usmarray_asarray(self.usmarray(), self.ndarray()) + self.assertIsInstance(e, usmarray.ndarray, type(e)) + + def test_numba_mixing_usmarray_and_constant(self): + """Testing Numba mixing usmarray and constant""" + g = numba_add_const(self.usmarray()) + self.assertIsInstance(g, usmarray.ndarray, type(g)) + self.assertTrue(usmarray.has_array_interface(g)) + + def test_numba_mixing_usmarray_and_numpy_ndarray(self): + """Testing Numba mixing usmarray and numpy.ndarray""" + h = numba_mul(self.usmarray(), self.ndarray()) + self.assertIsInstance(h, usmarray.ndarray, type(h)) + self.assertTrue(usmarray.has_array_interface(h)) + + def test_numba_usmarray_functions(self): + """Testing Numba usmarray functions""" + f = numba_usmarray_ones() + self.assertIsInstance(f, usmarray.ndarray, type(f)) + self.assertTrue(usmarray.has_array_interface(f)) + + def test_numba_usmarray_as_ndarray(self): + """Testing Numba usmarray.as_ndarray""" + nd3 = numba_usmarray_as_ndarray(self.usmarray()) + self.assertEqual(type(nd3), numpy.ndarray, nd3) + + def test_numba_usmarray_from_ndarray(self): + """Testing Numba usmarray.from_ndarray""" + nd3 = numba_usmarray_as_ndarray(self.usmarray()) + dp2 = numba_usmarray_from_ndarray(nd3) + self.assertIsInstance(dp2, usmarray.ndarray, type(dp2)) + self.assertTrue(usmarray.has_array_interface(dp2)) + + def test_numba_usmarray_empty(self): + """Testing Numba usmarray.empty""" + dp3 = numba_usmarray_empty() + self.assertIsInstance(dp3, usmarray.ndarray, type(dp3)) + self.assertTrue(usmarray.has_array_interface(dp3)) + + def test_numba_usmarray_shape(self): + """Testing Numba usmarray.shape""" + s1 = numba_shape(numba_usmarray_empty()) + self.assertIsInstance(s1, tuple, type(s1)) + self.assertEqual(s1, (10, 10)) + + @unittest.expectedFailure + def test_numba_usmarray_T(self): + """Testing Numba usmarray.T""" + dp4 = numba_T(numba_usmarray_empty()) + self.assertIsInstance(dp4, usmarray.ndarray, type(dp4)) + self.assertTrue(usmarray.has_array_interface(dp4)) diff --git a/setup.py b/setup.py index 37ad0bfc68..d833951745 100644 --- a/setup.py +++ b/setup.py @@ -1,12 +1,23 @@ import os from setuptools import Extension, find_packages, setup from Cython.Build import cythonize +from numba.core.extending import include_path import versioneer +import sys def get_ext_modules(): ext_modules = [] + numba_dir = include_path() + + ext_dppy = Extension( + name="numba_dppy._dppy_rt", + sources=["numba_dppy/dppy_rt.c"], + include_dirs=[numba_dir + "/numba"], + depends=[numba_dir + "/numba/core/runtime/nrt_external.h", numba_dir + "/numba/core/runtime/nrt.h", numba_dir + "/numba/_pymodule.h"], + ) + ext_modules += [ext_dppy] dpnp_present = False try: @@ -64,6 +75,10 @@ def get_ext_modules(): "Topic :: Software Development :: Compilers", ], cmdclass=versioneer.get_cmdclass(), + entry_points={ + "numba_extensions": [ + "init = numba_dppy.numpy_usm_shared:numba_register", + ]}, ) setup(**metadata)