Skip to content

Commit

Permalink
Cleaned-up extra.array_api.ArrayStrategy
Browse files Browse the repository at this point in the history
  • Loading branch information
honno committed Sep 29, 2021
1 parent f5ec0a8 commit 88e6792
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 37 deletions.
74 changes: 37 additions & 37 deletions hypothesis-python/src/hypothesis/extra/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,28 +266,14 @@ def __init__(self, xp, elements_strategy, dtype, shape, fill, unique):
self.array_size = math.prod(shape)
self.builtin = find_castable_builtin_for_dtype(xp, dtype)

def set_value(self, result, i, val, strategy=None):
strategy = strategy or self.elements_strategy
try:
result[i] = val
except TypeError as e:
raise InvalidArgument(
f"Could not add generated array element {val!r} "
f"of dtype {type(val)} to array of dtype {result.dtype}."
) from e
self.check_set_value(val, result[i], strategy)

def check_set_value(self, val, val_0d, strategy):
if self.builtin is bool:
finite = True
else:
finite = self.xp.isfinite(val_0d)
finite = self.builtin is bool or self.xp.isfinite(val_0d)
if finite and self.builtin(val_0d) != val:
raise InvalidArgument(
f"Generated array element {val!r} from strategy {strategy} "
f"cannot be represented as dtype {self.dtype}. "
f"cannot be represented with dtype {self.dtype}. "
f"Array module {self.xp.__name__} instead "
f"represents the element as {val_0d!r}. "
f"represents the element as {val_0d}. "
"Consider using a more precise elements strategy, "
"for example passing the width argument to floats()."
)
Expand All @@ -302,26 +288,32 @@ def do_draw(self, data):
# elements strategy does not produce reusable values), so we must
# generate a fully dense array with a freshly drawn value for each
# entry.

# This could legitimately be a xp.empty, but the performance gains
# for that are likely marginal, so there's really not much point
# risking undefined behaviour shenanigans.
result = self.xp.zeros(self.array_size, dtype=self.dtype)

if self.unique:
seen = set()
elems = st.lists(
elems = data.draw(
st.lists(
self.elements_strategy,
min_size=self.array_size,
max_size=self.array_size,
unique=True,
unique=self.unique,
)
for i, v in enumerate(data.draw(elems)):
self.set_value(result, i, v)
else:
for i in range(self.array_size):
val = data.draw(self.elements_strategy)
self.set_value(result, i, val)
)
try:
result = self.xp.asarray(elems, dtype=self.dtype)
except Exception as e:
if len(elems) <= 6:
f_elems = str(elems)
else:
f_elems = f"[{elems[0]}, {elems[1]}, ..., {elems[-2]}, {elems[-1]}]"
types = tuple({type(e) for e in elems})
f_types = f"type {types[0]}" if len(types) == 1 else f"types {types}"
raise InvalidArgument(
f"Generated elements {f_elems} from strategy "
f"{self.elements_strategy} could not be converted "
f"to array of dtype {self.dtype}. "
f"Consider if elements of {f_types} "
f"are compatible with {self.dtype}."
) from e
for i in range(self.array_size):
self.check_set_value(elems[i], result[i], self.elements_strategy)
else:
# We draw arrays as "sparse with an offset". We assume not every
# element will be assigned and so first draw a single value from our
Expand All @@ -338,7 +330,7 @@ def do_draw(self, data):
f"with fill value {fill_val!r}"
) from e
sample = result[0]
self.check_set_value(fill_val, sample, strategy=self.fill)
self.check_set_value(fill_val, sample, self.fill)
if self.unique and not self.xp.all(self.xp.isnan(result)):
raise InvalidArgument(
f"Array module {self.xp.__name__} did not recognise fill "
Expand Down Expand Up @@ -371,7 +363,14 @@ def do_draw(self, data):
continue
else:
seen.add(val)
self.set_value(result, i, val)
try:
result[i] = val
except Exception as e:
raise InvalidArgument(
f"Could not add generated array element {val!r} "
f"of type {type(val)} to array of dtype {result.dtype}."
) from e
self.check_set_value(val, result[i], self.elements_strategy)
assigned.add(i)

result = self.xp.reshape(result, self.shape)
Expand Down Expand Up @@ -459,8 +458,9 @@ def _arrays(
hundreds or more elements, having a fill value is essential if you want
your tests to run in reasonable time.
"""

check_xp_attributes(xp, ["zeros", "full", "all", "isnan", "isfinite", "reshape"])
check_xp_attributes(
xp, ["asarray", "zeros", "full", "all", "isnan", "isfinite", "reshape"]
)

if isinstance(dtype, st.SearchStrategy):
return dtype.flatmap(
Expand Down
1 change: 1 addition & 0 deletions hypothesis-python/tests/array_api/test_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def count_unique(x):
# TODO: The Array API makes boolean indexing optional, so in the future this
# will need to be reworked if we want to test libraries other than NumPy.
# If not possible, errors should be caught and the test skipped.
# See https://github.com/data-apis/array-api/issues/249
filtered_x = x[~nan_index]
unique_x = xp.unique(filtered_x)
n_unique += unique_x.size
Expand Down

0 comments on commit 88e6792

Please sign in to comment.