/
helpers.py
492 lines (410 loc) · 19.3 KB
/
helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
import sys
from itertools import chain
import warnings
from functools import wraps
from numpy import intp, bool_, array, broadcast_shapes
import numpy.testing
from pytest import fail
from hypothesis import assume, note
from hypothesis.strategies import (integers, none, one_of, lists, just,
builds, shared, composite, sampled_from,
nothing, tuples as hypothesis_tuples)
from hypothesis.extra.numpy import (arrays, mutually_broadcastable_shapes as
mbs, BroadcastableShapes, valid_tuple_axes)
from ..ndindex import ndindex
from ..shapetools import remove_indices, unremove_indices
from .._crt import prod
# Hypothesis strategies for generating indices. Note that some of these
# strategies are nominally already defined in hypothesis, but we redefine them
# here because the hypothesis definitions are too restrictive. For example,
# hypothesis's slices strategy does not generate slices with negative indices.
# Similarly, hypothesis.extra.numpy.basic_indices only generates tuples.
nonnegative_ints = integers(0, 10)
negative_ints = integers(-10, -1)
ints = lambda: one_of(nonnegative_ints, negative_ints)
def slices(start=one_of(none(), ints()), stop=one_of(none(), ints()),
step=one_of(none(), ints())):
return builds(slice, start, stop, step)
ellipses = lambda: just(...)
newaxes = lambda: just(None)
# hypotheses.strategies.tuples only generates tuples of a fixed size
def tuples(elements, *, min_size=0, max_size=None, unique_by=None, unique=False):
return lists(elements, min_size=min_size, max_size=max_size,
unique_by=unique_by, unique=unique).map(tuple)
MAX_ARRAY_SIZE = 100000
SHORT_MAX_ARRAY_SIZE = 1000
shapes = tuples(integers(0, 10)).filter(
# numpy gives errors with empty arrays with large shapes.
# See https://github.com/numpy/numpy/issues/15753
lambda shape: prod([i for i in shape if i]) < MAX_ARRAY_SIZE)
_short_shapes = lambda n: tuples(integers(0, 10), min_size=n).filter(
# numpy gives errors with empty arrays with large shapes.
# See https://github.com/numpy/numpy/issues/15753
lambda shape: prod([i for i in shape if i]) < SHORT_MAX_ARRAY_SIZE)
# short_shapes should be used in place of shapes in any test function that
# uses ndindices, boolean_arrays, or tuples
short_shapes = shared(_short_shapes(0))
_integer_arrays = arrays(intp, short_shapes)
integer_scalars = arrays(intp, ()).map(lambda x: x[()])
integer_arrays = one_of(integer_scalars, _integer_arrays.flatmap(lambda x: one_of(just(x), just(x.tolist()))))
# We need to make sure shapes for boolean arrays are generated in a way that
# makes them related to the test array shape. Otherwise, it will be very
# difficult for the boolean array index to match along the test array, which
# means we won't test any behavior other than IndexError.
@composite
def subsequences(draw, sequence):
seq = draw(sequence)
start = draw(integers(0, max(0, len(seq)-1)))
stop = draw(integers(start, len(seq)))
return seq[start:stop]
_boolean_arrays = arrays(bool_, one_of(subsequences(short_shapes), short_shapes))
boolean_scalars = arrays(bool_, ()).map(lambda x: x[()])
boolean_arrays = one_of(boolean_scalars, _boolean_arrays.flatmap(lambda x: one_of(just(x), just(x.tolist()))))
def _doesnt_raise(idx):
try:
ndindex(idx)
except (IndexError, ValueError, NotImplementedError):
return False
return True
Tuples = tuples(one_of(ellipses(), ints(), slices(), newaxes(),
integer_arrays, boolean_arrays)).filter(_doesnt_raise)
ndindices = one_of(
ints(),
slices(),
ellipses(),
newaxes(),
Tuples,
integer_arrays,
boolean_arrays,
).filter(_doesnt_raise)
# Note: We could use something like this:
# mutually_broadcastable_shapes = shared(integers(1, 32).flatmap(lambda i: mbs(num_shapes=i).filter(
# lambda broadcastable_shapes: prod([i for i in broadcastable_shapes.result_shape if i]) < MAX_ARRAY_SIZE)))
@composite
def _mutually_broadcastable_shapes(draw, *, shapes=short_shapes, min_shapes=0, max_shapes=32, min_side=0):
# mutually_broadcastable_shapes() with the default inputs doesn't generate
# very interesting examples (see
# https://github.com/HypothesisWorks/hypothesis/issues/3170). It's very
# difficult to get it to do so by tweaking the max_* parameters, because
# making them too big leads to generating too large shapes and filtering
# too much. So instead, we trick it into generating more interesting
# examples by telling it to create shapes that broadcast against some base
# shape.
# Unfortunately, this, along with the filtering below, has a downside that
# it tends to generate a result shape of () more often than you might
# like. But it generates enough "real" interesting shapes that both of
# these workarounds are worth doing (plus I don't know if any other better
# way of handling the situation).
base_shape = draw(shapes)
input_shapes, result_shape = draw(
mbs(
num_shapes=max_shapes,
base_shape=base_shape,
min_side=min_side,
))
# The hypothesis mutually_broadcastable_shapes doesn't allow num_shapes to
# be a strategy. It's tempting to do something like num_shapes =
# draw(integers(min_shapes, max_shapes)), but this shrinks poorly. See
# https://github.com/HypothesisWorks/hypothesis/issues/3151. So instead of
# using a strategy to draw the number of shapes, we just generate max_shapes
# shapes and pick a subset of them.
final_input_shapes = draw(lists(sampled_from(input_shapes),
min_size=min_shapes, max_size=max_shapes))
# Note: result_shape is input_shapes broadcasted with base_shape, but
# base_shape itself is not part of input_shapes. We "really" want our base
# shape to be (). We are only using it here to trick
# mutually_broadcastable_shapes into giving more interesting examples.
final_result_shape = broadcast_shapes(*final_input_shapes)
# The broadcast compatible shapes can be bigger than the base shape. This
# is already somewhat limited by the mutually_broadcastable_shapes
# defaults, and pretty unlikely, but we filter again here just to be safe.
if not prod([i for i in final_result_shape if i]) < SHORT_MAX_ARRAY_SIZE: # pragma: no cover
note(f"Filtering the shape {result_shape} (too many elements)")
assume(False)
return BroadcastableShapes(final_input_shapes, final_result_shape)
mutually_broadcastable_shapes = shared(_mutually_broadcastable_shapes())
def _fill_shape(draw,
*,
result_shape,
skip_axes,
skip_axes_values):
max_n = max([i + 1 if i >= 0 else -i for i in skip_axes], default=0)
assume(max_n <= len(skip_axes) + len(result_shape))
dim = draw(integers(min_value=max_n, max_value=len(skip_axes) + len(result_shape)))
new_shape = ['placeholder']*dim
for i in skip_axes:
assume(new_shape[i] is not None) # skip_axes must be unique
new_shape[i] = None
j = -1
for i in range(-1, -dim - 1, -1):
if new_shape[i] is None:
new_shape[i] = draw(skip_axes_values)
else:
new_shape[i] = draw(sampled_from([result_shape[j], 1]))
j -= 1
while new_shape and new_shape[0] == 'placeholder': # pragma: no cover
# Can happen if positive and negative skip_axes refer to the same
# entry
new_shape.pop(0)
# This will happen if the skip axes are too large
assume('placeholder' not in new_shape)
if prod([i for i in new_shape if i]) >= SHORT_MAX_ARRAY_SIZE:
note(f"Filtering the shape {new_shape} (too many elements)")
assume(False)
return tuple(new_shape)
skip_axes_with_broadcasted_shape_type = shared(sampled_from([int, tuple, list]))
@composite
def _mbs_and_skip_axes(
draw,
shapes=short_shapes,
min_shapes=0,
max_shapes=32,
skip_axes_type_st=skip_axes_with_broadcasted_shape_type,
skip_axes_values=integers(0, 20),
num_skip_axes=None,
):
"""
mutually_broadcastable_shapes except skip_axes() axes might not be
broadcastable
The result_shape will be None in the position of skip_axes.
"""
skip_axes_type = draw(skip_axes_type_st)
_result_shape = draw(shapes)
if _result_shape == ():
assume(num_skip_axes is None)
ndim = len(_result_shape)
num_shapes = draw(integers(min_value=min_shapes, max_value=max_shapes))
if not num_shapes:
assume(num_skip_axes is None)
num_skip_axes = 0
if not ndim:
return BroadcastableShapes([()]*num_shapes, ()), ()
if num_skip_axes is not None:
min_skip_axes = max_skip_axes = num_skip_axes
else:
min_skip_axes = 0
max_skip_axes = None
# int and single tuple cases must be limited to N to ensure that they are
# correct for all shapes
if skip_axes_type == int:
assume(num_skip_axes in [None, 1])
skip_axes = draw(valid_tuple_axes(ndim, min_size=1, max_size=1))[0]
_skip_axes = [(skip_axes,)]*num_shapes
elif skip_axes_type == tuple:
skip_axes = draw(tuples(integers(-ndim, ndim-1), min_size=min_skip_axes,
max_size=max_skip_axes, unique=True))
_skip_axes = [skip_axes]*num_shapes
elif skip_axes_type == list:
skip_axes = []
for i in range(num_shapes):
skip_axes.append(draw(tuples(integers(-ndim, ndim+1), min_size=min_skip_axes,
max_size=max_skip_axes, unique=True)))
_skip_axes = skip_axes
shapes = []
for i in range(num_shapes):
shapes.append(_fill_shape(draw, result_shape=_result_shape, skip_axes=_skip_axes[i],
skip_axes_values=skip_axes_values))
non_skip_shapes = [remove_indices(shape, sk) for shape, sk in
zip(shapes, _skip_axes)]
# Broadcasting the result _fill_shape may produce a shape different from
# _result_shape because it might not have filled all dimensions, or it
# might have chosen 1 for a dimension every time. Ideally we would just be
# using shapes from mutually_broadcastable_shapes, but I don't know how to
# reverse inject skip axes into shapes in general (see the comment in
# unremove_indices). So for now, we just use the actual broadcast of the
# non-skip shapes. Note that we use np.broadcast_shapes here instead of
# ndindex.broadcast_shapes because test_broadcast_shapes itself uses this
# strategy.
broadcasted_shape = broadcast_shapes(*non_skip_shapes)
return BroadcastableShapes(shapes, broadcasted_shape), skip_axes
mbs_and_skip_axes = shared(_mbs_and_skip_axes())
mutually_broadcastable_shapes_with_skipped_axes = mbs_and_skip_axes.map(
lambda i: i[0])
skip_axes_st = mbs_and_skip_axes.map(lambda i: i[1])
@composite
def _cross_shapes_and_skip_axes(draw):
(shapes, _broadcasted_shape), skip_axes = draw(_mbs_and_skip_axes(
shapes=_short_shapes(2),
min_shapes=2,
max_shapes=2,
num_skip_axes=1,
# TODO: Test other skip axes types
skip_axes_type_st=just(list),
skip_axes_values=just(3),
))
broadcasted_skip_axis = draw(integers(-len(_broadcasted_shape)-1, len(_broadcasted_shape)))
broadcasted_shape = unremove_indices(_broadcasted_shape,
[broadcasted_skip_axis], val=3)
skip_axes.append((broadcasted_skip_axis,))
return BroadcastableShapes(shapes, broadcasted_shape), skip_axes
cross_shapes_and_skip_axes = shared(_cross_shapes_and_skip_axes())
cross_shapes = cross_shapes_and_skip_axes.map(lambda i: i[0])
cross_skip_axes = cross_shapes_and_skip_axes.map(lambda i: i[1])
@composite
def cross_arrays_st(draw):
broadcastable_shapes = draw(cross_shapes)
shapes, broadcasted_shape = broadcastable_shapes
# Sanity check
assert len(shapes) == 2
# We need to generate fairly random arrays. Otherwise, if they are too
# similar to each other, like two arange arrays would be, the cross
# product will be 0. We also disable the fill feature in arrays() for the
# same reason, as it would otherwise generate too many vectors that are
# colinear.
a = draw(arrays(dtype=int, shape=shapes[0], elements=integers(-100, 100), fill=nothing()))
b = draw(arrays(dtype=int, shape=shapes[1], elements=integers(-100, 100), fill=nothing()))
return a, b
@composite
def _matmul_shapes_and_skip_axes(draw):
(shapes, _broadcasted_shape), skip_axes = draw(_mbs_and_skip_axes(
shapes=_short_shapes(2),
min_shapes=2,
max_shapes=2,
num_skip_axes=2,
# TODO: Test other skip axes types
skip_axes_type_st=just(list),
skip_axes_values=just(None),
))
broadcasted_skip_axes = draw(hypothesis_tuples(*[
integers(-len(_broadcasted_shape)-1, len(_broadcasted_shape))
]*2))
try:
broadcasted_shape = unremove_indices(_broadcasted_shape,
broadcasted_skip_axes)
except NotImplementedError:
# TODO: unremove_indices only works with both positive or both negative
assume(False)
# Make sure the indices are unique
assume(len(set(broadcasted_skip_axes)) == len(broadcasted_skip_axes))
skip_axes.append(broadcasted_skip_axes)
# (n, m) @ (m, k) -> (n, k)
n, m, k = draw(hypothesis_tuples(integers(0, 10), integers(0, 10),
integers(0, 10)))
shape1, shape2 = map(list, shapes)
ax1, ax2 = skip_axes[0]
shape1[ax1] = n
shape1[ax2] = m
ax1, ax2 = skip_axes[1]
shape2[ax1] = m
shape2[ax2] = k
broadcasted_shape = list(broadcasted_shape)
ax1, ax2 = skip_axes[2]
broadcasted_shape[ax1] = n
broadcasted_shape[ax2] = k
shapes = (tuple(shape1), tuple(shape2))
broadcasted_shape = tuple(broadcasted_shape)
return BroadcastableShapes(shapes, broadcasted_shape), skip_axes
matmul_shapes_and_skip_axes = shared(_matmul_shapes_and_skip_axes())
matmul_shapes = matmul_shapes_and_skip_axes.map(lambda i: i[0])
matmul_skip_axes = matmul_shapes_and_skip_axes.map(lambda i: i[1])
@composite
def matmul_arrays_st(draw):
broadcastable_shapes = draw(matmul_shapes)
shapes, broadcasted_shape = broadcastable_shapes
# Sanity check
assert len(shapes) == 2
a = draw(arrays(dtype=int, shape=shapes[0], elements=integers(-100, 100)))
b = draw(arrays(dtype=int, shape=shapes[1], elements=integers(-100, 100)))
return a, b
reduce_kwargs = sampled_from([{}, {'negative_int': False}, {'negative_int': True}])
def assert_equal(actual, desired, err_msg='', verbose=True):
"""
Same as numpy.testing.assert_equal except it also requires the shapes and
dtypes to be equal.
"""
numpy.testing.assert_equal(actual, desired, err_msg=err_msg,
verbose=verbose)
assert actual.shape == desired.shape, err_msg or f"{actual.shape} != {desired.shape}"
assert actual.dtype == desired.dtype, err_msg or f"{actual.dtype} != {desired.dtype}"
def warnings_are_errors(f):
@wraps(f)
def inner(*args, **kwargs):
with warnings.catch_warnings():
warnings.simplefilter("error")
return f(*args, **kwargs)
return inner
@warnings_are_errors
def check_same(a, idx, *, raw_func=lambda a, idx: a[idx],
ndindex_func=lambda a, index: a[index.raw],
same_exception=True, assert_equal=assert_equal):
"""
Check that a raw index idx produces the same result on an array a before
and after being transformed by ndindex.
Tests that raw_func(a, idx) == ndindex_func(a, ndindex(idx)) or that they
raise the same exception. If same_exception=False, it will still check
that they both raise an exception, but will not require the exception type
and message to be the same.
By default, raw_func(a, idx) is a[idx] and ndindex_func(a, index) is
a[index.raw].
The assert_equal argument changes the function used to test equality. By
default it is the custom assert_equal() function in this file that extends
numpy.testing.assert_equal. If the func functions return something other
than arrays, assert_equal should be set to something else, like
def assert_equal(x, y):
assert x == y
"""
exception = None
try:
# Handle list indices that NumPy treats as tuple indices with a
# deprecation warning. We want to test against the post-deprecation
# behavior.
e_inner = None
try:
try:
a_raw = raw_func(a, idx)
except Warning as w:
# In NumPy < 1.23, this is a FutureWarning. In 1.23 the
# deprecation was removed and lists are always interpreted as
# array indices.
if ("Using a non-tuple sequence for multidimensional indexing is deprecated" in w.args[0]): # pragma: no cover
idx = array(idx, dtype=intp)
a_raw = raw_func(a, idx)
elif "Out of bound index found. This was previously ignored when the indexing result contained no elements. In the future the index error will be raised. This error occurs either due to an empty slice, or if an array has zero elements even before indexing." in w.args[0]:
same_exception = False
raise IndexError
else: # pragma: no cover
fail(f"Unexpected warning raised: {w}")
except Exception:
_, e_inner, _ = sys.exc_info()
if e_inner:
raise e_inner
except Exception as e:
exception = e
try:
index = ndindex(idx)
a_ndindex = ndindex_func(a, index)
except Exception as e:
if not exception:
fail(f"Raw form does not raise but ndindex form does ({e!r}): {index})") # pragma: no cover
if same_exception:
assert type(e) == type(exception), (e, exception)
assert e.args == exception.args, (e.args, exception.args)
else:
if exception:
fail(f"ndindex form did not raise but raw form does ({exception!r}): {index})") # pragma: no cover
if not exception:
assert_equal(a_raw, a_ndindex)
def iterslice(start_range=(-10, 10),
stop_range=(-10, 10),
step_range=(-10, 10),
one_two_args=True
):
# one_two_args is unnecessary if the args are being passed to slice(),
# since slice() already canonicalizes missing arguments to None. We do it
# for Slice to test that behavior.
if one_two_args:
for start in chain(range(*start_range), [None]):
yield (start,)
for start in chain(range(*start_range), [None]):
for stop in chain(range(*stop_range), [None]):
yield (start, stop)
for start in chain(range(*start_range), [None]):
for stop in chain(range(*stop_range), [None]):
for step in chain(range(*step_range), [None]):
yield (start, stop, step)
chunk_shapes = short_shapes
@composite
def chunk_sizes(draw, shapes=chunk_shapes):
shape = draw(shapes)
return draw(tuples(integers(1, 10), min_size=len(shape),
max_size=len(shape)).filter(lambda shape: prod(shape) < MAX_ARRAY_SIZE))