From 7489387f900a709a6ebb27848833804087a2d6d8 Mon Sep 17 00:00:00 2001 From: Anton Date: Fri, 22 Sep 2023 18:07:07 +0300 Subject: [PATCH 1/2] Add inclusive scan operations. Move constants separetely. Fix minor issues --- arrayfire_wrapper/defines.py | 18 -------- arrayfire_wrapper/lib/_constants.py | 37 ++++++++++++++++ arrayfire_wrapper/lib/_error_handler.py | 8 +--- arrayfire_wrapper/lib/array_layout.py | 3 +- .../lib/image_processing/image_moments.py | 3 +- arrayfire_wrapper/lib/statistics.py | 14 +------ .../inclusive_scan_operations.py | 42 +++++++++++++++++++ 7 files changed, 86 insertions(+), 39 deletions(-) create mode 100644 arrayfire_wrapper/lib/_constants.py diff --git a/arrayfire_wrapper/defines.py b/arrayfire_wrapper/defines.py index e293eda..a075159 100644 --- a/arrayfire_wrapper/defines.py +++ b/arrayfire_wrapper/defines.py @@ -3,7 +3,6 @@ import ctypes import platform from dataclasses import dataclass -from enum import Enum from typing import Type @@ -48,20 +47,3 @@ def __repr__(self) -> str: def c_array(self): # type: ignore[no-untyped-def] c_shape = CDimT * 4 # ctypes.c_int | ctypes.c_longlong * 4 return c_shape(CDimT(self.x1), CDimT(self.x2), CDimT(self.x3), CDimT(self.x4)) - - -class Moment(Enum): - M00 = 1 - M01 = 2 - M10 = 4 - M11 = 8 - FIRST_ORDER = M00 | M01 | M10 | M11 - - -class PointerSource(Enum): - """ - Source of the pointer. - """ - - device = 0 # gpu - host = 1 # cpu diff --git a/arrayfire_wrapper/lib/_constants.py b/arrayfire_wrapper/lib/_constants.py new file mode 100644 index 0000000..96d69cb --- /dev/null +++ b/arrayfire_wrapper/lib/_constants.py @@ -0,0 +1,37 @@ +from enum import Enum + + +class BinaryOperator(Enum): + ADD = 0 + MUL = 1 + MIN = 2 + MAX = 3 + + +class ErrorCodes(Enum): + none = 0 + + +class Moment(Enum): + M00 = 1 + M01 = 2 + M10 = 4 + M11 = 8 + FIRST_ORDER = M00 | M01 | M10 | M11 + + +class PointerSource(Enum): + device = 0 # gpu + host = 1 # cpu + + +class TopK(Enum): + DEFAULT = 0 + MIN = 1 + MAX = 2 + + +class VarianceBias(Enum): + DEFAULT = 0 + SAMPLE = 1 + POPULATION = 2 diff --git a/arrayfire_wrapper/lib/_error_handler.py b/arrayfire_wrapper/lib/_error_handler.py index 295a457..290a903 100644 --- a/arrayfire_wrapper/lib/_error_handler.py +++ b/arrayfire_wrapper/lib/_error_handler.py @@ -1,17 +1,13 @@ import ctypes -from enum import Enum from arrayfire_wrapper._backend import _backend from arrayfire_wrapper.defines import CDimT from arrayfire_wrapper.dtypes import to_str - - -class _ErrorCodes(Enum): - none = 0 +from arrayfire_wrapper.lib._constants import ErrorCodes def safe_call(c_err: int) -> None: - if c_err == _ErrorCodes.none.value: + if c_err == ErrorCodes.none.value: return err_str = ctypes.c_char_p(0) diff --git a/arrayfire_wrapper/lib/array_layout.py b/arrayfire_wrapper/lib/array_layout.py index 0bd2d40..f20f53c 100644 --- a/arrayfire_wrapper/lib/array_layout.py +++ b/arrayfire_wrapper/lib/array_layout.py @@ -1,8 +1,9 @@ import ctypes from arrayfire_wrapper._backend import _backend -from arrayfire_wrapper.defines import AFArray, ArrayBuffer, CDimT, CShape, CType, PointerSource +from arrayfire_wrapper.defines import AFArray, ArrayBuffer, CDimT, CShape, CType from arrayfire_wrapper.dtypes import Dtype +from arrayfire_wrapper.lib._constants import PointerSource from arrayfire_wrapper.lib._error_handler import safe_call diff --git a/arrayfire_wrapper/lib/image_processing/image_moments.py b/arrayfire_wrapper/lib/image_processing/image_moments.py index e31f531..9b4bf74 100644 --- a/arrayfire_wrapper/lib/image_processing/image_moments.py +++ b/arrayfire_wrapper/lib/image_processing/image_moments.py @@ -1,7 +1,8 @@ import ctypes from arrayfire_wrapper._backend import _backend -from arrayfire_wrapper.defines import AFArray, Moment +from arrayfire_wrapper.defines import AFArray +from arrayfire_wrapper.lib._constants import Moment from arrayfire_wrapper.lib._error_handler import safe_call diff --git a/arrayfire_wrapper/lib/statistics.py b/arrayfire_wrapper/lib/statistics.py index 8b59439..8db6958 100644 --- a/arrayfire_wrapper/lib/statistics.py +++ b/arrayfire_wrapper/lib/statistics.py @@ -1,23 +1,11 @@ import ctypes -from enum import Enum from arrayfire_wrapper._backend import _backend from arrayfire_wrapper.defines import AFArray +from arrayfire_wrapper.lib._constants import TopK, VarianceBias from arrayfire_wrapper.lib._error_handler import safe_call -class VarianceBias(Enum): - DEFAULT = 0 - SAMPLE = 1 - POPULATION = 2 - - -class TopK(Enum): - DEFAULT = 0 - MIN = 1 - MAX = 2 - - def corrcoef(x: AFArray, y: AFArray, /) -> complex: """ source: https://arrayfire.org/docs/group__stat__func__corrcoef.htm#ga26b894c86731234136bfe1342453d8a7 diff --git a/arrayfire_wrapper/lib/vector_algorithms/inclusive_scan_operations.py b/arrayfire_wrapper/lib/vector_algorithms/inclusive_scan_operations.py index e69de29..4a887e6 100644 --- a/arrayfire_wrapper/lib/vector_algorithms/inclusive_scan_operations.py +++ b/arrayfire_wrapper/lib/vector_algorithms/inclusive_scan_operations.py @@ -0,0 +1,42 @@ +import ctypes + +from arrayfire_wrapper._backend import _backend +from arrayfire_wrapper.defines import AFArray +from arrayfire_wrapper.lib._constants import BinaryOperator +from arrayfire_wrapper.lib._error_handler import safe_call + + +def accum(arr: AFArray, dim: int, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__scan__func__accum.htm#ga50d499e844e0b63e338cb3ea50439629 + """ + out = AFArray.create_null_pointer() + safe_call(_backend.clib.af_accum(ctypes.pointer(out), arr, ctypes.c_int(dim))) + return out + + +def scan(arr: AFArray, dim: int, op: BinaryOperator, inclusive_scan: bool, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__scan__func__scan.htm#ga1c864e22826f61bec2e9b6c61aa93fce + """ + out = AFArray.create_null_pointer() + safe_call(_backend.clib.af_scan(ctypes.pointer(out), arr, dim, op.value, inclusive_scan)) + return out + + +def scan_by_key(key: AFArray, arr: AFArray, dim: int, op: BinaryOperator, inclusive_scan: bool, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__scan__func__scanbykey.htm#gaaae150e0f197782782f45340d137b027 + """ + out = AFArray.create_null_pointer() + safe_call(_backend.clib.af_scan(ctypes.pointer(out), key, arr, dim, op.value, inclusive_scan)) + return out + + +def where(arr: AFArray, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__scan__func__where.htm#gafda59a3d25d35238592dd09907be9d07 + """ + out = AFArray.create_null_pointer() + safe_call(_backend.clib.af_where(ctypes.pointer(out), arr)) + return out From 1c9e2b14bb7800cb9cecc809249a7adba376a627 Mon Sep 17 00:00:00 2001 From: Anton Date: Fri, 22 Sep 2023 20:18:16 +0300 Subject: [PATCH 2/2] Add vector operators. Add importing --- README.md | 2 +- arrayfire_wrapper/lib/__init__.py | 105 ++++++- arrayfire_wrapper/lib/_constants.py | 52 +++- arrayfire_wrapper/lib/_error_handler.py | 2 +- arrayfire_wrapper/lib/_utility.py | 4 +- .../lib/image_processing/filters.py | 17 + .../numerical_differentiation.py | 33 ++ .../vector_algorithms/reduction_operations.py | 294 ++++++++++++++++++ .../lib/vector_algorithms/set_operations.py | 32 ++ .../lib/vector_algorithms/sort_operations.py | 47 +++ 10 files changed, 574 insertions(+), 14 deletions(-) create mode 100644 arrayfire_wrapper/lib/image_processing/filters.py create mode 100644 arrayfire_wrapper/lib/vector_algorithms/numerical_differentiation.py create mode 100644 arrayfire_wrapper/lib/vector_algorithms/reduction_operations.py create mode 100644 arrayfire_wrapper/lib/vector_algorithms/set_operations.py create mode 100644 arrayfire_wrapper/lib/vector_algorithms/sort_operations.py diff --git a/README.md b/README.md index c5f396b..c009916 100644 --- a/README.md +++ b/README.md @@ -18,4 +18,4 @@ Arrayfire python C library wrapper - [ ] Signal Processing - [x] Statistics - [ ] Unified API Functions -- [ ] Vector Algorithms +- [x] Vector Algorithms diff --git a/arrayfire_wrapper/lib/__init__.py b/arrayfire_wrapper/lib/__init__.py index cc86a85..355d136 100644 --- a/arrayfire_wrapper/lib/__init__.py +++ b/arrayfire_wrapper/lib/__init__.py @@ -388,6 +388,101 @@ from .mathematical_functions.trigonometric_functions import acos, asin, atan, atan2, cos, sin, tan +# Vector Algorithms + +__all__ += [ + "accum", + "scan", + "scan_by_key", + "where", +] + +from .vector_algorithms.inclusive_scan_operations import accum, scan, scan_by_key, where + +__all__ += [ + "diff1", + "diff2", + "gradient", +] + +from .vector_algorithms.numerical_differentiation import diff1, diff2, gradient + +__all__ += [ + "all_true", + "all_true_all", + "all_true_by_key", + "any_true", + "any_true_all", + "any_true_by_key", + "count", + "count_all", + "count_by_key", + "imax", + "imax_all", + "imin", + "imin_all", + "max", + "max_all", + "max_by_key", + "max_ragged", + "min", + "min_all", + "product", + "product_all", + "product_nan", + "product_nan_all", + "sum", + "sum_all", + "sum_nan", + "sum_nan_all", +] + +from .vector_algorithms.reduction_operations import ( + all_true, + all_true_all, + all_true_by_key, + any_true, + any_true_all, + any_true_by_key, + count, + count_all, + count_by_key, + imax, + imax_all, + imin, + imin_all, + max, + max_all, + max_by_key, + max_ragged, + min, + min_all, + product, + product_all, + product_nan, + product_nan_all, + sum, + sum_all, + sum_nan, + sum_nan_all, +) + +__all__ += [ + "set_intersect", + "set_union", + "set_unique", +] + +from .vector_algorithms.set_operations import set_intersect, set_union, set_unique + +__all__ += [ + "sort", + "sort_by_key", + "sort_index", +] + +from .vector_algorithms.sort_operations import sort, sort_by_key, sort_index + # Functions to Work with Internal Array Layout __all__ += [ @@ -403,8 +498,6 @@ # Statistics __all__ += [ - "TopK", - "VarianceBias", "corrcoef", "cov", "mean", @@ -423,8 +516,6 @@ ] from .statistics import ( - TopK, - VarianceBias, corrcoef, cov, mean, @@ -441,3 +532,9 @@ var_all_weighted, var_weighted, ) + +# Constants + +__all__ += ["BinaryOperator", "Moment", "Pad", "PointerSource", "TopK", "VarianceBias"] + +from ._constants import BinaryOperator, Moment, Pad, PointerSource, TopK, VarianceBias diff --git a/arrayfire_wrapper/lib/_constants.py b/arrayfire_wrapper/lib/_constants.py index 96d69cb..efce5f7 100644 --- a/arrayfire_wrapper/lib/_constants.py +++ b/arrayfire_wrapper/lib/_constants.py @@ -1,18 +1,51 @@ from enum import Enum -class BinaryOperator(Enum): +class BinaryOperator(Enum): # Binary Operators ADD = 0 MUL = 1 MIN = 2 MAX = 3 -class ErrorCodes(Enum): - none = 0 +class ErrorCodes(Enum): # Error Values + NONE = 0 + # 100-199 Errors in environment + NO_MEM = 101 + DRIVER = 102 + RUNTIME = 103 -class Moment(Enum): + # 200-299 Errors in input parameters + INVALID_ARRAY = 201 + ARG = 202 + SIZE = 203 + TYPE = 204 + DIFF_TYPE = 205 + BATCH = 207 + DEVICE = 208 + + # 300-399 Errors for missing software features + NOT_SUPPORTED = 301 + NOT_CONFIGURED = 302 + NONFREE = 303 + + # 400-499 Errors for missing hardware features + NO_DBL = 401 + NO_GFX = 402 + NO_HALF = 403 + + # 500-599 Errors specific to the heterogeneous API + LOAD_LIB = 501 + LOAD_SYM = 502 + ARR_BKND_MISMATCH = 503 + + # 900-999 Errors from upstream libraries and runtimes + INTERNAL = 998 + UNKNOWN = 999 + + +class Moment(Enum): # Image moments types M00 = 1 M01 = 2 M10 = 4 @@ -20,18 +53,25 @@ class Moment(Enum): FIRST_ORDER = M00 | M01 | M10 | M11 +class Pad(Enum): # Edge padding types + ZERO = 0 + SYM = 1 + CLAMP_TO_EDGE = 2 + PERIODIC = 3 + + class PointerSource(Enum): device = 0 # gpu host = 1 # cpu -class TopK(Enum): +class TopK(Enum): # Top-K ordering DEFAULT = 0 MIN = 1 MAX = 2 -class VarianceBias(Enum): +class VarianceBias(Enum): # Variance Bias types DEFAULT = 0 SAMPLE = 1 POPULATION = 2 diff --git a/arrayfire_wrapper/lib/_error_handler.py b/arrayfire_wrapper/lib/_error_handler.py index 290a903..f1922d7 100644 --- a/arrayfire_wrapper/lib/_error_handler.py +++ b/arrayfire_wrapper/lib/_error_handler.py @@ -7,7 +7,7 @@ def safe_call(c_err: int) -> None: - if c_err == ErrorCodes.none.value: + if c_err == ErrorCodes.NONE.value: return err_str = ctypes.c_char_p(0) diff --git a/arrayfire_wrapper/lib/_utility.py b/arrayfire_wrapper/lib/_utility.py index 63a5041..e52dc2e 100644 --- a/arrayfire_wrapper/lib/_utility.py +++ b/arrayfire_wrapper/lib/_utility.py @@ -9,12 +9,12 @@ def binary_op(c_func: Callable, lhs: AFArray, rhs: AFArray, /) -> AFArray: - out = AFArray(0) + out = AFArray.create_null_pointer() safe_call(c_func(ctypes.pointer(out), lhs, rhs, bcast_var.get())) return out def unary_op(c_func: Callable, arr: AFArray, /) -> AFArray: - out = AFArray(0) + out = AFArray.create_null_pointer() safe_call(c_func(ctypes.pointer(out), arr)) return out diff --git a/arrayfire_wrapper/lib/image_processing/filters.py b/arrayfire_wrapper/lib/image_processing/filters.py new file mode 100644 index 0000000..c648690 --- /dev/null +++ b/arrayfire_wrapper/lib/image_processing/filters.py @@ -0,0 +1,17 @@ +import ctypes + +from arrayfire_wrapper._backend import _backend +from arrayfire_wrapper.defines import AFArray, CDimT +from arrayfire_wrapper.lib._constants import Pad +from arrayfire_wrapper.lib._error_handler import safe_call + + +def maxfilt(arr: AFArray, wind_lenght: int, wind_width: int, edge_pad: Pad) -> AFArray: + """ + source: https://arrayfire.org/docs/group__image__func__maxfilt.htm#ga97e07bf5f5c58752d23d1772586b71f4 + """ + out = AFArray.create_null_pointer() + safe_call( + _backend.clib.af_maxfilt(ctypes.pointer(out), arr, CDimT(wind_lenght), CDimT(wind_width), edge_pad.value) + ) + return out diff --git a/arrayfire_wrapper/lib/vector_algorithms/numerical_differentiation.py b/arrayfire_wrapper/lib/vector_algorithms/numerical_differentiation.py new file mode 100644 index 0000000..7346b77 --- /dev/null +++ b/arrayfire_wrapper/lib/vector_algorithms/numerical_differentiation.py @@ -0,0 +1,33 @@ +import ctypes + +from arrayfire_wrapper._backend import _backend +from arrayfire_wrapper.defines import AFArray +from arrayfire_wrapper.lib._error_handler import safe_call + + +def diff1(arr: AFArray, dim: int, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__calc__func__diff1.htm#gad3be33ce8114f65c188645e958fce171 + """ + out = AFArray.create_null_pointer() + safe_call(_backend.clib.af_diff1(ctypes.pointer(out), arr, ctypes.c_int(dim))) + return out + + +def diff2(arr: AFArray, dim: int, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__calc__func__diff2.htm#gafc7b2d05e4e85aeb3e8b3239f598f70c + """ + out = AFArray.create_null_pointer() + safe_call(_backend.clib.af_diff2(ctypes.pointer(out), arr, ctypes.c_int(dim))) + return out + + +def gradient(arr: AFArray, /) -> tuple[AFArray, AFArray]: + """ + source: https://arrayfire.org/docs/group__calc__func__grad.htm#gadb342e6765c1536125261b035f7eee59 + """ + out_dx = AFArray.create_null_pointer() + out_dy = AFArray.create_null_pointer() + safe_call(_backend.clib.af_gradient(ctypes.pointer(out_dx), ctypes.pointer(out_dy), arr)) + return (out_dx, out_dy) diff --git a/arrayfire_wrapper/lib/vector_algorithms/reduction_operations.py b/arrayfire_wrapper/lib/vector_algorithms/reduction_operations.py new file mode 100644 index 0000000..114a3e0 --- /dev/null +++ b/arrayfire_wrapper/lib/vector_algorithms/reduction_operations.py @@ -0,0 +1,294 @@ +import ctypes + +from arrayfire_wrapper._backend import _backend +from arrayfire_wrapper.defines import AFArray +from arrayfire_wrapper.lib._error_handler import safe_call + + +def all_true(arr: AFArray, dim: int, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__reduce__func__all__true.htm#ga068708be5177a0aa3788af140bb5ebd6 + """ + out = AFArray.create_null_pointer() + safe_call(_backend.clib.af_all_true(ctypes.pointer(out), arr, dim)) + return out + + +def all_true_all(arr: AFArray, /) -> complex: + """ + source: https://arrayfire.org/docs/group__reduce__func__all__true.htm#ga068708be5177a0aa3788af140bb5ebd6 + """ + real = ctypes.c_double(0) + imag = ctypes.c_double(0) + safe_call(_backend.clib.af_all_true(ctypes.pointer(real), ctypes.pointer(imag), arr)) + return real.value if imag.value == 0 else real.value + imag.value * 1j + + +def all_true_by_key(keys: AFArray, values: AFArray, dim: int, /) -> tuple[AFArray, AFArray]: + """ + source: https://arrayfire.org/docs/algorithm_8h.htm#a65fa5577c81a2c2fcf7406bf48cc014a + """ + out_keys = AFArray.create_null_pointer() + out_values = AFArray.create_null_pointer() + safe_call( + _backend.clib.af_all_true_by_key( + ctypes.pointer(out_keys), ctypes.pointer(out_values), keys, values, ctypes.c_int(dim) + ) + ) + return (out_keys, out_values) + + +def any_true(arr: AFArray, dim: int, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__reduce__func__any__true.htm#ga7c275cda2cfc8eb0bd20ea86472ca0d5 + """ + out = AFArray.create_null_pointer() + safe_call(_backend.clib.af_all_true(ctypes.pointer(out), arr, dim)) + return out + + +def any_true_all(arr: AFArray, /) -> int | float | bool | complex: + """ + source: https://arrayfire.org/docs/group__reduce__func__any__true.htm#ga47d991276bb5bf8cdba8340e8751e536 + """ + real = ctypes.c_double(0) + imag = ctypes.c_double(0) + safe_call(_backend.clib.af_all_true(ctypes.pointer(real), ctypes.pointer(imag), arr)) + return real.value if imag.value == 0 else real.value + imag.value * 1j + + +def any_true_by_key(keys: AFArray, values: AFArray, dim: int, /) -> tuple[AFArray, AFArray]: + """ + source: https://arrayfire.org/docs/group__reduce__func__anytrue__by__key.htm#ga973fd650f8a57533f675cfd7ad6f0718 + """ + out_keys = AFArray.create_null_pointer() + out_values = AFArray.create_null_pointer() + safe_call( + _backend.clib.af_any_true_by_key( + ctypes.pointer(out_keys), ctypes.pointer(out_values), keys, values, ctypes.c_int(dim) + ) + ) + return (out_keys, out_values) + + +def count(arr: AFArray, dim: int, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__reduce__func__count.htm#gaf2664c25ee6ca30aa3f5aa77db789f95 + """ + out = AFArray.create_null_pointer() + safe_call(_backend.clib.af_count(ctypes.pointer(out), arr, dim)) + return out + + +def count_all(arr: AFArray, /) -> complex: + """ + source: https://arrayfire.org/docs/group__reduce__func__count.htm#ga38699c5ce172c15e9850a9eda6050da5 + """ + real = ctypes.c_double(0) + imag = ctypes.c_double(0) + safe_call(_backend.clib.af_count_all(ctypes.pointer(real), ctypes.pointer(imag), arr)) + return real.value if imag.value == 0 else real.value + imag.value * 1j + + +def count_by_key(keys: AFArray, values: AFArray, dim: int, /) -> tuple[AFArray, AFArray]: + """ + source: https://arrayfire.org/docs/group__reduce__func__count__by__key.htm#ga96b01fd7375b3a3cb065ba860885e723 + """ + out_keys = AFArray.create_null_pointer() + out_values = AFArray.create_null_pointer() + safe_call( + _backend.clib.af_count_by_key( + ctypes.pointer(out_keys), ctypes.pointer(out_values), keys, values, ctypes.c_int(dim) + ) + ) + return (out_keys, out_values) + + +def imax(arr: AFArray, dim: int, /) -> tuple[AFArray, AFArray]: + """ + source: https://arrayfire.org/docs/group__reduce__func__max.htm#gaf0e6a523e2e435d5409d5d8cb843d8a2 + """ + out = AFArray.create_null_pointer() + out_idx = AFArray.create_null_pointer() + safe_call(_backend.clib.af_imax(ctypes.pointer(out), ctypes.pointer(out_idx), arr, ctypes.c_int(dim))) + return (out, out_idx) + + +def imax_all(arr: AFArray, /) -> tuple[complex, int]: + """ + source: https://arrayfire.org/docs/group__reduce__func__max.htm#gaea009bd51145be2fcc688b2390725401 + """ + real = ctypes.c_double(0) + imag = ctypes.c_double(0) + out_idx = ctypes.c_uint(0) + safe_call(_backend.clib.af_imax_all(ctypes.pointer(real), ctypes.pointer(imag), ctypes.pointer(out_idx), arr)) + complex_value = real.value if imag.value == 0 else real.value + imag.value * 1j + return (complex_value, out_idx.value) + + +def max(arr: AFArray, dim: int, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__reduce__func__max.htm#ga267f32b8dbb1b508e8738e3748d8dc3f + """ + out = AFArray.create_null_pointer() + safe_call(_backend.clib.af_max(ctypes.pointer(out), arr, dim)) + return out + + +def max_all(arr: AFArray, /) -> complex: + """ + source: https://arrayfire.org/docs/group__reduce__func__max.htm#ga5f71ab6056943723149585d2aebade7c + """ + real = ctypes.c_double(0) + imag = ctypes.c_double(0) + safe_call(_backend.clib.af_max_all(ctypes.pointer(real), ctypes.pointer(imag), arr)) + return real.value if imag.value == 0 else real.value + imag.value * 1j + + +def max_ragged(arr: AFArray, ragged_len: AFArray, dim: int, /) -> tuple[AFArray, AFArray]: + """ + source: https://arrayfire.org/docs/group__reduce__func__max.htm#ga564bbeca8e4c243355979a6cb5dc4970 + """ + out_values = AFArray.create_null_pointer() + out_idx = AFArray.create_null_pointer() + safe_call( + _backend.clib.af_max_ragged( + ctypes.pointer(out_values), ctypes.pointer(out_idx), arr, ragged_len, ctypes.c_int(dim) + ) + ) + return (out_values, out_idx) + + +def max_by_key(keys: AFArray, values: AFArray, dim: int, /) -> tuple[AFArray, AFArray]: + """ + source: https://arrayfire.org/docs/group__reduce__func__max__by__key.htm#ga002d03c0ebd674644c8a6831ebb775e2 + """ + out_keys = AFArray.create_null_pointer() + out_values = AFArray.create_null_pointer() + safe_call( + _backend.clib.af_max_by_key( + ctypes.pointer(out_keys), ctypes.pointer(out_values), keys, values, ctypes.c_int(dim) + ) + ) + return (out_keys, out_values) + + +def imin(arr: AFArray, dim: int, /) -> tuple[AFArray, AFArray]: + """ + source: https://arrayfire.org/docs/group__reduce__func__min.htm#ga2f65943090e0c2317bd682c25594b901 + """ + out = AFArray.create_null_pointer() + out_idx = AFArray.create_null_pointer() + safe_call(_backend.clib.af_imin(ctypes.pointer(out), ctypes.pointer(out_idx), arr, ctypes.c_int(dim))) + return (out, out_idx) + + +def imin_all(arr: AFArray, /) -> tuple[complex, int]: + """ + source: https://arrayfire.org/docs/group__reduce__func__min.htm#gae75785af0fdfcbb1f4c34461235f5206 + """ + real = ctypes.c_double(0) + imag = ctypes.c_double(0) + out_idx = ctypes.c_uint(0) + safe_call(_backend.clib.af_imin_all(ctypes.pointer(real), ctypes.pointer(imag), ctypes.pointer(out_idx), arr)) + complex_value = real.value if imag.value == 0 else real.value + imag.value * 1j + return (complex_value, out_idx.value) + + +def min(arr: AFArray, dim: int, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__reduce__func__min.htm#ga2ac4c8d9ba613dbc9bfec0bee7be8eb8 + """ + out = AFArray.create_null_pointer() + safe_call(_backend.clib.af_min(ctypes.pointer(out), arr, dim)) + return out + + +def min_all(arr: AFArray, /) -> complex: + """ + source: https://arrayfire.org/docs/group__reduce__func__min.htm#gab10198ae7ead1dc10f220d576f118104 + """ + real = ctypes.c_double(0) + imag = ctypes.c_double(0) + safe_call(_backend.clib.af_min_all(ctypes.pointer(real), ctypes.pointer(imag), arr)) + return real.value if imag.value == 0 else real.value + imag.value * 1j + + +def product(arr: AFArray, dim: int, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__reduce__func__product.htm#ga2be338d39be30ad22dddf658a4f5676e + """ + out = AFArray.create_null_pointer() + safe_call(_backend.clib.af_product(ctypes.pointer(out), arr, dim)) + return out + + +def product_all(arr: AFArray, /) -> complex: + """ + source: https://arrayfire.org/docs/group__reduce__func__product.htm#gad226a6ec77c12fd16cf42e3fe3264e22 + """ + real = ctypes.c_double(0) + imag = ctypes.c_double(0) + safe_call(_backend.clib.af_product_all(ctypes.pointer(real), ctypes.pointer(imag), arr)) + return real.value if imag.value == 0 else real.value + imag.value * 1j + + +def product_nan(arr: AFArray, dim: int, nan_value: float, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__reduce__func__product.htm#ga1d25447c16d492767ba7efa7ee72a36e + """ + out = AFArray.create_null_pointer() + safe_call(_backend.clib.af_product_nan(ctypes.pointer(out), arr, dim, ctypes.c_double(nan_value))) + return out + + +def product_nan_all(arr: AFArray, nan_value: float, /) -> complex: + """ + source: https://arrayfire.org/docs/group__reduce__func__product.htm#gaca78d54c53a33b419bfdb5c64accbc7b + """ + real = ctypes.c_double(0) + imag = ctypes.c_double(0) + safe_call( + _backend.clib.af_product_nan_all(ctypes.pointer(real), ctypes.pointer(imag), arr, ctypes.c_double(nan_value)) + ) + return real.value if imag.value == 0 else real.value + imag.value * 1j + + +def sum(arr: AFArray, dim: int, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__reduce__func__sum.htm#gacd4917c2e916870ebdf54afc2f61d533 + """ + out = AFArray.create_null_pointer() + safe_call(_backend.clib.af_sum(ctypes.pointer(out), arr, dim)) + return out + + +def sum_all(arr: AFArray, /) -> complex: + """ + source: https://arrayfire.org/docs/group__reduce__func__sum.htm#gabc009d04df0faf29ba1e381c7badde58 + """ + real = ctypes.c_double(0) + imag = ctypes.c_double(0) + safe_call(_backend.clib.af_sum_all(ctypes.pointer(real), ctypes.pointer(imag), arr)) + return real.value if imag.value == 0 else real.value + imag.value * 1j + + +def sum_nan(arr: AFArray, dim: int, nan_value: float, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__reduce__func__sum.htm#ga52461231e2d9995f689b7f23eea0e798 + """ + out = AFArray.create_null_pointer() + safe_call(_backend.clib.af_sum_nan(ctypes.pointer(out), arr, dim, ctypes.c_double(nan_value))) + return out + + +def sum_nan_all(arr: AFArray, nan_value: float, /) -> complex: + """ + source: https://arrayfire.org/docs/group__reduce__func__sum.htm#gabc009d04df0faf29ba1e381c7badde58 + """ + real = ctypes.c_double(0) + imag = ctypes.c_double(0) + safe_call( + _backend.clib.af_sum_nan_all(ctypes.pointer(real), ctypes.pointer(imag), arr, ctypes.c_double(nan_value)) + ) + return real.value if imag.value == 0 else real.value + imag.value * 1j diff --git a/arrayfire_wrapper/lib/vector_algorithms/set_operations.py b/arrayfire_wrapper/lib/vector_algorithms/set_operations.py new file mode 100644 index 0000000..759e80d --- /dev/null +++ b/arrayfire_wrapper/lib/vector_algorithms/set_operations.py @@ -0,0 +1,32 @@ +import ctypes + +from arrayfire_wrapper._backend import _backend +from arrayfire_wrapper.defines import AFArray +from arrayfire_wrapper.lib._error_handler import safe_call + + +def set_intersect(first: AFArray, second: AFArray, is_unique: bool, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__set__func__intersect.htm#ga985f9332c5f858eec66c717881ef2607 + """ + out = AFArray.create_null_pointer() + safe_call(_backend.clib.af_set_intersect(ctypes.pointer(out), first, second, ctypes.c_bool(is_unique))) + return out + + +def set_union(first: AFArray, second: AFArray, is_unique: bool, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__set__func__union.htm#gaabeead0c0dc360db9398e9703dbb273f + """ + out = AFArray.create_null_pointer() + safe_call(_backend.clib.af_set_union(ctypes.pointer(out), first, second, ctypes.c_bool(is_unique))) + return out + + +def set_unique(arr: AFArray, is_sorted: bool, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__set__func__unique.htm#ga6afa1de48cbbc4b2df530c2530087943 + """ + out = AFArray.create_null_pointer() + safe_call(_backend.clib.af_set_intersect(ctypes.pointer(out), ctypes.c_bool(is_sorted))) + return out diff --git a/arrayfire_wrapper/lib/vector_algorithms/sort_operations.py b/arrayfire_wrapper/lib/vector_algorithms/sort_operations.py new file mode 100644 index 0000000..e1b5ce3 --- /dev/null +++ b/arrayfire_wrapper/lib/vector_algorithms/sort_operations.py @@ -0,0 +1,47 @@ +import ctypes + +from arrayfire_wrapper._backend import _backend +from arrayfire_wrapper.defines import AFArray +from arrayfire_wrapper.lib._error_handler import safe_call + + +def sort(arr: AFArray, dim: int, is_ascending: bool, /) -> AFArray: + """ + source: https://arrayfire.org/docs/group__sort__func__sort.htm#gac4460d605452515d07ee8432f906aa8e + """ + out = AFArray.create_null_pointer() + safe_call(_backend.clib.af_sort(ctypes.pointer(out), arr, ctypes.c_uint(dim), ctypes.c_bool(is_ascending))) + return out + + +def sort_by_key(keys: AFArray, values: AFArray, dim: int, is_ascending: bool, /) -> tuple[AFArray, AFArray]: + """ + source: https://arrayfire.org/docs/group__sort__func__sort__keys.htm#ga7d4fcaf229ece5fbbe30a638d9a60b8a + """ + out_keys = AFArray.create_null_pointer() + out_values = AFArray.create_null_pointer() + safe_call( + _backend.clib.af_sort_by_key( + ctypes.pointer(out_keys), + ctypes.pointer(out_values), + keys, + values, + ctypes.c_uint(dim), + ctypes.c_bool(is_ascending), + ) + ) + return (out_keys, out_values) + + +def sort_index(arr: AFArray, dim: int, is_ascending: bool, /) -> tuple[AFArray, AFArray]: + """ + source: https://arrayfire.org/docs/group__sort__func__sort__index.htm#ga55675cd825c320db87398b1010b6ae41 + """ + out = AFArray.create_null_pointer() + out_idx = AFArray.create_null_pointer() + safe_call( + _backend.clib.af_sort_index( + ctypes.pointer(out), ctypes.pointer(out_idx), arr, ctypes.c_uint(dim), ctypes.c_bool(is_ascending) + ) + ) + return (out, out_idx)