Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: arrayfire/arrayfire-binary-python-wrapper
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: master
Choose a base ref
...
head repository: sakchal/arrayfire-binary-python-wrapper
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: master
Choose a head ref
Able to merge. These branches can be automatically merged.
  • 5 commits
  • 2 files changed
  • 1 contributor

Commits on Mar 13, 2024

  1. utility functions for tests

    Chaluvadi committed Mar 13, 2024

    Verified

    This commit was signed with the committer’s verified signature.
    tamird Tamir Duberstein
    Copy the full SHA
    25fbfd0 View commit details

Commits on Mar 28, 2024

  1. Fixed utility function merge conflict

    Chaluvadi committed Mar 28, 2024
    Copy the full SHA
    8c7d809 View commit details
  2. added unit tests for range function

    Chaluvadi committed Mar 28, 2024

    Verified

    This commit was signed with the committer’s verified signature.
    dtolnay David Tolnay
    Copy the full SHA
    2fbaace View commit details
  3. Readability changes to cosntants tests

    Chaluvadi committed Mar 28, 2024

    Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
    Copy the full SHA
    1bc0f3a View commit details
  4. readability changes pt.2

    Chaluvadi committed Mar 28, 2024

    Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature.
    Copy the full SHA
    bc96cbc View commit details
Showing with 113 additions and 27 deletions.
  1. +52 −27 tests/test_constants.py
  2. +61 −0 tests/test_range.py
79 changes: 52 additions & 27 deletions tests/test_constants.py
Original file line number Diff line number Diff line change
@@ -2,8 +2,23 @@

import pytest

import arrayfire_wrapper.dtypes as dtypes
import arrayfire_wrapper.lib as wrapper
from arrayfire_wrapper.dtypes import (
Dtype,
c32,
c64,
c_api_value_to_dtype,
f16,
f32,
f64,
s16,
s32,
s64,
u8,
u16,
u32,
u64,
)

invalid_shape = (
random.randint(1, 10),
@@ -14,6 +29,9 @@
)


all_types = [s16, s32, s64, u8, u16, u32, u64, f16, f32, f64, c32, c64]


@pytest.mark.parametrize(
"shape",
[
@@ -27,7 +45,7 @@
def test_constant_shape(shape: tuple) -> None:
"""Test if constant creates an array with the correct shape."""
number = 5.0
dtype = dtypes.s16
dtype = s16

result = wrapper.constant(number, shape, dtype)

@@ -46,9 +64,8 @@ def test_constant_shape(shape: tuple) -> None:
)
def test_constant_complex_shape(shape: tuple) -> None:
"""Test if constant_complex creates an array with the correct shape."""
dtype = dtypes.c32
dtype = c32

dtype = dtypes.c32
rand_array = wrapper.randu((1, 1), dtype)
number = wrapper.get_scalar(rand_array, dtype)

@@ -71,7 +88,7 @@ def test_constant_complex_shape(shape: tuple) -> None:
)
def test_constant_long_shape(shape: tuple) -> None:
"""Test if constant_long creates an array with the correct shape."""
dtype = dtypes.s64
dtype = s64
rand_array = wrapper.randu((1, 1), dtype)
number = wrapper.get_scalar(rand_array, dtype)

@@ -93,7 +110,7 @@ def test_constant_long_shape(shape: tuple) -> None:
)
def test_constant_ulong_shape(shape: tuple) -> None:
"""Test if constant_ulong creates an array with the correct shape."""
dtype = dtypes.u64
dtype = u64
rand_array = wrapper.randu((1, 1), dtype)
number = wrapper.get_scalar(rand_array, dtype)

@@ -109,15 +126,15 @@ def test_constant_shape_invalid() -> None:
"""Test if constant handles a shape with greater than 4 dimensions"""
with pytest.raises(TypeError):
number = 5.0
dtype = dtypes.s16
dtype = s16

wrapper.constant(number, invalid_shape, dtype)


def test_constant_complex_shape_invalid() -> None:
"""Test if constant_complex handles a shape with greater than 4 dimensions"""
with pytest.raises(TypeError):
dtype = dtypes.c32
dtype = c32
rand_array = wrapper.randu((1, 1), dtype)
number = wrapper.get_scalar(rand_array, dtype)

@@ -128,7 +145,7 @@ def test_constant_complex_shape_invalid() -> None:
def test_constant_long_shape_invalid() -> None:
"""Test if constant_long handles a shape with greater than 4 dimensions"""
with pytest.raises(TypeError):
dtype = dtypes.s64
dtype = s64
rand_array = wrapper.randu((1, 1), dtype)
number = wrapper.get_scalar(rand_array, dtype)

