From 9be01c5e62920644114b1d1357e6d4995056d875 Mon Sep 17 00:00:00 2001 From: Anton Chernyatevich Date: Thu, 10 Aug 2023 19:20:43 +0300 Subject: [PATCH 01/31] Add backend setup --- arrayfire/__init__.py | 4 +- arrayfire/array_api/__init__.py | 0 arrayfire/array_api/array_object.py | 816 ++++++++++++++++++++++ arrayfire/array_api/constants.py | 40 ++ arrayfire/array_api/creation_function.py | 27 + arrayfire/array_api/dtypes.py | 178 +++++ arrayfire/backend/backend.py | 149 +++- arrayfire/backend/wrapped/everything.py | 9 + arrayfire/backend/wrapped/indexing.py | 4 +- arrayfire/config.py | 119 ++++ arrayfire/dtypes/__init__.py | 2 +- arrayfire/dtypes/helpers.py | 2 +- arrayfire/library/array_object.py | 44 +- arrayfire/library/device.py | 3 + tests/array_object/test_initialization.py | 2 +- tests/array_object/test_methods.py | 27 +- tests/array_object/test_operators.py | 2 +- 17 files changed, 1400 insertions(+), 28 deletions(-) create mode 100755 arrayfire/array_api/__init__.py create mode 100755 arrayfire/array_api/array_object.py create mode 100755 arrayfire/array_api/constants.py create mode 100755 arrayfire/array_api/creation_function.py create mode 100755 arrayfire/array_api/dtypes.py diff --git a/arrayfire/__init__.py b/arrayfire/__init__.py index 994e2e5..f856c55 100644 --- a/arrayfire/__init__.py +++ b/arrayfire/__init__.py @@ -2,6 +2,7 @@ # array objects "Array", # dtypes + "int8", "int16", "int32", "int64", @@ -16,5 +17,6 @@ "bool", ] -from .dtypes import bool, complex64, complex128, float32, float64, int16, int32, int64, uint8, uint16, uint32, uint64 +from .dtypes import ( + bool, complex64, complex128, float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64) from .library.array_object import Array diff --git a/arrayfire/array_api/__init__.py b/arrayfire/array_api/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/arrayfire/array_api/array_object.py b/arrayfire/array_api/array_object.py new file mode 100755 index 0000000..95bed1a --- /dev/null +++ b/arrayfire/array_api/array_object.py @@ -0,0 +1,816 @@ +from __future__ import annotations + +__all__ = ["Array"] + +import types +from enum import IntEnum +from typing import TYPE_CHECKING, Any, Optional, Tuple, Union + +from arrayfire import Array as AFArray +from arrayfire.array_api.constants import NestedSequence, SupportsBufferProtocol +from arrayfire.dtypes import Dtype + +if TYPE_CHECKING: + from .constants import PyCapsule + from .dtypes import all_dtypes, dtype_categories, numeric_dtypes, promote_types + + +class Array: + _array: AFArray + + def __new__(cls, *args: Any, **kwargs: Any) -> Array: + raise TypeError( + "The array_api Array object should not be instantiated directly. " + "Use an array creation function, such as asarray(), instead." + ) + + def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_category: str, op: str) -> Array: + """ + Helper function for operators to only allow specific input dtypes + + Use like + + other = self._check_allowed_dtypes(other, 'numeric', '__add__') + if other is NotImplemented: + return other + """ + if self.dtype not in dtype_categories[dtype_category]: + raise TypeError(f"Only {dtype_category} dtypes are allowed in {op}") + if isinstance(other, (int, complex, float, bool)): + other = self._promote_scalar(other) + elif isinstance(other, Array): + if other.dtype not in dtype_categories[dtype_category]: + raise TypeError(f"Only {dtype_category} dtypes are allowed in {op}") + else: + return NotImplemented + + # This will raise TypeError for type combinations that are not allowed + # to promote in the spec (even if the NumPy array operator would + # promote them). + res_dtype = promote_types(self.dtype, other.dtype) + if op.startswith("__i"): + # Note: NumPy will allow in-place operators in some cases where + # the type promoted operator does not match the left-hand side + # operand. For example, + + # >>> a = np.array(1, dtype=np.int8) + # >>> a += np.array(1, dtype=np.int16) + + # The spec explicitly disallows this. + if res_dtype != self.dtype: + raise TypeError(f"Cannot perform {op} with dtypes {self.dtype} and {other.dtype}") + + return other + + def _promote_scalar(self, scalar): + """ + Returns a promoted version of a Python scalar appropriate for use with + operations on self. + + This may raise an OverflowError in cases where the scalar is an + integer that is too large to fit in a NumPy integer dtype, or + TypeError when the scalar type is incompatible with the dtype of self. + """ + # Note: Only Python scalar types that match the array dtype are + # allowed. + if isinstance(scalar, bool): + if self.dtype not in boolean_dtypes: + raise TypeError("Python bool scalars can only be promoted with bool arrays") + elif isinstance(scalar, int): + if self.dtype in boolean_dtypes: + raise TypeError("Python int scalars cannot be promoted with bool arrays") + if self.dtype in integer_dtypes: + info = np.iinfo(self.dtype) + if not (info.min <= scalar <= info.max): + raise OverflowError("Python int scalars must be within the bounds of the dtype for integer arrays") + # int + array(floating) is allowed + elif isinstance(scalar, float): + if self.dtype not in floating_dtypes: + raise TypeError("Python float scalars can only be promoted with floating-point arrays.") + elif isinstance(scalar, complex): + if self.dtype not in complex_floating_dtypes: + raise TypeError("Python complex scalars can only be promoted with complex floating-point arrays.") + else: + raise TypeError("'scalar' must be a Python scalar") + + # Note: scalars are unconditionally cast to the same dtype as the + # array. + + # Note: the spec only specifies integer-dtype/int promotion + # behavior for integers within the bounds of the integer dtype. + # Outside of those bounds we use the default NumPy behavior (either + # cast or raise OverflowError). + return Array._new(np.array(scalar, self.dtype)) + + @classmethod + def _new(cls, x: Union[Array, bool, int, float, complex, NestedSequence, SupportsBufferProtocol], /) -> Array: + """ + This is a private method for initializing the array API Array + object. + + Functions outside of the array_api submodule should not use this + method. Use one of the creation functions instead, such as + ``asarray``. + + """ + obj = super().__new__(cls) + # Note: The spec does not have array scalars, only 0-D arrays. + if isinstance(x, (bool, int, float, complex)): + # Convert the array scalar to a 0-D array + x = AFArray(x) # type: ignore[arg-type] + if x.dtype not in all_dtypes: # type: ignore[union-attr] + raise TypeError( + f"The array_api namespace does not support the dtype '{x.dtype}'" # type: ignore[union-attr] + ) + obj._array = x # type: ignore[assignment] + return obj + + def __abs__(self: Array, /) -> Array: + """ + Performs the operation __abs__. + """ + if self.dtype not in numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in __abs__") + res = self._array.__abs__() + return self.__class__._new(res) + + def __and__(self: Array, other: Union[int, bool, Array], /) -> Array: + """ + Performs the operation __and__. + """ + other = self._check_allowed_dtypes(other, "integer or boolean", "__and__") + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__and__(other._array) + return self.__class__._new(res) + + def __array_namespace__(self: Array, /, *, api_version: Optional[str] = None) -> types.ModuleType: + if api_version is not None and not api_version.startswith("2021."): + raise ValueError(f"Unrecognized array API version: {api_version!r}") + from arrayfire import array_api + + return array_api + + def __bool__(self: Array, /) -> bool: + """ + Performs the operation __bool__. + """ + # Note: This is an error here. + if self._array.ndim != 0: + raise TypeError("bool is only allowed on arrays with 0 dimensions") + res = self._array.__bool__() + return res + + def __complex__(self: Array, /) -> complex: + """ + Performs the operation __complex__. + """ + # Note: This is an error here. + if self._array.ndim != 0: + raise TypeError("complex is only allowed on arrays with 0 dimensions") + res = self._array.__complex__() + return res + + def __dlpack__(self: Array, /, *, stream: None = None) -> PyCapsule: + """ + Performs the operation __dlpack__. + """ + return self._array.__dlpack__(stream=stream) + + def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]: + """ + Performs the operation __dlpack_device__. + """ + # Note: device support is required for this + return self._array.__dlpack_device__() + + def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array: + """ + Performs the operation __eq__. + """ + # Even though "all" dtypes are allowed, we still require them to be + # promotable with each other. + other = self._check_allowed_dtypes(other, "all", "__eq__") + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__eq__(other._array) + return self.__class__._new(res) + + def __float__(self: Array, /) -> float: + """ + Performs the operation __float__. + """ + # Note: This is an error here. + if self._array.ndim != 0: + raise TypeError("float is only allowed on arrays with 0 dimensions") + if self.dtype in _complex_floating_dtypes: + raise TypeError("float is not allowed on complex floating-point arrays") + res = self._array.__float__() + return res + + def __floordiv__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __floordiv__. + """ + other = self._check_allowed_dtypes(other, "real numeric", "__floordiv__") + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__floordiv__(other._array) + return self.__class__._new(res) + + def __ge__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __ge__. + """ + other = self._check_allowed_dtypes(other, "real numeric", "__ge__") + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__ge__(other._array) + return self.__class__._new(res) + + def __getitem__( + self: Array, + key: Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], Array], + /, + ) -> Array: + """ + Performs the operation __getitem__. + """ + # Note: Only indices required by the spec are allowed. See the + # docstring of _validate_index + self._validate_index(key) + if isinstance(key, Array): + # Indexing self._array with array_api arrays can be erroneous + key = key._array + res = self._array.__getitem__(key) + return self._new(res) + + def __gt__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __gt__. + """ + other = self._check_allowed_dtypes(other, "real numeric", "__gt__") + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__gt__(other._array) + return self.__class__._new(res) + + def __int__(self: Array, /) -> int: + """ + Performs the operation __int__. + """ + # Note: This is an error here. + if self._array.ndim != 0: + raise TypeError("int is only allowed on arrays with 0 dimensions") + if self.dtype in _complex_floating_dtypes: + raise TypeError("int is not allowed on complex floating-point arrays") + res = self._array.__int__() + return res + + def __index__(self: Array, /) -> int: + """ + Performs the operation __index__. + """ + res = self._array.__index__() + return res + + def __invert__(self: Array, /) -> Array: + """ + Performs the operation __invert__. + """ + if self.dtype not in _integer_or_boolean_dtypes: + raise TypeError("Only integer or boolean dtypes are allowed in __invert__") + res = self._array.__invert__() + return self.__class__._new(res) + + def __le__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __le__. + """ + other = self._check_allowed_dtypes(other, "real numeric", "__le__") + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__le__(other._array) + return self.__class__._new(res) + + def __lshift__(self: Array, other: Union[int, Array], /) -> Array: + """ + Performs the operation __lshift__. + """ + other = self._check_allowed_dtypes(other, "integer", "__lshift__") + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__lshift__(other._array) + return self.__class__._new(res) + + def __lt__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __lt__. + """ + other = self._check_allowed_dtypes(other, "real numeric", "__lt__") + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__lt__(other._array) + return self.__class__._new(res) + + def __matmul__(self: Array, other: Array, /) -> Array: + """ + Performs the operation __matmul__. + """ + # matmul is not defined for scalars, but without this, we may get + # the wrong error message from asarray. + other = self._check_allowed_dtypes(other, "numeric", "__matmul__") + if other is NotImplemented: + return other + res = self._array.__matmul__(other._array) + return self.__class__._new(res) + + def __mod__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __mod__. + """ + other = self._check_allowed_dtypes(other, "real numeric", "__mod__") + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__mod__(other._array) + return self.__class__._new(res) + + def __mul__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __mul__. + """ + other = self._check_allowed_dtypes(other, "numeric", "__mul__") + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__mul__(other._array) + return self.__class__._new(res) + + def __ne__(self: Array, other: Union[int, float, bool, Array], /) -> Array: + """ + Performs the operation __ne__. + """ + other = self._check_allowed_dtypes(other, "all", "__ne__") + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__ne__(other._array) + return self.__class__._new(res) + + def __neg__(self: Array, /) -> Array: + """ + Performs the operation __neg__. + """ + if self.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in __neg__") + res = self._array.__neg__() + return self.__class__._new(res) + + def __or__(self: Array, other: Union[int, bool, Array], /) -> Array: + """ + Performs the operation __or__. + """ + other = self._check_allowed_dtypes(other, "integer or boolean", "__or__") + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__or__(other._array) + return self.__class__._new(res) + + def __pos__(self: Array, /) -> Array: + """ + Performs the operation __pos__. + """ + if self.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in __pos__") + res = self._array.__pos__() + return self.__class__._new(res) + + def __pow__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __pow__. + """ + from ._elementwise_functions import pow + + other = self._check_allowed_dtypes(other, "numeric", "__pow__") + if other is NotImplemented: + return other + # Note: NumPy's __pow__ does not follow type promotion rules for 0-d + # arrays, so we use pow() here instead. + return pow(self, other) + + def __rshift__(self: Array, other: Union[int, Array], /) -> Array: + """ + Performs the operation __rshift__. + """ + other = self._check_allowed_dtypes(other, "integer", "__rshift__") + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__rshift__(other._array) + return self.__class__._new(res) + + def __setitem__( + self, + key: Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], Array], + value: Union[int, float, bool, Array], + /, + ) -> None: + """ + Performs the operation __setitem__. + """ + # Note: Only indices required by the spec are allowed. See the + # docstring of _validate_index + self._validate_index(key) + if isinstance(key, Array): + # Indexing self._array with array_api arrays can be erroneous + key = key._array + self._array.__setitem__(key, asarray(value)._array) + + def __sub__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __sub__. + """ + other = self._check_allowed_dtypes(other, "numeric", "__sub__") + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__sub__(other._array) + return self.__class__._new(res) + + # PEP 484 requires int to be a subtype of float, but __truediv__ should + # not accept int. + def __truediv__(self: Array, other: Union[float, Array], /) -> Array: + """ + Performs the operation __truediv__. + """ + other = self._check_allowed_dtypes(other, "floating-point", "__truediv__") + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__truediv__(other._array) + return self.__class__._new(res) + + def __xor__(self: Array, other: Union[int, bool, Array], /) -> Array: + """ + Performs the operation __xor__. + """ + other = self._check_allowed_dtypes(other, "integer or boolean", "__xor__") + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__xor__(other._array) + return self.__class__._new(res) + + def __iadd__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __iadd__. + """ + other = self._check_allowed_dtypes(other, "numeric", "__iadd__") + if other is NotImplemented: + return other + self._array.__iadd__(other._array) + return self + + def __radd__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __radd__. + """ + other = self._check_allowed_dtypes(other, "numeric", "__radd__") + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__radd__(other._array) + return self.__class__._new(res) + + def __iand__(self: Array, other: Union[int, bool, Array], /) -> Array: + """ + Performs the operation __iand__. + """ + other = self._check_allowed_dtypes(other, "integer or boolean", "__iand__") + if other is NotImplemented: + return other + self._array.__iand__(other._array) + return self + + def __rand__(self: Array, other: Union[int, bool, Array], /) -> Array: + """ + Performs the operation __rand__. + """ + other = self._check_allowed_dtypes(other, "integer or boolean", "__rand__") + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__rand__(other._array) + return self.__class__._new(res) + + def __ifloordiv__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __ifloordiv__. + """ + other = self._check_allowed_dtypes(other, "real numeric", "__ifloordiv__") + if other is NotImplemented: + return other + self._array.__ifloordiv__(other._array) + return self + + def __rfloordiv__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __rfloordiv__. + """ + other = self._check_allowed_dtypes(other, "real numeric", "__rfloordiv__") + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__rfloordiv__(other._array) + return self.__class__._new(res) + + def __ilshift__(self: Array, other: Union[int, Array], /) -> Array: + """ + Performs the operation __ilshift__. + """ + other = self._check_allowed_dtypes(other, "integer", "__ilshift__") + if other is NotImplemented: + return other + self._array.__ilshift__(other._array) + return self + + def __rlshift__(self: Array, other: Union[int, Array], /) -> Array: + """ + Performs the operation __rlshift__. + """ + other = self._check_allowed_dtypes(other, "integer", "__rlshift__") + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__rlshift__(other._array) + return self.__class__._new(res) + + def __imatmul__(self: Array, other: Array, /) -> Array: + """ + Performs the operation __imatmul__. + """ + # matmul is not defined for scalars, but without this, we may get + # the wrong error message from asarray. + other = self._check_allowed_dtypes(other, "numeric", "__imatmul__") + if other is NotImplemented: + return other + res = self._array.__imatmul__(other._array) + return self.__class__._new(res) + + def __rmatmul__(self: Array, other: Array, /) -> Array: + """ + Performs the operation __rmatmul__. + """ + # matmul is not defined for scalars, but without this, we may get + # the wrong error message from asarray. + other = self._check_allowed_dtypes(other, "numeric", "__rmatmul__") + if other is NotImplemented: + return other + res = self._array.__rmatmul__(other._array) + return self.__class__._new(res) + + def __imod__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __imod__. + """ + other = self._check_allowed_dtypes(other, "real numeric", "__imod__") + if other is NotImplemented: + return other + self._array.__imod__(other._array) + return self + + def __rmod__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __rmod__. + """ + other = self._check_allowed_dtypes(other, "real numeric", "__rmod__") + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__rmod__(other._array) + return self.__class__._new(res) + + def __imul__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __imul__. + """ + other = self._check_allowed_dtypes(other, "numeric", "__imul__") + if other is NotImplemented: + return other + self._array.__imul__(other._array) + return self + + def __rmul__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __rmul__. + """ + other = self._check_allowed_dtypes(other, "numeric", "__rmul__") + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__rmul__(other._array) + return self.__class__._new(res) + + def __ior__(self: Array, other: Union[int, bool, Array], /) -> Array: + """ + Performs the operation __ior__. + """ + other = self._check_allowed_dtypes(other, "integer or boolean", "__ior__") + if other is NotImplemented: + return other + self._array.__ior__(other._array) + return self + + def __ror__(self: Array, other: Union[int, bool, Array], /) -> Array: + """ + Performs the operation __ror__. + """ + other = self._check_allowed_dtypes(other, "integer or boolean", "__ror__") + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__ror__(other._array) + return self.__class__._new(res) + + def __ipow__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __ipow__. + """ + other = self._check_allowed_dtypes(other, "numeric", "__ipow__") + if other is NotImplemented: + return other + self._array.__ipow__(other._array) + return self + + def __rpow__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __rpow__. + """ + other = self._check_allowed_dtypes(other, "numeric", "__rpow__") + if other is NotImplemented: + return other + self._array.__rpow__(other._array) + return self + + def __irshift__(self: Array, other: Union[int, Array], /) -> Array: + """ + Performs the operation __irshift__. + """ + other = self._check_allowed_dtypes(other, "integer", "__irshift__") + if other is NotImplemented: + return other + self._array.__irshift__(other._array) + return self + + def __rrshift__(self: Array, other: Union[int, Array], /) -> Array: + """ + Performs the operation __rrshift__. + """ + other = self._check_allowed_dtypes(other, "integer", "__rrshift__") + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__rrshift__(other._array) + return self.__class__._new(res) + + def __isub__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __isub__. + """ + other = self._check_allowed_dtypes(other, "numeric", "__isub__") + if other is NotImplemented: + return other + self._array.__isub__(other._array) + return self + + def __rsub__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __rsub__. + """ + other = self._check_allowed_dtypes(other, "numeric", "__rsub__") + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__rsub__(other._array) + return self.__class__._new(res) + + def __itruediv__(self: Array, other: Union[float, Array], /) -> Array: + """ + Performs the operation __itruediv__. + """ + other = self._check_allowed_dtypes(other, "floating-point", "__itruediv__") + if other is NotImplemented: + return other + self._array.__itruediv__(other._array) + return self + + def __rtruediv__(self: Array, other: Union[float, Array], /) -> Array: + """ + Performs the operation __rtruediv__. + """ + other = self._check_allowed_dtypes(other, "floating-point", "__rtruediv__") + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__rtruediv__(other._array) + return self.__class__._new(res) + + def __ixor__(self: Array, other: Union[int, bool, Array], /) -> Array: + """ + Performs the operation __ixor__. + """ + other = self._check_allowed_dtypes(other, "integer or boolean", "__ixor__") + if other is NotImplemented: + return other + self._array.__ixor__(other._array) + return self + + def __rxor__(self: Array, other: Union[int, bool, Array], /) -> Array: + """ + Performs the operation __rxor__. + """ + other = self._check_allowed_dtypes(other, "integer or boolean", "__rxor__") + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__rxor__(other._array) + return self.__class__._new(res) + + def to_device(self: Array, device: Device, /, stream: None = None) -> Array: + if stream is not None: + raise ValueError("The stream argument to to_device() is not supported") + if device == "cpu": + return self + raise ValueError(f"Unsupported device {device!r}") + + @property + def dtype(self) -> Dtype: + """ + Array API compatible wrapper for :py:meth:`np.ndarray.dtype `. + + See its docstring for more information. + """ + return self._array.dtype + + @property + def device(self) -> Device: + return "cpu" + + @property + def mT(self) -> Array: + from .linalg import matrix_transpose + + return matrix_transpose(self) + + @property + def ndim(self) -> int: + """ + Array API compatible wrapper for :py:meth:`np.ndarray.ndim `. + + See its docstring for more information. + """ + return self._array.ndim + + @property + def shape(self) -> Tuple[int, ...]: + """ + Array API compatible wrapper for :py:meth:`np.ndarray.shape `. + + See its docstring for more information. + """ + return self._array.shape + + @property + def size(self) -> int: + """ + Array API compatible wrapper for :py:meth:`np.ndarray.size `. + + See its docstring for more information. + """ + return self._array.size + + @property + def T(self) -> Array: + """ + Array API compatible wrapper for :py:meth:`np.ndarray.T `. + + See its docstring for more information. + """ + # Note: T only works on 2-dimensional arrays. See the corresponding + # note in the specification: + # https://data-apis.org/array-api/latest/API_specification/array_object.html#t + if self.ndim != 2: + raise ValueError( + "x.T requires x to have 2 dimensions. Use x.mT to transpose stacks of matrices and permute_dims() to permute dimensions." + ) + return self.__class__._new(self._array.T) diff --git a/arrayfire/array_api/constants.py b/arrayfire/array_api/constants.py new file mode 100755 index 0000000..7e45310 --- /dev/null +++ b/arrayfire/array_api/constants.py @@ -0,0 +1,40 @@ +""" +This file defines the types for type annotations. + +These names aren't part of the module namespace, but they are used in the +annotations in the function signatures. The functions in the module are only +valid for inputs that match the given type annotations. +""" + +from __future__ import annotations + +__all__ = [ + "Array", + "Device", + "SupportsDLPack", + "SupportsBufferProtocol", + "PyCapsule", +] + +from typing import Any, Literal, Protocol, TypeVar +from .array_object import Array + +_T_co = TypeVar("_T_co", covariant=True) + + +class NestedSequence(Protocol[_T_co]): + def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: + ... + + def __len__(self, /) -> int: + ... + + +Device = Literal["cpu"] # FIXME: add support for other devices +SupportsBufferProtocol = Any +PyCapsule = Any + + +class SupportsDLPack(Protocol): + def __dlpack__(self, /, *, stream: None = ...) -> PyCapsule: + ... diff --git a/arrayfire/array_api/creation_function.py b/arrayfire/array_api/creation_function.py new file mode 100755 index 0000000..0009513 --- /dev/null +++ b/arrayfire/array_api/creation_function.py @@ -0,0 +1,27 @@ +from .constants import NestedSequence, SupportsBufferProtocol, Device +from .array_object import Array +from arrayfire.dtypes import Dtype, supported_dtypes +from arrayfire.library.device import supported_devices +from arrayfire import Array as AFArray +from typing import Union, Optional + + +def asarray( + obj: Union[Array, bool, int, float, complex, NestedSequence, SupportsBufferProtocol], + /, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + copy: Optional[bool] = None, +) -> Array: + if dtype not in supported_dtypes: + raise ValueError(f"Unsupported dtype {dtype!r}") + + # if device not in supported_devices: + # raise ValueError(f"Unsupported device {device!r}") + + if dtype is None and isinstance(obj, int) and (obj > 2**64 or obj < -(2**63)): + raise OverflowError("Integer out of bounds for array dtypes") + + array = AFArray(obj, dtype=dtype, device=device) + return Array._new(array) diff --git a/arrayfire/array_api/dtypes.py b/arrayfire/array_api/dtypes.py new file mode 100755 index 0000000..a6aa3f9 --- /dev/null +++ b/arrayfire/array_api/dtypes.py @@ -0,0 +1,178 @@ +from arrayfire import ( + bool, + complex64, + complex128, + float32, + float64, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, +) +from arrayfire.dtypes import Dtype + +all_dtypes = ( + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float32, + float64, + complex64, + complex128, + bool, +) +boolean_dtypes = (bool,) +real_floating_dtypes = (float32, float64) +floating_dtypes = (float32, float64, complex64, complex128) +complex_floating_dtypes = (complex64, complex128) +integer_dtypes = (int8, int16, int32, int64, uint8, uint16, uint32, uint64) +signed_integer_dtypes = (int8, int16, int32, int64) +unsigned_integer_dtypes = (uint8, uint16, uint32, uint64) +integer_orboolean_dtypes = ( + bool, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, +) +real_numeric_dtypes = ( + float32, + float64, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, +) +numeric_dtypes = ( + float32, + float64, + complex64, + complex128, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, +) + +dtype_categories = { + "all": all_dtypes, + "real numeric": real_numeric_dtypes, + "numeric": numeric_dtypes, + "integer": integer_dtypes, + "integer or boolean": integer_orboolean_dtypes, + "boolean": boolean_dtypes, + "real floating-point": floating_dtypes, + "complex floating-point": complex_floating_dtypes, + "floating-point": floating_dtypes, +} + + +# Note: the spec defines a restricted type promotion table compared to NumPy. +# In particular, cross-kind promotions like integer + float or boolean + +# integer are not allowed, even for functions that accept both kinds. +# Additionally, NumPy promotes signed integer + uint64 to float64, but this +# promotion is not allowed here. To be clear, Python scalar int objects are +# allowed to promote to floating-point dtypes, but only in array operators +# (see Array._promote_scalar) method in _array_object.py. +_promotion_table = { + (int8, int8): int8, + (int8, int16): int16, + (int8, int32): int32, + (int8, int64): int64, + (int16, int8): int16, + (int16, int16): int16, + (int16, int32): int32, + (int16, int64): int64, + (int32, int8): int32, + (int32, int16): int32, + (int32, int32): int32, + (int32, int64): int64, + (int64, int8): int64, + (int64, int16): int64, + (int64, int32): int64, + (int64, int64): int64, + (uint8, uint8): uint8, + (uint8, uint16): uint16, + (uint8, uint32): uint32, + (uint8, uint64): uint64, + (uint16, uint8): uint16, + (uint16, uint16): uint16, + (uint16, uint32): uint32, + (uint16, uint64): uint64, + (uint32, uint8): uint32, + (uint32, uint16): uint32, + (uint32, uint32): uint32, + (uint32, uint64): uint64, + (uint64, uint8): uint64, + (uint64, uint16): uint64, + (uint64, uint32): uint64, + (uint64, uint64): uint64, + (int8, uint8): int16, + (int8, uint16): int32, + (int8, uint32): int64, + (int16, uint8): int16, + (int16, uint16): int32, + (int16, uint32): int64, + (int32, uint8): int32, + (int32, uint16): int32, + (int32, uint32): int64, + (int64, uint8): int64, + (int64, uint16): int64, + (int64, uint32): int64, + (uint8, int8): int16, + (uint16, int8): int32, + (uint32, int8): int64, + (uint8, int16): int16, + (uint16, int16): int32, + (uint32, int16): int64, + (uint8, int32): int32, + (uint16, int32): int32, + (uint32, int32): int64, + (uint8, int64): int64, + (uint16, int64): int64, + (uint32, int64): int64, + (float32, float32): float32, + (float32, float64): float64, + (float64, float32): float64, + (float64, float64): float64, + (complex64, complex64): complex64, + (complex64, complex128): complex128, + (complex128, complex64): complex128, + (complex128, complex128): complex128, + (float32, complex64): complex64, + (float32, complex128): complex128, + (float64, complex64): complex128, + (float64, complex128): complex128, + (complex64, float32): complex64, + (complex64, float64): complex128, + (complex128, float32): complex128, + (complex128, float64): complex128, + (bool, bool): bool, +} + + +def promote_types(type1: Dtype, type2: Dtype) -> Dtype: + if (type1, type2) in _promotion_table: + return _promotion_table[type1, type2] + raise TypeError(f"{type1} and {type2} cannot be type promoted together") diff --git a/arrayfire/backend/backend.py b/arrayfire/backend/backend.py index 2a5cd05..e779d11 100644 --- a/arrayfire/backend/backend.py +++ b/arrayfire/backend/backend.py @@ -1,11 +1,21 @@ import ctypes import enum from dataclasses import dataclass +import os +import platform +import sys +import traceback +from typing import List, Optional, Tuple +import ctypes +from pathlib import Path + +from arrayfire import config -from ..dtypes.helpers import c_dim_t, to_str +from arrayfire.dtypes.helpers import c_dim_t, to_str, CShape +from arrayfire.dtypes import Dtype, float32 # HACK for osx -backend_api = ctypes.CDLL("/opt/arrayfire//lib/libafcpu.3.dylib") +# backend_api = ctypes.CDLL("/opt/arrayfire//lib/libafcpu.3.dylib") # HACK for windows # backend_api = ctypes.CDLL("C:/Program Files/ArrayFire/v3/lib/afcpu.dll") @@ -27,3 +37,138 @@ class _ErrorCodes(enum.Enum): class ArrayBuffer: address: int length: int = 0 + + +class Backend: + def __init__(self) -> None: + self._clibs = {"cuda": None, "opencl": None, "cpu": None, "unified": None} + + self._backend_map = {0: "unified", 1: "cpu", 2: "cuda", 4: "opencl"} + + self._backend_name_map = {"default": 0, "unified": 0, "cpu": 1, "cuda": 2, "opencl": 4} + + more_info_str = "Please look at https://github.com/arrayfire/arrayfire-python/wiki for more information." + self.setup_obj = config.setup() + + af_module = __import__(__name__) + self.AF_PYMODULE_PATH = af_module.__path__[0] + "/" if af_module.__path__ else None + + self._name = None + + libnames = reversed(self._libname("forge", head="", ver_major=config.FORGE_VER_MAJOR)) + VERBOSE_LOADS = os.environ.get("AF_VERBOSE_LOADS") == "1" + + for libname in libnames: + try: + full_libname = libname[0] + libname[1] + ctypes.cdll.LoadLibrary(full_libname) + if VERBOSE_LOADS: + print("Loaded " + full_libname) + break + except OSError: + if VERBOSE_LOADS: + traceback.print_exc() + print("Unable to load " + full_libname) + pass + + out = ctypes.c_void_p(0) + dims = CShape(10, 10, 1, 1) + for name in ("cpu", "opencl", "cuda", ""): + libnames = reversed(self._libname(name)) + for libname in libnames: + try: + full_libname = Path(libname[0]) / Path(libname[1]) + ctypes.cdll.LoadLibrary(str(full_libname)) + _name = "unified" if name == "" else name + clib = ctypes.CDLL(str(full_libname)) + self._clibs[_name] = clib + err = clib.af_randu(ctypes.pointer(out), 4, ctypes.pointer(dims.c_array), float32.c_api_value) + if err == _ErrorCodes.none.value: + self._name = _name + clib.af_release_array(out) + if VERBOSE_LOADS: + print("Loaded " + full_libname) + + if name == "cuda": + nvrtc_name = self._find_nvrtc_builtins_libname(libname[0]) + if nvrtc_name: + ctypes.cdll.LoadLibrary(libname[0] + nvrtc_name) + if VERBOSE_LOADS: + print("Loaded " + libname[0] + nvrtc_name) + else: + if VERBOSE_LOADS: + print("Could not find local nvrtc-builtins library") + break + except OSError: + if VERBOSE_LOADS: + traceback.print_exc() + print("Unable to load " + full_libname) + pass + + # if self._name is None: + # raise RuntimeError("Could not load any ArrayFire libraries.\n" + more_info_str) + + def _libname(self, name, head="af", ver_major=config.AF_VER_MAJOR) -> List[str]: + post = self.setup_obj.post.replace(config._VER_MAJOR_PLACEHOLDER, ver_major) + libname = self.setup_obj.pre + head + name + post + + if self.setup_obj.af_path: + if (self.setup_obj.af_path / "lib64").is_dir(): + path_search = self.setup_obj.af_path / "lib64/" + else: + path_search = self.setup_obj.af_path / "lib/" + else: + if (self.setup_obj.af_path / "lib64").is_dir(): + path_search = self.setup_obj.af_path / "lib64/" + else: + path_search = self.setup_obj.af_path / "lib/" + + if platform.architecture()[0][:2] == "64": + path_site = sys.prefix + "/lib64/" + else: + path_site = sys.prefix + "/lib/" + + path_local = self.AF_PYMODULE_PATH + libpaths = [("", libname), (str(path_site), libname), (str(path_local), libname)] + if self.setup_obj.af_path: # prefer specified AF_PATH if exists + libpaths.append((str(path_search), libname)) + else: + libpaths.insert(2, (str(path_search), libname)) + return libpaths + + def _find_nvrtc_builtins_libname(self, search_path): + filelist = os.listdir(search_path) + for f in filelist: + if "nvrtc-builtins" in f: + return f + return None + + def set_unsafe(self, name: str) -> None: + lib = self._clibs.get(name) + if lib is None: + raise RuntimeError("Backend not found") + self._name = name + + def get_id(self, name: str) -> int: + return self._backend_name_map[name] + + def get_name(self, bk_id: int) -> str: + return self._backend_map.get(bk_id, "unknown") + + def get(self): + return self._clibs.get(self._name) + + def name(self) -> str: + return self._name + + def is_unified(self) -> bool: + return self._name == "unified" + + def parse(self, res: int) -> Tuple[str, ...]: + lst = [] + for key, value in self._backend_name_map.items(): + if value & res: + lst.append(key) + return tuple(lst) + +backend_api = Backend().get() diff --git a/arrayfire/backend/wrapped/everything.py b/arrayfire/backend/wrapped/everything.py index e2a433b..4c4dfa7 100644 --- a/arrayfire/backend/wrapped/everything.py +++ b/arrayfire/backend/wrapped/everything.py @@ -185,6 +185,15 @@ def get_data_ptr(arr: AFArrayType, size: int, dtype: Dtype, /) -> ctypes.Array: return ctypes_array +def copy_array(arr: AFArrayType) -> AFArrayType: + """ + source: https://arrayfire.org/docs/group__c__api__mat.htm#ga6040dc6f0eb127402fbf62c1165f0b9d + """ + out = ctypes.c_void_p(0) + safe_call(backend_api.af_copy_array(ctypes.pointer(out), arr)) + return out + + # Arrayfire Functions diff --git a/arrayfire/backend/wrapped/indexing.py b/arrayfire/backend/wrapped/indexing.py index d8fc580..9699a51 100755 --- a/arrayfire/backend/wrapped/indexing.py +++ b/arrayfire/backend/wrapped/indexing.py @@ -63,13 +63,13 @@ def __init__(self, chunk: Union[int, slice]): self.end.value = 1 self.step.value = 1 - elif 0 > self.end >= self.begin and self.step <= 0: # type: ignore[operator] + elif self.begin <= self.end < 0 and self.step <= 0: # type: ignore[operator] self.begin.value = -2 self.end.value = -2 self.step.value = -1 if chunk.stop: - self.end.value = self.end.value - math.copysign(1, self.step.value) + self.end -= math.copysign(1, self.step) # type: ignore[operator] else: raise IndexError("Invalid type while indexing arrayfire.array") diff --git a/arrayfire/config.py b/arrayfire/config.py index 588cbdf..cebb9dc 100644 --- a/arrayfire/config.py +++ b/arrayfire/config.py @@ -1,6 +1,125 @@ +import ctypes +import os import platform +import sys +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import Iterator, List, Tuple + +_VER_MAJOR_PLACEHOLDER = "__VER_MAJOR__" +AF_VER_MAJOR = "3" +FORGE_VER_MAJOR = "1" + + +class SupportedPlatforms(Enum): + windows = "Windows" + darwin = "Darwin" # OSX + linux = "Linux" + + @classmethod + def is_cygwin(cls, name: str) -> bool: + return "cyg" in name.lower() def is_arch_x86() -> bool: machine = platform.machine() return platform.architecture()[0][0:2] == "32" and (machine[-2:] == "86" or machine[0:3] == "arm") + + +@dataclass +class Setup: + pre: str + post: str + af_path: Path + cuda_found: bool + + def __iter__(self) -> Iterator: + return iter((self.pre, self.post, self.af_path, self.af_path, self.cuda_found)) + + +def setup() -> Setup: + platform_name = platform.system() + cuda_found = False + + try: + af_path = Path(os.environ["AF_PATH"]) + except KeyError: + af_path = None + + try: + cuda_path = Path(os.environ["CUDA_PATH"]) + except KeyError: + cuda_path = None + + if platform_name == SupportedPlatforms.windows.value or SupportedPlatforms.is_cygwin(platform_name): + if platform_name == SupportedPlatforms.windows.value: + # HACK Supressing crashes caused by missing dlls + # http://stackoverflow.com/questions/8347266/missing-dll-print-message-instead-of-launching-a-popup + # https://msdn.microsoft.com/en-us/library/windows/desktop/ms680621.aspx + ctypes.windll.kernel32.SetErrorMode(0x0001 | 0x0002) # type: ignore[attr-defined] + + if not af_path: + af_path = _find_default_path(f"C:/Program Files/ArrayFire/v{AF_VER_MAJOR}") + + if cuda_path and (cuda_path / "bin").is_dir() and (cuda_path / "nvvm/bin").is_dir(): + cuda_found = True + + return Setup("", ".dll", af_path, cuda_found) + + if platform_name == SupportedPlatforms.darwin.value: + default_cuda_path = Path("/usr/local/cuda/") + + if not af_path: + af_path = _find_default_path("/opt/arrayfire", "/usr/local") + + if not (cuda_path and default_cuda_path.exists()): + cuda_found = (default_cuda_path / "lib").is_dir() and (default_cuda_path / "/nvvm/lib").is_dir() + + return Setup("lib", f".{_VER_MAJOR_PLACEHOLDER}.dylib", af_path, cuda_found) + + if platform_name == SupportedPlatforms.linux.value: + default_cuda_path = Path("/usr/local/cuda/") + + if not af_path: + af_path = _find_default_path(f"/opt/arrayfire-{AF_VER_MAJOR}", "/opt/arrayfire/", "/usr/local/") + + if not (cuda_path and default_cuda_path.exists()): + if "64" in platform.architecture()[0]: # Check either is 64 bit arch is selected + cuda_found = (default_cuda_path / "lib64").is_dir() and (default_cuda_path / "nvvm/lib64").is_dir() + else: + cuda_found = (default_cuda_path / "lib").is_dir() and (default_cuda_path / "nvvm/lib").is_dir() + + return Setup("lib", f".so.{_VER_MAJOR_PLACEHOLDER}", af_path, cuda_found) + + raise OSError(f"{platform_name} is not supported.") + + +def _find_default_path(*args: str) -> Path: + for path in args: + default_path = Path(path) + if default_path.exists(): + return default_path + raise ValueError("None of specified default paths were found.") + + +def libnames(setup: Setup) -> List[Tuple[str, str]]: + post = setup.post.replace(_VER_MAJOR_PLACEHOLDER, AF_VER_MAJOR) + libname = setup.pre + "forge" + post + + lib64_path = setup.af_path / "lib64" + search_path = lib64_path if lib64_path.is_dir() else setup.af_path / "lib" + + site_path = Path(sys.prefix) / "lib64" if "64" in platform.architecture()[0] else Path(sys.prefix) / "lib" + + # prefer locally packaged arrayfire libraries if they exist + af_module = __import__(__name__) + local_path = af_module.__path__[0] + "/" if af_module.__path__ else None + + libpaths = [("", libname), (str(site_path), libname), (str(local_path), libname)] + + if setup.af_path: # prefer specified AF_PATH if exists + libpaths.append((str(search_path), libname)) + else: + libpaths.insert(2, (str(search_path), libname)) + return libpaths diff --git a/arrayfire/dtypes/__init__.py b/arrayfire/dtypes/__init__.py index 84f290a..a006696 100644 --- a/arrayfire/dtypes/__init__.py +++ b/arrayfire/dtypes/__init__.py @@ -31,7 +31,7 @@ class Dtype: # Specification required -# int8 - Not Supported, b8? # HACK Dtype("i8", ctypes.c_char, "int8", 4) +int8 = Dtype("i8", ctypes.c_char, "int8", 4) # HACK int8 - Not Supported, b8? int16 = Dtype("h", ctypes.c_short, "short int", 10) int32 = Dtype("i", ctypes.c_int, "int", 5) int64 = Dtype("l", ctypes.c_longlong, "long int", 8) diff --git a/arrayfire/dtypes/helpers.py b/arrayfire/dtypes/helpers.py index cf4d306..3894d74 100644 --- a/arrayfire/dtypes/helpers.py +++ b/arrayfire/dtypes/helpers.py @@ -3,7 +3,7 @@ import ctypes from typing import Tuple, Union -from ..config import is_arch_x86 +from arrayfire.config import is_arch_x86 from . import Dtype from . import bool as af_bool from . import complex64, complex128, float32, float64, int64, supported_dtypes diff --git a/arrayfire/library/array_object.py b/arrayfire/library/array_object.py index 1ac81fe..ecbe319 100755 --- a/arrayfire/library/array_object.py +++ b/arrayfire/library/array_object.py @@ -23,7 +23,7 @@ class Array: def __init__( self, - x: Union[None, Array, py_array.array, int, ctypes.c_void_p, List[Union[int, float]]] = None, + obj: Union[None, Array, py_array.array, int, ctypes.c_void_p, List[Union[int, float]]] = None, dtype: Union[None, Dtype, str] = None, shape: Tuple[int, ...] = (), pointer_source: PointerSource = PointerSource.host, @@ -39,7 +39,7 @@ def __init__( _no_initial_dtype = True dtype = af_float32 - if x is None: + if obj is None: if not shape: # shape is None or empty tuple self.arr = everything.create_handle((), dtype) return @@ -47,32 +47,32 @@ def __init__( self.arr = everything.create_handle(shape, dtype) return - if isinstance(x, Array): - self.arr = everything.retain_array(x.arr) + if isinstance(obj, Array): + self.arr = everything.retain_array(obj.arr) return - if isinstance(x, py_array.array): - _type_char: str = x.typecode - _array_buffer = ArrayBuffer(*x.buffer_info()) + if isinstance(obj, py_array.array): + _type_char: str = obj.typecode + _array_buffer = ArrayBuffer(*obj.buffer_info()) - elif isinstance(x, list): - _array = py_array.array("f", x) # BUG [True, False] -> dtype: f32 # TODO add int and float + elif isinstance(obj, list): + _array = py_array.array("f", obj) # BUG [True, False] -> dtype: f32 # TODO add int and float _type_char = _array.typecode _array_buffer = ArrayBuffer(*_array.buffer_info()) - elif isinstance(x, int) or isinstance(x, ctypes.c_void_p): # TODO - _array_buffer = ArrayBuffer(x if not isinstance(x, ctypes.c_void_p) else x.value) # type: ignore[arg-type] + elif isinstance(obj, int) or isinstance(obj, ctypes.c_void_p): # TODO + _array_buffer = ArrayBuffer(obj if not isinstance(obj, ctypes.c_void_p) else obj.value) # type: ignore if not shape: - raise TypeError("Expected to receive the initial shape due to the x being a data pointer.") + raise TypeError("Expected to receive the initial shape due to the obj being a data pointer.") if _no_initial_dtype: - raise TypeError("Expected to receive the initial dtype due to the x being a data pointer.") + raise TypeError("Expected to receive the initial dtype due to the obj being a data pointer.") _type_char = dtype.typecode else: - raise TypeError("Passed object x is an object of unsupported class.") + raise TypeError("Passed object obj is an object of unsupported class.") if not shape: if _array_buffer.length != 0: @@ -878,7 +878,7 @@ def to_list(self, row_major: bool = False) -> List[Union[None, int, float, bool, ctypes_array = everything.get_data_ptr(array.arr, array.size, array.dtype) if array.ndim == 1: - return list(ctypes_array) + return ctypes_array[:] out = [] for i in range(array.size): @@ -888,7 +888,7 @@ def to_list(self, row_major: bool = False) -> List[Union[None, int, float, bool, div = array.shape[j] sub_list.append(idx % div) idx //= div - out.append(ctypes_array[sub_list[::-1]]) # type: ignore[call-overload] # FIXME + out.append(ctypes_array[tuple(sub_list)]) # type: ignore[call-overload] # FIXME return out def to_ctype_array(self, row_major: bool = False) -> ctypes.Array: @@ -898,6 +898,18 @@ def to_ctype_array(self, row_major: bool = False) -> ctypes.Array: array = _reorder(self) if row_major else self return everything.get_data_ptr(array.arr, array.size, array.dtype) + def copy(self) -> Array: # BUG: this is not a deep copy + """ + Performs a deep copy of the array. + + Returns + ------- + out: af.Array() + An identical copy of self. + """ + self.arr = everything.copy_array(self.arr) + return self + IndexKey = Union[int, slice, Tuple[Union[int, slice], ...], Array] diff --git a/arrayfire/library/device.py b/arrayfire/library/device.py index 3395750..a51dc21 100644 --- a/arrayfire/library/device.py +++ b/arrayfire/library/device.py @@ -9,3 +9,6 @@ class PointerSource(enum.Enum): # FIXME device = 0 host = 1 + + +supported_devices = [] diff --git a/tests/array_object/test_initialization.py b/tests/array_object/test_initialization.py index 3e3d6d0..c2adcd1 100644 --- a/tests/array_object/test_initialization.py +++ b/tests/array_object/test_initialization.py @@ -54,4 +54,4 @@ def test_initalization_with_unsupported_argument_types( array_object: Any, dtype: Optional[Dtype], shape: Tuple[int, ...] ) -> None: with pytest.raises(TypeError): - Array(x=array_object, dtype=dtype, shape=shape) + Array(obj=array_object, dtype=dtype, shape=shape) diff --git a/tests/array_object/test_methods.py b/tests/array_object/test_methods.py index 20fde2d..3e924a6 100644 --- a/tests/array_object/test_methods.py +++ b/tests/array_object/test_methods.py @@ -1,9 +1,8 @@ +from arrayfire.dtypes import float32, int32 from arrayfire.library.array_object import Array -# TODO add more tests for different dtypes - -def test_array_getitem() -> None: +def test_array_getitem_by_index() -> None: array = Array([1, 2, 3, 4, 5]) int_item = array[2] @@ -11,6 +10,13 @@ def test_array_getitem() -> None: assert int_item.scalar() == 3 +def test_array_getitem_by_slice() -> None: + array = Array([1, 2, 3, 4, 5]) + + slice_item = array[1:3] + assert slice_item.to_list() == [2, 3] + + def test_scalar() -> None: array = Array([1, 2, 3]) assert array[1].scalar() == 2 @@ -29,3 +35,18 @@ def test_array_to_list() -> None: def test_array_to_list_is_empty() -> None: array = Array() assert array.to_list() == [] + + +def test_array_to_list_comparison() -> None: + array1 = Array([1, 2, 3]) + array2 = Array([1, 2, 3]) + assert array1 is not array2 + assert array1.to_list() == array2.to_list() + + +# BUG +# def test_copy_for_array_with_multiple_elements() -> None: +# array = Array([1, 2, 3]) +# copy = array.copy() +# assert array is not copy +# assert array.to_list() == copy.to_list() diff --git a/tests/array_object/test_operators.py b/tests/array_object/test_operators.py index 423ac54..8ed93c1 100644 --- a/tests/array_object/test_operators.py +++ b/tests/array_object/test_operators.py @@ -9,7 +9,7 @@ Operator = Callable[[Union[int, float, Array], Union[int, float, Array]], Array] -def _round(list_: List[Union[int, float]], symbols: int = 4) -> List[Union[int, float]]: +def _round(list_: List[Union[int, float]], symbols: int = 3) -> List[Union[int, float]]: # HACK replace for e.g. abs(x1-x2) < 1e-6 ~ https://davidamos.dev/the-right-way-to-compare-floats-in-python/ return [round(x, symbols) for x in list_] From af7c6d3c9039ffc83051c8e4926f012bb6fd10a9 Mon Sep 17 00:00:00 2001 From: Anton Chernyatevich Date: Thu, 10 Aug 2023 22:15:39 +0300 Subject: [PATCH 02/31] Refactoring --- arrayfire/backend/__init__.py | 2 +- arrayfire/backend/backend.py | 121 ++++++++---------- .../{wrapped => c_backend}/__init__.py | 0 .../{wrapped => c_backend}/constant_array.py | 0 .../{wrapped => c_backend}/constants.py | 0 .../{wrapped => c_backend}/indexing.py | 0 .../{wrapped => c_backend}/operators.py | 0 .../reduction_operations.py | 0 .../everything.py => c_backend/unsorted.py} | 0 arrayfire/{ => backend}/config.py | 41 ++---- arrayfire/backend/constants.py | 0 arrayfire/dtypes/helpers.py | 2 +- arrayfire/library/array_object.py | 48 +++---- arrayfire/version.py | 4 + 14 files changed, 97 insertions(+), 121 deletions(-) rename arrayfire/backend/{wrapped => c_backend}/__init__.py (100%) rename arrayfire/backend/{wrapped => c_backend}/constant_array.py (100%) rename arrayfire/backend/{wrapped => c_backend}/constants.py (100%) rename arrayfire/backend/{wrapped => c_backend}/indexing.py (100%) rename arrayfire/backend/{wrapped => c_backend}/operators.py (100%) rename arrayfire/backend/{wrapped => c_backend}/reduction_operations.py (100%) rename arrayfire/backend/{wrapped/everything.py => c_backend/unsorted.py} (100%) rename arrayfire/{ => backend}/config.py (70%) create mode 100755 arrayfire/backend/constants.py diff --git a/arrayfire/backend/__init__.py b/arrayfire/backend/__init__.py index 1f9c864..1e31ea2 100644 --- a/arrayfire/backend/__init__.py +++ b/arrayfire/backend/__init__.py @@ -23,7 +23,7 @@ ] from .backend import ArrayBuffer -from .wrapped.operators import ( +from .c_backend.operators import ( add, bitand, bitnot, diff --git a/arrayfire/backend/backend.py b/arrayfire/backend/backend.py index e779d11..fb5a6ff 100644 --- a/arrayfire/backend/backend.py +++ b/arrayfire/backend/backend.py @@ -1,32 +1,18 @@ import ctypes import enum -from dataclasses import dataclass import os -import platform import sys import traceback -from typing import List, Optional, Tuple -import ctypes +from dataclasses import dataclass from pathlib import Path +from typing import List, Optional, Tuple -from arrayfire import config - -from arrayfire.dtypes.helpers import c_dim_t, to_str, CShape +from arrayfire.backend import config from arrayfire.dtypes import Dtype, float32 +from arrayfire.dtypes.helpers import CShape, c_dim_t, to_str +from arrayfire.version import ARRAYFIRE_VER_MAJOR, FORGE_VER_MAJOR -# HACK for osx -# backend_api = ctypes.CDLL("/opt/arrayfire//lib/libafcpu.3.dylib") -# HACK for windows -# backend_api = ctypes.CDLL("C:/Program Files/ArrayFire/v3/lib/afcpu.dll") - - -def safe_call(c_err: int) -> None: - if c_err == _ErrorCodes.none.value: - return - - err_str = ctypes.c_char_p(0) - backend_api.af_get_last_error(ctypes.pointer(err_str), ctypes.pointer(c_dim_t(0))) - raise RuntimeError(to_str(err_str)) +VERBOSE_LOADS = os.environ.get("AF_VERBOSE_LOADS") == "1" class _ErrorCodes(enum.Enum): @@ -39,28 +25,28 @@ class ArrayBuffer: length: int = 0 -class Backend: - def __init__(self) -> None: - self._clibs = {"cuda": None, "opencl": None, "cpu": None, "unified": None} +class BackendDevices(enum.Enum): + unified = 0 # NOTE It is set as Default value on Arrayfire backend + cpu = 1 + cuda = 2 + opencl = 4 - self._backend_map = {0: "unified", 1: "cpu", 2: "cuda", 4: "opencl"} - self._backend_name_map = {"default": 0, "unified": 0, "cpu": 1, "cuda": 2, "opencl": 4} +class Backend: + def __init__(self) -> None: + self._clibs = {device.name: None for device in BackendDevices} more_info_str = "Please look at https://github.com/arrayfire/arrayfire-python/wiki for more information." - self.setup_obj = config.setup() + self.setup = config.setup() af_module = __import__(__name__) self.AF_PYMODULE_PATH = af_module.__path__[0] + "/" if af_module.__path__ else None self._name = None - libnames = reversed(self._libname("forge", head="", ver_major=config.FORGE_VER_MAJOR)) - VERBOSE_LOADS = os.environ.get("AF_VERBOSE_LOADS") == "1" - - for libname in libnames: + for libname in self._libnames(config.SupportedLibs.forge, FORGE_VER_MAJOR): + full_libname = libname[0] + libname[1] try: - full_libname = libname[0] + libname[1] ctypes.cdll.LoadLibrary(full_libname) if VERBOSE_LOADS: print("Loaded " + full_libname) @@ -73,13 +59,12 @@ def __init__(self) -> None: out = ctypes.c_void_p(0) dims = CShape(10, 10, 1, 1) - for name in ("cpu", "opencl", "cuda", ""): - libnames = reversed(self._libname(name)) - for libname in libnames: + for device in BackendDevices: + _name = device.name if device != BackendDevices.unified else "" + for libname in self._libnames(config.SupportedLibs.arrayfire): + full_libname = Path(libname[0]) / Path(libname[1]) try: - full_libname = Path(libname[0]) / Path(libname[1]) ctypes.cdll.LoadLibrary(str(full_libname)) - _name = "unified" if name == "" else name clib = ctypes.CDLL(str(full_libname)) self._clibs[_name] = clib err = clib.af_randu(ctypes.pointer(out), 4, ctypes.pointer(dims.c_array), float32.c_api_value) @@ -89,8 +74,8 @@ def __init__(self) -> None: if VERBOSE_LOADS: print("Loaded " + full_libname) - if name == "cuda": - nvrtc_name = self._find_nvrtc_builtins_libname(libname[0]) + if device == BackendDevices.cuda: + nvrtc_name = self._find_nvrtc_builtins_libname(Path(libname[0])) if nvrtc_name: ctypes.cdll.LoadLibrary(libname[0] + nvrtc_name) if VERBOSE_LOADS: @@ -108,39 +93,31 @@ def __init__(self) -> None: # if self._name is None: # raise RuntimeError("Could not load any ArrayFire libraries.\n" + more_info_str) - def _libname(self, name, head="af", ver_major=config.AF_VER_MAJOR) -> List[str]: - post = self.setup_obj.post.replace(config._VER_MAJOR_PLACEHOLDER, ver_major) - libname = self.setup_obj.pre + head + name + post + def _libnames(self, lib: config.SupportedLibs, ver_major: Optional[str] = None) -> List[Tuple[str, str]]: + post = self.setup.post if ver_major is None else ver_major + libname = self.setup.pre + lib.value + post - if self.setup_obj.af_path: - if (self.setup_obj.af_path / "lib64").is_dir(): - path_search = self.setup_obj.af_path / "lib64/" - else: - path_search = self.setup_obj.af_path / "lib/" - else: - if (self.setup_obj.af_path / "lib64").is_dir(): - path_search = self.setup_obj.af_path / "lib64/" - else: - path_search = self.setup_obj.af_path / "lib/" + lib64_path = self.setup.af_path / "lib64" + search_path = lib64_path if lib64_path.is_dir() else self.setup.af_path / "lib" - if platform.architecture()[0][:2] == "64": - path_site = sys.prefix + "/lib64/" - else: - path_site = sys.prefix + "/lib/" + site_path = Path(sys.prefix) / "lib64" if not config.is_arch_x86() else Path(sys.prefix) / "lib" + + # prefer locally packaged arrayfire libraries if they exist + af_module = __import__(__name__) + local_path = af_module.__path__[0] + "/" if af_module.__path__ else None - path_local = self.AF_PYMODULE_PATH - libpaths = [("", libname), (str(path_site), libname), (str(path_local), libname)] - if self.setup_obj.af_path: # prefer specified AF_PATH if exists - libpaths.append((str(path_search), libname)) + libpaths = [("", libname), (str(site_path), libname), (str(local_path), libname)] + + if self.setup.af_path: # prefer specified AF_PATH if exists + libpaths.append((str(search_path), libname)) else: - libpaths.insert(2, (str(path_search), libname)) + libpaths.insert(2, (str(search_path), libname)) return libpaths - def _find_nvrtc_builtins_libname(self, search_path): - filelist = os.listdir(search_path) - for f in filelist: - if "nvrtc-builtins" in f: - return f + def _find_nvrtc_builtins_libname(self, search_path: Path) -> Optional[str]: + for f in search_path.iterdir(): + if "nvrtc-builtins" in f.name: + return f.name return None def set_unsafe(self, name: str) -> None: @@ -171,4 +148,18 @@ def parse(self, res: int) -> Tuple[str, ...]: lst.append(key) return tuple(lst) + +# HACK for osx +# backend_api = ctypes.CDLL("/opt/arrayfire//lib/libafcpu.3.dylib") +# HACK for windows +# backend_api = ctypes.CDLL("C:/Program Files/ArrayFire/v3/lib/afcpu.dll") backend_api = Backend().get() + + +def safe_call(c_err: int) -> None: + if c_err == _ErrorCodes.none.value: + return + + err_str = ctypes.c_char_p(0) + backend_api.af_get_last_error(ctypes.pointer(err_str), ctypes.pointer(c_dim_t(0))) + raise RuntimeError(to_str(err_str)) diff --git a/arrayfire/backend/wrapped/__init__.py b/arrayfire/backend/c_backend/__init__.py similarity index 100% rename from arrayfire/backend/wrapped/__init__.py rename to arrayfire/backend/c_backend/__init__.py diff --git a/arrayfire/backend/wrapped/constant_array.py b/arrayfire/backend/c_backend/constant_array.py similarity index 100% rename from arrayfire/backend/wrapped/constant_array.py rename to arrayfire/backend/c_backend/constant_array.py diff --git a/arrayfire/backend/wrapped/constants.py b/arrayfire/backend/c_backend/constants.py similarity index 100% rename from arrayfire/backend/wrapped/constants.py rename to arrayfire/backend/c_backend/constants.py diff --git a/arrayfire/backend/wrapped/indexing.py b/arrayfire/backend/c_backend/indexing.py similarity index 100% rename from arrayfire/backend/wrapped/indexing.py rename to arrayfire/backend/c_backend/indexing.py diff --git a/arrayfire/backend/wrapped/operators.py b/arrayfire/backend/c_backend/operators.py similarity index 100% rename from arrayfire/backend/wrapped/operators.py rename to arrayfire/backend/c_backend/operators.py diff --git a/arrayfire/backend/wrapped/reduction_operations.py b/arrayfire/backend/c_backend/reduction_operations.py similarity index 100% rename from arrayfire/backend/wrapped/reduction_operations.py rename to arrayfire/backend/c_backend/reduction_operations.py diff --git a/arrayfire/backend/wrapped/everything.py b/arrayfire/backend/c_backend/unsorted.py similarity index 100% rename from arrayfire/backend/wrapped/everything.py rename to arrayfire/backend/c_backend/unsorted.py diff --git a/arrayfire/config.py b/arrayfire/backend/config.py similarity index 70% rename from arrayfire/config.py rename to arrayfire/backend/config.py index cebb9dc..fb84097 100644 --- a/arrayfire/config.py +++ b/arrayfire/backend/config.py @@ -5,11 +5,14 @@ from dataclasses import dataclass from enum import Enum from pathlib import Path -from typing import Iterator, List, Tuple +from typing import Iterator, List, Optional, Tuple -_VER_MAJOR_PLACEHOLDER = "__VER_MAJOR__" -AF_VER_MAJOR = "3" -FORGE_VER_MAJOR = "1" +from arrayfire.version import ARRAYFIRE_VER_MAJOR, FORGE_VER_MAJOR + + +class SupportedLibs(Enum): + forge = "forge" + arrayfire = "af" class SupportedPlatforms(Enum): @@ -60,7 +63,7 @@ def setup() -> Setup: ctypes.windll.kernel32.SetErrorMode(0x0001 | 0x0002) # type: ignore[attr-defined] if not af_path: - af_path = _find_default_path(f"C:/Program Files/ArrayFire/v{AF_VER_MAJOR}") + af_path = _find_default_path(f"C:/Program Files/ArrayFire/v{ARRAYFIRE_VER_MAJOR}") if cuda_path and (cuda_path / "bin").is_dir() and (cuda_path / "nvvm/bin").is_dir(): cuda_found = True @@ -76,13 +79,13 @@ def setup() -> Setup: if not (cuda_path and default_cuda_path.exists()): cuda_found = (default_cuda_path / "lib").is_dir() and (default_cuda_path / "/nvvm/lib").is_dir() - return Setup("lib", f".{_VER_MAJOR_PLACEHOLDER}.dylib", af_path, cuda_found) + return Setup("lib", f".{ARRAYFIRE_VER_MAJOR}.dylib", af_path, cuda_found) if platform_name == SupportedPlatforms.linux.value: default_cuda_path = Path("/usr/local/cuda/") if not af_path: - af_path = _find_default_path(f"/opt/arrayfire-{AF_VER_MAJOR}", "/opt/arrayfire/", "/usr/local/") + af_path = _find_default_path(f"/opt/arrayfire-{ARRAYFIRE_VER_MAJOR}", "/opt/arrayfire/", "/usr/local/") if not (cuda_path and default_cuda_path.exists()): if "64" in platform.architecture()[0]: # Check either is 64 bit arch is selected @@ -90,7 +93,7 @@ def setup() -> Setup: else: cuda_found = (default_cuda_path / "lib").is_dir() and (default_cuda_path / "nvvm/lib").is_dir() - return Setup("lib", f".so.{_VER_MAJOR_PLACEHOLDER}", af_path, cuda_found) + return Setup("lib", f".so.{ARRAYFIRE_VER_MAJOR}", af_path, cuda_found) raise OSError(f"{platform_name} is not supported.") @@ -101,25 +104,3 @@ def _find_default_path(*args: str) -> Path: if default_path.exists(): return default_path raise ValueError("None of specified default paths were found.") - - -def libnames(setup: Setup) -> List[Tuple[str, str]]: - post = setup.post.replace(_VER_MAJOR_PLACEHOLDER, AF_VER_MAJOR) - libname = setup.pre + "forge" + post - - lib64_path = setup.af_path / "lib64" - search_path = lib64_path if lib64_path.is_dir() else setup.af_path / "lib" - - site_path = Path(sys.prefix) / "lib64" if "64" in platform.architecture()[0] else Path(sys.prefix) / "lib" - - # prefer locally packaged arrayfire libraries if they exist - af_module = __import__(__name__) - local_path = af_module.__path__[0] + "/" if af_module.__path__ else None - - libpaths = [("", libname), (str(site_path), libname), (str(local_path), libname)] - - if setup.af_path: # prefer specified AF_PATH if exists - libpaths.append((str(search_path), libname)) - else: - libpaths.insert(2, (str(search_path), libname)) - return libpaths diff --git a/arrayfire/backend/constants.py b/arrayfire/backend/constants.py new file mode 100755 index 0000000..e69de29 diff --git a/arrayfire/dtypes/helpers.py b/arrayfire/dtypes/helpers.py index 3894d74..20c679e 100644 --- a/arrayfire/dtypes/helpers.py +++ b/arrayfire/dtypes/helpers.py @@ -3,7 +3,7 @@ import ctypes from typing import Tuple, Union -from arrayfire.config import is_arch_x86 +from arrayfire.backend.config import is_arch_x86 from . import Dtype from . import bool as af_bool from . import complex64, complex128, float32, float64, int64, supported_dtypes diff --git a/arrayfire/library/array_object.py b/arrayfire/library/array_object.py index ecbe319..e775bae 100755 --- a/arrayfire/library/array_object.py +++ b/arrayfire/library/array_object.py @@ -7,10 +7,10 @@ from .. import backend from ..backend import ArrayBuffer -from ..backend.wrapped import everything -from ..backend.wrapped.constant_array import create_constant_array -from ..backend.wrapped.indexing import CIndexStructure, IndexStructure -from ..backend.wrapped.reduction_operations import count_all +from ..backend.c_backend import unsorted +from ..backend.c_backend.constant_array import create_constant_array +from ..backend.c_backend.indexing import CIndexStructure, IndexStructure +from ..backend.c_backend.reduction_operations import count_all from ..dtypes import CType from ..dtypes import bool as af_bool from ..dtypes import float32 as af_float32 @@ -41,14 +41,14 @@ def __init__( if obj is None: if not shape: # shape is None or empty tuple - self.arr = everything.create_handle((), dtype) + self.arr = unsorted.create_handle((), dtype) return - self.arr = everything.create_handle(shape, dtype) + self.arr = unsorted.create_handle(shape, dtype) return if isinstance(obj, Array): - self.arr = everything.retain_array(obj.arr) + self.arr = unsorted.retain_array(obj.arr) return if isinstance(obj, py_array.array): @@ -85,13 +85,13 @@ def __init__( if not (offset or strides): if pointer_source == PointerSource.host: - self.arr = everything.create_array(shape, dtype, _array_buffer) + self.arr = unsorted.create_array(shape, dtype, _array_buffer) return - self.arr = everything.device_array(shape, dtype, _array_buffer) + self.arr = unsorted.device_array(shape, dtype, _array_buffer) return - self.arr = everything.create_strided_array( + self.arr = unsorted.create_strided_array( shape, dtype, _array_buffer, offset, strides, pointer_source # type: ignore[arg-type] ) @@ -733,7 +733,7 @@ def __getitem__(self, key: IndexKey, /) -> Array: return out # HACK known issue - out.arr = everything.index_gen(self.arr, ndims, key, _get_indices(key)) # type: ignore[arg-type] + out.arr = unsorted.index_gen(self.arr, ndims, key, _get_indices(key)) # type: ignore[arg-type] return out def __index__(self) -> int: @@ -755,12 +755,12 @@ def __str__(self) -> str: # TODO change the look of array str. E.g., like np.array # if not _in_display_dims_limit(self.shape): # return _metadata_string(self.dtype, self.shape) - return _metadata_string(self.dtype) + everything.array_as_str(self.arr) + return _metadata_string(self.dtype) + unsorted.array_as_str(self.arr) def __repr__(self) -> str: # return _metadata_string(self.dtype, self.shape) # TODO change the look of array representation. E.g., like np.array - return everything.array_as_str(self.arr) + return unsorted.array_as_str(self.arr) def to_device(self, device: Any, /, *, stream: Union[int, Any] = None) -> Array: # TODO implementation and change device type from Any to Device @@ -778,7 +778,7 @@ def dtype(self) -> Dtype: out : Dtype Array data type. """ - return c_api_value_to_dtype(everything.get_ctype(self.arr)) + return c_api_value_to_dtype(unsorted.get_ctype(self.arr)) @property def device(self) -> Any: @@ -811,7 +811,7 @@ def T(self) -> Array: # TODO add check if out.dtype == self.dtype out = Array() - out.arr = everything.transpose(self.arr, False) + out.arr = unsorted.transpose(self.arr, False) return out @property @@ -829,7 +829,7 @@ def size(self) -> int: - This must equal the product of the array's dimensions. """ # NOTE previously - elements() - return everything.get_elements(self.arr) + return unsorted.get_elements(self.arr) @property def ndim(self) -> int: @@ -839,7 +839,7 @@ def ndim(self) -> int: out : int Number of array dimensions (axes). """ - return everything.get_numdims(self.arr) + return unsorted.get_numdims(self.arr) @property def shape(self) -> Tuple[int, ...]: @@ -852,7 +852,7 @@ def shape(self) -> Tuple[int, ...]: Array dimensions. """ # NOTE skipping passing any None values - return everything.get_dims(self.arr)[: self.ndim] + return unsorted.get_dims(self.arr)[: self.ndim] def scalar(self) -> Union[None, int, float, bool, complex]: """ @@ -862,20 +862,20 @@ def scalar(self) -> Union[None, int, float, bool, complex]: if self.is_empty(): return None - return everything.get_scalar(self.arr, self.dtype) + return unsorted.get_scalar(self.arr, self.dtype) def is_empty(self) -> bool: """ Check if the array is empty i.e. it has no elements. """ - return everything.is_empty(self.arr) + return unsorted.is_empty(self.arr) def to_list(self, row_major: bool = False) -> List[Union[None, int, float, bool, complex]]: if self.is_empty(): return [] array = _reorder(self) if row_major else self - ctypes_array = everything.get_data_ptr(array.arr, array.size, array.dtype) + ctypes_array = unsorted.get_data_ptr(array.arr, array.size, array.dtype) if array.ndim == 1: return ctypes_array[:] @@ -896,7 +896,7 @@ def to_ctype_array(self, row_major: bool = False) -> ctypes.Array: raise RuntimeError("Can not convert an empty array to ctype.") array = _reorder(self) if row_major else self - return everything.get_data_ptr(array.arr, array.size, array.dtype) + return unsorted.get_data_ptr(array.arr, array.size, array.dtype) def copy(self) -> Array: # BUG: this is not a deep copy """ @@ -907,7 +907,7 @@ def copy(self) -> Array: # BUG: this is not a deep copy out: af.Array() An identical copy of self. """ - self.arr = everything.copy_array(self.arr) + self.arr = unsorted.copy_array(self.arr) return self @@ -922,7 +922,7 @@ def _reorder(array: Array) -> Array: return array out = Array() - out.arr = everything.reorder(array.arr, array.ndim) + out.arr = unsorted.reorder(array.arr, array.ndim) return out diff --git a/arrayfire/version.py b/arrayfire/version.py index ae862ac..a402fd7 100644 --- a/arrayfire/version.py +++ b/arrayfire/version.py @@ -11,3 +11,7 @@ VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR) VERSION = "{0}.{1}.{2}{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX) + +FORGE_VER_MAJOR = "0" +ARRAYFIRE_VER_MAJOR = "3" +ARRAYFIRE_VER_MINOR = "8" From 1eae6d0d76cef68c88c2b7fe25888019dd5a852d Mon Sep 17 00:00:00 2001 From: Anton Chernyatevich Date: Thu, 10 Aug 2023 23:06:15 +0300 Subject: [PATCH 03/31] Refactoring --- arrayfire/backend/backend.py | 87 +++++++++++++------------ arrayfire/backend/c_backend/unsorted.py | 10 +++ 2 files changed, 54 insertions(+), 43 deletions(-) diff --git a/arrayfire/backend/backend.py b/arrayfire/backend/backend.py index fb5a6ff..cdd4344 100644 --- a/arrayfire/backend/backend.py +++ b/arrayfire/backend/backend.py @@ -5,12 +5,12 @@ import traceback from dataclasses import dataclass from pathlib import Path -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple from arrayfire.backend import config -from arrayfire.dtypes import Dtype, float32 +from arrayfire.dtypes import float32 from arrayfire.dtypes.helpers import CShape, c_dim_t, to_str -from arrayfire.version import ARRAYFIRE_VER_MAJOR, FORGE_VER_MAJOR +from arrayfire.version import FORGE_VER_MAJOR VERBOSE_LOADS = os.environ.get("AF_VERBOSE_LOADS") == "1" @@ -34,61 +34,62 @@ class BackendDevices(enum.Enum): class Backend: def __init__(self) -> None: - self._clibs = {device.name: None for device in BackendDevices} - - more_info_str = "Please look at https://github.com/arrayfire/arrayfire-python/wiki for more information." + self._clibs: Dict[str, Optional[ctypes.CDLL]] = {device.name: None for device in BackendDevices} self.setup = config.setup() - af_module = __import__(__name__) - self.AF_PYMODULE_PATH = af_module.__path__[0] + "/" if af_module.__path__ else None + self._name: Optional[str] = None - self._name = None + self.load_forge_lib() + self.load_backend_libs() + def load_forge_lib(self) -> None: for libname in self._libnames(config.SupportedLibs.forge, FORGE_VER_MAJOR): full_libname = libname[0] + libname[1] try: ctypes.cdll.LoadLibrary(full_libname) if VERBOSE_LOADS: - print("Loaded " + full_libname) + print(" > Loaded " + full_libname) break except OSError: if VERBOSE_LOADS: - traceback.print_exc() + # traceback.print_exc() print("Unable to load " + full_libname) pass - out = ctypes.c_void_p(0) - dims = CShape(10, 10, 1, 1) + def load_backend_libs(self) -> None: for device in BackendDevices: - _name = device.name if device != BackendDevices.unified else "" - for libname in self._libnames(config.SupportedLibs.arrayfire): - full_libname = Path(libname[0]) / Path(libname[1]) - try: - ctypes.cdll.LoadLibrary(str(full_libname)) - clib = ctypes.CDLL(str(full_libname)) - self._clibs[_name] = clib - err = clib.af_randu(ctypes.pointer(out), 4, ctypes.pointer(dims.c_array), float32.c_api_value) - if err == _ErrorCodes.none.value: - self._name = _name - clib.af_release_array(out) - if VERBOSE_LOADS: - print("Loaded " + full_libname) - - if device == BackendDevices.cuda: - nvrtc_name = self._find_nvrtc_builtins_libname(Path(libname[0])) - if nvrtc_name: - ctypes.cdll.LoadLibrary(libname[0] + nvrtc_name) - if VERBOSE_LOADS: - print("Loaded " + libname[0] + nvrtc_name) - else: - if VERBOSE_LOADS: - print("Could not find local nvrtc-builtins library") - break - except OSError: - if VERBOSE_LOADS: - traceback.print_exc() - print("Unable to load " + full_libname) - pass + self.load_backend_lib(device) + + def load_backend_lib(self, device: BackendDevices) -> None: + for libname in self._libnames(config.SupportedLibs.arrayfire): + full_libname = Path(libname[0]) / Path(libname[1]) + try: + ctypes.cdll.LoadLibrary(str(full_libname)) + self._clibs[device.name] = ctypes.CDLL(str(full_libname)) + + if device == BackendDevices.cuda: + self.load_nvrtc_builtins_lib(libname[0]) + + if VERBOSE_LOADS: + print(f"Loaded {full_libname}") + break + + break + except OSError: + if VERBOSE_LOADS: + traceback.print_exc() + print(f"Unable to load {full_libname}") + pass + + def load_nvrtc_builtins_lib(self, lib_path: str) -> None: + nvrtc_name = self._find_nvrtc_builtins_libname(Path(lib_path)) + if nvrtc_name: + ctypes.cdll.LoadLibrary(lib_path + nvrtc_name) + if VERBOSE_LOADS: + print("Loaded " + lib_path + nvrtc_name) + else: + if VERBOSE_LOADS: + print("Could not find local nvrtc-builtins library") # if self._name is None: # raise RuntimeError("Could not load any ArrayFire libraries.\n" + more_info_str) @@ -133,7 +134,7 @@ def get_name(self, bk_id: int) -> str: return self._backend_map.get(bk_id, "unknown") def get(self): - return self._clibs.get(self._name) + return self._clibs.get("cpu") # FIXME: should be self._name def name(self) -> str: return self._name diff --git a/arrayfire/backend/c_backend/unsorted.py b/arrayfire/backend/c_backend/unsorted.py index 4c4dfa7..86054da 100644 --- a/arrayfire/backend/c_backend/unsorted.py +++ b/arrayfire/backend/c_backend/unsorted.py @@ -261,3 +261,13 @@ def where(arr: AFArrayType) -> AFArrayType: out = ctypes.c_void_p(0) safe_call(backend_api.af_where(ctypes.pointer(out), arr)) return out + + +def randu(shape: Tuple[int, ...], dtype: Dtype, /) -> AFArrayType: + """ + source: https://arrayfire.org/docs/group__random__func__randu.htm#ga412e2c2f5135bdda218c3487c487d3b5 + """ + out = ctypes.c_void_p(0) + c_shape = CShape(*shape) + safe_call(backend_api.af_randu(ctypes.pointer(out), *c_shape, dtype.c_api_value)) + return out From 7981cb3c33cd2757a4afb52c5ae4ec473cbf9bb0 Mon Sep 17 00:00:00 2001 From: Anton Chernyatevich Date: Fri, 11 Aug 2023 00:00:56 +0300 Subject: [PATCH 04/31] Refactoring --- arrayfire/backend/backend.py | 47 +++++++++++++++++------------------- arrayfire/backend/config.py | 2 +- arrayfire/logger.py | 18 ++++++++++++++ arrayfire/version.py | 2 +- 4 files changed, 42 insertions(+), 27 deletions(-) create mode 100755 arrayfire/logger.py diff --git a/arrayfire/backend/backend.py b/arrayfire/backend/backend.py index cdd4344..e0cc871 100644 --- a/arrayfire/backend/backend.py +++ b/arrayfire/backend/backend.py @@ -2,14 +2,13 @@ import enum import os import sys -import traceback from dataclasses import dataclass from pathlib import Path from typing import Dict, List, Optional, Tuple from arrayfire.backend import config -from arrayfire.dtypes import float32 -from arrayfire.dtypes.helpers import CShape, c_dim_t, to_str +from arrayfire.dtypes.helpers import c_dim_t, to_str +from arrayfire.logger import logger from arrayfire.version import FORGE_VER_MAJOR VERBOSE_LOADS = os.environ.get("AF_VERBOSE_LOADS") == "1" @@ -41,19 +40,17 @@ def __init__(self) -> None: self.load_forge_lib() self.load_backend_libs() + print(self._clibs) def load_forge_lib(self) -> None: - for libname in self._libnames(config.SupportedLibs.forge, FORGE_VER_MAJOR): + for libname in self._libnames("forge", config.SupportedLibs.forge, FORGE_VER_MAJOR): full_libname = libname[0] + libname[1] try: ctypes.cdll.LoadLibrary(full_libname) - if VERBOSE_LOADS: - print(" > Loaded " + full_libname) + logger.info(f"Loaded {full_libname}") break except OSError: - if VERBOSE_LOADS: - # traceback.print_exc() - print("Unable to load " + full_libname) + logger.warning(f"Unable to load {full_libname}") pass def load_backend_libs(self) -> None: @@ -61,7 +58,10 @@ def load_backend_libs(self) -> None: self.load_backend_lib(device) def load_backend_lib(self, device: BackendDevices) -> None: - for libname in self._libnames(config.SupportedLibs.arrayfire): + # NOTE we still set unified cdll to it's original name later, even if the path search is different + name = device.name if device != BackendDevices.unified else "" + + for libname in self._libnames(name): full_libname = Path(libname[0]) / Path(libname[1]) try: ctypes.cdll.LoadLibrary(str(full_libname)) @@ -70,33 +70,30 @@ def load_backend_lib(self, device: BackendDevices) -> None: if device == BackendDevices.cuda: self.load_nvrtc_builtins_lib(libname[0]) - if VERBOSE_LOADS: - print(f"Loaded {full_libname}") - break - + logger.info(f"Loaded {full_libname}") break except OSError: - if VERBOSE_LOADS: - traceback.print_exc() - print(f"Unable to load {full_libname}") + logger.warning(f"Unable to load {full_libname}") pass def load_nvrtc_builtins_lib(self, lib_path: str) -> None: nvrtc_name = self._find_nvrtc_builtins_libname(Path(lib_path)) if nvrtc_name: ctypes.cdll.LoadLibrary(lib_path + nvrtc_name) - if VERBOSE_LOADS: - print("Loaded " + lib_path + nvrtc_name) + logger.info("Loaded " + lib_path + nvrtc_name) else: - if VERBOSE_LOADS: - print("Could not find local nvrtc-builtins library") + logger.warning("Could not find local nvrtc-builtins library") - # if self._name is None: - # raise RuntimeError("Could not load any ArrayFire libraries.\n" + more_info_str) + if all(value is None for value in self._clibs.values()): + raise RuntimeError( + "Could not load any ArrayFire libraries.\n" + "Please look at https://github.com/arrayfire/arrayfire-python/wiki for more information.") - def _libnames(self, lib: config.SupportedLibs, ver_major: Optional[str] = None) -> List[Tuple[str, str]]: + def _libnames( + self, name: str, lib: config.SupportedLibs = config.SupportedLibs.arrayfire, ver_major: Optional[str] = None + ) -> List[Tuple[str, str]]: post = self.setup.post if ver_major is None else ver_major - libname = self.setup.pre + lib.value + post + libname = self.setup.pre + lib.value + name + post lib64_path = self.setup.af_path / "lib64" search_path = lib64_path if lib64_path.is_dir() else self.setup.af_path / "lib" diff --git a/arrayfire/backend/config.py b/arrayfire/backend/config.py index fb84097..687fd81 100644 --- a/arrayfire/backend/config.py +++ b/arrayfire/backend/config.py @@ -11,7 +11,7 @@ class SupportedLibs(Enum): - forge = "forge" + forge = "" arrayfire = "af" diff --git a/arrayfire/logger.py b/arrayfire/logger.py new file mode 100755 index 0000000..6ca0409 --- /dev/null +++ b/arrayfire/logger.py @@ -0,0 +1,18 @@ +import logging + +# Configure the logger +logging.basicConfig(level=logging.DEBUG) + +# Create a logger +logger = logging.getLogger(__name__) + +# Create a console handler and set the level to DEBUG +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.DEBUG) + +# Create a formatter and attach it to the console handler +formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") +console_handler.setFormatter(formatter) + +# Add the console handler to the logger +logger.addHandler(console_handler) diff --git a/arrayfire/version.py b/arrayfire/version.py index a402fd7..148cbb7 100644 --- a/arrayfire/version.py +++ b/arrayfire/version.py @@ -12,6 +12,6 @@ VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR) VERSION = "{0}.{1}.{2}{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX) -FORGE_VER_MAJOR = "0" +FORGE_VER_MAJOR = "1" ARRAYFIRE_VER_MAJOR = "3" ARRAYFIRE_VER_MINOR = "8" From 4fae55cb2f55b349ad09b077a3cb1217c57fd08b Mon Sep 17 00:00:00 2001 From: Anton Chernyatevich Date: Fri, 11 Aug 2023 00:27:42 +0300 Subject: [PATCH 05/31] Even more refactoring --- arrayfire/backend/backend.py | 51 +++++++++++++++++------------------- arrayfire/backend/config.py | 5 ++-- 2 files changed, 26 insertions(+), 30 deletions(-) diff --git a/arrayfire/backend/backend.py b/arrayfire/backend/backend.py index e0cc871..7416ef7 100644 --- a/arrayfire/backend/backend.py +++ b/arrayfire/backend/backend.py @@ -9,7 +9,6 @@ from arrayfire.backend import config from arrayfire.dtypes.helpers import c_dim_t, to_str from arrayfire.logger import logger -from arrayfire.version import FORGE_VER_MAJOR VERBOSE_LOADS = os.environ.get("AF_VERBOSE_LOADS") == "1" @@ -42,15 +41,19 @@ def __init__(self) -> None: self.load_backend_libs() print(self._clibs) + if all(value is None for value in self._clibs.values()): + raise RuntimeError( + "Could not load any ArrayFire libraries.\n" + "Please look at https://github.com/arrayfire/arrayfire-python/wiki for more information.") + def load_forge_lib(self) -> None: - for libname in self._libnames("forge", config.SupportedLibs.forge, FORGE_VER_MAJOR): - full_libname = libname[0] + libname[1] + for libname in self._libnames("forge", config.SupportedLibs.forge): try: - ctypes.cdll.LoadLibrary(full_libname) - logger.info(f"Loaded {full_libname}") + ctypes.cdll.LoadLibrary(str(libname)) + logger.info(f"Loaded {libname}") break except OSError: - logger.warning(f"Unable to load {full_libname}") + logger.warning(f"Unable to load {libname}") pass def load_backend_libs(self) -> None: @@ -62,36 +65,30 @@ def load_backend_lib(self, device: BackendDevices) -> None: name = device.name if device != BackendDevices.unified else "" for libname in self._libnames(name): - full_libname = Path(libname[0]) / Path(libname[1]) try: - ctypes.cdll.LoadLibrary(str(full_libname)) - self._clibs[device.name] = ctypes.CDLL(str(full_libname)) + ctypes.cdll.LoadLibrary(str(libname)) + self._clibs[device.name] = ctypes.CDLL(str(libname)) if device == BackendDevices.cuda: - self.load_nvrtc_builtins_lib(libname[0]) + self.load_nvrtc_builtins_lib(libname.parent) - logger.info(f"Loaded {full_libname}") + logger.info(f"Loaded {libname}") break except OSError: - logger.warning(f"Unable to load {full_libname}") + logger.warning(f"Unable to load {libname}") pass - def load_nvrtc_builtins_lib(self, lib_path: str) -> None: - nvrtc_name = self._find_nvrtc_builtins_libname(Path(lib_path)) + def load_nvrtc_builtins_lib(self, lib_path: Path) -> None: + nvrtc_name = self._find_nvrtc_builtins_libname(lib_path) if nvrtc_name: - ctypes.cdll.LoadLibrary(lib_path + nvrtc_name) - logger.info("Loaded " + lib_path + nvrtc_name) + ctypes.cdll.LoadLibrary(str(lib_path / nvrtc_name)) + logger.info(f"Loaded {lib_path / nvrtc_name}") else: logger.warning("Could not find local nvrtc-builtins library") - if all(value is None for value in self._clibs.values()): - raise RuntimeError( - "Could not load any ArrayFire libraries.\n" - "Please look at https://github.com/arrayfire/arrayfire-python/wiki for more information.") - def _libnames( self, name: str, lib: config.SupportedLibs = config.SupportedLibs.arrayfire, ver_major: Optional[str] = None - ) -> List[Tuple[str, str]]: + ) -> List[Path]: post = self.setup.post if ver_major is None else ver_major libname = self.setup.pre + lib.value + name + post @@ -102,15 +99,15 @@ def _libnames( # prefer locally packaged arrayfire libraries if they exist af_module = __import__(__name__) - local_path = af_module.__path__[0] + "/" if af_module.__path__ else None + local_path = Path(af_module.__path__[0]) if af_module.__path__ else Path("") - libpaths = [("", libname), (str(site_path), libname), (str(local_path), libname)] + libpaths = [Path("", libname), site_path / libname, local_path / libname] if self.setup.af_path: # prefer specified AF_PATH if exists - libpaths.append((str(search_path), libname)) + return [search_path / libname] + libpaths else: - libpaths.insert(2, (str(search_path), libname)) - return libpaths + libpaths.insert(2, Path(str(search_path), libname)) + return libpaths def _find_nvrtc_builtins_libname(self, search_path: Path) -> Optional[str]: for f in search_path.iterdir(): diff --git a/arrayfire/backend/config.py b/arrayfire/backend/config.py index 687fd81..b3c8100 100644 --- a/arrayfire/backend/config.py +++ b/arrayfire/backend/config.py @@ -1,13 +1,12 @@ import ctypes import os import platform -import sys from dataclasses import dataclass from enum import Enum from pathlib import Path -from typing import Iterator, List, Optional, Tuple +from typing import Iterator -from arrayfire.version import ARRAYFIRE_VER_MAJOR, FORGE_VER_MAJOR +from arrayfire.version import ARRAYFIRE_VER_MAJOR class SupportedLibs(Enum): From 2accd0a9c1c946a70e24ace960835b17cda44d19 Mon Sep 17 00:00:00 2001 From: Anton Chernyatevich Date: Fri, 11 Aug 2023 00:40:41 +0300 Subject: [PATCH 06/31] Refactor backend --- arrayfire/backend/backend.py | 85 +++++++++++++----------------------- 1 file changed, 30 insertions(+), 55 deletions(-) diff --git a/arrayfire/backend/backend.py b/arrayfire/backend/backend.py index 7416ef7..2f2007c 100644 --- a/arrayfire/backend/backend.py +++ b/arrayfire/backend/backend.py @@ -4,7 +4,7 @@ import sys from dataclasses import dataclass from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import List, Optional from arrayfire.backend import config from arrayfire.dtypes.helpers import c_dim_t, to_str @@ -25,28 +25,28 @@ class ArrayBuffer: class BackendDevices(enum.Enum): unified = 0 # NOTE It is set as Default value on Arrayfire backend - cpu = 1 cuda = 2 opencl = 4 + cpu = 1 # NOTE It comes last because we want to keep this order on backend initialization class Backend: def __init__(self) -> None: - self._clibs: Dict[str, Optional[ctypes.CDLL]] = {device.name: None for device in BackendDevices} - self.setup = config.setup() + self.device: Optional[BackendDevices] = None + self.library: Optional[ctypes.CDLL] = None - self._name: Optional[str] = None + self._setup = config.setup() - self.load_forge_lib() - self.load_backend_libs() - print(self._clibs) + self._load_forge_lib() + self._load_backend_libs() - if all(value is None for value in self._clibs.values()): + if not self.device and not self.library: raise RuntimeError( "Could not load any ArrayFire libraries.\n" - "Please look at https://github.com/arrayfire/arrayfire-python/wiki for more information.") + "Please look at https://github.com/arrayfire/arrayfire-python/wiki for more information." + ) - def load_forge_lib(self) -> None: + def _load_forge_lib(self) -> None: for libname in self._libnames("forge", config.SupportedLibs.forge): try: ctypes.cdll.LoadLibrary(str(libname)) @@ -56,21 +56,26 @@ def load_forge_lib(self) -> None: logger.warning(f"Unable to load {libname}") pass - def load_backend_libs(self) -> None: + def _load_backend_libs(self) -> None: for device in BackendDevices: - self.load_backend_lib(device) + self._load_backend_lib(device) - def load_backend_lib(self, device: BackendDevices) -> None: + if self.device: + logger.info(f"Setting {device.name} as backend.") + break + + def _load_backend_lib(self, device: BackendDevices) -> None: # NOTE we still set unified cdll to it's original name later, even if the path search is different name = device.name if device != BackendDevices.unified else "" - for libname in self._libnames(name): + for libname in self._libnames(name, config.SupportedLibs.arrayfire): try: ctypes.cdll.LoadLibrary(str(libname)) - self._clibs[device.name] = ctypes.CDLL(str(libname)) + self.device = device + self.library = ctypes.CDLL(str(libname)) if device == BackendDevices.cuda: - self.load_nvrtc_builtins_lib(libname.parent) + self._load_nvrtc_builtins_lib(libname.parent) logger.info(f"Loaded {libname}") break @@ -78,7 +83,7 @@ def load_backend_lib(self, device: BackendDevices) -> None: logger.warning(f"Unable to load {libname}") pass - def load_nvrtc_builtins_lib(self, lib_path: Path) -> None: + def _load_nvrtc_builtins_lib(self, lib_path: Path) -> None: nvrtc_name = self._find_nvrtc_builtins_libname(lib_path) if nvrtc_name: ctypes.cdll.LoadLibrary(str(lib_path / nvrtc_name)) @@ -86,14 +91,12 @@ def load_nvrtc_builtins_lib(self, lib_path: Path) -> None: else: logger.warning("Could not find local nvrtc-builtins library") - def _libnames( - self, name: str, lib: config.SupportedLibs = config.SupportedLibs.arrayfire, ver_major: Optional[str] = None - ) -> List[Path]: - post = self.setup.post if ver_major is None else ver_major - libname = self.setup.pre + lib.value + name + post + def _libnames(self, name: str, lib: config.SupportedLibs, ver_major: Optional[str] = None) -> List[Path]: + post = self._setup.post if ver_major is None else ver_major + libname = self._setup.pre + lib.value + name + post - lib64_path = self.setup.af_path / "lib64" - search_path = lib64_path if lib64_path.is_dir() else self.setup.af_path / "lib" + lib64_path = self._setup.af_path / "lib64" + search_path = lib64_path if lib64_path.is_dir() else self._setup.af_path / "lib" site_path = Path(sys.prefix) / "lib64" if not config.is_arch_x86() else Path(sys.prefix) / "lib" @@ -103,7 +106,7 @@ def _libnames( libpaths = [Path("", libname), site_path / libname, local_path / libname] - if self.setup.af_path: # prefer specified AF_PATH if exists + if self._setup.af_path: # prefer specified AF_PATH if exists return [search_path / libname] + libpaths else: libpaths.insert(2, Path(str(search_path), libname)) @@ -115,40 +118,12 @@ def _find_nvrtc_builtins_libname(self, search_path: Path) -> Optional[str]: return f.name return None - def set_unsafe(self, name: str) -> None: - lib = self._clibs.get(name) - if lib is None: - raise RuntimeError("Backend not found") - self._name = name - - def get_id(self, name: str) -> int: - return self._backend_name_map[name] - - def get_name(self, bk_id: int) -> str: - return self._backend_map.get(bk_id, "unknown") - - def get(self): - return self._clibs.get("cpu") # FIXME: should be self._name - - def name(self) -> str: - return self._name - - def is_unified(self) -> bool: - return self._name == "unified" - - def parse(self, res: int) -> Tuple[str, ...]: - lst = [] - for key, value in self._backend_name_map.items(): - if value & res: - lst.append(key) - return tuple(lst) - # HACK for osx # backend_api = ctypes.CDLL("/opt/arrayfire//lib/libafcpu.3.dylib") # HACK for windows # backend_api = ctypes.CDLL("C:/Program Files/ArrayFire/v3/lib/afcpu.dll") -backend_api = Backend().get() +backend_api = Backend().library def safe_call(c_err: int) -> None: From cd7dbb67fc6aa68a555d6566ac646e91818515d9 Mon Sep 17 00:00:00 2001 From: Anton Chernyatevich Date: Fri, 11 Aug 2023 00:42:07 +0300 Subject: [PATCH 07/31] Minor refactoring --- arrayfire/backend/backend.py | 6 +++--- arrayfire/backend/config.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/arrayfire/backend/backend.py b/arrayfire/backend/backend.py index 2f2007c..21d9c09 100644 --- a/arrayfire/backend/backend.py +++ b/arrayfire/backend/backend.py @@ -47,7 +47,7 @@ def __init__(self) -> None: ) def _load_forge_lib(self) -> None: - for libname in self._libnames("forge", config.SupportedLibs.forge): + for libname in self._libnames("forge", config.SupportedLibPrefixes.forge): try: ctypes.cdll.LoadLibrary(str(libname)) logger.info(f"Loaded {libname}") @@ -68,7 +68,7 @@ def _load_backend_lib(self, device: BackendDevices) -> None: # NOTE we still set unified cdll to it's original name later, even if the path search is different name = device.name if device != BackendDevices.unified else "" - for libname in self._libnames(name, config.SupportedLibs.arrayfire): + for libname in self._libnames(name, config.SupportedLibPrefixes.arrayfire): try: ctypes.cdll.LoadLibrary(str(libname)) self.device = device @@ -91,7 +91,7 @@ def _load_nvrtc_builtins_lib(self, lib_path: Path) -> None: else: logger.warning("Could not find local nvrtc-builtins library") - def _libnames(self, name: str, lib: config.SupportedLibs, ver_major: Optional[str] = None) -> List[Path]: + def _libnames(self, name: str, lib: config.SupportedLibPrefixes, ver_major: Optional[str] = None) -> List[Path]: post = self._setup.post if ver_major is None else ver_major libname = self._setup.pre + lib.value + name + post diff --git a/arrayfire/backend/config.py b/arrayfire/backend/config.py index b3c8100..3b64992 100644 --- a/arrayfire/backend/config.py +++ b/arrayfire/backend/config.py @@ -9,7 +9,7 @@ from arrayfire.version import ARRAYFIRE_VER_MAJOR -class SupportedLibs(Enum): +class SupportedLibPrefixes(Enum): forge = "" arrayfire = "af" From 5d90490ab7751b28078ea810f35cba7ab1663d14 Mon Sep 17 00:00:00 2001 From: Anton Chernyatevich Date: Fri, 11 Aug 2023 01:38:53 +0300 Subject: [PATCH 08/31] Refactor --- arrayfire/backend/backend.py | 18 +++++++++--------- arrayfire/backend/config.py | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/arrayfire/backend/backend.py b/arrayfire/backend/backend.py index 21d9c09..82f7bf5 100644 --- a/arrayfire/backend/backend.py +++ b/arrayfire/backend/backend.py @@ -31,21 +31,15 @@ class BackendDevices(enum.Enum): class Backend: - def __init__(self) -> None: - self.device: Optional[BackendDevices] = None - self.library: Optional[ctypes.CDLL] = None + device: BackendDevices + library: ctypes.CDLL + def __init__(self) -> None: self._setup = config.setup() self._load_forge_lib() self._load_backend_libs() - if not self.device and not self.library: - raise RuntimeError( - "Could not load any ArrayFire libraries.\n" - "Please look at https://github.com/arrayfire/arrayfire-python/wiki for more information." - ) - def _load_forge_lib(self) -> None: for libname in self._libnames("forge", config.SupportedLibPrefixes.forge): try: @@ -64,6 +58,12 @@ def _load_backend_libs(self) -> None: logger.info(f"Setting {device.name} as backend.") break + if not self.device and not self.library: + raise RuntimeError( + "Could not load any ArrayFire libraries.\n" + "Please look at https://github.com/arrayfire/arrayfire-python/wiki for more information." + ) + def _load_backend_lib(self, device: BackendDevices) -> None: # NOTE we still set unified cdll to it's original name later, even if the path search is different name = device.name if device != BackendDevices.unified else "" diff --git a/arrayfire/backend/config.py b/arrayfire/backend/config.py index 3b64992..d714996 100644 --- a/arrayfire/backend/config.py +++ b/arrayfire/backend/config.py @@ -29,7 +29,7 @@ def is_arch_x86() -> bool: return platform.architecture()[0][0:2] == "32" and (machine[-2:] == "86" or machine[0:3] == "arm") -@dataclass +@dataclass(frozen=True) class Setup: pre: str post: str From cb8fed998db751ca671fd7a3e9cca994fb62919b Mon Sep 17 00:00:00 2001 From: Anton Chernyatevich Date: Fri, 11 Aug 2023 02:28:42 +0300 Subject: [PATCH 09/31] Refactoring --- arrayfire/__init__.py | 4 +- arrayfire/array_api/constants.py | 1 + arrayfire/array_api/creation_function.py | 10 +- arrayfire/array_api/dtypes.py | 15 +- arrayfire/backend/__init__.py | 28 +-- arrayfire/backend/backend.py | 48 ++-- arrayfire/backend/c_backend/constant_array.py | 3 +- arrayfire/backend/c_backend/error_handler.py | 18 ++ arrayfire/backend/c_backend/indexing.py | 3 +- arrayfire/backend/c_backend/operators.py | 3 +- .../backend/c_backend/reduction_operations.py | 4 +- arrayfire/backend/c_backend/unsorted.py | 20 +- arrayfire/backend/constants.py | 7 + arrayfire/dtypes/helpers.py | 3 +- arrayfire/{backend/config.py => platform.py} | 205 +++++++++--------- tests/array_object/test_methods.py | 9 - 16 files changed, 184 insertions(+), 197 deletions(-) mode change 100644 => 100755 arrayfire/backend/c_backend/constant_array.py create mode 100755 arrayfire/backend/c_backend/error_handler.py mode change 100644 => 100755 arrayfire/backend/c_backend/operators.py mode change 100644 => 100755 arrayfire/backend/c_backend/unsorted.py rename arrayfire/{backend/config.py => platform.py} (84%) mode change 100644 => 100755 diff --git a/arrayfire/__init__.py b/arrayfire/__init__.py index f856c55..56d8a23 100644 --- a/arrayfire/__init__.py +++ b/arrayfire/__init__.py @@ -16,7 +16,9 @@ "complex128", "bool", ] - +# fmt: off from .dtypes import ( bool, complex64, complex128, float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64) from .library.array_object import Array + +# fmt: on diff --git a/arrayfire/array_api/constants.py b/arrayfire/array_api/constants.py index 7e45310..46007d2 100755 --- a/arrayfire/array_api/constants.py +++ b/arrayfire/array_api/constants.py @@ -17,6 +17,7 @@ ] from typing import Any, Literal, Protocol, TypeVar + from .array_object import Array _T_co = TypeVar("_T_co", covariant=True) diff --git a/arrayfire/array_api/creation_function.py b/arrayfire/array_api/creation_function.py index 0009513..76eb1b0 100755 --- a/arrayfire/array_api/creation_function.py +++ b/arrayfire/array_api/creation_function.py @@ -1,9 +1,11 @@ -from .constants import NestedSequence, SupportsBufferProtocol, Device -from .array_object import Array +from typing import Optional, Union + +from arrayfire import Array as AFArray from arrayfire.dtypes import Dtype, supported_dtypes from arrayfire.library.device import supported_devices -from arrayfire import Array as AFArray -from typing import Union, Optional + +from .array_object import Array +from .constants import Device, NestedSequence, SupportsBufferProtocol def asarray( diff --git a/arrayfire/array_api/dtypes.py b/arrayfire/array_api/dtypes.py index a6aa3f9..0260e00 100755 --- a/arrayfire/array_api/dtypes.py +++ b/arrayfire/array_api/dtypes.py @@ -1,18 +1,5 @@ from arrayfire import ( - bool, - complex64, - complex128, - float32, - float64, - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, -) + bool, complex64, complex128, float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64) from arrayfire.dtypes import Dtype all_dtypes = ( diff --git a/arrayfire/backend/__init__.py b/arrayfire/backend/__init__.py index 1e31ea2..8008659 100644 --- a/arrayfire/backend/__init__.py +++ b/arrayfire/backend/__init__.py @@ -1,5 +1,5 @@ __all__ = [ - # Backend + # Backend Constants "ArrayBuffer", # Operators "add", @@ -21,25 +21,9 @@ "eq", "neq", ] - -from .backend import ArrayBuffer +# fmt: off from .c_backend.operators import ( - add, - bitand, - bitnot, - bitor, - bitshiftl, - bitshiftr, - bitxor, - div, - eq, - ge, - gt, - le, - lt, - mod, - mul, - neq, - pow, - sub, -) + add, bitand, bitnot, bitor, bitshiftl, bitshiftr, bitxor, div, eq, ge, gt, le, lt, mod, mul, neq, pow, sub) +from .constants import ArrayBuffer + +# fmt: on diff --git a/arrayfire/backend/backend.py b/arrayfire/backend/backend.py index 82f7bf5..9f4d33d 100644 --- a/arrayfire/backend/backend.py +++ b/arrayfire/backend/backend.py @@ -1,26 +1,17 @@ import ctypes import enum -import os import sys -from dataclasses import dataclass +from enum import Enum from pathlib import Path from typing import List, Optional -from arrayfire.backend import config -from arrayfire.dtypes.helpers import c_dim_t, to_str from arrayfire.logger import logger +from arrayfire.platform import get_platform_config, is_arch_x86 -VERBOSE_LOADS = os.environ.get("AF_VERBOSE_LOADS") == "1" - -class _ErrorCodes(enum.Enum): - none = 0 - - -@dataclass -class ArrayBuffer: - address: int - length: int = 0 +class _LibPrefixes(Enum): + forge = "" + arrayfire = "af" class BackendDevices(enum.Enum): @@ -35,13 +26,13 @@ class Backend: library: ctypes.CDLL def __init__(self) -> None: - self._setup = config.setup() + self._platform_config = get_platform_config() self._load_forge_lib() self._load_backend_libs() def _load_forge_lib(self) -> None: - for libname in self._libnames("forge", config.SupportedLibPrefixes.forge): + for libname in self._libnames("forge", _LibPrefixes.forge): try: ctypes.cdll.LoadLibrary(str(libname)) logger.info(f"Loaded {libname}") @@ -68,7 +59,7 @@ def _load_backend_lib(self, device: BackendDevices) -> None: # NOTE we still set unified cdll to it's original name later, even if the path search is different name = device.name if device != BackendDevices.unified else "" - for libname in self._libnames(name, config.SupportedLibPrefixes.arrayfire): + for libname in self._libnames(name, _LibPrefixes.arrayfire): try: ctypes.cdll.LoadLibrary(str(libname)) self.device = device @@ -91,14 +82,14 @@ def _load_nvrtc_builtins_lib(self, lib_path: Path) -> None: else: logger.warning("Could not find local nvrtc-builtins library") - def _libnames(self, name: str, lib: config.SupportedLibPrefixes, ver_major: Optional[str] = None) -> List[Path]: - post = self._setup.post if ver_major is None else ver_major - libname = self._setup.pre + lib.value + name + post + def _libnames(self, name: str, lib: _LibPrefixes, ver_major: Optional[str] = None) -> List[Path]: + post = self._platform_config.lib_postfix if ver_major is None else ver_major + libname = self._platform_config.lib_prefix + lib.value + name + post - lib64_path = self._setup.af_path / "lib64" - search_path = lib64_path if lib64_path.is_dir() else self._setup.af_path / "lib" + lib64_path = self._platform_config.af_path / "lib64" + search_path = lib64_path if lib64_path.is_dir() else self._platform_config.af_path / "lib" - site_path = Path(sys.prefix) / "lib64" if not config.is_arch_x86() else Path(sys.prefix) / "lib" + site_path = Path(sys.prefix) / "lib64" if not is_arch_x86() else Path(sys.prefix) / "lib" # prefer locally packaged arrayfire libraries if they exist af_module = __import__(__name__) @@ -106,7 +97,7 @@ def _libnames(self, name: str, lib: config.SupportedLibPrefixes, ver_major: Opti libpaths = [Path("", libname), site_path / libname, local_path / libname] - if self._setup.af_path: # prefer specified AF_PATH if exists + if self._platform_config.af_path: # prefer specified AF_PATH if exists return [search_path / libname] + libpaths else: libpaths.insert(2, Path(str(search_path), libname)) @@ -124,12 +115,3 @@ def _find_nvrtc_builtins_libname(self, search_path: Path) -> Optional[str]: # HACK for windows # backend_api = ctypes.CDLL("C:/Program Files/ArrayFire/v3/lib/afcpu.dll") backend_api = Backend().library - - -def safe_call(c_err: int) -> None: - if c_err == _ErrorCodes.none.value: - return - - err_str = ctypes.c_char_p(0) - backend_api.af_get_last_error(ctypes.pointer(err_str), ctypes.pointer(c_dim_t(0))) - raise RuntimeError(to_str(err_str)) diff --git a/arrayfire/backend/c_backend/constant_array.py b/arrayfire/backend/c_backend/constant_array.py old mode 100644 new mode 100755 index 0bf9689..1abcabc --- a/arrayfire/backend/c_backend/constant_array.py +++ b/arrayfire/backend/c_backend/constant_array.py @@ -1,11 +1,12 @@ import ctypes from typing import Tuple, Union +from arrayfire.backend.backend import backend_api from arrayfire.dtypes import Dtype, int64, uint64 from arrayfire.dtypes.helpers import CShape, implicit_dtype -from ..backend import backend_api, safe_call from .constants import AFArrayType +from .error_handler import safe_call def _constant_complex(number: Union[int, float], shape: Tuple[int, ...], dtype: Dtype, /) -> AFArrayType: diff --git a/arrayfire/backend/c_backend/error_handler.py b/arrayfire/backend/c_backend/error_handler.py new file mode 100755 index 0000000..df402bd --- /dev/null +++ b/arrayfire/backend/c_backend/error_handler.py @@ -0,0 +1,18 @@ +import ctypes +from enum import Enum + +from arrayfire.backend.backend import backend_api +from arrayfire.dtypes.helpers import c_dim_t, to_str + + +class _ErrorCodes(Enum): + none = 0 + + +def safe_call(c_err: int) -> None: + if c_err == _ErrorCodes.none.value: + return + + err_str = ctypes.c_char_p(0) + backend_api.af_get_last_error(ctypes.pointer(err_str), ctypes.pointer(c_dim_t(0))) + raise RuntimeError(to_str(err_str)) diff --git a/arrayfire/backend/c_backend/indexing.py b/arrayfire/backend/c_backend/indexing.py index 9699a51..897f4bc 100755 --- a/arrayfire/backend/c_backend/indexing.py +++ b/arrayfire/backend/c_backend/indexing.py @@ -4,10 +4,11 @@ import math from typing import Any, Union +from arrayfire.backend.backend import backend_api from arrayfire.library.broadcast import bcast_var -from ..backend import backend_api, safe_call from . import constants +from .error_handler import safe_call class _IndexSequence(ctypes.Structure): diff --git a/arrayfire/backend/c_backend/operators.py b/arrayfire/backend/c_backend/operators.py old mode 100644 new mode 100755 index f2f3499..64b5710 --- a/arrayfire/backend/c_backend/operators.py +++ b/arrayfire/backend/c_backend/operators.py @@ -1,10 +1,11 @@ import ctypes from typing import Callable +from arrayfire.backend.backend import backend_api from arrayfire.library.broadcast import bcast_var -from ..backend import backend_api, safe_call from .constants import AFArrayType +from .error_handler import safe_call # Arithmetic Operators diff --git a/arrayfire/backend/c_backend/reduction_operations.py b/arrayfire/backend/c_backend/reduction_operations.py index d15e7ed..adc8897 100755 --- a/arrayfire/backend/c_backend/reduction_operations.py +++ b/arrayfire/backend/c_backend/reduction_operations.py @@ -1,8 +1,10 @@ import ctypes from typing import Callable, Union -from ..backend import backend_api, safe_call +from arrayfire.backend.backend import backend_api + from .constants import AFArrayType +from .error_handler import safe_call def count_all(x: AFArrayType) -> Union[int, float, complex]: diff --git a/arrayfire/backend/c_backend/unsorted.py b/arrayfire/backend/c_backend/unsorted.py old mode 100644 new mode 100755 index 86054da..860b497 --- a/arrayfire/backend/c_backend/unsorted.py +++ b/arrayfire/backend/c_backend/unsorted.py @@ -1,11 +1,14 @@ import ctypes from typing import Any, Tuple, Union, cast -from ...dtypes import CType, Dtype -from ...dtypes.helpers import CShape, c_dim_t, to_str -from ...library.device import PointerSource -from ..backend import ArrayBuffer, backend_api, safe_call +from arrayfire.backend.backend import backend_api +from arrayfire.backend.constants import ArrayBuffer +from arrayfire.dtypes import CType, Dtype +from arrayfire.dtypes.helpers import CShape, c_dim_t, to_str +from arrayfire.library.device import PointerSource + from .constants import AFArrayType +from .error_handler import safe_call # Array management @@ -271,3 +274,12 @@ def randu(shape: Tuple[int, ...], dtype: Dtype, /) -> AFArrayType: c_shape = CShape(*shape) safe_call(backend_api.af_randu(ctypes.pointer(out), *c_shape, dtype.c_api_value)) return out + + +def get_last_error() -> ctypes.c_char_p: + """ + source: https://arrayfire.org/docs/exception_8h.htm#a4f0227c17954d343021313f77e695c8e + """ + out = ctypes.c_char_p(0) + backend_api.af_get_last_error(ctypes.pointer(out), ctypes.pointer(c_dim_t(0))) + return out diff --git a/arrayfire/backend/constants.py b/arrayfire/backend/constants.py index e69de29..20c1418 100755 --- a/arrayfire/backend/constants.py +++ b/arrayfire/backend/constants.py @@ -0,0 +1,7 @@ +from dataclasses import dataclass + + +@dataclass(frozen=True) +class ArrayBuffer: + address: int + length: int = 0 diff --git a/arrayfire/dtypes/helpers.py b/arrayfire/dtypes/helpers.py index 20c679e..30b8d4c 100644 --- a/arrayfire/dtypes/helpers.py +++ b/arrayfire/dtypes/helpers.py @@ -3,7 +3,8 @@ import ctypes from typing import Tuple, Union -from arrayfire.backend.config import is_arch_x86 +from arrayfire.platform import is_arch_x86 + from . import Dtype from . import bool as af_bool from . import complex64, complex128, float32, float64, int64, supported_dtypes diff --git a/arrayfire/backend/config.py b/arrayfire/platform.py old mode 100644 new mode 100755 similarity index 84% rename from arrayfire/backend/config.py rename to arrayfire/platform.py index d714996..734f453 --- a/arrayfire/backend/config.py +++ b/arrayfire/platform.py @@ -1,105 +1,100 @@ -import ctypes -import os -import platform -from dataclasses import dataclass -from enum import Enum -from pathlib import Path -from typing import Iterator - -from arrayfire.version import ARRAYFIRE_VER_MAJOR - - -class SupportedLibPrefixes(Enum): - forge = "" - arrayfire = "af" - - -class SupportedPlatforms(Enum): - windows = "Windows" - darwin = "Darwin" # OSX - linux = "Linux" - - @classmethod - def is_cygwin(cls, name: str) -> bool: - return "cyg" in name.lower() - - -def is_arch_x86() -> bool: - machine = platform.machine() - return platform.architecture()[0][0:2] == "32" and (machine[-2:] == "86" or machine[0:3] == "arm") - - -@dataclass(frozen=True) -class Setup: - pre: str - post: str - af_path: Path - cuda_found: bool - - def __iter__(self) -> Iterator: - return iter((self.pre, self.post, self.af_path, self.af_path, self.cuda_found)) - - -def setup() -> Setup: - platform_name = platform.system() - cuda_found = False - - try: - af_path = Path(os.environ["AF_PATH"]) - except KeyError: - af_path = None - - try: - cuda_path = Path(os.environ["CUDA_PATH"]) - except KeyError: - cuda_path = None - - if platform_name == SupportedPlatforms.windows.value or SupportedPlatforms.is_cygwin(platform_name): - if platform_name == SupportedPlatforms.windows.value: - # HACK Supressing crashes caused by missing dlls - # http://stackoverflow.com/questions/8347266/missing-dll-print-message-instead-of-launching-a-popup - # https://msdn.microsoft.com/en-us/library/windows/desktop/ms680621.aspx - ctypes.windll.kernel32.SetErrorMode(0x0001 | 0x0002) # type: ignore[attr-defined] - - if not af_path: - af_path = _find_default_path(f"C:/Program Files/ArrayFire/v{ARRAYFIRE_VER_MAJOR}") - - if cuda_path and (cuda_path / "bin").is_dir() and (cuda_path / "nvvm/bin").is_dir(): - cuda_found = True - - return Setup("", ".dll", af_path, cuda_found) - - if platform_name == SupportedPlatforms.darwin.value: - default_cuda_path = Path("/usr/local/cuda/") - - if not af_path: - af_path = _find_default_path("/opt/arrayfire", "/usr/local") - - if not (cuda_path and default_cuda_path.exists()): - cuda_found = (default_cuda_path / "lib").is_dir() and (default_cuda_path / "/nvvm/lib").is_dir() - - return Setup("lib", f".{ARRAYFIRE_VER_MAJOR}.dylib", af_path, cuda_found) - - if platform_name == SupportedPlatforms.linux.value: - default_cuda_path = Path("/usr/local/cuda/") - - if not af_path: - af_path = _find_default_path(f"/opt/arrayfire-{ARRAYFIRE_VER_MAJOR}", "/opt/arrayfire/", "/usr/local/") - - if not (cuda_path and default_cuda_path.exists()): - if "64" in platform.architecture()[0]: # Check either is 64 bit arch is selected - cuda_found = (default_cuda_path / "lib64").is_dir() and (default_cuda_path / "nvvm/lib64").is_dir() - else: - cuda_found = (default_cuda_path / "lib").is_dir() and (default_cuda_path / "nvvm/lib").is_dir() - - return Setup("lib", f".so.{ARRAYFIRE_VER_MAJOR}", af_path, cuda_found) - - raise OSError(f"{platform_name} is not supported.") - - -def _find_default_path(*args: str) -> Path: - for path in args: - default_path = Path(path) - if default_path.exists(): - return default_path - raise ValueError("None of specified default paths were found.") +import ctypes +import os +import platform +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import Iterator + +from arrayfire.version import ARRAYFIRE_VER_MAJOR + + +def is_arch_x86() -> bool: + machine = platform.machine() + return platform.architecture()[0][0:2] == "32" and (machine[-2:] == "86" or machine[0:3] == "arm") + + +class SupportedPlatforms(Enum): + windows = "Windows" + darwin = "Darwin" # OSX + linux = "Linux" + + @classmethod + def is_cygwin(cls, name: str) -> bool: + return "cyg" in name.lower() + + +@dataclass(frozen=True) +class PlatformConfig: + lib_prefix: str + lib_postfix: str + af_path: Path + cuda_found: bool + + def __iter__(self) -> Iterator: + return iter((self.lib_prefix, self.lib_postfix, self.af_path, self.af_path, self.cuda_found)) + + +def get_platform_config() -> PlatformConfig: + platform_name = platform.system() + cuda_found = False + + try: + af_path = Path(os.environ["AF_PATH"]) + except KeyError: + af_path = None + + try: + cuda_path = Path(os.environ["CUDA_PATH"]) + except KeyError: + cuda_path = None + + if platform_name == SupportedPlatforms.windows.value or SupportedPlatforms.is_cygwin(platform_name): + if platform_name == SupportedPlatforms.windows.value: + # HACK Supressing crashes caused by missing dlls + # http://stackoverflow.com/questions/8347266/missing-dll-print-message-instead-of-launching-a-popup + # https://msdn.microsoft.com/en-us/library/windows/desktop/ms680621.aspx + ctypes.windll.kernel32.SetErrorMode(0x0001 | 0x0002) # type: ignore[attr-defined] + + if not af_path: + af_path = _find_default_path(f"C:/Program Files/ArrayFire/v{ARRAYFIRE_VER_MAJOR}") + + if cuda_path and (cuda_path / "bin").is_dir() and (cuda_path / "nvvm/bin").is_dir(): + cuda_found = True + + return PlatformConfig("", ".dll", af_path, cuda_found) + + if platform_name == SupportedPlatforms.darwin.value: + default_cuda_path = Path("/usr/local/cuda/") + + if not af_path: + af_path = _find_default_path("/opt/arrayfire", "/usr/local") + + if not (cuda_path and default_cuda_path.exists()): + cuda_found = (default_cuda_path / "lib").is_dir() and (default_cuda_path / "/nvvm/lib").is_dir() + + return PlatformConfig("lib", f".{ARRAYFIRE_VER_MAJOR}.dylib", af_path, cuda_found) + + if platform_name == SupportedPlatforms.linux.value: + default_cuda_path = Path("/usr/local/cuda/") + + if not af_path: + af_path = _find_default_path(f"/opt/arrayfire-{ARRAYFIRE_VER_MAJOR}", "/opt/arrayfire/", "/usr/local/") + + if not (cuda_path and default_cuda_path.exists()): + if "64" in platform.architecture()[0]: # Check either is 64 bit arch is selected + cuda_found = (default_cuda_path / "lib64").is_dir() and (default_cuda_path / "nvvm/lib64").is_dir() + else: + cuda_found = (default_cuda_path / "lib").is_dir() and (default_cuda_path / "nvvm/lib").is_dir() + + return PlatformConfig("lib", f".so.{ARRAYFIRE_VER_MAJOR}", af_path, cuda_found) + + raise OSError(f"{platform_name} is not supported.") + + +def _find_default_path(*args: str) -> Path: + for path in args: + default_path = Path(path) + if default_path.exists(): + return default_path + raise ValueError("None of specified default paths were found.") diff --git a/tests/array_object/test_methods.py b/tests/array_object/test_methods.py index 3e924a6..c5ec273 100644 --- a/tests/array_object/test_methods.py +++ b/tests/array_object/test_methods.py @@ -1,4 +1,3 @@ -from arrayfire.dtypes import float32, int32 from arrayfire.library.array_object import Array @@ -42,11 +41,3 @@ def test_array_to_list_comparison() -> None: array2 = Array([1, 2, 3]) assert array1 is not array2 assert array1.to_list() == array2.to_list() - - -# BUG -# def test_copy_for_array_with_multiple_elements() -> None: -# array = Array([1, 2, 3]) -# copy = array.copy() -# assert array is not copy -# assert array.to_list() == copy.to_list() From 2c965487ca3e13a22a693c5d8696c2c9a7941355 Mon Sep 17 00:00:00 2001 From: Anton Chernyatevich Date: Fri, 11 Aug 2023 03:04:45 +0300 Subject: [PATCH 10/31] Refactoring --- arrayfire/backend/__init__.py | 7 ++++++ arrayfire/backend/backend.py | 27 +++++++++++---------- arrayfire/backend/c_backend/unsorted.py | 14 +++++++++++ arrayfire/backend/helpers.py | 31 +++++++++++++++++++++++++ 4 files changed, 67 insertions(+), 12 deletions(-) create mode 100755 arrayfire/backend/helpers.py diff --git a/arrayfire/backend/__init__.py b/arrayfire/backend/__init__.py index 8008659..055b6dd 100644 --- a/arrayfire/backend/__init__.py +++ b/arrayfire/backend/__init__.py @@ -20,10 +20,17 @@ "ge", "eq", "neq", + # Backend + "BackendPlatform", + "set_backend", + "get_backend", ] + # fmt: off +from .backend import BackendPlatform from .c_backend.operators import ( add, bitand, bitnot, bitor, bitshiftl, bitshiftr, bitxor, div, eq, ge, gt, le, lt, mod, mul, neq, pow, sub) from .constants import ArrayBuffer +from .helpers import get_backend, set_backend # fmt: on diff --git a/arrayfire/backend/backend.py b/arrayfire/backend/backend.py index 9f4d33d..a8154aa 100644 --- a/arrayfire/backend/backend.py +++ b/arrayfire/backend/backend.py @@ -1,3 +1,5 @@ +__all__ = ["BackendPlatform"] + import ctypes import enum import sys @@ -14,7 +16,7 @@ class _LibPrefixes(Enum): arrayfire = "af" -class BackendDevices(enum.Enum): +class BackendPlatform(enum.Enum): unified = 0 # NOTE It is set as Default value on Arrayfire backend cuda = 2 opencl = 4 @@ -22,7 +24,7 @@ class BackendDevices(enum.Enum): class Backend: - device: BackendDevices + platform: BackendPlatform library: ctypes.CDLL def __init__(self) -> None: @@ -42,30 +44,30 @@ def _load_forge_lib(self) -> None: pass def _load_backend_libs(self) -> None: - for device in BackendDevices: - self._load_backend_lib(device) + for platform in BackendPlatform: + self._load_backend_lib(platform) - if self.device: - logger.info(f"Setting {device.name} as backend.") + if self.platform: + logger.info(f"Setting {platform.name} as backend.") break - if not self.device and not self.library: + if not self.platform and not self.library: raise RuntimeError( "Could not load any ArrayFire libraries.\n" "Please look at https://github.com/arrayfire/arrayfire-python/wiki for more information." ) - def _load_backend_lib(self, device: BackendDevices) -> None: + def _load_backend_lib(self, platform: BackendPlatform) -> None: # NOTE we still set unified cdll to it's original name later, even if the path search is different - name = device.name if device != BackendDevices.unified else "" + name = platform.name if platform != BackendPlatform.unified else "" for libname in self._libnames(name, _LibPrefixes.arrayfire): try: ctypes.cdll.LoadLibrary(str(libname)) - self.device = device + self.platform = platform self.library = ctypes.CDLL(str(libname)) - if device == BackendDevices.cuda: + if platform == BackendPlatform.cuda: self._load_nvrtc_builtins_lib(libname.parent) logger.info(f"Loaded {libname}") @@ -114,4 +116,5 @@ def _find_nvrtc_builtins_libname(self, search_path: Path) -> Optional[str]: # backend_api = ctypes.CDLL("/opt/arrayfire//lib/libafcpu.3.dylib") # HACK for windows # backend_api = ctypes.CDLL("C:/Program Files/ArrayFire/v3/lib/afcpu.dll") -backend_api = Backend().library +backend = Backend() +backend_api = backend.library diff --git a/arrayfire/backend/c_backend/unsorted.py b/arrayfire/backend/c_backend/unsorted.py index 860b497..3da0821 100755 --- a/arrayfire/backend/c_backend/unsorted.py +++ b/arrayfire/backend/c_backend/unsorted.py @@ -276,6 +276,9 @@ def randu(shape: Tuple[int, ...], dtype: Dtype, /) -> AFArrayType: return out +# Safe Call Wrapper + + def get_last_error() -> ctypes.c_char_p: """ source: https://arrayfire.org/docs/exception_8h.htm#a4f0227c17954d343021313f77e695c8e @@ -283,3 +286,14 @@ def get_last_error() -> ctypes.c_char_p: out = ctypes.c_char_p(0) backend_api.af_get_last_error(ctypes.pointer(out), ctypes.pointer(c_dim_t(0))) return out + + +# Backend + + +def set_backend(backend_c_value: int, /) -> None: + """ + source: https://arrayfire.org/docs/group__unified__func__setbackend.htm#ga6fde820e8802776b7fc823504b37f1b4 + """ + safe_call(backend_api.af_set_backend(backend_c_value)) + return None diff --git a/arrayfire/backend/helpers.py b/arrayfire/backend/helpers.py new file mode 100755 index 0000000..48b105f --- /dev/null +++ b/arrayfire/backend/helpers.py @@ -0,0 +1,31 @@ +from typing import Union + +from .backend import Backend, BackendPlatform, backend, backend_api +from .c_backend.unsorted import set_backend as c_set_backend + + +def set_backend(platform: Union[BackendPlatform, str]) -> None: + current_active_platform = backend_api.platform + + if isinstance(platform, str): + if platform not in [d.name for d in BackendPlatform]: + raise ValueError(f"{platform} is not a valid name for backend platform.") + platform = BackendPlatform[platform] + + if not isinstance(platform, BackendPlatform): + raise TypeError(f"{platform} is not a valid name for backend platform.") + + if current_active_platform == platform: + raise RuntimeError(f"{platform} is already the active backend platform.") + + if backend_api.platform == BackendPlatform.unified: + c_set_backend(platform.value) + + backend_api._load_backend_lib(platform) # FIXME should not access private API + + if current_active_platform == backend_api.platform: + raise RuntimeError(f"Could not set {platform} as new backend platform. Consider checking logs.") + + +def get_backend() -> Backend: + return backend From a6907e093f6a2ff889d4ef76fd49710191934a02 Mon Sep 17 00:00:00 2001 From: Anton Chernyatevich Date: Fri, 11 Aug 2023 04:49:24 +0300 Subject: [PATCH 11/31] Add backend helpers --- arrayfire/backend/__init__.py | 15 ++- arrayfire/backend/c_backend/unsorted.py | 36 +++++ arrayfire/backend/helpers.py | 167 +++++++++++++++++++++++- 3 files changed, 209 insertions(+), 9 deletions(-) diff --git a/arrayfire/backend/__init__.py b/arrayfire/backend/__init__.py index 055b6dd..2975dd5 100644 --- a/arrayfire/backend/__init__.py +++ b/arrayfire/backend/__init__.py @@ -22,8 +22,17 @@ "neq", # Backend "BackendPlatform", - "set_backend", + "get_active_backend", # DeprecationWarning + "get_array_backend_name", + "get_array_device_id", + "get_available_backends", # DeprecationWarning "get_backend", + "get_backend_count", + "get_backend_id", # DeprecationWarning + "get_device_id", # DeprecationWarning + "get_dtype_size", + "get_size_of", # DeprecationWarning + "set_backend", ] # fmt: off @@ -31,6 +40,8 @@ from .c_backend.operators import ( add, bitand, bitnot, bitor, bitshiftl, bitshiftr, bitxor, div, eq, ge, gt, le, lt, mod, mul, neq, pow, sub) from .constants import ArrayBuffer -from .helpers import get_backend, set_backend +from .helpers import ( + get_active_backend, get_array_backend_name, get_array_device_id, get_available_backends, get_backend, + get_backend_count, get_backend_id, get_device_id, get_dtype_size, get_size_of, set_backend) # fmt: on diff --git a/arrayfire/backend/c_backend/unsorted.py b/arrayfire/backend/c_backend/unsorted.py index 3da0821..e8e50e3 100755 --- a/arrayfire/backend/c_backend/unsorted.py +++ b/arrayfire/backend/c_backend/unsorted.py @@ -297,3 +297,39 @@ def set_backend(backend_c_value: int, /) -> None: """ safe_call(backend_api.af_set_backend(backend_c_value)) return None + + +def get_backend_count() -> int: + """ + source: https://arrayfire.org/docs/group__unified__func__getbackendcount.htm#gad38c2dfedfdabfa264afa46d8664e9cd + """ + out = ctypes.c_int(0) + safe_call(backend_api.get().af_get_backend_count(ctypes.pointer(out))) + return out.value + + +def get_device_id(arr: AFArrayType, /) -> int: + """ + source: https://arrayfire.org/docs/group__unified__func__getdeviceid.htm#ga5d94b64dccd1c7cbc7a3a69fa64888c3 + """ + out = ctypes.c_int(0) + safe_call(backend_api.get().af_get_device_id(ctypes.pointer(out), arr)) + return out.value + + +def get_size_of(dtype: Dtype, /) -> int: + """ + source: https://arrayfire.org/docs/util_8h.htm#a8b72cffd10a92a7a2ee7f52dadda5216 + """ + out = ctypes.c_size_t(0) + safe_call(backend_api.get().af_get_size_of(ctypes.pointer(out), dtype.c_api_value)) + return out.value + + +def get_backend_id(arr: AFArrayType, /) -> int: + """ + source: https://arrayfire.org/docs/group__unified__func__getbackendid.htm#ga5fc39e209e1886cf250aec265c0d9079 + """ + out = ctypes.c_int(0) + safe_call(backend_api.get().af_get_backend_id(ctypes.pointer(out), arr)) + return out.value diff --git a/arrayfire/backend/helpers.py b/arrayfire/backend/helpers.py index 48b105f..a6f30ee 100755 --- a/arrayfire/backend/helpers.py +++ b/arrayfire/backend/helpers.py @@ -1,11 +1,42 @@ -from typing import Union +from __future__ import annotations -from .backend import Backend, BackendPlatform, backend, backend_api +import warnings +from typing import TYPE_CHECKING, Union + +from .backend import Backend, BackendPlatform, backend +from .c_backend.unsorted import get_backend_count as c_get_backend_count +from .c_backend.unsorted import get_backend_id as c_get_backend_id +from .c_backend.unsorted import get_device_id as c_get_device_id +from .c_backend.unsorted import get_size_of as c_get_size_of from .c_backend.unsorted import set_backend as c_set_backend +if TYPE_CHECKING: + from arrayfire import Array + from arrayfire.dtypes import Dtype + def set_backend(platform: Union[BackendPlatform, str]) -> None: - current_active_platform = backend_api.platform + """ + Set a specific backend by platform name. + + Parameters + ---------- + platform : Union[BackendPlatform, str] + Name of the backend platform to set. + + Raises + ------ + ValueError + If the given platform name is not a valid name for backend platform. + TypeError + If the given platform is not a valid type for backend platform. + RuntimeError + If the given platform is already the active backend platform. + RuntimeError + If the given platform could not be set as new backend platform. + """ + + current_active_platform = backend.platform if isinstance(platform, str): if platform not in [d.name for d in BackendPlatform]: @@ -13,19 +44,141 @@ def set_backend(platform: Union[BackendPlatform, str]) -> None: platform = BackendPlatform[platform] if not isinstance(platform, BackendPlatform): - raise TypeError(f"{platform} is not a valid name for backend platform.") + raise TypeError(f"{platform} is not a valid type for backend platform.") if current_active_platform == platform: raise RuntimeError(f"{platform} is already the active backend platform.") - if backend_api.platform == BackendPlatform.unified: + if backend.platform == BackendPlatform.unified: c_set_backend(platform.value) - backend_api._load_backend_lib(platform) # FIXME should not access private API + backend._load_backend_lib(platform) # FIXME should not access private API - if current_active_platform == backend_api.platform: + if current_active_platform == backend.platform: raise RuntimeError(f"Could not set {platform} as new backend platform. Consider checking logs.") def get_backend() -> Backend: + """ + Get the current active backend. + + Returns + ------- + value : Backend + Current active backend. + """ + return backend + + +def get_array_backend_name(array: Array) -> str: + """ + Get the name of the backend on which the Array is located. + + Parameters + ---------- + array : Array + The Array to get the backend name of. + + Returns + ------- + value : str + Name of the backend on which the Array is located. + """ + + id_ = c_get_backend_id(array.arr) + return BackendPlatform(id_).name + + +def get_backend_id(array: Array) -> str: + warnings.warn("Was renamed due to unintuitive function name. Now get_array_backend_name().", DeprecationWarning) + return get_array_backend_name(array) + + +def get_backend_count() -> int: + """ + Get a number of available backends. + + Returns + ------- + + value : int + Number of available backends. + """ + + return c_get_backend_count() + + +def get_active_backend() -> Backend: + """ + Get the current active backend. + + value : Backend + Current active backend. + """ + + warnings.warn("A user has access explicitly only to the active backend.", DeprecationWarning) + return get_backend() + + +def get_available_backends() -> Backend: + """ + Get the list of available backends. + + Returns + ------- + value : Backend + Current active backend. + """ + + warnings.warn( + "A user has access explicitly only to the active backend. Thus returning only active backend.", + DeprecationWarning, + ) + return get_active_backend() + + +def get_array_device_id(array: Array) -> int: + """ + Get the id of the device on which the Array was created. + + Parameters + ---------- + array : Array + The Array to get the device id of. + + Returns + ------- + value : int + The id of the device on which the Array was created. + """ + + return c_get_device_id(array.arr) + + +def get_device_id(array: Array) -> int: + warnings.warn("Was renamed due to unintuitive function name. Now get_array_device_id().", DeprecationWarning) + return get_array_device_id(array) + + +def get_dtype_size(dtype: Dtype) -> int: + """ + Get the size of the type represented by Dtype. + + Parameters + ---------- + dtype : Dtype + The type to get the size of. + + Returns + ------- + value : int + The size of the type in bytes. + """ + + return c_get_size_of(dtype) + + +def get_size_of(dtype: Dtype) -> int: + warnings.warn("Was renamed due to unintuitive function name. Now get_dtype_size().", DeprecationWarning) + return get_dtype_size(dtype) From 485ca8d33e225160460c7038c8efa3c488f0f68f Mon Sep 17 00:00:00 2001 From: Anton Chernyatevich Date: Fri, 11 Aug 2023 16:14:05 +0300 Subject: [PATCH 12/31] Remove c_backend constants --- arrayfire/backend/backend.py | 8 ++++++-- arrayfire/backend/c_backend/constant_array.py | 8 ++++++-- arrayfire/backend/c_backend/constants.py | 4 ---- arrayfire/backend/c_backend/indexing.py | 5 ++--- arrayfire/backend/c_backend/operators.py | 8 ++++++-- arrayfire/backend/c_backend/reduction_operations.py | 8 ++++++-- arrayfire/backend/c_backend/unsorted.py | 8 ++++++-- arrayfire/library/array_object.py | 8 +++++--- 8 files changed, 37 insertions(+), 20 deletions(-) delete mode 100755 arrayfire/backend/c_backend/constants.py diff --git a/arrayfire/backend/backend.py b/arrayfire/backend/backend.py index a8154aa..5940620 100644 --- a/arrayfire/backend/backend.py +++ b/arrayfire/backend/backend.py @@ -5,7 +5,7 @@ import sys from enum import Enum from pathlib import Path -from typing import List, Optional +from typing import Iterator, List, Optional from arrayfire.logger import logger from arrayfire.platform import get_platform_config, is_arch_x86 @@ -18,9 +18,13 @@ class _LibPrefixes(Enum): class BackendPlatform(enum.Enum): unified = 0 # NOTE It is set as Default value on Arrayfire backend + cpu = 1 cuda = 2 opencl = 4 - cpu = 1 # NOTE It comes last because we want to keep this order on backend initialization + + def __iter__(self) -> Iterator: + # NOTE cpu comes last because we want to keep this order priorty during backend initialization + return iter((self.unified, self.cuda, self.opencl, self.cpu)) class Backend: diff --git a/arrayfire/backend/c_backend/constant_array.py b/arrayfire/backend/c_backend/constant_array.py index 1abcabc..37ed085 100755 --- a/arrayfire/backend/c_backend/constant_array.py +++ b/arrayfire/backend/c_backend/constant_array.py @@ -1,13 +1,17 @@ +from __future__ import annotations + import ctypes -from typing import Tuple, Union +from typing import TYPE_CHECKING, Tuple, Union from arrayfire.backend.backend import backend_api from arrayfire.dtypes import Dtype, int64, uint64 from arrayfire.dtypes.helpers import CShape, implicit_dtype -from .constants import AFArrayType from .error_handler import safe_call +if TYPE_CHECKING: + from arrayfire.library.array_object import AFArrayType + def _constant_complex(number: Union[int, float], shape: Tuple[int, ...], dtype: Dtype, /) -> AFArrayType: """ diff --git a/arrayfire/backend/c_backend/constants.py b/arrayfire/backend/c_backend/constants.py deleted file mode 100755 index cb8e944..0000000 --- a/arrayfire/backend/c_backend/constants.py +++ /dev/null @@ -1,4 +0,0 @@ -import ctypes - -AFArrayType = ctypes.c_void_p -AFArrayPointerType = ctypes._Pointer diff --git a/arrayfire/backend/c_backend/indexing.py b/arrayfire/backend/c_backend/indexing.py index 897f4bc..c9a05c5 100755 --- a/arrayfire/backend/c_backend/indexing.py +++ b/arrayfire/backend/c_backend/indexing.py @@ -7,7 +7,6 @@ from arrayfire.backend.backend import backend_api from arrayfire.library.broadcast import bcast_var -from . import constants from .error_handler import safe_call @@ -70,7 +69,7 @@ def __init__(self, chunk: Union[int, slice]): self.step.value = -1 if chunk.stop: - self.end -= math.copysign(1, self.step) # type: ignore[operator] + self.end -= math.copysign(1, self.step) # type: ignore[operator, assignment, arg-type] # FIXME else: raise IndexError("Invalid type while indexing arrayfire.array") @@ -238,7 +237,7 @@ def __init__(self) -> None: self.array = index_vec(*self.idxs) @property - def pointer(self) -> constants.AFArrayPointerType: + def pointer(self) -> ctypes._Pointer: return ctypes.pointer(self.array) def __getitem__(self, idx: int) -> IndexStructure: diff --git a/arrayfire/backend/c_backend/operators.py b/arrayfire/backend/c_backend/operators.py index 64b5710..63056c1 100755 --- a/arrayfire/backend/c_backend/operators.py +++ b/arrayfire/backend/c_backend/operators.py @@ -1,12 +1,16 @@ +from __future__ import annotations + import ctypes -from typing import Callable +from typing import TYPE_CHECKING, Callable from arrayfire.backend.backend import backend_api from arrayfire.library.broadcast import bcast_var -from .constants import AFArrayType from .error_handler import safe_call +if TYPE_CHECKING: + from arrayfire.library.array_object import AFArrayType + # Arithmetic Operators diff --git a/arrayfire/backend/c_backend/reduction_operations.py b/arrayfire/backend/c_backend/reduction_operations.py index adc8897..d97197a 100755 --- a/arrayfire/backend/c_backend/reduction_operations.py +++ b/arrayfire/backend/c_backend/reduction_operations.py @@ -1,11 +1,15 @@ +from __future__ import annotations + import ctypes -from typing import Callable, Union +from typing import TYPE_CHECKING, Callable, Union from arrayfire.backend.backend import backend_api -from .constants import AFArrayType from .error_handler import safe_call +if TYPE_CHECKING: + from arrayfire.library.array_object import AFArrayType + def count_all(x: AFArrayType) -> Union[int, float, complex]: # TODO reconsider original arith.count diff --git a/arrayfire/backend/c_backend/unsorted.py b/arrayfire/backend/c_backend/unsorted.py index e8e50e3..274ce4b 100755 --- a/arrayfire/backend/c_backend/unsorted.py +++ b/arrayfire/backend/c_backend/unsorted.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import ctypes -from typing import Any, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, Tuple, Union, cast from arrayfire.backend.backend import backend_api from arrayfire.backend.constants import ArrayBuffer @@ -7,9 +9,11 @@ from arrayfire.dtypes.helpers import CShape, c_dim_t, to_str from arrayfire.library.device import PointerSource -from .constants import AFArrayType from .error_handler import safe_call +if TYPE_CHECKING: + from arrayfire.library.array_object import AFArrayType + # Array management diff --git a/arrayfire/library/array_object.py b/arrayfire/library/array_object.py index e775bae..f2c1850 100755 --- a/arrayfire/library/array_object.py +++ b/arrayfire/library/array_object.py @@ -19,11 +19,13 @@ # TODO use int | float in operators -> remove bool | complex support +AFArrayType = ctypes.c_void_p + class Array: def __init__( self, - obj: Union[None, Array, py_array.array, int, ctypes.c_void_p, List[Union[int, float]]] = None, + obj: Union[None, Array, py_array.array, int, AFArrayType, List[Union[int, float]]] = None, dtype: Union[None, Dtype, str] = None, shape: Tuple[int, ...] = (), pointer_source: PointerSource = PointerSource.host, @@ -60,8 +62,8 @@ def __init__( _type_char = _array.typecode _array_buffer = ArrayBuffer(*_array.buffer_info()) - elif isinstance(obj, int) or isinstance(obj, ctypes.c_void_p): # TODO - _array_buffer = ArrayBuffer(obj if not isinstance(obj, ctypes.c_void_p) else obj.value) # type: ignore + elif isinstance(obj, int) or isinstance(obj, AFArrayType): + _array_buffer = ArrayBuffer(obj if not isinstance(obj, AFArrayType) else obj.value) # type: ignore if not shape: raise TypeError("Expected to receive the initial shape due to the obj being a data pointer.") From dd0180bb66edfc948dba917537175c291f7476cc Mon Sep 17 00:00:00 2001 From: Anton Chernyatevich Date: Fri, 11 Aug 2023 16:26:54 +0300 Subject: [PATCH 13/31] Rename dirs --- arrayfire/backend/__init__.py | 13 ++++---- arrayfire/backend/{backend.py => api.py} | 17 ++++++++-- .../{c_backend => c_library}/__init__.py | 0 .../constant_array.py | 2 +- .../{c_backend => c_library}/error_handler.py | 2 +- .../{c_backend => c_library}/indexing.py | 2 +- .../{c_backend => c_library}/operators.py | 2 +- .../reduction_operations.py | 2 +- .../{c_backend => c_library}/unsorted.py | 2 +- arrayfire/backend/helpers.py | 31 ++++++------------- arrayfire/library/array_object.py | 8 ++--- 11 files changed, 42 insertions(+), 39 deletions(-) rename arrayfire/backend/{backend.py => api.py} (94%) rename arrayfire/backend/{c_backend => c_library}/__init__.py (100%) rename arrayfire/backend/{c_backend => c_library}/constant_array.py (98%) rename arrayfire/backend/{c_backend => c_library}/error_handler.py (84%) rename arrayfire/backend/{c_backend => c_library}/indexing.py (95%) rename arrayfire/backend/{c_backend => c_library}/operators.py (98%) rename arrayfire/backend/{c_backend => c_library}/reduction_operations.py (90%) rename arrayfire/backend/{c_backend => c_library}/unsorted.py (99%) diff --git a/arrayfire/backend/__init__.py b/arrayfire/backend/__init__.py index 2975dd5..7ac21de 100644 --- a/arrayfire/backend/__init__.py +++ b/arrayfire/backend/__init__.py @@ -20,13 +20,14 @@ "ge", "eq", "neq", - # Backend + # Backend API "BackendPlatform", + "get_backend", + # Backend Helpers "get_active_backend", # DeprecationWarning "get_array_backend_name", "get_array_device_id", "get_available_backends", # DeprecationWarning - "get_backend", "get_backend_count", "get_backend_id", # DeprecationWarning "get_device_id", # DeprecationWarning @@ -36,12 +37,12 @@ ] # fmt: off -from .backend import BackendPlatform -from .c_backend.operators import ( +from .api import BackendPlatform, get_backend +from .c_library.operators import ( add, bitand, bitnot, bitor, bitshiftl, bitshiftr, bitxor, div, eq, ge, gt, le, lt, mod, mul, neq, pow, sub) from .constants import ArrayBuffer from .helpers import ( - get_active_backend, get_array_backend_name, get_array_device_id, get_available_backends, get_backend, - get_backend_count, get_backend_id, get_device_id, get_dtype_size, get_size_of, set_backend) + get_active_backend, get_array_backend_name, get_array_device_id, get_available_backends, get_backend_count, + get_backend_id, get_device_id, get_dtype_size, get_size_of, set_backend) # fmt: on diff --git a/arrayfire/backend/backend.py b/arrayfire/backend/api.py similarity index 94% rename from arrayfire/backend/backend.py rename to arrayfire/backend/api.py index 5940620..e760195 100644 --- a/arrayfire/backend/backend.py +++ b/arrayfire/backend/api.py @@ -120,5 +120,18 @@ def _find_nvrtc_builtins_libname(self, search_path: Path) -> Optional[str]: # backend_api = ctypes.CDLL("/opt/arrayfire//lib/libafcpu.3.dylib") # HACK for windows # backend_api = ctypes.CDLL("C:/Program Files/ArrayFire/v3/lib/afcpu.dll") -backend = Backend() -backend_api = backend.library +_backend = Backend() +backend_api = _backend.library + + +def get_backend() -> Backend: + """ + Get the current active backend. + + Returns + ------- + value : Backend + Current active backend. + """ + + return _backend diff --git a/arrayfire/backend/c_backend/__init__.py b/arrayfire/backend/c_library/__init__.py similarity index 100% rename from arrayfire/backend/c_backend/__init__.py rename to arrayfire/backend/c_library/__init__.py diff --git a/arrayfire/backend/c_backend/constant_array.py b/arrayfire/backend/c_library/constant_array.py similarity index 98% rename from arrayfire/backend/c_backend/constant_array.py rename to arrayfire/backend/c_library/constant_array.py index 37ed085..b3a6312 100755 --- a/arrayfire/backend/c_backend/constant_array.py +++ b/arrayfire/backend/c_library/constant_array.py @@ -3,7 +3,7 @@ import ctypes from typing import TYPE_CHECKING, Tuple, Union -from arrayfire.backend.backend import backend_api +from arrayfire.backend.api import backend_api from arrayfire.dtypes import Dtype, int64, uint64 from arrayfire.dtypes.helpers import CShape, implicit_dtype diff --git a/arrayfire/backend/c_backend/error_handler.py b/arrayfire/backend/c_library/error_handler.py similarity index 84% rename from arrayfire/backend/c_backend/error_handler.py rename to arrayfire/backend/c_library/error_handler.py index df402bd..aef3865 100755 --- a/arrayfire/backend/c_backend/error_handler.py +++ b/arrayfire/backend/c_library/error_handler.py @@ -1,7 +1,7 @@ import ctypes from enum import Enum -from arrayfire.backend.backend import backend_api +from arrayfire.backend.api import backend_api from arrayfire.dtypes.helpers import c_dim_t, to_str diff --git a/arrayfire/backend/c_backend/indexing.py b/arrayfire/backend/c_library/indexing.py similarity index 95% rename from arrayfire/backend/c_backend/indexing.py rename to arrayfire/backend/c_library/indexing.py index c9a05c5..1383798 100755 --- a/arrayfire/backend/c_backend/indexing.py +++ b/arrayfire/backend/c_library/indexing.py @@ -4,7 +4,7 @@ import math from typing import Any, Union -from arrayfire.backend.backend import backend_api +from arrayfire.backend.api import backend_api from arrayfire.library.broadcast import bcast_var from .error_handler import safe_call diff --git a/arrayfire/backend/c_backend/operators.py b/arrayfire/backend/c_library/operators.py similarity index 98% rename from arrayfire/backend/c_backend/operators.py rename to arrayfire/backend/c_library/operators.py index 63056c1..15a5f24 100755 --- a/arrayfire/backend/c_backend/operators.py +++ b/arrayfire/backend/c_library/operators.py @@ -3,7 +3,7 @@ import ctypes from typing import TYPE_CHECKING, Callable -from arrayfire.backend.backend import backend_api +from arrayfire.backend.api import backend_api from arrayfire.library.broadcast import bcast_var from .error_handler import safe_call diff --git a/arrayfire/backend/c_backend/reduction_operations.py b/arrayfire/backend/c_library/reduction_operations.py similarity index 90% rename from arrayfire/backend/c_backend/reduction_operations.py rename to arrayfire/backend/c_library/reduction_operations.py index d97197a..b4ff67b 100755 --- a/arrayfire/backend/c_backend/reduction_operations.py +++ b/arrayfire/backend/c_library/reduction_operations.py @@ -3,7 +3,7 @@ import ctypes from typing import TYPE_CHECKING, Callable, Union -from arrayfire.backend.backend import backend_api +from arrayfire.backend.api import backend_api from .error_handler import safe_call diff --git a/arrayfire/backend/c_backend/unsorted.py b/arrayfire/backend/c_library/unsorted.py similarity index 99% rename from arrayfire/backend/c_backend/unsorted.py rename to arrayfire/backend/c_library/unsorted.py index 274ce4b..aa970d7 100755 --- a/arrayfire/backend/c_backend/unsorted.py +++ b/arrayfire/backend/c_library/unsorted.py @@ -3,7 +3,7 @@ import ctypes from typing import TYPE_CHECKING, Any, Tuple, Union, cast -from arrayfire.backend.backend import backend_api +from arrayfire.backend.api import backend_api from arrayfire.backend.constants import ArrayBuffer from arrayfire.dtypes import CType, Dtype from arrayfire.dtypes.helpers import CShape, c_dim_t, to_str diff --git a/arrayfire/backend/helpers.py b/arrayfire/backend/helpers.py index a6f30ee..af19ec2 100755 --- a/arrayfire/backend/helpers.py +++ b/arrayfire/backend/helpers.py @@ -3,12 +3,12 @@ import warnings from typing import TYPE_CHECKING, Union -from .backend import Backend, BackendPlatform, backend -from .c_backend.unsorted import get_backend_count as c_get_backend_count -from .c_backend.unsorted import get_backend_id as c_get_backend_id -from .c_backend.unsorted import get_device_id as c_get_device_id -from .c_backend.unsorted import get_size_of as c_get_size_of -from .c_backend.unsorted import set_backend as c_set_backend +from .api import Backend, BackendPlatform, get_backend +from .c_library.unsorted import get_backend_count as c_get_backend_count +from .c_library.unsorted import get_backend_id as c_get_backend_id +from .c_library.unsorted import get_device_id as c_get_device_id +from .c_library.unsorted import get_size_of as c_get_size_of +from .c_library.unsorted import set_backend as c_set_backend if TYPE_CHECKING: from arrayfire import Array @@ -35,7 +35,7 @@ def set_backend(platform: Union[BackendPlatform, str]) -> None: RuntimeError If the given platform could not be set as new backend platform. """ - + backend = get_backend() current_active_platform = backend.platform if isinstance(platform, str): @@ -52,25 +52,14 @@ def set_backend(platform: Union[BackendPlatform, str]) -> None: if backend.platform == BackendPlatform.unified: c_set_backend(platform.value) - backend._load_backend_lib(platform) # FIXME should not access private API + # NOTE keep in mind that this operation works in-place + # FIXME should not access private API + backend._load_backend_lib(platform) if current_active_platform == backend.platform: raise RuntimeError(f"Could not set {platform} as new backend platform. Consider checking logs.") -def get_backend() -> Backend: - """ - Get the current active backend. - - Returns - ------- - value : Backend - Current active backend. - """ - - return backend - - def get_array_backend_name(array: Array) -> str: """ Get the name of the backend on which the Array is located. diff --git a/arrayfire/library/array_object.py b/arrayfire/library/array_object.py index f2c1850..ca28952 100755 --- a/arrayfire/library/array_object.py +++ b/arrayfire/library/array_object.py @@ -7,10 +7,10 @@ from .. import backend from ..backend import ArrayBuffer -from ..backend.c_backend import unsorted -from ..backend.c_backend.constant_array import create_constant_array -from ..backend.c_backend.indexing import CIndexStructure, IndexStructure -from ..backend.c_backend.reduction_operations import count_all +from ..backend.c_library import unsorted +from ..backend.c_library.constant_array import create_constant_array +from ..backend.c_library.indexing import CIndexStructure, IndexStructure +from ..backend.c_library.reduction_operations import count_all from ..dtypes import CType from ..dtypes import bool as af_bool from ..dtypes import float32 as af_float32 From 6774dd488ccb4fd5d3374404c8b075363a4660e0 Mon Sep 17 00:00:00 2001 From: Anton Chernyatevich Date: Fri, 11 Aug 2023 16:28:51 +0300 Subject: [PATCH 14/31] Fix c approach --- arrayfire/library/array_object.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/arrayfire/library/array_object.py b/arrayfire/library/array_object.py index ca28952..5176659 100755 --- a/arrayfire/library/array_object.py +++ b/arrayfire/library/array_object.py @@ -923,9 +923,7 @@ def _reorder(array: Array) -> Array: if array.ndim == 1: return array - out = Array() - out.arr = unsorted.reorder(array.arr, array.ndim) - return out + return Array(unsorted.reorder(array.arr, array.ndim)) def _metadata_string(dtype: Dtype, dims: Optional[Tuple[int, ...]] = None) -> str: From cf0ba2d61cf98f87d64393590943e8ee24559ad0 Mon Sep 17 00:00:00 2001 From: Anton Chernyatevich Date: Fri, 11 Aug 2023 19:39:54 +0300 Subject: [PATCH 15/31] Fix typings. Fix tests and array_api initialisation --- .isort.cfg | 2 +- arrayfire/__init__.py | 18 +- arrayfire/array_api/__init__.py | 43 ++++ arrayfire/array_api/array_object.py | 204 +++++++++++------- arrayfire/array_api/constants.py | 16 +- arrayfire/array_api/creation_function.py | 39 +++- arrayfire/array_api/data_type_functions.py | 35 +++ arrayfire/array_api/dtypes.py | 92 ++++---- arrayfire/array_api/tests/__init__.py | 0 .../tests/fixme_test_array_object.py | 177 +++++++++++++++ .../tests/test_creation_functions.py | 20 ++ arrayfire/backend/__init__.py | 36 +++- arrayfire/backend/c_library/error_handler.py | 3 +- arrayfire/dtypes/__init__.py | 20 +- arrayfire/library/array_object.py | 4 + arrayfire/library/device.py | 5 +- 16 files changed, 569 insertions(+), 145 deletions(-) mode change 100644 => 100755 arrayfire/__init__.py create mode 100755 arrayfire/array_api/data_type_functions.py create mode 100755 arrayfire/array_api/tests/__init__.py create mode 100755 arrayfire/array_api/tests/fixme_test_array_object.py create mode 100755 arrayfire/array_api/tests/test_creation_functions.py mode change 100644 => 100755 arrayfire/backend/__init__.py diff --git a/.isort.cfg b/.isort.cfg index 4fbfcfb..ec54132 100755 --- a/.isort.cfg +++ b/.isort.cfg @@ -1,3 +1,3 @@ [settings] line_length = 119 -multi_line_output = 4 +profile = black diff --git a/arrayfire/__init__.py b/arrayfire/__init__.py old mode 100644 new mode 100755 index 56d8a23..d0edcf7 --- a/arrayfire/__init__.py +++ b/arrayfire/__init__.py @@ -16,9 +16,19 @@ "complex128", "bool", ] -# fmt: off from .dtypes import ( - bool, complex64, complex128, float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64) + bool, + complex64, + complex128, + float32, + float64, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, +) from .library.array_object import Array - -# fmt: on diff --git a/arrayfire/array_api/__init__.py b/arrayfire/array_api/__init__.py index e69de29..d40c013 100755 --- a/arrayfire/array_api/__init__.py +++ b/arrayfire/array_api/__init__.py @@ -0,0 +1,43 @@ +# flake8: noqa + +__array_api_version__ = "2022.12" + +__all__ = ["__array_api_version__"] + +from .constants import Device + +__all__ += ["Device"] + +from .creation_function import asarray + +__all__ += ["asarray"] + +from .dtypes import ( + bool, + complex64, + complex128, + float32, + float64, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, +) + +__all__ += [ + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float32", + "float64", + "bool", +] diff --git a/arrayfire/array_api/array_object.py b/arrayfire/array_api/array_object.py index 95bed1a..8574d05 100755 --- a/arrayfire/array_api/array_object.py +++ b/arrayfire/array_api/array_object.py @@ -1,19 +1,22 @@ from __future__ import annotations -__all__ = ["Array"] - import types -from enum import IntEnum -from typing import TYPE_CHECKING, Any, Optional, Tuple, Union +from typing import Any, Optional, Tuple, Union from arrayfire import Array as AFArray -from arrayfire.array_api.constants import NestedSequence, SupportsBufferProtocol +from arrayfire.array_api.constants import Device, NestedSequence, PyCapsule, SupportsBufferProtocol +from arrayfire.array_api.dtypes import ( + all_dtypes, + boolean_dtypes, + complex_floating_dtypes, + dtype_categories, + floating_dtypes, + integer_or_boolean_dtypes, + numeric_dtypes, + promote_types, +) from arrayfire.dtypes import Dtype -if TYPE_CHECKING: - from .constants import PyCapsule - from .dtypes import all_dtypes, dtype_categories, numeric_dtypes, promote_types - class Array: _array: AFArray @@ -79,10 +82,12 @@ def _promote_scalar(self, scalar): elif isinstance(scalar, int): if self.dtype in boolean_dtypes: raise TypeError("Python int scalars cannot be promoted with bool arrays") - if self.dtype in integer_dtypes: - info = np.iinfo(self.dtype) - if not (info.min <= scalar <= info.max): - raise OverflowError("Python int scalars must be within the bounds of the dtype for integer arrays") + # TODO + # if self.dtype in integer_dtypes: + # info = np.iinfo(self.dtype) + # if not (info.min <= scalar <= info.max): + # raise OverflowError( + # "Python int scalars must be within the bounds of the dtype for integer arrays") # int + array(floating) is allowed elif isinstance(scalar, float): if self.dtype not in floating_dtypes: @@ -100,7 +105,42 @@ def _promote_scalar(self, scalar): # behavior for integers within the bounds of the integer dtype. # Outside of those bounds we use the default NumPy behavior (either # cast or raise OverflowError). - return Array._new(np.array(scalar, self.dtype)) + return Array._new(AFArray(scalar, self.dtype, shape=(1,))) + + @staticmethod + def _normalize_two_args(x1, x2) -> Tuple[Array, Array]: + """ + Normalize inputs to two arg functions to fix type promotion rules + + NumPy deviates from the spec type promotion rules in cases where one + argument is 0-dimensional and the other is not. For example: + + >>> import numpy as np + >>> a = np.array([1.0], dtype=np.float32) + >>> b = np.array(1.0, dtype=np.float64) + >>> np.add(a, b) # The spec says this should be float64 + array([2.], dtype=float32) + + To fix this, we add a dimension to the 0-dimension array before passing it + through. This works because a dimension would be added anyway from + broadcasting, so the resulting shape is the same, but this prevents NumPy + from not promoting the dtype. + """ + # Another option would be to use signature=(x1.dtype, x2.dtype, None), + # but that only works for ufuncs, so we would have to call the ufuncs + # directly in the operator methods. One should also note that this + # sort of trick wouldn't work for functions like searchsorted, which + # don't do normal broadcasting, but there aren't any functions like + # that in the array API namespace. + if x1.ndim == 0 and x2.ndim != 0: + # The _array[None] workaround was chosen because it is relatively + # performant. broadcast_to(x1._array, x2.shape) is much slower. We + # could also manually type promote x2, but that is more complicated + # and about the same performance as this. + x1 = Array._new(x1._array[None]) + elif x2.ndim == 0 and x1.ndim != 0: + x2 = Array._new(x2._array[None]) + return (x1, x2) @classmethod def _new(cls, x: Union[Array, bool, int, float, complex, NestedSequence, SupportsBufferProtocol], /) -> Array: @@ -134,6 +174,17 @@ def __abs__(self: Array, /) -> Array: res = self._array.__abs__() return self.__class__._new(res) + def __add__(self: Array, other: Union[int, float, Array], /) -> Array: + """ + Performs the operation __add__. + """ + other = self._check_allowed_dtypes(other, "numeric", "__add__") + if other is NotImplemented: + return other + self, other = self._normalize_two_args(self, other) + res = self._array.__add__(other._array) + return self.__class__._new(res) + def __and__(self: Array, other: Union[int, bool, Array], /) -> Array: """ Performs the operation __and__. @@ -178,12 +229,12 @@ def __dlpack__(self: Array, /, *, stream: None = None) -> PyCapsule: """ return self._array.__dlpack__(stream=stream) - def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]: - """ - Performs the operation __dlpack_device__. - """ - # Note: device support is required for this - return self._array.__dlpack_device__() + # def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]: + # """ + # Performs the operation __dlpack_device__. + # """ + # # Note: device support is required for this + # return self._array.__dlpack_device__() def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array: """ @@ -205,7 +256,7 @@ def __float__(self: Array, /) -> float: # Note: This is an error here. if self._array.ndim != 0: raise TypeError("float is only allowed on arrays with 0 dimensions") - if self.dtype in _complex_floating_dtypes: + if self.dtype in complex_floating_dtypes: raise TypeError("float is not allowed on complex floating-point arrays") res = self._array.__float__() return res @@ -232,22 +283,22 @@ def __ge__(self: Array, other: Union[int, float, Array], /) -> Array: res = self._array.__ge__(other._array) return self.__class__._new(res) - def __getitem__( - self: Array, - key: Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], Array], - /, - ) -> Array: - """ - Performs the operation __getitem__. - """ - # Note: Only indices required by the spec are allowed. See the - # docstring of _validate_index - self._validate_index(key) - if isinstance(key, Array): - # Indexing self._array with array_api arrays can be erroneous - key = key._array - res = self._array.__getitem__(key) - return self._new(res) + # def __getitem__( + # self: Array, + # key: Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], Array], + # /, + # ) -> Array: + # """ + # Performs the operation __getitem__. + # """ + # # Note: Only indices required by the spec are allowed. See the + # # docstring of _validate_index + # self._validate_index(key) + # if isinstance(key, Array): + # # Indexing self._array with array_api arrays can be erroneous + # key = key._array + # res = self._array.__getitem__(key) + # return self._new(res) def __gt__(self: Array, other: Union[int, float, Array], /) -> Array: """ @@ -267,7 +318,7 @@ def __int__(self: Array, /) -> int: # Note: This is an error here. if self._array.ndim != 0: raise TypeError("int is only allowed on arrays with 0 dimensions") - if self.dtype in _complex_floating_dtypes: + if self.dtype in complex_floating_dtypes: raise TypeError("int is not allowed on complex floating-point arrays") res = self._array.__int__() return res @@ -283,7 +334,7 @@ def __invert__(self: Array, /) -> Array: """ Performs the operation __invert__. """ - if self.dtype not in _integer_or_boolean_dtypes: + if self.dtype not in integer_or_boolean_dtypes: raise TypeError("Only integer or boolean dtypes are allowed in __invert__") res = self._array.__invert__() return self.__class__._new(res) @@ -370,7 +421,7 @@ def __neg__(self: Array, /) -> Array: """ Performs the operation __neg__. """ - if self.dtype not in _numeric_dtypes: + if self.dtype not in numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in __neg__") res = self._array.__neg__() return self.__class__._new(res) @@ -390,23 +441,23 @@ def __pos__(self: Array, /) -> Array: """ Performs the operation __pos__. """ - if self.dtype not in _numeric_dtypes: + if self.dtype not in numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in __pos__") res = self._array.__pos__() return self.__class__._new(res) - def __pow__(self: Array, other: Union[int, float, Array], /) -> Array: - """ - Performs the operation __pow__. - """ - from ._elementwise_functions import pow + # def __pow__(self: Array, other: Union[int, float, Array], /) -> Array: + # """ + # Performs the operation __pow__. + # """ + # from ._elementwise_functions import pow - other = self._check_allowed_dtypes(other, "numeric", "__pow__") - if other is NotImplemented: - return other - # Note: NumPy's __pow__ does not follow type promotion rules for 0-d - # arrays, so we use pow() here instead. - return pow(self, other) + # other = self._check_allowed_dtypes(other, "numeric", "__pow__") + # if other is NotImplemented: + # return other + # # Note: NumPy's __pow__ does not follow type promotion rules for 0-d + # # arrays, so we use pow() here instead. + # return pow(self, other) def __rshift__(self: Array, other: Union[int, Array], /) -> Array: """ @@ -419,22 +470,22 @@ def __rshift__(self: Array, other: Union[int, Array], /) -> Array: res = self._array.__rshift__(other._array) return self.__class__._new(res) - def __setitem__( - self, - key: Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], Array], - value: Union[int, float, bool, Array], - /, - ) -> None: - """ - Performs the operation __setitem__. - """ - # Note: Only indices required by the spec are allowed. See the - # docstring of _validate_index - self._validate_index(key) - if isinstance(key, Array): - # Indexing self._array with array_api arrays can be erroneous - key = key._array - self._array.__setitem__(key, asarray(value)._array) + # def __setitem__( + # self, + # key: Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], Array], + # value: Union[int, float, bool, Array], + # /, + # ) -> None: + # """ + # Performs the operation __setitem__. + # """ + # # Note: Only indices required by the spec are allowed. See the + # # docstring of _validate_index + # self._validate_index(key) + # if isinstance(key, Array): + # # Indexing self._array with array_api arrays can be erroneous + # key = key._array + # self._array.__setitem__(key, asarray(value)._array) def __sub__(self: Array, other: Union[int, float, Array], /) -> Array: """ @@ -762,15 +813,15 @@ def dtype(self) -> Dtype: """ return self._array.dtype - @property - def device(self) -> Device: - return "cpu" + # @property + # def device(self) -> Device: + # return "cpu" - @property - def mT(self) -> Array: - from .linalg import matrix_transpose + # @property + # def mT(self) -> Array: + # from .linalg import matrix_transpose - return matrix_transpose(self) + # return matrix_transpose(self) @property def ndim(self) -> int: @@ -811,6 +862,7 @@ def T(self) -> Array: # https://data-apis.org/array-api/latest/API_specification/array_object.html#t if self.ndim != 2: raise ValueError( - "x.T requires x to have 2 dimensions. Use x.mT to transpose stacks of matrices and permute_dims() to permute dimensions." + "x.T requires x to have 2 dimensions. " + "Use x.mT to transpose stacks of matrices and permute_dims() to permute dimensions." ) return self.__class__._new(self._array.T) diff --git a/arrayfire/array_api/constants.py b/arrayfire/array_api/constants.py index 46007d2..8f89829 100755 --- a/arrayfire/array_api/constants.py +++ b/arrayfire/array_api/constants.py @@ -8,17 +8,16 @@ from __future__ import annotations +from enum import Enum + __all__ = [ - "Array", "Device", "SupportsDLPack", "SupportsBufferProtocol", "PyCapsule", ] -from typing import Any, Literal, Protocol, TypeVar - -from .array_object import Array +from typing import Any, Iterator, Protocol, TypeVar _T_co = TypeVar("_T_co", covariant=True) @@ -31,7 +30,14 @@ def __len__(self, /) -> int: ... -Device = Literal["cpu"] # FIXME: add support for other devices +class Device(Enum): + cpu = "cpu" + gpu = "gpu" + + def __iter__(self) -> Iterator[Device]: + yield self + + SupportsBufferProtocol = Any PyCapsule = Any diff --git a/arrayfire/array_api/creation_function.py b/arrayfire/array_api/creation_function.py index 76eb1b0..95ea40a 100755 --- a/arrayfire/array_api/creation_function.py +++ b/arrayfire/array_api/creation_function.py @@ -1,11 +1,24 @@ +from __future__ import annotations + from typing import Optional, Union from arrayfire import Array as AFArray -from arrayfire.dtypes import Dtype, supported_dtypes -from arrayfire.library.device import supported_devices +from arrayfire.array_api.array_object import Array +from arrayfire.array_api.constants import Device, NestedSequence, SupportsBufferProtocol +from arrayfire.array_api.dtypes import all_dtypes +from arrayfire.dtypes import Dtype +from arrayfire.library.device import PointerSource + + +def _check_valid_dtype(dtype: Optional[Dtype]) -> None: + # Note: Only spelling dtypes as the dtype objects is supported. -from .array_object import Array -from .constants import Device, NestedSequence, SupportsBufferProtocol + # We use this instead of "dtype in _all_dtypes" because the dtype objects + # define equality with the sorts of things we want to disallow. + for d in (None,) + all_dtypes: + if dtype is d: + return + raise ValueError("dtype must be one of the supported dtypes") def asarray( @@ -16,8 +29,7 @@ def asarray( device: Optional[Device] = None, copy: Optional[bool] = None, ) -> Array: - if dtype not in supported_dtypes: - raise ValueError(f"Unsupported dtype {dtype!r}") + _check_valid_dtype(dtype) # if device not in supported_devices: # raise ValueError(f"Unsupported device {device!r}") @@ -25,5 +37,16 @@ def asarray( if dtype is None and isinstance(obj, int) and (obj > 2**64 or obj < -(2**63)): raise OverflowError("Integer out of bounds for array dtypes") - array = AFArray(obj, dtype=dtype, device=device) - return Array._new(array) + if device == Device.cpu or device is None: + pointer_source = PointerSource.host + elif device == Device.gpu: + pointer_source = PointerSource.device + else: + raise ValueError(f"Unsupported device {device!r}") + + if isinstance(obj, int | float): + afarray = AFArray([obj], dtype=dtype, shape=(1,), pointer_source=pointer_source) + return Array._new(afarray) + + afarray = AFArray(obj, dtype=dtype, pointer_source=pointer_source) + return Array._new(afarray) diff --git a/arrayfire/array_api/data_type_functions.py b/arrayfire/array_api/data_type_functions.py new file mode 100755 index 0000000..f9cadd2 --- /dev/null +++ b/arrayfire/array_api/data_type_functions.py @@ -0,0 +1,35 @@ +from typing import Union + +from arrayfire import Array as AFArray +from arrayfire.array_api.array_object import Array +from arrayfire.array_api.dtypes import all_dtypes, promote_types +from arrayfire.dtypes import Dtype + + +def result_type(*arrays_and_dtypes: Union[Array, Dtype]) -> Dtype: + """ + Array API compatible wrapper for :py:func:`np.result_type `. + + See its docstring for more information. + """ + # Note: we use a custom implementation that gives only the type promotions + # required by the spec rather than using np.result_type. NumPy implements + # too many extra type promotions like int64 + uint64 -> float64, and does + # value-based casting on scalar arrays. + A = [] + for a in arrays_and_dtypes: + if isinstance(a, Array): + a = a.dtype + elif isinstance(a, AFArray) or a not in all_dtypes: + raise TypeError("result_type() inputs must be array_api arrays or dtypes") + A.append(a) + + if len(A) == 0: + raise ValueError("at least one array or dtype is required") + elif len(A) == 1: + return A[0] + else: + t = A[0] + for t2 in A[1:]: + t = promote_types(t, t2) + return t diff --git a/arrayfire/array_api/dtypes.py b/arrayfire/array_api/dtypes.py index 0260e00..c2a5833 100755 --- a/arrayfire/array_api/dtypes.py +++ b/arrayfire/array_api/dtypes.py @@ -1,31 +1,42 @@ -from arrayfire import ( - bool, complex64, complex128, float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64) -from arrayfire.dtypes import Dtype +from __future__ import annotations -all_dtypes = ( - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, +__all__ = [ + "all_dtypes", + "boolean_dtypes", + "real_floating_dtypes", + "floating_dtypes", + "complex_floating_dtypes", + "integer_dtypes", + "signed_integer_dtypes", + "unsigned_integer_dtypes", + "integer_or_boolean_dtypes", + "real_numeric_dtypes", + "numeric_dtypes", + "dtype_categories", + # OG + "bool", + "complex64", + "complex128", + "float32", + "float64", + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", +] + +from typing import TYPE_CHECKING + +from arrayfire import ( + bool, complex64, complex128, - bool, -) -boolean_dtypes = (bool,) -real_floating_dtypes = (float32, float64) -floating_dtypes = (float32, float64, complex64, complex128) -complex_floating_dtypes = (complex64, complex128) -integer_dtypes = (int8, int16, int32, int64, uint8, uint16, uint32, uint64) -signed_integer_dtypes = (int8, int16, int32, int64) -unsigned_integer_dtypes = (uint8, uint16, uint32, uint64) -integer_orboolean_dtypes = ( - bool, + float32, + float64, int8, int16, int32, @@ -35,9 +46,11 @@ uint32, uint64, ) -real_numeric_dtypes = ( - float32, - float64, + +if TYPE_CHECKING: + from arrayfire.dtypes import Dtype + +all_dtypes = ( int8, int16, int32, @@ -46,28 +59,29 @@ uint16, uint32, uint64, -) -numeric_dtypes = ( float32, float64, complex64, complex128, - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, + bool, ) +boolean_dtypes = (bool,) +real_floating_dtypes = (float32, float64) +floating_dtypes = (float32, float64, complex64, complex128) +complex_floating_dtypes = (complex64, complex128) +integer_dtypes = (int8, int16, int32, int64, uint8, uint16, uint32, uint64) +signed_integer_dtypes = (int8, int16, int32, int64) +unsigned_integer_dtypes = (uint8, uint16, uint32, uint64) +integer_or_boolean_dtypes = boolean_dtypes + integer_dtypes +real_numeric_dtypes = real_floating_dtypes + integer_dtypes +numeric_dtypes = floating_dtypes + integer_dtypes dtype_categories = { "all": all_dtypes, "real numeric": real_numeric_dtypes, "numeric": numeric_dtypes, "integer": integer_dtypes, - "integer or boolean": integer_orboolean_dtypes, + "integer or boolean": integer_or_boolean_dtypes, "boolean": boolean_dtypes, "real floating-point": floating_dtypes, "complex floating-point": complex_floating_dtypes, diff --git a/arrayfire/array_api/tests/__init__.py b/arrayfire/array_api/tests/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/arrayfire/array_api/tests/fixme_test_array_object.py b/arrayfire/array_api/tests/fixme_test_array_object.py new file mode 100755 index 0000000..ed52637 --- /dev/null +++ b/arrayfire/array_api/tests/fixme_test_array_object.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Iterator + +import pytest + +from arrayfire import bool, int8, int16, int32, int64, uint64 +from arrayfire.array_api.creation_function import asarray +from arrayfire.array_api.data_type_functions import result_type +from arrayfire.array_api.dtypes import ( + boolean_dtypes, + complex_floating_dtypes, + floating_dtypes, + integer_dtypes, + integer_or_boolean_dtypes, + numeric_dtypes, + real_floating_dtypes, + real_numeric_dtypes, +) + +if TYPE_CHECKING: + from arrayfire.array_api.array_object import Array + + +def test_operators() -> None: + # For every operator, we test that it works for the required type + # combinations and raises TypeError otherwise + binary_op_dtypes = { + "__add__": "numeric", + "__and__": "integer_or_boolean", + "__eq__": "all", + "__floordiv__": "real numeric", + "__ge__": "real numeric", + "__gt__": "real numeric", + "__le__": "real numeric", + "__lshift__": "integer", + "__lt__": "real numeric", + "__mod__": "real numeric", + "__mul__": "numeric", + "__ne__": "all", + "__or__": "integer_or_boolean", + "__pow__": "numeric", + "__rshift__": "integer", + "__sub__": "numeric", + "__truediv__": "floating", + "__xor__": "integer_or_boolean", + } + + # Recompute each time because of in-place ops + def _array_vals() -> Iterator[Array]: + for d in integer_dtypes: + yield asarray(1, dtype=d) + for d in boolean_dtypes: + yield asarray(False, dtype=d) + for d in floating_dtypes: + yield asarray(1.0, dtype=d) + + BIG_INT = int(1e30) + for op, dtypes in binary_op_dtypes.items(): + ops = [op] + if op not in ["__eq__", "__ne__", "__le__", "__ge__", "__lt__", "__gt__"]: + rop = "__r" + op[2:] + iop = "__i" + op[2:] + ops += [rop, iop] + for s in [1, 1.0, 1j, BIG_INT, False]: + for _op in ops: + for a in _array_vals(): + # Test array op scalar. From the spec, the following combinations + # are supported: + + # - Python bool for a bool array dtype, + # - a Python int within the bounds of the given dtype for integer array dtypes, + # - a Python int or float for real floating-point array dtypes + # - a Python int, float, or complex for complex floating-point array dtypes + + if ( + ( + dtypes == "all" + or dtypes == "numeric" + and a.dtype in numeric_dtypes + or dtypes == "real numeric" + and a.dtype in real_numeric_dtypes + or dtypes == "integer" + and a.dtype in integer_dtypes + or dtypes == "integer_or_boolean" + and a.dtype in integer_or_boolean_dtypes + or dtypes == "boolean" + and a.dtype in boolean_dtypes + or dtypes == "floating" + and a.dtype in floating_dtypes + ) + # bool is a subtype of int, which is why we avoid + # isinstance here. + and ( + a.dtype in boolean_dtypes + and type(s) == bool + or a.dtype in integer_dtypes + and type(s) == int + or a.dtype in real_floating_dtypes + and type(s) in [float, int] + or a.dtype in complex_floating_dtypes + and type(s) in [complex, float, int] + ) + ): + if a.dtype in integer_dtypes and s == BIG_INT: + pytest.raises(OverflowError, lambda: getattr(a, _op)(s)) + else: + # ignore warnings from pow(BIG_INT) + pytest.raises(RuntimeWarning, getattr(a, _op)(s)) + getattr(a, _op)(s) + else: + pytest.raises(TypeError, lambda: getattr(a, _op)(s)) + + # Test array op array. + for _op in ops: + for x in _array_vals(): + for y in _array_vals(): + # See the promotion table in NEP 47 or the array + # API spec page on type promotion. Mixed kind + # promotion is not defined. + if ( + x.dtype == uint64 + and y.dtype in [int8, int16, int32, int64] + or y.dtype == uint64 + and x.dtype in [int8, int16, int32, int64] + or x.dtype in integer_dtypes + and y.dtype not in integer_dtypes + or y.dtype in integer_dtypes + and x.dtype not in integer_dtypes + or x.dtype in boolean_dtypes + and y.dtype not in boolean_dtypes + or y.dtype in boolean_dtypes + and x.dtype not in boolean_dtypes + or x.dtype in floating_dtypes + and y.dtype not in floating_dtypes + or y.dtype in floating_dtypes + and x.dtype not in floating_dtypes + ): + pytest.raises(TypeError, lambda: getattr(x, _op)(y)) + # Ensure in-place operators only promote to the same dtype as the left operand. + elif _op.startswith("__i") and result_type(x.dtype, y.dtype) != x.dtype: + pytest.raises(TypeError, lambda: getattr(x, _op)(y)) + # Ensure only those dtypes that are required for every operator are allowed. + elif ( + dtypes == "all" + and ( + x.dtype in boolean_dtypes + and y.dtype in boolean_dtypes + or x.dtype in numeric_dtypes + and y.dtype in numeric_dtypes + ) + or ( + dtypes == "real numeric" + and x.dtype in real_numeric_dtypes + and y.dtype in real_numeric_dtypes + ) + or (dtypes == "numeric" and x.dtype in numeric_dtypes and y.dtype in numeric_dtypes) + or dtypes == "integer" + and x.dtype in integer_dtypes + and y.dtype in integer_dtypes + or dtypes == "integer_or_boolean" + and ( + x.dtype in integer_dtypes + and y.dtype in integer_dtypes + or x.dtype in boolean_dtypes + and y.dtype in boolean_dtypes + ) + or dtypes == "boolean" + and x.dtype in boolean_dtypes + and y.dtype in boolean_dtypes + or dtypes == "floating" + and x.dtype in floating_dtypes + and y.dtype in floating_dtypes + ): + getattr(x, _op)(y) + else: + pytest.raises(TypeError, lambda: getattr(x, _op)(y)) diff --git a/arrayfire/array_api/tests/test_creation_functions.py b/arrayfire/array_api/tests/test_creation_functions.py new file mode 100755 index 0000000..6611882 --- /dev/null +++ b/arrayfire/array_api/tests/test_creation_functions.py @@ -0,0 +1,20 @@ +import pytest + +from arrayfire.array_api import asarray +from arrayfire.array_api.array_object import Array +from arrayfire.array_api.constants import Device +from arrayfire.dtypes import float16 + + +def test_asarray_errors() -> None: + # Test various protections against incorrect usage + pytest.raises(TypeError, lambda: Array([1])) + pytest.raises(TypeError, lambda: asarray(["a"])) + pytest.raises(ValueError, lambda: asarray([1.0], dtype=float16)) + pytest.raises(OverflowError, lambda: asarray(2**100)) + # pytest.raises(OverflowError, lambda: asarray([2**100])) # FIXME + asarray([1], device=Device.cpu) # Doesn't error + pytest.raises(ValueError, lambda: asarray([1], device="gpu")) # type: ignore[arg-type] + + pytest.raises(ValueError, lambda: asarray([1], dtype=int)) # type: ignore[arg-type] + pytest.raises(ValueError, lambda: asarray([1], dtype="i")) # type: ignore[arg-type] diff --git a/arrayfire/backend/__init__.py b/arrayfire/backend/__init__.py old mode 100644 new mode 100755 index 7ac21de..1545d3b --- a/arrayfire/backend/__init__.py +++ b/arrayfire/backend/__init__.py @@ -36,13 +36,37 @@ "set_backend", ] -# fmt: off from .api import BackendPlatform, get_backend from .c_library.operators import ( - add, bitand, bitnot, bitor, bitshiftl, bitshiftr, bitxor, div, eq, ge, gt, le, lt, mod, mul, neq, pow, sub) + add, + bitand, + bitnot, + bitor, + bitshiftl, + bitshiftr, + bitxor, + div, + eq, + ge, + gt, + le, + lt, + mod, + mul, + neq, + pow, + sub, +) from .constants import ArrayBuffer from .helpers import ( - get_active_backend, get_array_backend_name, get_array_device_id, get_available_backends, get_backend_count, - get_backend_id, get_device_id, get_dtype_size, get_size_of, set_backend) - -# fmt: on + get_active_backend, + get_array_backend_name, + get_array_device_id, + get_available_backends, + get_backend_count, + get_backend_id, + get_device_id, + get_dtype_size, + get_size_of, + set_backend, +) diff --git a/arrayfire/backend/c_library/error_handler.py b/arrayfire/backend/c_library/error_handler.py index aef3865..cc9be66 100755 --- a/arrayfire/backend/c_library/error_handler.py +++ b/arrayfire/backend/c_library/error_handler.py @@ -14,5 +14,6 @@ def safe_call(c_err: int) -> None: return err_str = ctypes.c_char_p(0) - backend_api.af_get_last_error(ctypes.pointer(err_str), ctypes.pointer(c_dim_t(0))) + err_len = c_dim_t(0) + backend_api.af_get_last_error(ctypes.pointer(err_str), ctypes.pointer(err_len)) raise RuntimeError(to_str(err_str)) diff --git a/arrayfire/dtypes/__init__.py b/arrayfire/dtypes/__init__.py index a006696..b244b52 100644 --- a/arrayfire/dtypes/__init__.py +++ b/arrayfire/dtypes/__init__.py @@ -8,6 +8,7 @@ "uint16", "uint32", "uint64", + "float16", "float32", "float64", "complex64", @@ -22,7 +23,7 @@ CType = Type[ctypes._SimpleCData] -@dataclass +@dataclass(frozen=True) class Dtype: typecode: str c_type: CType @@ -39,10 +40,25 @@ class Dtype: uint16 = Dtype("H", ctypes.c_ushort, "unsigned short int", 11) uint32 = Dtype("I", ctypes.c_uint, "unsigned int", 6) uint64 = Dtype("L", ctypes.c_ulonglong, "unsigned long int", 9) +float16 = Dtype("e", ctypes.c_uint16, "half", 12) float32 = Dtype("f", ctypes.c_float, "float", 0) float64 = Dtype("d", ctypes.c_double, "double", 2) complex64 = Dtype("F", ctypes.c_float * 2, "float complext", 1) # type: ignore[arg-type] complex128 = Dtype("D", ctypes.c_double * 2, "double complext", 3) # type: ignore[arg-type] bool = Dtype("b", ctypes.c_bool, "bool", 4) -supported_dtypes = [int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, complex64, complex128, bool] +supported_dtypes = ( + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float16, + float32, + float64, + complex64, + complex128, + bool, +) diff --git a/arrayfire/library/array_object.py b/arrayfire/library/array_object.py index 5176659..630f229 100755 --- a/arrayfire/library/array_object.py +++ b/arrayfire/library/array_object.py @@ -912,6 +912,10 @@ def copy(self) -> Array: # BUG: this is not a deep copy self.arr = unsorted.copy_array(self.arr) return self + @classmethod + def from_afarray(cls, array: AFArrayType) -> None: + cls.arr = array + IndexKey = Union[int, slice, Tuple[Union[int, slice], ...], Array] diff --git a/arrayfire/library/device.py b/arrayfire/library/device.py index a51dc21..42f8edd 100644 --- a/arrayfire/library/device.py +++ b/arrayfire/library/device.py @@ -6,9 +6,8 @@ class PointerSource(enum.Enum): Source of the pointer. """ - # FIXME - device = 0 - host = 1 + device = 0 # gpu + host = 1 # cpu supported_devices = [] From d5c4cc0e246345fc08a5b470cf527c3c8e89e5a2 Mon Sep 17 00:00:00 2001 From: Anton Chernyatevich Date: Fri, 11 Aug 2023 19:54:28 +0300 Subject: [PATCH 16/31] Change dtypes structure --- arrayfire/__init__.py | 48 ++++++++++++-- arrayfire/array_api/array_object.py | 11 +--- arrayfire/backend/c_library/constant_array.py | 3 +- arrayfire/backend/c_library/error_handler.py | 2 +- arrayfire/backend/c_library/unsorted.py | 3 +- arrayfire/{dtypes/helpers.py => dtypes.py} | 54 ++++++++++++++-- arrayfire/dtypes/__init__.py | 64 ------------------- arrayfire/dtypes/functions.py | 32 ---------- arrayfire/library/array_object.py | 5 +- arrayfire/platform.py | 10 +-- arrayfire/version.py | 1 + 11 files changed, 105 insertions(+), 128 deletions(-) rename arrayfire/{dtypes/helpers.py => dtypes.py} (60%) delete mode 100644 arrayfire/dtypes/__init__.py delete mode 100644 arrayfire/dtypes/functions.py diff --git a/arrayfire/__init__.py b/arrayfire/__init__.py index d0edcf7..8cebdf0 100755 --- a/arrayfire/__init__.py +++ b/arrayfire/__init__.py @@ -1,7 +1,16 @@ -__all__ = [ - # array objects - "Array", - # dtypes +# flake8: noqa +from .version import ARRAYFIRE_VERSION, VERSION + +__all__ = ["__version__"] +__version__ = VERSION + +__all__ += ["__arrayfire_version__"] +__arrayfire_version__ = ARRAYFIRE_VERSION + +__all__ += ["Array"] +from .library.array_object import Array + +__all__ += [ "int8", "int16", "int32", @@ -10,6 +19,7 @@ "uint16", "uint32", "uint64", + "float16", "float32", "float64", "complex64", @@ -20,6 +30,7 @@ bool, complex64, complex128, + float16, float32, float64, int8, @@ -31,4 +42,31 @@ uint32, uint64, ) -from .library.array_object import Array + +__all__ += [ + "get_backend", + "get_active_backend", # DeprecationWarning + "get_array_backend_name", + "get_array_device_id", + "get_available_backends", # DeprecationWarning + "get_backend_count", + "get_backend_id", # DeprecationWarning + "get_device_id", # DeprecationWarning + "get_dtype_size", + "get_size_of", # DeprecationWarning + "set_backend", +] + +from .backend import ( + get_backend, + set_backend, + get_available_backends, + get_backend_count, + get_backend_id, + get_active_backend, + get_array_backend_name, + get_array_device_id, + get_device_id, + get_dtype_size, + get_size_of, +) diff --git a/arrayfire/array_api/array_object.py b/arrayfire/array_api/array_object.py index 8574d05..1ec531d 100755 --- a/arrayfire/array_api/array_object.py +++ b/arrayfire/array_api/array_object.py @@ -6,15 +6,8 @@ from arrayfire import Array as AFArray from arrayfire.array_api.constants import Device, NestedSequence, PyCapsule, SupportsBufferProtocol from arrayfire.array_api.dtypes import ( - all_dtypes, - boolean_dtypes, - complex_floating_dtypes, - dtype_categories, - floating_dtypes, - integer_or_boolean_dtypes, - numeric_dtypes, - promote_types, -) + all_dtypes, boolean_dtypes, complex_floating_dtypes, dtype_categories, floating_dtypes, integer_or_boolean_dtypes, + numeric_dtypes, promote_types) from arrayfire.dtypes import Dtype diff --git a/arrayfire/backend/c_library/constant_array.py b/arrayfire/backend/c_library/constant_array.py index b3a6312..5454cad 100755 --- a/arrayfire/backend/c_library/constant_array.py +++ b/arrayfire/backend/c_library/constant_array.py @@ -4,8 +4,7 @@ from typing import TYPE_CHECKING, Tuple, Union from arrayfire.backend.api import backend_api -from arrayfire.dtypes import Dtype, int64, uint64 -from arrayfire.dtypes.helpers import CShape, implicit_dtype +from arrayfire.dtypes import CShape, Dtype, implicit_dtype, int64, uint64 from .error_handler import safe_call diff --git a/arrayfire/backend/c_library/error_handler.py b/arrayfire/backend/c_library/error_handler.py index cc9be66..35bd945 100755 --- a/arrayfire/backend/c_library/error_handler.py +++ b/arrayfire/backend/c_library/error_handler.py @@ -2,7 +2,7 @@ from enum import Enum from arrayfire.backend.api import backend_api -from arrayfire.dtypes.helpers import c_dim_t, to_str +from arrayfire.dtypes import c_dim_t, to_str class _ErrorCodes(Enum): diff --git a/arrayfire/backend/c_library/unsorted.py b/arrayfire/backend/c_library/unsorted.py index aa970d7..d38d4fa 100755 --- a/arrayfire/backend/c_library/unsorted.py +++ b/arrayfire/backend/c_library/unsorted.py @@ -5,8 +5,7 @@ from arrayfire.backend.api import backend_api from arrayfire.backend.constants import ArrayBuffer -from arrayfire.dtypes import CType, Dtype -from arrayfire.dtypes.helpers import CShape, c_dim_t, to_str +from arrayfire.dtypes import CShape, CType, Dtype, c_dim_t, to_str from arrayfire.library.device import PointerSource from .error_handler import safe_call diff --git a/arrayfire/dtypes/helpers.py b/arrayfire/dtypes.py similarity index 60% rename from arrayfire/dtypes/helpers.py rename to arrayfire/dtypes.py index 30b8d4c..3417034 100644 --- a/arrayfire/dtypes/helpers.py +++ b/arrayfire/dtypes.py @@ -1,13 +1,55 @@ from __future__ import annotations import ctypes -from typing import Tuple, Union +from dataclasses import dataclass +from typing import Tuple, Type, Union from arrayfire.platform import is_arch_x86 -from . import Dtype -from . import bool as af_bool -from . import complex64, complex128, float32, float64, int64, supported_dtypes +CType = Type[ctypes._SimpleCData] +python_bool = bool + + +@dataclass(frozen=True) +class Dtype: + typecode: str + c_type: CType + typename: str + c_api_value: int # Internal use only + + +# Specification required +int8 = Dtype("i8", ctypes.c_char, "int8", 4) # HACK int8 - Not Supported, b8? +int16 = Dtype("h", ctypes.c_short, "short int", 10) +int32 = Dtype("i", ctypes.c_int, "int", 5) +int64 = Dtype("l", ctypes.c_longlong, "long int", 8) +uint8 = Dtype("B", ctypes.c_ubyte, "unsigned_char", 7) +uint16 = Dtype("H", ctypes.c_ushort, "unsigned short int", 11) +uint32 = Dtype("I", ctypes.c_uint, "unsigned int", 6) +uint64 = Dtype("L", ctypes.c_ulonglong, "unsigned long int", 9) +float16 = Dtype("e", ctypes.c_uint16, "half", 12) +float32 = Dtype("f", ctypes.c_float, "float", 0) +float64 = Dtype("d", ctypes.c_double, "double", 2) +complex64 = Dtype("F", ctypes.c_float * 2, "float complext", 1) # type: ignore[arg-type] +complex128 = Dtype("D", ctypes.c_double * 2, "double complext", 3) # type: ignore[arg-type] +bool = Dtype("b", ctypes.c_bool, "bool", 4) + +supported_dtypes = ( + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float16, + float32, + float64, + complex64, + complex128, + bool, +) + c_dim_t = ctypes.c_int if is_arch_x86() else ctypes.c_longlong ShapeType = Tuple[int, ...] @@ -38,8 +80,8 @@ def to_str(c_str: ctypes.c_char_p) -> str: def implicit_dtype(number: Union[int, float], array_dtype: Dtype) -> Dtype: - if isinstance(number, bool): - number_dtype = af_bool + if isinstance(number, python_bool): + number_dtype = bool if isinstance(number, int): number_dtype = int64 elif isinstance(number, float): diff --git a/arrayfire/dtypes/__init__.py b/arrayfire/dtypes/__init__.py deleted file mode 100644 index b244b52..0000000 --- a/arrayfire/dtypes/__init__.py +++ /dev/null @@ -1,64 +0,0 @@ -from __future__ import annotations - -__all__ = [ - "int16", - "int32", - "int64", - "uint8", - "uint16", - "uint32", - "uint64", - "float16", - "float32", - "float64", - "complex64", - "complex128", - "bool", -] - -import ctypes -from dataclasses import dataclass -from typing import Type - -CType = Type[ctypes._SimpleCData] - - -@dataclass(frozen=True) -class Dtype: - typecode: str - c_type: CType - typename: str - c_api_value: int # Internal use only - - -# Specification required -int8 = Dtype("i8", ctypes.c_char, "int8", 4) # HACK int8 - Not Supported, b8? -int16 = Dtype("h", ctypes.c_short, "short int", 10) -int32 = Dtype("i", ctypes.c_int, "int", 5) -int64 = Dtype("l", ctypes.c_longlong, "long int", 8) -uint8 = Dtype("B", ctypes.c_ubyte, "unsigned_char", 7) -uint16 = Dtype("H", ctypes.c_ushort, "unsigned short int", 11) -uint32 = Dtype("I", ctypes.c_uint, "unsigned int", 6) -uint64 = Dtype("L", ctypes.c_ulonglong, "unsigned long int", 9) -float16 = Dtype("e", ctypes.c_uint16, "half", 12) -float32 = Dtype("f", ctypes.c_float, "float", 0) -float64 = Dtype("d", ctypes.c_double, "double", 2) -complex64 = Dtype("F", ctypes.c_float * 2, "float complext", 1) # type: ignore[arg-type] -complex128 = Dtype("D", ctypes.c_double * 2, "double complext", 3) # type: ignore[arg-type] -bool = Dtype("b", ctypes.c_bool, "bool", 4) - -supported_dtypes = ( - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float16, - float32, - float64, - complex64, - complex128, - bool, -) diff --git a/arrayfire/dtypes/functions.py b/arrayfire/dtypes/functions.py deleted file mode 100644 index a08619f..0000000 --- a/arrayfire/dtypes/functions.py +++ /dev/null @@ -1,32 +0,0 @@ -from typing import Tuple, Union - -from ..library.array_object import Array -from . import Dtype - -# TODO implement functions - - -def astype(x: Array, dtype: Dtype, /, *, copy: bool = True) -> Array: - return NotImplemented - - -def can_cast(from_: Union[Dtype, Array], to: Dtype, /) -> bool: - return NotImplemented - - -def finfo(type: Union[Dtype, Array], /): # type: ignore[no-untyped-def] - # NOTE expected return type -> finfo_object - return NotImplemented - - -def iinfo(type: Union[Dtype, Array], /): # type: ignore[no-untyped-def] - # NOTE expected return type -> iinfo_object - return NotImplemented - - -def isdtype(dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]]) -> bool: - return NotImplemented - - -def result_type(*arrays_and_dtypes: Union[Dtype, Array]) -> Dtype: - return NotImplemented diff --git a/arrayfire/library/array_object.py b/arrayfire/library/array_object.py index 630f229..04643b3 100755 --- a/arrayfire/library/array_object.py +++ b/arrayfire/library/array_object.py @@ -11,10 +11,11 @@ from ..backend.c_library.constant_array import create_constant_array from ..backend.c_library.indexing import CIndexStructure, IndexStructure from ..backend.c_library.reduction_operations import count_all -from ..dtypes import CType +from ..dtypes import CType, Dtype from ..dtypes import bool as af_bool +from ..dtypes import c_api_value_to_dtype from ..dtypes import float32 as af_float32 -from ..dtypes.helpers import Dtype, c_api_value_to_dtype, str_to_dtype +from ..dtypes import str_to_dtype from .device import PointerSource # TODO use int | float in operators -> remove bool | complex support diff --git a/arrayfire/platform.py b/arrayfire/platform.py index 734f453..df0144e 100755 --- a/arrayfire/platform.py +++ b/arrayfire/platform.py @@ -14,7 +14,7 @@ def is_arch_x86() -> bool: return platform.architecture()[0][0:2] == "32" and (machine[-2:] == "86" or machine[0:3] == "arm") -class SupportedPlatforms(Enum): +class _SupportedPlatforms(Enum): windows = "Windows" darwin = "Darwin" # OSX linux = "Linux" @@ -49,8 +49,8 @@ def get_platform_config() -> PlatformConfig: except KeyError: cuda_path = None - if platform_name == SupportedPlatforms.windows.value or SupportedPlatforms.is_cygwin(platform_name): - if platform_name == SupportedPlatforms.windows.value: + if platform_name == _SupportedPlatforms.windows.value or _SupportedPlatforms.is_cygwin(platform_name): + if platform_name == _SupportedPlatforms.windows.value: # HACK Supressing crashes caused by missing dlls # http://stackoverflow.com/questions/8347266/missing-dll-print-message-instead-of-launching-a-popup # https://msdn.microsoft.com/en-us/library/windows/desktop/ms680621.aspx @@ -64,7 +64,7 @@ def get_platform_config() -> PlatformConfig: return PlatformConfig("", ".dll", af_path, cuda_found) - if platform_name == SupportedPlatforms.darwin.value: + if platform_name == _SupportedPlatforms.darwin.value: default_cuda_path = Path("/usr/local/cuda/") if not af_path: @@ -75,7 +75,7 @@ def get_platform_config() -> PlatformConfig: return PlatformConfig("lib", f".{ARRAYFIRE_VER_MAJOR}.dylib", af_path, cuda_found) - if platform_name == SupportedPlatforms.linux.value: + if platform_name == _SupportedPlatforms.linux.value: default_cuda_path = Path("/usr/local/cuda/") if not af_path: diff --git a/arrayfire/version.py b/arrayfire/version.py index 148cbb7..85630b2 100644 --- a/arrayfire/version.py +++ b/arrayfire/version.py @@ -15,3 +15,4 @@ FORGE_VER_MAJOR = "1" ARRAYFIRE_VER_MAJOR = "3" ARRAYFIRE_VER_MINOR = "8" +ARRAYFIRE_VERSION = "{0}.{1}".format(ARRAYFIRE_VER_MAJOR, ARRAYFIRE_VER_MINOR) From 6621c8f30002a8a190397199322ad0b57a4dc5df Mon Sep 17 00:00:00 2001 From: Anton Chernyatevich Date: Fri, 11 Aug 2023 20:16:18 +0300 Subject: [PATCH 17/31] Change project structure --- arrayfire/__init__.py | 14 +- arrayfire/array_api/array_object.py | 11 +- arrayfire/{library => }/array_object.py | 152 +++++++++--------- arrayfire/backend/__init__.py | 72 --------- arrayfire/backend/c_library/__init__.py | 111 +++++++++++++ arrayfire/backend/c_library/constant_array.py | 2 +- arrayfire/backend/c_library/operators.py | 2 +- .../backend/c_library/reduction_operations.py | 2 +- arrayfire/backend/c_library/unsorted.py | 2 +- arrayfire/library/__init__.py | 3 - arrayfire/library/operators.py | 8 +- arrayfire/library/utils.py | 2 +- tests/array_object/test_initialization.py | 2 +- tests/array_object/test_methods.py | 2 +- tests/array_object/test_operators.py | 2 +- tests/test_operators.py | 2 +- 16 files changed, 214 insertions(+), 175 deletions(-) rename arrayfire/{library => }/array_object.py (83%) mode change 100644 => 100755 tests/array_object/test_initialization.py mode change 100644 => 100755 tests/array_object/test_operators.py mode change 100644 => 100755 tests/test_operators.py diff --git a/arrayfire/__init__.py b/arrayfire/__init__.py index 8cebdf0..686fa70 100755 --- a/arrayfire/__init__.py +++ b/arrayfire/__init__.py @@ -8,7 +8,7 @@ __arrayfire_version__ = ARRAYFIRE_VERSION __all__ += ["Array"] -from .library.array_object import Array +from .array_object import Array __all__ += [ "int8", @@ -57,16 +57,16 @@ "set_backend", ] -from .backend import ( - get_backend, - set_backend, - get_available_backends, - get_backend_count, - get_backend_id, +from .backend.api import get_backend +from .backend.helpers import ( get_active_backend, get_array_backend_name, get_array_device_id, + get_available_backends, + get_backend_count, + get_backend_id, get_device_id, get_dtype_size, get_size_of, + set_backend, ) diff --git a/arrayfire/array_api/array_object.py b/arrayfire/array_api/array_object.py index 1ec531d..8574d05 100755 --- a/arrayfire/array_api/array_object.py +++ b/arrayfire/array_api/array_object.py @@ -6,8 +6,15 @@ from arrayfire import Array as AFArray from arrayfire.array_api.constants import Device, NestedSequence, PyCapsule, SupportsBufferProtocol from arrayfire.array_api.dtypes import ( - all_dtypes, boolean_dtypes, complex_floating_dtypes, dtype_categories, floating_dtypes, integer_or_boolean_dtypes, - numeric_dtypes, promote_types) + all_dtypes, + boolean_dtypes, + complex_floating_dtypes, + dtype_categories, + floating_dtypes, + integer_or_boolean_dtypes, + numeric_dtypes, + promote_types, +) from arrayfire.dtypes import Dtype diff --git a/arrayfire/library/array_object.py b/arrayfire/array_object.py similarity index 83% rename from arrayfire/library/array_object.py rename to arrayfire/array_object.py index 04643b3..171511d 100755 --- a/arrayfire/library/array_object.py +++ b/arrayfire/array_object.py @@ -5,18 +5,11 @@ import enum from typing import Any, List, Optional, Tuple, Union -from .. import backend -from ..backend import ArrayBuffer -from ..backend.c_library import unsorted -from ..backend.c_library.constant_array import create_constant_array -from ..backend.c_library.indexing import CIndexStructure, IndexStructure -from ..backend.c_library.reduction_operations import count_all -from ..dtypes import CType, Dtype -from ..dtypes import bool as af_bool -from ..dtypes import c_api_value_to_dtype -from ..dtypes import float32 as af_float32 -from ..dtypes import str_to_dtype -from .device import PointerSource +from .backend import c_library as wrapper +from .backend.c_library.indexing import CIndexStructure, IndexStructure +from .backend.constants import ArrayBuffer +from .dtypes import CType, Dtype, c_api_value_to_dtype, float32, str_to_dtype +from .library.device import PointerSource # TODO use int | float in operators -> remove bool | complex support @@ -40,18 +33,18 @@ def __init__( if dtype is None: _no_initial_dtype = True - dtype = af_float32 + dtype = float32 if obj is None: if not shape: # shape is None or empty tuple - self.arr = unsorted.create_handle((), dtype) + self.arr = wrapper.create_handle((), dtype) return - self.arr = unsorted.create_handle(shape, dtype) + self.arr = wrapper.create_handle(shape, dtype) return if isinstance(obj, Array): - self.arr = unsorted.retain_array(obj.arr) + self.arr = wrapper.retain_array(obj.arr) return if isinstance(obj, py_array.array): @@ -88,13 +81,13 @@ def __init__( if not (offset or strides): if pointer_source == PointerSource.host: - self.arr = unsorted.create_array(shape, dtype, _array_buffer) + self.arr = wrapper.create_array(shape, dtype, _array_buffer) return - self.arr = unsorted.device_array(shape, dtype, _array_buffer) + self.arr = wrapper.device_array(shape, dtype, _array_buffer) return - self.arr = unsorted.create_strided_array( + self.arr = wrapper.create_strided_array( shape, dtype, _array_buffer, offset, strides, pointer_source # type: ignore[arg-type] ) @@ -133,7 +126,7 @@ def __neg__(self) -> Array: determined by Type Promotion Rules. """ - return _process_c_function(0, self, backend.sub) + return _process_c_function(0, self, wrapper.sub) def __add__(self, other: Union[int, float, Array], /) -> Array: """ @@ -152,7 +145,7 @@ def __add__(self, other: Union[int, float, Array], /) -> Array: An array containing the element-wise sums. The returned array must have a data type determined by Type Promotion Rules. """ - return _process_c_function(self, other, backend.add) + return _process_c_function(self, other, wrapper.add) def __sub__(self, other: Union[int, float, Array], /) -> Array: """ @@ -174,7 +167,7 @@ def __sub__(self, other: Union[int, float, Array], /) -> Array: An array containing the element-wise differences. The returned array must have a data type determined by Type Promotion Rules. """ - return _process_c_function(self, other, backend.sub) + return _process_c_function(self, other, wrapper.sub) def __mul__(self, other: Union[int, float, Array], /) -> Array: """ @@ -193,7 +186,7 @@ def __mul__(self, other: Union[int, float, Array], /) -> Array: An array containing the element-wise products. The returned array must have a data type determined by Type Promotion Rules. """ - return _process_c_function(self, other, backend.mul) + return _process_c_function(self, other, wrapper.mul) def __truediv__(self, other: Union[int, float, Array], /) -> Array: """ @@ -220,7 +213,7 @@ def __truediv__(self, other: Union[int, float, Array], /) -> Array: Specification-compliant libraries may choose to raise an error or return an array containing the element-wise results. If an array is returned, the array must have a real-valued floating-point data type. """ - return _process_c_function(self, other, backend.div) + return _process_c_function(self, other, wrapper.div) def __floordiv__(self, other: Union[int, float, Array], /) -> Array: # TODO @@ -250,7 +243,7 @@ def __mod__(self, other: Union[int, float, Array], /) -> Array: - For input arrays which promote to an integer data type, the result of division by zero is unspecified and thus implementation-defined. """ - return _process_c_function(self, other, backend.mod) + return _process_c_function(self, other, wrapper.mod) def __pow__(self, other: Union[int, float, Array], /) -> Array: """ @@ -272,7 +265,7 @@ def __pow__(self, other: Union[int, float, Array], /) -> Array: An array containing the element-wise results. The returned array must have a data type determined by Type Promotion Rules. """ - return _process_c_function(self, other, backend.pow) + return _process_c_function(self, other, wrapper.pow) # Array Operators @@ -298,7 +291,7 @@ def __invert__(self) -> Array: """ # FIXME out = Array() - out.arr = backend.bitnot(self.arr) + out.arr = wrapper.bitnot(self.arr) return out def __and__(self, other: Union[int, bool, Array], /) -> Array: @@ -319,7 +312,7 @@ def __and__(self, other: Union[int, bool, Array], /) -> Array: An array containing the element-wise results. The returned array must have a data type determined by Type Promotion Rules. """ - return _process_c_function(self, other, backend.bitand) + return _process_c_function(self, other, wrapper.bitand) def __or__(self, other: Union[int, bool, Array], /) -> Array: """ @@ -339,7 +332,7 @@ def __or__(self, other: Union[int, bool, Array], /) -> Array: An array containing the element-wise results. The returned array must have a data type determined by Type Promotion Rules. """ - return _process_c_function(self, other, backend.bitor) + return _process_c_function(self, other, wrapper.bitor) def __xor__(self, other: Union[int, bool, Array], /) -> Array: """ @@ -359,7 +352,7 @@ def __xor__(self, other: Union[int, bool, Array], /) -> Array: An array containing the element-wise results. The returned array must have a data type determined by Type Promotion Rules. """ - return _process_c_function(self, other, backend.bitxor) + return _process_c_function(self, other, wrapper.bitxor) def __lshift__(self, other: Union[int, Array], /) -> Array: """ @@ -379,7 +372,7 @@ def __lshift__(self, other: Union[int, Array], /) -> Array: out : Array An array containing the element-wise results. The returned array must have the same data type as self. """ - return _process_c_function(self, other, backend.bitshiftl) + return _process_c_function(self, other, wrapper.bitshiftl) def __rshift__(self, other: Union[int, Array], /) -> Array: """ @@ -399,7 +392,7 @@ def __rshift__(self, other: Union[int, Array], /) -> Array: out : Array An array containing the element-wise results. The returned array must have the same data type as self. """ - return _process_c_function(self, other, backend.bitshiftr) + return _process_c_function(self, other, wrapper.bitshiftr) # Comparison Operators @@ -420,7 +413,7 @@ def __lt__(self, other: Union[int, float, Array], /) -> Array: out : Array An array containing the element-wise results. The returned array must have a data type of bool. """ - return _process_c_function(self, other, backend.lt) + return _process_c_function(self, other, wrapper.lt) def __le__(self, other: Union[int, float, Array], /) -> Array: """ @@ -439,7 +432,7 @@ def __le__(self, other: Union[int, float, Array], /) -> Array: out : Array An array containing the element-wise results. The returned array must have a data type of bool. """ - return _process_c_function(self, other, backend.le) + return _process_c_function(self, other, wrapper.le) def __gt__(self, other: Union[int, float, Array], /) -> Array: """ @@ -458,7 +451,7 @@ def __gt__(self, other: Union[int, float, Array], /) -> Array: out : Array An array containing the element-wise results. The returned array must have a data type of bool. """ - return _process_c_function(self, other, backend.gt) + return _process_c_function(self, other, wrapper.gt) def __ge__(self, other: Union[int, float, Array], /) -> Array: """ @@ -477,7 +470,7 @@ def __ge__(self, other: Union[int, float, Array], /) -> Array: out : Array An array containing the element-wise results. The returned array must have a data type of bool. """ - return _process_c_function(self, other, backend.ge) + return _process_c_function(self, other, wrapper.ge) def __eq__(self, other: Union[int, float, bool, Array], /) -> Array: # type: ignore[override] """ @@ -496,7 +489,7 @@ def __eq__(self, other: Union[int, float, bool, Array], /) -> Array: # type: ig out : Array An array containing the element-wise results. The returned array must have a data type of bool. """ - return _process_c_function(self, other, backend.eq) + return _process_c_function(self, other, wrapper.eq) def __ne__(self, other: Union[int, float, bool, Array], /) -> Array: # type: ignore[override] """ @@ -515,7 +508,7 @@ def __ne__(self, other: Union[int, float, bool, Array], /) -> Array: # type: ig out : Array An array containing the element-wise results. The returned array must have a data type of bool. """ - return _process_c_function(self, other, backend.neq) + return _process_c_function(self, other, wrapper.neq) # Reflected Arithmetic Operators @@ -523,25 +516,25 @@ def __radd__(self, other: Array, /) -> Array: """ Return other + self. """ - return _process_c_function(other, self, backend.add) + return _process_c_function(other, self, wrapper.add) def __rsub__(self, other: Array, /) -> Array: """ Return other - self. """ - return _process_c_function(other, self, backend.sub) + return _process_c_function(other, self, wrapper.sub) def __rmul__(self, other: Array, /) -> Array: """ Return other * self. """ - return _process_c_function(other, self, backend.mul) + return _process_c_function(other, self, wrapper.mul) def __rtruediv__(self, other: Array, /) -> Array: """ Return other / self. """ - return _process_c_function(other, self, backend.div) + return _process_c_function(other, self, wrapper.div) def __rfloordiv__(self, other: Array, /) -> Array: # TODO @@ -551,13 +544,13 @@ def __rmod__(self, other: Array, /) -> Array: """ Return other % self. """ - return _process_c_function(other, self, backend.mod) + return _process_c_function(other, self, wrapper.mod) def __rpow__(self, other: Array, /) -> Array: """ Return other ** self. """ - return _process_c_function(other, self, backend.pow) + return _process_c_function(other, self, wrapper.pow) # Reflected Array Operators @@ -571,31 +564,31 @@ def __rand__(self, other: Array, /) -> Array: """ Return other & self. """ - return _process_c_function(other, self, backend.bitand) + return _process_c_function(other, self, wrapper.bitand) def __ror__(self, other: Array, /) -> Array: """ Return other | self. """ - return _process_c_function(other, self, backend.bitor) + return _process_c_function(other, self, wrapper.bitor) def __rxor__(self, other: Array, /) -> Array: """ Return other ^ self. """ - return _process_c_function(other, self, backend.bitxor) + return _process_c_function(other, self, wrapper.bitxor) def __rlshift__(self, other: Array, /) -> Array: """ Return other << self. """ - return _process_c_function(other, self, backend.bitshiftl) + return _process_c_function(other, self, wrapper.bitshiftl) def __rrshift__(self, other: Array, /) -> Array: """ Return other >> self. """ - return _process_c_function(other, self, backend.bitshiftr) + return _process_c_function(other, self, wrapper.bitshiftr) # In-place Arithmetic Operators @@ -604,25 +597,25 @@ def __iadd__(self, other: Union[int, float, Array], /) -> Array: """ Return self += other. """ - return _process_c_function(self, other, backend.add) + return _process_c_function(self, other, wrapper.add) def __isub__(self, other: Union[int, float, Array], /) -> Array: """ Return self -= other. """ - return _process_c_function(self, other, backend.sub) + return _process_c_function(self, other, wrapper.sub) def __imul__(self, other: Union[int, float, Array], /) -> Array: """ Return self *= other. """ - return _process_c_function(self, other, backend.mul) + return _process_c_function(self, other, wrapper.mul) def __itruediv__(self, other: Union[int, float, Array], /) -> Array: """ Return self /= other. """ - return _process_c_function(self, other, backend.div) + return _process_c_function(self, other, wrapper.div) def __ifloordiv__(self, other: Union[int, float, Array], /) -> Array: # TODO @@ -632,13 +625,13 @@ def __imod__(self, other: Union[int, float, Array], /) -> Array: """ Return self %= other. """ - return _process_c_function(self, other, backend.mod) + return _process_c_function(self, other, wrapper.mod) def __ipow__(self, other: Union[int, float, Array], /) -> Array: """ Return self **= other. """ - return _process_c_function(self, other, backend.pow) + return _process_c_function(self, other, wrapper.pow) # In-place Array Operators @@ -652,31 +645,31 @@ def __iand__(self, other: Union[int, bool, Array], /) -> Array: """ Return self &= other. """ - return _process_c_function(self, other, backend.bitand) + return _process_c_function(self, other, wrapper.bitand) def __ior__(self, other: Union[int, bool, Array], /) -> Array: """ Return self |= other. """ - return _process_c_function(self, other, backend.bitor) + return _process_c_function(self, other, wrapper.bitor) def __ixor__(self, other: Union[int, bool, Array], /) -> Array: """ Return self ^= other. """ - return _process_c_function(self, other, backend.bitxor) + return _process_c_function(self, other, wrapper.bitxor) def __ilshift__(self, other: Union[int, Array], /) -> Array: """ Return self <<= other. """ - return _process_c_function(self, other, backend.bitshiftl) + return _process_c_function(self, other, wrapper.bitshiftl) def __irshift__(self, other: Union[int, Array], /) -> Array: """ Return self >>= other. """ - return _process_c_function(self, other, backend.bitshiftr) + return _process_c_function(self, other, wrapper.bitshiftr) # Methods @@ -724,19 +717,22 @@ def __getitem__(self, key: IndexKey, /) -> Array: out : Array An array containing the accessed value(s). The returned array must have the same data type as self. """ + + from .dtypes import bool + # TODO # API Specification - key: Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], array]. # consider using af.span to replace ellipsis during refactoring out = Array() ndims = self.ndim - if isinstance(key, Array) and key == af_bool.c_api_value: + if isinstance(key, Array) and key == bool.c_api_value: ndims = 1 - if count_all(key.arr) == 0: # HACK was count() method before + if wrapper.count_all(key.arr) == 0: # HACK was count() method before return out # HACK known issue - out.arr = unsorted.index_gen(self.arr, ndims, key, _get_indices(key)) # type: ignore[arg-type] + out.arr = wrapper.index_gen(self.arr, ndims, key, _get_indices(key)) # type: ignore[arg-type] return out def __index__(self) -> int: @@ -758,12 +754,12 @@ def __str__(self) -> str: # TODO change the look of array str. E.g., like np.array # if not _in_display_dims_limit(self.shape): # return _metadata_string(self.dtype, self.shape) - return _metadata_string(self.dtype) + unsorted.array_as_str(self.arr) + return _metadata_string(self.dtype) + wrapper.array_as_str(self.arr) def __repr__(self) -> str: # return _metadata_string(self.dtype, self.shape) # TODO change the look of array representation. E.g., like np.array - return unsorted.array_as_str(self.arr) + return wrapper.array_as_str(self.arr) def to_device(self, device: Any, /, *, stream: Union[int, Any] = None) -> Array: # TODO implementation and change device type from Any to Device @@ -781,7 +777,7 @@ def dtype(self) -> Dtype: out : Dtype Array data type. """ - return c_api_value_to_dtype(unsorted.get_ctype(self.arr)) + return c_api_value_to_dtype(wrapper.get_ctype(self.arr)) @property def device(self) -> Any: @@ -814,7 +810,7 @@ def T(self) -> Array: # TODO add check if out.dtype == self.dtype out = Array() - out.arr = unsorted.transpose(self.arr, False) + out.arr = wrapper.transpose(self.arr, False) return out @property @@ -832,7 +828,7 @@ def size(self) -> int: - This must equal the product of the array's dimensions. """ # NOTE previously - elements() - return unsorted.get_elements(self.arr) + return wrapper.get_elements(self.arr) @property def ndim(self) -> int: @@ -842,7 +838,7 @@ def ndim(self) -> int: out : int Number of array dimensions (axes). """ - return unsorted.get_numdims(self.arr) + return wrapper.get_numdims(self.arr) @property def shape(self) -> Tuple[int, ...]: @@ -855,7 +851,7 @@ def shape(self) -> Tuple[int, ...]: Array dimensions. """ # NOTE skipping passing any None values - return unsorted.get_dims(self.arr)[: self.ndim] + return wrapper.get_dims(self.arr)[: self.ndim] def scalar(self) -> Union[None, int, float, bool, complex]: """ @@ -865,20 +861,20 @@ def scalar(self) -> Union[None, int, float, bool, complex]: if self.is_empty(): return None - return unsorted.get_scalar(self.arr, self.dtype) + return wrapper.get_scalar(self.arr, self.dtype) def is_empty(self) -> bool: """ Check if the array is empty i.e. it has no elements. """ - return unsorted.is_empty(self.arr) + return wrapper.is_empty(self.arr) def to_list(self, row_major: bool = False) -> List[Union[None, int, float, bool, complex]]: if self.is_empty(): return [] array = _reorder(self) if row_major else self - ctypes_array = unsorted.get_data_ptr(array.arr, array.size, array.dtype) + ctypes_array = wrapper.get_data_ptr(array.arr, array.size, array.dtype) if array.ndim == 1: return ctypes_array[:] @@ -899,7 +895,7 @@ def to_ctype_array(self, row_major: bool = False) -> ctypes.Array: raise RuntimeError("Can not convert an empty array to ctype.") array = _reorder(self) if row_major else self - return unsorted.get_data_ptr(array.arr, array.size, array.dtype) + return wrapper.get_data_ptr(array.arr, array.size, array.dtype) def copy(self) -> Array: # BUG: this is not a deep copy """ @@ -910,7 +906,7 @@ def copy(self) -> Array: # BUG: this is not a deep copy out: af.Array() An identical copy of self. """ - self.arr = unsorted.copy_array(self.arr) + self.arr = wrapper.copy_array(self.arr) return self @classmethod @@ -928,7 +924,7 @@ def _reorder(array: Array) -> Array: if array.ndim == 1: return array - return Array(unsorted.reorder(array.arr, array.ndim)) + return Array(wrapper.reorder(array.arr, array.ndim)) def _metadata_string(dtype: Dtype, dims: Optional[Tuple[int, ...]] = None) -> str: @@ -944,10 +940,10 @@ def _process_c_function(lhs: Union[int, float, Array], rhs: Union[int, float, Ar elif isinstance(lhs, Array) and isinstance(rhs, (int, float)): lhs_array = lhs.arr - rhs_array = create_constant_array(rhs, lhs.shape, lhs.dtype) + rhs_array = wrapper.create_constant_array(rhs, lhs.shape, lhs.dtype) elif isinstance(lhs, (int, float)) and isinstance(rhs, Array): - lhs_array = create_constant_array(lhs, rhs.shape, rhs.dtype) + lhs_array = wrapper.create_constant_array(lhs, rhs.shape, rhs.dtype) rhs_array = rhs.arr else: diff --git a/arrayfire/backend/__init__.py b/arrayfire/backend/__init__.py index 1545d3b..e69de29 100755 --- a/arrayfire/backend/__init__.py +++ b/arrayfire/backend/__init__.py @@ -1,72 +0,0 @@ -__all__ = [ - # Backend Constants - "ArrayBuffer", - # Operators - "add", - "sub", - "mul", - "div", - "mod", - "pow", - "bitnot", - "bitand", - "bitor", - "bitxor", - "bitshiftl", - "bitshiftr", - "lt", - "le", - "gt", - "ge", - "eq", - "neq", - # Backend API - "BackendPlatform", - "get_backend", - # Backend Helpers - "get_active_backend", # DeprecationWarning - "get_array_backend_name", - "get_array_device_id", - "get_available_backends", # DeprecationWarning - "get_backend_count", - "get_backend_id", # DeprecationWarning - "get_device_id", # DeprecationWarning - "get_dtype_size", - "get_size_of", # DeprecationWarning - "set_backend", -] - -from .api import BackendPlatform, get_backend -from .c_library.operators import ( - add, - bitand, - bitnot, - bitor, - bitshiftl, - bitshiftr, - bitxor, - div, - eq, - ge, - gt, - le, - lt, - mod, - mul, - neq, - pow, - sub, -) -from .constants import ArrayBuffer -from .helpers import ( - get_active_backend, - get_array_backend_name, - get_array_device_id, - get_available_backends, - get_backend_count, - get_backend_id, - get_device_id, - get_dtype_size, - get_size_of, - set_backend, -) diff --git a/arrayfire/backend/c_library/__init__.py b/arrayfire/backend/c_library/__init__.py index e69de29..a1f4fdf 100755 --- a/arrayfire/backend/c_library/__init__.py +++ b/arrayfire/backend/c_library/__init__.py @@ -0,0 +1,111 @@ +# flake8: noqa + +__all__ = [ + "add", + "sub", + "mul", + "div", + "mod", + "pow", + "bitnot", + "bitand", + "bitor", + "bitxor", + "bitshiftl", + "bitshiftr", + "lt", + "le", + "gt", + "ge", + "eq", + "neq", +] + +from .operators import ( + add, + bitand, + bitnot, + bitor, + bitshiftl, + bitshiftr, + bitxor, + div, + eq, + ge, + gt, + le, + lt, + mod, + mul, + neq, + pow, + sub, +) + +__all__ += [ + "create_array", + "create_handle", + "create_strided_array", + "device_array", + "get_ctype", + "get_elements", + "get_numdims", + "retain_array", + "get_dims", + "get_scalar", + "is_empty", + "get_data_ptr", + "copy_array", + "index_gen", + "transpose", + "reorder", + "array_as_str", + "where", + "randu", + "get_last_error", + "set_backend", + "get_backend_count", + "get_device_id", + "get_size_of", + "get_backend_id", +] + +from .unsorted import ( + array_as_str, + copy_array, + create_array, + create_handle, + create_strided_array, + device_array, + get_backend_count, + get_backend_id, + get_ctype, + get_data_ptr, + get_device_id, + get_dims, + get_elements, + get_last_error, + get_numdims, + get_scalar, + get_size_of, + index_gen, + is_empty, + randu, + reorder, + retain_array, + set_backend, + transpose, + where, +) + +__all__ += ["safe_call"] + +from .error_handler import safe_call + +__all__ += ["count_all"] + +from .reduction_operations import count_all + +__all__ += ["create_constant_array"] + +from .constant_array import create_constant_array diff --git a/arrayfire/backend/c_library/constant_array.py b/arrayfire/backend/c_library/constant_array.py index 5454cad..3192210 100755 --- a/arrayfire/backend/c_library/constant_array.py +++ b/arrayfire/backend/c_library/constant_array.py @@ -9,7 +9,7 @@ from .error_handler import safe_call if TYPE_CHECKING: - from arrayfire.library.array_object import AFArrayType + from arrayfire.array_object import AFArrayType def _constant_complex(number: Union[int, float], shape: Tuple[int, ...], dtype: Dtype, /) -> AFArrayType: diff --git a/arrayfire/backend/c_library/operators.py b/arrayfire/backend/c_library/operators.py index 15a5f24..8105f90 100755 --- a/arrayfire/backend/c_library/operators.py +++ b/arrayfire/backend/c_library/operators.py @@ -9,7 +9,7 @@ from .error_handler import safe_call if TYPE_CHECKING: - from arrayfire.library.array_object import AFArrayType + from arrayfire.array_object import AFArrayType # Arithmetic Operators diff --git a/arrayfire/backend/c_library/reduction_operations.py b/arrayfire/backend/c_library/reduction_operations.py index b4ff67b..cf5d953 100755 --- a/arrayfire/backend/c_library/reduction_operations.py +++ b/arrayfire/backend/c_library/reduction_operations.py @@ -8,7 +8,7 @@ from .error_handler import safe_call if TYPE_CHECKING: - from arrayfire.library.array_object import AFArrayType + from arrayfire.array_object import AFArrayType def count_all(x: AFArrayType) -> Union[int, float, complex]: diff --git a/arrayfire/backend/c_library/unsorted.py b/arrayfire/backend/c_library/unsorted.py index d38d4fa..e5c5b19 100755 --- a/arrayfire/backend/c_library/unsorted.py +++ b/arrayfire/backend/c_library/unsorted.py @@ -11,7 +11,7 @@ from .error_handler import safe_call if TYPE_CHECKING: - from arrayfire.library.array_object import AFArrayType + from arrayfire.array_object import AFArrayType # Array management diff --git a/arrayfire/library/__init__.py b/arrayfire/library/__init__.py index a040af0..e69de29 100755 --- a/arrayfire/library/__init__.py +++ b/arrayfire/library/__init__.py @@ -1,3 +0,0 @@ -__all__ = ["Array"] - -from .array_object import Array diff --git a/arrayfire/library/operators.py b/arrayfire/library/operators.py index 1f26c68..fdfaeb0 100644 --- a/arrayfire/library/operators.py +++ b/arrayfire/library/operators.py @@ -1,7 +1,7 @@ from typing import Callable -from .. import backend -from .array_object import Array +from arrayfire import Array +from arrayfire.backend import c_library as wrapper class return_copy: @@ -17,9 +17,9 @@ def __call__(self, x1: Array, x2: Array) -> Array: @return_copy def add(x1: Array, x2: Array, /) -> Array: - return backend.add(x1, x2) # type: ignore[arg-type, return-value] # FIXME + return wrapper.add(x1, x2) # type: ignore[arg-type, return-value, no-any-return] # FIXME @return_copy def sub(x1: Array, x2: Array, /) -> Array: - return backend.sub(x1, x2) # type: ignore[arg-type, return-value] + return wrapper.sub(x1, x2) # type: ignore[arg-type, return-value, no-any-return] diff --git a/arrayfire/library/utils.py b/arrayfire/library/utils.py index 195111d..b67672f 100644 --- a/arrayfire/library/utils.py +++ b/arrayfire/library/utils.py @@ -1,6 +1,6 @@ from typing import Tuple, Union -from .array_object import Array +from arrayfire import Array # TODO implement functions diff --git a/tests/array_object/test_initialization.py b/tests/array_object/test_initialization.py old mode 100644 new mode 100755 index c2adcd1..3c1436d --- a/tests/array_object/test_initialization.py +++ b/tests/array_object/test_initialization.py @@ -4,8 +4,8 @@ import pytest +from arrayfire import Array from arrayfire.dtypes import Dtype, float32, int16 -from arrayfire.library.array_object import Array # TODO add tests for array arguments: device, offset, strides # TODO add tests for all supported dtypes on initialisation diff --git a/tests/array_object/test_methods.py b/tests/array_object/test_methods.py index c5ec273..3fc8498 100644 --- a/tests/array_object/test_methods.py +++ b/tests/array_object/test_methods.py @@ -1,4 +1,4 @@ -from arrayfire.library.array_object import Array +from arrayfire import Array def test_array_getitem_by_index() -> None: diff --git a/tests/array_object/test_operators.py b/tests/array_object/test_operators.py old mode 100644 new mode 100755 index 8ed93c1..81ae6a7 --- a/tests/array_object/test_operators.py +++ b/tests/array_object/test_operators.py @@ -3,8 +3,8 @@ import pytest +from arrayfire import Array from arrayfire.dtypes import bool as af_bool -from arrayfire.library.array_object import Array Operator = Callable[[Union[int, float, Array], Union[int, float, Array]], Array] diff --git a/tests/test_operators.py b/tests/test_operators.py old mode 100644 new mode 100755 index 1e34a22..188828a --- a/tests/test_operators.py +++ b/tests/test_operators.py @@ -1,7 +1,7 @@ from typing import Any +from arrayfire import Array from arrayfire.library import operators -from arrayfire.library.array_object import Array class TestArithmeticOperators: From 985f333b7043092c871152c80367cc158c04007e Mon Sep 17 00:00:00 2001 From: Anton Chernyatevich Date: Wed, 16 Aug 2023 15:35:25 +0300 Subject: [PATCH 18/31] Refactor to avoid missunderstanding with import * --- arrayfire/__init__.py | 8 +- arrayfire/array_api/__init__.py | 6 +- .../{array_object.py => _array_object.py} | 4 +- .../array_api/{constants.py => _constants.py} | 0 arrayfire/array_api/_creation_function.py | 161 ++++++++++++ arrayfire/array_api/_data_type_functions.py | 82 ++++++ arrayfire/array_api/{dtypes.py => _dtypes.py} | 0 arrayfire/array_api/_elementwise_functions.py | 242 ++++++++++++++++++ arrayfire/array_api/_indexing_functions.py | 7 + .../array_api/_manipulation_functions.py | 43 ++++ arrayfire/array_api/_searching_functions.py | 21 ++ arrayfire/array_api/_set_functions.py | 38 +++ arrayfire/array_api/_sorting_functions.py | 11 + arrayfire/array_api/_statistical_functions.py | 80 ++++++ arrayfire/array_api/_utility_functions.py | 25 ++ arrayfire/array_api/creation_function.py | 52 ---- arrayfire/array_api/data_type_functions.py | 35 --- .../tests/fixme_test_array_object.py | 8 +- .../tests/fixme_test_elementwise_functions.py | 105 ++++++++ .../tests/test_creation_functions.py | 4 +- arrayfire/array_object.py | 18 +- arrayfire/backend/_backend.py | 239 +++++++++++++++++ .../{helpers.py => _backend_functions.py} | 28 +- .../{c_library => _clib_wrapper}/__init__.py | 10 +- .../_constant_array.py} | 14 +- .../_error_handler.py} | 4 +- .../_indexing.py} | 6 +- .../_operators.py} | 40 +-- .../_reduction_operations.py} | 6 +- .../_unsorted.py} | 65 +++-- arrayfire/backend/api.py | 137 ---------- arrayfire/backend/constants.py | 7 - arrayfire/dtypes.py | 6 +- arrayfire/library/operators.py | 2 +- arrayfire/platform.py | 100 -------- arrayfire/version.py | 4 +- tests/test_operators.py | 4 +- 37 files changed, 1177 insertions(+), 445 deletions(-) rename arrayfire/array_api/{array_object.py => _array_object.py} (97%) rename arrayfire/array_api/{constants.py => _constants.py} (100%) create mode 100755 arrayfire/array_api/_creation_function.py create mode 100755 arrayfire/array_api/_data_type_functions.py rename arrayfire/array_api/{dtypes.py => _dtypes.py} (100%) create mode 100755 arrayfire/array_api/_elementwise_functions.py create mode 100755 arrayfire/array_api/_indexing_functions.py create mode 100755 arrayfire/array_api/_manipulation_functions.py create mode 100755 arrayfire/array_api/_searching_functions.py create mode 100755 arrayfire/array_api/_set_functions.py create mode 100755 arrayfire/array_api/_sorting_functions.py create mode 100755 arrayfire/array_api/_statistical_functions.py create mode 100755 arrayfire/array_api/_utility_functions.py delete mode 100755 arrayfire/array_api/creation_function.py delete mode 100755 arrayfire/array_api/data_type_functions.py create mode 100755 arrayfire/array_api/tests/fixme_test_elementwise_functions.py create mode 100644 arrayfire/backend/_backend.py rename arrayfire/backend/{helpers.py => _backend_functions.py} (80%) rename arrayfire/backend/{c_library => _clib_wrapper}/__init__.py (82%) rename arrayfire/backend/{c_library/constant_array.py => _clib_wrapper/_constant_array.py} (89%) rename arrayfire/backend/{c_library/error_handler.py => _clib_wrapper/_error_handler.py} (67%) rename arrayfire/backend/{c_library/indexing.py => _clib_wrapper/_indexing.py} (94%) rename arrayfire/backend/{c_library/operators.py => _clib_wrapper/_operators.py} (77%) rename arrayfire/backend/{c_library/reduction_operations.py => _clib_wrapper/_reduction_operations.py} (77%) rename arrayfire/backend/{c_library/unsorted.py => _clib_wrapper/_unsorted.py} (78%) delete mode 100644 arrayfire/backend/api.py delete mode 100755 arrayfire/backend/constants.py delete mode 100755 arrayfire/platform.py diff --git a/arrayfire/__init__.py b/arrayfire/__init__.py index 686fa70..be66dfa 100755 --- a/arrayfire/__init__.py +++ b/arrayfire/__init__.py @@ -57,8 +57,8 @@ "set_backend", ] -from .backend.api import get_backend -from .backend.helpers import ( +from .backend._backend import get_backend +from .backend._backend_functions import ( get_active_backend, get_array_backend_name, get_array_device_id, @@ -70,3 +70,7 @@ get_size_of, set_backend, ) + +__all__ += ["add", "sub"] + +from .library.operators import add, sub diff --git a/arrayfire/array_api/__init__.py b/arrayfire/array_api/__init__.py index d40c013..1591982 100755 --- a/arrayfire/array_api/__init__.py +++ b/arrayfire/array_api/__init__.py @@ -4,15 +4,15 @@ __all__ = ["__array_api_version__"] -from .constants import Device +from ._constants import Device __all__ += ["Device"] -from .creation_function import asarray +from ._creation_function import asarray __all__ += ["asarray"] -from .dtypes import ( +from ._dtypes import ( bool, complex64, complex128, diff --git a/arrayfire/array_api/array_object.py b/arrayfire/array_api/_array_object.py similarity index 97% rename from arrayfire/array_api/array_object.py rename to arrayfire/array_api/_array_object.py index 8574d05..78313c5 100755 --- a/arrayfire/array_api/array_object.py +++ b/arrayfire/array_api/_array_object.py @@ -4,8 +4,8 @@ from typing import Any, Optional, Tuple, Union from arrayfire import Array as AFArray -from arrayfire.array_api.constants import Device, NestedSequence, PyCapsule, SupportsBufferProtocol -from arrayfire.array_api.dtypes import ( +from arrayfire.array_api._constants import Device, NestedSequence, PyCapsule, SupportsBufferProtocol +from arrayfire.array_api._dtypes import ( all_dtypes, boolean_dtypes, complex_floating_dtypes, diff --git a/arrayfire/array_api/constants.py b/arrayfire/array_api/_constants.py similarity index 100% rename from arrayfire/array_api/constants.py rename to arrayfire/array_api/_constants.py diff --git a/arrayfire/array_api/_creation_function.py b/arrayfire/array_api/_creation_function.py new file mode 100755 index 0000000..9ea4e4f --- /dev/null +++ b/arrayfire/array_api/_creation_function.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +from typing import List, Optional, Tuple, Union + +from arrayfire import Array as AFArray +from arrayfire.array_api._array_object import Array +from arrayfire.array_api._constants import Device, NestedSequence, SupportsBufferProtocol +from arrayfire.array_api._dtypes import all_dtypes +from arrayfire.dtypes import Dtype +from arrayfire.library.device import PointerSource + + +def _check_valid_dtype(dtype: Optional[Dtype]) -> None: + # Note: Only spelling dtypes as the dtype objects is supported. + + # We use this instead of "dtype in _all_dtypes" because the dtype objects + # define equality with the sorts of things we want to disallow. + for d in (None,) + all_dtypes: + if dtype is d: + return + raise ValueError("dtype must be one of the supported dtypes") + + +def asarray( + obj: Union[Array, bool, int, float, complex, NestedSequence, SupportsBufferProtocol], + /, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + copy: Optional[bool] = None, +) -> Array: + _check_valid_dtype(dtype) + + # if device not in supported_devices: + # raise ValueError(f"Unsupported device {device!r}") + + if dtype is None and isinstance(obj, int) and (obj > 2**64 or obj < -(2**63)): + raise OverflowError("Integer out of bounds for array dtypes") + + if device == Device.cpu or device is None: + pointer_source = PointerSource.host + elif device == Device.gpu: + pointer_source = PointerSource.device + else: + raise ValueError(f"Unsupported device {device!r}") + + if isinstance(obj, int | float): + afarray = AFArray([obj], dtype=dtype, shape=(1,), pointer_source=pointer_source) + return Array._new(afarray) + + afarray = AFArray(obj, dtype=dtype, pointer_source=pointer_source) + return Array._new(afarray) + + +def arange( + start: Union[int, float], + /, + stop: Optional[Union[int, float]] = None, + step: Union[int, float] = 1, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> Array: + return NotImplemented + + +def empty( + shape: Union[int, Tuple[int, ...]], + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> Array: + return NotImplemented + + +def empty_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: + return NotImplemented + + +def eye( + n_rows: int, + n_cols: Optional[int] = None, + /, + *, + k: int = 0, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> Array: + return NotImplemented + + +def full( + shape: Union[int, Tuple[int, ...]], + fill_value: Union[int, float], + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> Array: + return NotImplemented + + +def full_like( + x: Array, + /, + fill_value: Union[int, float], + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> Array: + return NotImplemented + + +def linspace( + start: Union[int, float], + stop: Union[int, float], + /, + num: int, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + endpoint: bool = True, +) -> Array: + return NotImplemented + + +def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]: + return NotImplemented + + +def ones( + shape: Union[int, Tuple[int, ...]], + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> Array: + return NotImplemented + + +def ones_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: + return NotImplemented + + +def tril(x: Array, /, *, k: int = 0) -> Array: + return NotImplemented + + +def triu(x: Array, /, *, k: int = 0) -> Array: + return NotImplemented + + +def zeros( + shape: Union[int, Tuple[int, ...]], + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, +) -> Array: + return NotImplemented + + +def zeros_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array: + return NotImplemented diff --git a/arrayfire/array_api/_data_type_functions.py b/arrayfire/array_api/_data_type_functions.py new file mode 100755 index 0000000..5dc4d00 --- /dev/null +++ b/arrayfire/array_api/_data_type_functions.py @@ -0,0 +1,82 @@ +from dataclasses import dataclass +from typing import List, Tuple, Union + +from arrayfire import Array as AFArray +from arrayfire.array_api._array_object import Array +from arrayfire.array_api._dtypes import all_dtypes, promote_types +from arrayfire.dtypes import Dtype + + +def astype(x: Array, dtype: Dtype, /, *, copy: bool = True) -> Array: + return NotImplemented + + +def broadcast_arrays(*arrays: Array) -> List[Array]: + return NotImplemented + + +def broadcast_to(x: Array, /, shape: Tuple[int, ...]) -> Array: + return NotImplemented + + +def can_cast(from_: Union[Dtype, Array], to: Dtype, /) -> bool: + return NotImplemented + + +@dataclass +class finfo_object: + bits: int + eps: float + max: float + min: float + smallest_normal: float + dtype: Dtype + + +@dataclass +class iinfo_object: + bits: int + max: int + min: int + dtype: Dtype + + +def finfo(type: Union[Dtype, Array], /) -> finfo_object: + return NotImplemented + + +def iinfo(type: Union[Dtype, Array], /) -> iinfo_object: + return NotImplemented + + +def isdtype(dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]]) -> bool: + return NotImplemented + + +def result_type(*arrays_and_dtypes: Union[Array, Dtype]) -> Dtype: + """ + Array API compatible wrapper for :py:func:`np.result_type `. + + See its docstring for more information. + """ + # Note: we use a custom implementation that gives only the type promotions + # required by the spec rather than using np.result_type. NumPy implements + # too many extra type promotions like int64 + uint64 -> float64, and does + # value-based casting on scalar arrays. + A = [] + for a in arrays_and_dtypes: + if isinstance(a, Array): + a = a.dtype + elif isinstance(a, AFArray) or a not in all_dtypes: + raise TypeError("result_type() inputs must be array_api arrays or dtypes") + A.append(a) + + if len(A) == 0: + raise ValueError("at least one array or dtype is required") + elif len(A) == 1: + return A[0] + else: + t = A[0] + for t2 in A[1:]: + t = promote_types(t, t2) + return t diff --git a/arrayfire/array_api/dtypes.py b/arrayfire/array_api/_dtypes.py similarity index 100% rename from arrayfire/array_api/dtypes.py rename to arrayfire/array_api/_dtypes.py diff --git a/arrayfire/array_api/_elementwise_functions.py b/arrayfire/array_api/_elementwise_functions.py new file mode 100755 index 0000000..9903425 --- /dev/null +++ b/arrayfire/array_api/_elementwise_functions.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +from ._array_object import Array + +from arrayfire.library import operators + + +def abs(x: Array, /) -> Array: + return NotImplemented + + +def acos(x: Array, /) -> Array: + return NotImplemented + + +def acosh(x: Array, /) -> Array: + return NotImplemented + + +def add(x1: Array, x2: Array, /) -> Array: + return Array._new(operators.add(x1._array, x2._array)) + + +def asin(x: Array, /) -> Array: + return NotImplemented + + +def asinh(x: Array, /) -> Array: + return NotImplemented + + +def atan(x: Array, /) -> Array: + return NotImplemented + + +def atan2(x1: Array, x2: Array, /) -> Array: + return NotImplemented + + +def atanh(x: Array, /) -> Array: + return NotImplemented + + +def bitwise_and(x1: Array, x2: Array, /) -> Array: + return NotImplemented + + +def bitwise_left_shift(x1: Array, x2: Array, /) -> Array: + return NotImplemented + + +def bitwise_invert(x: Array, /) -> Array: + return NotImplemented + + +def bitwise_or(x1: Array, x2: Array, /) -> Array: + return NotImplemented + + +def bitwise_right_shift(x1: Array, x2: Array, /) -> Array: + return NotImplemented + + +def bitwise_xor(x1: Array, x2: Array, /) -> Array: + return NotImplemented + + +def ceil(x: Array, /) -> Array: + return NotImplemented + + +def conj(x: Array, /) -> Array: + return NotImplemented + + +def cos(x: Array, /) -> Array: + return NotImplemented + + +def cosh(x: Array, /) -> Array: + return NotImplemented + + +def divide(x1: Array, x2: Array, /) -> Array: + return NotImplemented + + +def equal(x1: Array, x2: Array, /) -> Array: + return NotImplemented + + +def exp(x: Array, /) -> Array: + return NotImplemented + + +def expm1(x: Array, /) -> Array: + return NotImplemented + + +def floor(x: Array, /) -> Array: + return NotImplemented + + +def floor_divide(x1: Array, x2: Array, /) -> Array: + return NotImplemented + + +def greater(x1: Array, x2: Array, /) -> Array: + return NotImplemented + + +def greater_equal(x1: Array, x2: Array, /) -> Array: + return NotImplemented + + +def imag(x: Array, /) -> Array: + return NotImplemented + + +def isfinite(x: Array, /) -> Array: + return NotImplemented + + +def isinf(x: Array, /) -> Array: + return NotImplemented + + +def isnan(x: Array, /) -> Array: + return NotImplemented + + +def less(x1: Array, x2: Array, /) -> Array: + return NotImplemented + + +def less_equal(x1: Array, x2: Array, /) -> Array: + return NotImplemented + + +def log(x: Array, /) -> Array: + return NotImplemented + + +def log1p(x: Array, /) -> Array: + return NotImplemented + + +def log2(x: Array, /) -> Array: + return NotImplemented + + +def log10(x: Array, /) -> Array: + return NotImplemented + + +def logaddexp(x1: Array, x2: Array) -> Array: + return NotImplemented + + +def logical_and(x1: Array, x2: Array, /) -> Array: + return NotImplemented + + +def logical_not(x: Array, /) -> Array: + return NotImplemented + + +def logical_or(x1: Array, x2: Array, /) -> Array: + return NotImplemented + + +def logical_xor(x1: Array, x2: Array, /) -> Array: + return NotImplemented + + +def multiply(x1: Array, x2: Array, /) -> Array: + return NotImplemented + + +def negative(x: Array, /) -> Array: + return NotImplemented + + +def not_equal(x1: Array, x2: Array, /) -> Array: + return NotImplemented + + +def positive(x: Array, /) -> Array: + return NotImplemented + + +# Note: the function name is different here +def pow(x1: Array, x2: Array, /) -> Array: + return NotImplemented + + +def real(x: Array, /) -> Array: + return NotImplemented + + +def remainder(x1: Array, x2: Array, /) -> Array: + return NotImplemented + + +def round(x: Array, /) -> Array: + return NotImplemented + + +def sign(x: Array, /) -> Array: + return NotImplemented + + +def sin(x: Array, /) -> Array: + return NotImplemented + + +def sinh(x: Array, /) -> Array: + return NotImplemented + + +def square(x: Array, /) -> Array: + return NotImplemented + + +def sqrt(x: Array, /) -> Array: + return NotImplemented + + +def subtract(x1: Array, x2: Array, /) -> Array: + return NotImplemented + + +def tan(x: Array, /) -> Array: + return NotImplemented + + +def tanh(x: Array, /) -> Array: + return NotImplemented + + +def trunc(x: Array, /) -> Array: + return NotImplemented diff --git a/arrayfire/array_api/_indexing_functions.py b/arrayfire/array_api/_indexing_functions.py new file mode 100755 index 0000000..0b19766 --- /dev/null +++ b/arrayfire/array_api/_indexing_functions.py @@ -0,0 +1,7 @@ +from typing import Optional + +from ._array_object import Array + + +def take(x: Array, indices: Array, /, *, axis: Optional[int] = None) -> Array: + return NotImplemented diff --git a/arrayfire/array_api/_manipulation_functions.py b/arrayfire/array_api/_manipulation_functions.py new file mode 100755 index 0000000..eebafd4 --- /dev/null +++ b/arrayfire/array_api/_manipulation_functions.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from typing import List, Optional, Tuple, Union + +from ._array_object import Array + + +def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0) -> Array: + return NotImplemented + + +def expand_dims(x: Array, /, *, axis: int) -> Array: + return NotImplemented + + +def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: + return NotImplemented + + +def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array: + return NotImplemented + + +def reshape(x: Array, /, shape: Tuple[int, ...], *, copy: Optional[bool] = None) -> Array: + return NotImplemented + + +def roll( + x: Array, + /, + shift: Union[int, Tuple[int, ...]], + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, +) -> Array: + return NotImplemented + + +def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: + return NotImplemented + + +def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> Array: + return NotImplemented diff --git a/arrayfire/array_api/_searching_functions.py b/arrayfire/array_api/_searching_functions.py new file mode 100755 index 0000000..fafe009 --- /dev/null +++ b/arrayfire/array_api/_searching_functions.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from typing import Optional, Tuple + +from ._array_object import Array + + +def argmax(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> Array: + return NotImplemented + + +def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> Array: + return NotImplemented + + +def nonzero(x: Array, /) -> Tuple[Array, ...]: + return NotImplemented + + +def where(condition: Array, x1: Array, x2: Array, /) -> Array: + return NotImplemented \ No newline at end of file diff --git a/arrayfire/array_api/_set_functions.py b/arrayfire/array_api/_set_functions.py new file mode 100755 index 0000000..ea09497 --- /dev/null +++ b/arrayfire/array_api/_set_functions.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from typing import NamedTuple + +from ._array_object import Array + + +class UniqueAllResult(NamedTuple): + values: Array + indices: Array + inverse_indices: Array + counts: Array + + +class UniqueCountsResult(NamedTuple): + values: Array + counts: Array + + +class UniqueInverseResult(NamedTuple): + values: Array + inverse_indices: Array + + +def unique_all(x: Array, /) -> UniqueAllResult: + return NotImplemented + + +def unique_counts(x: Array, /) -> UniqueCountsResult: + return NotImplemented + + +def unique_inverse(x: Array, /) -> UniqueInverseResult: + return NotImplemented + + +def unique_values(x: Array, /) -> Array: + return NotImplemented diff --git a/arrayfire/array_api/_sorting_functions.py b/arrayfire/array_api/_sorting_functions.py new file mode 100755 index 0000000..478b781 --- /dev/null +++ b/arrayfire/array_api/_sorting_functions.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from ._array_object import Array + + +def argsort(x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True) -> Array: + return NotImplemented + + +def sort(x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True) -> Array: + return NotImplemented diff --git a/arrayfire/array_api/_statistical_functions.py b/arrayfire/array_api/_statistical_functions.py new file mode 100755 index 0000000..1e47b4b --- /dev/null +++ b/arrayfire/array_api/_statistical_functions.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +from typing import Optional, Tuple, Union + +from ._array_object import Array +from ._dtypes import Dtype + + +def max( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> Array: + return NotImplemented + + +def mean( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> Array: + return NotImplemented + + +def min( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> Array: + return NotImplemented + + +def prod( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + dtype: Optional[Dtype] = None, + keepdims: bool = False, +) -> Array: + return NotImplemented + + +def std( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + correction: Union[int, float] = 0.0, + keepdims: bool = False, +) -> Array: + return NotImplemented + + +def sum( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + dtype: Optional[Dtype] = None, + keepdims: bool = False, +) -> Array: + return NotImplemented + + +def var( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + correction: Union[int, float] = 0.0, + keepdims: bool = False, +) -> Array: + return NotImplemented diff --git a/arrayfire/array_api/_utility_functions.py b/arrayfire/array_api/_utility_functions.py new file mode 100755 index 0000000..1ca4248 --- /dev/null +++ b/arrayfire/array_api/_utility_functions.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from typing import Optional, Tuple, Union + +from ._array_object import Array + + +def all( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> Array: + return NotImplemented + + +def any( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> Array: + return NotImplemented diff --git a/arrayfire/array_api/creation_function.py b/arrayfire/array_api/creation_function.py deleted file mode 100755 index 95ea40a..0000000 --- a/arrayfire/array_api/creation_function.py +++ /dev/null @@ -1,52 +0,0 @@ -from __future__ import annotations - -from typing import Optional, Union - -from arrayfire import Array as AFArray -from arrayfire.array_api.array_object import Array -from arrayfire.array_api.constants import Device, NestedSequence, SupportsBufferProtocol -from arrayfire.array_api.dtypes import all_dtypes -from arrayfire.dtypes import Dtype -from arrayfire.library.device import PointerSource - - -def _check_valid_dtype(dtype: Optional[Dtype]) -> None: - # Note: Only spelling dtypes as the dtype objects is supported. - - # We use this instead of "dtype in _all_dtypes" because the dtype objects - # define equality with the sorts of things we want to disallow. - for d in (None,) + all_dtypes: - if dtype is d: - return - raise ValueError("dtype must be one of the supported dtypes") - - -def asarray( - obj: Union[Array, bool, int, float, complex, NestedSequence, SupportsBufferProtocol], - /, - *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - copy: Optional[bool] = None, -) -> Array: - _check_valid_dtype(dtype) - - # if device not in supported_devices: - # raise ValueError(f"Unsupported device {device!r}") - - if dtype is None and isinstance(obj, int) and (obj > 2**64 or obj < -(2**63)): - raise OverflowError("Integer out of bounds for array dtypes") - - if device == Device.cpu or device is None: - pointer_source = PointerSource.host - elif device == Device.gpu: - pointer_source = PointerSource.device - else: - raise ValueError(f"Unsupported device {device!r}") - - if isinstance(obj, int | float): - afarray = AFArray([obj], dtype=dtype, shape=(1,), pointer_source=pointer_source) - return Array._new(afarray) - - afarray = AFArray(obj, dtype=dtype, pointer_source=pointer_source) - return Array._new(afarray) diff --git a/arrayfire/array_api/data_type_functions.py b/arrayfire/array_api/data_type_functions.py deleted file mode 100755 index f9cadd2..0000000 --- a/arrayfire/array_api/data_type_functions.py +++ /dev/null @@ -1,35 +0,0 @@ -from typing import Union - -from arrayfire import Array as AFArray -from arrayfire.array_api.array_object import Array -from arrayfire.array_api.dtypes import all_dtypes, promote_types -from arrayfire.dtypes import Dtype - - -def result_type(*arrays_and_dtypes: Union[Array, Dtype]) -> Dtype: - """ - Array API compatible wrapper for :py:func:`np.result_type `. - - See its docstring for more information. - """ - # Note: we use a custom implementation that gives only the type promotions - # required by the spec rather than using np.result_type. NumPy implements - # too many extra type promotions like int64 + uint64 -> float64, and does - # value-based casting on scalar arrays. - A = [] - for a in arrays_and_dtypes: - if isinstance(a, Array): - a = a.dtype - elif isinstance(a, AFArray) or a not in all_dtypes: - raise TypeError("result_type() inputs must be array_api arrays or dtypes") - A.append(a) - - if len(A) == 0: - raise ValueError("at least one array or dtype is required") - elif len(A) == 1: - return A[0] - else: - t = A[0] - for t2 in A[1:]: - t = promote_types(t, t2) - return t diff --git a/arrayfire/array_api/tests/fixme_test_array_object.py b/arrayfire/array_api/tests/fixme_test_array_object.py index ed52637..876e457 100755 --- a/arrayfire/array_api/tests/fixme_test_array_object.py +++ b/arrayfire/array_api/tests/fixme_test_array_object.py @@ -5,9 +5,9 @@ import pytest from arrayfire import bool, int8, int16, int32, int64, uint64 -from arrayfire.array_api.creation_function import asarray -from arrayfire.array_api.data_type_functions import result_type -from arrayfire.array_api.dtypes import ( +from arrayfire.array_api._creation_function import asarray +from arrayfire.array_api._data_type_functions import result_type +from arrayfire.array_api._dtypes import ( boolean_dtypes, complex_floating_dtypes, floating_dtypes, @@ -19,7 +19,7 @@ ) if TYPE_CHECKING: - from arrayfire.array_api.array_object import Array + from arrayfire.array_api._array_object import Array def test_operators() -> None: diff --git a/arrayfire/array_api/tests/fixme_test_elementwise_functions.py b/arrayfire/array_api/tests/fixme_test_elementwise_functions.py new file mode 100755 index 0000000..9f2a0d6 --- /dev/null +++ b/arrayfire/array_api/tests/fixme_test_elementwise_functions.py @@ -0,0 +1,105 @@ +from inspect import getfullargspec + +import pytest + +from .. import _elementwise_functions, asarray +from .._dtypes import boolean_dtypes, dtype_categories, floating_dtypes, integer_dtypes +from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift + + +def nargs(func): + return len(getfullargspec(func).args) + + +def test_function_types(): + # Test that every function accepts only the required input types. We only + # test the negative cases here (error). The positive cases are tested in + # the array API test suite. + + elementwise_function_input_types = { + # "abs": "numeric", + # "acos": "floating-point", + # "acosh": "floating-point", + "add": "numeric", + # "asin": "floating-point", + # "asinh": "floating-point", + # "atan": "floating-point", + # "atan2": "real floating-point", + # "atanh": "floating-point", + # "bitwise_and": "integer or boolean", + # "bitwise_invert": "integer or boolean", + # "bitwise_left_shift": "integer", + # "bitwise_or": "integer or boolean", + # "bitwise_right_shift": "integer", + # "bitwise_xor": "integer or boolean", + # "ceil": "real numeric", + # "conj": "complex floating-point", + # "cos": "floating-point", + # "cosh": "floating-point", + # "divide": "floating-point", + # "equal": "all", + # "exp": "floating-point", + # "expm1": "floating-point", + # "floor": "real numeric", + # "floor_divide": "real numeric", + # "greater": "real numeric", + # "greater_equal": "real numeric", + # "imag": "complex floating-point", + # "isfinite": "numeric", + # "isinf": "numeric", + # "isnan": "numeric", + # "less": "real numeric", + # "less_equal": "real numeric", + # "log": "floating-point", + # "logaddexp": "real floating-point", + # "log10": "floating-point", + # "log1p": "floating-point", + # "log2": "floating-point", + # "logical_and": "boolean", + # "logical_not": "boolean", + # "logical_or": "boolean", + # "logical_xor": "boolean", + # "multiply": "numeric", + # "negative": "numeric", + # "not_equal": "all", + # "positive": "numeric", + # "pow": "numeric", + # "real": "complex floating-point", + # "remainder": "real numeric", + # "round": "numeric", + # "sign": "numeric", + # "sin": "floating-point", + # "sinh": "floating-point", + # "sqrt": "floating-point", + # "square": "numeric", + # "subtract": "numeric", + # "tan": "floating-point", + # "tanh": "floating-point", + # "trunc": "real numeric", + } + + def _array_vals(): + for d in integer_dtypes: + yield asarray(1, dtype=d) + for d in boolean_dtypes: + yield asarray(False, dtype=d) + for d in floating_dtypes: + yield asarray(1.0, dtype=d) + + for x in _array_vals(): + for func_name, types in elementwise_function_input_types.items(): + dtypes = dtype_categories[types] + func = getattr(_elementwise_functions, func_name) + if nargs(func) == 2: + for y in _array_vals(): + if x.dtype not in dtypes or y.dtype not in dtypes: + pytest.raises(TypeError, lambda: func(x, y)) + else: + if x.dtype not in dtypes: + pytest.raises(TypeError, lambda: func(x)) + + +# def test_bitwise_shift_error() -> None: +# # bitwise shift functions should raise when the second argument is negative +# pytest.raises(ValueError, lambda: bitwise_left_shift(asarray([1, 1]), asarray([1, -1]))) +# pytest.raises(ValueError, lambda: bitwise_right_shift(asarray([1, 1]), asarray([1, -1]))) diff --git a/arrayfire/array_api/tests/test_creation_functions.py b/arrayfire/array_api/tests/test_creation_functions.py index 6611882..33e3c07 100755 --- a/arrayfire/array_api/tests/test_creation_functions.py +++ b/arrayfire/array_api/tests/test_creation_functions.py @@ -1,8 +1,8 @@ import pytest from arrayfire.array_api import asarray -from arrayfire.array_api.array_object import Array -from arrayfire.array_api.constants import Device +from arrayfire.array_api._array_object import Array +from arrayfire.array_api._constants import Device from arrayfire.dtypes import float16 diff --git a/arrayfire/array_object.py b/arrayfire/array_object.py index 171511d..a12a901 100755 --- a/arrayfire/array_object.py +++ b/arrayfire/array_object.py @@ -3,11 +3,11 @@ import array as py_array import ctypes import enum +from dataclasses import dataclass from typing import Any, List, Optional, Tuple, Union -from .backend import c_library as wrapper -from .backend.c_library.indexing import CIndexStructure, IndexStructure -from .backend.constants import ArrayBuffer +from .backend import _clib_wrapper as wrapper +from .backend._clib_wrapper._indexing import CIndexStructure, IndexStructure from .dtypes import CType, Dtype, c_api_value_to_dtype, float32, str_to_dtype from .library.device import PointerSource @@ -16,6 +16,12 @@ AFArrayType = ctypes.c_void_p +@dataclass(frozen=True) +class _ArrayBuffer: + address: int + length: int = 0 + + class Array: def __init__( self, @@ -49,15 +55,15 @@ def __init__( if isinstance(obj, py_array.array): _type_char: str = obj.typecode - _array_buffer = ArrayBuffer(*obj.buffer_info()) + _array_buffer = _ArrayBuffer(*obj.buffer_info()) elif isinstance(obj, list): _array = py_array.array("f", obj) # BUG [True, False] -> dtype: f32 # TODO add int and float _type_char = _array.typecode - _array_buffer = ArrayBuffer(*_array.buffer_info()) + _array_buffer = _ArrayBuffer(*_array.buffer_info()) elif isinstance(obj, int) or isinstance(obj, AFArrayType): - _array_buffer = ArrayBuffer(obj if not isinstance(obj, AFArrayType) else obj.value) # type: ignore + _array_buffer = _ArrayBuffer(obj if not isinstance(obj, AFArrayType) else obj.value) # type: ignore if not shape: raise TypeError("Expected to receive the initial shape due to the obj being a data pointer.") diff --git a/arrayfire/backend/_backend.py b/arrayfire/backend/_backend.py new file mode 100644 index 0000000..ce71172 --- /dev/null +++ b/arrayfire/backend/_backend.py @@ -0,0 +1,239 @@ +__all__ = ["BackendType"] + +import ctypes +import enum +import os +import platform +import sys +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import Iterator, List, Optional + +from arrayfire.logger import logger +from arrayfire.version import ARRAYFIRE_VER_MAJOR + + +def is_arch_x86() -> bool: + machine = platform.machine() + return platform.architecture()[0][0:2] == "32" and (machine[-2:] == "86" or machine[0:3] == "arm") + + +class _LibPrefixes(Enum): + forge = "" + arrayfire = "af" + + +class _SupportedPlatforms(Enum): + windows = "Windows" + darwin = "Darwin" # OSX + linux = "Linux" + + @classmethod + def is_cygwin(cls, name: str) -> bool: + return "cyg" in name.lower() + + +@dataclass(frozen=True) +class _BackendPathConfig: + lib_prefix: str + lib_postfix: str + af_path: Path + cuda_found: bool + + def __iter__(self) -> Iterator: + return iter((self.lib_prefix, self.lib_postfix, self.af_path, self.af_path, self.cuda_found)) + + +def _get_backend_path_config() -> _BackendPathConfig: + platform_name = platform.system() + cuda_found = False + + try: + af_path = Path(os.environ["AF_PATH"]) + except KeyError: + af_path = None + + try: + cuda_path = Path(os.environ["CUDA_PATH"]) + except KeyError: + cuda_path = None + + if platform_name == _SupportedPlatforms.windows.value or _SupportedPlatforms.is_cygwin(platform_name): + if platform_name == _SupportedPlatforms.windows.value: + # HACK Supressing crashes caused by missing dlls + # http://stackoverflow.com/questions/8347266/missing-dll-print-message-instead-of-launching-a-popup + # https://msdn.microsoft.com/en-us/_clib/windows/desktop/ms680621.aspx + ctypes.windll.kernel32.SetErrorMode(0x0001 | 0x0002) # type: ignore[attr-defined] + + if not af_path: + af_path = _find_default_path(f"C:/Program Files/ArrayFire/v{ARRAYFIRE_VER_MAJOR}") + + if cuda_path and (cuda_path / "bin").is_dir() and (cuda_path / "nvvm/bin").is_dir(): + cuda_found = True + + return _BackendPathConfig("", ".dll", af_path, cuda_found) + + if platform_name == _SupportedPlatforms.darwin.value: + default_cuda_path = Path("/usr/local/cuda/") + + if not af_path: + af_path = _find_default_path("/opt/arrayfire", "/usr/local") + + if not (cuda_path and default_cuda_path.exists()): + cuda_found = (default_cuda_path / "lib").is_dir() and (default_cuda_path / "/nvvm/lib").is_dir() + + return _BackendPathConfig("lib", f".{ARRAYFIRE_VER_MAJOR}.dylib", af_path, cuda_found) + + if platform_name == _SupportedPlatforms.linux.value: + default_cuda_path = Path("/usr/local/cuda/") + + if not af_path: + af_path = _find_default_path(f"/opt/arrayfire-{ARRAYFIRE_VER_MAJOR}", "/opt/arrayfire/", "/usr/local/") + + if not (cuda_path and default_cuda_path.exists()): + if "64" in platform.architecture()[0]: # Check either is 64 bit arch is selected + cuda_found = (default_cuda_path / "lib64").is_dir() and (default_cuda_path / "nvvm/lib64").is_dir() + else: + cuda_found = (default_cuda_path / "lib").is_dir() and (default_cuda_path / "nvvm/lib").is_dir() + + return _BackendPathConfig("lib", f".so.{ARRAYFIRE_VER_MAJOR}", af_path, cuda_found) + + raise OSError(f"{platform_name} is not supported.") + + +def _find_default_path(*args: str) -> Path: + for path in args: + default_path = Path(path) + if default_path.exists(): + return default_path + raise ValueError("None of specified default paths were found.") + + +class BackendType(enum.Enum): # TODO change name - avoid using _backend_type - e.g. type + unified = 0 # NOTE It is set as Default value on Arrayfire backend + cpu = 1 + cuda = 2 + opencl = 4 + + def __iter__(self) -> Iterator: + # NOTE cpu comes last because we want to keep this order priorty during backend initialization + return iter((self.unified, self.cuda, self.opencl, self.cpu)) + + +class Backend: + _backend_type: BackendType + # HACK for osx + # _backend.clib = ctypes.CDLL("/opt/arrayfire//lib/libafcpu.3.dylib") + # HACK for windows + # _backend.clib = ctypes.CDLL("C:/Program Files/ArrayFire/v3/lib/afcpu.dll") + _clib: ctypes.CDLL + + def __init__(self) -> None: + self._backend_path_config = _get_backend_path_config() + + self._load_forge_lib() + self._load_backend_libs() + + def _load_forge_lib(self) -> None: + for lib_name in self._lib_names("forge", _LibPrefixes.forge): + try: + ctypes.cdll.LoadLibrary(str(lib_name)) + logger.info(f"Loaded {lib_name}") + break + except OSError: + logger.warning(f"Unable to load {lib_name}") + pass + + def _load_backend_libs(self) -> None: + for backend_type in BackendType: + self._load_backend_lib(backend_type) + + if self._backend_type: + logger.info(f"Setting {backend_type.name} as backend.") + break + + if not self._backend_type and not self._clib: + raise RuntimeError( + "Could not load any ArrayFire libraries.\n" + "Please look at https://github.com/arrayfire/arrayfire-python/wiki for more information." + ) + + def _load_backend_lib(self, _backend_type: BackendType) -> None: + # NOTE we still set unified cdll to it's original name later, even if the path search is different + name = _backend_type.name if _backend_type != BackendType.unified else "" + + for lib_name in self._lib_names(name, _LibPrefixes.arrayfire): + try: + ctypes.cdll.LoadLibrary(str(lib_name)) + self._backend_type = _backend_type + self._clib = ctypes.CDLL(str(lib_name)) + + if _backend_type == BackendType.cuda: + self._load_nvrtc_builtins_lib(lib_name.parent) + + logger.info(f"Loaded {lib_name}") + break + except OSError: + logger.warning(f"Unable to load {lib_name}") + pass + + def _load_nvrtc_builtins_lib(self, lib_path: Path) -> None: + nvrtc_name = self._find_nvrtc_builtins_lib_name(lib_path) + if nvrtc_name: + ctypes.cdll.LoadLibrary(str(lib_path / nvrtc_name)) + logger.info(f"Loaded {lib_path / nvrtc_name}") + else: + logger.warning("Could not find local nvrtc-builtins library") + + def _lib_names(self, name: str, lib: _LibPrefixes, ver_major: Optional[str] = None) -> List[Path]: + post = self._backend_path_config.lib_postfix if ver_major is None else ver_major + lib_name = self._backend_path_config.lib_prefix + lib.value + name + post + + lib64_path = self._backend_path_config.af_path / "lib64" + search_path = lib64_path if lib64_path.is_dir() else self._backend_path_config.af_path / "lib" + + site_path = Path(sys.prefix) / "lib64" if not is_arch_x86() else Path(sys.prefix) / "lib" + + # prefer locally packaged arrayfire libraries if they exist + af_module = __import__(__name__) + local_path = Path(af_module.__path__[0]) if af_module.__path__ else Path("") + + lib_paths = [Path("", lib_name), site_path / lib_name, local_path / lib_name] + + if self._backend_path_config.af_path: # prefer specified AF_PATH if exists + return [search_path / lib_name] + lib_paths + else: + lib_paths.insert(2, Path(str(search_path), lib_name)) + return lib_paths + + def _find_nvrtc_builtins_lib_name(self, search_path: Path) -> Optional[str]: + for f in search_path.iterdir(): + if "nvrtc-builtins" in f.name: + return f.name + return None + + @property + def backend_type(self) -> BackendType: + return self._backend_type + + @property + def clib(self) -> ctypes.CDLL: + return self._clib + + +# Initialize the backend +_backend = Backend() + + +def get_backend() -> Backend: + """ + Get the current active backend. + + Returns + ------- + value : Backend + Current active backend. + """ + + return _backend diff --git a/arrayfire/backend/helpers.py b/arrayfire/backend/_backend_functions.py similarity index 80% rename from arrayfire/backend/helpers.py rename to arrayfire/backend/_backend_functions.py index af19ec2..483dabc 100755 --- a/arrayfire/backend/helpers.py +++ b/arrayfire/backend/_backend_functions.py @@ -3,25 +3,25 @@ import warnings from typing import TYPE_CHECKING, Union -from .api import Backend, BackendPlatform, get_backend -from .c_library.unsorted import get_backend_count as c_get_backend_count -from .c_library.unsorted import get_backend_id as c_get_backend_id -from .c_library.unsorted import get_device_id as c_get_device_id -from .c_library.unsorted import get_size_of as c_get_size_of -from .c_library.unsorted import set_backend as c_set_backend +from ._backend import Backend, BackendType, get_backend +from ._clib_wrapper._unsorted import get_backend_count as c_get_backend_count +from ._clib_wrapper._unsorted import get_backend_id as c_get_backend_id +from ._clib_wrapper._unsorted import get_device_id as c_get_device_id +from ._clib_wrapper._unsorted import get_size_of as c_get_size_of +from ._clib_wrapper._unsorted import set_backend as c_set_backend if TYPE_CHECKING: from arrayfire import Array from arrayfire.dtypes import Dtype -def set_backend(platform: Union[BackendPlatform, str]) -> None: +def set_backend(platform: Union[BackendType, str]) -> None: """ Set a specific backend by platform name. Parameters ---------- - platform : Union[BackendPlatform, str] + platform : Union[BackendType, str] Name of the backend platform to set. Raises @@ -39,17 +39,17 @@ def set_backend(platform: Union[BackendPlatform, str]) -> None: current_active_platform = backend.platform if isinstance(platform, str): - if platform not in [d.name for d in BackendPlatform]: + if platform not in [d.name for d in BackendType]: raise ValueError(f"{platform} is not a valid name for backend platform.") - platform = BackendPlatform[platform] + platform = BackendType[platform] - if not isinstance(platform, BackendPlatform): + if not isinstance(platform, BackendType): raise TypeError(f"{platform} is not a valid type for backend platform.") if current_active_platform == platform: raise RuntimeError(f"{platform} is already the active backend platform.") - if backend.platform == BackendPlatform.unified: + if backend.platform == BackendType.unified: c_set_backend(platform.value) # NOTE keep in mind that this operation works in-place @@ -76,7 +76,7 @@ def get_array_backend_name(array: Array) -> str: """ id_ = c_get_backend_id(array.arr) - return BackendPlatform(id_).name + return BackendType(id_).name def get_backend_id(array: Array) -> str: @@ -106,6 +106,7 @@ def get_active_backend() -> Backend: Current active backend. """ + # TODO do not deprecate warnings.warn("A user has access explicitly only to the active backend.", DeprecationWarning) return get_backend() @@ -120,6 +121,7 @@ def get_available_backends() -> Backend: Current active backend. """ + # TODO do not deprecate warnings.warn( "A user has access explicitly only to the active backend. Thus returning only active backend.", DeprecationWarning, diff --git a/arrayfire/backend/c_library/__init__.py b/arrayfire/backend/_clib_wrapper/__init__.py similarity index 82% rename from arrayfire/backend/c_library/__init__.py rename to arrayfire/backend/_clib_wrapper/__init__.py index a1f4fdf..f39f3e3 100755 --- a/arrayfire/backend/c_library/__init__.py +++ b/arrayfire/backend/_clib_wrapper/__init__.py @@ -21,7 +21,7 @@ "neq", ] -from .operators import ( +from ._operators import ( add, bitand, bitnot, @@ -70,7 +70,7 @@ "get_backend_id", ] -from .unsorted import ( +from ._unsorted import ( array_as_str, copy_array, create_array, @@ -100,12 +100,12 @@ __all__ += ["safe_call"] -from .error_handler import safe_call +from ._error_handler import safe_call __all__ += ["count_all"] -from .reduction_operations import count_all +from ._reduction_operations import count_all __all__ += ["create_constant_array"] -from .constant_array import create_constant_array +from ._constant_array import create_constant_array diff --git a/arrayfire/backend/c_library/constant_array.py b/arrayfire/backend/_clib_wrapper/_constant_array.py similarity index 89% rename from arrayfire/backend/c_library/constant_array.py rename to arrayfire/backend/_clib_wrapper/_constant_array.py index 3192210..1921107 100755 --- a/arrayfire/backend/c_library/constant_array.py +++ b/arrayfire/backend/_clib_wrapper/_constant_array.py @@ -3,10 +3,10 @@ import ctypes from typing import TYPE_CHECKING, Tuple, Union -from arrayfire.backend.api import backend_api +from arrayfire.backend._backend import _backend from arrayfire.dtypes import CShape, Dtype, implicit_dtype, int64, uint64 -from .error_handler import safe_call +from ._error_handler import safe_call if TYPE_CHECKING: from arrayfire.array_object import AFArrayType @@ -20,7 +20,7 @@ def _constant_complex(number: Union[int, float], shape: Tuple[int, ...], dtype: c_shape = CShape(*shape) safe_call( - backend_api.af_constant_complex( + _backend.clib.af_constant_complex( ctypes.pointer(out), ctypes.c_double(number.real), ctypes.c_double(number.imag), @@ -40,7 +40,7 @@ def _constant_long(number: Union[int, float], shape: Tuple[int, ...], dtype: Dty c_shape = CShape(*shape) safe_call( - backend_api.af_constant_long( + _backend.clib.af_constant_long( ctypes.pointer(out), ctypes.c_longlong(int(number.real)), 4, ctypes.pointer(c_shape.c_array) ) ) @@ -51,12 +51,12 @@ def _constant_ulong(number: Union[int, float], shape: Tuple[int, ...], dtype: Dt """ source: https://arrayfire.org/docs/group__data__func__constant.htm#ga67af670cc9314589f8134019f5e68809 """ - # return backend_api.af_constant_ulong(arr, val, ndims, dims) + # return _backend.clib.af_constant_ulong(arr, val, ndims, dims) out = ctypes.c_void_p(0) c_shape = CShape(*shape) safe_call( - backend_api.af_constant_ulong( + _backend.clib.af_constant_ulong( ctypes.pointer(out), ctypes.c_ulonglong(int(number.real)), 4, ctypes.pointer(c_shape.c_array) ) ) @@ -71,7 +71,7 @@ def _constant(number: Union[int, float], shape: Tuple[int, ...], dtype: Dtype, / c_shape = CShape(*shape) safe_call( - backend_api.af_constant( + _backend.clib.af_constant( ctypes.pointer(out), ctypes.c_double(number), 4, ctypes.pointer(c_shape.c_array), dtype.c_api_value ) ) diff --git a/arrayfire/backend/c_library/error_handler.py b/arrayfire/backend/_clib_wrapper/_error_handler.py similarity index 67% rename from arrayfire/backend/c_library/error_handler.py rename to arrayfire/backend/_clib_wrapper/_error_handler.py index 35bd945..73e7762 100755 --- a/arrayfire/backend/c_library/error_handler.py +++ b/arrayfire/backend/_clib_wrapper/_error_handler.py @@ -1,7 +1,7 @@ import ctypes from enum import Enum -from arrayfire.backend.api import backend_api +from arrayfire.backend._backend import _backend from arrayfire.dtypes import c_dim_t, to_str @@ -15,5 +15,5 @@ def safe_call(c_err: int) -> None: err_str = ctypes.c_char_p(0) err_len = c_dim_t(0) - backend_api.af_get_last_error(ctypes.pointer(err_str), ctypes.pointer(err_len)) + _backend.clib.af_get_last_error(ctypes.pointer(err_str), ctypes.pointer(err_len)) raise RuntimeError(to_str(err_str)) diff --git a/arrayfire/backend/c_library/indexing.py b/arrayfire/backend/_clib_wrapper/_indexing.py similarity index 94% rename from arrayfire/backend/c_library/indexing.py rename to arrayfire/backend/_clib_wrapper/_indexing.py index 1383798..a20990b 100755 --- a/arrayfire/backend/c_library/indexing.py +++ b/arrayfire/backend/_clib_wrapper/_indexing.py @@ -4,10 +4,10 @@ import math from typing import Any, Union -from arrayfire.backend.api import backend_api +from arrayfire.backend._backend import _backend from arrayfire.library.broadcast import bcast_var -from .error_handler import safe_call +from ._error_handler import safe_call class _IndexSequence(ctypes.Structure): @@ -225,7 +225,7 @@ def __del__(self) -> None: # converted to basic C types so we have to # build the void_p from the value again. arr = ctypes.c_void_p(self.idx.arr) - safe_call(backend_api.af_release_array(arr)) + safe_call(_backend.clib.af_release_array(arr)) class CIndexStructure: diff --git a/arrayfire/backend/c_library/operators.py b/arrayfire/backend/_clib_wrapper/_operators.py similarity index 77% rename from arrayfire/backend/c_library/operators.py rename to arrayfire/backend/_clib_wrapper/_operators.py index 8105f90..78c5515 100755 --- a/arrayfire/backend/c_library/operators.py +++ b/arrayfire/backend/_clib_wrapper/_operators.py @@ -3,10 +3,10 @@ import ctypes from typing import TYPE_CHECKING, Callable -from arrayfire.backend.api import backend_api +from arrayfire.backend._backend import _backend from arrayfire.library.broadcast import bcast_var -from .error_handler import safe_call +from ._error_handler import safe_call if TYPE_CHECKING: from arrayfire.array_object import AFArrayType @@ -18,42 +18,42 @@ def add(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__add.htm#ga1dfbee755fedd680f4476803ddfe06a7 """ - return _binary_op(backend_api.af_add, lhs, rhs) + return _binary_op(_backend.clib.af_add, lhs, rhs) def sub(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__sub.htm#ga80ff99a2e186c23614ea9f36ffc6f0a4 """ - return _binary_op(backend_api.af_sub, lhs, rhs) + return _binary_op(_backend.clib.af_sub, lhs, rhs) def mul(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__mul.htm#ga5f7588b2809ff7551d38b6a0bd583a02 """ - return _binary_op(backend_api.af_mul, lhs, rhs) + return _binary_op(_backend.clib.af_mul, lhs, rhs) def div(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__div.htm#ga21f3f97755702692ec8976934e75fde6 """ - return _binary_op(backend_api.af_div, lhs, rhs) + return _binary_op(_backend.clib.af_div, lhs, rhs) def mod(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__mod.htm#ga01924d1b59d8886e46fabd2dc9b27e0f """ - return _binary_op(backend_api.af_mod, lhs, rhs) + return _binary_op(_backend.clib.af_mod, lhs, rhs) def pow(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__pow.htm#ga0f28be1a9c8b176a78c4a47f483e7fc6 """ - return _binary_op(backend_api.af_pow, lhs, rhs) + return _binary_op(_backend.clib.af_pow, lhs, rhs) # Bitwise Operators @@ -64,7 +64,7 @@ def bitnot(arr: AFArrayType, /) -> AFArrayType: source: https://arrayfire.org/docs/group__arith__func__bitnot.htm#gaf97e8a38aab59ed2d3a742515467d01e """ out = ctypes.c_void_p(0) - safe_call(backend_api.af_bitnot(ctypes.pointer(out), arr)) + safe_call(_backend.clib.af_bitnot(ctypes.pointer(out), arr)) return out @@ -72,77 +72,77 @@ def bitand(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__bitand.htm#ga45c0779ade4703708596df11cca98800 """ - return _binary_op(backend_api.af_bitand, lhs, rhs) + return _binary_op(_backend.clib.af_bitand, lhs, rhs) def bitor(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__bitor.htm#ga84c99f77d1d83fd53f949b4d67b5b210 """ - return _binary_op(backend_api.af_bitor, lhs, rhs) + return _binary_op(_backend.clib.af_bitor, lhs, rhs) def bitxor(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__bitxor.htm#ga8188620da6b432998e55fdd1fad22100 """ - return _binary_op(backend_api.af_bitxor, lhs, rhs) + return _binary_op(_backend.clib.af_bitxor, lhs, rhs) def bitshiftl(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__shiftl.htm#ga3139645aafe6f045a5cab454e9c13137 """ - return _binary_op(backend_api.af_butshiftl, lhs, rhs) + return _binary_op(_backend.clib.af_butshiftl, lhs, rhs) def bitshiftr(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__shiftr.htm#ga4c06b9977ecf96cdfc83b5dfd1ac4895 """ - return _binary_op(backend_api.af_bitshiftr, lhs, rhs) + return _binary_op(_backend.clib.af_bitshiftr, lhs, rhs) def lt(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/arith_8h.htm#ae7aa04bf23b32bb11c4bab8bdd637103 """ - return _binary_op(backend_api.af_lt, lhs, rhs) + return _binary_op(_backend.clib.af_lt, lhs, rhs) def le(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__le.htm#gad5535ce64dbed46d0773fd494e84e922 """ - return _binary_op(backend_api.af_le, lhs, rhs) + return _binary_op(_backend.clib.af_le, lhs, rhs) def gt(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__gt.htm#ga4e65603259515de8939899a163ebaf9e """ - return _binary_op(backend_api.af_gt, lhs, rhs) + return _binary_op(_backend.clib.af_gt, lhs, rhs) def ge(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__ge.htm#ga4513f212e0b0a22dcf4653e89c85e3d9 """ - return _binary_op(backend_api.af_ge, lhs, rhs) + return _binary_op(_backend.clib.af_ge, lhs, rhs) def eq(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__eq.htm#ga76d2da7716831616bb81effa9e163693 """ - return _binary_op(backend_api.af_eq, lhs, rhs) + return _binary_op(_backend.clib.af_eq, lhs, rhs) def neq(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__neq.htm#gae4ee8bd06a410f259f1493fb811ce441 """ - return _binary_op(backend_api.af_neq, lhs, rhs) + return _binary_op(_backend.clib.af_neq, lhs, rhs) def _binary_op(c_func: Callable, lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: diff --git a/arrayfire/backend/c_library/reduction_operations.py b/arrayfire/backend/_clib_wrapper/_reduction_operations.py similarity index 77% rename from arrayfire/backend/c_library/reduction_operations.py rename to arrayfire/backend/_clib_wrapper/_reduction_operations.py index cf5d953..8ef3ed8 100755 --- a/arrayfire/backend/c_library/reduction_operations.py +++ b/arrayfire/backend/_clib_wrapper/_reduction_operations.py @@ -3,9 +3,9 @@ import ctypes from typing import TYPE_CHECKING, Callable, Union -from arrayfire.backend.api import backend_api +from arrayfire.backend._backend import _backend -from .error_handler import safe_call +from ._error_handler import safe_call if TYPE_CHECKING: from arrayfire.array_object import AFArrayType @@ -13,7 +13,7 @@ def count_all(x: AFArrayType) -> Union[int, float, complex]: # TODO reconsider original arith.count - return _reduce_all(x, backend_api.af_count_all) + return _reduce_all(x, _backend.clib.af_count_all) def _reduce_all(arr: AFArrayType, c_func: Callable) -> Union[int, float, complex]: diff --git a/arrayfire/backend/c_library/unsorted.py b/arrayfire/backend/_clib_wrapper/_unsorted.py similarity index 78% rename from arrayfire/backend/c_library/unsorted.py rename to arrayfire/backend/_clib_wrapper/_unsorted.py index e5c5b19..665bc3d 100755 --- a/arrayfire/backend/c_library/unsorted.py +++ b/arrayfire/backend/_clib_wrapper/_unsorted.py @@ -3,15 +3,14 @@ import ctypes from typing import TYPE_CHECKING, Any, Tuple, Union, cast -from arrayfire.backend.api import backend_api -from arrayfire.backend.constants import ArrayBuffer +from arrayfire.backend._backend import _backend from arrayfire.dtypes import CShape, CType, Dtype, c_dim_t, to_str from arrayfire.library.device import PointerSource -from .error_handler import safe_call +from ._error_handler import safe_call if TYPE_CHECKING: - from arrayfire.array_object import AFArrayType + from arrayfire.array_object import AFArrayType, _ArrayBuffer # Array management @@ -24,7 +23,7 @@ def create_handle(shape: Tuple[int, ...], dtype: Dtype, /) -> AFArrayType: c_shape = CShape(*shape) safe_call( - backend_api.af_create_handle( + _backend.clib.af_create_handle( ctypes.pointer(out), c_shape.original_shape, ctypes.pointer(c_shape.c_array), dtype.c_api_value ) ) @@ -37,11 +36,11 @@ def retain_array(arr: AFArrayType) -> AFArrayType: """ out = ctypes.c_void_p(0) - safe_call(backend_api.af_retain_array(ctypes.pointer(out), arr)) + safe_call(_backend.clib.af_retain_array(ctypes.pointer(out), arr)) return out -def create_array(shape: Tuple[int, ...], dtype: Dtype, array_buffer: ArrayBuffer, /) -> AFArrayType: +def create_array(shape: Tuple[int, ...], dtype: Dtype, array_buffer: _ArrayBuffer, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__c__api__mat.htm#ga834be32357616d8ab735087c6f681858 """ @@ -49,7 +48,7 @@ def create_array(shape: Tuple[int, ...], dtype: Dtype, array_buffer: ArrayBuffer c_shape = CShape(*shape) safe_call( - backend_api.af_create_array( + _backend.clib.af_create_array( ctypes.pointer(out), ctypes.c_void_p(array_buffer.address), c_shape.original_shape, @@ -60,7 +59,7 @@ def create_array(shape: Tuple[int, ...], dtype: Dtype, array_buffer: ArrayBuffer return out -def device_array(shape: Tuple[int, ...], dtype: Dtype, array_buffer: ArrayBuffer, /) -> AFArrayType: +def device_array(shape: Tuple[int, ...], dtype: Dtype, array_buffer: _ArrayBuffer, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__c__api__mat.htm#gaad4fc77f872217e7337cb53bfb623cf5 """ @@ -68,7 +67,7 @@ def device_array(shape: Tuple[int, ...], dtype: Dtype, array_buffer: ArrayBuffer c_shape = CShape(*shape) safe_call( - backend_api.af_device_array( + _backend.clib.af_device_array( ctypes.pointer(out), ctypes.c_void_p(array_buffer.address), c_shape.original_shape, @@ -82,7 +81,7 @@ def device_array(shape: Tuple[int, ...], dtype: Dtype, array_buffer: ArrayBuffer def create_strided_array( shape: Tuple[int, ...], dtype: Dtype, - array_buffer: ArrayBuffer, + array_buffer: _ArrayBuffer, offset: CType, strides: Tuple[int, ...], pointer_source: PointerSource, @@ -104,7 +103,7 @@ def create_strided_array( strides += (strides[-1],) * (4 - len(strides)) safe_call( - backend_api.af_create_strided_array( + _backend.clib.af_create_strided_array( ctypes.pointer(out), ctypes.c_void_p(array_buffer.address), offset, @@ -124,7 +123,7 @@ def get_ctype(arr: AFArrayType) -> int: """ out = ctypes.c_int() - safe_call(backend_api.af_get_type(ctypes.pointer(out), arr)) + safe_call(_backend.clib.af_get_type(ctypes.pointer(out), arr)) return out.value @@ -134,7 +133,7 @@ def get_elements(arr: AFArrayType) -> int: """ out = c_dim_t(0) - safe_call(backend_api.af_get_elements(ctypes.pointer(out), arr)) + safe_call(_backend.clib.af_get_elements(ctypes.pointer(out), arr)) return out.value @@ -144,7 +143,7 @@ def get_numdims(arr: AFArrayType) -> int: """ out = ctypes.c_uint(0) - safe_call(backend_api.af_get_numdims(ctypes.pointer(out), arr)) + safe_call(_backend.clib.af_get_numdims(ctypes.pointer(out), arr)) return out.value @@ -158,7 +157,7 @@ def get_dims(arr: AFArrayType) -> Tuple[int, ...]: d3 = c_dim_t(0) safe_call( - backend_api.af_get_dims(ctypes.pointer(d0), ctypes.pointer(d1), ctypes.pointer(d2), ctypes.pointer(d3), arr) + _backend.clib.af_get_dims(ctypes.pointer(d0), ctypes.pointer(d1), ctypes.pointer(d2), ctypes.pointer(d3), arr) ) return (d0.value, d1.value, d2.value, d3.value) @@ -168,7 +167,7 @@ def get_scalar(arr: AFArrayType, dtype: Dtype, /) -> Union[None, int, float, boo source: https://arrayfire.org/docs/group__c__api__mat.htm#gaefe2e343a74a84bd43b588218ecc09a3 """ out = dtype.c_type() - safe_call(backend_api.af_get_scalar(ctypes.pointer(out), arr)) + safe_call(_backend.clib.af_get_scalar(ctypes.pointer(out), arr)) return cast(Union[None, int, float, bool, complex], out.value) @@ -177,7 +176,7 @@ def is_empty(arr: AFArrayType) -> bool: source: https://arrayfire.org/docs/group__c__api__mat.htm#ga19c749e95314e1c77d816ad9952fb680 """ out = ctypes.c_bool() - safe_call(backend_api.af_is_empty(ctypes.pointer(out), arr)) + safe_call(_backend.clib.af_is_empty(ctypes.pointer(out), arr)) return out.value @@ -187,7 +186,7 @@ def get_data_ptr(arr: AFArrayType, size: int, dtype: Dtype, /) -> ctypes.Array: """ c_shape = dtype.c_type * size ctypes_array = c_shape() - safe_call(backend_api.af_get_data_ptr(ctypes.pointer(ctypes_array), arr)) + safe_call(_backend.clib.af_get_data_ptr(ctypes.pointer(ctypes_array), arr)) return ctypes_array @@ -196,7 +195,7 @@ def copy_array(arr: AFArrayType) -> AFArrayType: source: https://arrayfire.org/docs/group__c__api__mat.htm#ga6040dc6f0eb127402fbf62c1165f0b9d """ out = ctypes.c_void_p(0) - safe_call(backend_api.af_copy_array(ctypes.pointer(out), arr)) + safe_call(_backend.clib.af_copy_array(ctypes.pointer(out), arr)) return out @@ -224,7 +223,7 @@ def index_gen( source: https://arrayfire.org/docs/group__index__func__index.htm#ga14a7d149dba0ed0b977335a3df9d91e6 """ out = ctypes.c_void_p(0) - safe_call(backend_api.af_index_gen(ctypes.pointer(out), arr, c_dim_t(ndims), indices.pointer)) + safe_call(_backend.clib.af_index_gen(ctypes.pointer(out), arr, c_dim_t(ndims), indices.pointer)) return out @@ -233,7 +232,7 @@ def transpose(arr: AFArrayType, conjugate: bool, /) -> AFArrayType: https://arrayfire.org/docs/group__blas__func__transpose.htm#ga716b2b9bf190c8f8d0970aef2b57d8e7 """ out = ctypes.c_void_p(0) - safe_call(backend_api.af_transpose(ctypes.pointer(out), arr, conjugate)) + safe_call(_backend.clib.af_transpose(ctypes.pointer(out), arr, conjugate)) return out @@ -243,7 +242,7 @@ def reorder(arr: AFArrayType, ndims: int, /) -> AFArrayType: """ out = ctypes.c_void_p(0) c_shape = CShape(*(tuple(reversed(range(ndims))) + tuple(range(ndims, 4)))) - safe_call(backend_api.af_reorder(ctypes.pointer(out), arr, *c_shape)) + safe_call(_backend.clib.af_reorder(ctypes.pointer(out), arr, *c_shape)) return out @@ -254,9 +253,9 @@ def array_as_str(arr: AFArrayType) -> str: - https://arrayfire.org/docs/group__device__func__free__host.htm#ga3f1149a837a7ebbe8002d5d2244e3370 """ arr_str = ctypes.c_char_p(0) - safe_call(backend_api.af_array_to_string(ctypes.pointer(arr_str), "", arr, 4, True)) + safe_call(_backend.clib.af_array_to_string(ctypes.pointer(arr_str), "", arr, 4, True)) py_str = to_str(arr_str) - safe_call(backend_api.af_free_host(arr_str)) + safe_call(_backend.clib.af_free_host(arr_str)) return py_str @@ -265,7 +264,7 @@ def where(arr: AFArrayType) -> AFArrayType: source: https://arrayfire.org/docs/group__scan__func__where.htm#gafda59a3d25d35238592dd09907be9d07 """ out = ctypes.c_void_p(0) - safe_call(backend_api.af_where(ctypes.pointer(out), arr)) + safe_call(_backend.clib.af_where(ctypes.pointer(out), arr)) return out @@ -275,7 +274,7 @@ def randu(shape: Tuple[int, ...], dtype: Dtype, /) -> AFArrayType: """ out = ctypes.c_void_p(0) c_shape = CShape(*shape) - safe_call(backend_api.af_randu(ctypes.pointer(out), *c_shape, dtype.c_api_value)) + safe_call(_backend.clib.af_randu(ctypes.pointer(out), *c_shape, dtype.c_api_value)) return out @@ -287,7 +286,7 @@ def get_last_error() -> ctypes.c_char_p: source: https://arrayfire.org/docs/exception_8h.htm#a4f0227c17954d343021313f77e695c8e """ out = ctypes.c_char_p(0) - backend_api.af_get_last_error(ctypes.pointer(out), ctypes.pointer(c_dim_t(0))) + _backend.clib.af_get_last_error(ctypes.pointer(out), ctypes.pointer(c_dim_t(0))) return out @@ -298,7 +297,7 @@ def set_backend(backend_c_value: int, /) -> None: """ source: https://arrayfire.org/docs/group__unified__func__setbackend.htm#ga6fde820e8802776b7fc823504b37f1b4 """ - safe_call(backend_api.af_set_backend(backend_c_value)) + safe_call(_backend.clib.af_set_backend(backend_c_value)) return None @@ -307,7 +306,7 @@ def get_backend_count() -> int: source: https://arrayfire.org/docs/group__unified__func__getbackendcount.htm#gad38c2dfedfdabfa264afa46d8664e9cd """ out = ctypes.c_int(0) - safe_call(backend_api.get().af_get_backend_count(ctypes.pointer(out))) + safe_call(_backend.clib.get().af_get_backend_count(ctypes.pointer(out))) return out.value @@ -316,7 +315,7 @@ def get_device_id(arr: AFArrayType, /) -> int: source: https://arrayfire.org/docs/group__unified__func__getdeviceid.htm#ga5d94b64dccd1c7cbc7a3a69fa64888c3 """ out = ctypes.c_int(0) - safe_call(backend_api.get().af_get_device_id(ctypes.pointer(out), arr)) + safe_call(_backend.clib.get().af_get_device_id(ctypes.pointer(out), arr)) return out.value @@ -325,7 +324,7 @@ def get_size_of(dtype: Dtype, /) -> int: source: https://arrayfire.org/docs/util_8h.htm#a8b72cffd10a92a7a2ee7f52dadda5216 """ out = ctypes.c_size_t(0) - safe_call(backend_api.get().af_get_size_of(ctypes.pointer(out), dtype.c_api_value)) + safe_call(_backend.clib.get().af_get_size_of(ctypes.pointer(out), dtype.c_api_value)) return out.value @@ -334,5 +333,5 @@ def get_backend_id(arr: AFArrayType, /) -> int: source: https://arrayfire.org/docs/group__unified__func__getbackendid.htm#ga5fc39e209e1886cf250aec265c0d9079 """ out = ctypes.c_int(0) - safe_call(backend_api.get().af_get_backend_id(ctypes.pointer(out), arr)) + safe_call(_backend.clib.get().af_get_backend_id(ctypes.pointer(out), arr)) return out.value diff --git a/arrayfire/backend/api.py b/arrayfire/backend/api.py deleted file mode 100644 index e760195..0000000 --- a/arrayfire/backend/api.py +++ /dev/null @@ -1,137 +0,0 @@ -__all__ = ["BackendPlatform"] - -import ctypes -import enum -import sys -from enum import Enum -from pathlib import Path -from typing import Iterator, List, Optional - -from arrayfire.logger import logger -from arrayfire.platform import get_platform_config, is_arch_x86 - - -class _LibPrefixes(Enum): - forge = "" - arrayfire = "af" - - -class BackendPlatform(enum.Enum): - unified = 0 # NOTE It is set as Default value on Arrayfire backend - cpu = 1 - cuda = 2 - opencl = 4 - - def __iter__(self) -> Iterator: - # NOTE cpu comes last because we want to keep this order priorty during backend initialization - return iter((self.unified, self.cuda, self.opencl, self.cpu)) - - -class Backend: - platform: BackendPlatform - library: ctypes.CDLL - - def __init__(self) -> None: - self._platform_config = get_platform_config() - - self._load_forge_lib() - self._load_backend_libs() - - def _load_forge_lib(self) -> None: - for libname in self._libnames("forge", _LibPrefixes.forge): - try: - ctypes.cdll.LoadLibrary(str(libname)) - logger.info(f"Loaded {libname}") - break - except OSError: - logger.warning(f"Unable to load {libname}") - pass - - def _load_backend_libs(self) -> None: - for platform in BackendPlatform: - self._load_backend_lib(platform) - - if self.platform: - logger.info(f"Setting {platform.name} as backend.") - break - - if not self.platform and not self.library: - raise RuntimeError( - "Could not load any ArrayFire libraries.\n" - "Please look at https://github.com/arrayfire/arrayfire-python/wiki for more information." - ) - - def _load_backend_lib(self, platform: BackendPlatform) -> None: - # NOTE we still set unified cdll to it's original name later, even if the path search is different - name = platform.name if platform != BackendPlatform.unified else "" - - for libname in self._libnames(name, _LibPrefixes.arrayfire): - try: - ctypes.cdll.LoadLibrary(str(libname)) - self.platform = platform - self.library = ctypes.CDLL(str(libname)) - - if platform == BackendPlatform.cuda: - self._load_nvrtc_builtins_lib(libname.parent) - - logger.info(f"Loaded {libname}") - break - except OSError: - logger.warning(f"Unable to load {libname}") - pass - - def _load_nvrtc_builtins_lib(self, lib_path: Path) -> None: - nvrtc_name = self._find_nvrtc_builtins_libname(lib_path) - if nvrtc_name: - ctypes.cdll.LoadLibrary(str(lib_path / nvrtc_name)) - logger.info(f"Loaded {lib_path / nvrtc_name}") - else: - logger.warning("Could not find local nvrtc-builtins library") - - def _libnames(self, name: str, lib: _LibPrefixes, ver_major: Optional[str] = None) -> List[Path]: - post = self._platform_config.lib_postfix if ver_major is None else ver_major - libname = self._platform_config.lib_prefix + lib.value + name + post - - lib64_path = self._platform_config.af_path / "lib64" - search_path = lib64_path if lib64_path.is_dir() else self._platform_config.af_path / "lib" - - site_path = Path(sys.prefix) / "lib64" if not is_arch_x86() else Path(sys.prefix) / "lib" - - # prefer locally packaged arrayfire libraries if they exist - af_module = __import__(__name__) - local_path = Path(af_module.__path__[0]) if af_module.__path__ else Path("") - - libpaths = [Path("", libname), site_path / libname, local_path / libname] - - if self._platform_config.af_path: # prefer specified AF_PATH if exists - return [search_path / libname] + libpaths - else: - libpaths.insert(2, Path(str(search_path), libname)) - return libpaths - - def _find_nvrtc_builtins_libname(self, search_path: Path) -> Optional[str]: - for f in search_path.iterdir(): - if "nvrtc-builtins" in f.name: - return f.name - return None - - -# HACK for osx -# backend_api = ctypes.CDLL("/opt/arrayfire//lib/libafcpu.3.dylib") -# HACK for windows -# backend_api = ctypes.CDLL("C:/Program Files/ArrayFire/v3/lib/afcpu.dll") -_backend = Backend() -backend_api = _backend.library - - -def get_backend() -> Backend: - """ - Get the current active backend. - - Returns - ------- - value : Backend - Current active backend. - """ - - return _backend diff --git a/arrayfire/backend/constants.py b/arrayfire/backend/constants.py deleted file mode 100755 index 20c1418..0000000 --- a/arrayfire/backend/constants.py +++ /dev/null @@ -1,7 +0,0 @@ -from dataclasses import dataclass - - -@dataclass(frozen=True) -class ArrayBuffer: - address: int - length: int = 0 diff --git a/arrayfire/dtypes.py b/arrayfire/dtypes.py index 3417034..ddddc09 100644 --- a/arrayfire/dtypes.py +++ b/arrayfire/dtypes.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from typing import Tuple, Type, Union -from arrayfire.platform import is_arch_x86 +from arrayfire.backend._backend import is_arch_x86 CType = Type[ctypes._SimpleCData] python_bool = bool @@ -30,8 +30,8 @@ class Dtype: float16 = Dtype("e", ctypes.c_uint16, "half", 12) float32 = Dtype("f", ctypes.c_float, "float", 0) float64 = Dtype("d", ctypes.c_double, "double", 2) -complex64 = Dtype("F", ctypes.c_float * 2, "float complext", 1) # type: ignore[arg-type] -complex128 = Dtype("D", ctypes.c_double * 2, "double complext", 3) # type: ignore[arg-type] +complex64 = Dtype("F", ctypes.c_float * 2, "float complex", 1) # type: ignore[arg-type] +complex128 = Dtype("D", ctypes.c_double * 2, "double complex", 3) # type: ignore[arg-type] bool = Dtype("b", ctypes.c_bool, "bool", 4) supported_dtypes = ( diff --git a/arrayfire/library/operators.py b/arrayfire/library/operators.py index fdfaeb0..57f2cb1 100644 --- a/arrayfire/library/operators.py +++ b/arrayfire/library/operators.py @@ -1,7 +1,7 @@ from typing import Callable from arrayfire import Array -from arrayfire.backend import c_library as wrapper +from arrayfire.backend import _clib_wrapper as wrapper class return_copy: diff --git a/arrayfire/platform.py b/arrayfire/platform.py deleted file mode 100755 index df0144e..0000000 --- a/arrayfire/platform.py +++ /dev/null @@ -1,100 +0,0 @@ -import ctypes -import os -import platform -from dataclasses import dataclass -from enum import Enum -from pathlib import Path -from typing import Iterator - -from arrayfire.version import ARRAYFIRE_VER_MAJOR - - -def is_arch_x86() -> bool: - machine = platform.machine() - return platform.architecture()[0][0:2] == "32" and (machine[-2:] == "86" or machine[0:3] == "arm") - - -class _SupportedPlatforms(Enum): - windows = "Windows" - darwin = "Darwin" # OSX - linux = "Linux" - - @classmethod - def is_cygwin(cls, name: str) -> bool: - return "cyg" in name.lower() - - -@dataclass(frozen=True) -class PlatformConfig: - lib_prefix: str - lib_postfix: str - af_path: Path - cuda_found: bool - - def __iter__(self) -> Iterator: - return iter((self.lib_prefix, self.lib_postfix, self.af_path, self.af_path, self.cuda_found)) - - -def get_platform_config() -> PlatformConfig: - platform_name = platform.system() - cuda_found = False - - try: - af_path = Path(os.environ["AF_PATH"]) - except KeyError: - af_path = None - - try: - cuda_path = Path(os.environ["CUDA_PATH"]) - except KeyError: - cuda_path = None - - if platform_name == _SupportedPlatforms.windows.value or _SupportedPlatforms.is_cygwin(platform_name): - if platform_name == _SupportedPlatforms.windows.value: - # HACK Supressing crashes caused by missing dlls - # http://stackoverflow.com/questions/8347266/missing-dll-print-message-instead-of-launching-a-popup - # https://msdn.microsoft.com/en-us/library/windows/desktop/ms680621.aspx - ctypes.windll.kernel32.SetErrorMode(0x0001 | 0x0002) # type: ignore[attr-defined] - - if not af_path: - af_path = _find_default_path(f"C:/Program Files/ArrayFire/v{ARRAYFIRE_VER_MAJOR}") - - if cuda_path and (cuda_path / "bin").is_dir() and (cuda_path / "nvvm/bin").is_dir(): - cuda_found = True - - return PlatformConfig("", ".dll", af_path, cuda_found) - - if platform_name == _SupportedPlatforms.darwin.value: - default_cuda_path = Path("/usr/local/cuda/") - - if not af_path: - af_path = _find_default_path("/opt/arrayfire", "/usr/local") - - if not (cuda_path and default_cuda_path.exists()): - cuda_found = (default_cuda_path / "lib").is_dir() and (default_cuda_path / "/nvvm/lib").is_dir() - - return PlatformConfig("lib", f".{ARRAYFIRE_VER_MAJOR}.dylib", af_path, cuda_found) - - if platform_name == _SupportedPlatforms.linux.value: - default_cuda_path = Path("/usr/local/cuda/") - - if not af_path: - af_path = _find_default_path(f"/opt/arrayfire-{ARRAYFIRE_VER_MAJOR}", "/opt/arrayfire/", "/usr/local/") - - if not (cuda_path and default_cuda_path.exists()): - if "64" in platform.architecture()[0]: # Check either is 64 bit arch is selected - cuda_found = (default_cuda_path / "lib64").is_dir() and (default_cuda_path / "nvvm/lib64").is_dir() - else: - cuda_found = (default_cuda_path / "lib").is_dir() and (default_cuda_path / "nvvm/lib").is_dir() - - return PlatformConfig("lib", f".so.{ARRAYFIRE_VER_MAJOR}", af_path, cuda_found) - - raise OSError(f"{platform_name} is not supported.") - - -def _find_default_path(*args: str) -> Path: - for path in args: - default_path = Path(path) - if default_path.exists(): - return default_path - raise ValueError("None of specified default paths were found.") diff --git a/arrayfire/version.py b/arrayfire/version.py index 85630b2..54d94c7 100644 --- a/arrayfire/version.py +++ b/arrayfire/version.py @@ -1,7 +1,7 @@ import os -_MAJOR = "0" -_MINOR = "1" +_MAJOR = "4" +_MINOR = "0" # On main and in a nightly release the patch should be one ahead of the last # released build. _PATCH = "0" diff --git a/tests/test_operators.py b/tests/test_operators.py index 188828a..72c627a 100755 --- a/tests/test_operators.py +++ b/tests/test_operators.py @@ -1,11 +1,9 @@ -from typing import Any - from arrayfire import Array from arrayfire.library import operators class TestArithmeticOperators: - def setup_method(self, method: Any) -> None: + def setup_method(self) -> None: self.array1 = Array([1, 2, 3]) self.array2 = Array([4, 5, 6]) From eabc241a8e3f8f6a7ec708a452fc121c1091e263 Mon Sep 17 00:00:00 2001 From: Anton Chernyatevich Date: Wed, 16 Aug 2023 16:21:10 +0300 Subject: [PATCH 19/31] Remove pointer source from array init --- arrayfire/array_api/_creation_function.py | 8 ++-- arrayfire/array_api/_elementwise_functions.py | 4 +- arrayfire/array_api/_searching_functions.py | 2 +- arrayfire/array_object.py | 45 +++++++------------ arrayfire/backend/_clib_wrapper/__init__.py | 9 +++- arrayfire/backend/_clib_wrapper/_base.py | 3 ++ .../backend/_clib_wrapper/_constant_array.py | 2 +- arrayfire/backend/_clib_wrapper/_indexing.py | 17 ++++++- arrayfire/backend/_clib_wrapper/_operators.py | 2 +- .../_clib_wrapper/_reduction_operations.py | 2 +- arrayfire/backend/_clib_wrapper/_unsorted.py | 2 +- 11 files changed, 55 insertions(+), 41 deletions(-) create mode 100755 arrayfire/backend/_clib_wrapper/_base.py diff --git a/arrayfire/array_api/_creation_function.py b/arrayfire/array_api/_creation_function.py index 9ea4e4f..c920124 100755 --- a/arrayfire/array_api/_creation_function.py +++ b/arrayfire/array_api/_creation_function.py @@ -38,17 +38,17 @@ def asarray( raise OverflowError("Integer out of bounds for array dtypes") if device == Device.cpu or device is None: - pointer_source = PointerSource.host + to_device = False elif device == Device.gpu: - pointer_source = PointerSource.device + to_device = True else: raise ValueError(f"Unsupported device {device!r}") if isinstance(obj, int | float): - afarray = AFArray([obj], dtype=dtype, shape=(1,), pointer_source=pointer_source) + afarray = AFArray([obj], dtype=dtype, shape=(1,), to_device=to_device) return Array._new(afarray) - afarray = AFArray(obj, dtype=dtype, pointer_source=pointer_source) + afarray = AFArray(obj, dtype=dtype, to_device=to_device) return Array._new(afarray) diff --git a/arrayfire/array_api/_elementwise_functions.py b/arrayfire/array_api/_elementwise_functions.py index 9903425..8c0bd22 100755 --- a/arrayfire/array_api/_elementwise_functions.py +++ b/arrayfire/array_api/_elementwise_functions.py @@ -1,9 +1,9 @@ from __future__ import annotations -from ._array_object import Array - from arrayfire.library import operators +from ._array_object import Array + def abs(x: Array, /) -> Array: return NotImplemented diff --git a/arrayfire/array_api/_searching_functions.py b/arrayfire/array_api/_searching_functions.py index fafe009..1ef4438 100755 --- a/arrayfire/array_api/_searching_functions.py +++ b/arrayfire/array_api/_searching_functions.py @@ -18,4 +18,4 @@ def nonzero(x: Array, /) -> Tuple[Array, ...]: def where(condition: Array, x1: Array, x2: Array, /) -> Array: - return NotImplemented \ No newline at end of file + return NotImplemented diff --git a/arrayfire/array_object.py b/arrayfire/array_object.py index a12a901..3d66360 100755 --- a/arrayfire/array_object.py +++ b/arrayfire/array_object.py @@ -1,19 +1,18 @@ from __future__ import annotations import array as py_array -import ctypes -import enum from dataclasses import dataclass -from typing import Any, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union from .backend import _clib_wrapper as wrapper -from .backend._clib_wrapper._indexing import CIndexStructure, IndexStructure from .dtypes import CType, Dtype, c_api_value_to_dtype, float32, str_to_dtype from .library.device import PointerSource -# TODO use int | float in operators -> remove bool | complex support +if TYPE_CHECKING: + from ctypes import Array as CArray + from enum import Enum -AFArrayType = ctypes.c_void_p +# TODO use int | float in operators -> remove bool | complex support @dataclass(frozen=True) @@ -25,10 +24,10 @@ class _ArrayBuffer: class Array: def __init__( self, - obj: Union[None, Array, py_array.array, int, AFArrayType, List[Union[int, float]]] = None, + obj: Union[None, Array, py_array.array, int, wrapper.AFArrayType, List[Union[int, float]]] = None, dtype: Union[None, Dtype, str] = None, shape: Tuple[int, ...] = (), - pointer_source: PointerSource = PointerSource.host, + to_device: bool = False, offset: Optional[CType] = None, strides: Optional[Tuple[int, ...]] = None, ) -> None: @@ -62,8 +61,10 @@ def __init__( _type_char = _array.typecode _array_buffer = _ArrayBuffer(*_array.buffer_info()) - elif isinstance(obj, int) or isinstance(obj, AFArrayType): - _array_buffer = _ArrayBuffer(obj if not isinstance(obj, AFArrayType) else obj.value) # type: ignore + elif isinstance(obj, int) or isinstance(obj, wrapper.AFArrayType): + _array_buffer = _ArrayBuffer( + obj if not isinstance(obj, wrapper.AFArrayType) else obj.value # type: ignore[arg-type] + ) if not shape: raise TypeError("Expected to receive the initial shape due to the obj being a data pointer.") @@ -86,7 +87,7 @@ def __init__( raise TypeError("Can not create array of requested type from input data type") if not (offset or strides): - if pointer_source == PointerSource.host: + if not to_device: self.arr = wrapper.create_array(shape, dtype, _array_buffer) return @@ -94,7 +95,7 @@ def __init__( return self.arr = wrapper.create_strided_array( - shape, dtype, _array_buffer, offset, strides, pointer_source # type: ignore[arg-type] + shape, dtype, _array_buffer, offset, strides, PointerSource.device # type: ignore[arg-type] ) # Arithmetic Operators @@ -699,7 +700,7 @@ def __dlpack__(self, *, stream: Union[None, int, Any] = None): # type: ignore[n # TODO implementation and expected return type -> PyCapsule return NotImplemented - def __dlpack_device__(self) -> Tuple[enum.Enum, int]: + def __dlpack_device__(self) -> Tuple[Enum, int]: # TODO return NotImplemented @@ -738,7 +739,7 @@ def __getitem__(self, key: IndexKey, /) -> Array: return out # HACK known issue - out.arr = wrapper.index_gen(self.arr, ndims, key, _get_indices(key)) # type: ignore[arg-type] + out.arr = wrapper.index_gen(self.arr, ndims, key, wrapper.get_indices(key)) # type: ignore[arg-type] return out def __index__(self) -> int: @@ -896,7 +897,7 @@ def to_list(self, row_major: bool = False) -> List[Union[None, int, float, bool, out.append(ctypes_array[tuple(sub_list)]) # type: ignore[call-overload] # FIXME return out - def to_ctype_array(self, row_major: bool = False) -> ctypes.Array: + def to_ctype_array(self, row_major: bool = False) -> CArray: if self.is_empty(): raise RuntimeError("Can not convert an empty array to ctype.") @@ -916,7 +917,7 @@ def copy(self) -> Array: # BUG: this is not a deep copy return self @classmethod - def from_afarray(cls, array: AFArrayType) -> None: + def from_afarray(cls, array: wrapper.AFArrayType) -> None: cls.arr = array @@ -957,15 +958,3 @@ def _process_c_function(lhs: Union[int, float, Array], rhs: Union[int, float, Ar out.arr = c_function(lhs_array, rhs_array) return out - - -def _get_indices(key: IndexKey) -> CIndexStructure: - indices = CIndexStructure() - - if isinstance(key, tuple): - for n in range(len(key)): - indices[n] = IndexStructure(key[n]) - else: - indices[0] = IndexStructure(key) - - return indices diff --git a/arrayfire/backend/_clib_wrapper/__init__.py b/arrayfire/backend/_clib_wrapper/__init__.py index f39f3e3..f0b70f0 100755 --- a/arrayfire/backend/_clib_wrapper/__init__.py +++ b/arrayfire/backend/_clib_wrapper/__init__.py @@ -1,6 +1,10 @@ # flake8: noqa -__all__ = [ +__all__ = ["AFArrayType"] + +from ._base import AFArrayType + +__all__ += [ "add", "sub", "mul", @@ -109,3 +113,6 @@ __all__ += ["create_constant_array"] from ._constant_array import create_constant_array + +__all__ += ["get_indices"] +from ._indexing import get_indices diff --git a/arrayfire/backend/_clib_wrapper/_base.py b/arrayfire/backend/_clib_wrapper/_base.py new file mode 100755 index 0000000..92cbc67 --- /dev/null +++ b/arrayfire/backend/_clib_wrapper/_base.py @@ -0,0 +1,3 @@ +import ctypes + +AFArrayType = ctypes.c_void_p diff --git a/arrayfire/backend/_clib_wrapper/_constant_array.py b/arrayfire/backend/_clib_wrapper/_constant_array.py index 1921107..6646656 100755 --- a/arrayfire/backend/_clib_wrapper/_constant_array.py +++ b/arrayfire/backend/_clib_wrapper/_constant_array.py @@ -9,7 +9,7 @@ from ._error_handler import safe_call if TYPE_CHECKING: - from arrayfire.array_object import AFArrayType + from ._base import AFArrayType def _constant_complex(number: Union[int, float], shape: Tuple[int, ...], dtype: Dtype, /) -> AFArrayType: diff --git a/arrayfire/backend/_clib_wrapper/_indexing.py b/arrayfire/backend/_clib_wrapper/_indexing.py index a20990b..f71767c 100755 --- a/arrayfire/backend/_clib_wrapper/_indexing.py +++ b/arrayfire/backend/_clib_wrapper/_indexing.py @@ -2,13 +2,16 @@ import ctypes import math -from typing import Any, Union +from typing import TYPE_CHECKING, Any, Tuple, Union from arrayfire.backend._backend import _backend from arrayfire.library.broadcast import bcast_var from ._error_handler import safe_call +if TYPE_CHECKING: + from arrayfire import Array + class _IndexSequence(ctypes.Structure): """ @@ -246,3 +249,15 @@ def __getitem__(self, idx: int) -> IndexStructure: def __setitem__(self, idx: int, value: IndexStructure) -> None: self.array[idx] = value self.idxs[idx] = value + + +def get_indices(key: Union[int, slice, Tuple[Union[int, slice], ...]]) -> CIndexStructure: + indices = CIndexStructure() + + if isinstance(key, tuple): + for n in range(len(key)): + indices[n] = IndexStructure(key[n]) + else: + indices[0] = IndexStructure(key) + + return indices diff --git a/arrayfire/backend/_clib_wrapper/_operators.py b/arrayfire/backend/_clib_wrapper/_operators.py index 78c5515..7fda982 100755 --- a/arrayfire/backend/_clib_wrapper/_operators.py +++ b/arrayfire/backend/_clib_wrapper/_operators.py @@ -9,7 +9,7 @@ from ._error_handler import safe_call if TYPE_CHECKING: - from arrayfire.array_object import AFArrayType + from ._base import AFArrayType # Arithmetic Operators diff --git a/arrayfire/backend/_clib_wrapper/_reduction_operations.py b/arrayfire/backend/_clib_wrapper/_reduction_operations.py index 8ef3ed8..fed81c7 100755 --- a/arrayfire/backend/_clib_wrapper/_reduction_operations.py +++ b/arrayfire/backend/_clib_wrapper/_reduction_operations.py @@ -8,7 +8,7 @@ from ._error_handler import safe_call if TYPE_CHECKING: - from arrayfire.array_object import AFArrayType + from ._base import AFArrayType def count_all(x: AFArrayType) -> Union[int, float, complex]: diff --git a/arrayfire/backend/_clib_wrapper/_unsorted.py b/arrayfire/backend/_clib_wrapper/_unsorted.py index 665bc3d..cd2fab2 100755 --- a/arrayfire/backend/_clib_wrapper/_unsorted.py +++ b/arrayfire/backend/_clib_wrapper/_unsorted.py @@ -10,7 +10,7 @@ from ._error_handler import safe_call if TYPE_CHECKING: - from arrayfire.array_object import AFArrayType, _ArrayBuffer + from ._base import AFArrayType, _ArrayBuffer # Array management From 1ff492dcec045337520da2762b886d07c753ee80 Mon Sep 17 00:00:00 2001 From: Anton Chernyatevich Date: Thu, 17 Aug 2023 04:04:11 +0300 Subject: [PATCH 20/31] operations refactoring --- arrayfire/__init__.py | 156 ++++- arrayfire/array_api/_array_object.py | 19 + arrayfire/array_api/_elementwise_functions.py | 9 +- arrayfire/array_object.py | 22 +- arrayfire/backend/_backend_functions.py | 44 +- arrayfire/backend/_clib_wrapper/__init__.py | 109 ++++ arrayfire/backend/_clib_wrapper/_operators.py | 385 ++++++++++- arrayfire/dtypes.py | 48 +- arrayfire/library/operators.py | 612 +++++++++++++++++- arrayfire/library/operators2.py | 310 +++++++++ arrayfire/logger.py | 4 +- tests/array_object/test_initialization.py | 2 +- tests/test_dtypes.py | 79 +++ 13 files changed, 1711 insertions(+), 88 deletions(-) mode change 100644 => 100755 arrayfire/library/operators.py create mode 100644 arrayfire/library/operators2.py create mode 100755 tests/test_dtypes.py diff --git a/arrayfire/__init__.py b/arrayfire/__init__.py index be66dfa..ad42594 100755 --- a/arrayfire/__init__.py +++ b/arrayfire/__init__.py @@ -7,8 +7,8 @@ __all__ += ["__arrayfire_version__"] __arrayfire_version__ = ARRAYFIRE_VERSION -__all__ += ["Array"] -from .array_object import Array +__all__ += ["Array", "return_copy"] +from .array_object import Array, return_copy __all__ += [ "int8", @@ -71,6 +71,154 @@ set_backend, ) -__all__ += ["add", "sub"] +# __all__ += [ +# "add", +# "sub", +# "mul", +# "div", +# "mod", +# "pow", +# "bitnot", +# "bitand", +# "bitor", +# "bitxor", +# "bitshiftl", +# "bitshiftr", +# "lt", +# "le", +# "gt", +# "ge", +# "eq", +# "neq", +# "sin", +# "cos", +# "tan", +# "asin", +# "acos", +# "atan", +# "atan2", +# "sinh", +# "cosh", +# "tanh", +# "asinh", +# "acosh", +# "atanh", +# "exp", +# "expm1", +# "log", +# "log1p", +# "log2", +# "log10", +# "sqrt", +# "cbrt", +# "hypot", +# "erf", +# "erfc", +# "tgamma", +# "lgamma", +# "pow2", +# "sign", +# "abs", +# "ceil", +# "floor", +# "round", +# "trunc", +# "isinf", +# "isnan", +# "iszero", +# "isinf", +# "isnan", +# "iszero", +# "isinf", +# "isnan", +# "clamp", +# "arg", +# "conjg", +# "cplx", +# "imag", +# "factorial", +# "maxof", +# "minof", +# "real", +# "rem", +# "root", +# "rsqrt", +# "sigmoid", +# "land", +# "lor", +# "lnot", +# ] -from .library.operators import add, sub + +# from .library.operators import ( +# abs, +# acos, +# acosh, +# add, +# arg, +# asin, +# asinh, +# atan, +# atan2, +# atanh, +# bitand, +# bitnot, +# bitor, +# bitshiftl, +# bitshiftr, +# bitxor, +# cbrt, +# ceil, +# clamp, +# conjg, +# cos, +# cosh, +# cplx, +# div, +# eq, +# erf, +# erfc, +# exp, +# expm1, +# factorial, +# floor, +# ge, +# gt, +# hypot, +# imag, +# isinf, +# isnan, +# iszero, +# land, +# le, +# lgamma, +# lnot, +# log, +# log1p, +# log2, +# log10, +# lor, +# lt, +# maxof, +# minof, +# mod, +# mul, +# neq, +# pow, +# pow2, +# real, +# rem, +# root, +# round, +# rsqrt, +# sigmoid, +# sign, +# sin, +# sinh, +# sqrt, +# sub, +# tan, +# tanh, +# tgamma, +# trunc, +# ) diff --git a/arrayfire/array_api/_array_object.py b/arrayfire/array_api/_array_object.py index 78313c5..a6af9b8 100755 --- a/arrayfire/array_api/_array_object.py +++ b/arrayfire/array_api/_array_object.py @@ -165,6 +165,25 @@ def _new(cls, x: Union[Array, bool, int, float, complex, NestedSequence, Support obj._array = x # type: ignore[assignment] return obj + def __str__(self: Array, /) -> str: + """ + Performs the operation __str__. + """ + return self._array.__str__()#.replace("array", "Array") + + # def __repr__(self: Array, /) -> str: + # """ + # Performs the operation __repr__. + # """ + # suffix = f", dtype={self.dtype.name})" + # if 0 in self.shape: + # prefix = "empty(" + # mid = str(self.shape) + # else: + # prefix = "Array(" + # mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix) + # return prefix + mid + suffix + def __abs__(self: Array, /) -> Array: """ Performs the operation __abs__. diff --git a/arrayfire/array_api/_elementwise_functions.py b/arrayfire/array_api/_elementwise_functions.py index 8c0bd22..713ad0b 100755 --- a/arrayfire/array_api/_elementwise_functions.py +++ b/arrayfire/array_api/_elementwise_functions.py @@ -6,7 +6,7 @@ def abs(x: Array, /) -> Array: - return NotImplemented + return Array._new(operators.abs(x._array)) def acos(x: Array, /) -> Array: @@ -42,11 +42,11 @@ def atanh(x: Array, /) -> Array: def bitwise_and(x1: Array, x2: Array, /) -> Array: - return NotImplemented + return Array._new(operators.bitand(x1._array, x2._array)) def bitwise_left_shift(x1: Array, x2: Array, /) -> Array: - return NotImplemented + return Array._new(operators.bitshiftl(x1._array, x2._array)) def bitwise_invert(x: Array, /) -> Array: @@ -189,9 +189,8 @@ def positive(x: Array, /) -> Array: return NotImplemented -# Note: the function name is different here def pow(x1: Array, x2: Array, /) -> Array: - return NotImplemented + return Array._new(operators.pow(x1._array, x2._array)) def real(x: Array, /) -> Array: diff --git a/arrayfire/array_object.py b/arrayfire/array_object.py index 3d66360..07a633e 100755 --- a/arrayfire/array_object.py +++ b/arrayfire/array_object.py @@ -2,7 +2,7 @@ import array as py_array from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union from .backend import _clib_wrapper as wrapper from .dtypes import CType, Dtype, c_api_value_to_dtype, float32, str_to_dtype @@ -21,6 +21,18 @@ class _ArrayBuffer: length: int = 0 +class return_copy: + # TODO merge with process_c_function in array_object + def __init__(self, func: Callable) -> None: + self.func = func + + def __call__(self, *x: Array) -> Array: + out = Array() + # import ipdb; ipdb.set_trace() + out.arr = self.func(*[item.arr for item in x]) + return out + + class Array: def __init__( self, @@ -904,7 +916,7 @@ def to_ctype_array(self, row_major: bool = False) -> CArray: array = _reorder(self) if row_major else self return wrapper.get_data_ptr(array.arr, array.size, array.dtype) - def copy(self) -> Array: # BUG: this is not a deep copy + def copy(self) -> Array: """ Performs a deep copy of the array. @@ -913,8 +925,8 @@ def copy(self) -> Array: # BUG: this is not a deep copy out: af.Array() An identical copy of self. """ - self.arr = wrapper.copy_array(self.arr) - return self + + return return_copy(wrapper.copy_array)(self) # type: ignore[return-value] @classmethod def from_afarray(cls, array: wrapper.AFArrayType) -> None: @@ -935,7 +947,7 @@ def _reorder(array: Array) -> Array: def _metadata_string(dtype: Dtype, dims: Optional[Tuple[int, ...]] = None) -> str: - return "arrayfire.Array()\n" f"Type: {dtype.typename}\n" f"Dims: {str(dims) if dims else ''}" + return "arrayfire.Array()\n" f"Type: {dtype.name}\n" f"Dims: {str(dims) if dims else ''}" def _process_c_function(lhs: Union[int, float, Array], rhs: Union[int, float, Array], c_function: Any) -> Array: diff --git a/arrayfire/backend/_backend_functions.py b/arrayfire/backend/_backend_functions.py index 483dabc..6f4f756 100755 --- a/arrayfire/backend/_backend_functions.py +++ b/arrayfire/backend/_backend_functions.py @@ -15,49 +15,49 @@ from arrayfire.dtypes import Dtype -def set_backend(platform: Union[BackendType, str]) -> None: +def set_backend(backend_type: Union[BackendType, str]) -> None: """ - Set a specific backend by platform name. + Set a specific backend by backend_type name. Parameters ---------- - platform : Union[BackendType, str] - Name of the backend platform to set. + backend_type : Union[BackendType, str] + Name of the backend backend_type to set. Raises ------ ValueError - If the given platform name is not a valid name for backend platform. + If the given backend_type name is not a valid name for backend backend_type. TypeError - If the given platform is not a valid type for backend platform. + If the given backend_type is not a valid type for backend backend_type. RuntimeError - If the given platform is already the active backend platform. + If the given backend_type is already the active backend backend_type. RuntimeError - If the given platform could not be set as new backend platform. + If the given backend_type could not be set as new backend backend_type. """ backend = get_backend() - current_active_platform = backend.platform + current_active_backend_type = backend.backend_type - if isinstance(platform, str): - if platform not in [d.name for d in BackendType]: - raise ValueError(f"{platform} is not a valid name for backend platform.") - platform = BackendType[platform] + if isinstance(backend_type, str): + if backend_type not in [d.name for d in BackendType]: + raise ValueError(f"{backend_type} is not a valid name for backend backend_type.") + backend_type = BackendType[backend_type] - if not isinstance(platform, BackendType): - raise TypeError(f"{platform} is not a valid type for backend platform.") + if not isinstance(backend_type, BackendType): + raise TypeError(f"{backend_type} is not a valid type for backend backend_type.") - if current_active_platform == platform: - raise RuntimeError(f"{platform} is already the active backend platform.") + if current_active_backend_type == backend_type: + raise RuntimeError(f"{backend_type} is already the active backend backend_type.") - if backend.platform == BackendType.unified: - c_set_backend(platform.value) + if backend.backend_type == BackendType.unified: + c_set_backend(backend_type.value) # NOTE keep in mind that this operation works in-place # FIXME should not access private API - backend._load_backend_lib(platform) + backend._load_backend_lib(backend_type) - if current_active_platform == backend.platform: - raise RuntimeError(f"Could not set {platform} as new backend platform. Consider checking logs.") + if current_active_backend_type == backend.backend_type: + raise RuntimeError(f"Could not set {backend_type} as new backend backend_type. Consider checking logs.") def get_array_backend_name(array: Array) -> str: diff --git a/arrayfire/backend/_clib_wrapper/__init__.py b/arrayfire/backend/_clib_wrapper/__init__.py index f0b70f0..cd6c061 100755 --- a/arrayfire/backend/_clib_wrapper/__init__.py +++ b/arrayfire/backend/_clib_wrapper/__init__.py @@ -23,27 +23,136 @@ "ge", "eq", "neq", + "sin", + "cos", + "tan", + "asin", + "acos", + "atan", + "atan2", + "sinh", + "cosh", + "tanh", + "asinh", + "acosh", + "atanh", + "exp", + "expm1", + "log", + "log1p", + "log2", + "log10", + "sqrt", + "cbrt", + "hypot", + "erf", + "erfc", + "tgamma", + "lgamma", + "pow2", + "sign", + "abs", + "ceil", + "floor", + "round", + "trunc", + "isinf", + "isnan", + "iszero", + "isinf", + "isnan", + "iszero", + "isinf", + "isnan", + "clamp", + "arg", + "conjg", + "cplx", + "imag", + "factorial", + "maxof", + "minof", + "real", + "rem", + "root", + "rsqrt", + "sigmoid", + "land", + "lor", + "lnot", ] from ._operators import ( + abs, + acos, + acosh, add, + arg, + asin, + asinh, + atan, + atan2, + atanh, bitand, bitnot, bitor, bitshiftl, bitshiftr, bitxor, + cbrt, + ceil, + clamp, + conjg, + cos, + cosh, + cplx, div, eq, + erf, + erfc, + exp, + expm1, + factorial, + floor, ge, gt, + hypot, + imag, + isinf, + isnan, + iszero, + land, le, + lgamma, + lnot, + log, + log1p, + log2, + log10, + lor, lt, + maxof, + minof, mod, mul, neq, pow, + pow2, + real, + rem, + root, + round, + rsqrt, + sigmoid, + sign, + sin, + sinh, + sqrt, sub, + tan, + tanh, + tgamma, + trunc, ) __all__ += [ diff --git a/arrayfire/backend/_clib_wrapper/_operators.py b/arrayfire/backend/_clib_wrapper/_operators.py index 7fda982..33809cb 100755 --- a/arrayfire/backend/_clib_wrapper/_operators.py +++ b/arrayfire/backend/_clib_wrapper/_operators.py @@ -1,7 +1,7 @@ from __future__ import annotations import ctypes -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Callable, Optional from arrayfire.backend._backend import _backend from arrayfire.library.broadcast import bcast_var @@ -63,9 +63,7 @@ def bitnot(arr: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__bitnot.htm#gaf97e8a38aab59ed2d3a742515467d01e """ - out = ctypes.c_void_p(0) - safe_call(_backend.clib.af_bitnot(ctypes.pointer(out), arr)) - return out + return _unary_op(_backend.clib.af_bitnot, arr) def bitand(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: @@ -145,7 +143,386 @@ def neq(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: return _binary_op(_backend.clib.af_neq, lhs, rhs) +# Numeric Functions + + +def clamp(arr: AFArrayType, /, lo: float, hi: float) -> AFArrayType: + """ + source: https://arrayfire.org/docs/group__arith__func__clamp.htm#gac4e785c5c877c7905e56f44ef0cb5e61 + """ + # TODO: check if lo and hi are of type float. Can be ArrayFire array as well + out = ctypes.c_void_p(0) + safe_call(_backend.clib.af_clamp(ctypes.pointer(out), arr, lo, hi)) + return out + + +def minof(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: + """ + source: https://arrayfire.org/docs/group__arith__func__min.htm#ga2b842c2d86df978ff68699aeaafca794 + """ + return _binary_op(_backend.clib.af_minof, lhs, rhs) + + +def maxof(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: + """ + source: https://arrayfire.org/docs/group__arith__func__max.htm#ga0cd47e70cf82b48730a97c59f494b421 + """ + return _binary_op(_backend.clib.af_maxof, lhs, rhs) + + +def rem(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: + """ + source: https://arrayfire.org/docs/group__arith__func__clamp.htm#gac4e785c5c877c7905e56f44ef0cb5e61 + """ + return _binary_op(_backend.clib.af_rem, lhs, rhs) + + +def abs(arr: AFArrayType, /) -> AFArrayType: + """ + source: https://arrayfire.org/docs/group__arith__func__abs.htm#ga7e8b3c848e6cda3d1f3b0c8b2b4c3f8f + """ + return _unary_op(_backend.clib.af_abs, arr) + + +def arg(arr: AFArrayType, /) -> AFArrayType: + """ + source: https://arrayfire.org/docs/group__arith__func__arg.htm#gad04de0f7948688378dcd3628628a7424 + """ + return _unary_op(_backend.clib.af_arg, arr) + + +def sign(arr: AFArrayType, /) -> AFArrayType: + """ + source: https://arrayfire.org/docs/group__arith__func__sign.htm#ga2d55dfb9b25e0a1316b70f01d5b44b35 + """ + return _unary_op(_backend.clib.af_sign, arr) + + +def round(arr: AFArrayType, /) -> AFArrayType: + """ + source: https://arrayfire.org/docs/group__arith__func__sign.htm#ga2d55dfb9b25e0a1316b70f01d5b44b35 + """ + return _unary_op(_backend.clib.af_round, arr) + + +def trunc(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_trunc, arr) + + +def floor(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_floor, arr) + + +def ceil(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_ceil, arr) + + +def hypot(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _binary_op(_backend.clib.af_hypot, lhs, rhs) + + +def sin(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_sin, arr) + + +def cos(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_cos, arr) + + +def tan(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_tan, arr) + + +def asin(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_asin, arr) + + +def acos(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_acos, arr) + + +def atan(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_atan, arr) + + +def atan2(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _binary_op(_backend.clib.af_atan2, lhs, rhs) + + +def cplx(lhs: AFArrayType, rhs: Optional[AFArrayType], /) -> AFArrayType: + """ + source: + """ + if rhs is None: + return _unary_op(_backend.clib.af_cplx, lhs) + else: + return _binary_op(_backend.clib.af_cplx2, lhs, rhs) + + +def real(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_real, arr) + + +def imag(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_imag, arr) + + +def conjg(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_conjg, arr) + + +def sinh(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_sinh, arr) + + +def cosh(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_cosh, arr) + + +def tanh(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_tanh, arr) + + +def asinh(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_asinh, arr) + + +def acosh(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_acosh, arr) + + +def atanh(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_atanh, arr) + + +def root(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _binary_op(_backend.clib.af_root, lhs, rhs) + + +def pow2(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_pow2, arr) + + +def sigmoid(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_sigmoid, arr) + + +def exp(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_exp, arr) + + +def expm1(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_expm1, arr) + + +def erf(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_erf, arr) + + +def erfc(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_erfc, arr) + + +def log(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_log, arr) + + +def log1p(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_log1p, arr) + + +def log10(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_log10, arr) + + +def log2(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_log2, arr) + + +def sqrt(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_sqrt, arr) + + +def rsqrt(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_rsqrt, arr) + + +def cbrt(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_cbrt, arr) + + +def factorial(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_factorial, arr) + + +def tgamma(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_tgamma, arr) + + +def lgamma(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_lgamma, arr) + + +def iszero(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_iszero, arr) + + +def isinf(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_isinf, arr) + + +def isnan(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_isnan, arr) + + +def land(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _binary_op(_backend.clib.af_and, lhs, rhs) + + +def lor(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _binary_op(_backend.clib.af_or, lhs, rhs) + + +def lnot(arr: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _unary_op(_backend.clib.af_not, arr) + + def _binary_op(c_func: Callable, lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: out = ctypes.c_void_p(0) safe_call(c_func(ctypes.pointer(out), lhs, rhs, bcast_var.get())) return out + + +def _unary_op(c_func: Callable, arr: AFArrayType, /) -> AFArrayType: + out = ctypes.c_void_p(0) + safe_call(c_func(ctypes.pointer(out), arr)) + return out diff --git a/arrayfire/dtypes.py b/arrayfire/dtypes.py index ddddc09..76dd06e 100644 --- a/arrayfire/dtypes.py +++ b/arrayfire/dtypes.py @@ -7,32 +7,39 @@ from arrayfire.backend._backend import is_arch_x86 CType = Type[ctypes._SimpleCData] -python_bool = bool +_python_bool = bool @dataclass(frozen=True) class Dtype: + name: str typecode: str c_type: CType typename: str c_api_value: int # Internal use only + def __str__(self) -> str: + return self.name + + def __repr__(self) -> str: + return f"arrayfire.{self.name}(typecode<{self.typecode}>)" + # Specification required -int8 = Dtype("i8", ctypes.c_char, "int8", 4) # HACK int8 - Not Supported, b8? -int16 = Dtype("h", ctypes.c_short, "short int", 10) -int32 = Dtype("i", ctypes.c_int, "int", 5) -int64 = Dtype("l", ctypes.c_longlong, "long int", 8) -uint8 = Dtype("B", ctypes.c_ubyte, "unsigned_char", 7) -uint16 = Dtype("H", ctypes.c_ushort, "unsigned short int", 11) -uint32 = Dtype("I", ctypes.c_uint, "unsigned int", 6) -uint64 = Dtype("L", ctypes.c_ulonglong, "unsigned long int", 9) -float16 = Dtype("e", ctypes.c_uint16, "half", 12) -float32 = Dtype("f", ctypes.c_float, "float", 0) -float64 = Dtype("d", ctypes.c_double, "double", 2) -complex64 = Dtype("F", ctypes.c_float * 2, "float complex", 1) # type: ignore[arg-type] -complex128 = Dtype("D", ctypes.c_double * 2, "double complex", 3) # type: ignore[arg-type] -bool = Dtype("b", ctypes.c_bool, "bool", 4) +int8 = Dtype("int8", "i8", ctypes.c_char, "int8", 4) # HACK int8 - Not Supported, b8? +int16 = Dtype("int16", "h", ctypes.c_short, "short int", 10) +int32 = Dtype("int32", "i", ctypes.c_int, "int", 5) +int64 = Dtype("int64", "l", ctypes.c_longlong, "long int", 8) +uint8 = Dtype("uint8", "B", ctypes.c_ubyte, "unsigned_char", 7) +uint16 = Dtype("uint16", "H", ctypes.c_ushort, "unsigned short int", 11) +uint32 = Dtype("uint32", "I", ctypes.c_uint, "unsigned int", 6) +uint64 = Dtype("uint64", "L", ctypes.c_ulonglong, "unsigned long int", 9) +float16 = Dtype("float16", "e", ctypes.c_uint16, "half", 12) +float32 = Dtype("float32", "f", ctypes.c_float, "float", 0) +float64 = Dtype("float64", "d", ctypes.c_double, "double", 2) +complex64 = Dtype("complex64", "F", ctypes.c_float * 2, "float complex", 1) # type: ignore[arg-type] +complex128 = Dtype("complex128", "D", ctypes.c_double * 2, "double complex", 3) # type: ignore[arg-type] +bool = Dtype("bool", "b", ctypes.c_bool, "bool", 4) supported_dtypes = ( int16, @@ -48,6 +55,7 @@ class Dtype: complex64, complex128, bool, + int8 # BUG if place on top of the list ) @@ -79,10 +87,10 @@ def to_str(c_str: ctypes.c_char_p) -> str: return str(c_str.value.decode("utf-8")) # type: ignore[union-attr] -def implicit_dtype(number: Union[int, float], array_dtype: Dtype) -> Dtype: - if isinstance(number, python_bool): +def implicit_dtype(number: Union[int, float, _python_bool, complex], array_dtype: Dtype) -> Dtype: + if isinstance(number, _python_bool): number_dtype = bool - if isinstance(number, int): + elif isinstance(number, int): number_dtype = int64 elif isinstance(number, float): number_dtype = float64 @@ -111,9 +119,9 @@ def c_api_value_to_dtype(value: int) -> Dtype: raise TypeError("There is no supported dtype that matches passed dtype C API value.") -def str_to_dtype(value: int) -> Dtype: +def str_to_dtype(value: str) -> Dtype: for dtype in supported_dtypes: - if value == dtype.typecode or value == dtype.typename: + if value == dtype.typecode or value == dtype.typename or value == dtype.name: return dtype raise TypeError("There is no supported dtype that matches passed dtype typecode.") diff --git a/arrayfire/library/operators.py b/arrayfire/library/operators.py old mode 100644 new mode 100755 index 57f2cb1..a828e54 --- a/arrayfire/library/operators.py +++ b/arrayfire/library/operators.py @@ -1,25 +1,587 @@ -from typing import Callable - -from arrayfire import Array -from arrayfire.backend import _clib_wrapper as wrapper - - -class return_copy: - # TODO merge with process_c_function in array_object - def __init__(self, func: Callable) -> None: - self.func = func - - def __call__(self, x1: Array, x2: Array) -> Array: - out = Array() - out.arr = self.func(x1.arr, x2.arr) - return out - - -@return_copy -def add(x1: Array, x2: Array, /) -> Array: - return wrapper.add(x1, x2) # type: ignore[arg-type, return-value, no-any-return] # FIXME - - -@return_copy -def sub(x1: Array, x2: Array, /) -> Array: - return wrapper.sub(x1, x2) # type: ignore[arg-type, return-value, no-any-return] +from __future__ import annotations + +from typing import Union + +from arrayfire import Array, return_copy +from arrayfire.backend import _clib_wrapper as wrapper + + +@return_copy +def add(x1: Array, x2: Array, /) -> Array: + return wrapper.add(x1, x2) # type: ignore[arg-type, return-value] + + +@return_copy +def sub(x1: Array, x2: Array, /) -> Array: + return wrapper.sub(x1, x2) # type: ignore[arg-type, return-value] + + +@return_copy +def mul(x1: Array, x2: Array, /) -> Array: + return wrapper.mul(x1, x2) # type: ignore[arg-type, return-value] + + +@return_copy +def div(x1: Array, x2: Array, /) -> Array: + return wrapper.div(x1, x2) # type: ignore[arg-type, return-value] + + +@return_copy +def mod(x1: Union[int, float, Array], x2: Union[int, float, Array], /) -> Array: + """ + Calculate the modulus of two arrays or a scalar and an array. + + Parameters + ---------- + x1 : Union[int, float, Array] + The first array or scalar operand. + x2 : Union[int, float, Array] + The second array or scalar operand. + + Returns + ------- + result : Array + The array containing the modulus values after performing the operation. + + Raises + ------ + ValueError + If both operands are scalars or if the arrays' shapes do not match. + """ + + if isinstance(x1, Array) and isinstance(x2, Array): + if x1.shape != x2.shape: + raise ValueError("Array shapes must match.") + elif not isinstance(x1, Array) and not isinstance(x2, Array): + raise ValueError("At least one operand must be an Array.") + + return wrapper.mod(x1, x2) # type: ignore[arg-type, return-value] + + +@return_copy +def pow(x1: Array, x2: Array, /) -> Array: + """ + source: https://arrayfire.org/docs/group__arith__func__pow.htm#ga0f28be1a9c8b176a78c4a47f483e7fc6 + """ + return wrapper.pow(x1, x2) # type: ignore[arg-type, return-value] + + +@return_copy +def bitnot(x: Array, /) -> Array: + """ + source: https://arrayfire.org/docs/group__arith__func__bitnot.htm#gaf97e8a38aab59ed2d3a742515467d01e + """ + return wrapper.bitnot(x1, x2) # type: ignore[arg-type, return-value] + + +@return_copy +def bitand(x1: Array, x2: Array, /) -> Array: + """ + source: https://arrayfire.org/docs/group__arith__func__bitand.htm#ga45c0779ade4703708596df11cca98800 + """ + return wrapper.bitand(x1, x2) # type: ignore[arg-type, return-value] + + +@return_copy +def bitor(x1: Array, x2: Array, /) -> Array: + """ + source: https://arrayfire.org/docs/group__arith__func__bitor.htm#ga84c99f77d1d83fd53f949b4d67b5b210 + """ + return wrapper.bitor(x1, x2) # type: ignore[arg-type, return-value] + + +# @return_copy +# def bitxor(x1: Array, x2: Array, /) -> Array: +# """ +# source: https://arrayfire.org/docs/group__arith__func__bitxor.htm#ga8188620da6b432998e55fdd1fad22100 +# """ +# return _binary_op(_backend.clib.af_bitxor, x1, x2) # type: ignore[arg-type, return-value] + + +# @return_copy +# def bitshiftl(x1: Array, x2: Array, /) -> Array: +# """ +# source: https://arrayfire.org/docs/group__arith__func__shiftl.htm#ga3139645aafe6f045a5cab454e9c13137 +# """ +# return _binary_op(_backend.clib.af_butshiftl, x1, x2) # type: ignore[arg-type, return-value] + + +# @return_copy +# def bitshiftr(x1: Array, x2: Array, /) -> Array: +# """ +# source: https://arrayfire.org/docs/group__arith__func__shiftr.htm#ga4c06b9977ecf96cdfc83b5dfd1ac4895 +# """ +# return _binary_op(_backend.clib.af_bitshiftr, x1, x2) # type: ignore[arg-type, return-value] + + +# @return_copy +# def lt(x1: Array, x2: Array, /) -> Array: +# """ +# source: https://arrayfire.org/docs/arith_8h.htm#ae7aa04bf23b32bb11c4bab8bdd637103 +# """ +# return _binary_op(_backend.clib.af_lt, x1, x2) # type: ignore[arg-type, return-value] + + +# @return_copy +# def le(x1: Array, x2: Array, /) -> Array: +# """ +# source: https://arrayfire.org/docs/group__arith__func__le.htm#gad5535ce64dbed46d0773fd494e84e922 +# """ +# return _binary_op(_backend.clib.af_le, x1, x2) # type: ignore[arg-type, return-value] + + +# @return_copy +# def gt(x1: Array, x2: Array, /) -> Array: +# """ +# source: https://arrayfire.org/docs/group__arith__func__gt.htm#ga4e65603259515de8939899a163ebaf9e +# """ +# return _binary_op(_backend.clib.af_gt, x1, x2) # type: ignore[arg-type, return-value] + + +# @return_copy +# def ge(x1: Array, x2: Array, /) -> Array: +# """ +# source: https://arrayfire.org/docs/group__arith__func__ge.htm#ga4513f212e0b0a22dcf4653e89c85e3d9 +# """ +# return _binary_op(_backend.clib.af_ge, x1, x2) # type: ignore[arg-type, return-value] + + +# @return_copy +# def eq(x1: Array, x2: Array, /) -> Array: +# """ +# source: https://arrayfire.org/docs/group__arith__func__eq.htm#ga76d2da7716831616bb81effa9e163693 +# """ +# return _binary_op(_backend.clib.af_eq, x1, x2) # type: ignore[arg-type, return-value] + + +# @return_copy +# def neq(x1: Array, x2: Array, /) -> Array: +# """ +# source: https://arrayfire.org/docs/group__arith__func__neq.htm#gae4ee8bd06a410f259f1493fb811ce441 +# """ +# return _binary_op(_backend.clib.af_neq, x1, x2) # type: ignore[arg-type, return-value] + + +# @return_copy +# def clamp(x: Array, /, lo: float, hi: float) -> Array: +# return NotImplemented + + +# # """ +# # source: https://arrayfire.org/docs/group__arith__func__clamp.htm#gac4e785c5c877c7905e56f44ef0cb5e61 +# # """ +# # # TODO: check if lo and hi are of type float. Can be ArrayFire array as well +# # out = ctypes.c_void_p(0) +# # safe_call(_backend.clib.af_clamp(ctypes.pointer(out), arr, lo, hi)) +# # return out + + +# @return_copy +# def minof(x1: Array, x2: Array, /) -> Array: +# """ +# source: https://arrayfire.org/docs/group__arith__func__min.htm#ga2b842c2d86df978ff68699aeaafca794 +# """ +# return _binary_op(_backend.clib.af_minof, x1, x2) # type: ignore[arg-type, return-value] + + +# @return_copy +# def maxof(x1: Array, x2: Array, /) -> Array: +# """ +# source: https://arrayfire.org/docs/group__arith__func__max.htm#ga0cd47e70cf82b48730a97c59f494b421 +# """ +# return _binary_op(_backend.clib.af_maxof, x1, x2) # type: ignore[arg-type, return-value] + + +# @return_copy +# def rem(x1: Array, x2: Array, /) -> Array: +# """ +# source: https://arrayfire.org/docs/group__arith__func__clamp.htm#gac4e785c5c877c7905e56f44ef0cb5e61 +# """ +# return _binary_op(_backend.clib.af_rem, x1, x2) # type: ignore[arg-type, return-value] + + +# @return_copy +# def abs(x: Array, /) -> Array: +# """ +# source: https://arrayfire.org/docs/group__arith__func__abs.htm#ga7e8b3c848e6cda3d1f3b0c8b2b4c3f8f +# """ +# return _unary_op(_backend.clib.af_abs, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def arg(x: Array, /) -> Array: +# """ +# source: https://arrayfire.org/docs/group__arith__func__arg.htm#gad04de0f7948688378dcd3628628a7424 +# """ +# return _unary_op(_backend.clib.af_arg, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def sign(x: Array, /) -> Array: +# """ +# source: https://arrayfire.org/docs/group__arith__func__sign.htm#ga2d55dfb9b25e0a1316b70f01d5b44b35 +# """ +# return _unary_op(_backend.clib.af_sign, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def round(x: Array, /) -> Array: +# """ +# source: https://arrayfire.org/docs/group__arith__func__sign.htm#ga2d55dfb9b25e0a1316b70f01d5b44b35 +# """ +# return _unary_op(_backend.clib.af_round, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def trunc(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_trunc, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def floor(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_floor, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def ceil(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_ceil, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def hypot(x1: Array, x2: Array, /) -> Array: +# """ +# source: +# """ +# return _binary_op(_backend.clib.af_hypot, x1, x2) # type: ignore[arg-type, return-value] + + +# @return_copy +# def sin(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_sin, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def cos(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_cos, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def tan(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_tan, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def asin(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_asin, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def acos(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_acos, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def atan(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_atan, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def atan2(x1: Array, x2: Array, /) -> Array: +# """ +# source: +# """ +# return _binary_op(_backend.clib.af_atan2, x1, x2) # type: ignore[arg-type, return-value] + + +# @return_copy +# def cplx(x1: Array, x2: Optional[Array], /) -> Array: +# """ +# source: +# """ +# if x2 is None: +# return _unary_op(_backend.clib.af_cplx, x1) # type: ignore[arg-type, return-value] +# else: +# return _binary_op(_backend.clib.af_cplx2, x1, x2) # type: ignore[arg-type, return-value] + + +# @return_copy +# def real(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_real, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def imag(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_imag, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def conjg(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_conjg, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def sinh(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_sinh, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def cosh(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_cosh, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def tanh(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_tanh, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def asinh(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_asinh, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def acosh(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_acosh, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def atanh(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_atanh, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def root(x1: Array, x2: Array, /) -> Array: +# """ +# source: +# """ +# return _binary_op(_backend.clib.af_root, x1, x2) # type: ignore[arg-type, return-value] + + +# @return_copy +# def pow2(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_pow2, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def sigmoid(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_sigmoid, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def exp(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_exp, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def expm1(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_expm1, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def erf(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_erf, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def erfc(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_erfc, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def log(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_log, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def log1p(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_log1p, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def log10(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_log10, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def log2(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_log2, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def sqrt(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_sqrt, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def rsqrt(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_rsqrt, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def cbrt(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_cbrt, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def factorial(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_factorial, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def tgamma(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_tgamma, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def lgamma(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_lgamma, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def iszero(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_iszero, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def isinf(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_isinf, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def isnan(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_isnan, x) # type: ignore[arg-type, return-value] + + +# @return_copy +# def land(x1: Array, x2: Array, /) -> Array: +# """ +# source: +# """ +# return _binary_op(_backend.clib.af_and, x1, x2) # type: ignore[arg-type, return-value] + + +# @return_copy +# def lor(x1: Array, x2: Array, /) -> Array: +# """ +# source: +# """ +# return _binary_op(_backend.clib.af_or, x1, x2) # type: ignore[arg-type, return-value] + + +# @return_copy +# def lnot(x: Array, /) -> Array: +# """ +# source: +# """ +# return _unary_op(_backend.clib.af_not, x) # type: ignore[arg-type, return-value] diff --git a/arrayfire/library/operators2.py b/arrayfire/library/operators2.py new file mode 100644 index 0000000..3dc1361 --- /dev/null +++ b/arrayfire/library/operators2.py @@ -0,0 +1,310 @@ +from typing import Callable + +from arrayfire import Array +from arrayfire.backend import _clib_wrapper as wrapper + + +class return_copy: + # TODO merge with process_c_function in array_object + def __init__(self, func: Callable) -> None: + self.func = func + + def __call__(self, x1: Array, x2: Array) -> Array: + out = Array() + out.arr = self.func(x1.arr, x2.arr) + return out + + +@return_copy +def abs(x: Array, /) -> Array: + return wrapper.abs(x) # type: ignore[arg-type, return-value] + + +@return_copy +def acos(x: Array, /) -> Array: + return wrapper.acos(x) # type: ignore[arg-type, return-value] + + +@return_copy +def acosh(x: Array, /) -> Array: + return wrapper.acosh(x) # type: ignore[arg-type, return-value] + + +@return_copy +def add(x1: Array, x2: Array, /) -> Array: + return wrapper.add(x1, x2) # type: ignore[arg-type, return-value] + + +@return_copy +def asin(x: Array, /) -> Array: + return wrapper.asin(x) # type: ignore[arg-type, return-value] + + +@return_copy +def asinh(x: Array, /) -> Array: + return wrapper.asinh(x) # type: ignore[arg-type, return-value] + + +@return_copy +def atan(x: Array, /) -> Array: + return wrapper.atan(x) # type: ignore[arg-type, return-value] + + +@return_copy +def atan2(x1: Array, x2: Array, /) -> Array: + return wrapper.atan2(x1, x2) # type: ignore[arg-type, return-value] + + +@return_copy +def atanh(x: Array, /) -> Array: + return wrapper.atanh(x) # type: ignore[arg-type, return-value] + + +@return_copy +def bitwise_and(x1: Array, x2: Array, /) -> Array: + return wrapper.bitand(x1, x2) # type: ignore[arg-type, return-value] + + +@return_copy +def bitwise_left_shift(x1: Array, x2: Array, /) -> Array: + return wrapper.bitshiftl(x1, x2) # type: ignore[arg-type, return-value] + + +@return_copy +def bitwise_invert(x: Array, /) -> Array: + return wrapper.bitnot(x) # type: ignore[arg-type, return-value] + + +@return_copy +def bitwise_or(x1: Array, x2: Array, /) -> Array: + return wrapper.bitor(x1, x2) # type: ignore[arg-type, return-value] + + +@return_copy +def bitwise_right_shift(x1: Array, x2: Array, /) -> Array: + return wrapper.bitshiftr(x1, x2) # type: ignore[arg-type, return-value] + + +@return_copy +def bitwise_xor(x1: Array, x2: Array, /) -> Array: + return wrapper.bitxor(x1, x2) # type: ignore[arg-type, return-value] + + +@return_copy +def ceil(x: Array, /) -> Array: + return wrapper.ceil(x) # type: ignore[arg-type, return-value] + + +@return_copy +def conj(x: Array, /) -> Array: + return wrapper.conjg(x) # type: ignore[arg-type, return-value] + + +@return_copy +def cos(x: Array, /) -> Array: + return wrapper.cos(x) # type: ignore[arg-type, return-value] + + +@return_copy +def cosh(x: Array, /) -> Array: + return wrapper.cosh(x) # type: ignore[arg-type, return-value] + + +@return_copy +def divide(x1: Array, x2: Array, /) -> Array: + return wrapper.div(x1, x2) # type: ignore[arg-type, return-value] + + +@return_copy +def equal(x1: Array, x2: Array, /) -> Array: + return wrapper.eq(x1, x2) # type: ignore[arg-type, return-value] + + +@return_copy +def exp(x: Array, /) -> Array: + return wrapper.exp(x) # type: ignore[arg-type, return-value] + + +@return_copy +def expm1(x: Array, /) -> Array: + return wrapper.expm1(x) # type: ignore[arg-type, return-value] + + +@return_copy +def floor(x: Array, /) -> Array: + return wrapper.floor(x) # type: ignore[arg-type, return-value] + + +@return_copy +def floor_divide(x1: Array, x2: Array, /) -> Array: + return NotImplemented + + +@return_copy +def greater(x1: Array, x2: Array, /) -> Array: + return wrapper.gt(x1, x2) # type: ignore[arg-type, return-value] + + +@return_copy +def greater_equal(x1: Array, x2: Array, /) -> Array: + return wrapper.ge(x1, x2) # type: ignore[arg-type, return-value] + + +@return_copy +def imag(x: Array, /) -> Array: + return wrapper.imag(x) # type: ignore[arg-type, return-value] + + +@return_copy +def isfinite(x: Array, /) -> Array: + return NotImplemented + + +@return_copy +def isinf(x: Array, /) -> Array: + return wrapper.isinf(x) # type: ignore[arg-type, return-value] + + +@return_copy +def isnan(x: Array, /) -> Array: + return wrapper.isnan(x) # type: ignore[arg-type, return-value] + + +@return_copy +def less(x1: Array, x2: Array, /) -> Array: + return wrapper.le(x1, x2) # type: ignore[arg-type, return-value] + + +@return_copy +def less_equal(x1: Array, x2: Array, /) -> Array: + return wrapper.lt(x1, x2) # type: ignore[arg-type, return-value] + + +@return_copy +def log(x: Array, /) -> Array: + return wrapper.log(x) # type: ignore[arg-type, return-value] + + +@return_copy +def log1p(x: Array, /) -> Array: + return wrapper.log1p(x) # type: ignore[arg-type, return-value] + + +@return_copy +def log2(x: Array, /) -> Array: + return wrapper.log2(x) # type: ignore[arg-type, return-value] + + +@return_copy +def log10(x: Array, /) -> Array: + return wrapper.log10(x) # type: ignore[arg-type, return-value] + + +@return_copy +def logaddexp(x1: Array, x2: Array, /) -> Array: + return NotImplemented + + +@return_copy +def logical_and(x1: Array, x2: Array, /) -> Array: + return wrapper.land(x1, x2) # type: ignore[arg-type, return-value] + + +@return_copy +def logical_not(x: Array, /) -> Array: + return wrapper.lnot(x) # type: ignore[arg-type, return-value] + + +@return_copy +def logical_or(x1: Array, x2: Array, /) -> Array: + return wrapper.lor(x1, x2) # type: ignore[arg-type, return-value] + + +@return_copy +def logical_xor(x1: Array, x2: Array, /) -> Array: + return NotImplemented + + +@return_copy +def multiply(x1: Array, x2: Array, /) -> Array: + return wrapper.mul(x1, x2) # type: ignore[arg-type, return-value] + + +@return_copy +def negative(x: Array, /) -> Array: + return wrapper.sub(0, x) # type: ignore[arg-type, return-value] + + +@return_copy +def not_equal(x1: Array, x2: Array, /) -> Array: + return wrapper.neq(x1, x2) # type: ignore[arg-type, return-value] + + +@return_copy +def positive(x: Array, /) -> Array: + return x # type: ignore[arg-type, return-value] + + +@return_copy +def pow(x1: Array, x2: Array, /) -> Array: + return wrapper.pow(x1, x2) # type: ignore[arg-type, return-value] + + +@return_copy +def real(x: Array, /) -> Array: + return wrapper.real(x) # type: ignore[arg-type, return-value] + + +@return_copy +def remainder(x1: Array, x2: Array, /) -> Array: + return wrapper.rem(x1, x2) # type: ignore[arg-type, return-value] + + +@return_copy +def round(x: Array, /) -> Array: + return wrapper.round(x) # type: ignore[arg-type, return-value] + + +@return_copy +def sign(x: Array, /) -> Array: + return wrapper.sign(x) # type: ignore[arg-type, return-value] + + +@return_copy +def sin(x: Array, /) -> Array: + return wrapper.sin(x) # type: ignore[arg-type, return-value] + + +@return_copy +def sinh(x: Array, /) -> Array: + return wrapper.sinh(x) # type: ignore[arg-type, return-value] + + +@return_copy +def square(x: Array, /) -> Array: + return wrapper.pow(x, 2) # type: ignore[arg-type, return-value] + + +@return_copy +def sqrt(x: Array, /) -> Array: + return wrapper.sqrt(x) # type: ignore[arg-type, return-value] + + +@return_copy +def subtract(x1: Array, x2: Array, /) -> Array: + return wrapper.sub(x1, x2) # type: ignore[arg-type, return-value] + + +@return_copy +def tan(x: Array, /) -> Array: + return wrapper.tan(x) # type: ignore[arg-type, return-value] + + +@return_copy +def tanh(x: Array, /) -> Array: + return wrapper.tanh(x) # type: ignore[arg-type, return-value] + + +@return_copy +def trunc(x: Array, /) -> Array: + return wrapper.trunc(x) # type: ignore[arg-type, return-value] diff --git a/arrayfire/logger.py b/arrayfire/logger.py index 6ca0409..25c48d7 100755 --- a/arrayfire/logger.py +++ b/arrayfire/logger.py @@ -1,14 +1,14 @@ import logging # Configure the logger -logging.basicConfig(level=logging.DEBUG) +logging.basicConfig(level=logging.INFO) # Create a logger logger = logging.getLogger(__name__) # Create a console handler and set the level to DEBUG console_handler = logging.StreamHandler() -console_handler.setLevel(logging.DEBUG) +console_handler.setLevel(logging.INFO) # Create a formatter and attach it to the console handler formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") diff --git a/tests/array_object/test_initialization.py b/tests/array_object/test_initialization.py index 3c1436d..430b7c3 100755 --- a/tests/array_object/test_initialization.py +++ b/tests/array_object/test_initialization.py @@ -22,7 +22,7 @@ (Array(shape=(2, 3)), float32, 2, 6, (2, 3), 2), (Array([1, 2, 3]), float32, 1, 3, (3,), 3), (Array(pyarray.array("f", [1, 2, 3])), float32, 1, 3, (3,), 3), - (Array([1], shape=(1,), dtype=float32), float32, 1, 1, (1,), 1), # BUG + (Array([1], shape=(1,), dtype=float32), float32, 1, 1, (1,), 1), (Array(Array([1])), float32, 1, 1, (1,), 1), ], ) diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py new file mode 100755 index 0000000..d7ad295 --- /dev/null +++ b/tests/test_dtypes.py @@ -0,0 +1,79 @@ +import ctypes + +import pytest + +from arrayfire.dtypes import Dtype +from arrayfire.dtypes import bool as af_bool +from arrayfire.dtypes import complex128, float32, float64, implicit_dtype, int8, int32, int64, str_to_dtype, uint16 + + +def test_dtype_str_representation() -> None: + assert str(float32) == "float32" + + +def test_dtype_repr_representation() -> None: + assert repr(float32) == "arrayfire.float32(typecode)" + + +def test_dtype_equality() -> None: + dt1 = Dtype("float32", "f", ctypes.c_float, "float", 0) + dt2 = float32 + assert dt1 == dt2 + + +def test_dtype_inequality() -> None: + assert float32 != int8 + + +@pytest.mark.parametrize( + "number,array_dtype,expected_dtype", + [ + (1, int64, int64), + (1.0, float64, float64), + (1.0, float32, float32), + (True, float32, af_bool), + (1 + 2j, complex128, complex128), + ], +) +def test_implicit_dtype(number: int | float | bool | complex, array_dtype: Dtype, expected_dtype: Dtype) -> None: + result_dtype = implicit_dtype(number, array_dtype) + assert result_dtype == expected_dtype + + +def test_implicit_dtype_raises_error_invalid_array_dtype() -> None: + with pytest.raises(TypeError): + implicit_dtype([1], "invalid_dtype") # type: ignore[arg-type] + + +def test_implicit_dtype_raises_error_invalid_number_type() -> None: + with pytest.raises(TypeError): + implicit_dtype("invalid_number", int32) # type: ignore[arg-type] + + +def test_implicit_dtype_raises_error_invalid_combination() -> None: + with pytest.raises(TypeError): + implicit_dtype("invalid_number", float32) # type: ignore[arg-type] + + +@pytest.mark.parametrize( + "value,expected_dtype", + [ + ("i8", int8), + ("int", int32), + ("uint16", uint16), + ("float", float32), + ], +) +def test_str_to_dtype(value: str, expected_dtype: Dtype) -> None: + result_dtype = str_to_dtype(value) + assert result_dtype == expected_dtype + + +def test_str_to_dtype_raises_error() -> None: + with pytest.raises(TypeError): + str_to_dtype("invalid_dtype") + + +def test_str_to_dtype_raises_error_case_insensitive() -> None: + with pytest.raises(TypeError): + str_to_dtype("Int") From 8a333767ace75bf829bc1201daafac83ac1b1d80 Mon Sep 17 00:00:00 2001 From: Anton Chernyatevich Date: Sat, 19 Aug 2023 03:13:54 +0300 Subject: [PATCH 21/31] Temp commit --- arrayfire/__init__.py | 1 + arrayfire/array_api/_array_object.py | 62 +- arrayfire/array_api/_creation_function.py | 6 +- arrayfire/array_api/_elementwise_functions.py | 102 +-- .../tests/fixme_test_elementwise_functions.py | 105 --- .../tests/test_elementwise_functions.py | 111 +++ arrayfire/array_object.py | 15 +- arrayfire/backend/_clib_wrapper/__init__.py | 6 +- .../backend/_clib_wrapper/_constant_array.py | 13 +- arrayfire/backend/_clib_wrapper/_operators.py | 18 +- arrayfire/dtypes.py | 4 + arrayfire/library/operators.py | 646 +++++++----------- 12 files changed, 488 insertions(+), 601 deletions(-) delete mode 100755 arrayfire/array_api/tests/fixme_test_elementwise_functions.py create mode 100755 arrayfire/array_api/tests/test_elementwise_functions.py diff --git a/arrayfire/__init__.py b/arrayfire/__init__.py index ad42594..9ca05f4 100755 --- a/arrayfire/__init__.py +++ b/arrayfire/__init__.py @@ -26,6 +26,7 @@ "complex128", "bool", ] + from .dtypes import ( bool, complex64, diff --git a/arrayfire/array_api/_array_object.py b/arrayfire/array_api/_array_object.py index a6af9b8..d4026e2 100755 --- a/arrayfire/array_api/_array_object.py +++ b/arrayfire/array_api/_array_object.py @@ -109,37 +109,37 @@ def _promote_scalar(self, scalar): @staticmethod def _normalize_two_args(x1, x2) -> Tuple[Array, Array]: - """ - Normalize inputs to two arg functions to fix type promotion rules - - NumPy deviates from the spec type promotion rules in cases where one - argument is 0-dimensional and the other is not. For example: - - >>> import numpy as np - >>> a = np.array([1.0], dtype=np.float32) - >>> b = np.array(1.0, dtype=np.float64) - >>> np.add(a, b) # The spec says this should be float64 - array([2.], dtype=float32) - - To fix this, we add a dimension to the 0-dimension array before passing it - through. This works because a dimension would be added anyway from - broadcasting, so the resulting shape is the same, but this prevents NumPy - from not promoting the dtype. - """ - # Another option would be to use signature=(x1.dtype, x2.dtype, None), - # but that only works for ufuncs, so we would have to call the ufuncs - # directly in the operator methods. One should also note that this - # sort of trick wouldn't work for functions like searchsorted, which - # don't do normal broadcasting, but there aren't any functions like - # that in the array API namespace. - if x1.ndim == 0 and x2.ndim != 0: - # The _array[None] workaround was chosen because it is relatively - # performant. broadcast_to(x1._array, x2.shape) is much slower. We - # could also manually type promote x2, but that is more complicated - # and about the same performance as this. - x1 = Array._new(x1._array[None]) - elif x2.ndim == 0 and x1.ndim != 0: - x2 = Array._new(x2._array[None]) + # """ + # Normalize inputs to two arg functions to fix type promotion rules + + # NumPy deviates from the spec type promotion rules in cases where one + # argument is 0-dimensional and the other is not. For example: + + # >>> import numpy as np + # >>> a = np.array([1.0], dtype=np.float32) + # >>> b = np.array(1.0, dtype=np.float64) + # >>> np.add(a, b) # The spec says this should be float64 + # array([2.], dtype=float32) + + # To fix this, we add a dimension to the 0-dimension array before passing it + # through. This works because a dimension would be added anyway from + # broadcasting, so the resulting shape is the same, but this prevents NumPy + # from not promoting the dtype. + # """ + # # Another option would be to use signature=(x1.dtype, x2.dtype, None), + # # but that only works for ufuncs, so we would have to call the ufuncs + # # directly in the operator methods. One should also note that this + # # sort of trick wouldn't work for functions like searchsorted, which + # # don't do normal broadcasting, but there aren't any functions like + # # that in the array API namespace. + # if x1.ndim == 0 and x2.ndim != 0: + # # The _array[None] workaround was chosen because it is relatively + # # performant. broadcast_to(x1._array, x2.shape) is much slower. We + # # could also manually type promote x2, but that is more complicated + # # and about the same performance as this. + # x1 = Array._new(x1._array[None]) + # elif x2.ndim == 0 and x1.ndim != 0: + # x2 = Array._new(x2._array[None]) return (x1, x2) @classmethod diff --git a/arrayfire/array_api/_creation_function.py b/arrayfire/array_api/_creation_function.py index c920124..845c9a1 100755 --- a/arrayfire/array_api/_creation_function.py +++ b/arrayfire/array_api/_creation_function.py @@ -44,9 +44,9 @@ def asarray( else: raise ValueError(f"Unsupported device {device!r}") - if isinstance(obj, int | float): - afarray = AFArray([obj], dtype=dtype, shape=(1,), to_device=to_device) - return Array._new(afarray) + # if isinstance(obj, int | float): + # afarray = AFArray([obj], dtype=dtype, shape=(1,), to_device=to_device) + # return Array._new(afarray) afarray = AFArray(obj, dtype=dtype, to_device=to_device) return Array._new(afarray) diff --git a/arrayfire/array_api/_elementwise_functions.py b/arrayfire/array_api/_elementwise_functions.py index 713ad0b..36f2bef 100755 --- a/arrayfire/array_api/_elementwise_functions.py +++ b/arrayfire/array_api/_elementwise_functions.py @@ -6,99 +6,99 @@ def abs(x: Array, /) -> Array: - return Array._new(operators.abs(x._array)) + return Array._new(operators.abs(x._array.arr)) def acos(x: Array, /) -> Array: - return NotImplemented + return Array._new(operators.acos(x._array.arr)) def acosh(x: Array, /) -> Array: - return NotImplemented + return Array._new(operators.acosh(x._array.arr)) def add(x1: Array, x2: Array, /) -> Array: - return Array._new(operators.add(x1._array, x2._array)) + return Array._new(operators.add(x1._array.arr, x2._array.arr)) def asin(x: Array, /) -> Array: - return NotImplemented + return Array._new(operators.asin(x._array.arr)) def asinh(x: Array, /) -> Array: - return NotImplemented + return Array._new(operators.asinh(x._array.arr)) def atan(x: Array, /) -> Array: - return NotImplemented + return Array._new(operators.atan(x._array.arr)) def atan2(x1: Array, x2: Array, /) -> Array: - return NotImplemented + return Array._new(operators.atan2(x1._array.arr, x2._array.arr)) def atanh(x: Array, /) -> Array: - return NotImplemented + return Array._new(operators.atanh(x._array.arr)) def bitwise_and(x1: Array, x2: Array, /) -> Array: - return Array._new(operators.bitand(x1._array, x2._array)) + return Array._new(operators.bitand(x1._array.arr, x2._array.arr)) def bitwise_left_shift(x1: Array, x2: Array, /) -> Array: - return Array._new(operators.bitshiftl(x1._array, x2._array)) + return Array._new(operators.bitshiftl(x1._array.arr, x2._array.arr)) def bitwise_invert(x: Array, /) -> Array: - return NotImplemented + return Array._new(operators.bitnot(x._array.arr)) def bitwise_or(x1: Array, x2: Array, /) -> Array: - return NotImplemented + return Array._new(operators.bitor(x1._array.arr, x2._array.arr)) def bitwise_right_shift(x1: Array, x2: Array, /) -> Array: - return NotImplemented + return Array._new(operators.bitshiftr(x1._array.arr, x2._array.arr)) def bitwise_xor(x1: Array, x2: Array, /) -> Array: - return NotImplemented + return Array._new(operators.bitxor(x1._array.arr, x2._array.arr)) def ceil(x: Array, /) -> Array: - return NotImplemented + return Array._new(operators.ceil(x._array.arr)) def conj(x: Array, /) -> Array: - return NotImplemented + return Array._new(operators.conjg(x._array.arr)) def cos(x: Array, /) -> Array: - return NotImplemented + return Array._new(operators.cos(x._array.arr)) def cosh(x: Array, /) -> Array: - return NotImplemented + return Array._new(operators.cosh(x._array.arr)) def divide(x1: Array, x2: Array, /) -> Array: - return NotImplemented + return Array._new(operators.div(x1._array.arr, x2._array.arr)) def equal(x1: Array, x2: Array, /) -> Array: - return NotImplemented + return Array._new(operators.eq(x1._array.arr, x2._array.arr)) def exp(x: Array, /) -> Array: - return NotImplemented + return Array._new(operators.exp(x._array.arr)) def expm1(x: Array, /) -> Array: - return NotImplemented + return Array._new(operators.expm1(x._array.arr)) def floor(x: Array, /) -> Array: - return NotImplemented + return Array._new(operators.floor(x._array.arr)) def floor_divide(x1: Array, x2: Array, /) -> Array: @@ -106,15 +106,15 @@ def floor_divide(x1: Array, x2: Array, /) -> Array: def greater(x1: Array, x2: Array, /) -> Array: - return NotImplemented + return Array._new(operators.gt(x1._array.arr, x2._array.arr)) def greater_equal(x1: Array, x2: Array, /) -> Array: - return NotImplemented + return Array._new(operators.ge(x1._array.arr, x2._array.arr)) def imag(x: Array, /) -> Array: - return NotImplemented + return Array._new(operators.imag(x._array.arr)) def isfinite(x: Array, /) -> Array: @@ -122,35 +122,35 @@ def isfinite(x: Array, /) -> Array: def isinf(x: Array, /) -> Array: - return NotImplemented + return Array._new(operators.isinf(x._array.arr)) def isnan(x: Array, /) -> Array: - return NotImplemented + return Array._new(operators.isnan(x._array.arr)) def less(x1: Array, x2: Array, /) -> Array: - return NotImplemented + return Array._new(operators.lt(x1._array.arr, x2._array.arr)) def less_equal(x1: Array, x2: Array, /) -> Array: - return NotImplemented + return Array._new(operators.le(x1._array.arr, x2._array.arr)) def log(x: Array, /) -> Array: - return NotImplemented + return Array._new(operators.log(x._array.arr)) def log1p(x: Array, /) -> Array: - return NotImplemented + return Array._new(operators.log1p(x._array.arr)) def log2(x: Array, /) -> Array: - return NotImplemented + return Array._new(operators.log2(x._array.arr)) def log10(x: Array, /) -> Array: - return NotImplemented + return Array._new(operators.log10(x._array.arr)) def logaddexp(x1: Array, x2: Array) -> Array: @@ -158,15 +158,15 @@ def logaddexp(x1: Array, x2: Array) -> Array: def logical_and(x1: Array, x2: Array, /) -> Array: - return NotImplemented + return Array._new(operators.land(x1._array.arr, x2._array.arr)) def logical_not(x: Array, /) -> Array: - return NotImplemented + return Array._new(operators.lnot(x._array.arr)) def logical_or(x1: Array, x2: Array, /) -> Array: - return NotImplemented + return Array._new(operators.lor(x1._array.arr, x2._array.arr)) def logical_xor(x1: Array, x2: Array, /) -> Array: @@ -174,7 +174,7 @@ def logical_xor(x1: Array, x2: Array, /) -> Array: def multiply(x1: Array, x2: Array, /) -> Array: - return NotImplemented + return Array._new(operators.mul(x1._array.arr, x2._array.arr)) def negative(x: Array, /) -> Array: @@ -190,31 +190,31 @@ def positive(x: Array, /) -> Array: def pow(x1: Array, x2: Array, /) -> Array: - return Array._new(operators.pow(x1._array, x2._array)) + return Array._new(operators.pow(x1._array.arr, x2._array.arr)) def real(x: Array, /) -> Array: - return NotImplemented + return Array._new(operators.real(x._array.arr)) def remainder(x1: Array, x2: Array, /) -> Array: - return NotImplemented + return Array._new(operators.rem(x1._array.arr, x2._array.arr)) def round(x: Array, /) -> Array: - return NotImplemented + return Array._new(operators.round(x._array.arr)) def sign(x: Array, /) -> Array: - return NotImplemented + return Array._new(operators.sign(x._array.arr)) def sin(x: Array, /) -> Array: - return NotImplemented + return Array._new(operators.sin(x._array.arr)) def sinh(x: Array, /) -> Array: - return NotImplemented + return Array._new(operators.sinh(x._array.arr)) def square(x: Array, /) -> Array: @@ -222,20 +222,20 @@ def square(x: Array, /) -> Array: def sqrt(x: Array, /) -> Array: - return NotImplemented + return Array._new(operators.sqrt(x._array.arr)) def subtract(x1: Array, x2: Array, /) -> Array: - return NotImplemented + return Array._new(operators.sub(x1._array.arr, x2._array.arr)) def tan(x: Array, /) -> Array: - return NotImplemented + return Array._new(operators.tan(x._array.arr)) def tanh(x: Array, /) -> Array: - return NotImplemented + return Array._new(operators.tanh(x._array.arr)) def trunc(x: Array, /) -> Array: - return NotImplemented + return Array._new(operators.trunc(x._array.arr)) diff --git a/arrayfire/array_api/tests/fixme_test_elementwise_functions.py b/arrayfire/array_api/tests/fixme_test_elementwise_functions.py deleted file mode 100755 index 9f2a0d6..0000000 --- a/arrayfire/array_api/tests/fixme_test_elementwise_functions.py +++ /dev/null @@ -1,105 +0,0 @@ -from inspect import getfullargspec - -import pytest - -from .. import _elementwise_functions, asarray -from .._dtypes import boolean_dtypes, dtype_categories, floating_dtypes, integer_dtypes -from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift - - -def nargs(func): - return len(getfullargspec(func).args) - - -def test_function_types(): - # Test that every function accepts only the required input types. We only - # test the negative cases here (error). The positive cases are tested in - # the array API test suite. - - elementwise_function_input_types = { - # "abs": "numeric", - # "acos": "floating-point", - # "acosh": "floating-point", - "add": "numeric", - # "asin": "floating-point", - # "asinh": "floating-point", - # "atan": "floating-point", - # "atan2": "real floating-point", - # "atanh": "floating-point", - # "bitwise_and": "integer or boolean", - # "bitwise_invert": "integer or boolean", - # "bitwise_left_shift": "integer", - # "bitwise_or": "integer or boolean", - # "bitwise_right_shift": "integer", - # "bitwise_xor": "integer or boolean", - # "ceil": "real numeric", - # "conj": "complex floating-point", - # "cos": "floating-point", - # "cosh": "floating-point", - # "divide": "floating-point", - # "equal": "all", - # "exp": "floating-point", - # "expm1": "floating-point", - # "floor": "real numeric", - # "floor_divide": "real numeric", - # "greater": "real numeric", - # "greater_equal": "real numeric", - # "imag": "complex floating-point", - # "isfinite": "numeric", - # "isinf": "numeric", - # "isnan": "numeric", - # "less": "real numeric", - # "less_equal": "real numeric", - # "log": "floating-point", - # "logaddexp": "real floating-point", - # "log10": "floating-point", - # "log1p": "floating-point", - # "log2": "floating-point", - # "logical_and": "boolean", - # "logical_not": "boolean", - # "logical_or": "boolean", - # "logical_xor": "boolean", - # "multiply": "numeric", - # "negative": "numeric", - # "not_equal": "all", - # "positive": "numeric", - # "pow": "numeric", - # "real": "complex floating-point", - # "remainder": "real numeric", - # "round": "numeric", - # "sign": "numeric", - # "sin": "floating-point", - # "sinh": "floating-point", - # "sqrt": "floating-point", - # "square": "numeric", - # "subtract": "numeric", - # "tan": "floating-point", - # "tanh": "floating-point", - # "trunc": "real numeric", - } - - def _array_vals(): - for d in integer_dtypes: - yield asarray(1, dtype=d) - for d in boolean_dtypes: - yield asarray(False, dtype=d) - for d in floating_dtypes: - yield asarray(1.0, dtype=d) - - for x in _array_vals(): - for func_name, types in elementwise_function_input_types.items(): - dtypes = dtype_categories[types] - func = getattr(_elementwise_functions, func_name) - if nargs(func) == 2: - for y in _array_vals(): - if x.dtype not in dtypes or y.dtype not in dtypes: - pytest.raises(TypeError, lambda: func(x, y)) - else: - if x.dtype not in dtypes: - pytest.raises(TypeError, lambda: func(x)) - - -# def test_bitwise_shift_error() -> None: -# # bitwise shift functions should raise when the second argument is negative -# pytest.raises(ValueError, lambda: bitwise_left_shift(asarray([1, 1]), asarray([1, -1]))) -# pytest.raises(ValueError, lambda: bitwise_right_shift(asarray([1, 1]), asarray([1, -1]))) diff --git a/arrayfire/array_api/tests/test_elementwise_functions.py b/arrayfire/array_api/tests/test_elementwise_functions.py new file mode 100755 index 0000000..c88e244 --- /dev/null +++ b/arrayfire/array_api/tests/test_elementwise_functions.py @@ -0,0 +1,111 @@ +from inspect import getfullargspec +from typing import Callable, Iterator, TYPE_CHECKING + +import pytest + +from .. import _elementwise_functions, asarray +from .._dtypes import boolean_dtypes, dtype_categories, floating_dtypes, integer_dtypes, int8, real_floating_dtypes, int8 +from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift +from .._array_object import Array + + + +def nargs(func: Callable) -> int: + return len(getfullargspec(func).args) + + +def test_function_types() -> None: + # Test that every function accepts only the required input types. We only + # test the negative cases here (error). The positive cases are tested in + # the array API test suite. + + elementwise_function_input_types = { + "abs": "numeric", + "acos": "floating-point", + "acosh": "floating-point", + "add": "numeric", + "asin": "floating-point", + "asinh": "floating-point", + "atan": "floating-point", + "atan2": "real floating-point", + "atanh": "floating-point", + # "bitwise_and": "integer or boolean", + # "bitwise_invert": "integer or boolean", + # "bitwise_left_shift": "integer", + # "bitwise_or": "integer or boolean", + # "bitwise_right_shift": "integer", + # "bitwise_xor": "integer or boolean", + "ceil": "real numeric", + # "conj": "complex floating-point", + "cos": "floating-point", + "cosh": "floating-point", + "divide": "floating-point", + "equal": "all", + "exp": "floating-point", + "expm1": "floating-point", + "floor": "real numeric", + "floor_divide": "real numeric", + "greater": "real numeric", + "greater_equal": "real numeric", + # "imag": "complex floating-point", + "isfinite": "numeric", + "isinf": "numeric", + "isnan": "numeric", + "less": "real numeric", + "less_equal": "real numeric", + "log": "floating-point", + "logaddexp": "real floating-point", + "log10": "floating-point", + "log1p": "floating-point", + "log2": "floating-point", + # "logical_and": "boolean", + # "logical_not": "boolean", + # "logical_or": "boolean", + # "logical_xor": "boolean", + "multiply": "numeric", + "negative": "numeric", + "not_equal": "all", + "positive": "numeric", + "pow": "numeric", + # "real": "complex floating-point", + "remainder": "real numeric", + "round": "numeric", + "sign": "numeric", + "sin": "floating-point", + "sinh": "floating-point", + "sqrt": "floating-point", + "square": "numeric", + "subtract": "numeric", + "tan": "floating-point", + "tanh": "floating-point", + "trunc": "real numeric", + } + + def _array_vals() -> Iterator[Array]: + for dt in integer_dtypes: + if dt in {int8}: + continue + yield asarray([1], dtype=dt) + # for d in boolean_dtypes: + # yield asarray(False, dtype=d) + for dt in real_floating_dtypes: + yield asarray([1.0], dtype=dt) + + for x in _array_vals(): + for func_name, types in elementwise_function_input_types.items(): + dtypes = dtype_categories[types] + func = getattr(_elementwise_functions, func_name) + if nargs(func) == 2: + for y in _array_vals(): + if x.dtype not in dtypes or y.dtype not in dtypes: + pytest.raises(TypeError, lambda: func(x, y)) + else: + if x.dtype not in dtypes: + print(func) + pytest.raises(TypeError, lambda: func(x)) + + +# def test_bitwise_shift_error() -> None: +# # bitwise shift functions should raise when the second argument is negative +# pytest.raises(ValueError, lambda: bitwise_left_shift(asarray([1, 1]), asarray([1, -1]))) +# pytest.raises(ValueError, lambda: bitwise_right_shift(asarray([1, 1]), asarray([1, -1]))) diff --git a/arrayfire/array_object.py b/arrayfire/array_object.py index 07a633e..2ba45ad 100755 --- a/arrayfire/array_object.py +++ b/arrayfire/array_object.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union from .backend import _clib_wrapper as wrapper -from .dtypes import CType, Dtype, c_api_value_to_dtype, float32, str_to_dtype +from .dtypes import CType, Dtype, c_api_value_to_dtype, float32, str_to_dtype, float64 from .library.device import PointerSource if TYPE_CHECKING: @@ -28,8 +28,7 @@ def __init__(self, func: Callable) -> None: def __call__(self, *x: Array) -> Array: out = Array() - # import ipdb; ipdb.set_trace() - out.arr = self.func(*[item.arr for item in x]) + out.arr = self.func(*x) return out @@ -69,7 +68,15 @@ def __init__( _array_buffer = _ArrayBuffer(*obj.buffer_info()) elif isinstance(obj, list): - _array = py_array.array("f", obj) # BUG [True, False] -> dtype: f32 # TODO add int and float + # TODO fix an issue when Array can not be created from float values to complex + if _no_initial_dtype: + arr_typecode = "f" + elif dtype.typecode in py_array.typecodes: + arr_typecode = dtype.typecode + else: + raise TypeError(f"Unsupported typecode. Can not create a python array from '{repr(dtype)}'") + + _array = py_array.array(arr_typecode, obj) _type_char = _array.typecode _array_buffer = _ArrayBuffer(*_array.buffer_info()) diff --git a/arrayfire/backend/_clib_wrapper/__init__.py b/arrayfire/backend/_clib_wrapper/__init__.py index cd6c061..6fbe9b7 100755 --- a/arrayfire/backend/_clib_wrapper/__init__.py +++ b/arrayfire/backend/_clib_wrapper/__init__.py @@ -67,7 +67,8 @@ "clamp", "arg", "conjg", - "cplx", + "cplx1", + "cplx2", "imag", "factorial", "maxof", @@ -105,7 +106,8 @@ conjg, cos, cosh, - cplx, + cplx1, + cplx2, div, eq, erf, diff --git a/arrayfire/backend/_clib_wrapper/_constant_array.py b/arrayfire/backend/_clib_wrapper/_constant_array.py index 6646656..6ada6e9 100755 --- a/arrayfire/backend/_clib_wrapper/_constant_array.py +++ b/arrayfire/backend/_clib_wrapper/_constant_array.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Tuple, Union from arrayfire.backend._backend import _backend -from arrayfire.dtypes import CShape, Dtype, implicit_dtype, int64, uint64 +from arrayfire.dtypes import CShape, Dtype, implicit_dtype, int64, uint64, is_complex_dtype, complex64, complex128 from ._error_handler import safe_call @@ -12,7 +12,7 @@ from ._base import AFArrayType -def _constant_complex(number: Union[int, float], shape: Tuple[int, ...], dtype: Dtype, /) -> AFArrayType: +def _constant_complex(number: Union[int, float, complex], shape: Tuple[int, ...], dtype: Dtype, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__data__func__constant.htm#ga5a083b1f3cd8a72a41f151de3bdea1a2 """ @@ -78,14 +78,11 @@ def _constant(number: Union[int, float], shape: Tuple[int, ...], dtype: Dtype, / return out -def create_constant_array(number: Union[int, float], shape: Tuple[int, ...], dtype: Dtype, /) -> AFArrayType: +def create_constant_array(number: Union[int, float, complex], shape: Tuple[int, ...], dtype: Dtype, /) -> AFArrayType: dtype = implicit_dtype(number, dtype) - # NOTE complex is not supported in Data API - # if isinstance(number, complex): - # if dtype != complex64 and dtype != complex128: - # dtype = complex64 - # return _constant_complex(number, shape, dtype) + if isinstance(number, complex): + return _constant_complex(number, shape, dtype if is_complex_dtype(dtype) else complex64) if dtype == int64: return _constant_long(number, shape, dtype) diff --git a/arrayfire/backend/_clib_wrapper/_operators.py b/arrayfire/backend/_clib_wrapper/_operators.py index 33809cb..2832cc8 100755 --- a/arrayfire/backend/_clib_wrapper/_operators.py +++ b/arrayfire/backend/_clib_wrapper/_operators.py @@ -1,7 +1,7 @@ from __future__ import annotations import ctypes -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Callable from arrayfire.backend._backend import _backend from arrayfire.library.broadcast import bcast_var @@ -91,7 +91,7 @@ def bitshiftl(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__shiftl.htm#ga3139645aafe6f045a5cab454e9c13137 """ - return _binary_op(_backend.clib.af_butshiftl, lhs, rhs) + return _binary_op(_backend.clib.af_bitshiftl, lhs, rhs) def bitshiftr(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: @@ -282,14 +282,18 @@ def atan2(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: return _binary_op(_backend.clib.af_atan2, lhs, rhs) -def cplx(lhs: AFArrayType, rhs: Optional[AFArrayType], /) -> AFArrayType: +def cplx1(arr: AFArrayType, /) -> AFArrayType: """ source: """ - if rhs is None: - return _unary_op(_backend.clib.af_cplx, lhs) - else: - return _binary_op(_backend.clib.af_cplx2, lhs, rhs) + return _unary_op(_backend.clib.af_cplx, arr) + + +def cplx2(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: + """ + source: + """ + return _binary_op(_backend.clib.af_cplx2, lhs, rhs) def real(arr: AFArrayType, /) -> AFArrayType: diff --git a/arrayfire/dtypes.py b/arrayfire/dtypes.py index 76dd06e..0c526da 100644 --- a/arrayfire/dtypes.py +++ b/arrayfire/dtypes.py @@ -59,6 +59,10 @@ def __repr__(self) -> str: ) +def is_complex_dtype(dtype: Dtype) -> _python_bool: + return dtype in {complex64, complex128} + + c_dim_t = ctypes.c_int if is_arch_x86() else ctypes.c_longlong ShapeType = Tuple[int, ...] diff --git a/arrayfire/library/operators.py b/arrayfire/library/operators.py index a828e54..2606c9e 100755 --- a/arrayfire/library/operators.py +++ b/arrayfire/library/operators.py @@ -4,26 +4,27 @@ from arrayfire import Array, return_copy from arrayfire.backend import _clib_wrapper as wrapper +from arrayfire.dtypes import is_complex_dtype @return_copy def add(x1: Array, x2: Array, /) -> Array: - return wrapper.add(x1, x2) # type: ignore[arg-type, return-value] + return wrapper.add(x1.arr, x2.arr) # type: ignore[arg-type, return-value] @return_copy def sub(x1: Array, x2: Array, /) -> Array: - return wrapper.sub(x1, x2) # type: ignore[arg-type, return-value] + return wrapper.sub(x1.arr, x2.arr) # type: ignore[arg-type, return-value] @return_copy def mul(x1: Array, x2: Array, /) -> Array: - return wrapper.mul(x1, x2) # type: ignore[arg-type, return-value] + return wrapper.mul(x1.arr, x2.arr) # type: ignore[arg-type, return-value] @return_copy def div(x1: Array, x2: Array, /) -> Array: - return wrapper.div(x1, x2) # type: ignore[arg-type, return-value] + return wrapper.div(x1.arr, x2.arr) # type: ignore[arg-type, return-value] @return_copy @@ -49,117 +50,76 @@ def mod(x1: Union[int, float, Array], x2: Union[int, float, Array], /) -> Array: If both operands are scalars or if the arrays' shapes do not match. """ - if isinstance(x1, Array) and isinstance(x2, Array): - if x1.shape != x2.shape: - raise ValueError("Array shapes must match.") - elif not isinstance(x1, Array) and not isinstance(x2, Array): - raise ValueError("At least one operand must be an Array.") + _check_operands_fit_requirements(x1, x2) - return wrapper.mod(x1, x2) # type: ignore[arg-type, return-value] + return wrapper.mod(x1.arr, x2.arr) # type: ignore[arg-type, return-value] @return_copy -def pow(x1: Array, x2: Array, /) -> Array: - """ - source: https://arrayfire.org/docs/group__arith__func__pow.htm#ga0f28be1a9c8b176a78c4a47f483e7fc6 - """ - return wrapper.pow(x1, x2) # type: ignore[arg-type, return-value] +def pow(x1: Union[int, float, Array], x2: Union[int, float, Array], /) -> Array: + _check_operands_fit_requirements(x1, x2) + + return wrapper.pow(x1.arr, x2.arr) # type: ignore[arg-type, return-value] @return_copy def bitnot(x: Array, /) -> Array: - """ - source: https://arrayfire.org/docs/group__arith__func__bitnot.htm#gaf97e8a38aab59ed2d3a742515467d01e - """ - return wrapper.bitnot(x1, x2) # type: ignore[arg-type, return-value] + return wrapper.bitnot(x.arr) # type: ignore[arg-type, return-value] @return_copy def bitand(x1: Array, x2: Array, /) -> Array: - """ - source: https://arrayfire.org/docs/group__arith__func__bitand.htm#ga45c0779ade4703708596df11cca98800 - """ - return wrapper.bitand(x1, x2) # type: ignore[arg-type, return-value] + return wrapper.bitand(x1.arr, x2.arr) # type: ignore[arg-type, return-value] @return_copy def bitor(x1: Array, x2: Array, /) -> Array: - """ - source: https://arrayfire.org/docs/group__arith__func__bitor.htm#ga84c99f77d1d83fd53f949b4d67b5b210 - """ - return wrapper.bitor(x1, x2) # type: ignore[arg-type, return-value] + return wrapper.bitor(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def bitxor(x1: Array, x2: Array, /) -> Array: -# """ -# source: https://arrayfire.org/docs/group__arith__func__bitxor.htm#ga8188620da6b432998e55fdd1fad22100 -# """ -# return _binary_op(_backend.clib.af_bitxor, x1, x2) # type: ignore[arg-type, return-value] +@return_copy +def bitxor(x1: Array, x2: Array, /) -> Array: + return wrapper.bitxor(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def bitshiftl(x1: Array, x2: Array, /) -> Array: -# """ -# source: https://arrayfire.org/docs/group__arith__func__shiftl.htm#ga3139645aafe6f045a5cab454e9c13137 -# """ -# return _binary_op(_backend.clib.af_butshiftl, x1, x2) # type: ignore[arg-type, return-value] +@return_copy +def bitshiftl(x1: Array, x2: Array, /) -> Array: + return wrapper.bitshiftl(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def bitshiftr(x1: Array, x2: Array, /) -> Array: -# """ -# source: https://arrayfire.org/docs/group__arith__func__shiftr.htm#ga4c06b9977ecf96cdfc83b5dfd1ac4895 -# """ -# return _binary_op(_backend.clib.af_bitshiftr, x1, x2) # type: ignore[arg-type, return-value] +@return_copy +def bitshiftr(x1: Array, x2: Array, /) -> Array: + return wrapper.bitshiftr(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def lt(x1: Array, x2: Array, /) -> Array: -# """ -# source: https://arrayfire.org/docs/arith_8h.htm#ae7aa04bf23b32bb11c4bab8bdd637103 -# """ -# return _binary_op(_backend.clib.af_lt, x1, x2) # type: ignore[arg-type, return-value] +@return_copy +def lt(x1: Array, x2: Array, /) -> Array: + return wrapper.lt(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def le(x1: Array, x2: Array, /) -> Array: -# """ -# source: https://arrayfire.org/docs/group__arith__func__le.htm#gad5535ce64dbed46d0773fd494e84e922 -# """ -# return _binary_op(_backend.clib.af_le, x1, x2) # type: ignore[arg-type, return-value] +@return_copy +def le(x1: Array, x2: Array, /) -> Array: + return wrapper.le(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def gt(x1: Array, x2: Array, /) -> Array: -# """ -# source: https://arrayfire.org/docs/group__arith__func__gt.htm#ga4e65603259515de8939899a163ebaf9e -# """ -# return _binary_op(_backend.clib.af_gt, x1, x2) # type: ignore[arg-type, return-value] +@return_copy +def gt(x1: Array, x2: Array, /) -> Array: + return wrapper.gt(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def ge(x1: Array, x2: Array, /) -> Array: -# """ -# source: https://arrayfire.org/docs/group__arith__func__ge.htm#ga4513f212e0b0a22dcf4653e89c85e3d9 -# """ -# return _binary_op(_backend.clib.af_ge, x1, x2) # type: ignore[arg-type, return-value] +@return_copy +def ge(x1: Array, x2: Array, /) -> Array: + return wrapper.ge(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def eq(x1: Array, x2: Array, /) -> Array: -# """ -# source: https://arrayfire.org/docs/group__arith__func__eq.htm#ga76d2da7716831616bb81effa9e163693 -# """ -# return _binary_op(_backend.clib.af_eq, x1, x2) # type: ignore[arg-type, return-value] +@return_copy +def eq(x1: Array, x2: Array, /) -> Array: + return wrapper.eq(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def neq(x1: Array, x2: Array, /) -> Array: -# """ -# source: https://arrayfire.org/docs/group__arith__func__neq.htm#gae4ee8bd06a410f259f1493fb811ce441 -# """ -# return _binary_op(_backend.clib.af_neq, x1, x2) # type: ignore[arg-type, return-value] +@return_copy +def neq(x1: Array, x2: Array, /) -> Array: + return wrapper.neq(x1.arr, x2.arr) # type: ignore[arg-type, return-value] # @return_copy @@ -176,412 +136,318 @@ def bitor(x1: Array, x2: Array, /) -> Array: # # return out -# @return_copy -# def minof(x1: Array, x2: Array, /) -> Array: -# """ -# source: https://arrayfire.org/docs/group__arith__func__min.htm#ga2b842c2d86df978ff68699aeaafca794 -# """ -# return _binary_op(_backend.clib.af_minof, x1, x2) # type: ignore[arg-type, return-value] +@return_copy +def minof(x1: Union[int, float, Array], x2: Union[int, float, Array], /) -> Array: + _check_operands_fit_requirements(x1, x2) + return wrapper.minof(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def maxof(x1: Array, x2: Array, /) -> Array: -# """ -# source: https://arrayfire.org/docs/group__arith__func__max.htm#ga0cd47e70cf82b48730a97c59f494b421 -# """ -# return _binary_op(_backend.clib.af_maxof, x1, x2) # type: ignore[arg-type, return-value] +@return_copy +def maxof(x1: Union[int, float, Array], x2: Union[int, float, Array], /) -> Array: + _check_operands_fit_requirements(x1, x2) -# @return_copy -# def rem(x1: Array, x2: Array, /) -> Array: -# """ -# source: https://arrayfire.org/docs/group__arith__func__clamp.htm#gac4e785c5c877c7905e56f44ef0cb5e61 -# """ -# return _binary_op(_backend.clib.af_rem, x1, x2) # type: ignore[arg-type, return-value] + return wrapper.maxof(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def abs(x: Array, /) -> Array: -# """ -# source: https://arrayfire.org/docs/group__arith__func__abs.htm#ga7e8b3c848e6cda3d1f3b0c8b2b4c3f8f -# """ -# return _unary_op(_backend.clib.af_abs, x) # type: ignore[arg-type, return-value] +@return_copy +def rem(x1: Union[int, float, Array], x2: Union[int, float, Array], /) -> Array: + _check_operands_fit_requirements(x1, x2) + return wrapper.rem(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def arg(x: Array, /) -> Array: -# """ -# source: https://arrayfire.org/docs/group__arith__func__arg.htm#gad04de0f7948688378dcd3628628a7424 -# """ -# return _unary_op(_backend.clib.af_arg, x) # type: ignore[arg-type, return-value] +@return_copy +def abs(x: Array, /) -> Array: + return wrapper.abs(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def sign(x: Array, /) -> Array: -# """ -# source: https://arrayfire.org/docs/group__arith__func__sign.htm#ga2d55dfb9b25e0a1316b70f01d5b44b35 -# """ -# return _unary_op(_backend.clib.af_sign, x) # type: ignore[arg-type, return-value] +@return_copy +def arg(x: Array, /) -> Array: + return wrapper.arg(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def round(x: Array, /) -> Array: -# """ -# source: https://arrayfire.org/docs/group__arith__func__sign.htm#ga2d55dfb9b25e0a1316b70f01d5b44b35 -# """ -# return _unary_op(_backend.clib.af_round, x) # type: ignore[arg-type, return-value] +@return_copy +def sign(x: Array, /) -> Array: + return wrapper.sign(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def trunc(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_trunc, x) # type: ignore[arg-type, return-value] +@return_copy +def round(x: Array, /) -> Array: + return wrapper.round(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def floor(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_floor, x) # type: ignore[arg-type, return-value] +@return_copy +def trunc(x: Array, /) -> Array: + return wrapper.trunc(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def ceil(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_ceil, x) # type: ignore[arg-type, return-value] +@return_copy +def floor(x: Array, /) -> Array: + return wrapper.floor(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def hypot(x1: Array, x2: Array, /) -> Array: -# """ -# source: -# """ -# return _binary_op(_backend.clib.af_hypot, x1, x2) # type: ignore[arg-type, return-value] +@return_copy +def ceil(x: Array, /) -> Array: + return wrapper.ceil(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def sin(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_sin, x) # type: ignore[arg-type, return-value] +@return_copy +def hypot(x1: Union[int, float, Array], x2: Union[int, float, Array], /) -> Array: + _check_operands_fit_requirements(x1, x2) -# @return_copy -# def cos(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_cos, x) # type: ignore[arg-type, return-value] + return wrapper.hypot(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def tan(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_tan, x) # type: ignore[arg-type, return-value] +@return_copy +def sin(x: Array, /) -> Array: + _check_array_values_not_complex(x) + return wrapper.sin(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def asin(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_asin, x) # type: ignore[arg-type, return-value] +@return_copy +def cos(x: Array, /) -> Array: + _check_array_values_not_complex(x) + return wrapper.cos(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def acos(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_acos, x) # type: ignore[arg-type, return-value] +@return_copy +def tan(x: Array, /) -> Array: + _check_array_values_not_complex(x) + return wrapper.tan(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def atan(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_atan, x) # type: ignore[arg-type, return-value] +@return_copy +def asin(x: Array, /) -> Array: + _check_array_values_not_complex(x) + return wrapper.asin(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def atan2(x1: Array, x2: Array, /) -> Array: -# """ -# source: -# """ -# return _binary_op(_backend.clib.af_atan2, x1, x2) # type: ignore[arg-type, return-value] +@return_copy +def acos(x: Array, /) -> Array: + _check_array_values_not_complex(x) + return wrapper.acos(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def cplx(x1: Array, x2: Optional[Array], /) -> Array: -# """ -# source: -# """ -# if x2 is None: -# return _unary_op(_backend.clib.af_cplx, x1) # type: ignore[arg-type, return-value] -# else: -# return _binary_op(_backend.clib.af_cplx2, x1, x2) # type: ignore[arg-type, return-value] +@return_copy +def atan(x: Array, /) -> Array: + _check_array_values_not_complex(x) + return wrapper.atan(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def real(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_real, x) # type: ignore[arg-type, return-value] +@return_copy +def atan2(x1: Union[int, float, Array], x2: Union[int, float, Array], /) -> Array: + _check_operands_fit_requirements(x1, x2) + return wrapper.atan2(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def imag(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_imag, x) # type: ignore[arg-type, return-value] +@return_copy +def cplx(x1: Union[int, float, Array], x2: Union[int, float, Array, None], /) -> Array: + if x2 is None: + return wrapper.cplx1(x1) # type: ignore[arg-type, return-value] + else: + return wrapper.cplx2(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def conjg(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_conjg, x) # type: ignore[arg-type, return-value] +@return_copy +def real(x: Array, /) -> Array: -# @return_copy -# def sinh(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_sinh, x) # type: ignore[arg-type, return-value] + return wrapper.real(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def cosh(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_cosh, x) # type: ignore[arg-type, return-value] +@return_copy +def imag(x: Array, /) -> Array: + return wrapper.imag(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def tanh(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_tanh, x) # type: ignore[arg-type, return-value] +@return_copy +def conjg(x: Array, /) -> Array: + return wrapper.conjg(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def asinh(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_asinh, x) # type: ignore[arg-type, return-value] +@return_copy +def sinh(x: Array, /) -> Array: + _check_array_values_not_complex(x) + return wrapper.sinh(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def acosh(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_acosh, x) # type: ignore[arg-type, return-value] +@return_copy +def cosh(x: Array, /) -> Array: + _check_array_values_not_complex(x) + return wrapper.cosh(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def atanh(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_atanh, x) # type: ignore[arg-type, return-value] +@return_copy +def tanh(x: Array, /) -> Array: + _check_array_values_not_complex(x) + return wrapper.tanh(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def root(x1: Array, x2: Array, /) -> Array: -# """ -# source: -# """ -# return _binary_op(_backend.clib.af_root, x1, x2) # type: ignore[arg-type, return-value] +@return_copy +def asinh(x: Array, /) -> Array: + _check_array_values_not_complex(x) + return wrapper.asinh(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def pow2(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_pow2, x) # type: ignore[arg-type, return-value] +@return_copy +def acosh(x: Array, /) -> Array: + _check_array_values_not_complex(x) + return wrapper.acosh(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def sigmoid(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_sigmoid, x) # type: ignore[arg-type, return-value] +@return_copy +def atanh(x: Array, /) -> Array: + _check_array_values_not_complex(x) + return wrapper.atanh(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def exp(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_exp, x) # type: ignore[arg-type, return-value] +@return_copy +def root(x1: Union[int, float, Array], x2: Union[int, float, Array], /) -> Array: + _check_operands_fit_requirements(x1, x2) + return wrapper.root(x1.arr, x2.arr) -# @return_copy -# def expm1(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_expm1, x) # type: ignore[arg-type, return-value] +@return_copy +def pow2(x: Array, /) -> Array: + _check_array_values_not_complex(x) + return wrapper.pow2(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def erf(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_erf, x) # type: ignore[arg-type, return-value] +@return_copy +def sigmoid(x: Array, /) -> Array: + _check_array_values_not_complex(x) + return wrapper.sigmoid(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def erfc(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_erfc, x) # type: ignore[arg-type, return-value] +@return_copy +def exp(x: Array, /) -> Array: + _check_array_values_not_complex(x) + return wrapper.exp(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def log(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_log, x) # type: ignore[arg-type, return-value] +@return_copy +def expm1(x: Array, /) -> Array: + _check_array_values_not_complex(x) + return wrapper.expm1(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def log1p(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_log1p, x) # type: ignore[arg-type, return-value] +@return_copy +def erf(x: Array, /) -> Array: + _check_array_values_not_complex(x) + return wrapper.erf(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def log10(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_log10, x) # type: ignore[arg-type, return-value] +@return_copy +def erfc(x: Array, /) -> Array: + _check_array_values_not_complex(x) + return wrapper.erfc(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def log2(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_log2, x) # type: ignore[arg-type, return-value] +@return_copy +def log(x: Array, /) -> Array: + _check_array_values_not_complex(x) + return wrapper.log(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def sqrt(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_sqrt, x) # type: ignore[arg-type, return-value] +@return_copy +def log1p(x: Array, /) -> Array: + _check_array_values_not_complex(x) + return wrapper.log1p(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def rsqrt(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_rsqrt, x) # type: ignore[arg-type, return-value] + +@return_copy +def log10(x: Array, /) -> Array: + _check_array_values_not_complex(x) + return wrapper.log10(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def cbrt(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_cbrt, x) # type: ignore[arg-type, return-value] +@return_copy +def log2(x: Array, /) -> Array: + _check_array_values_not_complex(x) + return wrapper.log2(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def factorial(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_factorial, x) # type: ignore[arg-type, return-value] +@return_copy +def sqrt(x: Array, /) -> Array: + _check_array_values_not_complex(x) + return wrapper.sqrt(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def tgamma(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_tgamma, x) # type: ignore[arg-type, return-value] +@return_copy +def rsqrt(x: Array, /) -> Array: + _check_array_values_not_complex(x) + return wrapper.rsqrt(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def lgamma(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_lgamma, x) # type: ignore[arg-type, return-value] +@return_copy +def cbrt(x: Array, /) -> Array: + _check_array_values_not_complex(x) + return wrapper.cbrt(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def iszero(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_iszero, x) # type: ignore[arg-type, return-value] +@return_copy +def factorial(x: Array, /) -> Array: + _check_array_values_not_complex(x) + return wrapper.factorial(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def isinf(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_isinf, x) # type: ignore[arg-type, return-value] +@return_copy +def tgamma(x: Array, /) -> Array: + _check_array_values_not_complex(x) + return wrapper.tgamma(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def isnan(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_isnan, x) # type: ignore[arg-type, return-value] +@return_copy +def lgamma(x: Array, /) -> Array: + _check_array_values_not_complex(x) + return wrapper.lgamma(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def land(x1: Array, x2: Array, /) -> Array: -# """ -# source: -# """ -# return _binary_op(_backend.clib.af_and, x1, x2) # type: ignore[arg-type, return-value] +@return_copy +def iszero(x: Array, /) -> Array: + _check_array_values_not_complex(x) + return wrapper.iszero(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def lor(x1: Array, x2: Array, /) -> Array: -# """ -# source: -# """ -# return _binary_op(_backend.clib.af_or, x1, x2) # type: ignore[arg-type, return-value] +@return_copy +def isinf(x: Array, /) -> Array: + _check_array_values_not_complex(x) + return wrapper.isinf(x.arr) # type: ignore[arg-type, return-value] -# @return_copy -# def lnot(x: Array, /) -> Array: -# """ -# source: -# """ -# return _unary_op(_backend.clib.af_not, x) # type: ignore[arg-type, return-value] +@return_copy +def isnan(x: Array, /) -> Array: + _check_array_values_not_complex(x) + return wrapper.isnan(x.arr) # type: ignore[arg-type, return-value] + + +@return_copy +def land(x1: Array, x2: Array, /) -> Array: + return wrapper.land(x1.arr, x2.arr) # type: ignore[arg-type, return-value] + + +@return_copy +def lor(x1: Array, x2: Array, /) -> Array: + return wrapper.lor(x1.arr, x2.arr) # type: ignore[arg-type, return-value] + + +@return_copy +def lnot(x: Array, /) -> Array: + return wrapper.lnot(x.arr) # type: ignore[arg-type, return-value] + + +def _check_operands_fit_requirements(x1: Union[int, float, Array], x2: Union[int, float, Array]) -> None: + if isinstance(x1, Array) and isinstance(x2, Array): + if x1.shape != x2.shape: + raise ValueError("Array shapes must match.") + + if not isinstance(x1, Array) and not isinstance(x2, Array): + raise ValueError("At least one operand must be an Array.") + + +def _check_array_values_not_complex(x: Array) -> None: + # if is_complex_dtype(x.dtype): + # raise TypeError("Values of an Array should not be the complex numbers.") + pass From cb849e3782744a8c772adbd9fee5570b49f00929 Mon Sep 17 00:00:00 2001 From: Anton Date: Fri, 1 Sep 2023 21:54:53 +0300 Subject: [PATCH 22/31] Add lots of code --- arrayfire/__init__.py | 305 +++++++++--------- arrayfire/array_api/_array_object.py | 2 +- ...py => fixme_test_elementwise_functions.py} | 222 ++++++------- arrayfire/array_object.py | 136 ++++---- arrayfire/backend/_backend.py | 1 + arrayfire/backend/_backend_functions.py | 131 +++++++- arrayfire/backend/_clib_wrapper/__init__.py | 2 + arrayfire/backend/_clib_wrapper/_unsorted.py | 47 +++ arrayfire/library/data.py | 130 ++++++++ arrayfire/library/operators.py | 4 +- arrayfire/library/utils.py | 51 ++- examples/helloworld.py | 57 ++++ tests/_helpers.py | 6 + ...perators.py => test_operator_overrides.py} | 10 +- tests/test_data.py | 67 ++++ tests/test_dtypes.py | 3 +- tests/test_operators.py | 11 + 17 files changed, 838 insertions(+), 347 deletions(-) rename arrayfire/array_api/tests/{test_elementwise_functions.py => fixme_test_elementwise_functions.py} (97%) mode change 100755 => 100644 create mode 100644 arrayfire/library/data.py create mode 100644 examples/helloworld.py create mode 100644 tests/_helpers.py rename tests/array_object/{test_operators.py => test_operator_overrides.py} (91%) mode change 100755 => 100644 create mode 100644 tests/test_data.py diff --git a/arrayfire/__init__.py b/arrayfire/__init__.py index 9ca05f4..7bb8c64 100755 --- a/arrayfire/__init__.py +++ b/arrayfire/__init__.py @@ -72,154 +72,161 @@ set_backend, ) -# __all__ += [ -# "add", -# "sub", -# "mul", -# "div", -# "mod", -# "pow", -# "bitnot", -# "bitand", -# "bitor", -# "bitxor", -# "bitshiftl", -# "bitshiftr", -# "lt", -# "le", -# "gt", -# "ge", -# "eq", -# "neq", -# "sin", -# "cos", -# "tan", -# "asin", -# "acos", -# "atan", -# "atan2", -# "sinh", -# "cosh", -# "tanh", -# "asinh", -# "acosh", -# "atanh", -# "exp", -# "expm1", -# "log", -# "log1p", -# "log2", -# "log10", -# "sqrt", -# "cbrt", -# "hypot", -# "erf", -# "erfc", -# "tgamma", -# "lgamma", -# "pow2", -# "sign", -# "abs", -# "ceil", -# "floor", -# "round", -# "trunc", -# "isinf", -# "isnan", -# "iszero", -# "isinf", -# "isnan", -# "iszero", -# "isinf", -# "isnan", -# "clamp", -# "arg", -# "conjg", -# "cplx", -# "imag", -# "factorial", -# "maxof", -# "minof", -# "real", -# "rem", -# "root", -# "rsqrt", -# "sigmoid", -# "land", -# "lor", -# "lnot", -# ] +__all__ += [ + "add", + "sub", + "mul", + "div", + "mod", + "pow", + "bitnot", + "bitand", + "bitor", + "bitxor", + "bitshiftl", + "bitshiftr", + "lt", + "le", + "gt", + "ge", + "eq", + "neq", + "sin", + "cos", + "tan", + "asin", + "acos", + "atan", + "atan2", + "sinh", + "cosh", + "tanh", + "asinh", + "acosh", + "atanh", + "exp", + "expm1", + "log", + "log1p", + "log2", + "log10", + "sqrt", + "cbrt", + "hypot", + "erf", + "erfc", + "tgamma", + "lgamma", + "pow2", + "sign", + "abs", + "ceil", + "floor", + "round", + "trunc", + "isinf", + "isnan", + "iszero", + "isinf", + "isnan", + "iszero", + "isinf", + "isnan", + "clamp", + "arg", + "conjg", + "cplx", + "imag", + "factorial", + "maxof", + "minof", + "real", + "rem", + "root", + "rsqrt", + "sigmoid", + "land", + "lor", + "lnot", +] + + +from .library.operators import ( + abs, + acos, + acosh, + add, + arg, + asin, + asinh, + atan, + atan2, + atanh, + bitand, + bitnot, + bitor, + bitshiftl, + bitshiftr, + bitxor, + cbrt, + ceil, + conjg, + cos, + cosh, + cplx, + div, + eq, + erf, + erfc, + exp, + expm1, + factorial, + floor, + ge, + gt, + hypot, + imag, + isinf, + isnan, + iszero, + land, + le, + lgamma, + lnot, + log, + log1p, + log2, + log10, + lor, + lt, + maxof, + minof, + mod, + mul, + neq, + pow, + pow2, + real, + rem, + root, + round, + rsqrt, + sigmoid, + sign, + sin, + sinh, + sqrt, + sub, + tan, + tanh, + tgamma, + trunc, +) + +__all__ += [ + "constant", + "range" +] -# from .library.operators import ( -# abs, -# acos, -# acosh, -# add, -# arg, -# asin, -# asinh, -# atan, -# atan2, -# atanh, -# bitand, -# bitnot, -# bitor, -# bitshiftl, -# bitshiftr, -# bitxor, -# cbrt, -# ceil, -# clamp, -# conjg, -# cos, -# cosh, -# cplx, -# div, -# eq, -# erf, -# erfc, -# exp, -# expm1, -# factorial, -# floor, -# ge, -# gt, -# hypot, -# imag, -# isinf, -# isnan, -# iszero, -# land, -# le, -# lgamma, -# lnot, -# log, -# log1p, -# log2, -# log10, -# lor, -# lt, -# maxof, -# minof, -# mod, -# mul, -# neq, -# pow, -# pow2, -# real, -# rem, -# root, -# round, -# rsqrt, -# sigmoid, -# sign, -# sin, -# sinh, -# sqrt, -# sub, -# tan, -# tanh, -# tgamma, -# trunc, -# ) +from arrayfire.library.data import constant, range diff --git a/arrayfire/array_api/_array_object.py b/arrayfire/array_api/_array_object.py index d4026e2..6e65056 100755 --- a/arrayfire/array_api/_array_object.py +++ b/arrayfire/array_api/_array_object.py @@ -27,7 +27,7 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Array: "Use an array creation function, such as asarray(), instead." ) - def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_category: str, op: str) -> Array: + def _check_allowed_dtypes(self, other: Union[bool, int, float, Array], dtype_category: str, op: str) -> Array: """ Helper function for operators to only allow specific input dtypes diff --git a/arrayfire/array_api/tests/test_elementwise_functions.py b/arrayfire/array_api/tests/fixme_test_elementwise_functions.py old mode 100755 new mode 100644 similarity index 97% rename from arrayfire/array_api/tests/test_elementwise_functions.py rename to arrayfire/array_api/tests/fixme_test_elementwise_functions.py index c88e244..3e5261c --- a/arrayfire/array_api/tests/test_elementwise_functions.py +++ b/arrayfire/array_api/tests/fixme_test_elementwise_functions.py @@ -1,111 +1,111 @@ -from inspect import getfullargspec -from typing import Callable, Iterator, TYPE_CHECKING - -import pytest - -from .. import _elementwise_functions, asarray -from .._dtypes import boolean_dtypes, dtype_categories, floating_dtypes, integer_dtypes, int8, real_floating_dtypes, int8 -from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift -from .._array_object import Array - - - -def nargs(func: Callable) -> int: - return len(getfullargspec(func).args) - - -def test_function_types() -> None: - # Test that every function accepts only the required input types. We only - # test the negative cases here (error). The positive cases are tested in - # the array API test suite. - - elementwise_function_input_types = { - "abs": "numeric", - "acos": "floating-point", - "acosh": "floating-point", - "add": "numeric", - "asin": "floating-point", - "asinh": "floating-point", - "atan": "floating-point", - "atan2": "real floating-point", - "atanh": "floating-point", - # "bitwise_and": "integer or boolean", - # "bitwise_invert": "integer or boolean", - # "bitwise_left_shift": "integer", - # "bitwise_or": "integer or boolean", - # "bitwise_right_shift": "integer", - # "bitwise_xor": "integer or boolean", - "ceil": "real numeric", - # "conj": "complex floating-point", - "cos": "floating-point", - "cosh": "floating-point", - "divide": "floating-point", - "equal": "all", - "exp": "floating-point", - "expm1": "floating-point", - "floor": "real numeric", - "floor_divide": "real numeric", - "greater": "real numeric", - "greater_equal": "real numeric", - # "imag": "complex floating-point", - "isfinite": "numeric", - "isinf": "numeric", - "isnan": "numeric", - "less": "real numeric", - "less_equal": "real numeric", - "log": "floating-point", - "logaddexp": "real floating-point", - "log10": "floating-point", - "log1p": "floating-point", - "log2": "floating-point", - # "logical_and": "boolean", - # "logical_not": "boolean", - # "logical_or": "boolean", - # "logical_xor": "boolean", - "multiply": "numeric", - "negative": "numeric", - "not_equal": "all", - "positive": "numeric", - "pow": "numeric", - # "real": "complex floating-point", - "remainder": "real numeric", - "round": "numeric", - "sign": "numeric", - "sin": "floating-point", - "sinh": "floating-point", - "sqrt": "floating-point", - "square": "numeric", - "subtract": "numeric", - "tan": "floating-point", - "tanh": "floating-point", - "trunc": "real numeric", - } - - def _array_vals() -> Iterator[Array]: - for dt in integer_dtypes: - if dt in {int8}: - continue - yield asarray([1], dtype=dt) - # for d in boolean_dtypes: - # yield asarray(False, dtype=d) - for dt in real_floating_dtypes: - yield asarray([1.0], dtype=dt) - - for x in _array_vals(): - for func_name, types in elementwise_function_input_types.items(): - dtypes = dtype_categories[types] - func = getattr(_elementwise_functions, func_name) - if nargs(func) == 2: - for y in _array_vals(): - if x.dtype not in dtypes or y.dtype not in dtypes: - pytest.raises(TypeError, lambda: func(x, y)) - else: - if x.dtype not in dtypes: - print(func) - pytest.raises(TypeError, lambda: func(x)) - - -# def test_bitwise_shift_error() -> None: -# # bitwise shift functions should raise when the second argument is negative -# pytest.raises(ValueError, lambda: bitwise_left_shift(asarray([1, 1]), asarray([1, -1]))) -# pytest.raises(ValueError, lambda: bitwise_right_shift(asarray([1, 1]), asarray([1, -1]))) +from inspect import getfullargspec +from typing import Callable, Iterator, TYPE_CHECKING + +import pytest + +from .. import _elementwise_functions, asarray +from .._dtypes import boolean_dtypes, dtype_categories, floating_dtypes, integer_dtypes, int8, real_floating_dtypes, int8 +from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift +from .._array_object import Array + + + +def nargs(func: Callable) -> int: + return len(getfullargspec(func).args) + + +def test_function_types() -> None: + # Test that every function accepts only the required input types. We only + # test the negative cases here (error). The positive cases are tested in + # the array API test suite. + + elementwise_function_input_types = { + "abs": "numeric", + "acos": "floating-point", + "acosh": "floating-point", + "add": "numeric", + "asin": "floating-point", + "asinh": "floating-point", + "atan": "floating-point", + "atan2": "real floating-point", + "atanh": "floating-point", + # "bitwise_and": "integer or boolean", + # "bitwise_invert": "integer or boolean", + # "bitwise_left_shift": "integer", + # "bitwise_or": "integer or boolean", + # "bitwise_right_shift": "integer", + # "bitwise_xor": "integer or boolean", + "ceil": "real numeric", + # "conj": "complex floating-point", + "cos": "floating-point", + "cosh": "floating-point", + "divide": "floating-point", + "equal": "all", + "exp": "floating-point", + "expm1": "floating-point", + "floor": "real numeric", + "floor_divide": "real numeric", + "greater": "real numeric", + "greater_equal": "real numeric", + # "imag": "complex floating-point", + "isfinite": "numeric", + "isinf": "numeric", + "isnan": "numeric", + "less": "real numeric", + "less_equal": "real numeric", + "log": "floating-point", + "logaddexp": "real floating-point", + "log10": "floating-point", + "log1p": "floating-point", + "log2": "floating-point", + # "logical_and": "boolean", + # "logical_not": "boolean", + # "logical_or": "boolean", + # "logical_xor": "boolean", + "multiply": "numeric", + "negative": "numeric", + "not_equal": "all", + "positive": "numeric", + "pow": "numeric", + # "real": "complex floating-point", + "remainder": "real numeric", + "round": "numeric", + "sign": "numeric", + "sin": "floating-point", + "sinh": "floating-point", + "sqrt": "floating-point", + "square": "numeric", + "subtract": "numeric", + "tan": "floating-point", + "tanh": "floating-point", + "trunc": "real numeric", + } + + def _array_vals() -> Iterator[Array]: + for dt in integer_dtypes: + if dt in {int8}: + continue + yield asarray([1], dtype=dt) + # for d in boolean_dtypes: + # yield asarray(False, dtype=d) + for dt in real_floating_dtypes: + yield asarray([1.0], dtype=dt) + + for x in _array_vals(): + for func_name, types in elementwise_function_input_types.items(): + dtypes = dtype_categories[types] + func = getattr(_elementwise_functions, func_name) + if nargs(func) == 2: + for y in _array_vals(): + if x.dtype not in dtypes or y.dtype not in dtypes: + pytest.raises(TypeError, lambda: func(x, y)) + else: + if x.dtype not in dtypes: + print(func) + pytest.raises(TypeError, lambda: func(x)) + + +# def test_bitwise_shift_error() -> None: +# # bitwise shift functions should raise when the second argument is negative +# pytest.raises(ValueError, lambda: bitwise_left_shift(asarray([1, 1]), asarray([1, -1]))) +# pytest.raises(ValueError, lambda: bitwise_right_shift(asarray([1, 1]), asarray([1, -1]))) diff --git a/arrayfire/array_object.py b/arrayfire/array_object.py index 2ba45ad..774801c 100755 --- a/arrayfire/array_object.py +++ b/arrayfire/array_object.py @@ -2,10 +2,10 @@ import array as py_array from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, cast from .backend import _clib_wrapper as wrapper -from .dtypes import CType, Dtype, c_api_value_to_dtype, float32, str_to_dtype, float64 +from .dtypes import CType, Dtype, c_api_value_to_dtype, float32, float64, str_to_dtype from .library.device import PointerSource if TYPE_CHECKING: @@ -35,12 +35,12 @@ def __call__(self, *x: Array) -> Array: class Array: def __init__( self, - obj: Union[None, Array, py_array.array, int, wrapper.AFArrayType, List[Union[int, float]]] = None, - dtype: Union[None, Dtype, str] = None, - shape: Tuple[int, ...] = (), + obj: None | Array | py_array.array | int | wrapper.AFArrayType | list[int | float] = None, + dtype: None | Dtype | str = None, + shape: tuple[int, ...] = (), to_device: bool = False, - offset: Optional[CType] = None, - strides: Optional[Tuple[int, ...]] = None, + offset: CType | None = None, + strides: tuple[int, ...] | None = None, ) -> None: _no_initial_dtype = False # HACK, FIXME @@ -154,7 +154,7 @@ def __neg__(self) -> Array: """ return _process_c_function(0, self, wrapper.sub) - def __add__(self, other: Union[int, float, Array], /) -> Array: + def __add__(self, other: int | float | Array, /) -> Array: """ Calculates the sum for each element of an array instance with the respective element of the array other. @@ -162,7 +162,7 @@ def __add__(self, other: Union[int, float, Array], /) -> Array: ---------- self : Array Array instance (augend array). Should have a numeric data type. - other: Union[int, float, Array] + other: int | float | Array Addend array. Must be compatible with self (see Broadcasting). Should have a numeric data type. Returns @@ -173,7 +173,7 @@ def __add__(self, other: Union[int, float, Array], /) -> Array: """ return _process_c_function(self, other, wrapper.add) - def __sub__(self, other: Union[int, float, Array], /) -> Array: + def __sub__(self, other: int | float | Array, /) -> Array: """ Calculates the difference for each element of an array instance with the respective element of the array other. @@ -184,7 +184,7 @@ def __sub__(self, other: Union[int, float, Array], /) -> Array: ---------- self : Array Array instance (minuend array). Should have a numeric data type. - other: Union[int, float, Array] + other: int | float | Array Subtrahend array. Must be compatible with self (see Broadcasting). Should have a numeric data type. Returns @@ -195,7 +195,7 @@ def __sub__(self, other: Union[int, float, Array], /) -> Array: """ return _process_c_function(self, other, wrapper.sub) - def __mul__(self, other: Union[int, float, Array], /) -> Array: + def __mul__(self, other: int | float | Array, /) -> Array: """ Calculates the product for each element of an array instance with the respective element of the array other. @@ -203,7 +203,7 @@ def __mul__(self, other: Union[int, float, Array], /) -> Array: ---------- self : Array Array instance. Should have a numeric data type. - other: Union[int, float, Array] + other: int | float | Array Other array. Must be compatible with self (see Broadcasting). Should have a numeric data type. Returns @@ -214,7 +214,7 @@ def __mul__(self, other: Union[int, float, Array], /) -> Array: """ return _process_c_function(self, other, wrapper.mul) - def __truediv__(self, other: Union[int, float, Array], /) -> Array: + def __truediv__(self, other: int | float | Array, /) -> Array: """ Evaluates self_i / other_i for each element of an array instance with the respective element of the array other. @@ -223,7 +223,7 @@ def __truediv__(self, other: Union[int, float, Array], /) -> Array: ---------- self : Array Array instance. Should have a numeric data type. - other: Union[int, float, Array] + other: int | float | Array Other array. Must be compatible with self (see Broadcasting). Should have a numeric data type. Returns @@ -241,11 +241,11 @@ def __truediv__(self, other: Union[int, float, Array], /) -> Array: """ return _process_c_function(self, other, wrapper.div) - def __floordiv__(self, other: Union[int, float, Array], /) -> Array: + def __floordiv__(self, other: int | float | Array, /) -> Array: # TODO return NotImplemented - def __mod__(self, other: Union[int, float, Array], /) -> Array: + def __mod__(self, other: int | float | Array, /) -> Array: """ Evaluates self_i % other_i for each element of an array instance with the respective element of the array other. @@ -254,7 +254,7 @@ def __mod__(self, other: Union[int, float, Array], /) -> Array: ---------- self : Array Array instance. Should have a real-valued data type. - other: Union[int, float, Array] + other: int | float | Array Other array. Must be compatible with self (see Broadcasting). Should have a real-valued data type. Returns @@ -271,7 +271,7 @@ def __mod__(self, other: Union[int, float, Array], /) -> Array: """ return _process_c_function(self, other, wrapper.mod) - def __pow__(self, other: Union[int, float, Array], /) -> Array: + def __pow__(self, other: int | float | Array, /) -> Array: """ Calculates an implementation-dependent approximation of exponentiation by raising each element (the base) of an array instance to the power of other_i (the exponent), where other_i is the corresponding element of the @@ -281,7 +281,7 @@ def __pow__(self, other: Union[int, float, Array], /) -> Array: ---------- self : Array Array instance whose elements correspond to the exponentiation base. Should have a numeric data type. - other: Union[int, float, Array] + other: int | float | Array Other array whose elements correspond to the exponentiation exponent. Must be compatible with self (see Broadcasting). Should have a numeric data type. @@ -320,7 +320,7 @@ def __invert__(self) -> Array: out.arr = wrapper.bitnot(self.arr) return out - def __and__(self, other: Union[int, bool, Array], /) -> Array: + def __and__(self, other: int | bool | Array, /) -> Array: """ Evaluates self_i & other_i for each element of an array instance with the respective element of the array other. @@ -329,7 +329,7 @@ def __and__(self, other: Union[int, bool, Array], /) -> Array: ---------- self : Array Array instance. Should have a numeric data type. - other: Union[int, bool, Array] + other: int | bool | Array Other array. Must be compatible with self (see Broadcasting). Should have a numeric data type. Returns @@ -340,7 +340,7 @@ def __and__(self, other: Union[int, bool, Array], /) -> Array: """ return _process_c_function(self, other, wrapper.bitand) - def __or__(self, other: Union[int, bool, Array], /) -> Array: + def __or__(self, other: int | bool | Array, /) -> Array: """ Evaluates self_i | other_i for each element of an array instance with the respective element of the array other. @@ -349,7 +349,7 @@ def __or__(self, other: Union[int, bool, Array], /) -> Array: ---------- self : Array Array instance. Should have a numeric data type. - other: Union[int, bool, Array] + other: int | bool | Array Other array. Must be compatible with self (see Broadcasting). Should have a numeric data type. Returns @@ -360,7 +360,7 @@ def __or__(self, other: Union[int, bool, Array], /) -> Array: """ return _process_c_function(self, other, wrapper.bitor) - def __xor__(self, other: Union[int, bool, Array], /) -> Array: + def __xor__(self, other: int | bool | Array, /) -> Array: """ Evaluates self_i ^ other_i for each element of an array instance with the respective element of the array other. @@ -369,7 +369,7 @@ def __xor__(self, other: Union[int, bool, Array], /) -> Array: ---------- self : Array Array instance. Should have a numeric data type. - other: Union[int, bool, Array] + other: int | bool | Array Other array. Must be compatible with self (see Broadcasting). Should have a numeric data type. Returns @@ -380,7 +380,7 @@ def __xor__(self, other: Union[int, bool, Array], /) -> Array: """ return _process_c_function(self, other, wrapper.bitxor) - def __lshift__(self, other: Union[int, Array], /) -> Array: + def __lshift__(self, other: int | Array, /) -> Array: """ Evaluates self_i << other_i for each element of an array instance with the respective element of the array other. @@ -389,7 +389,7 @@ def __lshift__(self, other: Union[int, Array], /) -> Array: ---------- self : Array Array instance. Should have a numeric data type. - other: Union[int, Array] + other: int | Array Other array. Must be compatible with self (see Broadcasting). Should have a numeric data type. Each element must be greater than or equal to 0. @@ -400,7 +400,7 @@ def __lshift__(self, other: Union[int, Array], /) -> Array: """ return _process_c_function(self, other, wrapper.bitshiftl) - def __rshift__(self, other: Union[int, Array], /) -> Array: + def __rshift__(self, other: int | Array, /) -> Array: """ Evaluates self_i >> other_i for each element of an array instance with the respective element of the array other. @@ -409,7 +409,7 @@ def __rshift__(self, other: Union[int, Array], /) -> Array: ---------- self : Array Array instance. Should have a numeric data type. - other: Union[int, Array] + other: int | Array Other array. Must be compatible with self (see Broadcasting). Should have a numeric data type. Each element must be greater than or equal to 0. @@ -422,7 +422,7 @@ def __rshift__(self, other: Union[int, Array], /) -> Array: # Comparison Operators - def __lt__(self, other: Union[int, float, Array], /) -> Array: + def __lt__(self, other: int | float | Array, /) -> Array: """ Computes the truth value of self_i < other_i for each element of an array instance with the respective element of the array other. @@ -431,7 +431,7 @@ def __lt__(self, other: Union[int, float, Array], /) -> Array: ---------- self : Array Array instance. Should have a numeric data type. - other: Union[int, float, Array] + other: int | float | Array Other array. Must be compatible with self (see Broadcasting). Should have a real-valued data type. Returns @@ -441,7 +441,7 @@ def __lt__(self, other: Union[int, float, Array], /) -> Array: """ return _process_c_function(self, other, wrapper.lt) - def __le__(self, other: Union[int, float, Array], /) -> Array: + def __le__(self, other: int | float | Array, /) -> Array: """ Computes the truth value of self_i <= other_i for each element of an array instance with the respective element of the array other. @@ -450,7 +450,7 @@ def __le__(self, other: Union[int, float, Array], /) -> Array: ---------- self : Array Array instance. Should have a numeric data type. - other: Union[int, float, Array] + other: int | float | Array Other array. Must be compatible with self (see Broadcasting). Should have a real-valued data type. Returns @@ -460,7 +460,7 @@ def __le__(self, other: Union[int, float, Array], /) -> Array: """ return _process_c_function(self, other, wrapper.le) - def __gt__(self, other: Union[int, float, Array], /) -> Array: + def __gt__(self, other: int | float | Array, /) -> Array: """ Computes the truth value of self_i > other_i for each element of an array instance with the respective element of the array other. @@ -469,7 +469,7 @@ def __gt__(self, other: Union[int, float, Array], /) -> Array: ---------- self : Array Array instance. Should have a numeric data type. - other: Union[int, float, Array] + other: int | float | Array Other array. Must be compatible with self (see Broadcasting). Should have a real-valued data type. Returns @@ -479,7 +479,7 @@ def __gt__(self, other: Union[int, float, Array], /) -> Array: """ return _process_c_function(self, other, wrapper.gt) - def __ge__(self, other: Union[int, float, Array], /) -> Array: + def __ge__(self, other: int | float | Array, /) -> Array: """ Computes the truth value of self_i >= other_i for each element of an array instance with the respective element of the array other. @@ -488,7 +488,7 @@ def __ge__(self, other: Union[int, float, Array], /) -> Array: ---------- self : Array Array instance. Should have a numeric data type. - other: Union[int, float, Array] + other: int | float | Array Other array. Must be compatible with self (see Broadcasting). Should have a real-valued data type. Returns @@ -498,7 +498,7 @@ def __ge__(self, other: Union[int, float, Array], /) -> Array: """ return _process_c_function(self, other, wrapper.ge) - def __eq__(self, other: Union[int, float, bool, Array], /) -> Array: # type: ignore[override] + def __eq__(self, other: int | float | bool | Array, /) -> Array: # type: ignore[override] """ Computes the truth value of self_i == other_i for each element of an array instance with the respective element of the array other. @@ -507,7 +507,7 @@ def __eq__(self, other: Union[int, float, bool, Array], /) -> Array: # type: ig ---------- self : Array Array instance. Should have a numeric data type. - other: Union[int, float, bool, Array] + other: int | float | bool | Array Other array. Must be compatible with self (see Broadcasting). May have any data type. Returns @@ -517,7 +517,7 @@ def __eq__(self, other: Union[int, float, bool, Array], /) -> Array: # type: ig """ return _process_c_function(self, other, wrapper.eq) - def __ne__(self, other: Union[int, float, bool, Array], /) -> Array: # type: ignore[override] + def __ne__(self, other: int | float | bool | Array, /) -> Array: # type: ignore[override] """ Computes the truth value of self_i != other_i for each element of an array instance with the respective element of the array other. @@ -526,7 +526,7 @@ def __ne__(self, other: Union[int, float, bool, Array], /) -> Array: # type: ig ---------- self : Array Array instance. Should have a numeric data type. - other: Union[int, float, bool, Array] + other: int | float | bool | Array Other array. Must be compatible with self (see Broadcasting). May have any data type. Returns @@ -618,42 +618,42 @@ def __rrshift__(self, other: Array, /) -> Array: # In-place Arithmetic Operators - def __iadd__(self, other: Union[int, float, Array], /) -> Array: + def __iadd__(self, other: int | float | Array, /) -> Array: # TODO discuss either we need to support complex and bool as other input type """ Return self += other. """ return _process_c_function(self, other, wrapper.add) - def __isub__(self, other: Union[int, float, Array], /) -> Array: + def __isub__(self, other: int | float | Array, /) -> Array: """ Return self -= other. """ return _process_c_function(self, other, wrapper.sub) - def __imul__(self, other: Union[int, float, Array], /) -> Array: + def __imul__(self, other: int | float | Array, /) -> Array: """ Return self *= other. """ return _process_c_function(self, other, wrapper.mul) - def __itruediv__(self, other: Union[int, float, Array], /) -> Array: + def __itruediv__(self, other: int | float | Array, /) -> Array: """ Return self /= other. """ return _process_c_function(self, other, wrapper.div) - def __ifloordiv__(self, other: Union[int, float, Array], /) -> Array: + def __ifloordiv__(self, other: int | float | Array, /) -> Array: # TODO return NotImplemented - def __imod__(self, other: Union[int, float, Array], /) -> Array: + def __imod__(self, other: int | float | Array, /) -> Array: """ Return self %= other. """ return _process_c_function(self, other, wrapper.mod) - def __ipow__(self, other: Union[int, float, Array], /) -> Array: + def __ipow__(self, other: int | float | Array, /) -> Array: """ Return self **= other. """ @@ -667,31 +667,31 @@ def __imatmul__(self, other: Array, /) -> Array: # In-place Bitwise Operators - def __iand__(self, other: Union[int, bool, Array], /) -> Array: + def __iand__(self, other: int | bool | Array, /) -> Array: """ Return self &= other. """ return _process_c_function(self, other, wrapper.bitand) - def __ior__(self, other: Union[int, bool, Array], /) -> Array: + def __ior__(self, other: int | bool | Array, /) -> Array: """ Return self |= other. """ return _process_c_function(self, other, wrapper.bitor) - def __ixor__(self, other: Union[int, bool, Array], /) -> Array: + def __ixor__(self, other: int | bool | Array, /) -> Array: """ Return self ^= other. """ return _process_c_function(self, other, wrapper.bitxor) - def __ilshift__(self, other: Union[int, Array], /) -> Array: + def __ilshift__(self, other: int | Array, /) -> Array: """ Return self <<= other. """ return _process_c_function(self, other, wrapper.bitshiftl) - def __irshift__(self, other: Union[int, Array], /) -> Array: + def __irshift__(self, other: int | Array, /) -> Array: """ Return self >>= other. """ @@ -703,7 +703,7 @@ def __abs__(self) -> Array: # TODO return NotImplemented - def __array_namespace__(self, *, api_version: Optional[str] = None) -> Any: + def __array_namespace__(self, *, api_version: str | None = None) -> Any: # TODO return NotImplemented @@ -715,11 +715,11 @@ def __complex__(self) -> complex: # TODO return NotImplemented - def __dlpack__(self, *, stream: Union[None, int, Any] = None): # type: ignore[no-untyped-def] + def __dlpack__(self, *, stream: int | Any | None = None): # type: ignore[no-untyped-def] # TODO implementation and expected return type -> PyCapsule return NotImplemented - def __dlpack_device__(self) -> Tuple[Enum, int]: + def __dlpack_device__(self) -> tuple[Enum, int]: # TODO return NotImplemented @@ -735,7 +735,7 @@ def __getitem__(self, key: IndexKey, /) -> Array: ---------- self : Array Array instance. - key : Union[int, slice, Tuple[Union[int, slice, ], ...], Array] + key : int | slice | tuple[int | slice, ...] | Array Index key. Returns @@ -747,7 +747,7 @@ def __getitem__(self, key: IndexKey, /) -> Array: from .dtypes import bool # TODO - # API Specification - key: Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], array]. + # API Specification - key: Union[int, slice, ellipsis, tuple[Union[int, slice, ellipsis], ...], array]. # consider using af.span to replace ellipsis during refactoring out = Array() ndims = self.ndim @@ -772,7 +772,7 @@ def __int__(self) -> int: def __len__(self) -> int: return self.shape[0] if self.shape else 0 - def __setitem__(self, key: IndexKey, value: Union[int, float, bool, Array], /) -> None: + def __setitem__(self, key: IndexKey, value: int | float | bool | Array, /) -> None: # TODO return NotImplemented # type: ignore[return-value] # FIXME @@ -787,7 +787,7 @@ def __repr__(self) -> str: # TODO change the look of array representation. E.g., like np.array return wrapper.array_as_str(self.arr) - def to_device(self, device: Any, /, *, stream: Union[int, Any] = None) -> Array: + def to_device(self, device: Any, /, *, stream: int | Any = None) -> Array: # TODO implementation and change device type from Any to Device return NotImplemented @@ -867,7 +867,7 @@ def ndim(self) -> int: return wrapper.get_numdims(self.arr) @property - def shape(self) -> Tuple[int, ...]: + def shape(self) -> tuple[int, ...]: """ Array dimensions. @@ -879,7 +879,7 @@ def shape(self) -> Tuple[int, ...]: # NOTE skipping passing any None values return wrapper.get_dims(self.arr)[: self.ndim] - def scalar(self) -> Union[None, int, float, bool, complex]: + def scalar(self) -> int | float | bool | complex | None: """ Return the first element of the array """ @@ -895,7 +895,7 @@ def is_empty(self) -> bool: """ return wrapper.is_empty(self.arr) - def to_list(self, row_major: bool = False) -> List[Union[None, int, float, bool, complex]]: + def to_list(self, row_major: bool = False) -> list[int | float | bool | complex]: if self.is_empty(): return [] @@ -903,7 +903,7 @@ def to_list(self, row_major: bool = False) -> List[Union[None, int, float, bool, ctypes_array = wrapper.get_data_ptr(array.arr, array.size, array.dtype) if array.ndim == 1: - return ctypes_array[:] + return cast(list, ctypes_array[:]) # HACK out = [] for i in range(array.size): @@ -940,7 +940,7 @@ def from_afarray(cls, array: wrapper.AFArrayType) -> None: cls.arr = array -IndexKey = Union[int, slice, Tuple[Union[int, slice], ...], Array] +IndexKey = int | slice | tuple[int | slice, ...] | Array def _reorder(array: Array) -> Array: @@ -953,11 +953,11 @@ def _reorder(array: Array) -> Array: return Array(wrapper.reorder(array.arr, array.ndim)) -def _metadata_string(dtype: Dtype, dims: Optional[Tuple[int, ...]] = None) -> str: +def _metadata_string(dtype: Dtype, dims: tuple[int, ...] | None = None) -> str: return "arrayfire.Array()\n" f"Type: {dtype.name}\n" f"Dims: {str(dims) if dims else ''}" -def _process_c_function(lhs: Union[int, float, Array], rhs: Union[int, float, Array], c_function: Any) -> Array: +def _process_c_function(lhs: int | float | Array, rhs: int | float | Array, c_function: Any) -> Array: out = Array() if isinstance(lhs, Array) and isinstance(rhs, Array): diff --git a/arrayfire/backend/_backend.py b/arrayfire/backend/_backend.py index ce71172..7791a4d 100644 --- a/arrayfire/backend/_backend.py +++ b/arrayfire/backend/_backend.py @@ -115,6 +115,7 @@ class BackendType(enum.Enum): # TODO change name - avoid using _backend_type - cpu = 1 cuda = 2 opencl = 4 + oneapi = 8 def __iter__(self) -> Iterator: # NOTE cpu comes last because we want to keep this order priorty during backend initialization diff --git a/arrayfire/backend/_backend_functions.py b/arrayfire/backend/_backend_functions.py index 6f4f756..98f9e70 100755 --- a/arrayfire/backend/_backend_functions.py +++ b/arrayfire/backend/_backend_functions.py @@ -1,20 +1,30 @@ from __future__ import annotations import warnings +from enum import Enum from typing import TYPE_CHECKING, Union from ._backend import Backend, BackendType, get_backend +from ._clib_wrapper._unsorted import cublas_set_math_mode from ._clib_wrapper._unsorted import get_backend_count as c_get_backend_count from ._clib_wrapper._unsorted import get_backend_id as c_get_backend_id from ._clib_wrapper._unsorted import get_device_id as c_get_device_id +from ._clib_wrapper._unsorted import get_native_id as c_get_native_id from ._clib_wrapper._unsorted import get_size_of as c_get_size_of +from ._clib_wrapper._unsorted import get_stream as c_get_stream from ._clib_wrapper._unsorted import set_backend as c_set_backend +from ._clib_wrapper._unsorted import set_native_id as c_set_native_id if TYPE_CHECKING: from arrayfire import Array from arrayfire.dtypes import Dtype +class CublasMathMode(Enum): + default = 0 + tensor_op = 1 + + def set_backend(backend_type: Union[BackendType, str]) -> None: """ Set a specific backend by backend_type name. @@ -22,7 +32,7 @@ def set_backend(backend_type: Union[BackendType, str]) -> None: Parameters ---------- backend_type : Union[BackendType, str] - Name of the backend backend_type to set. + Name of the backend type to set. Raises ------ @@ -173,3 +183,122 @@ def get_dtype_size(dtype: Dtype) -> int: def get_size_of(dtype: Dtype) -> int: warnings.warn("Was renamed due to unintuitive function name. Now get_dtype_size().", DeprecationWarning) return get_dtype_size(dtype) + + +# Previously module arrayfire.cuda + + +def _check_if_cuda_used() -> None: + backend = get_backend() + if backend.backend_type != BackendType.cuda: + raise RuntimeError( + f"Can not get the CUDA stream id because the other backend is in use: {backend.backend_type}." + ) + + +def get_stream(index: int) -> int: + warnings.warn("Was renamed due to unintuitive function name. Now get_cuda_stream().", DeprecationWarning) + return get_cuda_stream(index) + + +def get_cuda_stream(index: int) -> int: + """ + Get the CUDA stream used for the device id by ArrayFire. + + Parameters + ---------- + idx : int + Specifies the index of the device. + + Returns + ------- + value : int + Denoting the stream id. + + Raises + ------ + RuntimeError + If the current backend type is not CUDA. + """ + _check_if_cuda_used() + + return c_get_stream(index) + + +def get_native_id(index: int) -> int: + warnings.warn("Was renamed due to unintuitive function name. Now get_native_cuda_id().", DeprecationWarning) + return get_native_cuda_id(index) + + +def get_native_cuda_id(index: int) -> int: + """ + Get native (unsorted) CUDA device id. + + Parameters + ---------- + idx : int + Specifies the (sorted) index of the device. + + Returns + ------- + value : int + Denoting the native cuda id. + + Raises + ------ + RuntimeError + If the current backend type is not CUDA. + """ + _check_if_cuda_used() + + return c_get_native_id(index) + + +def set_native_id(index: int) -> None: + warnings.warn("Was renamed due to unintuitive function name. Now get_native_cuda_id().", DeprecationWarning) + return set_native_cuda_id(index) + + +def set_native_cuda_id(index: int) -> None: + """ + Set native (unsorted) CUDA device id. + + Parameters + ---------- + idx : int + Specifies the (unsorted) native index of the device. + + Raises + ------ + RuntimeError + If the current backend type is not CUDA. + """ + _check_if_cuda_used() + + return c_set_native_id(index) + + +def set_cublas_mode(mode: Union[CublasMathMode, int] = CublasMathMode.default) -> None: + """ + Set cuBLAS math mode for CUDA backend. It enables the Tensor Core usage if available on CUDA backend GPUs. + + Parameters + ---------- + mode : Union[CublasMathMode, int] + Specify the mode available within CublasMathMode enum. + + Raises + ------ + ValueError + If the given math mode int value is not a valid value for cuBLAS math mode. + RuntimeError + If the current backend type is not CUDA. + """ + if isinstance(mode, int): + if mode not in [m.value for m in CublasMathMode]: + raise ValueError(f"{mode} is not supported as cublas math mode.") + mode = CublasMathMode(mode) + + _check_if_cuda_used() + + return cublas_set_math_mode(mode.value) diff --git a/arrayfire/backend/_clib_wrapper/__init__.py b/arrayfire/backend/_clib_wrapper/__init__.py index 6fbe9b7..34fea3f 100755 --- a/arrayfire/backend/_clib_wrapper/__init__.py +++ b/arrayfire/backend/_clib_wrapper/__init__.py @@ -183,6 +183,7 @@ "get_device_id", "get_size_of", "get_backend_id", + "af_range" ] from ._unsorted import ( @@ -211,6 +212,7 @@ set_backend, transpose, where, + af_range ) __all__ += ["safe_call"] diff --git a/arrayfire/backend/_clib_wrapper/_unsorted.py b/arrayfire/backend/_clib_wrapper/_unsorted.py index cd2fab2..255f741 100755 --- a/arrayfire/backend/_clib_wrapper/_unsorted.py +++ b/arrayfire/backend/_clib_wrapper/_unsorted.py @@ -278,6 +278,16 @@ def randu(shape: Tuple[int, ...], dtype: Dtype, /) -> AFArrayType: return out +def af_range(shape: Tuple[int, ...], axis: int, dtype: Dtype, /) -> AFArrayType: + """ + source: https://arrayfire.org/docs/group__data__func__range.htm#gadd6c9b479692454670a51e00ea5b26d5 + """ + out = ctypes.c_void_p(0) + c_shape = CShape(*shape) + safe_call(_backend.clib.af_range(ctypes.pointer(out), 4, c_shape.c_array, axis, dtype.c_api_value)) + return out + + # Safe Call Wrapper @@ -335,3 +345,40 @@ def get_backend_id(arr: AFArrayType, /) -> int: out = ctypes.c_int(0) safe_call(_backend.clib.get().af_get_backend_id(ctypes.pointer(out), arr)) return out.value + + +# Cuda specific + + +def get_stream(index: int) -> int: + """ + source: https://arrayfire.org/docs/group__cuda__mat.htm#ga8323b850f80afe9878b099f647b0a7e5 + """ + out = ctypes.c_void_p(0) + safe_call(_backend.clib.get().afcu_get_stream(ctypes.pointer(out), index)) + return out.value + + +def get_native_id(index: int) -> int: + """ + source: https://arrayfire.org/docs/group__cuda__mat.htm#gaf38af1cbbf4be710cc8cbd95d20b24c4 + """ + out = ctypes.c_int(0) + safe_call(_backend.clib.get().afcu_get_native_id(ctypes.pointer(out), index)) + return out.value + + +def set_native_id(index: int) -> None: + """ + source: https://arrayfire.org/docs/group__cuda__mat.htm#ga966f4c6880e90ce91d9599c90c0db378 + """ + safe_call(_backend.clib.get().afcu_set_native_id(index)) + return None + + +def cublas_set_math_mode(mode: int) -> None: + """ + source: https://arrayfire.org/docs/group__cuda__mat.htm#gac23ea38f0bff77a0e12555f27f47aa4f + """ + safe_call(_backend.clib.get().afcu_cublasSetMathMode(mode)) + return None diff --git a/arrayfire/library/data.py b/arrayfire/library/data.py new file mode 100644 index 0000000..0418c9c --- /dev/null +++ b/arrayfire/library/data.py @@ -0,0 +1,130 @@ +from typing import Callable, Tuple, Union, cast + +from typing_extensions import ParamSpec + +from arrayfire import Array +from arrayfire.backend import _clib_wrapper as wrapper +from arrayfire.dtypes import Dtype, float32 + +_pyrange = range + +P = ParamSpec("P") + + +def _afarray_as_array(func: Callable[P, Array]) -> Callable[P, Array]: + """ + Decorator that converts a function returning an array to return an ArrayFire Array. + + Parameters + ---------- + func : Callable[P, Array] + The original function that returns an array. + + Returns + ------- + Callable[P, Array] + A decorated function that returns an ArrayFire Array. + """ + + def decorated(*args: P.args, **kwargs: P.kwargs) -> Array: + out = Array() + result = func(*args, **kwargs) + out.arr = result # type: ignore[assignment] + return out + + return decorated + + +@_afarray_as_array +def constant(scalar: Union[int, float, complex], shape: Tuple[int, ...] = (1,), dtype: Dtype = float32) -> Array: + """ + Create a multi-dimensional array filled with a constant value. + + Parameters + ---------- + scalar : Union[int, float, complex] + The value to fill each element of the constant array with. + + shape : Tuple[int, ...], optional, default: (1,) + The shape of the constant array. + + dtype : Dtype, optional, default: float32 + Data type of the array. + + Returns + ------- + Array + A multi-dimensional ArrayFire array filled with the specified value. + + Notes + ----- + The shape parameter determines the dimensions of the resulting array: + - If shape is (x1,), the output is a 1D array of size (x1,). + - If shape is (x1, x2), the output is a 2D array of size (x1, x2). + - If shape is (x1, x2, x3), the output is a 3D array of size (x1, x2, x3). + - If shape is (x1, x2, x3, x4), the output is a 4D array of size (x1, x2, x3, x4). + """ + result = wrapper.create_constant_array(scalar, shape, dtype) + return cast(Array, result) # HACK actually it return AFArrayType, but decorator makes it an ArrayFire Array. + + +@_afarray_as_array +def range(shape: Tuple[int, ...], axis: int = 0, dtype: Dtype = float32) -> Array: + """ + Create a multi-dimensional array using the length of a dimension as a range. + + Parameters + ---------- + shape : Tuple[int, ...] + The shape of the resulting array. Each element represents the length + of a corresponding dimension. + + axis : int, optional, default: 0 + The dimension along which the range is calculated. + + dtype : Dtype, optional, default: float32 + Data type of the array. + + Returns + ------- + Array + A multi-dimensional ArrayFire array whose elements along `axis` fall + between [0, self.ndims[axis]-1]. + + Raises + ------ + ValueError + If axis value is greater than the number of axes in resulting Array. + + Notes + ----- + The `shape` parameter determines the dimensions of the resulting array: + - If shape is (x1,), the output is a 1D array of size (x1,). + - If shape is (x1, x2), the output is a 2D array of size (x1, x2). + - If shape is (x1, x2, x3), the output is a 3D array of size (x1, x2, x3). + - If shape is (x1, x2, x3, x4), the output is a 4D array of size (x1, x2, x3, x4). + + Examples + -------- + >>> import arrayfire as af + >>> a = af.range((3, 2)) # axis is not specified, range is along the first dimension. + >>> af.display(a) # The data ranges from [0 - 2] (3 elements along the first dimension) + [3 2 1 1] + 0.0000 0.0000 + 1.0000 1.0000 + 2.0000 2.0000 + + >>> a = af.range((3, 2), axis=1) # axis is 1, range is along the second dimension. + >>> af.display(a) # The data ranges from [0 - 1] (2 elements along the second dimension) + [3 2 1 1] + 0.0000 1.0000 + 0.0000 1.0000 + 0.0000 1.0000 + """ + if axis > len(shape): + raise ValueError( + f"Can not calculate along {axis} dimension. The resulting Array is set to has {len(shape)} dimensions." + ) + + result = wrapper.af_range(shape, axis, dtype) + return cast(Array, result) # HACK actually it return AFArrayType, but decorator makes it an ArrayFire Array. diff --git a/arrayfire/library/operators.py b/arrayfire/library/operators.py index 2606c9e..9facb9b 100755 --- a/arrayfire/library/operators.py +++ b/arrayfire/library/operators.py @@ -448,6 +448,6 @@ def _check_operands_fit_requirements(x1: Union[int, float, Array], x2: Union[int def _check_array_values_not_complex(x: Array) -> None: - # if is_complex_dtype(x.dtype): - # raise TypeError("Values of an Array should not be the complex numbers.") + if is_complex_dtype(x.dtype): + raise TypeError("Values of an Array should not be the complex numbers.") pass diff --git a/arrayfire/library/utils.py b/arrayfire/library/utils.py index b67672f..5a55602 100644 --- a/arrayfire/library/utils.py +++ b/arrayfire/library/utils.py @@ -1,13 +1,50 @@ -from typing import Tuple, Union - from arrayfire import Array -# TODO implement functions - -def all(x: Array, /, *, axis: Union[None, int, Tuple[int, ...]] = None, keepdims: bool = False) -> Array: +def all_true(array: Array, axis: int | None = None) -> Array: return NotImplemented -def any(x: Array, /, *, axis: Union[None, int, Tuple[int, ...]] = None, keepdims: bool = False) -> Array: - return NotImplemented +# from time import time +# import math + +# def timeit(af_func, *args): +# """ +# Function to time arrayfire functions. + +# Parameters +# ---------- + +# af_func : arrayfire function + +# *args : arguments to `af_func` + +# Returns +# -------- + +# t : Time in seconds +# """ + +# sample_trials = 3 + +# sample_time = 1E20 + +# for i in range(sample_trials): +# start = time() +# res = af_func(*args) +# eval(res) +# sync() +# sample_time = min(sample_time, time() - start) + +# if (sample_time >= 0.5): +# return sample_time + +# num_iters = max(math.ceil(1.0 / sample_time), 3.0) + +# start = time() +# for i in range(int(num_iters)): +# res = af_func(*args) +# eval(res) +# sync() +# sample_time = (time() - start) / num_iters +# return sample_time diff --git a/examples/helloworld.py b/examples/helloworld.py new file mode 100644 index 0000000..d8e2415 --- /dev/null +++ b/examples/helloworld.py @@ -0,0 +1,57 @@ +#!/usr/bin/python + +####################################################### +# Copyright (c) 2015, ArrayFire +# All rights reserved. +# +# This file is distributed under 3-clause BSD license. +# The complete license agreement can be obtained at: +# http://arrayfire.com/licenses/BSD-3-Clause +######################################################## + +import arrayfire as af + +# Display backend information +# af.info() + +print("Create a 5-by-3 matrix of random floats on the GPU\n") +A = af.randu(5, 3, 1, 1, af.Dtype.f32) +print(A) + +print("Element-wise arithmetic\n") +B = af.sin(A) + 1.5 +print(B) + +print("Negate the first three elements of second column\n") +B[0:3, 1] = B[0:3, 1] * -1 +print(B) + +print("Fourier transform the result\n") +C = af.fft(B) +print(C) + +print("Grab last row\n") +c = C[-1, :] +print(c) + +print("Scan Test\n") +r = af.constant(2, 16, 4, 1, 1) +print(r) + +print("Scan\n") +S = af.scan(r, 0, af.BINARYOP.MUL) +print(S) + +print("Create 2-by-3 matrix from host data\n") +d = [1, 2, 3, 4, 5, 6] +print(af.Array(d, shape=(2, 3))) + +print("Copy last column onto first\n") +D[:, 0] = D[:, -1] +print(D) + +print("Sort A and print sorted array and corresponding indices\n") +[sorted_vals, sorted_idxs] = af.sort_index(A) +print(A) +print(sorted_vals) +print(sorted_idxs) diff --git a/tests/_helpers.py b/tests/_helpers.py new file mode 100644 index 0000000..1a45730 --- /dev/null +++ b/tests/_helpers.py @@ -0,0 +1,6 @@ +from typing import List, Union + + +def round_to(list_: List[Union[int, float, complex, bool]], symbols: int = 3) -> List[Union[int, float]]: + # HACK replace for e.g. abs(x1-x2) < 1e-6 ~ https://davidamos.dev/the-right-way-to-compare-floats-in-python/ + return [round(x, symbols) for x in list_] diff --git a/tests/array_object/test_operators.py b/tests/array_object/test_operator_overrides.py old mode 100755 new mode 100644 similarity index 91% rename from tests/array_object/test_operators.py rename to tests/array_object/test_operator_overrides.py index 81ae6a7..cd9a68e --- a/tests/array_object/test_operators.py +++ b/tests/array_object/test_operator_overrides.py @@ -5,15 +5,11 @@ from arrayfire import Array from arrayfire.dtypes import bool as af_bool +from tests._helpers import round_to Operator = Callable[[Union[int, float, Array], Union[int, float, Array]], Array] -def _round(list_: List[Union[int, float]], symbols: int = 3) -> List[Union[int, float]]: - # HACK replace for e.g. abs(x1-x2) < 1e-6 ~ https://davidamos.dev/the-right-way-to-compare-floats-in-python/ - return [round(x, symbols) for x in list_] - - def pytest_generate_tests(metafunc: Any) -> None: if "array_origin" in metafunc.fixturenames: metafunc.parametrize( @@ -77,8 +73,8 @@ def test_arithmetic_operators( ires = iop(array, operand) rres = op(operand, array) - assert _round(res.to_list()) == _round(ires.to_list()) == _round(ref) - assert _round(rres.to_list()) == _round(rref) + assert round_to(res.to_list()) == round_to(ires.to_list()) == round_to(ref) + assert round_to(rres.to_list()) == round_to(rref) assert res.dtype == ires.dtype == rres.dtype assert res.ndim == ires.ndim == rres.ndim diff --git a/tests/test_data.py b/tests/test_data.py new file mode 100644 index 0000000..a3e60f3 --- /dev/null +++ b/tests/test_data.py @@ -0,0 +1,67 @@ +import pytest + +from arrayfire.dtypes import int64 +from arrayfire.library import data + + +# Test cases for the constant function +def test_constant_1d() -> None: + result = data.constant(42, (5,)) + assert result.shape == (5,) + assert result.scalar() == 42 + + +def test_constant_2d() -> None: + result = data.constant(3.14, (3, 4)) + assert result.shape == (3, 4) + assert round(result.scalar(), 2) == 3.14 + + +def test_constant_3d() -> None: + result = data.constant(0, (2, 2, 2), dtype=int64) + assert result.shape == (2, 2, 2) + assert result.scalar() == 0 + assert result.dtype == int64 + + +def test_constant_default_shape() -> None: + result = data.constant(1.0) + assert result.shape == (1,) + assert result.scalar() == 1.0 + + +# TODO add error handling +# def test_constant_invalid_dtype() -> None: +# with pytest.raises(ValueError): +# data.constant(42, (3, 3), dtype="invalid_dtype") + + +# Test cases for the range function +def test_range_1d() -> None: + result = data.range((5,)) + assert result.shape == (5,) + + +def test_range_2d() -> None: + result = data.range((3, 4)) + assert result.shape == (3, 4) + + +def test_range_3d() -> None: + result = data.range((2, 2, 2)) + assert result.shape == (2, 2, 2) + + +def test_range_with_axis() -> None: + result = data.range((3, 4), axis=1) + assert result.shape == (3, 4) + + +def test_range_with_dtype() -> None: + result = data.range((4, 3), dtype=int64) + assert result.dtype == int64 + + +def test_range_with_invalid_axis() -> None: + with pytest.raises(ValueError): + data.range((2, 3, 4), axis=4) diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index d7ad295..eecc843 100755 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -1,4 +1,5 @@ import ctypes +from typing import Union import pytest @@ -35,7 +36,7 @@ def test_dtype_inequality() -> None: (1 + 2j, complex128, complex128), ], ) -def test_implicit_dtype(number: int | float | bool | complex, array_dtype: Dtype, expected_dtype: Dtype) -> None: +def test_implicit_dtype(number: Union[int, float, bool, complex], array_dtype: Dtype, expected_dtype: Dtype) -> None: result_dtype = implicit_dtype(number, array_dtype) assert result_dtype == expected_dtype diff --git a/tests/test_operators.py b/tests/test_operators.py index 72c627a..d5a3f02 100755 --- a/tests/test_operators.py +++ b/tests/test_operators.py @@ -1,5 +1,6 @@ from arrayfire import Array from arrayfire.library import operators +from tests._helpers import round_to class TestArithmeticOperators: @@ -16,3 +17,13 @@ def test_sub(self) -> None: res = operators.sub(self.array1, self.array2) res_sum = self.array1 - self.array2 assert res.to_list() == res_sum.to_list() == [-3, -3, -3] + + def test_mul(self) -> None: + res = operators.mul(self.array1, self.array2) + res_product = self.array1 * self.array2 + assert res.to_list() == res_product.to_list() == [4, 10, 18] + + def test_div(self) -> None: + res = operators.div(self.array1, self.array2) + res_quotient = self.array1 / self.array2 + assert round_to(res.to_list()) == round_to(res_quotient.to_list()) == [0.25, 0.4, 0.5] From 0e185e88f60f699cf346305b467395d69f28bf97 Mon Sep 17 00:00:00 2001 From: Anton Date: Sat, 2 Sep 2023 02:16:09 +0300 Subject: [PATCH 23/31] Add random. Add data --- arrayfire/__init__.py | 6 +- arrayfire/_array_helpers.py | 31 + arrayfire/array_object.py | 2 +- arrayfire/backend/_clib_wrapper/__init__.py | 22 +- arrayfire/backend/_clib_wrapper/_random.py | 62 ++ arrayfire/backend/_clib_wrapper/_unsorted.py | 21 +- arrayfire/library/data.py | 859 ++++++++++++++++++- arrayfire/library/random.py | 175 ++++ tests/test_data.py | 67 -- tests/test_random.py | 93 ++ tests/wip_test_data.py | 148 ++++ 11 files changed, 1375 insertions(+), 111 deletions(-) create mode 100644 arrayfire/_array_helpers.py create mode 100644 arrayfire/backend/_clib_wrapper/_random.py create mode 100644 arrayfire/library/random.py delete mode 100644 tests/test_data.py create mode 100644 tests/test_random.py create mode 100644 tests/wip_test_data.py diff --git a/arrayfire/__init__.py b/arrayfire/__init__.py index 7bb8c64..3034d33 100755 --- a/arrayfire/__init__.py +++ b/arrayfire/__init__.py @@ -226,7 +226,9 @@ __all__ += [ "constant", - "range" + "range", + "identity", + "flat" ] -from arrayfire.library.data import constant, range +from arrayfire.library.data import constant, range, identity, flat diff --git a/arrayfire/_array_helpers.py b/arrayfire/_array_helpers.py new file mode 100644 index 0000000..ea7cb99 --- /dev/null +++ b/arrayfire/_array_helpers.py @@ -0,0 +1,31 @@ +from typing import Callable + +from typing_extensions import ParamSpec + +from arrayfire import Array + +P = ParamSpec("P") + + +def afarray_as_array(func: Callable[P, Array]) -> Callable[P, Array]: + """ + Decorator that converts a function returning an array to return an ArrayFire Array. + + Parameters + ---------- + func : Callable[P, Array] + The original function that returns an array. + + Returns + ------- + Callable[P, Array] + A decorated function that returns an ArrayFire Array. + """ + + def decorated(*args: P.args, **kwargs: P.kwargs) -> Array: + out = Array() + result = func(*args, **kwargs) + out.arr = result # type: ignore[assignment] + return out + + return decorated diff --git a/arrayfire/array_object.py b/arrayfire/array_object.py index 774801c..ecdc190 100755 --- a/arrayfire/array_object.py +++ b/arrayfire/array_object.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Callable, cast from .backend import _clib_wrapper as wrapper -from .dtypes import CType, Dtype, c_api_value_to_dtype, float32, float64, str_to_dtype +from .dtypes import CType, Dtype, c_api_value_to_dtype, float32, str_to_dtype from .library.device import PointerSource if TYPE_CHECKING: diff --git a/arrayfire/backend/_clib_wrapper/__init__.py b/arrayfire/backend/_clib_wrapper/__init__.py index 34fea3f..26097b5 100755 --- a/arrayfire/backend/_clib_wrapper/__init__.py +++ b/arrayfire/backend/_clib_wrapper/__init__.py @@ -176,23 +176,26 @@ "reorder", "array_as_str", "where", - "randu", "get_last_error", "set_backend", "get_backend_count", "get_device_id", "get_size_of", "get_backend_id", - "af_range" + "af_range", + "identity", + "flat", ] from ._unsorted import ( + af_range, array_as_str, copy_array, create_array, create_handle, create_strided_array, device_array, + flat, get_backend_count, get_backend_id, get_ctype, @@ -204,15 +207,14 @@ get_numdims, get_scalar, get_size_of, + identity, index_gen, is_empty, - randu, reorder, retain_array, set_backend, transpose, where, - af_range ) __all__ += ["safe_call"] @@ -229,3 +231,15 @@ __all__ += ["get_indices"] from ._indexing import get_indices + +__all__ += ["create_random_engine", "release_random_engine", "AFRandomEngine"] +from ._random import ( + AFRandomEngine, + create_random_engine, + random_engine_get_type, + random_engine_set_type, + release_random_engine, + random_engine_set_seed, + random_engine_get_seed, + randu, random_uniform +) diff --git a/arrayfire/backend/_clib_wrapper/_random.py b/arrayfire/backend/_clib_wrapper/_random.py new file mode 100644 index 0000000..a7ba44d --- /dev/null +++ b/arrayfire/backend/_clib_wrapper/_random.py @@ -0,0 +1,62 @@ +import ctypes + +from arrayfire.backend._backend import _backend +from arrayfire.dtypes import CShape, Dtype + +from ._base import AFArrayType +from ._error_handler import safe_call + +AFRandomEngine = ctypes.c_void_p + + +def create_random_engine(engine_type: int, seed: int, /) -> AFRandomEngine: + out = ctypes.c_void_p(0) + safe_call(_backend.clib.af_create_random_engine(ctypes.pointer(out), engine_type, ctypes.c_longlong(seed))) + return out + + +def release_random_engine(engine: AFRandomEngine, /) -> None: + safe_call(_backend.clib.af_release_random_engine(engine)) + return None + + +def random_engine_set_type(engine: AFRandomEngine, engine_type: int, /) -> None: + safe_call(_backend.clib.af_random_engine_set_type(ctypes.pointer(engine), engine_type)) + return None + + +def random_engine_get_type(engine: AFRandomEngine, /) -> int: + out = ctypes.c_int(0) + safe_call(_backend.clib.af_random_engine_get_type(ctypes.pointer(out), engine)) + return out.value + + +def random_engine_set_seed(engine: AFRandomEngine, seed: int, /) -> None: + safe_call(_backend.clib.af_random_engine_set_seed(ctypes.pointer(engine), ctypes.c_longlong(seed))) + return None + + +def random_engine_get_seed(engine: AFRandomEngine, /) -> int: + out = ctypes.c_longlong(0) + safe_call(_backend.clib.af_random_engine_get_seed(ctypes.pointer(out), engine)) + return out.value + + +def randu(shape: tuple[int, ...], dtype: Dtype, /) -> AFArrayType: + """ + source: https://arrayfire.org/docs/group__random__func__randu.htm#ga412e2c2f5135bdda218c3487c487d3b5 + """ + out = ctypes.c_void_p(0) + c_shape = CShape(*shape) + safe_call(_backend.clib.af_randu(ctypes.pointer(out), 4, c_shape.c_array, dtype.c_api_value)) + return out + + +def random_uniform(shape: tuple[int, ...], dtype: Dtype, engine: AFRandomEngine, /) -> AFArrayType: + """ + source: https://arrayfire.org/docs/group__random__func__randu.htm#ga2ca76d970cfac076f9006755582a4a4c + """ + out = ctypes.c_void_p(0) + c_shape = CShape(*shape) + safe_call(_backend.clib.af_random_uniform(ctypes.pointer(out), 4, c_shape.c_array, dtype.c_api_value, engine)) + return out diff --git a/arrayfire/backend/_clib_wrapper/_unsorted.py b/arrayfire/backend/_clib_wrapper/_unsorted.py index 255f741..3096358 100755 --- a/arrayfire/backend/_clib_wrapper/_unsorted.py +++ b/arrayfire/backend/_clib_wrapper/_unsorted.py @@ -268,23 +268,32 @@ def where(arr: AFArrayType) -> AFArrayType: return out -def randu(shape: Tuple[int, ...], dtype: Dtype, /) -> AFArrayType: +def af_range(shape: Tuple[int, ...], axis: int, dtype: Dtype, /) -> AFArrayType: """ - source: https://arrayfire.org/docs/group__random__func__randu.htm#ga412e2c2f5135bdda218c3487c487d3b5 + source: https://arrayfire.org/docs/group__data__func__range.htm#gadd6c9b479692454670a51e00ea5b26d5 """ out = ctypes.c_void_p(0) c_shape = CShape(*shape) - safe_call(_backend.clib.af_randu(ctypes.pointer(out), *c_shape, dtype.c_api_value)) + safe_call(_backend.clib.af_range(ctypes.pointer(out), 4, c_shape.c_array, axis, dtype.c_api_value)) return out -def af_range(shape: Tuple[int, ...], axis: int, dtype: Dtype, /) -> AFArrayType: +def identity(shape: Tuple[int, ...], dtype: Dtype, /) -> AFArrayType: """ - source: https://arrayfire.org/docs/group__data__func__range.htm#gadd6c9b479692454670a51e00ea5b26d5 + source: """ out = ctypes.c_void_p(0) c_shape = CShape(*shape) - safe_call(_backend.clib.af_range(ctypes.pointer(out), 4, c_shape.c_array, axis, dtype.c_api_value)) + safe_call(_backend.clib.af_identity(ctypes.pointer(out), 4, c_shape.c_array, dtype.c_api_value)) + return out + + +def flat(arr: AFArrayType) -> AFArrayType: + """ + source: https://arrayfire.org/docs/group__manip__func__flat.htm#gac6dfb22cbd3b151ddffb9a4ddf74455e + """ + out = ctypes.c_void_p(0) + safe_call(_backend.clib.af_flat(ctypes.pointer(out), arr)) return out diff --git a/arrayfire/library/data.py b/arrayfire/library/data.py index 0418c9c..24308ea 100644 --- a/arrayfire/library/data.py +++ b/arrayfire/library/data.py @@ -1,41 +1,14 @@ -from typing import Callable, Tuple, Union, cast - -from typing_extensions import ParamSpec +from typing import Tuple, Union, cast from arrayfire import Array +from arrayfire._array_helpers import afarray_as_array from arrayfire.backend import _clib_wrapper as wrapper from arrayfire.dtypes import Dtype, float32 _pyrange = range -P = ParamSpec("P") - - -def _afarray_as_array(func: Callable[P, Array]) -> Callable[P, Array]: - """ - Decorator that converts a function returning an array to return an ArrayFire Array. - - Parameters - ---------- - func : Callable[P, Array] - The original function that returns an array. - Returns - ------- - Callable[P, Array] - A decorated function that returns an ArrayFire Array. - """ - - def decorated(*args: P.args, **kwargs: P.kwargs) -> Array: - out = Array() - result = func(*args, **kwargs) - out.arr = result # type: ignore[assignment] - return out - - return decorated - - -@_afarray_as_array +@afarray_as_array def constant(scalar: Union[int, float, complex], shape: Tuple[int, ...] = (1,), dtype: Dtype = float32) -> Array: """ Create a multi-dimensional array filled with a constant value. @@ -68,7 +41,7 @@ def constant(scalar: Union[int, float, complex], shape: Tuple[int, ...] = (1,), return cast(Array, result) # HACK actually it return AFArrayType, but decorator makes it an ArrayFire Array. -@_afarray_as_array +@afarray_as_array def range(shape: Tuple[int, ...], axis: int = 0, dtype: Dtype = float32) -> Array: """ Create a multi-dimensional array using the length of a dimension as a range. @@ -128,3 +101,827 @@ def range(shape: Tuple[int, ...], axis: int = 0, dtype: Dtype = float32) -> Arra result = wrapper.af_range(shape, axis, dtype) return cast(Array, result) # HACK actually it return AFArrayType, but decorator makes it an ArrayFire Array. + + +# def iota(d0, d1=None, d2=None, d3=None, dim=-1, tile_dims=None, dtype=Dtype.f32): +# """ +# Create a multi dimensional array using the number of elements in the array as the range. + +# Parameters +# ---------- +# val : scalar. +# Value of each element of the constant array. + +# d0 : int. +# Length of first dimension. + +# d1 : optional: int. default: None. +# Length of second dimension. + +# d2 : optional: int. default: None. +# Length of third dimension. + +# d3 : optional: int. default: None. +# Length of fourth dimension. + +# tile_dims : optional: tuple of ints. default: None. +# The number of times the data is tiled. + +# dtype : optional: af.Dtype. default: af.Dtype.f32. +# Data type of the array. + +# Returns +# ------- + +# out : af.Array +# Multi dimensional array whose elements are along `dim` fall between [0 - self.elements() - 1]. + +# Examples +# -------- +# >>> import arrayfire as af +# >>> import arrayfire as af +# >>> a = af.iota(3,3) # tile_dim is not specified, data is not tiled +# >>> af.display(a) # the elements range from [0 - 8] (9 elements) +# [3 3 1 1] +# 0.0000 3.0000 6.0000 +# 1.0000 4.0000 7.0000 +# 2.0000 5.0000 8.0000 + +# >>> b = af.iota(3,3,tile_dims(1,2)) # Asking to tile along second dimension. +# >>> af.display(b) +# [3 6 1 1] +# 0.0000 3.0000 6.0000 0.0000 3.0000 6.0000 +# 1.0000 4.0000 7.0000 1.0000 4.0000 7.0000 +# 2.0000 5.0000 8.0000 2.0000 5.0000 8.0000 +# """ +# out = Array() +# dims = dim4(d0, d1, d2, d3) +# td=[1]*4 + +# if tile_dims is not None: +# for i in _brange(len(tile_dims)): +# td[i] = tile_dims[i] + +# tdims = dim4(td[0], td[1], td[2], td[3]) + +# safe_call(backend.get().af_iota(c_pointer(out.arr), 4, c_pointer(dims), +# 4, c_pointer(tdims), dtype.value)) +# return out + + +@afarray_as_array +def identity(shape: Tuple[int, ...], dtype: Dtype = float32) -> Array: + """ + Create an identity matrix or batch of identity matrices. + + Parameters + ---------- + shape : Tuple[int, ...] + The shape of the resulting identity array or batch of arrays. + Must have at least 2 values. + + dtype : Dtype, optional, default: float32 + Data type of the array. + + Returns + ------- + Array + A multi-dimensional ArrayFire array where the first two dimensions + form an identity matrix or batch of matrices. + + Notes + ----- + The `shape` parameter determines the dimensions of the resulting array: + - If shape is (x1, x2), the output is a 2D array of size (x1, x2). + - If shape is (x1, x2, x3), the output is a 3D array of size (x1, x2, x3). + - If shape is (x1, x2, x3, x4), the output is a 4D array of size (x1, x2, x3, x4). + + Raises + ------ + ValueError + If shape is not a tuple or has less than two values. + + Examples + -------- + >>> import arrayfire as af + >>> identity_matrix = af.identity((3, 3)) # Create a 3x3 identity matrix + >>> af.display(identity_matrix) + [3 3 1 1] + 1.0000 0.0000 0.0000 + 0.0000 1.0000 0.0000 + 0.0000 0.0000 1.0000 + + >>> identity_batch = af.identity((2, 2, 3)) # Create a batch of 3 identity 2x2 matrices + >>> af.display(identity_batch) + [2 2 3 1] + 1.0000 0.0000 1.0000 0.0000 1.0000 0.0000 + 0.0000 1.0000 0.0000 1.0000 0.0000 1.0000 + """ + + if not isinstance(shape, tuple) or len(shape) < 2: + raise ValueError("Argument shape must be a tuple with at least 2 values.") + + result = wrapper.identity(shape, dtype) + return cast(Array, result) + + +# def diag(a, num=0, extract=True): +# """ +# Create a diagonal matrix or Extract the diagonal from a matrix. + +# Parameters +# ---------- +# a : af.Array. +# 1 dimensional or 2 dimensional arrayfire array. + +# num : optional: int. default: 0. +# The index of the diagonal. +# - num == 0 signifies the diagonal. +# - num > 0 signifies super diagonals. +# - num < 0 signifies sub diagonals. + +# extract : optional: bool. default: True. +# - If True , diagonal is extracted. `a` has to be 2D. +# - If False, diagonal matrix is created. `a` has to be 1D. + +# Returns +# ------- + +# out : af.Array +# - if extract is True, `out` contains the num'th diagonal from `a`. +# - if extract is False, `out` contains `a` as the num'th diagonal. +# """ +# out = Array() +# if extract: +# safe_call(backend.get().af_diag_extract(c_pointer(out.arr), a.arr, c_int_t(num))) +# else: +# safe_call(backend.get().af_diag_create(c_pointer(out.arr), a.arr, c_int_t(num))) +# return out + + +# def join(dim, first, second, third=None, fourth=None): +# """ +# Join two or more arrayfire arrays along a specified dimension. + +# Parameters +# ---------- + +# dim: int. +# Dimension along which the join occurs. + +# first : af.Array. +# Multi dimensional arrayfire array. + +# second : af.Array. +# Multi dimensional arrayfire array. + +# third : optional: af.Array. default: None. +# Multi dimensional arrayfire array. + +# fourth : optional: af.Array. default: None. +# Multi dimensional arrayfire array. + +# Returns +# ------- + +# out : af.Array +# An array containing the input arrays joined along the specified dimension. + +# Examples +# --------- + +# >>> import arrayfire as af +# >>> a = af.randu(2, 3) +# >>> b = af.randu(2, 3) +# >>> c = af.join(0, a, b) +# >>> d = af.join(1, a, b) +# >>> af.display(a) +# [2 3 1 1] +# 0.9508 0.2591 0.7928 +# 0.5367 0.8359 0.8719 + +# >>> af.display(b) +# [2 3 1 1] +# 0.3266 0.6009 0.2442 +# 0.6275 0.0495 0.6591 + +# >>> af.display(c) +# [4 3 1 1] +# 0.9508 0.2591 0.7928 +# 0.5367 0.8359 0.8719 +# 0.3266 0.6009 0.2442 +# 0.6275 0.0495 0.6591 + +# >>> af.display(d) +# [2 6 1 1] +# 0.9508 0.2591 0.7928 0.3266 0.6009 0.2442 +# 0.5367 0.8359 0.8719 0.6275 0.0495 0.6591 +# """ +# out = Array() +# if third is None and fourth is None: +# safe_call(backend.get().af_join(c_pointer(out.arr), dim, first.arr, second.arr)) +# else: +# c_void_p_4 = c_void_ptr_t * 4 +# c_array_vec = c_void_p_4(first.arr, second.arr, 0, 0) +# num = 2 +# if third is not None: +# c_array_vec[num] = third.arr +# num += 1 +# if fourth is not None: +# c_array_vec[num] = fourth.arr +# num += 1 + +# safe_call(backend.get().af_join_many(c_pointer(out.arr), dim, num, c_pointer(c_array_vec))) +# return out + + +# def tile(a, d0, d1=1, d2=1, d3=1): +# """ +# Tile an array along specified dimensions. + +# Parameters +# ---------- + +# a : af.Array. +# Multi dimensional array. + +# d0: int. +# The number of times `a` has to be tiled along first dimension. + +# d1: optional: int. default: 1. +# The number of times `a` has to be tiled along second dimension. + +# d2: optional: int. default: 1. +# The number of times `a` has to be tiled along third dimension. + +# d3: optional: int. default: 1. +# The number of times `a` has to be tiled along fourth dimension. + +# Returns +# ------- + +# out : af.Array +# An array containing the input after tiling the the specified number of times. + +# Examples +# --------- + +# >>> import arrayfire as af +# >>> a = af.randu(2, 3) +# >>> b = af.tile(a, 2) +# >>> c = af.tile(a, 1, 2) +# >>> d = af.tile(a, 2, 2) +# >>> af.display(a) +# [2 3 1 1] +# 0.9508 0.2591 0.7928 +# 0.5367 0.8359 0.8719 + +# >>> af.display(b) +# [4 3 1 1] +# 0.4107 0.9518 0.4198 +# 0.8224 0.1794 0.0081 +# 0.4107 0.9518 0.4198 +# 0.8224 0.1794 0.0081 + +# >>> af.display(c) +# [2 6 1 1] +# 0.4107 0.9518 0.4198 0.4107 0.9518 0.4198 +# 0.8224 0.1794 0.0081 0.8224 0.1794 0.0081 + +# >>> af.display(d) +# [4 6 1 1] +# 0.4107 0.9518 0.4198 0.4107 0.9518 0.4198 +# 0.8224 0.1794 0.0081 0.8224 0.1794 0.0081 +# 0.4107 0.9518 0.4198 0.4107 0.9518 0.4198 +# 0.8224 0.1794 0.0081 0.8224 0.1794 0.0081 +# """ +# out = Array() +# safe_call(backend.get().af_tile(c_pointer(out.arr), a.arr, d0, d1, d2, d3)) +# return out + + +# def reorder(a, d0=1, d1=0, d2=2, d3=3): +# """ +# Reorder the dimensions of the input. + +# Parameters +# ---------- + +# a : af.Array. +# Multi dimensional array. + +# d0: optional: int. default: 1. +# The location of the first dimension in the output. + +# d1: optional: int. default: 0. +# The location of the second dimension in the output. + +# d2: optional: int. default: 2. +# The location of the third dimension in the output. + +# d3: optional: int. default: 3. +# The location of the fourth dimension in the output. + +# Returns +# ------- + +# out : af.Array +# - An array containing the input aftern reordering its dimensions. + +# Note +# ------ +# - `af.reorder(a, 1, 0)` is the same as `transpose(a)` + +# Examples +# -------- +# >>> import arrayfire as af +# >>> a = af.randu(5, 5, 3) +# >>> af.display(a) +# [5 5 3 1] +# 0.4107 0.0081 0.6600 0.1046 0.8395 +# 0.8224 0.3775 0.0764 0.8827 0.1933 +# 0.9518 0.3027 0.0901 0.1647 0.7270 +# 0.1794 0.6456 0.5933 0.8060 0.0322 +# 0.4198 0.5591 0.1098 0.5938 0.0012 + +# 0.8703 0.9250 0.4387 0.6530 0.4224 +# 0.5259 0.3063 0.3784 0.5476 0.5293 +# 0.1443 0.9313 0.4002 0.8577 0.0212 +# 0.3253 0.8684 0.4390 0.8370 0.1103 +# 0.5081 0.6592 0.4718 0.0618 0.4420 + +# 0.8355 0.6767 0.1033 0.9426 0.9276 +# 0.4878 0.6742 0.2119 0.4817 0.8662 +# 0.2055 0.4523 0.5955 0.9097 0.3578 +# 0.1794 0.1236 0.3745 0.6821 0.6263 +# 0.5606 0.7924 0.9165 0.6056 0.9747 + + +# >>> b = af.reorder(a, 2, 0, 1) +# >>> af.display(b) +# [3 5 5 1] +# 0.4107 0.8224 0.9518 0.1794 0.4198 +# 0.8703 0.5259 0.1443 0.3253 0.5081 +# 0.8355 0.4878 0.2055 0.1794 0.5606 + +# 0.0081 0.3775 0.3027 0.6456 0.5591 +# 0.9250 0.3063 0.9313 0.8684 0.6592 +# 0.6767 0.6742 0.4523 0.1236 0.7924 + +# 0.6600 0.0764 0.0901 0.5933 0.1098 +# 0.4387 0.3784 0.4002 0.4390 0.4718 +# 0.1033 0.2119 0.5955 0.3745 0.9165 + +# 0.1046 0.8827 0.1647 0.8060 0.5938 +# 0.6530 0.5476 0.8577 0.8370 0.0618 +# 0.9426 0.4817 0.9097 0.6821 0.6056 + +# 0.8395 0.1933 0.7270 0.0322 0.0012 +# 0.4224 0.5293 0.0212 0.1103 0.4420 +# 0.9276 0.8662 0.3578 0.6263 0.9747 +# """ +# out = Array() +# safe_call(backend.get().af_reorder(c_pointer(out.arr), a.arr, d0, d1, d2, d3)) +# return out + + +# def shift(a, d0, d1=0, d2=0, d3=0): +# """ +# Shift the input along each dimension. + +# Parameters +# ---------- + +# a : af.Array. +# Multi dimensional array. + +# d0: int. +# The amount of shift along first dimension. + +# d1: optional: int. default: 0. +# The amount of shift along second dimension. + +# d2: optional: int. default: 0. +# The amount of shift along third dimension. + +# d3: optional: int. default: 0. +# The amount of shift along fourth dimension. + +# Returns +# ------- + +# out : af.Array +# - An array the same shape as `a` after shifting it by the specified amounts. + +# Examples +# -------- +# >>> import arrayfire as af +# >>> a = af.randu(3, 3) +# >>> b = af.shift(a, 2) +# >>> c = af.shift(a, 1, -1) +# >>> af.display(a) +# [3 3 1 1] +# 0.7269 0.3569 0.3341 +# 0.7104 0.1437 0.0899 +# 0.5201 0.4563 0.5363 + +# >>> af.display(b) +# [3 3 1 1] +# 0.7104 0.1437 0.0899 +# 0.5201 0.4563 0.5363 +# 0.7269 0.3569 0.3341 + +# >>> af.display(c) +# [3 3 1 1] +# 0.4563 0.5363 0.5201 +# 0.3569 0.3341 0.7269 +# 0.1437 0.0899 0.7104 +# """ +# out = Array() +# safe_call(backend.get().af_shift(c_pointer(out.arr), a.arr, d0, d1, d2, d3)) +# return out + + +# def moddims(a, d0, d1=1, d2=1, d3=1): +# """ +# Modify the shape of the array without changing the data layout. + +# Parameters +# ---------- + +# a : af.Array. +# Multi dimensional array. + +# d0: int. +# The first dimension of output. + +# d1: optional: int. default: 1. +# The second dimension of output. + +# d2: optional: int. default: 1. +# The third dimension of output. + +# d3: optional: int. default: 1. +# The fourth dimension of output. + +# Returns +# ------- + +# out : af.Array +# - An containing the same data as `a` with the specified shape. +# - The number of elements in `a` must match `d0 x d1 x d2 x d3`. +# """ +# out = Array() +# dims = dim4(d0, d1, d2, d3) +# safe_call(backend.get().af_moddims(c_pointer(out.arr), a.arr, 4, c_pointer(dims))) +# return out + + +@afarray_as_array +def flat(array: Array) -> Array: + """ + Flatten the input multi-dimensional array into a 1D array. + + Parameters + ---------- + array : Array + The input multi-dimensional array to be flattened. + + Returns + ------- + Array + A 1D array containing all the elements from the input array. + + Examples + -------- + >>> import arrayfire as af + >>> arr = af.randu(3, 2) # Create a 3x2 random array + >>> flattened = af.flat(arr) # Flatten the array + >>> af.display(flattened) + [6 1 1 1] + 0.8364 + 0.5604 + 0.6352 + 0.0062 + 0.7052 + 0.1676 + """ + result = wrapper.flat(array.arr) + return cast(Array, result) + + +# def flip(a, dim=0): +# """ +# Flip an array along a dimension. + +# Parameters +# ---------- + +# a : af.Array. +# Multi dimensional array. + +# dim : optional: int. default: 0. +# The dimension along which the flip is performed. + +# Returns +# ------- + +# out : af.Array +# The output after flipping `a` along `dim`. + +# Examples +# --------- + +# >>> import arrayfire as af +# >>> a = af.randu(3, 3) +# >>> af.display(a) +# [3 3 1 1] +# 0.7269 0.3569 0.3341 +# 0.7104 0.1437 0.0899 +# 0.5201 0.4563 0.5363 + +# >>> af.display(b) +# [3 3 1 1] +# 0.5201 0.4563 0.5363 +# 0.7104 0.1437 0.0899 +# 0.7269 0.3569 0.3341 + +# >>> af.display(c) +# [3 3 1 1] +# 0.3341 0.3569 0.7269 +# 0.0899 0.1437 0.7104 +# 0.5363 0.4563 0.5201 + +# """ +# out = Array() +# safe_call(backend.get().af_flip(c_pointer(out.arr), a.arr, c_int_t(dim))) +# return out + + +# def lower(a, is_unit_diag=False): +# """ +# Extract the lower triangular matrix from the input. + +# Parameters +# ---------- + +# a : af.Array. +# Multi dimensional array. + +# is_unit_diag: optional: bool. default: False. +# Flag specifying if the diagonal elements are 1. + +# Returns +# ------- + +# out : af.Array +# An array containing the lower triangular elements from `a`. +# """ +# out = Array() +# safe_call(backend.get().af_lower(c_pointer(out.arr), a.arr, is_unit_diag)) +# return out + + +# def upper(a, is_unit_diag=False): +# """ +# Extract the upper triangular matrix from the input. + +# Parameters +# ---------- + +# a : af.Array. +# Multi dimensional array. + +# is_unit_diag: optional: bool. default: False. +# Flag specifying if the diagonal elements are 1. + +# Returns +# ------- + +# out : af.Array +# An array containing the upper triangular elements from `a`. +# """ +# out = Array() +# safe_call(backend.get().af_upper(c_pointer(out.arr), a.arr, is_unit_diag)) +# return out + + +# def select(cond, lhs, rhs): +# """ +# Select elements from one of two arrays based on condition. + +# Parameters +# ---------- + +# cond : af.Array +# Conditional array + +# lhs : af.Array or scalar +# numerical array whose elements are picked when conditional element is True + +# rhs : af.Array or scalar +# numerical array whose elements are picked when conditional element is False + +# Returns +# -------- + +# out: af.Array +# An array containing elements from `lhs` when `cond` is True and `rhs` when False. + +# Examples +# --------- + +# >>> import arrayfire as af +# >>> a = af.randu(3,3) +# >>> b = af.randu(3,3) +# >>> cond = a > b +# >>> res = af.select(cond, a, b) + +# >>> af.display(a) +# [3 3 1 1] +# 0.4107 0.1794 0.3775 +# 0.8224 0.4198 0.3027 +# 0.9518 0.0081 0.6456 + +# >>> af.display(b) +# [3 3 1 1] +# 0.7269 0.3569 0.3341 +# 0.7104 0.1437 0.0899 +# 0.5201 0.4563 0.5363 + +# >>> af.display(res) +# [3 3 1 1] +# 0.7269 0.3569 0.3775 +# 0.8224 0.4198 0.3027 +# 0.9518 0.4563 0.6456 +# """ +# out = Array() + +# is_left_array = isinstance(lhs, Array) +# is_right_array = isinstance(rhs, Array) + +# if not (is_left_array or is_right_array): +# raise TypeError("Atleast one input needs to be of type arrayfire.array") + +# elif is_left_array and is_right_array: +# safe_call(backend.get().af_select(c_pointer(out.arr), cond.arr, lhs.arr, rhs.arr)) + +# elif _is_number(rhs): +# safe_call(backend.get().af_select_scalar_r(c_pointer(out.arr), cond.arr, lhs.arr, c_double_t(rhs))) +# else: +# safe_call(backend.get().af_select_scalar_l(c_pointer(out.arr), cond.arr, c_double_t(lhs), rhs.arr)) + +# return out + + +# def replace(lhs, cond, rhs): +# """ +# Select elements from one of two arrays based on condition. + +# Parameters +# ---------- + +# lhs : af.Array or scalar +# numerical array whose elements are replaced with `rhs` when conditional element is False + +# cond : af.Array +# Conditional array + +# rhs : af.Array or scalar +# numerical array whose elements are picked when conditional element is False + +# Examples +# --------- +# >>> import arrayfire as af +# >>> a = af.randu(3,3) +# >>> af.display(a) +# [3 3 1 1] +# 0.4107 0.1794 0.3775 +# 0.8224 0.4198 0.3027 +# 0.9518 0.0081 0.6456 + +# >>> cond = (a >= 0.25) & (a <= 0.75) +# >>> af.display(cond) +# [3 3 1 1] +# 1 0 1 +# 0 1 1 +# 0 0 1 + +# >>> af.replace(a, cond, 0.3333) +# >>> af.display(a) +# [3 3 1 1] +# 0.3333 0.1794 0.3333 +# 0.8224 0.3333 0.3333 +# 0.9518 0.0081 0.3333 + +# """ +# is_right_array = isinstance(rhs, Array) + +# if is_right_array: +# safe_call(backend.get().af_replace(lhs.arr, cond.arr, rhs.arr)) +# else: +# safe_call(backend.get().af_replace_scalar(lhs.arr, cond.arr, c_double_t(rhs))) + + +# def pad(a, beginPadding, endPadding, padFillType=PAD.ZERO): +# """ +# Pad an array + +# This function will pad an array with the specified border size. +# Newly padded values can be filled in several different ways. + +# Parameters +# ---------- + +# a: af.Array +# A multi dimensional input arrayfire array. + +# beginPadding: tuple of ints. default: (0, 0, 0, 0). + +# endPadding: tuple of ints. default: (0, 0, 0, 0). + +# padFillType: optional af.PAD default: af.PAD.ZERO +# specifies type of values to fill padded border with + +# Returns +# ------- +# output: af.Array +# A padded array + +# Examples +# --------- +# >>> import arrayfire as af +# >>> a = af.randu(3,3) +# >>> af.display(a) +# [3 3 1 1] +# 0.4107 0.1794 0.3775 +# 0.8224 0.4198 0.3027 +# 0.9518 0.0081 0.6456 + +# >>> padded = af.pad(a, (1, 1), (1, 1), af.ZERO) +# >>> af.display(padded) +# [5 5 1 1] +# 0.0000 0.0000 0.0000 0.0000 0.0000 +# 0.0000 0.4107 0.1794 0.3775 0.0000 +# 0.0000 0.8224 0.4198 0.3027 0.0000 +# 0.0000 0.9518 0.0081 0.6456 0.0000 +# 0.0000 0.0000 0.0000 0.0000 0.0000 +# """ +# out = Array() +# begin_dims = dim4(beginPadding[0], beginPadding[1], beginPadding[2], beginPadding[3]) +# end_dims = dim4(endPadding[0], endPadding[1], endPadding[2], endPadding[3]) + +# safe_call( +# backend.get().af_pad( +# c_pointer(out.arr), a.arr, 4, c_pointer(begin_dims), 4, c_pointer(end_dims), padFillType.value +# ) +# ) +# return out + + +# def lookup(a, idx, dim=0): +# """ +# Lookup the values of input array based on index. + +# Parameters +# ---------- + +# a : af.Array. +# Multi dimensional array. + +# idx : is lookup indices + +# dim : optional: int. default: 0. +# Specifies the dimension for indexing + +# Returns +# ------- + +# out : af.Array +# An array containing values at locations specified by 'idx' + +# Examples +# --------- + +# >>> import arrayfire as af +# >>> arr = af.Array([1,0,3,4,5,6], (2,3)) +# >>> af.display(arr) +# [2 3 1 1] +# 1.0000 3.0000 5.0000 +# 0.0000 4.0000 6.0000 + +# >>> idx = af.array([0, 2]) +# >>> af.lookup(arr, idx, 1) +# [2 2 1 1] +# 1.0000 5.0000 +# 0.0000 6.0000 + +# >>> idx = af.array([0]) +# >>> af.lookup(arr, idx, 0) +# [2 1 1 1] +# 0.0000 +# 2.0000 +# """ +# out = Array() +# safe_call(backend.get().af_lookup(c_pointer(out.arr), a.arr, idx.arr, c_int_t(dim))) +# return out diff --git a/arrayfire/library/random.py b/arrayfire/library/random.py new file mode 100644 index 0000000..86d32f8 --- /dev/null +++ b/arrayfire/library/random.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +from enum import Enum +from typing import cast + +from arrayfire import Array +from arrayfire.backend import _clib_wrapper as wrapper +from arrayfire.dtypes import Dtype, float32 +from arrayfire._array_helpers import afarray_as_array + + +class RandomEngineType(Enum): + PHILOX = 100 # PHILOX_4X32_10 + THREEFRY = 200 # THREEFRY_2X32_16 + MERSENNE = 300 # MERSENNE_GP11213 + + +class RandomEngine: + """ + Class to handle random number generator engines. + + Parameters + ---------- + engine_type : RandomEngineType, optional, default: RandomEngineType.PHILOX + Specifies the type of random engine to be created. Can be one of: + - RandomEngineType.PHILOX + - RandomEngineType.THREEFRY + - RandomEngineType.MERSENNE + + seed : int, optional, default: 0 + Specifies the seed for the random engine. + """ + + def __init__(self, engine_type: RandomEngineType = RandomEngineType.PHILOX, seed: int = 0) -> None: + """ + Initialize a random engine instance. + + Parameters + ---------- + engine_type : RandomEngineType, optional, default: RandomEngineType.PHILOX + Specifies the type of random engine to be created. + + seed : int, optional, default: 0 + Specifies the seed for the random engine. + """ + self._engine = wrapper.create_random_engine(engine_type.value, seed) + + def __del__(self) -> None: + """ + Destructor to release the random engine resources. + """ + wrapper.release_random_engine(self._engine) + return None + + def set_type(self, engine_type: RandomEngineType) -> None: + """ + Set the type of the random engine. + + Parameters + ---------- + engine_type : RandomEngineType + Specifies the type of the random engine to be set. + """ + wrapper.random_engine_set_type(self._engine, engine_type.value) + return None + + def get_type(self) -> RandomEngineType: + """ + Get the type of the random engine. + + Returns + ------- + RandomEngineType + The type of the random engine. + """ + engine_type_value = wrapper.random_engine_get_type(self._engine) + return RandomEngineType(engine_type_value) + + def set_seed(self, seed: int) -> None: + """ + Set the seed for the random engine. + + Parameters + ---------- + seed : int + Specifies the seed to be set for the random engine. + """ + wrapper.random_engine_set_seed(self._engine, seed) + return None + + def get_seed(self) -> int: + """ + Get the seed for the random engine. + + Returns + ------- + int + The seed value of the random engine. + """ + return wrapper.random_engine_get_seed(self._engine) + + def get_engine(self) -> wrapper.AFRandomEngine: + """ + Get the ArrayFire random engine handle. + + Returns + ------- + wrapper.AFRandomEngine + The ArrayFire random engine handle associated with this RandomEngine instance. + """ + return self._engine + + @classmethod + def from_engine(cls, engine: wrapper.AFRandomEngine) -> RandomEngine: + """ + Create a RandomEngine instance from an existing RandomEngine handle. + + Parameters + ---------- + engine : wrapper.AFRandomEngine + The existing RandomEngine handle. + + Returns + ------- + RandomEngine + A new RandomEngine instance created from the provided engine handle. + """ + instance = cls.__new__(cls) + instance._engine = engine + return instance + + +@afarray_as_array +def randu(shape: tuple[int, ...], /, *, dtype: Dtype = float32, engine: RandomEngine | None = None) -> Array: + """ + Create a multi-dimensional array containing values from a uniform distribution. + + Parameters + ---------- + shape : tuple[int, ...] + The shape of the resulting array. Must have at least 1 element, e.g., shape=(3,) + + dtype : Dtype, optional, default: float32 + Data type of the array. + + engine : RandomEngine | None, optional, default: None + If engine is None, uses a default engine created by ArrayFire. + + Returns + ------- + Array + A multi-dimensional array whose elements are sampled uniformly between [0, 1]. + + Notes + ----- + The `shape` parameter determines the dimensions of the resulting array: + - If shape is (x1,), the output is a 1D array of size (x1,). + - If shape is (x1, x2), the output is a 2D array of size (x1, x2). + - If shape is (x1, x2, x3), the output is a 3D array of size (x1, x2, x3). + - If shape is (x1, x2, x3, x4), the output is a 4D array of size (x1, x2, x3, x4). + + Raises + ------ + ValueError + If shape is not a tuple or has less than one value. + """ + if not isinstance(shape, tuple) or not shape: + raise ValueError("Argument shape must be a tuple with at least 1 value.") + + if engine is None: + result = wrapper.randu(shape, dtype) + return cast(Array, result) + + result = wrapper.random_uniform(shape, dtype, engine.get_engine()) + return cast(Array, result) diff --git a/tests/test_data.py b/tests/test_data.py deleted file mode 100644 index a3e60f3..0000000 --- a/tests/test_data.py +++ /dev/null @@ -1,67 +0,0 @@ -import pytest - -from arrayfire.dtypes import int64 -from arrayfire.library import data - - -# Test cases for the constant function -def test_constant_1d() -> None: - result = data.constant(42, (5,)) - assert result.shape == (5,) - assert result.scalar() == 42 - - -def test_constant_2d() -> None: - result = data.constant(3.14, (3, 4)) - assert result.shape == (3, 4) - assert round(result.scalar(), 2) == 3.14 - - -def test_constant_3d() -> None: - result = data.constant(0, (2, 2, 2), dtype=int64) - assert result.shape == (2, 2, 2) - assert result.scalar() == 0 - assert result.dtype == int64 - - -def test_constant_default_shape() -> None: - result = data.constant(1.0) - assert result.shape == (1,) - assert result.scalar() == 1.0 - - -# TODO add error handling -# def test_constant_invalid_dtype() -> None: -# with pytest.raises(ValueError): -# data.constant(42, (3, 3), dtype="invalid_dtype") - - -# Test cases for the range function -def test_range_1d() -> None: - result = data.range((5,)) - assert result.shape == (5,) - - -def test_range_2d() -> None: - result = data.range((3, 4)) - assert result.shape == (3, 4) - - -def test_range_3d() -> None: - result = data.range((2, 2, 2)) - assert result.shape == (2, 2, 2) - - -def test_range_with_axis() -> None: - result = data.range((3, 4), axis=1) - assert result.shape == (3, 4) - - -def test_range_with_dtype() -> None: - result = data.range((4, 3), dtype=int64) - assert result.dtype == int64 - - -def test_range_with_invalid_axis() -> None: - with pytest.raises(ValueError): - data.range((2, 3, 4), axis=4) diff --git a/tests/test_random.py b/tests/test_random.py new file mode 100644 index 0000000..83f52a6 --- /dev/null +++ b/tests/test_random.py @@ -0,0 +1,93 @@ +import pytest + +from arrayfire import Array +from arrayfire.backend import _clib_wrapper as wrapper +from arrayfire.library import random +from arrayfire.library.random import RandomEngine, RandomEngineType + +# Test cases for the Random Engine + + +def test_random_engine_creation() -> None: + # Test creating a random engine with default values + engine = RandomEngine() + assert engine.get_type() == RandomEngineType.PHILOX + assert engine.get_seed() == 0 + engine.set_type(RandomEngineType.THREEFRY) + assert engine.get_type() == RandomEngineType.THREEFRY + engine.set_seed(42) + assert engine.get_seed() == 42 + + +def test_random_engine_from_handle() -> None: + # Test creating a random engine from an existing handle + handle = wrapper.create_random_engine(RandomEngineType.MERSENNE.value, 1232) + engine = RandomEngine.from_engine(handle) + assert engine.get_type() == RandomEngineType.MERSENNE + assert engine.get_seed() == 1232 + + +def test_random_engine_deletion() -> None: + # Test engine deletion and resource release + engine = RandomEngine() + del engine # This should release the engine's resources + + +# Test cases for the randu function + + +def test_randu_shape_1d() -> None: + shape = (10,) + result: Array = random.randu(shape) + assert isinstance(result, Array) + assert result.shape == shape + + +def test_randu_shape_2d() -> None: + shape = (5, 7) + result: Array = random.randu(shape) + assert isinstance(result, Array) + assert result.shape == shape + + +def test_randu_shape_3d() -> None: + shape = (3, 4, 6) + result: Array = random.randu(shape) + assert isinstance(result, Array) + assert result.shape == shape + + +def test_randu_shape_4d() -> None: + shape = (2, 2, 3, 5) + result: Array = random.randu(shape) + assert isinstance(result, Array) + assert result.shape == shape + + +def test_randu_default_engine() -> None: + shape = (5, 5) + result: Array = random.randu(shape) + assert isinstance(result, Array) + assert result.shape == shape + + +def test_randu_custom_engine() -> None: + shape = (3, 3) + custom_engine = RandomEngine(RandomEngineType.THREEFRY, seed=42) + result: Array = random.randu(shape, engine=custom_engine) + assert isinstance(result, Array) + assert result.shape == shape + + +def test_randu_invalid_shape() -> None: + # Test with an invalid shape (empty tuple) + with pytest.raises(ValueError): + shape = () + random.randu(shape) + + +def test_randu_invalid_shape_type() -> None: + # Test with an invalid shape (non-tuple) + with pytest.raises(ValueError): + shape = [5, 5] + random.randu(shape) # type: ignore[arg-type] diff --git a/tests/wip_test_data.py b/tests/wip_test_data.py new file mode 100644 index 0000000..b27ce3b --- /dev/null +++ b/tests/wip_test_data.py @@ -0,0 +1,148 @@ +import pytest + +from arrayfire.dtypes import int64 +from arrayfire.library import data +from arrayfire import Array + + +# Test cases for the constant function +def test_constant_1d() -> None: + result = data.constant(42, (5,)) + assert result.shape == (5,) + assert result.scalar() == 42 + + +def test_constant_2d() -> None: + result = data.constant(3.14, (3, 4)) + assert result.shape == (3, 4) + assert round(result.scalar(), 2) == 3.14 # type: ignore[arg-type] + + +def test_constant_3d() -> None: + result = data.constant(0, (2, 2, 2), dtype=int64) + assert result.shape == (2, 2, 2) + assert result.scalar() == 0 + assert result.dtype == int64 + + +def test_constant_default_shape() -> None: + result = data.constant(1.0) + assert result.shape == (1,) + assert result.scalar() == 1.0 + + +# TODO add error handling +# def test_constant_invalid_dtype() -> None: +# with pytest.raises(ValueError): +# data.constant(42, (3, 3), dtype="invalid_dtype") + + +# Test cases for the range function +def test_range_1d() -> None: + result = data.range((5,)) + assert result.shape == (5,) + + +def test_range_2d() -> None: + result = data.range((3, 4)) + assert result.shape == (3, 4) + + +def test_range_3d() -> None: + result = data.range((2, 2, 2)) + assert result.shape == (2, 2, 2) + + +def test_range_with_axis() -> None: + result = data.range((3, 4), axis=1) + assert result.shape == (3, 4) + + +def test_range_with_dtype() -> None: + result = data.range((4, 3), dtype=int64) + assert result.dtype == int64 + + +def test_range_with_invalid_axis() -> None: + with pytest.raises(ValueError): + data.range((2, 3, 4), axis=4) + + +# Test cases for the identity function + + +def _is_identity_matrix(arr: Array) -> bool: + rows, cols = arr.shape + if rows != cols: + return False + for i in range(rows): + for j in range(cols): + if i == j and arr[i, j] != 1: + return False + elif i != j and arr[i, j] != 0: + return False + return True + + +# Test cases for the identity function +def test_identity_2d() -> None: + result = data.identity((3, 3)) + assert result.shape == (3, 3) + assert _is_identity_matrix(result) + + +def test_identity_3d() -> None: + result = data.identity((2, 2, 2)) + assert result.shape == (2, 2, 2) + assert custom_all(result, lambda x: x == 1.0) + + +def test_identity_4d() -> None: + result = data.identity((2, 2, 2, 2)) + assert result.shape == (2, 2, 2, 2) + assert custom_all(result, lambda x: x == 1.0) + + +def test_identity_with_dtype() -> None: + result = data.identity((3, 3), dtype=int64) + assert result.shape == (3, 3) + assert result.dtype == int64 + + +def test_identity_invalid_shape() -> None: + with pytest.raises(ValueError): + data.identity((1,)) + + +def test_identity_invalid_shape2() -> None: + with pytest.raises(ValueError): + data.identity((3,)) + + +# Custom function to check if all elements in an array meet a condition +def custom_all(arr: Array) -> bool: + for element in data.flat(arr): + if not element: + return False + return True + + +# Test cases for the flat function +def test_flat_2d(): + input_array = af.randu(3, 2) # Create a 3x2 random array + result = data.flat(input_array) + assert result.shape == (6,) # Flattened shape should be 6 elements in 1D + assert custom_all(result == input_array) + + +def test_flat_3d(): + input_array = af.randu(2, 2, 2) # Create a 2x2x2 random array + result = data.flat(input_array) + assert result.shape == (8,) # Flattened shape should be 8 elements in 1D + assert custom_all(result == input_array) + + +def test_flat_empty_array(): + input_array = af.Array() # Create an empty array + result = data.flat(input_array) + assert result.shape == (0,) # Flattened shape of an empty array should be (0,) From 92b080107ae35b8668b08c3c1a667f4b2c52c303 Mon Sep 17 00:00:00 2001 From: Anton Date: Sat, 2 Sep 2023 04:17:59 +0300 Subject: [PATCH 24/31] Fix tests. Add utils --- arrayfire/__init__.py | 10 +- arrayfire/array_api/_array_object.py | 2 +- .../tests/fixme_test_elementwise_functions.py | 7 +- arrayfire/array_object.py | 6 +- arrayfire/backend/_clib_wrapper/__init__.py | 9 +- .../backend/_clib_wrapper/_constant_array.py | 3 +- .../backend/_clib_wrapper/_error_handler.py | 2 +- arrayfire/backend/_clib_wrapper/_indexing.py | 5 +- arrayfire/backend/_clib_wrapper/_unsorted.py | 21 ++- arrayfire/dtypes.py | 2 +- .../{operators2.py => old_operators.py} | 0 arrayfire/library/operators.py | 155 +++++++++--------- arrayfire/library/random.py | 2 +- arrayfire/library/utils.py | 38 ++++- examples/helloworld.py | 57 ------- tests/test_data.py | 137 ++++++++++++++++ tests/wip_test_data.py | 148 ----------------- 17 files changed, 292 insertions(+), 312 deletions(-) rename arrayfire/library/{operators2.py => old_operators.py} (100%) delete mode 100644 examples/helloworld.py create mode 100644 tests/test_data.py delete mode 100644 tests/wip_test_data.py diff --git a/arrayfire/__init__.py b/arrayfire/__init__.py index 3034d33..0ece943 100755 --- a/arrayfire/__init__.py +++ b/arrayfire/__init__.py @@ -223,12 +223,6 @@ trunc, ) +__all__ += ["constant", "range", "identity", "flat"] -__all__ += [ - "constant", - "range", - "identity", - "flat" -] - -from arrayfire.library.data import constant, range, identity, flat +from arrayfire.library.data import constant, flat, identity, range diff --git a/arrayfire/array_api/_array_object.py b/arrayfire/array_api/_array_object.py index 6e65056..ece5cf5 100755 --- a/arrayfire/array_api/_array_object.py +++ b/arrayfire/array_api/_array_object.py @@ -169,7 +169,7 @@ def __str__(self: Array, /) -> str: """ Performs the operation __str__. """ - return self._array.__str__()#.replace("array", "Array") + return self._array.__str__() # .replace("array", "Array") # def __repr__(self: Array, /) -> str: # """ diff --git a/arrayfire/array_api/tests/fixme_test_elementwise_functions.py b/arrayfire/array_api/tests/fixme_test_elementwise_functions.py index 3e5261c..853bee6 100644 --- a/arrayfire/array_api/tests/fixme_test_elementwise_functions.py +++ b/arrayfire/array_api/tests/fixme_test_elementwise_functions.py @@ -1,13 +1,12 @@ from inspect import getfullargspec -from typing import Callable, Iterator, TYPE_CHECKING +from typing import TYPE_CHECKING, Callable, Iterator import pytest from .. import _elementwise_functions, asarray -from .._dtypes import boolean_dtypes, dtype_categories, floating_dtypes, integer_dtypes, int8, real_floating_dtypes, int8 -from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift from .._array_object import Array - +from .._dtypes import boolean_dtypes, dtype_categories, floating_dtypes, int8, integer_dtypes, real_floating_dtypes +from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift def nargs(func: Callable) -> int: diff --git a/arrayfire/array_object.py b/arrayfire/array_object.py index ecdc190..0f1e467 100755 --- a/arrayfire/array_object.py +++ b/arrayfire/array_object.py @@ -707,9 +707,9 @@ def __array_namespace__(self, *, api_version: str | None = None) -> Any: # TODO return NotImplemented - def __bool__(self) -> bool: - # TODO consider using scalar() and is_scalar() - return NotImplemented + # def __bool__(self) -> bool: + # # TODO consider using scalar() and is_scalar() + # return NotImplemented def __complex__(self) -> complex: # TODO diff --git a/arrayfire/backend/_clib_wrapper/__init__.py b/arrayfire/backend/_clib_wrapper/__init__.py index 26097b5..e0ce07e 100755 --- a/arrayfire/backend/_clib_wrapper/__init__.py +++ b/arrayfire/backend/_clib_wrapper/__init__.py @@ -189,6 +189,8 @@ from ._unsorted import ( af_range, + all_true, + all_true_all, array_as_str, copy_array, create_array, @@ -236,10 +238,11 @@ from ._random import ( AFRandomEngine, create_random_engine, + random_engine_get_seed, random_engine_get_type, + random_engine_set_seed, random_engine_set_type, + random_uniform, + randu, release_random_engine, - random_engine_set_seed, - random_engine_get_seed, - randu, random_uniform ) diff --git a/arrayfire/backend/_clib_wrapper/_constant_array.py b/arrayfire/backend/_clib_wrapper/_constant_array.py index 6ada6e9..e8d1d58 100755 --- a/arrayfire/backend/_clib_wrapper/_constant_array.py +++ b/arrayfire/backend/_clib_wrapper/_constant_array.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Tuple, Union from arrayfire.backend._backend import _backend -from arrayfire.dtypes import CShape, Dtype, implicit_dtype, int64, uint64, is_complex_dtype, complex64, complex128 +from arrayfire.dtypes import CShape, Dtype, complex64, implicit_dtype, int64, is_complex_dtype, uint64 from ._error_handler import safe_call @@ -51,7 +51,6 @@ def _constant_ulong(number: Union[int, float], shape: Tuple[int, ...], dtype: Dt """ source: https://arrayfire.org/docs/group__data__func__constant.htm#ga67af670cc9314589f8134019f5e68809 """ - # return _backend.clib.af_constant_ulong(arr, val, ndims, dims) out = ctypes.c_void_p(0) c_shape = CShape(*shape) diff --git a/arrayfire/backend/_clib_wrapper/_error_handler.py b/arrayfire/backend/_clib_wrapper/_error_handler.py index 73e7762..b6b61fa 100755 --- a/arrayfire/backend/_clib_wrapper/_error_handler.py +++ b/arrayfire/backend/_clib_wrapper/_error_handler.py @@ -15,5 +15,5 @@ def safe_call(c_err: int) -> None: err_str = ctypes.c_char_p(0) err_len = c_dim_t(0) - _backend.clib.af_get_last_error(ctypes.pointer(err_str), ctypes.pointer(err_len)) + _backend.clib.af_get_last_error(ctypes.pointer(err_str), ctypes.pointer(err_len)) # BUG somewhere raise RuntimeError(to_str(err_str)) diff --git a/arrayfire/backend/_clib_wrapper/_indexing.py b/arrayfire/backend/_clib_wrapper/_indexing.py index f71767c..e06acf1 100755 --- a/arrayfire/backend/_clib_wrapper/_indexing.py +++ b/arrayfire/backend/_clib_wrapper/_indexing.py @@ -2,16 +2,13 @@ import ctypes import math -from typing import TYPE_CHECKING, Any, Tuple, Union +from typing import Any, Tuple, Union from arrayfire.backend._backend import _backend from arrayfire.library.broadcast import bcast_var from ._error_handler import safe_call -if TYPE_CHECKING: - from arrayfire import Array - class _IndexSequence(ctypes.Structure): """ diff --git a/arrayfire/backend/_clib_wrapper/_unsorted.py b/arrayfire/backend/_clib_wrapper/_unsorted.py index 3096358..4e810af 100755 --- a/arrayfire/backend/_clib_wrapper/_unsorted.py +++ b/arrayfire/backend/_clib_wrapper/_unsorted.py @@ -288,7 +288,7 @@ def identity(shape: Tuple[int, ...], dtype: Dtype, /) -> AFArrayType: return out -def flat(arr: AFArrayType) -> AFArrayType: +def flat(arr: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__manip__func__flat.htm#gac6dfb22cbd3b151ddffb9a4ddf74455e """ @@ -297,6 +297,25 @@ def flat(arr: AFArrayType) -> AFArrayType: return out +def all_true(arr: AFArrayType, axis: int, /) -> AFArrayType: + """ + source: https://arrayfire.org/docs/group__reduce__func__all__true.htm#ga068708be5177a0aa3788af140bb5ebd6 + """ + out = ctypes.c_void_p(0) + safe_call(_backend.clib.af_all_true(ctypes.pointer(out), arr, axis)) + return out + + +def all_true_all(arr: AFArrayType, /) -> complex: + """ + source: https://arrayfire.org/docs/group__reduce__func__all__true.htm#ga068708be5177a0aa3788af140bb5ebd6 + """ + real = ctypes.c_double(0) + imag = ctypes.c_double(0) + safe_call(_backend.clib.af_all_true(ctypes.pointer(real), ctypes.pointer(imag), arr)) + return real.value # NOTE imag is always set to 0 in C library + + # Safe Call Wrapper diff --git a/arrayfire/dtypes.py b/arrayfire/dtypes.py index 0c526da..083e332 100644 --- a/arrayfire/dtypes.py +++ b/arrayfire/dtypes.py @@ -55,7 +55,7 @@ def __repr__(self) -> str: complex64, complex128, bool, - int8 # BUG if place on top of the list + int8, # BUG if place on top of the list ) diff --git a/arrayfire/library/operators2.py b/arrayfire/library/old_operators.py similarity index 100% rename from arrayfire/library/operators2.py rename to arrayfire/library/old_operators.py diff --git a/arrayfire/library/operators.py b/arrayfire/library/operators.py index 9facb9b..4413aee 100755 --- a/arrayfire/library/operators.py +++ b/arrayfire/library/operators.py @@ -1,33 +1,34 @@ from __future__ import annotations -from typing import Union +from typing import Union, cast -from arrayfire import Array, return_copy +from arrayfire import Array +from arrayfire._array_helpers import afarray_as_array from arrayfire.backend import _clib_wrapper as wrapper from arrayfire.dtypes import is_complex_dtype -@return_copy +@afarray_as_array def add(x1: Array, x2: Array, /) -> Array: - return wrapper.add(x1.arr, x2.arr) # type: ignore[arg-type, return-value] + return cast(Array, wrapper.add(x1.arr, x2.arr)) -@return_copy +@afarray_as_array def sub(x1: Array, x2: Array, /) -> Array: return wrapper.sub(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def mul(x1: Array, x2: Array, /) -> Array: return wrapper.mul(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def div(x1: Array, x2: Array, /) -> Array: return wrapper.div(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def mod(x1: Union[int, float, Array], x2: Union[int, float, Array], /) -> Array: """ Calculate the modulus of two arrays or a scalar and an array. @@ -52,77 +53,81 @@ def mod(x1: Union[int, float, Array], x2: Union[int, float, Array], /) -> Array: _check_operands_fit_requirements(x1, x2) - return wrapper.mod(x1.arr, x2.arr) # type: ignore[arg-type, return-value] + x1_ = x1.arr if isinstance(x1, Array) else x1 + x2_ = x2.arr if isinstance(x2, Array) else x2 + result = wrapper.mod(x1_, x2_) + return cast(Array, result) -@return_copy + +@afarray_as_array def pow(x1: Union[int, float, Array], x2: Union[int, float, Array], /) -> Array: _check_operands_fit_requirements(x1, x2) return wrapper.pow(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def bitnot(x: Array, /) -> Array: return wrapper.bitnot(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def bitand(x1: Array, x2: Array, /) -> Array: return wrapper.bitand(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def bitor(x1: Array, x2: Array, /) -> Array: return wrapper.bitor(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def bitxor(x1: Array, x2: Array, /) -> Array: return wrapper.bitxor(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def bitshiftl(x1: Array, x2: Array, /) -> Array: return wrapper.bitshiftl(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def bitshiftr(x1: Array, x2: Array, /) -> Array: return wrapper.bitshiftr(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def lt(x1: Array, x2: Array, /) -> Array: return wrapper.lt(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def le(x1: Array, x2: Array, /) -> Array: return wrapper.le(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def gt(x1: Array, x2: Array, /) -> Array: return wrapper.gt(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def ge(x1: Array, x2: Array, /) -> Array: return wrapper.ge(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def eq(x1: Array, x2: Array, /) -> Array: return wrapper.eq(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def neq(x1: Array, x2: Array, /) -> Array: return wrapper.neq(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -# @return_copy +# @afarray_as_array # def clamp(x: Array, /, lo: float, hi: float) -> Array: # return NotImplemented @@ -136,113 +141,113 @@ def neq(x1: Array, x2: Array, /) -> Array: # # return out -@return_copy +@afarray_as_array def minof(x1: Union[int, float, Array], x2: Union[int, float, Array], /) -> Array: _check_operands_fit_requirements(x1, x2) return wrapper.minof(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def maxof(x1: Union[int, float, Array], x2: Union[int, float, Array], /) -> Array: _check_operands_fit_requirements(x1, x2) return wrapper.maxof(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def rem(x1: Union[int, float, Array], x2: Union[int, float, Array], /) -> Array: _check_operands_fit_requirements(x1, x2) return wrapper.rem(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def abs(x: Array, /) -> Array: return wrapper.abs(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def arg(x: Array, /) -> Array: return wrapper.arg(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def sign(x: Array, /) -> Array: return wrapper.sign(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def round(x: Array, /) -> Array: return wrapper.round(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def trunc(x: Array, /) -> Array: return wrapper.trunc(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def floor(x: Array, /) -> Array: return wrapper.floor(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def ceil(x: Array, /) -> Array: return wrapper.ceil(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def hypot(x1: Union[int, float, Array], x2: Union[int, float, Array], /) -> Array: _check_operands_fit_requirements(x1, x2) return wrapper.hypot(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def sin(x: Array, /) -> Array: _check_array_values_not_complex(x) return wrapper.sin(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def cos(x: Array, /) -> Array: _check_array_values_not_complex(x) return wrapper.cos(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def tan(x: Array, /) -> Array: _check_array_values_not_complex(x) return wrapper.tan(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def asin(x: Array, /) -> Array: _check_array_values_not_complex(x) return wrapper.asin(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def acos(x: Array, /) -> Array: _check_array_values_not_complex(x) return wrapper.acos(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def atan(x: Array, /) -> Array: _check_array_values_not_complex(x) return wrapper.atan(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def atan2(x1: Union[int, float, Array], x2: Union[int, float, Array], /) -> Array: _check_operands_fit_requirements(x1, x2) return wrapper.atan2(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def cplx(x1: Union[int, float, Array], x2: Union[int, float, Array, None], /) -> Array: if x2 is None: return wrapper.cplx1(x1) # type: ignore[arg-type, return-value] @@ -250,190 +255,188 @@ def cplx(x1: Union[int, float, Array], x2: Union[int, float, Array, None], /) -> return wrapper.cplx2(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def real(x: Array, /) -> Array: - return wrapper.real(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def imag(x: Array, /) -> Array: return wrapper.imag(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def conjg(x: Array, /) -> Array: - return wrapper.conjg(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def sinh(x: Array, /) -> Array: _check_array_values_not_complex(x) return wrapper.sinh(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def cosh(x: Array, /) -> Array: _check_array_values_not_complex(x) return wrapper.cosh(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def tanh(x: Array, /) -> Array: _check_array_values_not_complex(x) return wrapper.tanh(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def asinh(x: Array, /) -> Array: _check_array_values_not_complex(x) return wrapper.asinh(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def acosh(x: Array, /) -> Array: _check_array_values_not_complex(x) return wrapper.acosh(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def atanh(x: Array, /) -> Array: _check_array_values_not_complex(x) return wrapper.atanh(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def root(x1: Union[int, float, Array], x2: Union[int, float, Array], /) -> Array: _check_operands_fit_requirements(x1, x2) return wrapper.root(x1.arr, x2.arr) -@return_copy +@afarray_as_array def pow2(x: Array, /) -> Array: _check_array_values_not_complex(x) return wrapper.pow2(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def sigmoid(x: Array, /) -> Array: _check_array_values_not_complex(x) return wrapper.sigmoid(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def exp(x: Array, /) -> Array: _check_array_values_not_complex(x) return wrapper.exp(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def expm1(x: Array, /) -> Array: _check_array_values_not_complex(x) return wrapper.expm1(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def erf(x: Array, /) -> Array: _check_array_values_not_complex(x) return wrapper.erf(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def erfc(x: Array, /) -> Array: _check_array_values_not_complex(x) return wrapper.erfc(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def log(x: Array, /) -> Array: _check_array_values_not_complex(x) return wrapper.log(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def log1p(x: Array, /) -> Array: _check_array_values_not_complex(x) return wrapper.log1p(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def log10(x: Array, /) -> Array: _check_array_values_not_complex(x) return wrapper.log10(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def log2(x: Array, /) -> Array: _check_array_values_not_complex(x) return wrapper.log2(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def sqrt(x: Array, /) -> Array: _check_array_values_not_complex(x) return wrapper.sqrt(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def rsqrt(x: Array, /) -> Array: _check_array_values_not_complex(x) return wrapper.rsqrt(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def cbrt(x: Array, /) -> Array: _check_array_values_not_complex(x) return wrapper.cbrt(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def factorial(x: Array, /) -> Array: _check_array_values_not_complex(x) return wrapper.factorial(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def tgamma(x: Array, /) -> Array: _check_array_values_not_complex(x) return wrapper.tgamma(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def lgamma(x: Array, /) -> Array: _check_array_values_not_complex(x) return wrapper.lgamma(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def iszero(x: Array, /) -> Array: _check_array_values_not_complex(x) return wrapper.iszero(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def isinf(x: Array, /) -> Array: _check_array_values_not_complex(x) return wrapper.isinf(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def isnan(x: Array, /) -> Array: _check_array_values_not_complex(x) return wrapper.isnan(x.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def land(x1: Array, x2: Array, /) -> Array: return wrapper.land(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def lor(x1: Array, x2: Array, /) -> Array: return wrapper.lor(x1.arr, x2.arr) # type: ignore[arg-type, return-value] -@return_copy +@afarray_as_array def lnot(x: Array, /) -> Array: return wrapper.lnot(x.arr) # type: ignore[arg-type, return-value] diff --git a/arrayfire/library/random.py b/arrayfire/library/random.py index 86d32f8..33c463c 100644 --- a/arrayfire/library/random.py +++ b/arrayfire/library/random.py @@ -4,9 +4,9 @@ from typing import cast from arrayfire import Array +from arrayfire._array_helpers import afarray_as_array from arrayfire.backend import _clib_wrapper as wrapper from arrayfire.dtypes import Dtype, float32 -from arrayfire._array_helpers import afarray_as_array class RandomEngineType(Enum): diff --git a/arrayfire/library/utils.py b/arrayfire/library/utils.py index 5a55602..cfafd5e 100644 --- a/arrayfire/library/utils.py +++ b/arrayfire/library/utils.py @@ -1,8 +1,42 @@ +from typing import cast + from arrayfire import Array +from arrayfire._array_helpers import afarray_as_array +from arrayfire.backend import _clib_wrapper as wrapper + + +@afarray_as_array +def _all_true(array: Array, axis: int, /) -> Array: + result = wrapper.all_true(array.arr, axis) + return cast(Array, result) + + +def all_true(array: Array, axis: int | None = None) -> bool | Array: + """ + Check if all the elements along a specified dimension are true. + + Parameters + ---------- + array : Array + Multi-dimensional ArrayFire array. + + axis : int, optional, default: None + Dimension along which the product is required. + + Returns + ------- + bool | Array + An ArrayFire array containing True if all elements in `array` along the specified dimension are True. + If `axis` is `None`, the output is True if `array` does not have any zeros, else False. + Note + ---- + If `axis` is `None`, output is True if the array does not have any zeros, else False. + """ + if axis is None: + return wrapper.all_true_all(array.arr) -def all_true(array: Array, axis: int | None = None) -> Array: - return NotImplemented + return _all_true(array, axis) # from time import time diff --git a/examples/helloworld.py b/examples/helloworld.py deleted file mode 100644 index d8e2415..0000000 --- a/examples/helloworld.py +++ /dev/null @@ -1,57 +0,0 @@ -#!/usr/bin/python - -####################################################### -# Copyright (c) 2015, ArrayFire -# All rights reserved. -# -# This file is distributed under 3-clause BSD license. -# The complete license agreement can be obtained at: -# http://arrayfire.com/licenses/BSD-3-Clause -######################################################## - -import arrayfire as af - -# Display backend information -# af.info() - -print("Create a 5-by-3 matrix of random floats on the GPU\n") -A = af.randu(5, 3, 1, 1, af.Dtype.f32) -print(A) - -print("Element-wise arithmetic\n") -B = af.sin(A) + 1.5 -print(B) - -print("Negate the first three elements of second column\n") -B[0:3, 1] = B[0:3, 1] * -1 -print(B) - -print("Fourier transform the result\n") -C = af.fft(B) -print(C) - -print("Grab last row\n") -c = C[-1, :] -print(c) - -print("Scan Test\n") -r = af.constant(2, 16, 4, 1, 1) -print(r) - -print("Scan\n") -S = af.scan(r, 0, af.BINARYOP.MUL) -print(S) - -print("Create 2-by-3 matrix from host data\n") -d = [1, 2, 3, 4, 5, 6] -print(af.Array(d, shape=(2, 3))) - -print("Copy last column onto first\n") -D[:, 0] = D[:, -1] -print(D) - -print("Sort A and print sorted array and corresponding indices\n") -[sorted_vals, sorted_idxs] = af.sort_index(A) -print(A) -print(sorted_vals) -print(sorted_idxs) diff --git a/tests/test_data.py b/tests/test_data.py new file mode 100644 index 0000000..c7f4022 --- /dev/null +++ b/tests/test_data.py @@ -0,0 +1,137 @@ +import pytest + +from arrayfire import Array +from arrayfire.dtypes import int64 +from arrayfire.library import data, random +from arrayfire.library.utils import all_true + +# Test cases for the constant function + + +def test_constant_1d() -> None: + result = data.constant(42, (5,)) + assert result.shape == (5,) + assert result.scalar() == 42 + + +def test_constant_2d() -> None: + result = data.constant(3.14, (3, 4)) + assert result.shape == (3, 4) + assert round(result.scalar(), 2) == 3.14 # type: ignore[arg-type] + + +def test_constant_3d() -> None: + result = data.constant(0, (2, 2, 2), dtype=int64) + assert result.shape == (2, 2, 2) + assert result.scalar() == 0 + assert result.dtype == int64 + + +def test_constant_default_shape() -> None: + result = data.constant(1.0) + assert result.shape == (1,) + assert result.scalar() == 1.0 + + +# TODO add error handling +# def test_constant_invalid_dtype() -> None: +# with pytest.raises(ValueError): +# data.constant(42, (3, 3), dtype="invalid_dtype") + + +# Test cases for the range function + + +def test_range_1d() -> None: + result = data.range((5,)) + assert result.shape == (5,) + + +def test_range_2d() -> None: + result = data.range((3, 4)) + assert result.shape == (3, 4) + + +def test_range_3d() -> None: + result = data.range((2, 2, 2)) + assert result.shape == (2, 2, 2) + + +def test_range_with_axis() -> None: + result = data.range((3, 4), axis=1) + assert result.shape == (3, 4) + + +def test_range_with_dtype() -> None: + result = data.range((4, 3), dtype=int64) + assert result.dtype == int64 + + +def test_range_with_invalid_axis() -> None: + with pytest.raises(ValueError): + data.range((2, 3, 4), axis=4) + + +# Test cases for the identity function + +# BUG in to_list() +# def test_identity_2x2() -> None: +# result = data.identity((2, 2)) +# expected = [[1, 0], [0, 1]] +# assert result.to_list() == expected + + +# def test_identity_3x3() -> None: +# result = data.identity((3, 3)) +# expected = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] +# assert result.to_list() == expected + + +# def test_identity_2x2x2() -> None: +# result = data.identity((2, 2, 2)) +# expected = [[[1, 0], [0, 1]], [[1, 0], [0, 1]]] +# assert result.to_list() == expected + + +# Test cases for the flat function + + +def test_flat_empty_array() -> None: + arr = Array() + flattened = data.flat(arr) + assert flattened.shape == () + + +def test_flat_1d() -> None: + arr = random.randu((5,)) + flattened = data.flat(arr) + assert flattened.shape == (5,) + assert all_true(flattened == arr, 0) + + +def test_flat_2d() -> None: + arr = random.randu((3, 2)) + flattened = data.flat(arr) + assert flattened.shape == (6,) + assert all_true(flattened == data.flat(arr), 0) + + +def test_flat_3d() -> None: + arr = random.randu((3, 2, 4)) + flattened = data.flat(arr) + assert flattened.shape == (24,) + assert all_true(flattened == data.flat(arr), 0) + + +def test_flat_4d() -> None: + arr = random.randu((3, 2, 4, 5)) + flattened = data.flat(arr) + assert flattened.shape == (120,) + assert all_true(flattened == data.flat(arr), 0) + + +def test_flat_large_array() -> None: + arr = random.randu((1000, 1000)) + flattened = data.flat(arr) + assert flattened.shape == (1000000,) + assert all_true(flattened == data.flat(arr), 0) diff --git a/tests/wip_test_data.py b/tests/wip_test_data.py deleted file mode 100644 index b27ce3b..0000000 --- a/tests/wip_test_data.py +++ /dev/null @@ -1,148 +0,0 @@ -import pytest - -from arrayfire.dtypes import int64 -from arrayfire.library import data -from arrayfire import Array - - -# Test cases for the constant function -def test_constant_1d() -> None: - result = data.constant(42, (5,)) - assert result.shape == (5,) - assert result.scalar() == 42 - - -def test_constant_2d() -> None: - result = data.constant(3.14, (3, 4)) - assert result.shape == (3, 4) - assert round(result.scalar(), 2) == 3.14 # type: ignore[arg-type] - - -def test_constant_3d() -> None: - result = data.constant(0, (2, 2, 2), dtype=int64) - assert result.shape == (2, 2, 2) - assert result.scalar() == 0 - assert result.dtype == int64 - - -def test_constant_default_shape() -> None: - result = data.constant(1.0) - assert result.shape == (1,) - assert result.scalar() == 1.0 - - -# TODO add error handling -# def test_constant_invalid_dtype() -> None: -# with pytest.raises(ValueError): -# data.constant(42, (3, 3), dtype="invalid_dtype") - - -# Test cases for the range function -def test_range_1d() -> None: - result = data.range((5,)) - assert result.shape == (5,) - - -def test_range_2d() -> None: - result = data.range((3, 4)) - assert result.shape == (3, 4) - - -def test_range_3d() -> None: - result = data.range((2, 2, 2)) - assert result.shape == (2, 2, 2) - - -def test_range_with_axis() -> None: - result = data.range((3, 4), axis=1) - assert result.shape == (3, 4) - - -def test_range_with_dtype() -> None: - result = data.range((4, 3), dtype=int64) - assert result.dtype == int64 - - -def test_range_with_invalid_axis() -> None: - with pytest.raises(ValueError): - data.range((2, 3, 4), axis=4) - - -# Test cases for the identity function - - -def _is_identity_matrix(arr: Array) -> bool: - rows, cols = arr.shape - if rows != cols: - return False - for i in range(rows): - for j in range(cols): - if i == j and arr[i, j] != 1: - return False - elif i != j and arr[i, j] != 0: - return False - return True - - -# Test cases for the identity function -def test_identity_2d() -> None: - result = data.identity((3, 3)) - assert result.shape == (3, 3) - assert _is_identity_matrix(result) - - -def test_identity_3d() -> None: - result = data.identity((2, 2, 2)) - assert result.shape == (2, 2, 2) - assert custom_all(result, lambda x: x == 1.0) - - -def test_identity_4d() -> None: - result = data.identity((2, 2, 2, 2)) - assert result.shape == (2, 2, 2, 2) - assert custom_all(result, lambda x: x == 1.0) - - -def test_identity_with_dtype() -> None: - result = data.identity((3, 3), dtype=int64) - assert result.shape == (3, 3) - assert result.dtype == int64 - - -def test_identity_invalid_shape() -> None: - with pytest.raises(ValueError): - data.identity((1,)) - - -def test_identity_invalid_shape2() -> None: - with pytest.raises(ValueError): - data.identity((3,)) - - -# Custom function to check if all elements in an array meet a condition -def custom_all(arr: Array) -> bool: - for element in data.flat(arr): - if not element: - return False - return True - - -# Test cases for the flat function -def test_flat_2d(): - input_array = af.randu(3, 2) # Create a 3x2 random array - result = data.flat(input_array) - assert result.shape == (6,) # Flattened shape should be 6 elements in 1D - assert custom_all(result == input_array) - - -def test_flat_3d(): - input_array = af.randu(2, 2, 2) # Create a 2x2x2 random array - result = data.flat(input_array) - assert result.shape == (8,) # Flattened shape should be 8 elements in 1D - assert custom_all(result == input_array) - - -def test_flat_empty_array(): - input_array = af.Array() # Create an empty array - result = data.flat(input_array) - assert result.shape == (0,) # Flattened shape of an empty array should be (0,) From b7923668fe7d3a86578c975c94d794b57db689c8 Mon Sep 17 00:00:00 2001 From: Anton Date: Sat, 2 Sep 2023 04:40:09 +0300 Subject: [PATCH 25/31] Add typings from 3.10 --- arrayfire/_array_helpers.py | 2 +- arrayfire/array_object.py | 3 +- arrayfire/backend/_backend.py | 6 ++-- arrayfire/backend/_backend_functions.py | 10 +++--- .../backend/_clib_wrapper/_constant_array.py | 12 +++---- arrayfire/backend/_clib_wrapper/_indexing.py | 10 +++--- arrayfire/backend/_clib_wrapper/_operators.py | 3 +- .../_clib_wrapper/_reduction_operations.py | 7 ++-- arrayfire/backend/_clib_wrapper/_unsorted.py | 34 +++++++------------ arrayfire/dtypes.py | 6 ++-- arrayfire/library/broadcast.py | 3 +- arrayfire/library/data.py | 18 +++++----- arrayfire/library/old_operators.py | 2 +- arrayfire/library/operators.py | 26 +++++++------- tests/_helpers.py | 5 +-- tests/array_object/test_initialization.py | 6 ++-- tests/array_object/test_operator_overrides.py | 17 +++++----- tests/test_dtypes.py | 3 +- 18 files changed, 82 insertions(+), 91 deletions(-) diff --git a/arrayfire/_array_helpers.py b/arrayfire/_array_helpers.py index ea7cb99..ebccc87 100644 --- a/arrayfire/_array_helpers.py +++ b/arrayfire/_array_helpers.py @@ -1,4 +1,4 @@ -from typing import Callable +from collections.abc import Callable from typing_extensions import ParamSpec diff --git a/arrayfire/array_object.py b/arrayfire/array_object.py index 0f1e467..0966036 100755 --- a/arrayfire/array_object.py +++ b/arrayfire/array_object.py @@ -1,8 +1,9 @@ from __future__ import annotations import array as py_array +from collections.abc import Callable from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, cast +from typing import TYPE_CHECKING, Any, cast from .backend import _clib_wrapper as wrapper from .dtypes import CType, Dtype, c_api_value_to_dtype, float32, str_to_dtype diff --git a/arrayfire/backend/_backend.py b/arrayfire/backend/_backend.py index 7791a4d..b151679 100644 --- a/arrayfire/backend/_backend.py +++ b/arrayfire/backend/_backend.py @@ -8,7 +8,7 @@ from dataclasses import dataclass from enum import Enum from pathlib import Path -from typing import Iterator, List, Optional +from typing import Iterator from arrayfire.logger import logger from arrayfire.version import ARRAYFIRE_VER_MAJOR @@ -187,7 +187,7 @@ def _load_nvrtc_builtins_lib(self, lib_path: Path) -> None: else: logger.warning("Could not find local nvrtc-builtins library") - def _lib_names(self, name: str, lib: _LibPrefixes, ver_major: Optional[str] = None) -> List[Path]: + def _lib_names(self, name: str, lib: _LibPrefixes, ver_major: str | None = None) -> list[Path]: post = self._backend_path_config.lib_postfix if ver_major is None else ver_major lib_name = self._backend_path_config.lib_prefix + lib.value + name + post @@ -208,7 +208,7 @@ def _lib_names(self, name: str, lib: _LibPrefixes, ver_major: Optional[str] = No lib_paths.insert(2, Path(str(search_path), lib_name)) return lib_paths - def _find_nvrtc_builtins_lib_name(self, search_path: Path) -> Optional[str]: + def _find_nvrtc_builtins_lib_name(self, search_path: Path) -> str | None: for f in search_path.iterdir(): if "nvrtc-builtins" in f.name: return f.name diff --git a/arrayfire/backend/_backend_functions.py b/arrayfire/backend/_backend_functions.py index 98f9e70..58ad611 100755 --- a/arrayfire/backend/_backend_functions.py +++ b/arrayfire/backend/_backend_functions.py @@ -2,7 +2,7 @@ import warnings from enum import Enum -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING from ._backend import Backend, BackendType, get_backend from ._clib_wrapper._unsorted import cublas_set_math_mode @@ -25,13 +25,13 @@ class CublasMathMode(Enum): tensor_op = 1 -def set_backend(backend_type: Union[BackendType, str]) -> None: +def set_backend(backend_type: BackendType | str) -> None: """ Set a specific backend by backend_type name. Parameters ---------- - backend_type : Union[BackendType, str] + backend_type : BackendType | str Name of the backend type to set. Raises @@ -278,13 +278,13 @@ def set_native_cuda_id(index: int) -> None: return c_set_native_id(index) -def set_cublas_mode(mode: Union[CublasMathMode, int] = CublasMathMode.default) -> None: +def set_cublas_mode(mode: CublasMathMode | int = CublasMathMode.default) -> None: """ Set cuBLAS math mode for CUDA backend. It enables the Tensor Core usage if available on CUDA backend GPUs. Parameters ---------- - mode : Union[CublasMathMode, int] + mode : CublasMathMode | int Specify the mode available within CublasMathMode enum. Raises diff --git a/arrayfire/backend/_clib_wrapper/_constant_array.py b/arrayfire/backend/_clib_wrapper/_constant_array.py index e8d1d58..52b2fad 100755 --- a/arrayfire/backend/_clib_wrapper/_constant_array.py +++ b/arrayfire/backend/_clib_wrapper/_constant_array.py @@ -1,7 +1,7 @@ from __future__ import annotations import ctypes -from typing import TYPE_CHECKING, Tuple, Union +from typing import TYPE_CHECKING from arrayfire.backend._backend import _backend from arrayfire.dtypes import CShape, Dtype, complex64, implicit_dtype, int64, is_complex_dtype, uint64 @@ -12,7 +12,7 @@ from ._base import AFArrayType -def _constant_complex(number: Union[int, float, complex], shape: Tuple[int, ...], dtype: Dtype, /) -> AFArrayType: +def _constant_complex(number: int | float | complex, shape: tuple[int, ...], dtype: Dtype, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__data__func__constant.htm#ga5a083b1f3cd8a72a41f151de3bdea1a2 """ @@ -32,7 +32,7 @@ def _constant_complex(number: Union[int, float, complex], shape: Tuple[int, ...] return out -def _constant_long(number: Union[int, float], shape: Tuple[int, ...], dtype: Dtype, /) -> AFArrayType: +def _constant_long(number: int | float, shape: tuple[int, ...], dtype: Dtype, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__data__func__constant.htm#ga10f1c9fad1ce9e9fefd885d5a1d1fd49 """ @@ -47,7 +47,7 @@ def _constant_long(number: Union[int, float], shape: Tuple[int, ...], dtype: Dty return out -def _constant_ulong(number: Union[int, float], shape: Tuple[int, ...], dtype: Dtype, /) -> AFArrayType: +def _constant_ulong(number: int | float, shape: tuple[int, ...], dtype: Dtype, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__data__func__constant.htm#ga67af670cc9314589f8134019f5e68809 """ @@ -62,7 +62,7 @@ def _constant_ulong(number: Union[int, float], shape: Tuple[int, ...], dtype: Dt return out -def _constant(number: Union[int, float], shape: Tuple[int, ...], dtype: Dtype, /) -> AFArrayType: +def _constant(number: int | float, shape: tuple[int, ...], dtype: Dtype, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__data__func__constant.htm#gafc51b6a98765dd24cd4139f3bde00670 """ @@ -77,7 +77,7 @@ def _constant(number: Union[int, float], shape: Tuple[int, ...], dtype: Dtype, / return out -def create_constant_array(number: Union[int, float, complex], shape: Tuple[int, ...], dtype: Dtype, /) -> AFArrayType: +def create_constant_array(number: int | float | complex, shape: tuple[int, ...], dtype: Dtype, /) -> AFArrayType: dtype = implicit_dtype(number, dtype) if isinstance(number, complex): diff --git a/arrayfire/backend/_clib_wrapper/_indexing.py b/arrayfire/backend/_clib_wrapper/_indexing.py index e06acf1..bca8678 100755 --- a/arrayfire/backend/_clib_wrapper/_indexing.py +++ b/arrayfire/backend/_clib_wrapper/_indexing.py @@ -2,7 +2,7 @@ import ctypes import math -from typing import Any, Tuple, Union +from typing import Any from arrayfire.backend._backend import _backend from arrayfire.library.broadcast import bcast_var @@ -36,7 +36,7 @@ class _IndexSequence(ctypes.Structure): # More about _fields_ purpose: https://docs.python.org/3/library/ctypes.html#structures-and-unions _fields_ = [("begin", ctypes.c_double), ("end", ctypes.c_double), ("step", ctypes.c_double)] - def __init__(self, chunk: Union[int, slice]): + def __init__(self, chunk: int | slice): self.begin = ctypes.c_double(0) self.end = ctypes.c_double(-1) self.step = ctypes.c_double(1) @@ -128,9 +128,7 @@ class ParallelRange(_IndexSequence): """ - def __init__( - self, start: Union[int, float], stop: Union[int, float, None] = None, step: Union[int, float, None] = None - ) -> None: + def __init__(self, start: int | float, stop: int | float | None = None, step: int | float | None = None) -> None: if not stop: stop = start start = 0 @@ -248,7 +246,7 @@ def __setitem__(self, idx: int, value: IndexStructure) -> None: self.idxs[idx] = value -def get_indices(key: Union[int, slice, Tuple[Union[int, slice], ...]]) -> CIndexStructure: +def get_indices(key: int | slice | tuple[int | slice, ...]) -> CIndexStructure: indices = CIndexStructure() if isinstance(key, tuple): diff --git a/arrayfire/backend/_clib_wrapper/_operators.py b/arrayfire/backend/_clib_wrapper/_operators.py index 2832cc8..e7e0e28 100755 --- a/arrayfire/backend/_clib_wrapper/_operators.py +++ b/arrayfire/backend/_clib_wrapper/_operators.py @@ -1,7 +1,8 @@ from __future__ import annotations import ctypes -from typing import TYPE_CHECKING, Callable +from collections.abc import Callable +from typing import TYPE_CHECKING from arrayfire.backend._backend import _backend from arrayfire.library.broadcast import bcast_var diff --git a/arrayfire/backend/_clib_wrapper/_reduction_operations.py b/arrayfire/backend/_clib_wrapper/_reduction_operations.py index fed81c7..900be6c 100755 --- a/arrayfire/backend/_clib_wrapper/_reduction_operations.py +++ b/arrayfire/backend/_clib_wrapper/_reduction_operations.py @@ -1,7 +1,8 @@ from __future__ import annotations import ctypes -from typing import TYPE_CHECKING, Callable, Union +from collections.abc import Callable +from typing import TYPE_CHECKING from arrayfire.backend._backend import _backend @@ -11,12 +12,12 @@ from ._base import AFArrayType -def count_all(x: AFArrayType) -> Union[int, float, complex]: +def count_all(x: AFArrayType) -> int | float | complex: # TODO reconsider original arith.count return _reduce_all(x, _backend.clib.af_count_all) -def _reduce_all(arr: AFArrayType, c_func: Callable) -> Union[int, float, complex]: +def _reduce_all(arr: AFArrayType, c_func: Callable) -> int | float | complex: real = ctypes.c_double(0) imag = ctypes.c_double(0) safe_call(c_func(ctypes.pointer(real), ctypes.pointer(imag), arr)) diff --git a/arrayfire/backend/_clib_wrapper/_unsorted.py b/arrayfire/backend/_clib_wrapper/_unsorted.py index 4e810af..946cd72 100755 --- a/arrayfire/backend/_clib_wrapper/_unsorted.py +++ b/arrayfire/backend/_clib_wrapper/_unsorted.py @@ -1,7 +1,7 @@ from __future__ import annotations import ctypes -from typing import TYPE_CHECKING, Any, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, cast from arrayfire.backend._backend import _backend from arrayfire.dtypes import CShape, CType, Dtype, c_dim_t, to_str @@ -15,7 +15,7 @@ # Array management -def create_handle(shape: Tuple[int, ...], dtype: Dtype, /) -> AFArrayType: +def create_handle(shape: tuple[int, ...], dtype: Dtype, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__c__api__mat.htm#ga3b8f5cf6fce69aa1574544bc2d44d7d0 """ @@ -40,7 +40,7 @@ def retain_array(arr: AFArrayType) -> AFArrayType: return out -def create_array(shape: Tuple[int, ...], dtype: Dtype, array_buffer: _ArrayBuffer, /) -> AFArrayType: +def create_array(shape: tuple[int, ...], dtype: Dtype, array_buffer: _ArrayBuffer, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__c__api__mat.htm#ga834be32357616d8ab735087c6f681858 """ @@ -59,7 +59,7 @@ def create_array(shape: Tuple[int, ...], dtype: Dtype, array_buffer: _ArrayBuffe return out -def device_array(shape: Tuple[int, ...], dtype: Dtype, array_buffer: _ArrayBuffer, /) -> AFArrayType: +def device_array(shape: tuple[int, ...], dtype: Dtype, array_buffer: _ArrayBuffer, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__c__api__mat.htm#gaad4fc77f872217e7337cb53bfb623cf5 """ @@ -79,11 +79,11 @@ def device_array(shape: Tuple[int, ...], dtype: Dtype, array_buffer: _ArrayBuffe def create_strided_array( - shape: Tuple[int, ...], + shape: tuple[int, ...], dtype: Dtype, array_buffer: _ArrayBuffer, offset: CType, - strides: Tuple[int, ...], + strides: tuple[int, ...], pointer_source: PointerSource, /, ) -> AFArrayType: @@ -147,7 +147,7 @@ def get_numdims(arr: AFArrayType) -> int: return out.value -def get_dims(arr: AFArrayType) -> Tuple[int, ...]: +def get_dims(arr: AFArrayType) -> tuple[int, ...]: """ source: https://arrayfire.org/docs/group__c__api__mat.htm#ga8b90da50a532837d9763e301b2267348 """ @@ -162,13 +162,13 @@ def get_dims(arr: AFArrayType) -> Tuple[int, ...]: return (d0.value, d1.value, d2.value, d3.value) -def get_scalar(arr: AFArrayType, dtype: Dtype, /) -> Union[None, int, float, bool, complex]: +def get_scalar(arr: AFArrayType, dtype: Dtype, /) -> int | float | complex | bool | None: """ source: https://arrayfire.org/docs/group__c__api__mat.htm#gaefe2e343a74a84bd43b588218ecc09a3 """ out = dtype.c_type() safe_call(_backend.clib.af_get_scalar(ctypes.pointer(out), arr)) - return cast(Union[None, int, float, bool, complex], out.value) + return cast(int | float | complex | bool | None, out.value) def is_empty(arr: AFArrayType) -> bool: @@ -205,17 +205,7 @@ def copy_array(arr: AFArrayType) -> AFArrayType: def index_gen( arr: AFArrayType, ndims: int, - key: Union[ - int, - slice, - Tuple[ - Union[ - int, - slice, - ], - ..., - ], - ], + key: int | slice | tuple[int | slice, ...], indices: Any, /, ) -> AFArrayType: @@ -268,7 +258,7 @@ def where(arr: AFArrayType) -> AFArrayType: return out -def af_range(shape: Tuple[int, ...], axis: int, dtype: Dtype, /) -> AFArrayType: +def af_range(shape: tuple[int, ...], axis: int, dtype: Dtype, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__data__func__range.htm#gadd6c9b479692454670a51e00ea5b26d5 """ @@ -278,7 +268,7 @@ def af_range(shape: Tuple[int, ...], axis: int, dtype: Dtype, /) -> AFArrayType: return out -def identity(shape: Tuple[int, ...], dtype: Dtype, /) -> AFArrayType: +def identity(shape: tuple[int, ...], dtype: Dtype, /) -> AFArrayType: """ source: """ diff --git a/arrayfire/dtypes.py b/arrayfire/dtypes.py index 083e332..0462fcb 100644 --- a/arrayfire/dtypes.py +++ b/arrayfire/dtypes.py @@ -2,7 +2,7 @@ import ctypes from dataclasses import dataclass -from typing import Tuple, Type, Union +from typing import Type from arrayfire.backend._backend import is_arch_x86 @@ -64,7 +64,7 @@ def is_complex_dtype(dtype: Dtype) -> _python_bool: c_dim_t = ctypes.c_int if is_arch_x86() else ctypes.c_longlong -ShapeType = Tuple[int, ...] +ShapeType = tuple[int, ...] class CShape(tuple): @@ -91,7 +91,7 @@ def to_str(c_str: ctypes.c_char_p) -> str: return str(c_str.value.decode("utf-8")) # type: ignore[union-attr] -def implicit_dtype(number: Union[int, float, _python_bool, complex], array_dtype: Dtype) -> Dtype: +def implicit_dtype(number: int | float | _python_bool | complex, array_dtype: Dtype) -> Dtype: if isinstance(number, _python_bool): number_dtype = bool elif isinstance(number, int): diff --git a/arrayfire/library/broadcast.py b/arrayfire/library/broadcast.py index 90f4cea..65b031f 100644 --- a/arrayfire/library/broadcast.py +++ b/arrayfire/library/broadcast.py @@ -1,4 +1,5 @@ -from typing import Any, Callable +from collections.abc import Callable +from typing import Any class Bcast: diff --git a/arrayfire/library/data.py b/arrayfire/library/data.py index 24308ea..71b24f1 100644 --- a/arrayfire/library/data.py +++ b/arrayfire/library/data.py @@ -1,4 +1,4 @@ -from typing import Tuple, Union, cast +from typing import cast from arrayfire import Array from arrayfire._array_helpers import afarray_as_array @@ -7,18 +7,20 @@ _pyrange = range +# TODO add more error handles + @afarray_as_array -def constant(scalar: Union[int, float, complex], shape: Tuple[int, ...] = (1,), dtype: Dtype = float32) -> Array: +def constant(scalar: int | float | complex, shape: tuple[int, ...] = (1,), dtype: Dtype = float32) -> Array: """ Create a multi-dimensional array filled with a constant value. Parameters ---------- - scalar : Union[int, float, complex] + scalar : int | float | complex The value to fill each element of the constant array with. - shape : Tuple[int, ...], optional, default: (1,) + shape : tuple[int, ...], optional, default: (1,) The shape of the constant array. dtype : Dtype, optional, default: float32 @@ -42,13 +44,13 @@ def constant(scalar: Union[int, float, complex], shape: Tuple[int, ...] = (1,), @afarray_as_array -def range(shape: Tuple[int, ...], axis: int = 0, dtype: Dtype = float32) -> Array: +def range(shape: tuple[int, ...], axis: int = 0, dtype: Dtype = float32) -> Array: """ Create a multi-dimensional array using the length of a dimension as a range. Parameters ---------- - shape : Tuple[int, ...] + shape : tuple[int, ...] The shape of the resulting array. Each element represents the length of a corresponding dimension. @@ -170,13 +172,13 @@ def range(shape: Tuple[int, ...], axis: int = 0, dtype: Dtype = float32) -> Arra @afarray_as_array -def identity(shape: Tuple[int, ...], dtype: Dtype = float32) -> Array: +def identity(shape: tuple[int, ...], dtype: Dtype = float32) -> Array: """ Create an identity matrix or batch of identity matrices. Parameters ---------- - shape : Tuple[int, ...] + shape : tuple[int, ...] The shape of the resulting identity array or batch of arrays. Must have at least 2 values. diff --git a/arrayfire/library/old_operators.py b/arrayfire/library/old_operators.py index 3dc1361..9924482 100644 --- a/arrayfire/library/old_operators.py +++ b/arrayfire/library/old_operators.py @@ -1,4 +1,4 @@ -from typing import Callable +from collections.abc import Callable from arrayfire import Array from arrayfire.backend import _clib_wrapper as wrapper diff --git a/arrayfire/library/operators.py b/arrayfire/library/operators.py index 4413aee..3b86f52 100755 --- a/arrayfire/library/operators.py +++ b/arrayfire/library/operators.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Union, cast +from typing import cast from arrayfire import Array from arrayfire._array_helpers import afarray_as_array @@ -29,15 +29,15 @@ def div(x1: Array, x2: Array, /) -> Array: @afarray_as_array -def mod(x1: Union[int, float, Array], x2: Union[int, float, Array], /) -> Array: +def mod(x1: int | float | Array, x2: int | float | Array, /) -> Array: """ Calculate the modulus of two arrays or a scalar and an array. Parameters ---------- - x1 : Union[int, float, Array] + x1 : int |float |Array The first array or scalar operand. - x2 : Union[int, float, Array] + x2 : int |float |Array The second array or scalar operand. Returns @@ -61,7 +61,7 @@ def mod(x1: Union[int, float, Array], x2: Union[int, float, Array], /) -> Array: @afarray_as_array -def pow(x1: Union[int, float, Array], x2: Union[int, float, Array], /) -> Array: +def pow(x1: int | float | Array, x2: int | float | Array, /) -> Array: _check_operands_fit_requirements(x1, x2) return wrapper.pow(x1.arr, x2.arr) # type: ignore[arg-type, return-value] @@ -142,21 +142,21 @@ def neq(x1: Array, x2: Array, /) -> Array: @afarray_as_array -def minof(x1: Union[int, float, Array], x2: Union[int, float, Array], /) -> Array: +def minof(x1: int | float | Array, x2: int | float | Array, /) -> Array: _check_operands_fit_requirements(x1, x2) return wrapper.minof(x1.arr, x2.arr) # type: ignore[arg-type, return-value] @afarray_as_array -def maxof(x1: Union[int, float, Array], x2: Union[int, float, Array], /) -> Array: +def maxof(x1: int | float | Array, x2: int | float | Array, /) -> Array: _check_operands_fit_requirements(x1, x2) return wrapper.maxof(x1.arr, x2.arr) # type: ignore[arg-type, return-value] @afarray_as_array -def rem(x1: Union[int, float, Array], x2: Union[int, float, Array], /) -> Array: +def rem(x1: int | float | Array, x2: int | float | Array, /) -> Array: _check_operands_fit_requirements(x1, x2) return wrapper.rem(x1.arr, x2.arr) # type: ignore[arg-type, return-value] @@ -198,7 +198,7 @@ def ceil(x: Array, /) -> Array: @afarray_as_array -def hypot(x1: Union[int, float, Array], x2: Union[int, float, Array], /) -> Array: +def hypot(x1: int | float | Array, x2: int | float | Array, /) -> Array: _check_operands_fit_requirements(x1, x2) return wrapper.hypot(x1.arr, x2.arr) # type: ignore[arg-type, return-value] @@ -242,13 +242,13 @@ def atan(x: Array, /) -> Array: @afarray_as_array -def atan2(x1: Union[int, float, Array], x2: Union[int, float, Array], /) -> Array: +def atan2(x1: int | float | Array, x2: int | float | Array, /) -> Array: _check_operands_fit_requirements(x1, x2) return wrapper.atan2(x1.arr, x2.arr) # type: ignore[arg-type, return-value] @afarray_as_array -def cplx(x1: Union[int, float, Array], x2: Union[int, float, Array, None], /) -> Array: +def cplx(x1: int | float | Array, x2: int | float | Array | None, /) -> Array: if x2 is None: return wrapper.cplx1(x1) # type: ignore[arg-type, return-value] else: @@ -307,7 +307,7 @@ def atanh(x: Array, /) -> Array: @afarray_as_array -def root(x1: Union[int, float, Array], x2: Union[int, float, Array], /) -> Array: +def root(x1: int | float | Array, x2: int | float | Array, /) -> Array: _check_operands_fit_requirements(x1, x2) return wrapper.root(x1.arr, x2.arr) @@ -441,7 +441,7 @@ def lnot(x: Array, /) -> Array: return wrapper.lnot(x.arr) # type: ignore[arg-type, return-value] -def _check_operands_fit_requirements(x1: Union[int, float, Array], x2: Union[int, float, Array]) -> None: +def _check_operands_fit_requirements(x1: int | float | Array, x2: int | float | Array) -> None: if isinstance(x1, Array) and isinstance(x2, Array): if x1.shape != x2.shape: raise ValueError("Array shapes must match.") diff --git a/tests/_helpers.py b/tests/_helpers.py index 1a45730..038062e 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -1,6 +1,3 @@ -from typing import List, Union - - -def round_to(list_: List[Union[int, float, complex, bool]], symbols: int = 3) -> List[Union[int, float]]: +def round_to(list_: list[int | float | complex | bool], symbols: int = 3) -> list[int | float]: # HACK replace for e.g. abs(x1-x2) < 1e-6 ~ https://davidamos.dev/the-right-way-to-compare-floats-in-python/ return [round(x, symbols) for x in list_] diff --git a/tests/array_object/test_initialization.py b/tests/array_object/test_initialization.py index 430b7c3..108a05e 100755 --- a/tests/array_object/test_initialization.py +++ b/tests/array_object/test_initialization.py @@ -1,6 +1,6 @@ import array as pyarray import math -from typing import Any, Optional, Tuple +from typing import Any import pytest @@ -27,7 +27,7 @@ ], ) def test_initialization_with_different_arguments( - array: Array, res_dtype: Dtype, res_ndim: int, res_size: int, res_shape: Tuple[int, ...], res_len: int + array: Array, res_dtype: Dtype, res_ndim: int, res_size: int, res_shape: tuple[int, ...], res_len: int ) -> None: assert array.dtype == res_dtype assert array.ndim == res_ndim @@ -51,7 +51,7 @@ def test_initialization_with_different_arguments( ], ) def test_initalization_with_unsupported_argument_types( - array_object: Any, dtype: Optional[Dtype], shape: Tuple[int, ...] + array_object: Any, dtype: Dtype | None, shape: tuple[int, ...] ) -> None: with pytest.raises(TypeError): Array(obj=array_object, dtype=dtype, shape=shape) diff --git a/tests/array_object/test_operator_overrides.py b/tests/array_object/test_operator_overrides.py index cd9a68e..49fc4ff 100644 --- a/tests/array_object/test_operator_overrides.py +++ b/tests/array_object/test_operator_overrides.py @@ -1,5 +1,6 @@ import operator -from typing import Any, Callable, List, Union +from collections.abc import Callable +from typing import Any import pytest @@ -7,7 +8,7 @@ from arrayfire.dtypes import bool as af_bool from tests._helpers import round_to -Operator = Callable[[Union[int, float, Array], Union[int, float, Array]], Array] +Operator = Callable[[int | float | Array, int | float | Array], Array] def pytest_generate_tests(metafunc: Any) -> None: @@ -52,9 +53,9 @@ def pytest_generate_tests(metafunc: Any) -> None: def test_arithmetic_operators( - array_origin: List[Union[int, float]], + array_origin: list[int | float], arithmetic_operator: str, - operand: Union[int, float, List[Union[int, float]]], + operand: int | float | list[int | float], ) -> None: op = getattr(operator, arithmetic_operator) iop = getattr(operator, "i" + arithmetic_operator) @@ -84,7 +85,7 @@ def test_arithmetic_operators( def test_arithmetic_operators_expected_to_raise_error( - array_origin: List[Union[int, float]], arithmetic_operator: str, false_operand: Any + array_origin: list[int | float], arithmetic_operator: str, false_operand: Any ) -> None: array = Array(array_origin) op = getattr(operator, arithmetic_operator) @@ -93,9 +94,9 @@ def test_arithmetic_operators_expected_to_raise_error( def test_comparison_operators( - array_origin: List[Union[int, float]], + array_origin: list[int | float], comparison_operator: Operator, - operand: Union[int, float, List[Union[int, float]]], + operand: int | float | list[int | float], ) -> None: if isinstance(operand, list): ref = [comparison_operator(x, y) for x, y in zip(array_origin, operand)] @@ -111,7 +112,7 @@ def test_comparison_operators( def test_comparison_operators_expected_to_raise_error( - array_origin: List[Union[int, float]], comparison_operator: Operator, false_operand: Any + array_origin: list[int | float], comparison_operator: Operator, false_operand: Any ) -> None: array = Array(array_origin) diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index eecc843..d7ad295 100755 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -1,5 +1,4 @@ import ctypes -from typing import Union import pytest @@ -36,7 +35,7 @@ def test_dtype_inequality() -> None: (1 + 2j, complex128, complex128), ], ) -def test_implicit_dtype(number: Union[int, float, bool, complex], array_dtype: Dtype, expected_dtype: Dtype) -> None: +def test_implicit_dtype(number: int | float | bool | complex, array_dtype: Dtype, expected_dtype: Dtype) -> None: result_dtype = implicit_dtype(number, array_dtype) assert result_dtype == expected_dtype From f626d534a2a2824a5de0bbf77da3d0e9d193f9c1 Mon Sep 17 00:00:00 2001 From: Anton Date: Sat, 2 Sep 2023 18:42:02 +0300 Subject: [PATCH 26/31] Add reduction operations. Add array methods --- arrayfire/__init__.py | 12 +- arrayfire/_array_helpers.py | 31 --- arrayfire/array_object.py | 182 ++++++++++++++---- arrayfire/backend/_clib_wrapper/__init__.py | 21 +- arrayfire/backend/_clib_wrapper/_indexing.py | 2 +- .../_clib_wrapper/_reduction_operations.py | 76 ++++++++ arrayfire/backend/_clib_wrapper/_unsorted.py | 55 ++++-- arrayfire/library/data.py | 2 +- arrayfire/library/operators.py | 2 +- arrayfire/library/random.py | 2 +- arrayfire/library/utils.py | 53 +++-- .../library/vector_algorithms/__init__.py | 3 + .../vector_algorithms/reduction_operations.py | 81 ++++++++ tests/test_data.py | 2 +- tests/vector_algorithms/__init__.py | 0 .../test_reduction_operations.py | 29 +++ 16 files changed, 438 insertions(+), 115 deletions(-) delete mode 100644 arrayfire/_array_helpers.py create mode 100644 arrayfire/library/vector_algorithms/__init__.py create mode 100644 arrayfire/library/vector_algorithms/reduction_operations.py create mode 100644 tests/vector_algorithms/__init__.py create mode 100644 tests/vector_algorithms/test_reduction_operations.py diff --git a/arrayfire/__init__.py b/arrayfire/__init__.py index 0ece943..9f64689 100755 --- a/arrayfire/__init__.py +++ b/arrayfire/__init__.py @@ -7,8 +7,8 @@ __all__ += ["__arrayfire_version__"] __arrayfire_version__ = ARRAYFIRE_VERSION -__all__ += ["Array", "return_copy"] -from .array_object import Array, return_copy +__all__ += ["Array"] +from .array_object import Array __all__ += [ "int8", @@ -226,3 +226,11 @@ __all__ += ["constant", "range", "identity", "flat"] from arrayfire.library.data import constant, flat, identity, range + +__all__ += ["randu"] + +from arrayfire.library.random import randu + +__all__ += ["all_true", "any_true"] + +from arrayfire.library.vector_algorithms import all_true, any_true diff --git a/arrayfire/_array_helpers.py b/arrayfire/_array_helpers.py deleted file mode 100644 index ebccc87..0000000 --- a/arrayfire/_array_helpers.py +++ /dev/null @@ -1,31 +0,0 @@ -from collections.abc import Callable - -from typing_extensions import ParamSpec - -from arrayfire import Array - -P = ParamSpec("P") - - -def afarray_as_array(func: Callable[P, Array]) -> Callable[P, Array]: - """ - Decorator that converts a function returning an array to return an ArrayFire Array. - - Parameters - ---------- - func : Callable[P, Array] - The original function that returns an array. - - Returns - ------- - Callable[P, Array] - A decorated function that returns an ArrayFire Array. - """ - - def decorated(*args: P.args, **kwargs: P.kwargs) -> Array: - out = Array() - result = func(*args, **kwargs) - out.arr = result # type: ignore[assignment] - return out - - return decorated diff --git a/arrayfire/array_object.py b/arrayfire/array_object.py index 0966036..aa2572d 100755 --- a/arrayfire/array_object.py +++ b/arrayfire/array_object.py @@ -1,12 +1,15 @@ from __future__ import annotations import array as py_array +import ctypes from collections.abc import Callable from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, ParamSpec, cast from .backend import _clib_wrapper as wrapper -from .dtypes import CType, Dtype, c_api_value_to_dtype, float32, str_to_dtype +from .dtypes import CType, Dtype +from .dtypes import bool as afbool +from .dtypes import c_api_value_to_dtype, float32, str_to_dtype from .library.device import PointerSource if TYPE_CHECKING: @@ -22,16 +25,32 @@ class _ArrayBuffer: length: int = 0 -class return_copy: - # TODO merge with process_c_function in array_object - def __init__(self, func: Callable) -> None: - self.func = func +P = ParamSpec("P") - def __call__(self, *x: Array) -> Array: + +def afarray_as_array(func: Callable[P, Array]) -> Callable[P, Array]: + """ + Decorator that converts a function returning an array to return an ArrayFire Array. + + Parameters + ---------- + func : Callable[P, Array] + The original function that returns an array. + + Returns + ------- + Callable[P, Array] + A decorated function that returns an ArrayFire Array. + """ + + def decorated(*args: P.args, **kwargs: P.kwargs) -> Array: out = Array() - out.arr = self.func(*x) + result = func(*args, **kwargs) + out.arr = result # type: ignore[assignment] return out + return decorated + class Array: def __init__( @@ -43,6 +62,7 @@ def __init__( offset: CType | None = None, strides: tuple[int, ...] | None = None, ) -> None: + self.arr = ctypes.c_void_p(0) # FIXME _no_initial_dtype = False # HACK, FIXME if isinstance(dtype, str): @@ -744,22 +764,19 @@ def __getitem__(self, key: IndexKey, /) -> Array: out : Array An array containing the accessed value(s). The returned array must have the same data type as self. """ - - from .dtypes import bool - # TODO # API Specification - key: Union[int, slice, ellipsis, tuple[Union[int, slice, ellipsis], ...], array]. # consider using af.span to replace ellipsis during refactoring out = Array() ndims = self.ndim - if isinstance(key, Array) and key == bool.c_api_value: + if isinstance(key, Array) and key == afbool.c_api_value: ndims = 1 if wrapper.count_all(key.arr) == 0: # HACK was count() method before return out # HACK known issue - out.arr = wrapper.index_gen(self.arr, ndims, key, wrapper.get_indices(key)) # type: ignore[arg-type] + out.arr = wrapper.index_gen(self.arr, ndims, wrapper.get_indices(key)) # type: ignore[arg-type] return out def __index__(self) -> int: @@ -773,9 +790,37 @@ def __int__(self) -> int: def __len__(self) -> int: return self.shape[0] if self.shape else 0 + # BUG def __setitem__(self, key: IndexKey, value: int | float | bool | Array, /) -> None: - # TODO - return NotImplemented # type: ignore[return-value] # FIXME + out = Array() + ndims = self.ndim + + is_array_with_bool = isinstance(key, Array) and type(key) == afbool + + if is_array_with_bool: + ndims = 1 + num = wrapper.count_all(key.arr) + if num == 0: + return + + if isinstance(value, int | float | complex | bool): + dims = _get_processed_index(key, self.shape) + if is_array_with_bool: + ndims = 1 + other_arr = wrapper.create_constant_array(value, (int(num),), self.dtype) + else: + other_arr = wrapper.create_constant_array(value, dims, self.dtype) + del_other = True + else: + other_arr = value.arr + del_other = False + + indices = wrapper.get_indices(key) + out.arr = wrapper.assign_gen(self.arr, other_arr, ndims, indices) + wrapper.release_array(self.arr) + if del_other: + wrapper.release_array(other_arr) + self.arr = out.arr def __str__(self) -> str: # TODO change the look of array str. E.g., like np.array @@ -788,6 +833,13 @@ def __repr__(self) -> str: # TODO change the look of array representation. E.g., like np.array return wrapper.array_as_str(self.arr) + def __del__(self) -> None: + if not self.arr.value: + return + + wrapper.release_array(self.arr) + self.arr.value = 0 + def to_device(self, device: Any, /, *, stream: int | Any = None) -> Array: # TODO implementation and change device type from Any to Device return NotImplemented @@ -812,18 +864,14 @@ def device(self) -> Any: return NotImplemented @property - def mT(self) -> Array: - # TODO - return NotImplemented - - @property + @afarray_as_array def T(self) -> Array: """ Transpose of the array. Returns ------- - out : Array + Array Two-dimensional array whose first and last dimensions (axes) are permuted in reverse order relative to original array. The returned array must have the same data type as the original array. @@ -836,9 +884,12 @@ def T(self) -> Array: raise TypeError(f"Array should be at least 2-dimensional. Got {self.ndim}-dimensional array") # TODO add check if out.dtype == self.dtype - out = Array() - out.arr = wrapper.transpose(self.arr, False) - return out + return cast(Array, wrapper.transpose(self.arr, False)) + + @property + @afarray_as_array + def H(self) -> Array: + return cast(Array, wrapper.transpose(self.arr, True)) @property def size(self) -> int: @@ -847,7 +898,7 @@ def size(self) -> int: Returns ------- - out : int + int Number of elements in an array Note @@ -862,7 +913,7 @@ def ndim(self) -> int: """ Number of array dimensions (axes). - out : int + int Number of array dimensions (axes). """ return wrapper.get_numdims(self.arr) @@ -874,13 +925,37 @@ def shape(self) -> tuple[int, ...]: Returns ------- - out : tuple[int, ...] + tuple[int, ...] Array dimensions. """ # NOTE skipping passing any None values return wrapper.get_dims(self.arr)[: self.ndim] - def scalar(self) -> int | float | bool | complex | None: + @property + def offset(self) -> int: + """ + Return the offset of the first element relative to the raw pointer. + + Returns + ------- + int + The offset in number of elements. + """ + return wrapper.get_offset(self.arr) + + @property + def strides(self) -> tuple[int, ...]: + """ + Return the distance in bytes between consecutive elements for each dimension. + + Returns + ------- + tuple[int, ...] + The strides for each dimension. + """ + return wrapper.get_strides(self.arr)[: self.ndim] + + def scalar(self) -> int | float | bool | complex | None: # FIXME """ Return the first element of the array """ @@ -924,6 +999,7 @@ def to_ctype_array(self, row_major: bool = False) -> CArray: array = _reorder(self) if row_major else self return wrapper.get_data_ptr(array.arr, array.size, array.dtype) + @afarray_as_array def copy(self) -> Array: """ Performs a deep copy of the array. @@ -934,14 +1010,14 @@ def copy(self) -> Array: An identical copy of self. """ - return return_copy(wrapper.copy_array)(self) # type: ignore[return-value] + return cast(Array, wrapper.copy_array(self.arr)) @classmethod def from_afarray(cls, array: wrapper.AFArrayType) -> None: cls.arr = array -IndexKey = int | slice | tuple[int | slice, ...] | Array +IndexKey = int | float | complex | bool | wrapper.ParallelRange | slice | tuple[int | slice, ...] | Array def _reorder(array: Array) -> Array: @@ -958,9 +1034,8 @@ def _metadata_string(dtype: Dtype, dims: tuple[int, ...] | None = None) -> str: return "arrayfire.Array()\n" f"Type: {dtype.name}\n" f"Dims: {str(dims) if dims else ''}" +@afarray_as_array def _process_c_function(lhs: int | float | Array, rhs: int | float | Array, c_function: Any) -> Array: - out = Array() - if isinstance(lhs, Array) and isinstance(rhs, Array): lhs_array = lhs.arr rhs_array = rhs.arr @@ -976,5 +1051,46 @@ def _process_c_function(lhs: int | float | Array, rhs: int | float | Array, c_fu else: raise TypeError(f"{type(rhs)} is not supported and can not be passed to C binary function.") - out.arr = c_function(lhs_array, rhs_array) + return cast(Array, c_function(lhs_array, rhs_array)) + + +def _get_processed_index(key: IndexKey, shape: tuple[int, ...]) -> tuple[int, ...]: + if isinstance(key, tuple): + return tuple(_index_to_afindex(key[i], shape[i]) for i in range(len(key))) + + return (_index_to_afindex(key, shape[0]),) + shape[1:] + + +def _index_to_afindex(key: int | float | complex | bool | slice | wrapper.ParallelRange | Array, dim: int) -> int: + if isinstance(key, int | float | complex | bool): + out = 1 + elif isinstance(key, slice): + out = _slice_to_length(key, dim) + elif isinstance(key, wrapper.ParallelRange): + out = _slice_to_length(key.S, dim) + elif isinstance(key, Array): + if key.dtype == afbool: + out = int(sum(key)) # FIXME af.sum + else: + out = key.size + else: + raise IndexError(f"Invalid key type {type(key)}.") + return out + + +def _slice_to_length(key: slice, dim: int) -> int: + if key.start is None: + start = 0 + elif key.start < 0: + start = dim - key.start + + if key.stop is None: + stop = dim + elif key.stop < 0: + stop = dim - key.stop + + if key.step is None: + step = 1 + + return int(((stop - start - 1) / step) + 1) diff --git a/arrayfire/backend/_clib_wrapper/__init__.py b/arrayfire/backend/_clib_wrapper/__init__.py index e0ce07e..54d302f 100755 --- a/arrayfire/backend/_clib_wrapper/__init__.py +++ b/arrayfire/backend/_clib_wrapper/__init__.py @@ -189,9 +189,9 @@ from ._unsorted import ( af_range, - all_true, - all_true_all, array_as_str, + assign_gen, + cast, copy_array, create_array, create_handle, @@ -207,11 +207,14 @@ get_elements, get_last_error, get_numdims, + get_offset, get_scalar, get_size_of, + get_strides, identity, index_gen, is_empty, + release_array, reorder, retain_array, set_backend, @@ -225,14 +228,24 @@ __all__ += ["count_all"] -from ._reduction_operations import count_all +from ._reduction_operations import ( + all_true, + all_true_all, + any_true, + any_true_all, + count_all, + sum, + sum_all, + sum_nan, + sum_nan_all, +) __all__ += ["create_constant_array"] from ._constant_array import create_constant_array __all__ += ["get_indices"] -from ._indexing import get_indices +from ._indexing import ParallelRange, get_indices __all__ += ["create_random_engine", "release_random_engine", "AFRandomEngine"] from ._random import ( diff --git a/arrayfire/backend/_clib_wrapper/_indexing.py b/arrayfire/backend/_clib_wrapper/_indexing.py index bca8678..a2d9d08 100755 --- a/arrayfire/backend/_clib_wrapper/_indexing.py +++ b/arrayfire/backend/_clib_wrapper/_indexing.py @@ -246,7 +246,7 @@ def __setitem__(self, idx: int, value: IndexStructure) -> None: self.idxs[idx] = value -def get_indices(key: int | slice | tuple[int | slice, ...]) -> CIndexStructure: +def get_indices(key: int | slice | tuple[int | slice, ...]) -> CIndexStructure: # BUG indices = CIndexStructure() if isinstance(key, tuple): diff --git a/arrayfire/backend/_clib_wrapper/_reduction_operations.py b/arrayfire/backend/_clib_wrapper/_reduction_operations.py index 900be6c..1f33fe8 100755 --- a/arrayfire/backend/_clib_wrapper/_reduction_operations.py +++ b/arrayfire/backend/_clib_wrapper/_reduction_operations.py @@ -22,3 +22,79 @@ def _reduce_all(arr: AFArrayType, c_func: Callable) -> int | float | complex: imag = ctypes.c_double(0) safe_call(c_func(ctypes.pointer(real), ctypes.pointer(imag), arr)) return real.value if imag.value == 0 else real.value + imag.value * 1j + + +def all_true(arr: AFArrayType, axis: int, /) -> AFArrayType: + """ + source: https://arrayfire.org/docs/group__reduce__func__all__true.htm#ga068708be5177a0aa3788af140bb5ebd6 + """ + out = ctypes.c_void_p(0) + safe_call(_backend.clib.af_all_true(ctypes.pointer(out), arr, axis)) + return out + + +def all_true_all(arr: AFArrayType, /) -> complex: + """ + source: https://arrayfire.org/docs/group__reduce__func__all__true.htm#ga068708be5177a0aa3788af140bb5ebd6 + """ + real = ctypes.c_double(0) + imag = ctypes.c_double(0) + safe_call(_backend.clib.af_all_true(ctypes.pointer(real), ctypes.pointer(imag), arr)) + return real.value # NOTE imag is always set to 0 in C library + + +def any_true(arr: AFArrayType, axis: int, /) -> AFArrayType: + """ + source: https://arrayfire.org/docs/group__reduce__func__any__true.htm#ga7c275cda2cfc8eb0bd20ea86472ca0d5 + """ + out = ctypes.c_void_p(0) + safe_call(_backend.clib.af_all_true(ctypes.pointer(out), arr, axis)) + return out + + +def any_true_all(arr: AFArrayType, /) -> complex: + """ + source: https://arrayfire.org/docs/group__reduce__func__any__true.htm#ga47d991276bb5bf8cdba8340e8751e536 + """ + real = ctypes.c_double(0) + imag = ctypes.c_double(0) + safe_call(_backend.clib.af_all_true(ctypes.pointer(real), ctypes.pointer(imag), arr)) + return real.value # NOTE imag is always set to 0 in C library + + +def sum(arr: AFArrayType, axis: int, /) -> AFArrayType: + """ + source: https://arrayfire.org/docs/group__reduce__func__sum.htm#gacd4917c2e916870ebdf54afc2f61d533 + """ + out = ctypes.c_void_p(0) + safe_call(_backend.clib.af_sum(ctypes.pointer(out), arr, axis)) + return out + + +def sum_all(arr: AFArrayType, /) -> complex: + """ + source: https://arrayfire.org/docs/group__reduce__func__sum.htm#gabc009d04df0faf29ba1e381c7badde58 + """ + real = ctypes.c_double(0) + imag = ctypes.c_double(0) + safe_call(_backend.clib.sum_all(ctypes.pointer(real), ctypes.pointer(imag), arr)) + return real.value # NOTE imag is always set to 0 in C library + + +def sum_nan(arr: AFArrayType, axis: int, nan_value: float, /) -> AFArrayType: + """ + source: https://arrayfire.org/docs/group__reduce__func__sum.htm#ga52461231e2d9995f689b7f23eea0e798 + """ + out = ctypes.c_void_p(0) + safe_call(_backend.clib.af_sum_nan(ctypes.pointer(out), arr, axis, ctypes.c_double(nan_value))) + return out + + +def sum_nan_all(arr: AFArrayType, nan_value: float, /) -> complex: + """ + source: https://arrayfire.org/docs/group__reduce__func__sum.htm#gabc009d04df0faf29ba1e381c7badde58 + """ + real = ctypes.c_double(0) + imag = ctypes.c_double(0) + safe_call(_backend.clib.sum_all(ctypes.pointer(real), ctypes.pointer(imag), arr, ctypes.c_double(nan_value))) + return real.value # NOTE imag is always set to 0 in C library diff --git a/arrayfire/backend/_clib_wrapper/_unsorted.py b/arrayfire/backend/_clib_wrapper/_unsorted.py index 946cd72..ebdbc0f 100755 --- a/arrayfire/backend/_clib_wrapper/_unsorted.py +++ b/arrayfire/backend/_clib_wrapper/_unsorted.py @@ -1,7 +1,8 @@ from __future__ import annotations import ctypes -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any +from typing import cast as typing_cast from arrayfire.backend._backend import _backend from arrayfire.dtypes import CShape, CType, Dtype, c_dim_t, to_str @@ -162,13 +163,29 @@ def get_dims(arr: AFArrayType) -> tuple[int, ...]: return (d0.value, d1.value, d2.value, d3.value) +def get_strides(arr: AFArrayType) -> tuple[int, ...]: + """ + source: https://arrayfire.org/docs/group__internal__func__strides.htm#gaff91b376156ce0ad7180af6e68faab51 + """ + s0 = c_dim_t(0) + s1 = c_dim_t(0) + s2 = c_dim_t(0) + s3 = c_dim_t(0) + safe_call( + _backend.clib.af_get_strides( + ctypes.pointer(s0), ctypes.pointer(s1), ctypes.pointer(s2), ctypes.pointer(s3), arr + ) + ) + return (s0.value, s1.value, s2.value, s3.value) + + def get_scalar(arr: AFArrayType, dtype: Dtype, /) -> int | float | complex | bool | None: """ source: https://arrayfire.org/docs/group__c__api__mat.htm#gaefe2e343a74a84bd43b588218ecc09a3 """ out = dtype.c_type() safe_call(_backend.clib.af_get_scalar(ctypes.pointer(out), arr)) - return cast(int | float | complex | bool | None, out.value) + return typing_cast(int | float | complex | bool | None, out.value) def is_empty(arr: AFArrayType) -> bool: @@ -199,13 +216,21 @@ def copy_array(arr: AFArrayType) -> AFArrayType: return out +def cast(arr: AFArrayType, dtype: Dtype, /) -> AFArrayType: + """ + source: https://arrayfire.org/docs/group__arith__func__cast.htm#gab0cb307d6f9019ac8cbbbe9b8a4d6b9b + """ + out = ctypes.c_void_p(0) + safe_call(_backend.clib.af_cast(ctypes.pointer(out), arr, dtype.c_api_value)) + return out + + # Arrayfire Functions def index_gen( arr: AFArrayType, ndims: int, - key: int | slice | tuple[int | slice, ...], indices: Any, /, ) -> AFArrayType: @@ -287,23 +312,29 @@ def flat(arr: AFArrayType, /) -> AFArrayType: return out -def all_true(arr: AFArrayType, axis: int, /) -> AFArrayType: +def assign_gen(lhs: AFArrayType, rhs: AFArrayType, ndims: int, indices: Any, /) -> AFArrayType: """ - source: https://arrayfire.org/docs/group__reduce__func__all__true.htm#ga068708be5177a0aa3788af140bb5ebd6 + source: https://arrayfire.org/docs/group__index__func__assign.htm#ga93cd5199c647dce0e3b823f063b352ae """ out = ctypes.c_void_p(0) - safe_call(_backend.clib.af_all_true(ctypes.pointer(out), arr, axis)) + safe_call(_backend.clib.af_assign_gen(ctypes.pointer(out), lhs, ndims, indices.pointer, rhs)) return out -def all_true_all(arr: AFArrayType, /) -> complex: +def release_array(arr: AFArrayType, /) -> None: """ - source: https://arrayfire.org/docs/group__reduce__func__all__true.htm#ga068708be5177a0aa3788af140bb5ebd6 + source: https://arrayfire.org/docs/group__c__api__mat.htm#gad6c58648ed0db398e170dabf045e8309 """ - real = ctypes.c_double(0) - imag = ctypes.c_double(0) - safe_call(_backend.clib.af_all_true(ctypes.pointer(real), ctypes.pointer(imag), arr)) - return real.value # NOTE imag is always set to 0 in C library + safe_call(_backend.clib.af_release_array(arr)) + + +def get_offset(arr: AFArrayType, /) -> int: + """ + source: https://arrayfire.org/docs/group__internal__func__offset.htm#ga303cb334026bdb5cab86e038951d6a5a + """ + out = c_dim_t(0) + safe_call(_backend.clib.af_get_offset(ctypes.pointer(out), arr)) + return out.value # Safe Call Wrapper diff --git a/arrayfire/library/data.py b/arrayfire/library/data.py index 71b24f1..b8080a0 100644 --- a/arrayfire/library/data.py +++ b/arrayfire/library/data.py @@ -1,7 +1,7 @@ from typing import cast from arrayfire import Array -from arrayfire._array_helpers import afarray_as_array +from arrayfire.array_object import afarray_as_array from arrayfire.backend import _clib_wrapper as wrapper from arrayfire.dtypes import Dtype, float32 diff --git a/arrayfire/library/operators.py b/arrayfire/library/operators.py index 3b86f52..954b1f6 100755 --- a/arrayfire/library/operators.py +++ b/arrayfire/library/operators.py @@ -3,7 +3,7 @@ from typing import cast from arrayfire import Array -from arrayfire._array_helpers import afarray_as_array +from arrayfire.array_object import afarray_as_array from arrayfire.backend import _clib_wrapper as wrapper from arrayfire.dtypes import is_complex_dtype diff --git a/arrayfire/library/random.py b/arrayfire/library/random.py index 33c463c..35fd66a 100644 --- a/arrayfire/library/random.py +++ b/arrayfire/library/random.py @@ -4,7 +4,7 @@ from typing import cast from arrayfire import Array -from arrayfire._array_helpers import afarray_as_array +from arrayfire.array_object import afarray_as_array from arrayfire.backend import _clib_wrapper as wrapper from arrayfire.dtypes import Dtype, float32 diff --git a/arrayfire/library/utils.py b/arrayfire/library/utils.py index cfafd5e..b1e08e5 100644 --- a/arrayfire/library/utils.py +++ b/arrayfire/library/utils.py @@ -1,46 +1,43 @@ -from typing import cast +from typing import cast as typing_cast -from arrayfire import Array -from arrayfire._array_helpers import afarray_as_array +from arrayfire.array_object import Array, afarray_as_array from arrayfire.backend import _clib_wrapper as wrapper +from arrayfire.dtypes import Dtype @afarray_as_array -def _all_true(array: Array, axis: int, /) -> Array: - result = wrapper.all_true(array.arr, axis) - return cast(Array, result) - - -def all_true(array: Array, axis: int | None = None) -> bool | Array: +def cast(array: Array, dtype: Dtype, /) -> Array: """ - Check if all the elements along a specified dimension are true. + Cast an array to a specified type. Parameters ---------- array : Array - Multi-dimensional ArrayFire array. - - axis : int, optional, default: None - Dimension along which the product is required. + Multi-dimensional arrayfire array to be cast. + dtype : Dtype + The target data type to which the array will be cast. Must be one of the following: + - Dtype.int8 for signed 8-bit integer + - Dtype.int16 for signed 16-bit integer + - Dtype.int32 for signed 32-bit integer + - Dtype.int64 for signed 64-bit integer + - Dtype.uint8 for unsigned 8-bit integer + - Dtype.uint16 for unsigned 16-bit integer + - Dtype.uint32 for unsigned 32-bit integer + - Dtype.uint64 for unsigned 64-bit integer + - Dtype.float16 for 16-bit floating-point + - Dtype.float32 for 32-bit floating-point + - Dtype.float64 for 64-bit floating-point + - Dtype.complex64 for 64-bit complex number + - Dtype.complex128 for 128-bit complex number + - Dtype.bool for boolean Returns ------- - bool | Array - An ArrayFire array containing True if all elements in `array` along the specified dimension are True. - If `axis` is `None`, the output is True if `array` does not have any zeros, else False. - - Note - ---- - If `axis` is `None`, output is True if the array does not have any zeros, else False. + Array + An array containing the values from `array` after conversion to the specified `dtype`. """ - if axis is None: - return wrapper.all_true_all(array.arr) - - return _all_true(array, axis) - + return typing_cast(Array, wrapper.cast(array.arr, dtype)) -# from time import time -# import math # def timeit(af_func, *args): # """ diff --git a/arrayfire/library/vector_algorithms/__init__.py b/arrayfire/library/vector_algorithms/__init__.py new file mode 100644 index 0000000..8941ff4 --- /dev/null +++ b/arrayfire/library/vector_algorithms/__init__.py @@ -0,0 +1,3 @@ +__all__ = ["any_true", "all_true"] + +from .reduction_operations import all_true, any_true diff --git a/arrayfire/library/vector_algorithms/reduction_operations.py b/arrayfire/library/vector_algorithms/reduction_operations.py new file mode 100644 index 0000000..141d48f --- /dev/null +++ b/arrayfire/library/vector_algorithms/reduction_operations.py @@ -0,0 +1,81 @@ +from collections.abc import Callable +from typing import Any, cast + +from arrayfire import Array +from arrayfire.array_object import afarray_as_array +from arrayfire.backend import _clib_wrapper as wrapper + + +@afarray_as_array +def _reduce_to_array(func: Callable, array: Array, axis: int, /, **kwargs: Any) -> Array: + result = func(array.arr, axis, **kwargs) + return cast(Array, result) + + +def all_true(array: Array, axis: int | None = None) -> bool | Array: + """ + Check if all the elements along a specified dimension are true. + + Parameters + ---------- + array : Array + Multi-dimensional ArrayFire array. + + axis : int, optional, default: None + Dimension along which the product is required. + + Returns + ------- + bool | Array + An ArrayFire array containing True if all elements in `array` along the specified dimension are True. + If `axis` is `None`, the output is True if `array` does not have any zeros, else False. + + Note + ---- + If `axis` is `None`, output is True if the array does not have any zeros, else False. + """ + if axis is None: + return bool(wrapper.all_true_all(array.arr)) + + return _reduce_to_array(wrapper.all_true, array, axis) + + +def any_true(array: Array, axis: int | None = None) -> bool | Array: + """ + Check if any of the elements along a specified dimension are true. + + Parameters + ---------- + array : Array + Multi-dimensional ArrayFire array. + + axis : int, optional, default: None + Dimension along which the product is required. + + Returns + ------- + bool | Array + An ArrayFire array containing True if any of the elements in `array` along the specified dimension are True. + If `axis` is `None`, the output is True if `array` does not have any zeros, else False. + + Note + ---- + If `axis` is `None`, output is True if the array does not have any zeros, else False. + """ + if axis is None: + return bool(wrapper.any_true_all(array.arr)) + + return _reduce_to_array(wrapper.any_true, array, axis) + + +def sum(array: Array, /, *, axis: int | None = None, nan_value: float | None = None) -> bool | Array: + if axis is None: + if nan_value is None: + return bool(wrapper.sum_all(array.arr)) + + return bool(wrapper.sum_nan_all(array.arr, nan_value)) + + if nan_value is None: + return _reduce_to_array(wrapper.sum, array, axis) + + return _reduce_to_array(wrapper.sum_nan, array, axis, nan_value=nan_value) diff --git a/tests/test_data.py b/tests/test_data.py index c7f4022..86eac95 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -3,7 +3,7 @@ from arrayfire import Array from arrayfire.dtypes import int64 from arrayfire.library import data, random -from arrayfire.library.utils import all_true +from arrayfire.library.vector_algorithms import all_true # Test cases for the constant function diff --git a/tests/vector_algorithms/__init__.py b/tests/vector_algorithms/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/vector_algorithms/test_reduction_operations.py b/tests/vector_algorithms/test_reduction_operations.py new file mode 100644 index 0000000..79ba73d --- /dev/null +++ b/tests/vector_algorithms/test_reduction_operations.py @@ -0,0 +1,29 @@ +from typing import TYPE_CHECKING + +import pytest + +from arrayfire import Array +from arrayfire.library import data +from arrayfire.library import vector_algorithms as va + +# if TYPE_CHECKING: +# from arrayfire import Array + + +@pytest.fixture +def true_array() -> Array: + return data.constant(1, (5, 5)) + + +# BUG Array.__setitem__ +# @pytest.fixture +# def false_array() -> Array: +# arr = data.constant(1, (5, 5)) +# arr[2, 2] = 0 # Set one element to False +# return arr + + +# BUG Array.to_list() +# def test_all_true_with_axis(true_array: Array) -> None: +# result = va.all_true(true_array, axis=0) +# assert result.to_list() == [True, True, True, True, True] From 850e658d938c0e126a4192bb95006d1954fc1cec Mon Sep 17 00:00:00 2001 From: Anton Date: Sat, 2 Sep 2023 20:58:41 +0300 Subject: [PATCH 27/31] Fix bug with sum method. Add some test cases --- arrayfire/array_object.py | 5 ++- .../_clib_wrapper/_reduction_operations.py | 4 +-- .../library/vector_algorithms/__init__.py | 4 +-- .../vector_algorithms/reduction_operations.py | 32 +++++++++++++++++-- .../test_reduction_operations.py | 27 ++++++++++++---- 5 files changed, 58 insertions(+), 14 deletions(-) diff --git a/arrayfire/array_object.py b/arrayfire/array_object.py index aa2572d..493f53f 100755 --- a/arrayfire/array_object.py +++ b/arrayfire/array_object.py @@ -817,6 +817,7 @@ def __setitem__(self, key: IndexKey, value: int | float | bool | Array, /) -> No indices = wrapper.get_indices(key) out.arr = wrapper.assign_gen(self.arr, other_arr, ndims, indices) + wrapper.release_array(self.arr) if del_other: wrapper.release_array(other_arr) @@ -1070,7 +1071,9 @@ def _index_to_afindex(key: int | float | complex | bool | slice | wrapper.Parall out = _slice_to_length(key.S, dim) elif isinstance(key, Array): if key.dtype == afbool: - out = int(sum(key)) # FIXME af.sum + from arrayfire.library.vector_algorithms import sum as af_sum + + out = int(af_sum(key)) else: out = key.size else: diff --git a/arrayfire/backend/_clib_wrapper/_reduction_operations.py b/arrayfire/backend/_clib_wrapper/_reduction_operations.py index 1f33fe8..8a341a3 100755 --- a/arrayfire/backend/_clib_wrapper/_reduction_operations.py +++ b/arrayfire/backend/_clib_wrapper/_reduction_operations.py @@ -77,7 +77,7 @@ def sum_all(arr: AFArrayType, /) -> complex: """ real = ctypes.c_double(0) imag = ctypes.c_double(0) - safe_call(_backend.clib.sum_all(ctypes.pointer(real), ctypes.pointer(imag), arr)) + safe_call(_backend.clib.af_sum_all(ctypes.pointer(real), ctypes.pointer(imag), arr)) return real.value # NOTE imag is always set to 0 in C library @@ -96,5 +96,5 @@ def sum_nan_all(arr: AFArrayType, nan_value: float, /) -> complex: """ real = ctypes.c_double(0) imag = ctypes.c_double(0) - safe_call(_backend.clib.sum_all(ctypes.pointer(real), ctypes.pointer(imag), arr, ctypes.c_double(nan_value))) + safe_call(_backend.clib.af_sum_all(ctypes.pointer(real), ctypes.pointer(imag), arr, ctypes.c_double(nan_value))) return real.value # NOTE imag is always set to 0 in C library diff --git a/arrayfire/library/vector_algorithms/__init__.py b/arrayfire/library/vector_algorithms/__init__.py index 8941ff4..06a3bd4 100644 --- a/arrayfire/library/vector_algorithms/__init__.py +++ b/arrayfire/library/vector_algorithms/__init__.py @@ -1,3 +1,3 @@ -__all__ = ["any_true", "all_true"] +__all__ = ["any_true", "all_true", "sum"] -from .reduction_operations import all_true, any_true +from .reduction_operations import all_true, any_true, sum diff --git a/arrayfire/library/vector_algorithms/reduction_operations.py b/arrayfire/library/vector_algorithms/reduction_operations.py index 141d48f..aab2cf6 100644 --- a/arrayfire/library/vector_algorithms/reduction_operations.py +++ b/arrayfire/library/vector_algorithms/reduction_operations.py @@ -68,12 +68,38 @@ def any_true(array: Array, axis: int | None = None) -> bool | Array: return _reduce_to_array(wrapper.any_true, array, axis) -def sum(array: Array, /, *, axis: int | None = None, nan_value: float | None = None) -> bool | Array: +def sum(array: Array, /, *, axis: int | None = None, nan_value: float | None = None) -> int | float | complex | Array: + """ + Calculate the sum of elements along a specified dimension or the entire array. + + Parameters + ---------- + array : Array + Multi-dimensional array to calculate the sum of. + + axis : int or None, optional, default: None + The dimension along which the sum is calculated. + If None, the sum of all elements in the entire array is calculated. + + nan_value : float or None, optional, default: None + The value to replace NaN (Not-a-Number) values in the array before summing. If None, NaN values are ignored. + + Returns + ------- + result : Array or bool or scalar + - If `axis` is None and `nan_value` is None, returns a boolean indicating if the sum contains NaN or Inf. + - If `axis` is None and `nan_value` is not None, returns a boolean indicating if the sum contains NaN or + Inf after replacing NaN values. + - If `axis` is not None, returns an Array containing the sum along the specified dimension. + - If `axis` is not None and `nan_value` is not None, returns an Array containing the sum along the specified + dimension after replacing NaN values. + """ + if axis is None: if nan_value is None: - return bool(wrapper.sum_all(array.arr)) + return wrapper.sum_all(array.arr) - return bool(wrapper.sum_nan_all(array.arr, nan_value)) + return wrapper.sum_nan_all(array.arr, nan_value) if nan_value is None: return _reduce_to_array(wrapper.sum, array, axis) diff --git a/tests/vector_algorithms/test_reduction_operations.py b/tests/vector_algorithms/test_reduction_operations.py index 79ba73d..899c3a4 100644 --- a/tests/vector_algorithms/test_reduction_operations.py +++ b/tests/vector_algorithms/test_reduction_operations.py @@ -15,15 +15,30 @@ def true_array() -> Array: return data.constant(1, (5, 5)) -# BUG Array.__setitem__ -# @pytest.fixture -# def false_array() -> Array: -# arr = data.constant(1, (5, 5)) -# arr[2, 2] = 0 # Set one element to False -# return arr +@pytest.fixture +def false_array() -> Array: + arr = data.constant(1, (5, 5)) + arr[2, 2] = 0 # Set one element to False + return arr # BUG Array.to_list() # def test_all_true_with_axis(true_array: Array) -> None: # result = va.all_true(true_array, axis=0) # assert result.to_list() == [True, True, True, True, True] + + +# Test cases for the sum function + + +@pytest.fixture +def sample_array() -> Array: + return Array([1, 2, 3, 4]) + + +def test_sum_no_axis_no_nan_value(sample_array: Array) -> None: + result = va.sum(sample_array) + assert result == 10 # Sum of all elements is 1 + 2 + 3 + 4 = 10 + + +# TODO add more test cases From 771aa96aa16c2ade69e82401b5fea4423cf30be3 Mon Sep 17 00:00:00 2001 From: Anton Date: Sat, 2 Sep 2023 21:56:21 +0300 Subject: [PATCH 28/31] Fix typing --- arrayfire/array_object.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arrayfire/array_object.py b/arrayfire/array_object.py index 493f53f..6612cfe 100755 --- a/arrayfire/array_object.py +++ b/arrayfire/array_object.py @@ -1073,7 +1073,7 @@ def _index_to_afindex(key: int | float | complex | bool | slice | wrapper.Parall if key.dtype == afbool: from arrayfire.library.vector_algorithms import sum as af_sum - out = int(af_sum(key)) + out = int(af_sum(key)) # type: ignore[arg-type] else: out = key.size else: From 895fabd461f580a9a8add05178c1d5bae3d9ecbe Mon Sep 17 00:00:00 2001 From: Anton Date: Sun, 3 Sep 2023 00:42:55 +0300 Subject: [PATCH 29/31] On-call changes. Added vanilla sync, get_device --- .github/workflows/ci.yml | 2 +- arrayfire/backend/_clib_wrapper/__init__.py | 2 +- .../backend/_clib_wrapper/_constant_array.py | 3 ++- arrayfire/backend/_clib_wrapper/_unsorted.py | 22 ++++++++++++++++++- arrayfire/library/device.py | 10 +++++++++ setup.py | 4 ++-- 6 files changed, 37 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d43f12b..ccb1c9f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,7 +13,7 @@ on: - master env: - DEFAULT_PYTHON_VERSION: "3.8" + DEFAULT_PYTHON_VERSION: "3.10" defaults: run: diff --git a/arrayfire/backend/_clib_wrapper/__init__.py b/arrayfire/backend/_clib_wrapper/__init__.py index 54d302f..a5ffaa2 100755 --- a/arrayfire/backend/_clib_wrapper/__init__.py +++ b/arrayfire/backend/_clib_wrapper/__init__.py @@ -219,7 +219,7 @@ retain_array, set_backend, transpose, - where, + where, get_device, sync ) __all__ += ["safe_call"] diff --git a/arrayfire/backend/_clib_wrapper/_constant_array.py b/arrayfire/backend/_clib_wrapper/_constant_array.py index 52b2fad..f585b96 100755 --- a/arrayfire/backend/_clib_wrapper/_constant_array.py +++ b/arrayfire/backend/_clib_wrapper/_constant_array.py @@ -78,7 +78,8 @@ def _constant(number: int | float, shape: tuple[int, ...], dtype: Dtype, /) -> A def create_constant_array(number: int | float | complex, shape: tuple[int, ...], dtype: Dtype, /) -> AFArrayType: - dtype = implicit_dtype(number, dtype) + if not dtype: + dtype = implicit_dtype(number, dtype) if isinstance(number, complex): return _constant_complex(number, shape, dtype if is_complex_dtype(dtype) else complex64) diff --git a/arrayfire/backend/_clib_wrapper/_unsorted.py b/arrayfire/backend/_clib_wrapper/_unsorted.py index ebdbc0f..7eee5a3 100755 --- a/arrayfire/backend/_clib_wrapper/_unsorted.py +++ b/arrayfire/backend/_clib_wrapper/_unsorted.py @@ -122,7 +122,7 @@ def get_ctype(arr: AFArrayType) -> int: """ source: https://arrayfire.org/docs/group__c__api__mat.htm#ga0dda6898e1c0d9a43efb56cd6a988c9b """ - out = ctypes.c_int() + out = ctypes.c_int(0) safe_call(_backend.clib.af_get_type(ctypes.pointer(out), arr)) return out.value @@ -349,6 +349,26 @@ def get_last_error() -> ctypes.c_char_p: return out +# Device + + +# FIXME +def sync(device_id: int) -> None: + """ + source: https://arrayfire.org/docs/group__device__func__sync.htm#ga9dbc7f1e99d70170ad567c480b6ddbde + """ + safe_call(_backend.clib.af_sync(device_id)) + + +def get_device() -> int: + """ + source: https://arrayfire.org/docs/group__device__func__set.htm#ga54120b126cfcb1b0b3ee25e0fc66b8a4 + """ + out = ctypes.c_int(0) + safe_call(_backend.clib.af_get_device(ctypes.pointer(out))) + return out.value + + # Backend diff --git a/arrayfire/library/device.py b/arrayfire/library/device.py index 42f8edd..ab5d96d 100644 --- a/arrayfire/library/device.py +++ b/arrayfire/library/device.py @@ -1,5 +1,7 @@ import enum +from arrayfire.backend import _clib_wrapper as wrapper + class PointerSource(enum.Enum): """ @@ -10,4 +12,12 @@ class PointerSource(enum.Enum): host = 1 # cpu +def get_device() -> int: # FIXME + return wrapper.get_device() + + +def sync(device_id: int) -> None: # FIXME + return wrapper.sync(device_id) + + supported_devices = [] diff --git a/setup.py b/setup.py index 170f5e9..806ea96 100644 --- a/setup.py +++ b/setup.py @@ -80,7 +80,7 @@ def fix_url_dependencies(req: str) -> str: "License :: OSI Approved :: BSD License", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.10", "Topic :: Scientific/Engineering", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Information Analysis", @@ -92,6 +92,6 @@ def fix_url_dependencies(req: str) -> str: install_requires=install_requirements, extras_require=extras, include_package_data=True, - python_requires=">=3.8.0", + python_requires=">=3.10.0", zip_safe=False, ) From a46a265641d1a86f3dd6ccc862f117d5dbbfb36a Mon Sep 17 00:00:00 2001 From: Anton Date: Tue, 12 Sep 2023 16:22:25 +0300 Subject: [PATCH 30/31] Hotfixes --- arrayfire/array_object.py | 8 +++----- arrayfire/backend/_clib_wrapper/__init__.py | 4 +++- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/arrayfire/array_object.py b/arrayfire/array_object.py index 6612cfe..1dab106 100755 --- a/arrayfire/array_object.py +++ b/arrayfire/array_object.py @@ -790,16 +790,14 @@ def __int__(self) -> int: def __len__(self) -> int: return self.shape[0] if self.shape else 0 - # BUG def __setitem__(self, key: IndexKey, value: int | float | bool | Array, /) -> None: - out = Array() ndims = self.ndim is_array_with_bool = isinstance(key, Array) and type(key) == afbool if is_array_with_bool: ndims = 1 - num = wrapper.count_all(key.arr) + num = wrapper.count_all(key.arr) # type: ignore[union-attr] if num == 0: return @@ -816,12 +814,12 @@ def __setitem__(self, key: IndexKey, value: int | float | bool | Array, /) -> No del_other = False indices = wrapper.get_indices(key) - out.arr = wrapper.assign_gen(self.arr, other_arr, ndims, indices) + out = wrapper.assign_gen(self.arr, other_arr, ndims, indices) wrapper.release_array(self.arr) if del_other: wrapper.release_array(other_arr) - self.arr = out.arr + self.arr = out def __str__(self) -> str: # TODO change the look of array str. E.g., like np.array diff --git a/arrayfire/backend/_clib_wrapper/__init__.py b/arrayfire/backend/_clib_wrapper/__init__.py index a5ffaa2..f409341 100755 --- a/arrayfire/backend/_clib_wrapper/__init__.py +++ b/arrayfire/backend/_clib_wrapper/__init__.py @@ -202,6 +202,7 @@ get_backend_id, get_ctype, get_data_ptr, + get_device, get_device_id, get_dims, get_elements, @@ -218,8 +219,9 @@ reorder, retain_array, set_backend, + sync, transpose, - where, get_device, sync + where, ) __all__ += ["safe_call"] From 03f0e2db2f2114831f5a939e066ecf27e1550af3 Mon Sep 17 00:00:00 2001 From: Anton Date: Fri, 15 Sep 2023 16:22:37 +0300 Subject: [PATCH 31/31] Minir changes in str notes --- arrayfire/array_object.py | 1 + arrayfire/backend/_backend_functions.py | 2 +- setup.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/arrayfire/array_object.py b/arrayfire/array_object.py index 1dab106..654f351 100755 --- a/arrayfire/array_object.py +++ b/arrayfire/array_object.py @@ -954,6 +954,7 @@ def strides(self) -> tuple[int, ...]: """ return wrapper.get_strides(self.arr)[: self.ndim] + # TODO rename front_to_host or smth. Extend doc: move first element of array from gpu to cpu def scalar(self) -> int | float | bool | complex | None: # FIXME """ Return the first element of the array diff --git a/arrayfire/backend/_backend_functions.py b/arrayfire/backend/_backend_functions.py index 58ad611..dd64530 100755 --- a/arrayfire/backend/_backend_functions.py +++ b/arrayfire/backend/_backend_functions.py @@ -90,7 +90,7 @@ def get_array_backend_name(array: Array) -> str: def get_backend_id(array: Array) -> str: - warnings.warn("Was renamed due to unintuitive function name. Now get_array_backend_name().", DeprecationWarning) + warnings.warn("Was renamed. Now get_array_backend_name() in main repo.", DeprecationWarning) return get_array_backend_name(array) diff --git a/setup.py b/setup.py index 806ea96..e519b47 100644 --- a/setup.py +++ b/setup.py @@ -87,7 +87,7 @@ def fix_url_dependencies(req: str) -> str: "Topic :: Scientific/Engineering :: Mathematics", "Topic :: Software Development :: Libraries", ], - keywords="arrayfire parallel computing gpu cpu opencl", + keywords="arrayfire parallel computing gpu cpu opencl oneapi", packages=find_packages(), install_requires=install_requirements, extras_require=extras,