From 29d5c1dc0d523a7d012c1a48a5ba06ee7d4e0115 Mon Sep 17 00:00:00 2001 From: Anton Chernyatevich Date: Tue, 13 Jun 2023 22:03:30 +0300 Subject: [PATCH 1/6] af object refactoring --- .github/workflows/ci.yml | 2 +- arrayfire/__init__.py | 9 + arrayfire/array_object.py | 935 ++++++++++++++++++++++ arrayfire/backend/__init__.py | 10 + arrayfire/backend/backend.py | 26 + arrayfire/backend/constant_array.py | 84 ++ arrayfire/backend/library.py | 238 ++++++ arrayfire/backend/operators.py | 161 ++++ arrayfire/config.py | 6 + arrayfire/device.py | 10 + arrayfire/dtypes/__init__.py | 39 + arrayfire/dtypes/functions.py | 32 + arrayfire/dtypes/helpers.py | 76 ++ arrayfire/operators.py | 25 + arrayfire/utils.py | 13 + setup.cfg | 4 +- tests/array_object/__init__.py | 0 tests/array_object/test_initialization.py | 53 ++ tests/array_object/test_methods.py | 31 + tests/array_object/test_operators.py | 121 +++ tests/test_operators.py | 20 + 21 files changed, 1892 insertions(+), 3 deletions(-) create mode 100755 arrayfire/array_object.py create mode 100644 arrayfire/backend/__init__.py create mode 100644 arrayfire/backend/backend.py create mode 100644 arrayfire/backend/constant_array.py create mode 100644 arrayfire/backend/library.py create mode 100644 arrayfire/backend/operators.py create mode 100644 arrayfire/config.py create mode 100644 arrayfire/device.py create mode 100644 arrayfire/dtypes/__init__.py create mode 100644 arrayfire/dtypes/functions.py create mode 100644 arrayfire/dtypes/helpers.py create mode 100644 arrayfire/operators.py create mode 100644 arrayfire/utils.py create mode 100644 tests/array_object/__init__.py create mode 100644 tests/array_object/test_initialization.py create mode 100644 tests/array_object/test_methods.py create mode 100644 tests/array_object/test_operators.py create mode 100644 tests/test_operators.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1300c1d..6b032a4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -9,7 +9,7 @@ on: - master env: - DEFAULT_PYTHON_VERSION: "3.10" + DEFAULT_PYTHON_VERSION: "3.8" jobs: build: diff --git a/arrayfire/__init__.py b/arrayfire/__init__.py index e69de29..675b27a 100644 --- a/arrayfire/__init__.py +++ b/arrayfire/__init__.py @@ -0,0 +1,9 @@ +__all__ = [ + # array objects + "Array", + # dtypes + "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64", "float32", "float64", + "complex64", "complex128", "bool"] + +from .array_object import Array +from .dtypes import bool, complex64, complex128, float32, float64, int16, int32, int64, uint8, uint16, uint32, uint64 diff --git a/arrayfire/array_object.py b/arrayfire/array_object.py new file mode 100755 index 0000000..5777871 --- /dev/null +++ b/arrayfire/array_object.py @@ -0,0 +1,935 @@ +from __future__ import annotations + +import array as py_array +import ctypes +import enum +from typing import Any, List, Optional, Tuple, Union + +from . import backend +from .backend import ArrayBuffer, library +from .backend.operators import count_all +from .backend.constant_array import create_constant_array +from .device import PointerSource +from .dtypes import CType +from .dtypes import bool as af_bool +from .dtypes import float32 as af_float32 +from .dtypes.helpers import Dtype, c_api_value_to_dtype, str_to_dtype + +# TODO use int | float in operators -> remove bool | complex support + + +class Array: + def __init__( + self, x: Union[None, Array, py_array.array, int, ctypes.c_void_p, List[Union[int, float]]] = None, + dtype: Union[None, Dtype, str] = None, shape: Tuple[int, ...] = (), + pointer_source: PointerSource = PointerSource.host, offset: Optional[CType] = None, + strides: Optional[Tuple[int, ...]] = None) -> None: + _no_initial_dtype = False # HACK, FIXME + + if isinstance(dtype, str): + dtype = str_to_dtype(dtype) # type: ignore[arg-type] + + if dtype is None: + _no_initial_dtype = True + dtype = af_float32 + + if x is None: + if not shape: # shape is None or empty tuple + self.arr = library.create_handle((), dtype) + return + + self.arr = library.create_handle(shape, dtype) + return + + if isinstance(x, Array): + self.arr = library.retain_array(x.arr) + return + + if isinstance(x, py_array.array): + _type_char: str = x.typecode + _array_buffer = ArrayBuffer(*x.buffer_info()) + + elif isinstance(x, list): + _array = py_array.array("f", x) # BUG [True, False] -> dtype: f32 # TODO add int and float + _type_char = _array.typecode + _array_buffer = ArrayBuffer(*_array.buffer_info()) + + elif isinstance(x, int) or isinstance(x, ctypes.c_void_p): # TODO + _array_buffer = ArrayBuffer(x if not isinstance(x, ctypes.c_void_p) else x.value) # type: ignore[arg-type] + + if not shape: + raise TypeError("Expected to receive the initial shape due to the x being a data pointer.") + + if _no_initial_dtype: + raise TypeError("Expected to receive the initial dtype due to the x being a data pointer.") + + _type_char = dtype.typecode + + else: + raise TypeError("Passed object x is an object of unsupported class.") + + if not shape: + if _array_buffer.length != 0: + shape = (_array_buffer.length, ) + else: + RuntimeError("Shape and buffer length are size invalid.") + + if not _no_initial_dtype and dtype.typecode != _type_char: + raise TypeError("Can not create array of requested type from input data type") + + if not (offset or strides): + if pointer_source == PointerSource.host: + self.arr = library.create_array(shape, dtype, _array_buffer) + return + + self.arr = library.device_array(shape, dtype, _array_buffer) + return + + self.arr = library.create_strided_array( + shape, dtype, _array_buffer, offset, strides, pointer_source) # type: ignore[arg-type] + + # Arithmetic Operators + + def __pos__(self) -> Array: + """ + Evaluates +self_i for each element of an array instance. + + Parameters + ---------- + self : Array + Array instance. Should have a numeric data type. + + Returns + ------- + out : Array + An array containing the evaluated result for each element. The returned array must have the same data type + as self. + """ + return self + + def __neg__(self) -> Array: + """ + Evaluates +self_i for each element of an array instance. + + Parameters + ---------- + self : Array + Array instance. Should have a numeric data type. + + Returns + ------- + out : Array + An array containing the evaluated result for each element in self. The returned array must have a data type + determined by Type Promotion Rules. + + """ + return _process_c_function(0, self, backend.sub) + + def __add__(self, other: Union[int, float, Array], /) -> Array: + """ + Calculates the sum for each element of an array instance with the respective element of the array other. + + Parameters + ---------- + self : Array + Array instance (augend array). Should have a numeric data type. + other: Union[int, float, Array] + Addend array. Must be compatible with self (see Broadcasting). Should have a numeric data type. + + Returns + ------- + out : Array + An array containing the element-wise sums. The returned array must have a data type determined + by Type Promotion Rules. + """ + return _process_c_function(self, other, backend.add) + + def __sub__(self, other: Union[int, float, Array], /) -> Array: + """ + Calculates the difference for each element of an array instance with the respective element of the array other. + + The result of self_i - other_i must be the same as self_i + (-other_i) and must be governed by the same + floating-point rules as addition (see array.__add__()). + + Parameters + ---------- + self : Array + Array instance (minuend array). Should have a numeric data type. + other: Union[int, float, Array] + Subtrahend array. Must be compatible with self (see Broadcasting). Should have a numeric data type. + + Returns + ------- + out : Array + An array containing the element-wise differences. The returned array must have a data type determined + by Type Promotion Rules. + """ + return _process_c_function(self, other, backend.sub) + + def __mul__(self, other: Union[int, float, Array], /) -> Array: + """ + Calculates the product for each element of an array instance with the respective element of the array other. + + Parameters + ---------- + self : Array + Array instance. Should have a numeric data type. + other: Union[int, float, Array] + Other array. Must be compatible with self (see Broadcasting). Should have a numeric data type. + + Returns + ------- + out : Array + An array containing the element-wise products. The returned array must have a data type determined + by Type Promotion Rules. + """ + return _process_c_function(self, other, backend.mul) + + def __truediv__(self, other: Union[int, float, Array], /) -> Array: + """ + Evaluates self_i / other_i for each element of an array instance with the respective element of the + array other. + + Parameters + ---------- + self : Array + Array instance. Should have a numeric data type. + other: Union[int, float, Array] + Other array. Must be compatible with self (see Broadcasting). Should have a numeric data type. + + Returns + ------- + out : Array + An array containing the element-wise results. The returned array should have a floating-point data type + determined by Type Promotion Rules. + + Note + ---- + - If one or both of self and other have integer data types, the result is implementation-dependent, as type + promotion between data type “kinds” (e.g., integer versus floating-point) is unspecified. + Specification-compliant libraries may choose to raise an error or return an array containing the element-wise + results. If an array is returned, the array must have a real-valued floating-point data type. + """ + return _process_c_function(self, other, backend.div) + + def __floordiv__(self, other: Union[int, float, Array], /) -> Array: + # TODO + return NotImplemented + + def __mod__(self, other: Union[int, float, Array], /) -> Array: + """ + Evaluates self_i % other_i for each element of an array instance with the respective element of the + array other. + + Parameters + ---------- + self : Array + Array instance. Should have a real-valued data type. + other: Union[int, float, Array] + Other array. Must be compatible with self (see Broadcasting). Should have a real-valued data type. + + Returns + ------- + out : Array + An array containing the element-wise results. Each element-wise result must have the same sign as the + respective element other_i. The returned array must have a real-valued floating-point data type determined + by Type Promotion Rules. + + Note + ---- + - For input arrays which promote to an integer data type, the result of division by zero is unspecified and + thus implementation-defined. + """ + return _process_c_function(self, other, backend.mod) + + def __pow__(self, other: Union[int, float, Array], /) -> Array: + """ + Calculates an implementation-dependent approximation of exponentiation by raising each element (the base) of + an array instance to the power of other_i (the exponent), where other_i is the corresponding element of the + array other. + + Parameters + ---------- + self : Array + Array instance whose elements correspond to the exponentiation base. Should have a numeric data type. + other: Union[int, float, Array] + Other array whose elements correspond to the exponentiation exponent. Must be compatible with self + (see Broadcasting). Should have a numeric data type. + + Returns + ------- + out : Array + An array containing the element-wise results. The returned array must have a data type determined + by Type Promotion Rules. + """ + return _process_c_function(self, other, backend.pow) + + # Array Operators + + def __matmul__(self, other: Array, /) -> Array: + # TODO get from blas - make vanilla version and not copy af.matmul as is + return NotImplemented + + # Bitwise Operators + + def __invert__(self) -> Array: + """ + Evaluates ~self_i for each element of an array instance. + + Parameters + ---------- + self : Array + Array instance. Should have an integer or boolean data type. + + Returns + ------- + out : Array + An array containing the element-wise results. The returned array must have the same data type as self. + """ + # FIXME + out = Array() + out.arr = backend.bitnot(self.arr) + return out + + def __and__(self, other: Union[int, bool, Array], /) -> Array: + """ + Evaluates self_i & other_i for each element of an array instance with the respective element of the + array other. + + Parameters + ---------- + self : Array + Array instance. Should have a numeric data type. + other: Union[int, bool, Array] + Other array. Must be compatible with self (see Broadcasting). Should have a numeric data type. + + Returns + ------- + out : Array + An array containing the element-wise results. The returned array must have a data type determined + by Type Promotion Rules. + """ + return _process_c_function(self, other, backend.bitand) + + def __or__(self, other: Union[int, bool, Array], /) -> Array: + """ + Evaluates self_i | other_i for each element of an array instance with the respective element of the + array other. + + Parameters + ---------- + self : Array + Array instance. Should have a numeric data type. + other: Union[int, bool, Array] + Other array. Must be compatible with self (see Broadcasting). Should have a numeric data type. + + Returns + ------- + out : Array + An array containing the element-wise results. The returned array must have a data type determined + by Type Promotion Rules. + """ + return _process_c_function(self, other, backend.bitor) + + def __xor__(self, other: Union[int, bool, Array], /) -> Array: + """ + Evaluates self_i ^ other_i for each element of an array instance with the respective element of the + array other. + + Parameters + ---------- + self : Array + Array instance. Should have a numeric data type. + other: Union[int, bool, Array] + Other array. Must be compatible with self (see Broadcasting). Should have a numeric data type. + + Returns + ------- + out : Array + An array containing the element-wise results. The returned array must have a data type determined + by Type Promotion Rules. + """ + return _process_c_function(self, other, backend.bitxor) + + def __lshift__(self, other: Union[int, Array], /) -> Array: + """ + Evaluates self_i << other_i for each element of an array instance with the respective element of the + array other. + + Parameters + ---------- + self : Array + Array instance. Should have a numeric data type. + other: Union[int, Array] + Other array. Must be compatible with self (see Broadcasting). Should have a numeric data type. + Each element must be greater than or equal to 0. + + Returns + ------- + out : Array + An array containing the element-wise results. The returned array must have the same data type as self. + """ + return _process_c_function(self, other, backend.bitshiftl) + + def __rshift__(self, other: Union[int, Array], /) -> Array: + """ + Evaluates self_i >> other_i for each element of an array instance with the respective element of the + array other. + + Parameters + ---------- + self : Array + Array instance. Should have a numeric data type. + other: Union[int, Array] + Other array. Must be compatible with self (see Broadcasting). Should have a numeric data type. + Each element must be greater than or equal to 0. + + Returns + ------- + out : Array + An array containing the element-wise results. The returned array must have the same data type as self. + """ + return _process_c_function(self, other, backend.bitshiftr) + + # Comparison Operators + + def __lt__(self, other: Union[int, float, Array], /) -> Array: + """ + Computes the truth value of self_i < other_i for each element of an array instance with the respective + element of the array other. + + Parameters + ---------- + self : Array + Array instance. Should have a numeric data type. + other: Union[int, float, Array] + Other array. Must be compatible with self (see Broadcasting). Should have a real-valued data type. + + Returns + ------- + out : Array + An array containing the element-wise results. The returned array must have a data type of bool. + """ + return _process_c_function(self, other, backend.lt) + + def __le__(self, other: Union[int, float, Array], /) -> Array: + """ + Computes the truth value of self_i <= other_i for each element of an array instance with the respective + element of the array other. + + Parameters + ---------- + self : Array + Array instance. Should have a numeric data type. + other: Union[int, float, Array] + Other array. Must be compatible with self (see Broadcasting). Should have a real-valued data type. + + Returns + ------- + out : Array + An array containing the element-wise results. The returned array must have a data type of bool. + """ + return _process_c_function(self, other, backend.le) + + def __gt__(self, other: Union[int, float, Array], /) -> Array: + """ + Computes the truth value of self_i > other_i for each element of an array instance with the respective + element of the array other. + + Parameters + ---------- + self : Array + Array instance. Should have a numeric data type. + other: Union[int, float, Array] + Other array. Must be compatible with self (see Broadcasting). Should have a real-valued data type. + + Returns + ------- + out : Array + An array containing the element-wise results. The returned array must have a data type of bool. + """ + return _process_c_function(self, other, backend.gt) + + def __ge__(self, other: Union[int, float, Array], /) -> Array: + """ + Computes the truth value of self_i >= other_i for each element of an array instance with the respective + element of the array other. + + Parameters + ---------- + self : Array + Array instance. Should have a numeric data type. + other: Union[int, float, Array] + Other array. Must be compatible with self (see Broadcasting). Should have a real-valued data type. + + Returns + ------- + out : Array + An array containing the element-wise results. The returned array must have a data type of bool. + """ + return _process_c_function(self, other, backend.ge) + + def __eq__(self, other: Union[int, float, bool, Array], /) -> Array: # type: ignore[override] + """ + Computes the truth value of self_i == other_i for each element of an array instance with the respective + element of the array other. + + Parameters + ---------- + self : Array + Array instance. Should have a numeric data type. + other: Union[int, float, bool, Array] + Other array. Must be compatible with self (see Broadcasting). May have any data type. + + Returns + ------- + out : Array + An array containing the element-wise results. The returned array must have a data type of bool. + """ + return _process_c_function(self, other, backend.eq) + + def __ne__(self, other: Union[int, float, bool, Array], /) -> Array: # type: ignore[override] + """ + Computes the truth value of self_i != other_i for each element of an array instance with the respective + element of the array other. + + Parameters + ---------- + self : Array + Array instance. Should have a numeric data type. + other: Union[int, float, bool, Array] + Other array. Must be compatible with self (see Broadcasting). May have any data type. + + Returns + ------- + out : Array + An array containing the element-wise results. The returned array must have a data type of bool. + """ + return _process_c_function(self, other, backend.neq) + + # Reflected Arithmetic Operators + + def __radd__(self, other: Array, /) -> Array: + """ + Return other + self. + """ + return _process_c_function(other, self, backend.add) + + def __rsub__(self, other: Array, /) -> Array: + """ + Return other - self. + """ + return _process_c_function(other, self, backend.sub) + + def __rmul__(self, other: Array, /) -> Array: + """ + Return other * self. + """ + return _process_c_function(other, self, backend.mul) + + def __rtruediv__(self, other: Array, /) -> Array: + """ + Return other / self. + """ + return _process_c_function(other, self, backend.div) + + def __rfloordiv__(self, other: Array, /) -> Array: + # TODO + return NotImplemented + + def __rmod__(self, other: Array, /) -> Array: + """ + Return other % self. + """ + return _process_c_function(other, self, backend.mod) + + def __rpow__(self, other: Array, /) -> Array: + """ + Return other ** self. + """ + return _process_c_function(other, self, backend.pow) + + # Reflected Array Operators + + def __rmatmul__(self, other: Array, /) -> Array: + # TODO + return NotImplemented + + # Reflected Bitwise Operators + + def __rand__(self, other: Array, /) -> Array: + """ + Return other & self. + """ + return _process_c_function(other, self, backend.bitand) + + def __ror__(self, other: Array, /) -> Array: + """ + Return other | self. + """ + return _process_c_function(other, self, backend.bitor) + + def __rxor__(self, other: Array, /) -> Array: + """ + Return other ^ self. + """ + return _process_c_function(other, self, backend.bitxor) + + def __rlshift__(self, other: Array, /) -> Array: + """ + Return other << self. + """ + return _process_c_function(other, self, backend.bitshiftl) + + def __rrshift__(self, other: Array, /) -> Array: + """ + Return other >> self. + """ + return _process_c_function(other, self, backend.bitshiftr) + + # In-place Arithmetic Operators + + def __iadd__(self, other: Union[int, float, Array], /) -> Array: + # TODO discuss either we need to support complex and bool as other input type + """ + Return self += other. + """ + return _process_c_function(self, other, backend.add) + + def __isub__(self, other: Union[int, float, Array], /) -> Array: + """ + Return self -= other. + """ + return _process_c_function(self, other, backend.sub) + + def __imul__(self, other: Union[int, float, Array], /) -> Array: + """ + Return self *= other. + """ + return _process_c_function(self, other, backend.mul) + + def __itruediv__(self, other: Union[int, float, Array], /) -> Array: + """ + Return self /= other. + """ + return _process_c_function(self, other, backend.div) + + def __ifloordiv__(self, other: Union[int, float, Array], /) -> Array: + # TODO + return NotImplemented + + def __imod__(self, other: Union[int, float, Array], /) -> Array: + """ + Return self %= other. + """ + return _process_c_function(self, other, backend.mod) + + def __ipow__(self, other: Union[int, float, Array], /) -> Array: + """ + Return self **= other. + """ + return _process_c_function(self, other, backend.pow) + + # In-place Array Operators + + def __imatmul__(self, other: Array, /) -> Array: + # TODO + return NotImplemented + + # In-place Bitwise Operators + + def __iand__(self, other: Union[int, bool, Array], /) -> Array: + """ + Return self &= other. + """ + return _process_c_function(self, other, backend.bitand) + + def __ior__(self, other: Union[int, bool, Array], /) -> Array: + """ + Return self |= other. + """ + return _process_c_function(self, other, backend.bitor) + + def __ixor__(self, other: Union[int, bool, Array], /) -> Array: + """ + Return self ^= other. + """ + return _process_c_function(self, other, backend.bitxor) + + def __ilshift__(self, other: Union[int, Array], /) -> Array: + """ + Return self <<= other. + """ + return _process_c_function(self, other, backend.bitshiftl) + + def __irshift__(self, other: Union[int, Array], /) -> Array: + """ + Return self >>= other. + """ + return _process_c_function(self, other, backend.bitshiftr) + + # Methods + + def __abs__(self) -> Array: + # TODO + return NotImplemented + + def __array_namespace__(self, *, api_version: Optional[str] = None) -> Any: + # TODO + return NotImplemented + + def __bool__(self) -> bool: + # TODO consider using scalar() and is_scalar() + return NotImplemented + + def __complex__(self) -> complex: + # TODO + return NotImplemented + + def __dlpack__(self, *, stream: Union[None, int, Any] = None): # type: ignore[no-untyped-def] + # TODO implementation and expected return type -> PyCapsule + return NotImplemented + + def __dlpack_device__(self) -> Tuple[enum.Enum, int]: + # TODO + return NotImplemented + + def __float__(self) -> float: + # TODO + return NotImplemented + + def __getitem__(self, key: Union[int, slice, Tuple[Union[int, slice, ], ...], Array], /) -> Array: + """ + Returns self[key]. + + Parameters + ---------- + self : Array + Array instance. + key : Union[int, slice, Tuple[Union[int, slice, ], ...], Array] + Index key. + + Returns + ------- + out : Array + An array containing the accessed value(s). The returned array must have the same data type as self. + """ + # TODO + # API Specification - key: Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], array]. + # consider using af.span to replace ellipsis during refactoring + out = Array() + ndims = self.ndim + + if isinstance(key, Array) and key == af_bool.c_api_value: + ndims = 1 + if count_all(key.arr) == 0: + return out + + # HACK known issue + out.arr = library.index_gen(self.arr, ndims, key) # type: ignore[arg-type] + return out + + def __index__(self) -> int: + # TODO + return NotImplemented + + def __int__(self) -> int: + # TODO + return NotImplemented + + def __len__(self) -> int: + return self.shape[0] if self.shape else 0 + + def __setitem__( + self, key: Union[int, slice, Tuple[Union[int, slice, ], ...], Array], + value: Union[int, float, bool, Array], /) -> None: + # TODO + return NotImplemented # type: ignore[return-value] # FIXME + + def __str__(self) -> str: + # TODO change the look of array str. E.g., like np.array + # if not _in_display_dims_limit(self.shape): + # return _metadata_string(self.dtype, self.shape) + return _metadata_string(self.dtype) + library.array_as_str(self.arr) + + def __repr__(self) -> str: + # return _metadata_string(self.dtype, self.shape) + # TODO change the look of array representation. E.g., like np.array + return library.array_as_str(self.arr) + + def to_device(self, device: Any, /, *, stream: Union[int, Any] = None) -> Array: + # TODO implementation and change device type from Any to Device + return NotImplemented + + # Attributes + + @property + def dtype(self) -> Dtype: + """ + Data type of the array elements. + + Returns + ------- + out : Dtype + Array data type. + """ + return c_api_value_to_dtype(library.get_ctype(self.arr)) + + @property + def device(self) -> Any: + # TODO + return NotImplemented + + @property + def mT(self) -> Array: + # TODO + return NotImplemented + + @property + def T(self) -> Array: + """ + Transpose of the array. + + Returns + ------- + out : Array + Two-dimensional array whose first and last dimensions (axes) are permuted in reverse order relative to + original array. The returned array must have the same data type as the original array. + + Note + ---- + - The array instance must be two-dimensional. If the array instance is not two-dimensional, an error + should be raised. + """ + if self.ndim < 2: + raise TypeError(f"Array should be at least 2-dimensional. Got {self.ndim}-dimensional array") + + # TODO add check if out.dtype == self.dtype + out = Array() + out.arr = library.transpose(self.arr, False) + return out + + @property + def size(self) -> int: + """ + Number of elements in an array. + + Returns + ------- + out : int + Number of elements in an array + + Note + ---- + - This must equal the product of the array's dimensions. + """ + # NOTE previously - elements() + return library.get_elements(self.arr) + + @property + def ndim(self) -> int: + """ + Number of array dimensions (axes). + + out : int + Number of array dimensions (axes). + """ + return library.get_numdims(self.arr) + + @property + def shape(self) -> Tuple[int, ...]: + """ + Array dimensions. + + Returns + ------- + out : tuple[int, ...] + Array dimensions. + """ + # NOTE skipping passing any None values + return library.get_dims(self.arr)[:self.ndim] + + def scalar(self) -> Union[None, int, float, bool, complex]: + """ + Return the first element of the array + """ + # TODO change the logic of this method + if self.is_empty(): + return None + + return library.get_scalar(self.arr, self.dtype) + + def is_empty(self) -> bool: + """ + Check if the array is empty i.e. it has no elements. + """ + return library.is_empty(self.arr) + + def to_list(self, row_major: bool = False) -> List[Union[None, int, float, bool, complex]]: + if self.is_empty(): + return [] + + array = _reorder(self) if row_major else self + ctypes_array = library.get_data_ptr(array.arr, array.size, array.dtype) + + if array.ndim == 1: + return list(ctypes_array) + + out = [] + for i in range(array.size): + idx = i + sub_list = [] + for j in range(array.ndim): + div = array.shape[j] + sub_list.append(idx % div) + idx //= div + out.append(ctypes_array[sub_list[::-1]]) # type: ignore[call-overload] # FIXME + return out + + def to_ctype_array(self, row_major: bool = False) -> ctypes.Array: + if self.is_empty(): + raise RuntimeError("Can not convert an empty array to ctype.") + + array = _reorder(self) if row_major else self + return library.get_data_ptr(array.arr, array.size, array.dtype) + + +def _reorder(array: Array) -> Array: + """ + Returns a reordered array to help interoperate with row major formats. + """ + if array.ndim == 1: + return array + + out = Array() + out.arr = library.reorder(array.arr, array.ndim) + return out + + +def _metadata_string(dtype: Dtype, dims: Optional[Tuple[int, ...]] = None) -> str: + return ( + "arrayfire.Array()\n" + f"Type: {dtype.typename}\n" + f"Dims: {str(dims) if dims else ''}") + + +def _process_c_function(lhs: Union[int, float, Array], rhs: Union[int, float, Array], c_function: Any) -> Array: + out = Array() + + if isinstance(lhs, Array) and isinstance(rhs, Array): + lhs_array = lhs.arr + rhs_array = rhs.arr + + elif isinstance(lhs, Array) and isinstance(rhs, (int, float)): + lhs_array = lhs.arr + rhs_array = create_constant_array(rhs, lhs.shape, lhs.dtype) + + elif isinstance(lhs, (int, float)) and isinstance(rhs, Array): + lhs_array = create_constant_array(lhs, rhs.shape, rhs.dtype) + rhs_array = rhs.arr + + else: + raise TypeError(f"{type(rhs)} is not supported and can not be passed to C binary function.") + + out.arr = c_function(lhs_array, rhs_array) + return out diff --git a/arrayfire/backend/__init__.py b/arrayfire/backend/__init__.py new file mode 100644 index 0000000..30253b4 --- /dev/null +++ b/arrayfire/backend/__init__.py @@ -0,0 +1,10 @@ +__all__ = [ + # Backend + "ArrayBuffer", + # Operators + "add", "sub", "mul", "div", "mod", "pow", "bitnot", "bitand", "bitor", "bitxor", "bitshiftl", "bitshiftr", "lt", + "le", "gt", "ge", "eq", "neq"] + +from .backend import ArrayBuffer +from .operators import ( + add, bitand, bitnot, bitor, bitshiftl, bitshiftr, bitxor, div, eq, ge, gt, le, lt, mod, mul, neq, pow, sub) diff --git a/arrayfire/backend/backend.py b/arrayfire/backend/backend.py new file mode 100644 index 0000000..a154555 --- /dev/null +++ b/arrayfire/backend/backend.py @@ -0,0 +1,26 @@ +import ctypes +import enum +from dataclasses import dataclass + +from ..dtypes.helpers import c_dim_t, to_str + +backend_api = ctypes.CDLL("/opt/arrayfire//lib/libafcpu.3.dylib") # Mock + + +def safe_call(c_err: int) -> None: + if c_err == _ErrorCodes.none.value: + return + + err_str = ctypes.c_char_p(0) + backend_api.af_get_last_error(ctypes.pointer(err_str), ctypes.pointer(c_dim_t(0))) + raise RuntimeError(to_str(err_str)) + + +class _ErrorCodes(enum.Enum): + none = 0 + + +@dataclass +class ArrayBuffer: + address: int + length: int = 0 diff --git a/arrayfire/backend/constant_array.py b/arrayfire/backend/constant_array.py new file mode 100644 index 0000000..85ba3d8 --- /dev/null +++ b/arrayfire/backend/constant_array.py @@ -0,0 +1,84 @@ +import ctypes +from typing import Tuple, Union + +from ..dtypes import Dtype, int64, uint64 +from ..dtypes.helpers import CShape, implicit_dtype +from .backend import backend_api, safe_call + +AFArray = ctypes.c_void_p + + +def _constant_complex(number: Union[int, float], shape: Tuple[int, ...], dtype: Dtype, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__data__func__constant.htm#ga5a083b1f3cd8a72a41f151de3bdea1a2 + """ + out = ctypes.c_void_p(0) + c_shape = CShape(*shape) + + safe_call( + backend_api.af_constant_complex( + ctypes.pointer(out), ctypes.c_double(number.real), ctypes.c_double(number.imag), 4, + ctypes.pointer(c_shape.c_array), dtype.c_api_value) + ) + return out + + +def _constant_long(number: Union[int, float], shape: Tuple[int, ...], dtype: Dtype, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__data__func__constant.htm#ga10f1c9fad1ce9e9fefd885d5a1d1fd49 + """ + out = ctypes.c_void_p(0) + c_shape = CShape(*shape) + + safe_call( + backend_api.af_constant_long( + ctypes.pointer(out), ctypes.c_longlong(number.real), 4, ctypes.pointer(c_shape.c_array)) + ) + return out + + +def _constant_ulong(number: Union[int, float], shape: Tuple[int, ...], dtype: Dtype, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__data__func__constant.htm#ga67af670cc9314589f8134019f5e68809 + """ + # return backend_api.af_constant_ulong(arr, val, ndims, dims) + out = ctypes.c_void_p(0) + c_shape = CShape(*shape) + + safe_call( + backend_api.af_constant_ulong( + ctypes.pointer(out), ctypes.c_ulonglong(number.real), 4, ctypes.pointer(c_shape.c_array)) + ) + return out + + +def _constant(number: Union[int, float], shape: Tuple[int, ...], dtype: Dtype, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__data__func__constant.htm#gafc51b6a98765dd24cd4139f3bde00670 + """ + out = ctypes.c_void_p(0) + c_shape = CShape(*shape) + + safe_call( + backend_api.af_constant( + ctypes.pointer(out), ctypes.c_double(number), 4, ctypes.pointer(c_shape.c_array), dtype.c_api_value) + ) + return out + + +def create_constant_array(number: Union[int, float], shape: Tuple[int, ...], dtype: Dtype, /) -> AFArray: + dtype = implicit_dtype(number, dtype) + + # NOTE complex is not supported in Data API + # if isinstance(number, complex): + # if dtype != complex64 and dtype != complex128: + # dtype = complex64 + # return _constant_complex(number, shape, dtype) + + if dtype == int64: + return _constant_long(number, shape, dtype) + + if dtype == uint64: + return _constant_ulong(number, shape, dtype) + + return _constant(number, shape, dtype) diff --git a/arrayfire/backend/library.py b/arrayfire/backend/library.py new file mode 100644 index 0000000..9c87aa5 --- /dev/null +++ b/arrayfire/backend/library.py @@ -0,0 +1,238 @@ +import ctypes +from typing import Tuple, Union, cast + +from arrayfire.array import _get_indices # HACK replace with refactored one + +from ..device import PointerSource +from ..dtypes import CType, Dtype +from ..dtypes.helpers import CShape, c_dim_t, to_str +from .backend import ArrayBuffer, backend_api, safe_call + +AFArrayPointer = ctypes._Pointer +AFArray = ctypes.c_void_p + +# HACK, TODO replace for actual bcast_var after refactoring ~ https://github.com/arrayfire/arrayfire/pull/2871 +_bcast_var = False + +# Array management + + +def create_handle(shape: Tuple[int, ...], dtype: Dtype, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__c__api__mat.htm#ga3b8f5cf6fce69aa1574544bc2d44d7d0 + """ + out = ctypes.c_void_p(0) + c_shape = CShape(*shape) + + safe_call( + backend_api.af_create_handle( + ctypes.pointer(out), c_shape.original_shape, ctypes.pointer(c_shape.c_array), dtype.c_api_value) + ) + return out + + +def retain_array(arr: AFArray) -> AFArray: + """ + source: https://arrayfire.org/docs/group__c__api__mat.htm#ga7ed45b3f881c0f6c80c5cf2af886dbab + """ + out = ctypes.c_void_p(0) + + safe_call( + backend_api.af_retain_array(ctypes.pointer(out), arr) + ) + return out + + +def create_array(shape: Tuple[int, ...], dtype: Dtype, array_buffer: ArrayBuffer, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__c__api__mat.htm#ga834be32357616d8ab735087c6f681858 + """ + out = ctypes.c_void_p(0) + c_shape = CShape(*shape) + + safe_call( + backend_api.af_create_array( + ctypes.pointer(out), ctypes.c_void_p(array_buffer.address), c_shape.original_shape, + ctypes.pointer(c_shape.c_array), dtype.c_api_value) + ) + return out + + +def device_array(shape: Tuple[int, ...], dtype: Dtype, array_buffer: ArrayBuffer, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__c__api__mat.htm#gaad4fc77f872217e7337cb53bfb623cf5 + """ + out = ctypes.c_void_p(0) + c_shape = CShape(*shape) + + safe_call( + backend_api.af_device_array( + ctypes.pointer(out), ctypes.c_void_p(array_buffer.address), c_shape.original_shape, + ctypes.pointer(c_shape.c_array), dtype.c_api_value) + ) + return out + + +def create_strided_array( + shape: Tuple[int, ...], dtype: Dtype, array_buffer: ArrayBuffer, offset: CType, strides: Tuple[int, ...], + pointer_source: PointerSource, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__internal__func__create.htm#gad31241a3437b7b8bc3cf49f85e5c4e0c + """ + out = ctypes.c_void_p(0) + c_shape = CShape(*shape) + + if offset is None: + offset = c_dim_t(0) + + if strides is None: + strides = (1, c_shape[0], c_shape[0]*c_shape[1], c_shape[0]*c_shape[1]*c_shape[2]) + + if len(strides) < 4: + strides += (strides[-1], ) * (4 - len(strides)) + + safe_call( + backend_api.af_create_strided_array( + ctypes.pointer(out), ctypes.c_void_p(array_buffer.address), offset, c_shape.original_shape, + ctypes.pointer(c_shape.c_array), CShape(*strides).c_array, dtype.c_api_value, pointer_source.value) + ) + return out + + +def get_ctype(arr: AFArray) -> int: + """ + source: https://arrayfire.org/docs/group__c__api__mat.htm#ga0dda6898e1c0d9a43efb56cd6a988c9b + """ + out = ctypes.c_int() + + safe_call( + backend_api.af_get_type(ctypes.pointer(out), arr) + ) + return out.value + + +def get_elements(arr: AFArray) -> int: + """ + source: https://arrayfire.org/docs/group__c__api__mat.htm#ga6845bbe4385a60a606b88f8130252c1f + """ + out = c_dim_t(0) + + safe_call( + backend_api.af_get_elements(ctypes.pointer(out), arr) + ) + return out.value + + +def get_numdims(arr: AFArray) -> int: + """ + source: https://arrayfire.org/docs/group__c__api__mat.htm#gaefa019d932ff58c2a829ce87edddd2a8 + """ + out = ctypes.c_uint(0) + + safe_call( + backend_api.af_get_numdims(ctypes.pointer(out), arr) + ) + return out.value + + +def get_dims(arr: AFArray) -> Tuple[int, ...]: + """ + source: https://arrayfire.org/docs/group__c__api__mat.htm#ga8b90da50a532837d9763e301b2267348 + """ + d0 = c_dim_t(0) + d1 = c_dim_t(0) + d2 = c_dim_t(0) + d3 = c_dim_t(0) + + safe_call( + backend_api.af_get_dims(ctypes.pointer(d0), ctypes.pointer(d1), ctypes.pointer(d2), ctypes.pointer(d3), arr) + ) + return (d0.value, d1.value, d2.value, d3.value) + + +def get_scalar(arr: AFArray, dtype: Dtype, /) -> Union[None, int, float, bool, complex]: + """ + source: https://arrayfire.org/docs/group__c__api__mat.htm#gaefe2e343a74a84bd43b588218ecc09a3 + """ + out = dtype.c_type() + safe_call( + backend_api.af_get_scalar(ctypes.pointer(out), arr) + ) + return cast(Union[None, int, float, bool, complex], out.value) + + +def is_empty(arr: AFArray) -> bool: + """ + source: https://arrayfire.org/docs/group__c__api__mat.htm#ga19c749e95314e1c77d816ad9952fb680 + """ + out = ctypes.c_bool() + safe_call( + backend_api.af_is_empty(ctypes.pointer(out), arr) + ) + return out.value + + +def get_data_ptr(arr: AFArray, size: int, dtype: Dtype, /) -> ctypes.Array: + """ + source: https://arrayfire.org/docs/group__c__api__mat.htm#ga6040dc6f0eb127402fbf62c1165f0b9d + """ + c_shape = dtype.c_type * size + ctypes_array = c_shape() + safe_call( + backend_api.af_get_data_ptr(ctypes.pointer(ctypes_array), arr) + ) + return ctypes_array + + +# Arrayfire Functions + + +def index_gen(arr: AFArray, ndims: int, key: Union[int, slice, Tuple[Union[int, slice, ], ...]], /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__index__func__index.htm#ga14a7d149dba0ed0b977335a3df9d91e6 + """ + out = ctypes.c_void_p(0) + safe_call( + backend_api.af_index_gen(ctypes.pointer(out), arr, c_dim_t(ndims), _get_indices(key).pointer) + ) + return out + + +def transpose(arr: AFArray, conjugate: bool, /) -> AFArray: + """ + https://arrayfire.org/docs/group__blas__func__transpose.htm#ga716b2b9bf190c8f8d0970aef2b57d8e7 + """ + out = ctypes.c_void_p(0) + safe_call( + backend_api.af_transpose(ctypes.pointer(out), arr, conjugate) + ) + return out + + +def reorder(arr: AFArray, ndims: int, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__manip__func__reorder.htm#ga57383f4d00a3a86eab08dddd52c3ad3d + """ + out = ctypes.c_void_p(0) + c_shape = CShape(*(tuple(reversed(range(ndims))) + tuple(range(ndims, 4)))) + safe_call( + backend_api.af_reorder(ctypes.pointer(out), arr, *c_shape) + ) + return out + + +def array_as_str(arr: AFArray) -> str: + """ + source: + - https://arrayfire.org/docs/group__print__func__tostring.htm#ga01f32ef2420b5d4592c6e4b4964b863b + - https://arrayfire.org/docs/group__device__func__free__host.htm#ga3f1149a837a7ebbe8002d5d2244e3370 + """ + arr_str = ctypes.c_char_p(0) + safe_call( + backend_api.af_array_to_string(ctypes.pointer(arr_str), "", arr, 4, True) + ) + py_str = to_str(arr_str) + safe_call( + backend_api.af_free_host(arr_str) + ) + return py_str diff --git a/arrayfire/backend/operators.py b/arrayfire/backend/operators.py new file mode 100644 index 0000000..aa83a7d --- /dev/null +++ b/arrayfire/backend/operators.py @@ -0,0 +1,161 @@ +import ctypes +from typing import Callable, Union + +from .backend import backend_api, safe_call + +AFArray = ctypes.c_void_p + +# HACK, TODO replace for actual bcast_var after refactoring ~ https://github.com/arrayfire/arrayfire/pull/2871 +_bcast_var = False + +# Arithmetic Operators + + +def add(lhs: AFArray, rhs: AFArray, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__arith__func__add.htm#ga1dfbee755fedd680f4476803ddfe06a7 + """ + return _binary_op(backend_api.af_add, lhs, rhs) + + +def sub(lhs: AFArray, rhs: AFArray, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__arith__func__sub.htm#ga80ff99a2e186c23614ea9f36ffc6f0a4 + """ + return _binary_op(backend_api.af_sub, lhs, rhs) + + +def mul(lhs: AFArray, rhs: AFArray, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__arith__func__mul.htm#ga5f7588b2809ff7551d38b6a0bd583a02 + """ + return _binary_op(backend_api.af_mul, lhs, rhs) + + +def div(lhs: AFArray, rhs: AFArray, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__arith__func__div.htm#ga21f3f97755702692ec8976934e75fde6 + """ + return _binary_op(backend_api.af_div, lhs, rhs) + + +def mod(lhs: AFArray, rhs: AFArray, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__arith__func__mod.htm#ga01924d1b59d8886e46fabd2dc9b27e0f + """ + return _binary_op(backend_api.af_mod, lhs, rhs) + + +def pow(lhs: AFArray, rhs: AFArray, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__arith__func__pow.htm#ga0f28be1a9c8b176a78c4a47f483e7fc6 + """ + return _binary_op(backend_api.af_pow, lhs, rhs) + + +# Bitwise Operators + +def bitnot(arr: AFArray, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__arith__func__bitnot.htm#gaf97e8a38aab59ed2d3a742515467d01e + """ + out = ctypes.c_void_p(0) + safe_call( + backend_api.af_bitnot(ctypes.pointer(out), arr) + ) + return out + + +def bitand(lhs: AFArray, rhs: AFArray, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__arith__func__bitand.htm#ga45c0779ade4703708596df11cca98800 + """ + return _binary_op(backend_api.af_bitand, lhs, rhs) + + +def bitor(lhs: AFArray, rhs: AFArray, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__arith__func__bitor.htm#ga84c99f77d1d83fd53f949b4d67b5b210 + """ + return _binary_op(backend_api.af_bitor, lhs, rhs) + + +def bitxor(lhs: AFArray, rhs: AFArray, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__arith__func__bitxor.htm#ga8188620da6b432998e55fdd1fad22100 + """ + return _binary_op(backend_api.af_bitxor, lhs, rhs) + + +def bitshiftl(lhs: AFArray, rhs: AFArray, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__arith__func__shiftl.htm#ga3139645aafe6f045a5cab454e9c13137 + """ + return _binary_op(backend_api.af_butshiftl, lhs, rhs) + + +def bitshiftr(lhs: AFArray, rhs: AFArray, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__arith__func__shiftr.htm#ga4c06b9977ecf96cdfc83b5dfd1ac4895 + """ + return _binary_op(backend_api.af_bitshiftr, lhs, rhs) + + +def lt(lhs: AFArray, rhs: AFArray, /) -> AFArray: + """ + source: https://arrayfire.org/docs/arith_8h.htm#ae7aa04bf23b32bb11c4bab8bdd637103 + """ + return _binary_op(backend_api.af_lt, lhs, rhs) + + +def le(lhs: AFArray, rhs: AFArray, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__arith__func__le.htm#gad5535ce64dbed46d0773fd494e84e922 + """ + return _binary_op(backend_api.af_le, lhs, rhs) + + +def gt(lhs: AFArray, rhs: AFArray, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__arith__func__gt.htm#ga4e65603259515de8939899a163ebaf9e + """ + return _binary_op(backend_api.af_gt, lhs, rhs) + + +def ge(lhs: AFArray, rhs: AFArray, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__arith__func__ge.htm#ga4513f212e0b0a22dcf4653e89c85e3d9 + """ + return _binary_op(backend_api.af_ge, lhs, rhs) + + +def eq(lhs: AFArray, rhs: AFArray, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__arith__func__eq.htm#ga76d2da7716831616bb81effa9e163693 + """ + return _binary_op(backend_api.af_eq, lhs, rhs) + + +def neq(lhs: AFArray, rhs: AFArray, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__arith__func__neq.htm#gae4ee8bd06a410f259f1493fb811ce441 + """ + return _binary_op(backend_api.af_neq, lhs, rhs) + + +def _binary_op(c_func: Callable, lhs: AFArray, rhs: AFArray, /) -> AFArray: + out = ctypes.c_void_p(0) + safe_call(c_func(ctypes.pointer(out), lhs, rhs, _bcast_var)) + return out + + +def count_all(x: AFArray) -> Union[int, float, complex]: + # TODO reconsider original arith.count + return _reduce_all(x, backend_api.af_count_all) + + +def _reduce_all(arr: AFArray, c_func: Callable) -> Union[int, float, complex]: + real = ctypes.c_double(0) + imag = ctypes.c_double(0) + safe_call(c_func(ctypes.pointer(real), ctypes.pointer(imag), arr)) + return real.value if imag.value == 0 else real.value + imag.value * 1j diff --git a/arrayfire/config.py b/arrayfire/config.py new file mode 100644 index 0000000..588cbdf --- /dev/null +++ b/arrayfire/config.py @@ -0,0 +1,6 @@ +import platform + + +def is_arch_x86() -> bool: + machine = platform.machine() + return platform.architecture()[0][0:2] == "32" and (machine[-2:] == "86" or machine[0:3] == "arm") diff --git a/arrayfire/device.py b/arrayfire/device.py new file mode 100644 index 0000000..fde5d6a --- /dev/null +++ b/arrayfire/device.py @@ -0,0 +1,10 @@ +import enum + + +class PointerSource(enum.Enum): + """ + Source of the pointer. + """ + # FIXME + device = 0 + host = 1 diff --git a/arrayfire/dtypes/__init__.py b/arrayfire/dtypes/__init__.py new file mode 100644 index 0000000..e9a181a --- /dev/null +++ b/arrayfire/dtypes/__init__.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +__all__ = [ + "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64", "float32", "float64", "complex64", + "complex128", "bool"] + +import ctypes +from dataclasses import dataclass +from typing import Type + +CType = Type[ctypes._SimpleCData] + + +@dataclass +class Dtype: + typecode: str + c_type: CType + typename: str + c_api_value: int # Internal use only + + +# Specification required +# int8 - Not Supported, b8? # HACK Dtype("i8", ctypes.c_char, "int8", 4) +int16 = Dtype("h", ctypes.c_short, "short int", 10) +int32 = Dtype("i", ctypes.c_int, "int", 5) +int64 = Dtype("l", ctypes.c_longlong, "long int", 8) +uint8 = Dtype("B", ctypes.c_ubyte, "unsigned_char", 7) +uint16 = Dtype("H", ctypes.c_ushort, "unsigned short int", 11) +uint32 = Dtype("I", ctypes.c_uint, "unsigned int", 6) +uint64 = Dtype("L", ctypes.c_ulonglong, "unsigned long int", 9) +float32 = Dtype("f", ctypes.c_float, "float", 0) +float64 = Dtype("d", ctypes.c_double, "double", 2) +complex64 = Dtype("F", ctypes.c_float*2, "float complext", 1) # type: ignore[arg-type] +complex128 = Dtype("D", ctypes.c_double*2, "double complext", 3) # type: ignore[arg-type] +bool = Dtype("b", ctypes.c_bool, "bool", 4) + +supported_dtypes = [ + int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, complex64, complex128, bool +] diff --git a/arrayfire/dtypes/functions.py b/arrayfire/dtypes/functions.py new file mode 100644 index 0000000..155dc61 --- /dev/null +++ b/arrayfire/dtypes/functions.py @@ -0,0 +1,32 @@ +from typing import Tuple, Union + +from ..array_object import Array +from . import Dtype + +# TODO implement functions + + +def astype(x: Array, dtype: Dtype, /, *, copy: bool = True) -> Array: + return NotImplemented + + +def can_cast(from_: Union[Dtype, Array], to: Dtype, /) -> bool: + return NotImplemented + + +def finfo(type: Union[Dtype, Array], /): # type: ignore[no-untyped-def] + # NOTE expected return type -> finfo_object + return NotImplemented + + +def iinfo(type: Union[Dtype, Array], /): # type: ignore[no-untyped-def] + # NOTE expected return type -> iinfo_object + return NotImplemented + + +def isdtype(dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]]) -> bool: + return NotImplemented + + +def result_type(*arrays_and_dtypes: Union[Dtype, Array]) -> Dtype: + return NotImplemented diff --git a/arrayfire/dtypes/helpers.py b/arrayfire/dtypes/helpers.py new file mode 100644 index 0000000..cf4d306 --- /dev/null +++ b/arrayfire/dtypes/helpers.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import ctypes +from typing import Tuple, Union + +from ..config import is_arch_x86 +from . import Dtype +from . import bool as af_bool +from . import complex64, complex128, float32, float64, int64, supported_dtypes + +c_dim_t = ctypes.c_int if is_arch_x86() else ctypes.c_longlong +ShapeType = Tuple[int, ...] + + +class CShape(tuple): + def __new__(cls, *args: int) -> CShape: + cls.original_shape = len(args) + return tuple.__new__(cls, args) + + def __init__(self, x1: int = 1, x2: int = 1, x3: int = 1, x4: int = 1) -> None: + self.x1 = x1 + self.x2 = x2 + self.x3 = x3 + self.x4 = x4 + + def __repr__(self) -> str: + return f"{self.__class__.__name__}{self.x1, self.x2, self.x3, self.x4}" + + @property + def c_array(self): # type: ignore[no-untyped-def] + c_shape = c_dim_t * 4 # ctypes.c_int | ctypes.c_longlong * 4 + return c_shape(c_dim_t(self.x1), c_dim_t(self.x2), c_dim_t(self.x3), c_dim_t(self.x4)) + + +def to_str(c_str: ctypes.c_char_p) -> str: + return str(c_str.value.decode("utf-8")) # type: ignore[union-attr] + + +def implicit_dtype(number: Union[int, float], array_dtype: Dtype) -> Dtype: + if isinstance(number, bool): + number_dtype = af_bool + if isinstance(number, int): + number_dtype = int64 + elif isinstance(number, float): + number_dtype = float64 + elif isinstance(number, complex): + number_dtype = complex128 + else: + raise TypeError(f"{type(number)} is not supported and can not be converted to af.Dtype.") + + if not (array_dtype == float32 or array_dtype == complex64): + return number_dtype + + if number_dtype == float64: + return float32 + + if number_dtype == complex128: + return complex64 + + return number_dtype + + +def c_api_value_to_dtype(value: int) -> Dtype: + for dtype in supported_dtypes: + if value == dtype.c_api_value: + return dtype + + raise TypeError("There is no supported dtype that matches passed dtype C API value.") + + +def str_to_dtype(value: int) -> Dtype: + for dtype in supported_dtypes: + if value == dtype.typecode or value == dtype.typename: + return dtype + + raise TypeError("There is no supported dtype that matches passed dtype typecode.") diff --git a/arrayfire/operators.py b/arrayfire/operators.py new file mode 100644 index 0000000..7cb4c07 --- /dev/null +++ b/arrayfire/operators.py @@ -0,0 +1,25 @@ +from typing import Callable + +from . import backend +from .array_object import Array + + +class return_copy: + # TODO merge with process_c_function in array_object + def __init__(self, func: Callable) -> None: + self.func = func + + def __call__(self, x1: Array, x2: Array) -> Array: + out = Array() + out.arr = self.func(x1.arr, x2.arr) + return out + + +@return_copy +def add(x1: Array, x2: Array, /) -> Array: + return backend.add(x1, x2) + + +@return_copy +def sub(x1: Array, x2: Array, /) -> Array: + return backend.sub(x1, x2) diff --git a/arrayfire/utils.py b/arrayfire/utils.py new file mode 100644 index 0000000..195111d --- /dev/null +++ b/arrayfire/utils.py @@ -0,0 +1,13 @@ +from typing import Tuple, Union + +from .array_object import Array + +# TODO implement functions + + +def all(x: Array, /, *, axis: Union[None, int, Tuple[int, ...]] = None, keepdims: bool = False) -> Array: + return NotImplemented + + +def any(x: Array, /, *, axis: Union[None, int, Tuple[int, ...]] = None, keepdims: bool = False) -> Array: + return NotImplemented diff --git a/setup.cfg b/setup.cfg index 050e9ce..bd675bf 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,7 +13,7 @@ classifiers = License :: OSI Approved :: BSD License Programming Language :: Python Programming Language :: Python :: 3 - Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.8 Topic :: Scientific/Engineering Topic :: Scientific/Engineering :: Artificial Intelligence Topic :: Scientific/Engineering :: Information Analysis @@ -27,7 +27,7 @@ packages = find: install_requires = scikit-build python_requires = - >=3.10.0 + >=3.8.0 [options.packages.find] include = arrayfire diff --git a/tests/array_object/__init__.py b/tests/array_object/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/array_object/test_initialization.py b/tests/array_object/test_initialization.py new file mode 100644 index 0000000..81f90e5 --- /dev/null +++ b/tests/array_object/test_initialization.py @@ -0,0 +1,53 @@ +import array as pyarray +import math +from typing import Any, Optional, Tuple + +import pytest + +from arrayfire.array_object import Array +from arrayfire.dtypes import Dtype, float32, int16 + +# TODO add tests for array arguments: device, offset, strides +# TODO add tests for all supported dtypes on initialisation +# TODO add test generation + + +@pytest.mark.parametrize( + "array, res_dtype, res_ndim, res_size, res_shape, res_len", [ + (Array(), float32, 0, 0, (), 0), + (Array(dtype=int16), int16, 0, 0, (), 0), + (Array(dtype="short int"), int16, 0, 0, (), 0), + (Array(dtype="h"), int16, 0, 0, (), 0), + (Array(shape=(2, 3)), float32, 2, 6, (2, 3), 2), + (Array([1, 2, 3]), float32, 1, 3, (3,), 3), + (Array(pyarray.array("f", [1, 2, 3])), float32, 1, 3, (3,), 3), + (Array([1], shape=(1,), dtype=float32), float32, 1, 1, (1,), 1), # BUG + (Array(Array([1])), float32, 1, 1, (1,), 1) + ]) +def test_initialization_with_different_arguments( + array: Array, res_dtype: Dtype, res_ndim: int, res_size: int, res_shape: Tuple[int, ...], + res_len: int) -> None: + assert array.dtype == res_dtype + assert array.ndim == res_ndim + assert array.size == res_size + # NOTE math.prod from empty object returns 1, but it should work for other cases + if res_size != 0: + assert array.size == math.prod(res_shape) + assert array.shape == res_shape + assert len(array) == res_len + + +@pytest.mark.parametrize( + "array_object, dtype, shape", [ + (None, "hello world", ()), + ([[1, 2, 3], [1, 2, 3]], None, ()), + (1, None, ()), + (1, None, (1,)), + ((5, 5), None, ()), + ({1: 2, 3: 4}, None, ()) + ] +) +def test_initalization_with_unsupported_argument_types( + array_object: Any, dtype: Optional[Dtype], shape: Tuple[int, ...]) -> None: + with pytest.raises(TypeError): + Array(x=array_object, dtype=dtype, shape=shape) diff --git a/tests/array_object/test_methods.py b/tests/array_object/test_methods.py new file mode 100644 index 0000000..f1292a6 --- /dev/null +++ b/tests/array_object/test_methods.py @@ -0,0 +1,31 @@ +from arrayfire.array_object import Array + +# TODO add more tests for different dtypes + + +def test_array_getitem() -> None: + array = Array([1, 2, 3, 4, 5]) + + int_item = array[2] + assert array.dtype == int_item.dtype + assert int_item.scalar() == 3 + + +def test_scalar() -> None: + array = Array([1, 2, 3]) + assert array[1].scalar() == 2 + + +def test_scalar_is_empty() -> None: + array = Array() + assert array.scalar() is None + + +def test_array_to_list() -> None: + array = Array([1, 2, 3]) + assert array.to_list() == [1, 2, 3] + + +def test_array_to_list_is_empty() -> None: + array = Array() + assert array.to_list() == [] diff --git a/tests/array_object/test_operators.py b/tests/array_object/test_operators.py new file mode 100644 index 0000000..3634cd3 --- /dev/null +++ b/tests/array_object/test_operators.py @@ -0,0 +1,121 @@ +import operator +from typing import Any, Callable, List, Union + +import pytest + +from arrayfire.array_object import Array +from arrayfire.dtypes import bool as af_bool + +Operator = Callable[[Union[int, float, Array], Union[int, float, Array]], Array] + + +def _round(list_: List[Union[int, float]], symbols: int = 4) -> List[Union[int, float]]: + # HACK replace for e.g. abs(x1-x2) < 1e-6 ~ https://davidamos.dev/the-right-way-to-compare-floats-in-python/ + return [round(x, symbols) for x in list_] + + +def pytest_generate_tests(metafunc: Any) -> None: + if "array_origin" in metafunc.fixturenames: + metafunc.parametrize("array_origin", [ + [1, 2, 3], + # [4.2, 7.5, 5.41] # FIXME too big difference between python pow and af backend + ]) + if "arithmetic_operator" in metafunc.fixturenames: + metafunc.parametrize("arithmetic_operator", [ + "add", # __add__, __iadd__, __radd__ + "sub", # __sub__, __isub__, __rsub__ + "mul", # __mul__, __imul__, __rmul__ + "truediv", # __truediv__, __itruediv__, __rtruediv__ + # "floordiv", # __floordiv__, __ifloordiv__, __rfloordiv__ # TODO + "mod", # __mod__, __imod__, __rmod__ + "pow", # __pow__, __ipow__, __rpow__, + ]) + if "array_operator" in metafunc.fixturenames: + metafunc.parametrize("array_operator", [ + operator.matmul, + operator.imatmul + ]) + if "comparison_operator" in metafunc.fixturenames: + metafunc.parametrize("comparison_operator", [ + operator.lt, + operator.le, + operator.gt, + operator.ge, + operator.eq, + operator.ne + ]) + if "operand" in metafunc.fixturenames: + metafunc.parametrize("operand", [ + 2, + 1.5, + [9, 9, 9], + ]) + if "false_operand" in metafunc.fixturenames: + metafunc.parametrize("false_operand", [ + (1, 2, 3), + ("2"), + {2.34, 523.2}, + "15" + ]) + + +def test_arithmetic_operators( + array_origin: List[Union[int, float]], arithmetic_operator: str, + operand: Union[int, float, List[Union[int, float]]]) -> None: + op = getattr(operator, arithmetic_operator) + iop = getattr(operator, "i" + arithmetic_operator) + + if isinstance(operand, list): + ref = [op(x, y) for x, y in zip(array_origin, operand)] + rref = [op(y, x) for x, y in zip(array_origin, operand)] + operand = Array(operand) # type: ignore[assignment] + else: + ref = [op(x, operand) for x in array_origin] + rref = [op(operand, x) for x in array_origin] + + array = Array(array_origin) + + res = op(array, operand) + ires = iop(array, operand) + rres = op(operand, array) + + assert _round(res.to_list()) == _round(ires.to_list()) == _round(ref) + assert _round(rres.to_list()) == _round(rref) + + assert res.dtype == ires.dtype == rres.dtype + assert res.ndim == ires.ndim == rres.ndim + assert res.size == ires.size == ires.size + assert res.shape == ires.shape == rres.shape + assert len(res) == len(ires) == len(rres) + + +def test_arithmetic_operators_expected_to_raise_error( + array_origin: List[Union[int, float]], arithmetic_operator: str, false_operand: Any) -> None: + array = Array(array_origin) + op = getattr(operator, arithmetic_operator) + with pytest.raises(TypeError): + op(array, false_operand) + + +def test_comparison_operators( + array_origin: List[Union[int, float]], comparison_operator: Operator, + operand: Union[int, float, List[Union[int, float]]]) -> None: + if isinstance(operand, list): + ref = [comparison_operator(x, y) for x, y in zip(array_origin, operand)] + operand = Array(operand) # type: ignore[assignment] + else: + ref = [comparison_operator(x, operand) for x in array_origin] + + array = Array(array_origin) + res = comparison_operator(array, operand) # type: ignore[arg-type] + + assert res.to_list() == ref + assert res.dtype == af_bool + + +def test_comparison_operators_expected_to_raise_error( + array_origin: List[Union[int, float]], comparison_operator: Operator, false_operand: Any) -> None: + array = Array(array_origin) + + with pytest.raises(TypeError): + comparison_operator(array, false_operand) diff --git a/tests/test_operators.py b/tests/test_operators.py new file mode 100644 index 0000000..f13d0fe --- /dev/null +++ b/tests/test_operators.py @@ -0,0 +1,20 @@ +from typing import Any + +from arrayfire import operators +from arrayfire.array_object import Array + + +class TestArithmeticOperators: + def setup_method(self, method: Any) -> None: + self.array1 = Array([1, 2, 3]) + self.array2 = Array([4, 5, 6]) + + def test_add(self) -> None: + res = operators.add(self.array1, self.array2) + res_sum = self.array1 + self.array2 + assert res.to_list() == res_sum.to_list() == [5, 7, 9] + + def test_sub(self) -> None: + res = operators.sub(self.array1, self.array2) + res_sum = self.array1 - self.array2 + assert res.to_list() == res_sum.to_list() == [-3, -3, -3] From 4f26024fa8c99d4c1a8e91971f4be505fc54b765 Mon Sep 17 00:00:00 2001 From: Anton Chernyatevich Date: Wed, 14 Jun 2023 00:27:04 +0300 Subject: [PATCH 2/6] Structure refactoring --- arrayfire/__init__.py | 2 +- arrayfire/array/__init__.py | 0 arrayfire/{ => array}/array_object.py | 58 ++--- arrayfire/backend/backend.py | 3 +- arrayfire/backend/{library.py => wrapped.py} | 4 +- arrayfire/dtypes/functions.py | 2 +- arrayfire/library/__init__.py | 0 arrayfire/library/broadcast.py | 86 +++++++ arrayfire/{ => library}/device.py | 0 arrayfire/library/index.py | 257 +++++++++++++++++++ arrayfire/{ => library}/operators.py | 4 +- arrayfire/{ => library}/utils.py | 2 +- tests/array_object/test_initialization.py | 2 +- tests/array_object/test_methods.py | 2 +- tests/array_object/test_operators.py | 2 +- tests/test_operators.py | 4 +- 16 files changed, 386 insertions(+), 42 deletions(-) create mode 100755 arrayfire/array/__init__.py rename arrayfire/{ => array}/array_object.py (92%) rename arrayfire/backend/{library.py => wrapped.py} (98%) create mode 100755 arrayfire/library/__init__.py create mode 100755 arrayfire/library/broadcast.py rename arrayfire/{ => library}/device.py (100%) create mode 100755 arrayfire/library/index.py rename arrayfire/{ => library}/operators.py (88%) rename arrayfire/{ => library}/utils.py (89%) diff --git a/arrayfire/__init__.py b/arrayfire/__init__.py index 675b27a..9775705 100644 --- a/arrayfire/__init__.py +++ b/arrayfire/__init__.py @@ -5,5 +5,5 @@ "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64", "float32", "float64", "complex64", "complex128", "bool"] -from .array_object import Array +from .array.array_object import Array from .dtypes import bool, complex64, complex128, float32, float64, int16, int32, int64, uint8, uint16, uint32, uint64 diff --git a/arrayfire/array/__init__.py b/arrayfire/array/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/arrayfire/array_object.py b/arrayfire/array/array_object.py similarity index 92% rename from arrayfire/array_object.py rename to arrayfire/array/array_object.py index 5777871..f31f551 100755 --- a/arrayfire/array_object.py +++ b/arrayfire/array/array_object.py @@ -5,15 +5,15 @@ import enum from typing import Any, List, Optional, Tuple, Union -from . import backend -from .backend import ArrayBuffer, library -from .backend.operators import count_all -from .backend.constant_array import create_constant_array -from .device import PointerSource -from .dtypes import CType -from .dtypes import bool as af_bool -from .dtypes import float32 as af_float32 -from .dtypes.helpers import Dtype, c_api_value_to_dtype, str_to_dtype +from .. import backend +from ..backend import ArrayBuffer, wrapped +from ..backend.operators import count_all +from ..backend.constant_array import create_constant_array +from ..library.device import PointerSource +from ..dtypes import CType +from ..dtypes import bool as af_bool +from ..dtypes import float32 as af_float32 +from ..dtypes.helpers import Dtype, c_api_value_to_dtype, str_to_dtype # TODO use int | float in operators -> remove bool | complex support @@ -35,14 +35,14 @@ def __init__( if x is None: if not shape: # shape is None or empty tuple - self.arr = library.create_handle((), dtype) + self.arr = wrapped.create_handle((), dtype) return - self.arr = library.create_handle(shape, dtype) + self.arr = wrapped.create_handle(shape, dtype) return if isinstance(x, Array): - self.arr = library.retain_array(x.arr) + self.arr = wrapped.retain_array(x.arr) return if isinstance(x, py_array.array): @@ -79,13 +79,13 @@ def __init__( if not (offset or strides): if pointer_source == PointerSource.host: - self.arr = library.create_array(shape, dtype, _array_buffer) + self.arr = wrapped.create_array(shape, dtype, _array_buffer) return - self.arr = library.device_array(shape, dtype, _array_buffer) + self.arr = wrapped.device_array(shape, dtype, _array_buffer) return - self.arr = library.create_strided_array( + self.arr = wrapped.create_strided_array( shape, dtype, _array_buffer, offset, strides, pointer_source) # type: ignore[arg-type] # Arithmetic Operators @@ -722,11 +722,11 @@ def __getitem__(self, key: Union[int, slice, Tuple[Union[int, slice, ], ...], Ar if isinstance(key, Array) and key == af_bool.c_api_value: ndims = 1 - if count_all(key.arr) == 0: + if count_all(key.arr) == 0: # HACK was count() method before return out # HACK known issue - out.arr = library.index_gen(self.arr, ndims, key) # type: ignore[arg-type] + out.arr = wrapped.index_gen(self.arr, ndims, key) # type: ignore[arg-type] return out def __index__(self) -> int: @@ -750,12 +750,12 @@ def __str__(self) -> str: # TODO change the look of array str. E.g., like np.array # if not _in_display_dims_limit(self.shape): # return _metadata_string(self.dtype, self.shape) - return _metadata_string(self.dtype) + library.array_as_str(self.arr) + return _metadata_string(self.dtype) + wrapped.array_as_str(self.arr) def __repr__(self) -> str: # return _metadata_string(self.dtype, self.shape) # TODO change the look of array representation. E.g., like np.array - return library.array_as_str(self.arr) + return wrapped.array_as_str(self.arr) def to_device(self, device: Any, /, *, stream: Union[int, Any] = None) -> Array: # TODO implementation and change device type from Any to Device @@ -773,7 +773,7 @@ def dtype(self) -> Dtype: out : Dtype Array data type. """ - return c_api_value_to_dtype(library.get_ctype(self.arr)) + return c_api_value_to_dtype(wrapped.get_ctype(self.arr)) @property def device(self) -> Any: @@ -806,7 +806,7 @@ def T(self) -> Array: # TODO add check if out.dtype == self.dtype out = Array() - out.arr = library.transpose(self.arr, False) + out.arr = wrapped.transpose(self.arr, False) return out @property @@ -824,7 +824,7 @@ def size(self) -> int: - This must equal the product of the array's dimensions. """ # NOTE previously - elements() - return library.get_elements(self.arr) + return wrapped.get_elements(self.arr) @property def ndim(self) -> int: @@ -834,7 +834,7 @@ def ndim(self) -> int: out : int Number of array dimensions (axes). """ - return library.get_numdims(self.arr) + return wrapped.get_numdims(self.arr) @property def shape(self) -> Tuple[int, ...]: @@ -847,7 +847,7 @@ def shape(self) -> Tuple[int, ...]: Array dimensions. """ # NOTE skipping passing any None values - return library.get_dims(self.arr)[:self.ndim] + return wrapped.get_dims(self.arr)[:self.ndim] def scalar(self) -> Union[None, int, float, bool, complex]: """ @@ -857,20 +857,20 @@ def scalar(self) -> Union[None, int, float, bool, complex]: if self.is_empty(): return None - return library.get_scalar(self.arr, self.dtype) + return wrapped.get_scalar(self.arr, self.dtype) def is_empty(self) -> bool: """ Check if the array is empty i.e. it has no elements. """ - return library.is_empty(self.arr) + return wrapped.is_empty(self.arr) def to_list(self, row_major: bool = False) -> List[Union[None, int, float, bool, complex]]: if self.is_empty(): return [] array = _reorder(self) if row_major else self - ctypes_array = library.get_data_ptr(array.arr, array.size, array.dtype) + ctypes_array = wrapped.get_data_ptr(array.arr, array.size, array.dtype) if array.ndim == 1: return list(ctypes_array) @@ -891,7 +891,7 @@ def to_ctype_array(self, row_major: bool = False) -> ctypes.Array: raise RuntimeError("Can not convert an empty array to ctype.") array = _reorder(self) if row_major else self - return library.get_data_ptr(array.arr, array.size, array.dtype) + return wrapped.get_data_ptr(array.arr, array.size, array.dtype) def _reorder(array: Array) -> Array: @@ -902,7 +902,7 @@ def _reorder(array: Array) -> Array: return array out = Array() - out.arr = library.reorder(array.arr, array.ndim) + out.arr = wrapped.reorder(array.arr, array.ndim) return out diff --git a/arrayfire/backend/backend.py b/arrayfire/backend/backend.py index a154555..c7d57b8 100644 --- a/arrayfire/backend/backend.py +++ b/arrayfire/backend/backend.py @@ -4,7 +4,8 @@ from ..dtypes.helpers import c_dim_t, to_str -backend_api = ctypes.CDLL("/opt/arrayfire//lib/libafcpu.3.dylib") # Mock +# backend_api = ctypes.CDLL("/opt/arrayfire//lib/libafcpu.3.dylib") # Mock +backend_api = ctypes.CDLL("C:/Program Files/ArrayFire/v3/lib/afcpu.dll") def safe_call(c_err: int) -> None: diff --git a/arrayfire/backend/library.py b/arrayfire/backend/wrapped.py similarity index 98% rename from arrayfire/backend/library.py rename to arrayfire/backend/wrapped.py index 9c87aa5..5918a8f 100644 --- a/arrayfire/backend/library.py +++ b/arrayfire/backend/wrapped.py @@ -1,9 +1,9 @@ import ctypes from typing import Tuple, Union, cast -from arrayfire.array import _get_indices # HACK replace with refactored one +# from arrayfire.array import _get_indices # HACK replace with refactored one -from ..device import PointerSource +from ..library.device import PointerSource from ..dtypes import CType, Dtype from ..dtypes.helpers import CShape, c_dim_t, to_str from .backend import ArrayBuffer, backend_api, safe_call diff --git a/arrayfire/dtypes/functions.py b/arrayfire/dtypes/functions.py index 155dc61..865980c 100644 --- a/arrayfire/dtypes/functions.py +++ b/arrayfire/dtypes/functions.py @@ -1,6 +1,6 @@ from typing import Tuple, Union -from ..array_object import Array +from ..array.array_object import Array from . import Dtype # TODO implement functions diff --git a/arrayfire/library/__init__.py b/arrayfire/library/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/arrayfire/library/broadcast.py b/arrayfire/library/broadcast.py new file mode 100755 index 0000000..8946670 --- /dev/null +++ b/arrayfire/library/broadcast.py @@ -0,0 +1,86 @@ +class _bcast(object): + _flag = False + + def get(self): + return _bcast._flag + + def set(self, flag): + _bcast._flag = flag + + def toggle(self): + _bcast._flag ^= True + + +_bcast_var = _bcast() + + +def broadcast(func, *args): + """ + Function to perform broadcast operations. + + This function can be used directly or as an annotation in the following manner. + + Example + ------- + + Using broadcast as an annotation + + >>> import arrayfire as af + >>> @af.broadcast + ... def add(a, b): + ... return a + b + ... + >>> a = af.randu(2,3) + >>> b = af.randu(2,1) # b is a different size + >>> # Trying to add arrays of different sizes raises an exceptions + >>> c = add(a, b) # This call does not raise an exception because of the annotation + >>> af.display(a) + [2 3 1 1] + 0.4107 0.9518 0.4198 + 0.8224 0.1794 0.0081 + + >>> af.display(b) + [2 1 1 1] + 0.7269 + 0.7104 + + >>> af.display(c) + [2 3 1 1] + 1.1377 1.6787 1.1467 + 1.5328 0.8898 0.7185 + + Using broadcast as function + + >>> import arrayfire as af + >>> add = lambda a,b: a + b + >>> a = af.randu(2,3) + >>> b = af.randu(2,1) # b is a different size + >>> # Trying to add arrays of different sizes raises an exceptions + >>> c = af.broadcast(add, a, b) # This call does not raise an exception + >>> af.display(a) + [2 3 1 1] + 0.4107 0.9518 0.4198 + 0.8224 0.1794 0.0081 + + >>> af.display(b) + [2 1 1 1] + 0.7269 + 0.7104 + + >>> af.display(c) + [2 3 1 1] + 1.1377 1.6787 1.1467 + 1.5328 0.8898 0.7185 + + """ + + def wrapper(*func_args): + _bcast_var.toggle() + res = func(*func_args) + _bcast_var.toggle() + return res + + if len(args) == 0: + return wrapper + else: + return wrapper(*args) diff --git a/arrayfire/device.py b/arrayfire/library/device.py similarity index 100% rename from arrayfire/device.py rename to arrayfire/library/device.py diff --git a/arrayfire/library/index.py b/arrayfire/library/index.py new file mode 100755 index 0000000..55cc808 --- /dev/null +++ b/arrayfire/library/index.py @@ -0,0 +1,257 @@ +import ctypes +import math +import numbers +from typing import Union + + +class IndexSequence(ctypes.Structure): + """ + arrayfire equivalent of slice + + Attributes + ---------- + + begin: number + Start of the sequence. + + end : number + End of sequence. + + step : number + Step size. + + Parameters + ---------- + + S: slice or number. + + """ + _fields_ = [("begin", ctypes.c_double), + ("end", ctypes.c_double), + ("step", ctypes.c_double)] + + def __init__(self, S: Union[numbers.Number, slice, None]): + self.begin = ctypes.c_double(0) + self.end = ctypes.c_double(-1) + self.step = ctypes.c_double(1) + + if isinstance(slice, numbers.Number): + self.begin = ctypes.c_double(S) + self.end = ctypes.c_double(S) + + elif isinstance(S, slice): + if S.step: + self.step = ctypes.c_double(S.step) + if S.step < 0: + self.begin, self.end = self.end, self.begin + + if S.start: + self.begin = ctypes.c_double(S.start) + + if S.stop: + self.end = ctypes.c_double(S.stop) + + # handle special cases + if self.begin >= 0 and self.end >= 0 and self.end <= self.begin and self.step >= 0: + self.begin = 1 + self.end = 1 + self.step = 1 + + elif self.begin < 0 and self.end < 0 and self.end >= self.begin and self.step <= 0: + self.begin = -2 + self.end = -2 + self.step = -1 + + if S.stop: + self.end = self.end - math.copysign(1, self.step) + else: + raise IndexError("Invalid type while indexing arrayfire.array") + + +class ParallelRange(IndexSequence): + + """ + Class used to parallelize for loop. + + Inherits from Seq. + + Attributes + ---------- + + S: slice + + Parameters + ---------- + + start: number + Beginning of parallel range. + + stop : number + End of parallel range. + + step : number + Step size for parallel range. + + Examples + -------- + + >>> import arrayfire as af + >>> a = af.randu(3, 3) + >>> b = af.randu(3, 1) + >>> c = af.constant(0, 3, 3) + >>> for ii in af.ParallelRange(3): + ... c[:, ii] = a[:, ii] + b + ... + >>> af.display(a) + [3 3 1 1] + 0.4107 0.1794 0.3775 + 0.8224 0.4198 0.3027 + 0.9518 0.0081 0.6456 + + >>> af.display(b) + [3 1 1 1] + 0.7269 + 0.7104 + 0.5201 + + >>> af.display(c) + [3 3 1 1] + 1.1377 0.9063 1.1045 + 1.5328 1.1302 1.0131 + 1.4719 0.5282 1.1657 + + """ + def __init__(self, start, stop=None, step=None): + + if (stop is None): + stop = start + start = 0 + + self.S = slice(start, stop, step) + super(ParallelRange, self).__init__(self.S) + + def __iter__(self): + return self + + def next(self): + """ + Function called by the iterator in Python 2 + """ + if _bcast_var.get() is True: + _bcast_var.toggle() + raise StopIteration + else: + _bcast_var.toggle() + return self + + def __next__(self): + """ + Function called by the iterator in Python 3 + """ + return self.next() + +class IndexUnion(ctypes.Union): + _fields_ = [("arr", ctypes.c_void_p), + ("seq", IndexSequence)] + + +class IndexStructure(ctypes.Structure): + _fields_ = [("idx", IndexUnion), + ("isSeq", ctypes.c_bool), + ("isBatch", ctypes.c_bool)] + + """ + Container for the index class in arrayfire C library + + Attributes + ---------- + idx.arr: ctypes.c_void_p + - Default 0 + + idx.seq: af.Seq + - Default af.Seq(0, -1, 1) + + isSeq : bool + - Default True + + isBatch : bool + - Default False + + Parameters + ----------- + + idx: key + - If of type af.Array, self.idx.arr = idx, self.isSeq = False + - If of type af.ParallelRange, self.idx.seq = idx, self.isBatch = True + - Default:, self.idx.seq = af.Seq(idx) + + Note + ---- + + Implemented for internal use only. Use with extreme caution. + + """ + + def __init__(self, idx): + self.idx = IndexUnion() + self.isBatch = False + self.isSeq = True + + if isinstance(idx, ctypes.c_void_p): + + arr = ctypes.c_void_p(0) + + if (idx.type() == Dtype.b8.value): + safe_call(backend.get().af_where(c_pointer(arr), idx.arr)) + else: + safe_call(backend.get().af_retain_array(c_pointer(arr), idx.arr)) + + self.idx.arr = arr + self.isSeq = False + elif isinstance(idx, ParallelRange): + self.idx.seq = idx + self.isBatch = True + else: + self.idx.seq = Seq(idx) + + def __del__(self) -> None: + if not self.isSeq: + # ctypes field variables are automatically + # converted to basic C types so we have to + # build the void_p from the value again. + arr = c_void_ptr_t(self.idx.arr) + backend.get().af_release_array(arr) + + +class _Index4(object): + def __init__(self) -> None: + index_vec = IndexStructure * 4 + _span = IndexStructure(slice(None)) + self.array = index_vec(_span, _span, _span, _span) + # Do not lose those idx as self.array keeps + # no reference to them. Otherwise the destructor + # is prematurely called + self.idxs = [_span, _span, _span, _span] + + @property + def pointer(self): + return ctypes.pointer(self.array) + + def __getitem__(self, idx): + return self.array[idx] + + def __setitem__(self, idx, value): + self.array[idx] = value + self.idxs[idx] = value + + +def _get_indices(key): + inds = _Index4() + if isinstance(key, tuple): + n_idx = len(key) + for n in range(n_idx): + inds[n] = Index(key[n]) + else: + inds[0] = Index(key) + + return inds diff --git a/arrayfire/operators.py b/arrayfire/library/operators.py similarity index 88% rename from arrayfire/operators.py rename to arrayfire/library/operators.py index 7cb4c07..4d27fdd 100644 --- a/arrayfire/operators.py +++ b/arrayfire/library/operators.py @@ -1,7 +1,7 @@ from typing import Callable -from . import backend -from .array_object import Array +from .. import backend +from ..array.array_object import Array class return_copy: diff --git a/arrayfire/utils.py b/arrayfire/library/utils.py similarity index 89% rename from arrayfire/utils.py rename to arrayfire/library/utils.py index 195111d..7a787bc 100644 --- a/arrayfire/utils.py +++ b/arrayfire/library/utils.py @@ -1,6 +1,6 @@ from typing import Tuple, Union -from .array_object import Array +from ..array.array_object import Array # TODO implement functions diff --git a/tests/array_object/test_initialization.py b/tests/array_object/test_initialization.py index 81f90e5..207a852 100644 --- a/tests/array_object/test_initialization.py +++ b/tests/array_object/test_initialization.py @@ -4,7 +4,7 @@ import pytest -from arrayfire.array_object import Array +from arrayfire.array.array_object import Array from arrayfire.dtypes import Dtype, float32, int16 # TODO add tests for array arguments: device, offset, strides diff --git a/tests/array_object/test_methods.py b/tests/array_object/test_methods.py index f1292a6..15a8c17 100644 --- a/tests/array_object/test_methods.py +++ b/tests/array_object/test_methods.py @@ -1,4 +1,4 @@ -from arrayfire.array_object import Array +from arrayfire.array.array_object import Array # TODO add more tests for different dtypes diff --git a/tests/array_object/test_operators.py b/tests/array_object/test_operators.py index 3634cd3..290ce6d 100644 --- a/tests/array_object/test_operators.py +++ b/tests/array_object/test_operators.py @@ -3,7 +3,7 @@ import pytest -from arrayfire.array_object import Array +from arrayfire.array.array_object import Array from arrayfire.dtypes import bool as af_bool Operator = Callable[[Union[int, float, Array], Union[int, float, Array]], Array] diff --git a/tests/test_operators.py b/tests/test_operators.py index f13d0fe..04c4741 100644 --- a/tests/test_operators.py +++ b/tests/test_operators.py @@ -1,7 +1,7 @@ from typing import Any -from arrayfire import operators -from arrayfire.array_object import Array +from arrayfire.library import operators +from arrayfire.array.array_object import Array class TestArithmeticOperators: From dcfef3dde0ea84c8ed68f4ab1bea9f257810606f Mon Sep 17 00:00:00 2001 From: Anton Chernyatevich Date: Wed, 14 Jun 2023 02:09:47 +0300 Subject: [PATCH 3/6] Change index and bcast --- arrayfire/array/__init__.py | 3 + arrayfire/array/array_object.py | 47 +++--- arrayfire/backend/__init__.py | 2 +- arrayfire/backend/wrapped/__init__.py | 0 .../index.py => backend/wrapped/_indexing.py} | 138 +++++++++--------- .../backend/{ => wrapped}/constant_array.py | 7 +- .../{wrapped.py => wrapped/everything.py} | 19 ++- arrayfire/backend/{ => wrapped}/operators.py | 16 +- .../backend/wrapped/reduction_operations.py | 18 +++ arrayfire/library/broadcast.py | 1 + 10 files changed, 140 insertions(+), 111 deletions(-) create mode 100755 arrayfire/backend/wrapped/__init__.py rename arrayfire/{library/index.py => backend/wrapped/_indexing.py} (56%) rename arrayfire/backend/{ => wrapped}/constant_array.py (94%) rename arrayfire/backend/{wrapped.py => wrapped/everything.py} (93%) rename arrayfire/backend/{ => wrapped}/operators.py (89%) create mode 100755 arrayfire/backend/wrapped/reduction_operations.py diff --git a/arrayfire/array/__init__.py b/arrayfire/array/__init__.py index e69de29..a040af0 100755 --- a/arrayfire/array/__init__.py +++ b/arrayfire/array/__init__.py @@ -0,0 +1,3 @@ +__all__ = ["Array"] + +from .array_object import Array diff --git a/arrayfire/array/array_object.py b/arrayfire/array/array_object.py index f31f551..a5bc519 100755 --- a/arrayfire/array/array_object.py +++ b/arrayfire/array/array_object.py @@ -6,14 +6,15 @@ from typing import Any, List, Optional, Tuple, Union from .. import backend -from ..backend import ArrayBuffer, wrapped -from ..backend.operators import count_all -from ..backend.constant_array import create_constant_array -from ..library.device import PointerSource +from ..backend import ArrayBuffer +from ..backend.wrapped import everything +from ..backend.wrapped.constant_array import create_constant_array +from ..backend.wrapped.reduction_operations import count_all from ..dtypes import CType from ..dtypes import bool as af_bool from ..dtypes import float32 as af_float32 from ..dtypes.helpers import Dtype, c_api_value_to_dtype, str_to_dtype +from ..library.device import PointerSource # TODO use int | float in operators -> remove bool | complex support @@ -35,14 +36,14 @@ def __init__( if x is None: if not shape: # shape is None or empty tuple - self.arr = wrapped.create_handle((), dtype) + self.arr = everything.create_handle((), dtype) return - self.arr = wrapped.create_handle(shape, dtype) + self.arr = everything.create_handle(shape, dtype) return if isinstance(x, Array): - self.arr = wrapped.retain_array(x.arr) + self.arr = everything.retain_array(x.arr) return if isinstance(x, py_array.array): @@ -79,13 +80,13 @@ def __init__( if not (offset or strides): if pointer_source == PointerSource.host: - self.arr = wrapped.create_array(shape, dtype, _array_buffer) + self.arr = everything.create_array(shape, dtype, _array_buffer) return - self.arr = wrapped.device_array(shape, dtype, _array_buffer) + self.arr = everything.device_array(shape, dtype, _array_buffer) return - self.arr = wrapped.create_strided_array( + self.arr = everything.create_strided_array( shape, dtype, _array_buffer, offset, strides, pointer_source) # type: ignore[arg-type] # Arithmetic Operators @@ -726,7 +727,7 @@ def __getitem__(self, key: Union[int, slice, Tuple[Union[int, slice, ], ...], Ar return out # HACK known issue - out.arr = wrapped.index_gen(self.arr, ndims, key) # type: ignore[arg-type] + out.arr = everything.index_gen(self.arr, ndims, key) # type: ignore[arg-type] return out def __index__(self) -> int: @@ -750,12 +751,12 @@ def __str__(self) -> str: # TODO change the look of array str. E.g., like np.array # if not _in_display_dims_limit(self.shape): # return _metadata_string(self.dtype, self.shape) - return _metadata_string(self.dtype) + wrapped.array_as_str(self.arr) + return _metadata_string(self.dtype) + everything.array_as_str(self.arr) def __repr__(self) -> str: # return _metadata_string(self.dtype, self.shape) # TODO change the look of array representation. E.g., like np.array - return wrapped.array_as_str(self.arr) + return everything.array_as_str(self.arr) def to_device(self, device: Any, /, *, stream: Union[int, Any] = None) -> Array: # TODO implementation and change device type from Any to Device @@ -773,7 +774,7 @@ def dtype(self) -> Dtype: out : Dtype Array data type. """ - return c_api_value_to_dtype(wrapped.get_ctype(self.arr)) + return c_api_value_to_dtype(everything.get_ctype(self.arr)) @property def device(self) -> Any: @@ -806,7 +807,7 @@ def T(self) -> Array: # TODO add check if out.dtype == self.dtype out = Array() - out.arr = wrapped.transpose(self.arr, False) + out.arr = everything.transpose(self.arr, False) return out @property @@ -824,7 +825,7 @@ def size(self) -> int: - This must equal the product of the array's dimensions. """ # NOTE previously - elements() - return wrapped.get_elements(self.arr) + return everything.get_elements(self.arr) @property def ndim(self) -> int: @@ -834,7 +835,7 @@ def ndim(self) -> int: out : int Number of array dimensions (axes). """ - return wrapped.get_numdims(self.arr) + return everything.get_numdims(self.arr) @property def shape(self) -> Tuple[int, ...]: @@ -847,7 +848,7 @@ def shape(self) -> Tuple[int, ...]: Array dimensions. """ # NOTE skipping passing any None values - return wrapped.get_dims(self.arr)[:self.ndim] + return everything.get_dims(self.arr)[:self.ndim] def scalar(self) -> Union[None, int, float, bool, complex]: """ @@ -857,20 +858,20 @@ def scalar(self) -> Union[None, int, float, bool, complex]: if self.is_empty(): return None - return wrapped.get_scalar(self.arr, self.dtype) + return everything.get_scalar(self.arr, self.dtype) def is_empty(self) -> bool: """ Check if the array is empty i.e. it has no elements. """ - return wrapped.is_empty(self.arr) + return everything.is_empty(self.arr) def to_list(self, row_major: bool = False) -> List[Union[None, int, float, bool, complex]]: if self.is_empty(): return [] array = _reorder(self) if row_major else self - ctypes_array = wrapped.get_data_ptr(array.arr, array.size, array.dtype) + ctypes_array = everything.get_data_ptr(array.arr, array.size, array.dtype) if array.ndim == 1: return list(ctypes_array) @@ -891,7 +892,7 @@ def to_ctype_array(self, row_major: bool = False) -> ctypes.Array: raise RuntimeError("Can not convert an empty array to ctype.") array = _reorder(self) if row_major else self - return wrapped.get_data_ptr(array.arr, array.size, array.dtype) + return everything.get_data_ptr(array.arr, array.size, array.dtype) def _reorder(array: Array) -> Array: @@ -902,7 +903,7 @@ def _reorder(array: Array) -> Array: return array out = Array() - out.arr = wrapped.reorder(array.arr, array.ndim) + out.arr = everything.reorder(array.arr, array.ndim) return out diff --git a/arrayfire/backend/__init__.py b/arrayfire/backend/__init__.py index 30253b4..cd01bcd 100644 --- a/arrayfire/backend/__init__.py +++ b/arrayfire/backend/__init__.py @@ -6,5 +6,5 @@ "le", "gt", "ge", "eq", "neq"] from .backend import ArrayBuffer -from .operators import ( +from .wrapped.operators import ( add, bitand, bitnot, bitor, bitshiftl, bitshiftr, bitxor, div, eq, ge, gt, le, lt, mod, mul, neq, pow, sub) diff --git a/arrayfire/backend/wrapped/__init__.py b/arrayfire/backend/wrapped/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/arrayfire/library/index.py b/arrayfire/backend/wrapped/_indexing.py similarity index 56% rename from arrayfire/library/index.py rename to arrayfire/backend/wrapped/_indexing.py index 55cc808..14fd288 100755 --- a/arrayfire/library/index.py +++ b/arrayfire/backend/wrapped/_indexing.py @@ -1,10 +1,15 @@ +from __future__ import annotations + import ctypes import math -import numbers from typing import Union +from arrayfire.array.array_object import Array +from arrayfire.backend.wrapped import everything +from arrayfire.dtypes import bool as af_bool + -class IndexSequence(ctypes.Structure): +class _IndexSequence(ctypes.Structure): """ arrayfire equivalent of slice @@ -23,36 +28,38 @@ class IndexSequence(ctypes.Structure): Parameters ---------- - S: slice or number. + chunk: slice or number. """ - _fields_ = [("begin", ctypes.c_double), - ("end", ctypes.c_double), - ("step", ctypes.c_double)] + # More about _fields_ purpose: https://docs.python.org/3/library/ctypes.html#structures-and-unions + _fields_ = [ + ("begin", ctypes.c_double), + ("end", ctypes.c_double), + ("step", ctypes.c_double)] - def __init__(self, S: Union[numbers.Number, slice, None]): + def __init__(self, chunk: Union[int, slice]): self.begin = ctypes.c_double(0) self.end = ctypes.c_double(-1) self.step = ctypes.c_double(1) - if isinstance(slice, numbers.Number): - self.begin = ctypes.c_double(S) - self.end = ctypes.c_double(S) + if isinstance(chunk, int): + self.begin = ctypes.c_double(chunk) + self.end = ctypes.c_double(chunk) - elif isinstance(S, slice): - if S.step: - self.step = ctypes.c_double(S.step) - if S.step < 0: + elif isinstance(chunk, slice): + if chunk.step: + self.step = ctypes.c_double(chunk.step) + if chunk.step < 0: self.begin, self.end = self.end, self.begin - if S.start: - self.begin = ctypes.c_double(S.start) + if chunk.start: + self.begin = ctypes.c_double(chunk.start) - if S.stop: - self.end = ctypes.c_double(S.stop) + if chunk.stop: + self.end = ctypes.c_double(chunk.stop) # handle special cases - if self.begin >= 0 and self.end >= 0 and self.end <= self.begin and self.step >= 0: + if self.begin >= 0 and self.end >= 0 and 0 <= self.end <= self.begin and self.step >= 0: self.begin = 1 self.end = 1 self.step = 1 @@ -62,23 +69,23 @@ def __init__(self, S: Union[numbers.Number, slice, None]): self.end = -2 self.step = -1 - if S.stop: + if chunk.stop: self.end = self.end - math.copysign(1, self.step) else: raise IndexError("Invalid type while indexing arrayfire.array") -class ParallelRange(IndexSequence): +class ParallelRange(_IndexSequence): """ Class used to parallelize for loop. - Inherits from Seq. + Inherits from _IndexSequence. Attributes ---------- - S: slice + chunk: slice Parameters ---------- @@ -121,19 +128,19 @@ class ParallelRange(IndexSequence): 1.4719 0.5282 1.1657 """ - def __init__(self, start, stop=None, step=None): - if (stop is None): + def __init__(self, start, stop=None, step=None): + if not stop: stop = start start = 0 - self.S = slice(start, stop, step) - super(ParallelRange, self).__init__(self.S) + self.chunk = slice(start, stop, step) + super().__init__(self.chunk) - def __iter__(self): + def __iter__(self) -> ParallelRange: return self - def next(self): + def next(self) -> ParallelRange: """ Function called by the iterator in Python 2 """ @@ -144,21 +151,24 @@ def next(self): _bcast_var.toggle() return self - def __next__(self): + def __next__(self) -> ParallelRange: """ Function called by the iterator in Python 3 """ return self.next() -class IndexUnion(ctypes.Union): - _fields_ = [("arr", ctypes.c_void_p), - ("seq", IndexSequence)] + +class _IndexUnion(ctypes.Union): + _fields_ = [ + ("arr", ctypes.c_void_p), + ("seq", _IndexSequence)] -class IndexStructure(ctypes.Structure): - _fields_ = [("idx", IndexUnion), - ("isSeq", ctypes.c_bool), - ("isBatch", ctypes.c_bool)] +class _IndexStructure(ctypes.Structure): + _fields_ = [ + ("idx", _IndexUnion), + ("isSeq", ctypes.c_bool), + ("isBatch", ctypes.c_bool)] """ Container for the index class in arrayfire C library @@ -168,8 +178,8 @@ class IndexStructure(ctypes.Structure): idx.arr: ctypes.c_void_p - Default 0 - idx.seq: af.Seq - - Default af.Seq(0, -1, 1) + idx.seq: af._IndexSequence + - Default af._IndexSequence(0, -1, 1) isSeq : bool - Default True @@ -183,7 +193,7 @@ class IndexStructure(ctypes.Structure): idx: key - If of type af.Array, self.idx.arr = idx, self.isSeq = False - If of type af.ParallelRange, self.idx.seq = idx, self.isBatch = True - - Default:, self.idx.seq = af.Seq(idx) + - Default:, self.idx.seq = af._IndexSequence(idx) Note ---- @@ -193,48 +203,44 @@ class IndexStructure(ctypes.Structure): """ def __init__(self, idx): - self.idx = IndexUnion() + self.idx = _IndexUnion() self.isBatch = False self.isSeq = True - if isinstance(idx, ctypes.c_void_p): - - arr = ctypes.c_void_p(0) - - if (idx.type() == Dtype.b8.value): - safe_call(backend.get().af_where(c_pointer(arr), idx.arr)) + if isinstance(idx, Array): + if idx.dtype == af_bool: + self.idx.arr = everything.where(idx.arr) else: - safe_call(backend.get().af_retain_array(c_pointer(arr), idx.arr)) + self.idx.arr = everything.retain_array(idx.arr) - self.idx.arr = arr self.isSeq = False + elif isinstance(idx, ParallelRange): self.idx.seq = idx self.isBatch = True + else: - self.idx.seq = Seq(idx) + self.idx.seq = _IndexSequence(idx) def __del__(self) -> None: if not self.isSeq: # ctypes field variables are automatically # converted to basic C types so we have to # build the void_p from the value again. - arr = c_void_ptr_t(self.idx.arr) + arr = ctypes.c_void_p(self.idx.arr) backend.get().af_release_array(arr) -class _Index4(object): +class _CIndexStructure: def __init__(self) -> None: - index_vec = IndexStructure * 4 - _span = IndexStructure(slice(None)) - self.array = index_vec(_span, _span, _span, _span) - # Do not lose those idx as self.array keeps - # no reference to them. Otherwise the destructor + index_vec = _IndexStructure * 4 + # NOTE Do not lose those idx as self.array keeps no reference to them. Otherwise the destructor # is prematurely called - self.idxs = [_span, _span, _span, _span] + self.idxs = [_IndexStructure(slice(None))] * 4 + self.array = index_vec(*self.idxs) @property - def pointer(self): + def pointer(self) -> everything.AFArrayPointer: return ctypes.pointer(self.array) def __getitem__(self, idx): @@ -245,13 +251,13 @@ def __setitem__(self, idx, value): self.idxs[idx] = value -def _get_indices(key): - inds = _Index4() +def get_indices(key): + indices = _CIndexStructure() + if isinstance(key, tuple): - n_idx = len(key) - for n in range(n_idx): - inds[n] = Index(key[n]) + for n in range(len(key)): + indices[n] = _IndexStructure(key[n]) else: - inds[0] = Index(key) + indices[0] = _IndexStructure(key) - return inds + return indices diff --git a/arrayfire/backend/constant_array.py b/arrayfire/backend/wrapped/constant_array.py similarity index 94% rename from arrayfire/backend/constant_array.py rename to arrayfire/backend/wrapped/constant_array.py index 85ba3d8..f98f327 100644 --- a/arrayfire/backend/constant_array.py +++ b/arrayfire/backend/wrapped/constant_array.py @@ -1,9 +1,10 @@ import ctypes from typing import Tuple, Union -from ..dtypes import Dtype, int64, uint64 -from ..dtypes.helpers import CShape, implicit_dtype -from .backend import backend_api, safe_call +from arrayfire.dtypes import Dtype, int64, uint64 +from arrayfire.dtypes.helpers import CShape, implicit_dtype + +from ..backend import backend_api, safe_call AFArray = ctypes.c_void_p diff --git a/arrayfire/backend/wrapped.py b/arrayfire/backend/wrapped/everything.py similarity index 93% rename from arrayfire/backend/wrapped.py rename to arrayfire/backend/wrapped/everything.py index 5918a8f..4d93b9b 100644 --- a/arrayfire/backend/wrapped.py +++ b/arrayfire/backend/wrapped/everything.py @@ -3,10 +3,10 @@ # from arrayfire.array import _get_indices # HACK replace with refactored one -from ..library.device import PointerSource -from ..dtypes import CType, Dtype -from ..dtypes.helpers import CShape, c_dim_t, to_str -from .backend import ArrayBuffer, backend_api, safe_call +from ...library.device import PointerSource +from ...dtypes import CType, Dtype +from ...dtypes.helpers import CShape, c_dim_t, to_str +from ..backend import ArrayBuffer, backend_api, safe_call AFArrayPointer = ctypes._Pointer AFArray = ctypes.c_void_p @@ -236,3 +236,14 @@ def array_as_str(arr: AFArray) -> str: backend_api.af_free_host(arr_str) ) return py_str + + +def where(arr: AFArray) -> AFArray: + """ + source: https://arrayfire.org/docs/group__scan__func__where.htm#gafda59a3d25d35238592dd09907be9d07 + """ + out = ctypes.c_void_p(0) + safe_call( + backend_api.af_where(ctypes.pointer(out), arr) + ) + return out diff --git a/arrayfire/backend/operators.py b/arrayfire/backend/wrapped/operators.py similarity index 89% rename from arrayfire/backend/operators.py rename to arrayfire/backend/wrapped/operators.py index aa83a7d..663b5b6 100644 --- a/arrayfire/backend/operators.py +++ b/arrayfire/backend/wrapped/operators.py @@ -1,7 +1,7 @@ import ctypes -from typing import Callable, Union +from typing import Callable -from .backend import backend_api, safe_call +from ..backend import backend_api, safe_call AFArray = ctypes.c_void_p @@ -147,15 +147,3 @@ def _binary_op(c_func: Callable, lhs: AFArray, rhs: AFArray, /) -> AFArray: out = ctypes.c_void_p(0) safe_call(c_func(ctypes.pointer(out), lhs, rhs, _bcast_var)) return out - - -def count_all(x: AFArray) -> Union[int, float, complex]: - # TODO reconsider original arith.count - return _reduce_all(x, backend_api.af_count_all) - - -def _reduce_all(arr: AFArray, c_func: Callable) -> Union[int, float, complex]: - real = ctypes.c_double(0) - imag = ctypes.c_double(0) - safe_call(c_func(ctypes.pointer(real), ctypes.pointer(imag), arr)) - return real.value if imag.value == 0 else real.value + imag.value * 1j diff --git a/arrayfire/backend/wrapped/reduction_operations.py b/arrayfire/backend/wrapped/reduction_operations.py new file mode 100755 index 0000000..438bd11 --- /dev/null +++ b/arrayfire/backend/wrapped/reduction_operations.py @@ -0,0 +1,18 @@ +import ctypes +from typing import Callable, Union + +from ..backend import backend_api, safe_call + +AFArray = ctypes.c_void_p + + +def count_all(x: AFArray) -> Union[int, float, complex]: + # TODO reconsider original arith.count + return _reduce_all(x, backend_api.af_count_all) + + +def _reduce_all(arr: AFArray, c_func: Callable) -> Union[int, float, complex]: + real = ctypes.c_double(0) + imag = ctypes.c_double(0) + safe_call(c_func(ctypes.pointer(real), ctypes.pointer(imag), arr)) + return real.value if imag.value == 0 else real.value + imag.value * 1j diff --git a/arrayfire/library/broadcast.py b/arrayfire/library/broadcast.py index 8946670..cd05b2c 100755 --- a/arrayfire/library/broadcast.py +++ b/arrayfire/library/broadcast.py @@ -1,3 +1,4 @@ + class _bcast(object): _flag = False From e07879057dccaa4f19f577859fe2eb49cc0ee217 Mon Sep 17 00:00:00 2001 From: Anton Chernyatevich Date: Thu, 15 Jun 2023 15:28:23 +0300 Subject: [PATCH 4/6] Rename indexing file --- .../wrapped/{_indexing.py => indexing.py} | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) rename arrayfire/backend/wrapped/{_indexing.py => indexing.py} (89%) diff --git a/arrayfire/backend/wrapped/_indexing.py b/arrayfire/backend/wrapped/indexing.py similarity index 89% rename from arrayfire/backend/wrapped/_indexing.py rename to arrayfire/backend/wrapped/indexing.py index 14fd288..7c3bf57 100755 --- a/arrayfire/backend/wrapped/_indexing.py +++ b/arrayfire/backend/wrapped/indexing.py @@ -59,18 +59,18 @@ def __init__(self, chunk: Union[int, slice]): self.end = ctypes.c_double(chunk.stop) # handle special cases - if self.begin >= 0 and self.end >= 0 and 0 <= self.end <= self.begin and self.step >= 0: - self.begin = 1 - self.end = 1 - self.step = 1 + if 0 <= self.end.value <= self.begin.value and self.step.value >= 0: + self.begin.value = 1 + self.end.value = 1 + self.step.value = 1 - elif self.begin < 0 and self.end < 0 and self.end >= self.begin and self.step <= 0: - self.begin = -2 - self.end = -2 - self.step = -1 + elif 0 > self.end.value >= self.begin.value and self.step.value <= 0: + self.begin.value = -2 + self.end.value = -2 + self.step.value = -1 if chunk.stop: - self.end = self.end - math.copysign(1, self.step) + self.end.value = self.end.value - math.copysign(1, self.step.value) else: raise IndexError("Invalid type while indexing arrayfire.array") From 22bd8477600c537258047073c6ac4febd54d40ed Mon Sep 17 00:00:00 2001 From: Anton Chernyatevich Date: Thu, 15 Jun 2023 19:59:27 +0300 Subject: [PATCH 5/6] Fix imports and typings --- arrayfire/array/array_object.py | 15 +++- arrayfire/backend/wrapped/constant_array.py | 17 +++-- arrayfire/backend/wrapped/constants.py | 4 ++ arrayfire/backend/wrapped/everything.py | 51 +++++++------- arrayfire/backend/wrapped/indexing.py | 68 ++++++++----------- arrayfire/backend/wrapped/operators.py | 48 +++++++------ .../backend/wrapped/reduction_operations.py | 7 +- arrayfire/library/broadcast.py | 32 +++++---- setup.cfg | 5 ++ tests/test_operators.py | 2 +- 10 files changed, 128 insertions(+), 121 deletions(-) create mode 100755 arrayfire/backend/wrapped/constants.py diff --git a/arrayfire/array/array_object.py b/arrayfire/array/array_object.py index a5bc519..600668d 100755 --- a/arrayfire/array/array_object.py +++ b/arrayfire/array/array_object.py @@ -9,6 +9,7 @@ from ..backend import ArrayBuffer from ..backend.wrapped import everything from ..backend.wrapped.constant_array import create_constant_array +from ..backend.wrapped.indexing import CIndexStructure, IndexStructure from ..backend.wrapped.reduction_operations import count_all from ..dtypes import CType from ..dtypes import bool as af_bool @@ -727,7 +728,7 @@ def __getitem__(self, key: Union[int, slice, Tuple[Union[int, slice, ], ...], Ar return out # HACK known issue - out.arr = everything.index_gen(self.arr, ndims, key) # type: ignore[arg-type] + out.arr = everything.index_gen(self.arr, ndims, key, _get_indices(key)) # type: ignore[arg-type] return out def __index__(self) -> int: @@ -934,3 +935,15 @@ def _process_c_function(lhs: Union[int, float, Array], rhs: Union[int, float, Ar out.arr = c_function(lhs_array, rhs_array) return out + + +def _get_indices(key: Union[int, slice, Tuple[Union[int, slice, ], ...], Array]) -> CIndexStructure: + indices = CIndexStructure() + + if isinstance(key, tuple): + for n in range(len(key)): + indices[n] = IndexStructure(key[n]) + else: + indices[0] = IndexStructure(key) + + return indices diff --git a/arrayfire/backend/wrapped/constant_array.py b/arrayfire/backend/wrapped/constant_array.py index f98f327..596c95f 100644 --- a/arrayfire/backend/wrapped/constant_array.py +++ b/arrayfire/backend/wrapped/constant_array.py @@ -5,11 +5,10 @@ from arrayfire.dtypes.helpers import CShape, implicit_dtype from ..backend import backend_api, safe_call +from .constants import AFArrayType -AFArray = ctypes.c_void_p - -def _constant_complex(number: Union[int, float], shape: Tuple[int, ...], dtype: Dtype, /) -> AFArray: +def _constant_complex(number: Union[int, float], shape: Tuple[int, ...], dtype: Dtype, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__data__func__constant.htm#ga5a083b1f3cd8a72a41f151de3bdea1a2 """ @@ -24,7 +23,7 @@ def _constant_complex(number: Union[int, float], shape: Tuple[int, ...], dtype: return out -def _constant_long(number: Union[int, float], shape: Tuple[int, ...], dtype: Dtype, /) -> AFArray: +def _constant_long(number: Union[int, float], shape: Tuple[int, ...], dtype: Dtype, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__data__func__constant.htm#ga10f1c9fad1ce9e9fefd885d5a1d1fd49 """ @@ -33,12 +32,12 @@ def _constant_long(number: Union[int, float], shape: Tuple[int, ...], dtype: Dty safe_call( backend_api.af_constant_long( - ctypes.pointer(out), ctypes.c_longlong(number.real), 4, ctypes.pointer(c_shape.c_array)) + ctypes.pointer(out), ctypes.c_longlong(int(number.real)), 4, ctypes.pointer(c_shape.c_array)) ) return out -def _constant_ulong(number: Union[int, float], shape: Tuple[int, ...], dtype: Dtype, /) -> AFArray: +def _constant_ulong(number: Union[int, float], shape: Tuple[int, ...], dtype: Dtype, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__data__func__constant.htm#ga67af670cc9314589f8134019f5e68809 """ @@ -48,12 +47,12 @@ def _constant_ulong(number: Union[int, float], shape: Tuple[int, ...], dtype: Dt safe_call( backend_api.af_constant_ulong( - ctypes.pointer(out), ctypes.c_ulonglong(number.real), 4, ctypes.pointer(c_shape.c_array)) + ctypes.pointer(out), ctypes.c_ulonglong(int(number.real)), 4, ctypes.pointer(c_shape.c_array)) ) return out -def _constant(number: Union[int, float], shape: Tuple[int, ...], dtype: Dtype, /) -> AFArray: +def _constant(number: Union[int, float], shape: Tuple[int, ...], dtype: Dtype, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__data__func__constant.htm#gafc51b6a98765dd24cd4139f3bde00670 """ @@ -67,7 +66,7 @@ def _constant(number: Union[int, float], shape: Tuple[int, ...], dtype: Dtype, / return out -def create_constant_array(number: Union[int, float], shape: Tuple[int, ...], dtype: Dtype, /) -> AFArray: +def create_constant_array(number: Union[int, float], shape: Tuple[int, ...], dtype: Dtype, /) -> AFArrayType: dtype = implicit_dtype(number, dtype) # NOTE complex is not supported in Data API diff --git a/arrayfire/backend/wrapped/constants.py b/arrayfire/backend/wrapped/constants.py new file mode 100755 index 0000000..cb8e944 --- /dev/null +++ b/arrayfire/backend/wrapped/constants.py @@ -0,0 +1,4 @@ +import ctypes + +AFArrayType = ctypes.c_void_p +AFArrayPointerType = ctypes._Pointer diff --git a/arrayfire/backend/wrapped/everything.py b/arrayfire/backend/wrapped/everything.py index 4d93b9b..10153f3 100644 --- a/arrayfire/backend/wrapped/everything.py +++ b/arrayfire/backend/wrapped/everything.py @@ -1,23 +1,16 @@ import ctypes -from typing import Tuple, Union, cast +from typing import Any, Tuple, Union, cast -# from arrayfire.array import _get_indices # HACK replace with refactored one - -from ...library.device import PointerSource from ...dtypes import CType, Dtype from ...dtypes.helpers import CShape, c_dim_t, to_str +from ...library.device import PointerSource from ..backend import ArrayBuffer, backend_api, safe_call - -AFArrayPointer = ctypes._Pointer -AFArray = ctypes.c_void_p - -# HACK, TODO replace for actual bcast_var after refactoring ~ https://github.com/arrayfire/arrayfire/pull/2871 -_bcast_var = False +from .constants import AFArrayType # Array management -def create_handle(shape: Tuple[int, ...], dtype: Dtype, /) -> AFArray: +def create_handle(shape: Tuple[int, ...], dtype: Dtype, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__c__api__mat.htm#ga3b8f5cf6fce69aa1574544bc2d44d7d0 """ @@ -31,7 +24,7 @@ def create_handle(shape: Tuple[int, ...], dtype: Dtype, /) -> AFArray: return out -def retain_array(arr: AFArray) -> AFArray: +def retain_array(arr: AFArrayType) -> AFArrayType: """ source: https://arrayfire.org/docs/group__c__api__mat.htm#ga7ed45b3f881c0f6c80c5cf2af886dbab """ @@ -43,7 +36,7 @@ def retain_array(arr: AFArray) -> AFArray: return out -def create_array(shape: Tuple[int, ...], dtype: Dtype, array_buffer: ArrayBuffer, /) -> AFArray: +def create_array(shape: Tuple[int, ...], dtype: Dtype, array_buffer: ArrayBuffer, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__c__api__mat.htm#ga834be32357616d8ab735087c6f681858 """ @@ -58,7 +51,7 @@ def create_array(shape: Tuple[int, ...], dtype: Dtype, array_buffer: ArrayBuffer return out -def device_array(shape: Tuple[int, ...], dtype: Dtype, array_buffer: ArrayBuffer, /) -> AFArray: +def device_array(shape: Tuple[int, ...], dtype: Dtype, array_buffer: ArrayBuffer, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__c__api__mat.htm#gaad4fc77f872217e7337cb53bfb623cf5 """ @@ -75,7 +68,7 @@ def device_array(shape: Tuple[int, ...], dtype: Dtype, array_buffer: ArrayBuffer def create_strided_array( shape: Tuple[int, ...], dtype: Dtype, array_buffer: ArrayBuffer, offset: CType, strides: Tuple[int, ...], - pointer_source: PointerSource, /) -> AFArray: + pointer_source: PointerSource, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__internal__func__create.htm#gad31241a3437b7b8bc3cf49f85e5c4e0c """ @@ -99,7 +92,7 @@ def create_strided_array( return out -def get_ctype(arr: AFArray) -> int: +def get_ctype(arr: AFArrayType) -> int: """ source: https://arrayfire.org/docs/group__c__api__mat.htm#ga0dda6898e1c0d9a43efb56cd6a988c9b """ @@ -111,7 +104,7 @@ def get_ctype(arr: AFArray) -> int: return out.value -def get_elements(arr: AFArray) -> int: +def get_elements(arr: AFArrayType) -> int: """ source: https://arrayfire.org/docs/group__c__api__mat.htm#ga6845bbe4385a60a606b88f8130252c1f """ @@ -123,7 +116,7 @@ def get_elements(arr: AFArray) -> int: return out.value -def get_numdims(arr: AFArray) -> int: +def get_numdims(arr: AFArrayType) -> int: """ source: https://arrayfire.org/docs/group__c__api__mat.htm#gaefa019d932ff58c2a829ce87edddd2a8 """ @@ -135,7 +128,7 @@ def get_numdims(arr: AFArray) -> int: return out.value -def get_dims(arr: AFArray) -> Tuple[int, ...]: +def get_dims(arr: AFArrayType) -> Tuple[int, ...]: """ source: https://arrayfire.org/docs/group__c__api__mat.htm#ga8b90da50a532837d9763e301b2267348 """ @@ -150,7 +143,7 @@ def get_dims(arr: AFArray) -> Tuple[int, ...]: return (d0.value, d1.value, d2.value, d3.value) -def get_scalar(arr: AFArray, dtype: Dtype, /) -> Union[None, int, float, bool, complex]: +def get_scalar(arr: AFArrayType, dtype: Dtype, /) -> Union[None, int, float, bool, complex]: """ source: https://arrayfire.org/docs/group__c__api__mat.htm#gaefe2e343a74a84bd43b588218ecc09a3 """ @@ -161,7 +154,7 @@ def get_scalar(arr: AFArray, dtype: Dtype, /) -> Union[None, int, float, bool, c return cast(Union[None, int, float, bool, complex], out.value) -def is_empty(arr: AFArray) -> bool: +def is_empty(arr: AFArrayType) -> bool: """ source: https://arrayfire.org/docs/group__c__api__mat.htm#ga19c749e95314e1c77d816ad9952fb680 """ @@ -172,7 +165,7 @@ def is_empty(arr: AFArray) -> bool: return out.value -def get_data_ptr(arr: AFArray, size: int, dtype: Dtype, /) -> ctypes.Array: +def get_data_ptr(arr: AFArrayType, size: int, dtype: Dtype, /) -> ctypes.Array: """ source: https://arrayfire.org/docs/group__c__api__mat.htm#ga6040dc6f0eb127402fbf62c1165f0b9d """ @@ -187,18 +180,20 @@ def get_data_ptr(arr: AFArray, size: int, dtype: Dtype, /) -> ctypes.Array: # Arrayfire Functions -def index_gen(arr: AFArray, ndims: int, key: Union[int, slice, Tuple[Union[int, slice, ], ...]], /) -> AFArray: +def index_gen( + arr: AFArrayType, ndims: int, key: Union[int, slice, Tuple[Union[int, slice, ], ...]], + indices: Any, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__index__func__index.htm#ga14a7d149dba0ed0b977335a3df9d91e6 """ out = ctypes.c_void_p(0) safe_call( - backend_api.af_index_gen(ctypes.pointer(out), arr, c_dim_t(ndims), _get_indices(key).pointer) + backend_api.af_index_gen(ctypes.pointer(out), arr, c_dim_t(ndims), indices.pointer) ) return out -def transpose(arr: AFArray, conjugate: bool, /) -> AFArray: +def transpose(arr: AFArrayType, conjugate: bool, /) -> AFArrayType: """ https://arrayfire.org/docs/group__blas__func__transpose.htm#ga716b2b9bf190c8f8d0970aef2b57d8e7 """ @@ -209,7 +204,7 @@ def transpose(arr: AFArray, conjugate: bool, /) -> AFArray: return out -def reorder(arr: AFArray, ndims: int, /) -> AFArray: +def reorder(arr: AFArrayType, ndims: int, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__manip__func__reorder.htm#ga57383f4d00a3a86eab08dddd52c3ad3d """ @@ -221,7 +216,7 @@ def reorder(arr: AFArray, ndims: int, /) -> AFArray: return out -def array_as_str(arr: AFArray) -> str: +def array_as_str(arr: AFArrayType) -> str: """ source: - https://arrayfire.org/docs/group__print__func__tostring.htm#ga01f32ef2420b5d4592c6e4b4964b863b @@ -238,7 +233,7 @@ def array_as_str(arr: AFArray) -> str: return py_str -def where(arr: AFArray) -> AFArray: +def where(arr: AFArrayType) -> AFArrayType: """ source: https://arrayfire.org/docs/group__scan__func__where.htm#gafda59a3d25d35238592dd09907be9d07 """ diff --git a/arrayfire/backend/wrapped/indexing.py b/arrayfire/backend/wrapped/indexing.py index 7c3bf57..149c6df 100755 --- a/arrayfire/backend/wrapped/indexing.py +++ b/arrayfire/backend/wrapped/indexing.py @@ -2,11 +2,12 @@ import ctypes import math -from typing import Union +from typing import Any, Union -from arrayfire.array.array_object import Array -from arrayfire.backend.wrapped import everything -from arrayfire.dtypes import bool as af_bool +from arrayfire.library.broadcast import bcast_var + +from ..backend import backend_api, safe_call +from . import constants class _IndexSequence(ctypes.Structure): @@ -59,12 +60,12 @@ def __init__(self, chunk: Union[int, slice]): self.end = ctypes.c_double(chunk.stop) # handle special cases - if 0 <= self.end.value <= self.begin.value and self.step.value >= 0: + if 0 <= self.end <= self.begin and self.step >= 0: self.begin.value = 1 self.end.value = 1 self.step.value = 1 - elif 0 > self.end.value >= self.begin.value and self.step.value <= 0: + elif 0 > self.end >= self.begin and self.step <= 0: self.begin.value = -2 self.end.value = -2 self.step.value = -1 @@ -129,7 +130,9 @@ class ParallelRange(_IndexSequence): """ - def __init__(self, start, stop=None, step=None): + def __init__( + self, start: Union[int, float], stop: Union[int, float, None] = None, + step: Union[int, float, None] = None) -> None: if not stop: stop = start start = 0 @@ -144,11 +147,11 @@ def next(self) -> ParallelRange: """ Function called by the iterator in Python 2 """ - if _bcast_var.get() is True: - _bcast_var.toggle() + if bcast_var.get() is True: + bcast_var.toggle() raise StopIteration else: - _bcast_var.toggle() + bcast_var.toggle() return self def __next__(self) -> ParallelRange: @@ -164,7 +167,7 @@ class _IndexUnion(ctypes.Union): ("seq", _IndexSequence)] -class _IndexStructure(ctypes.Structure): +class IndexStructure(ctypes.Structure): _fields_ = [ ("idx", _IndexUnion), ("isSeq", ctypes.c_bool), @@ -202,20 +205,21 @@ class _IndexStructure(ctypes.Structure): """ - def __init__(self, idx): + def __init__(self, idx: Any) -> None: self.idx = _IndexUnion() self.isBatch = False self.isSeq = True - if isinstance(idx, Array): - if idx.dtype == af_bool: - self.idx.arr = everything.where(idx.arr) - else: - self.idx.arr = everything.retain_array(idx.arr) + # FIXME cyclic reimport + # if isinstance(idx, Array): + # if idx.dtype == af_bool: + # self.idx.arr = everything.where(idx.arr) + # else: + # self.idx.arr = everything.retain_array(idx.arr) - self.isSeq = False + # self.isSeq = False - elif isinstance(idx, ParallelRange): + if isinstance(idx, ParallelRange): self.idx.seq = idx self.isBatch = True @@ -228,36 +232,24 @@ def __del__(self) -> None: # converted to basic C types so we have to # build the void_p from the value again. arr = ctypes.c_void_p(self.idx.arr) - backend.get().af_release_array(arr) + safe_call(backend_api.af_release_array(arr)) -class _CIndexStructure: +class CIndexStructure: def __init__(self) -> None: - index_vec = _IndexStructure * 4 + index_vec = IndexStructure * 4 # NOTE Do not lose those idx as self.array keeps no reference to them. Otherwise the destructor # is prematurely called - self.idxs = [_IndexStructure(slice(None))] * 4 + self.idxs = [IndexStructure(slice(None))] * 4 self.array = index_vec(*self.idxs) @property - def pointer(self) -> everything.AFArrayPointer: + def pointer(self) -> constants.AFArrayPointerType: return ctypes.pointer(self.array) - def __getitem__(self, idx): + def __getitem__(self, idx: int) -> IndexStructure: return self.array[idx] - def __setitem__(self, idx, value): + def __setitem__(self, idx: int, value: IndexStructure) -> None: self.array[idx] = value self.idxs[idx] = value - - -def get_indices(key): - indices = _CIndexStructure() - - if isinstance(key, tuple): - for n in range(len(key)): - indices[n] = _IndexStructure(key[n]) - else: - indices[0] = _IndexStructure(key) - - return indices diff --git a/arrayfire/backend/wrapped/operators.py b/arrayfire/backend/wrapped/operators.py index 663b5b6..1e1f5a8 100644 --- a/arrayfire/backend/wrapped/operators.py +++ b/arrayfire/backend/wrapped/operators.py @@ -1,52 +1,50 @@ import ctypes from typing import Callable -from ..backend import backend_api, safe_call - -AFArray = ctypes.c_void_p +from arrayfire.library.broadcast import bcast_var -# HACK, TODO replace for actual bcast_var after refactoring ~ https://github.com/arrayfire/arrayfire/pull/2871 -_bcast_var = False +from ..backend import backend_api, safe_call +from .constants import AFArrayType # Arithmetic Operators -def add(lhs: AFArray, rhs: AFArray, /) -> AFArray: +def add(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__add.htm#ga1dfbee755fedd680f4476803ddfe06a7 """ return _binary_op(backend_api.af_add, lhs, rhs) -def sub(lhs: AFArray, rhs: AFArray, /) -> AFArray: +def sub(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__sub.htm#ga80ff99a2e186c23614ea9f36ffc6f0a4 """ return _binary_op(backend_api.af_sub, lhs, rhs) -def mul(lhs: AFArray, rhs: AFArray, /) -> AFArray: +def mul(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__mul.htm#ga5f7588b2809ff7551d38b6a0bd583a02 """ return _binary_op(backend_api.af_mul, lhs, rhs) -def div(lhs: AFArray, rhs: AFArray, /) -> AFArray: +def div(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__div.htm#ga21f3f97755702692ec8976934e75fde6 """ return _binary_op(backend_api.af_div, lhs, rhs) -def mod(lhs: AFArray, rhs: AFArray, /) -> AFArray: +def mod(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__mod.htm#ga01924d1b59d8886e46fabd2dc9b27e0f """ return _binary_op(backend_api.af_mod, lhs, rhs) -def pow(lhs: AFArray, rhs: AFArray, /) -> AFArray: +def pow(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__pow.htm#ga0f28be1a9c8b176a78c4a47f483e7fc6 """ @@ -55,7 +53,7 @@ def pow(lhs: AFArray, rhs: AFArray, /) -> AFArray: # Bitwise Operators -def bitnot(arr: AFArray, /) -> AFArray: +def bitnot(arr: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__bitnot.htm#gaf97e8a38aab59ed2d3a742515467d01e """ @@ -66,84 +64,84 @@ def bitnot(arr: AFArray, /) -> AFArray: return out -def bitand(lhs: AFArray, rhs: AFArray, /) -> AFArray: +def bitand(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__bitand.htm#ga45c0779ade4703708596df11cca98800 """ return _binary_op(backend_api.af_bitand, lhs, rhs) -def bitor(lhs: AFArray, rhs: AFArray, /) -> AFArray: +def bitor(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__bitor.htm#ga84c99f77d1d83fd53f949b4d67b5b210 """ return _binary_op(backend_api.af_bitor, lhs, rhs) -def bitxor(lhs: AFArray, rhs: AFArray, /) -> AFArray: +def bitxor(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__bitxor.htm#ga8188620da6b432998e55fdd1fad22100 """ return _binary_op(backend_api.af_bitxor, lhs, rhs) -def bitshiftl(lhs: AFArray, rhs: AFArray, /) -> AFArray: +def bitshiftl(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__shiftl.htm#ga3139645aafe6f045a5cab454e9c13137 """ return _binary_op(backend_api.af_butshiftl, lhs, rhs) -def bitshiftr(lhs: AFArray, rhs: AFArray, /) -> AFArray: +def bitshiftr(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__shiftr.htm#ga4c06b9977ecf96cdfc83b5dfd1ac4895 """ return _binary_op(backend_api.af_bitshiftr, lhs, rhs) -def lt(lhs: AFArray, rhs: AFArray, /) -> AFArray: +def lt(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/arith_8h.htm#ae7aa04bf23b32bb11c4bab8bdd637103 """ return _binary_op(backend_api.af_lt, lhs, rhs) -def le(lhs: AFArray, rhs: AFArray, /) -> AFArray: +def le(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__le.htm#gad5535ce64dbed46d0773fd494e84e922 """ return _binary_op(backend_api.af_le, lhs, rhs) -def gt(lhs: AFArray, rhs: AFArray, /) -> AFArray: +def gt(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__gt.htm#ga4e65603259515de8939899a163ebaf9e """ return _binary_op(backend_api.af_gt, lhs, rhs) -def ge(lhs: AFArray, rhs: AFArray, /) -> AFArray: +def ge(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__ge.htm#ga4513f212e0b0a22dcf4653e89c85e3d9 """ return _binary_op(backend_api.af_ge, lhs, rhs) -def eq(lhs: AFArray, rhs: AFArray, /) -> AFArray: +def eq(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__eq.htm#ga76d2da7716831616bb81effa9e163693 """ return _binary_op(backend_api.af_eq, lhs, rhs) -def neq(lhs: AFArray, rhs: AFArray, /) -> AFArray: +def neq(lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: """ source: https://arrayfire.org/docs/group__arith__func__neq.htm#gae4ee8bd06a410f259f1493fb811ce441 """ return _binary_op(backend_api.af_neq, lhs, rhs) -def _binary_op(c_func: Callable, lhs: AFArray, rhs: AFArray, /) -> AFArray: +def _binary_op(c_func: Callable, lhs: AFArrayType, rhs: AFArrayType, /) -> AFArrayType: out = ctypes.c_void_p(0) - safe_call(c_func(ctypes.pointer(out), lhs, rhs, _bcast_var)) + safe_call(c_func(ctypes.pointer(out), lhs, rhs, bcast_var.get())) return out diff --git a/arrayfire/backend/wrapped/reduction_operations.py b/arrayfire/backend/wrapped/reduction_operations.py index 438bd11..f56b9a6 100755 --- a/arrayfire/backend/wrapped/reduction_operations.py +++ b/arrayfire/backend/wrapped/reduction_operations.py @@ -2,16 +2,15 @@ from typing import Callable, Union from ..backend import backend_api, safe_call +from .constants import AFArrayType -AFArray = ctypes.c_void_p - -def count_all(x: AFArray) -> Union[int, float, complex]: +def count_all(x: AFArrayType) -> Union[int, float, complex]: # TODO reconsider original arith.count return _reduce_all(x, backend_api.af_count_all) -def _reduce_all(arr: AFArray, c_func: Callable) -> Union[int, float, complex]: +def _reduce_all(arr: AFArrayType, c_func: Callable) -> Union[int, float, complex]: real = ctypes.c_double(0) imag = ctypes.c_double(0) safe_call(c_func(ctypes.pointer(real), ctypes.pointer(imag), arr)) diff --git a/arrayfire/library/broadcast.py b/arrayfire/library/broadcast.py index cd05b2c..aacdbcd 100755 --- a/arrayfire/library/broadcast.py +++ b/arrayfire/library/broadcast.py @@ -1,21 +1,24 @@ +from typing import Any, Callable -class _bcast(object): - _flag = False - def get(self): - return _bcast._flag +class Bcast: + def __init__(self) -> None: + self._flag: bool = False - def set(self, flag): - _bcast._flag = flag + def get(self) -> bool: + return self._flag - def toggle(self): - _bcast._flag ^= True + def set(self, flag: bool) -> None: + self._flag = flag + def toggle(self) -> None: + self._flag ^= True -_bcast_var = _bcast() +bcast_var: Bcast = Bcast() -def broadcast(func, *args): + +def broadcast(func: Callable[..., Any], *args: Any) -> Any: """ Function to perform broadcast operations. @@ -74,14 +77,13 @@ def broadcast(func, *args): 1.5328 0.8898 0.7185 """ - - def wrapper(*func_args): - _bcast_var.toggle() + def wrapper(*func_args: Any) -> Any: + bcast_var.toggle() res = func(*func_args) - _bcast_var.toggle() + bcast_var.toggle() return res if len(args) == 0: - return wrapper + return wrapper() else: return wrapper(*args) diff --git a/setup.cfg b/setup.cfg index bd675bf..a19c97b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -55,6 +55,11 @@ test = line_length = 119 multi_line_output = 4 +[tool:pytest] +addopts = --cache-clear --cov=arrayfire --flake8 --isort -s +console_output_style = classic +markers = mypy + [flake8] exclude = venv application-import-names = arrayfire diff --git a/tests/test_operators.py b/tests/test_operators.py index 04c4741..252eaa4 100644 --- a/tests/test_operators.py +++ b/tests/test_operators.py @@ -1,7 +1,7 @@ from typing import Any -from arrayfire.library import operators from arrayfire.array.array_object import Array +from arrayfire.library import operators class TestArithmeticOperators: From 75370b2e053f7ff99e1047cc86f5ffd185b1c2b8 Mon Sep 17 00:00:00 2001 From: Anton Chernyatevich Date: Thu, 15 Jun 2023 20:02:26 +0300 Subject: [PATCH 6/6] Add comment for placed hack --- arrayfire/backend/backend.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/arrayfire/backend/backend.py b/arrayfire/backend/backend.py index c7d57b8..3e6e5c1 100644 --- a/arrayfire/backend/backend.py +++ b/arrayfire/backend/backend.py @@ -4,7 +4,9 @@ from ..dtypes.helpers import c_dim_t, to_str -# backend_api = ctypes.CDLL("/opt/arrayfire//lib/libafcpu.3.dylib") # Mock +# HACK for osx +# backend_api = ctypes.CDLL("/opt/arrayfire//lib/libafcpu.3.dylib") +# HACK for windows backend_api = ctypes.CDLL("C:/Program Files/ArrayFire/v3/lib/afcpu.dll")