@@ -139,7 +156,7 @@ def test_constant_long_shape_invalid() -> None:
def test_constant_ulong_shape_invalid() -> None:
"""Test if constant_ulong handles a shape with greater than 4 dimensions"""
with pytest.raises(TypeError):
dtype = dtypes.u64
dtype = u64
rand_array = wrapper.randu((1, 1), dtype)
number = wrapper.get_scalar(rand_array, dtype)

@@ -148,50 +165,47 @@ def test_constant_ulong_shape_invalid() -> None:


@pytest.mark.parametrize(
"dtype_index",
[i for i in range(13)],
"dtype",
all_types,
)
def test_constant_dtype(dtype_index: int) -> None:
def test_constant_dtype(dtype: Dtype) -> None:
"""Test if constant creates an array with the correct dtype."""
if dtype_index in [1, 3] or (dtype_index == 2 and not wrapper.get_dbl_support()):
if is_cmplx_type(dtype) or not is_system_supported(dtype):
pytest.skip()

dtype = dtypes.c_api_value_to_dtype(dtype_index)

rand_array = wrapper.randu((1, 1), dtype)
value = wrapper.get_scalar(rand_array, dtype)
shape = (2, 2)
if isinstance(value, (int, float)):
result = wrapper.constant(value, shape, dtype)
assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype
assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype
else:
pytest.skip()


@pytest.mark.parametrize(
"dtype_index",
[i for i in range(13)],
"dtype",
all_types,
)
def test_constant_complex_dtype(dtype_index: int) -> None:
def test_constant_complex_dtype(dtype: Dtype) -> None:
"""Test if constant_complex creates an array with the correct dtype."""
if dtype_index not in [1, 3] or (dtype_index == 3 and not wrapper.get_dbl_support()):
if not is_cmplx_type(dtype) or not is_system_supported(dtype):
pytest.skip()

dtype = dtypes.c_api_value_to_dtype(dtype_index)
rand_array = wrapper.randu((1, 1), dtype)
value = wrapper.get_scalar(rand_array, dtype)
shape = (2, 2)

if isinstance(value, (int, float, complex)):
result = wrapper.constant_complex(value, shape, dtype)
assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype
assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype
else:
pytest.skip()


def test_constant_long_dtype() -> None:
"""Test if constant_long creates an array with the correct dtype."""
dtype = dtypes.s64
dtype = s64

rand_array = wrapper.randu((1, 1), dtype)
value = wrapper.get_scalar(rand_array, dtype)
@@ -200,14 +214,14 @@ def test_constant_long_dtype() -> None:
if isinstance(value, (int, float)):
result = wrapper.constant_long(value, shape, dtype)

assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype
assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype
else:
pytest.skip()


def test_constant_ulong_dtype() -> None:
"""Test if constant_ulong creates an array with the correct dtype."""
dtype = dtypes.u64
dtype = u64

rand_array = wrapper.randu((1, 1), dtype)
value = wrapper.get_scalar(rand_array, dtype)
@@ -216,6 +230,17 @@ def test_constant_ulong_dtype() -> None:
if isinstance(value, (int, float)):
result = wrapper.constant_ulong(value, shape, dtype)

assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype
assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype
else:
pytest.skip()


def is_cmplx_type(dtype: Dtype) -> bool:
return dtype == c32 or dtype == c64


def is_system_supported(dtype: Dtype) -> bool:
if dtype in [f64, c64] and not wrapper.get_dbl_support():
return False

return True
61 changes: 61 additions & 0 deletions tests/test_range.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import random

import pytest

import arrayfire_wrapper.dtypes as dtypes
import arrayfire_wrapper.lib as wrapper


@pytest.mark.parametrize(
"shape",
[
(),
(random.randint(1, 10), 1),
(random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
],
)
def test_range_shape(shape: tuple) -> None:
"""Test if the range function output an AFArray with the correct shape"""
dim = 2
dtype = dtypes.s16

result = wrapper.range(shape, dim, dtype)

assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203


def test_range_invalid_shape() -> None:
"""Test if range function correctly handles an invalid shape"""
with pytest.raises(TypeError):
shape = (
random.randint(1, 10),
random.randint(1, 10),
random.randint(1, 10),
random.randint(1, 10),
random.randint(1, 10),
)
dim = 2
dtype = dtypes.s16

wrapper.range(shape, dim, dtype)


@pytest.mark.parametrize(
"shape",
[
(),
(random.randint(1, 10), 1),
(random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
],
)
def test_range_invalid_dim(shape: tuple) -> None:
"""Test if the range function can properly handle and invalid dimension given"""
with pytest.raises(RuntimeError):
dim = random.randint(4, 10)
dtype = dtypes.s16

wrapper.range(shape, dim, dtype)