Skip to content

Commit

Permalink
BUG: check_broadcast cannot return None
Browse files Browse the repository at this point in the history
  • Loading branch information
lpsinger committed Aug 21, 2023
1 parent 102454b commit bf84b7f
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 34 deletions.
7 changes: 4 additions & 3 deletions astropy/modeling/core.py
Expand Up @@ -1061,11 +1061,12 @@ def _validate_input_shapes(self, inputs, argnames, model_set_axis):
)
)

input_shape = check_broadcast(*all_shapes)
if input_shape is None:
try:
input_shape = check_broadcast(*all_shapes)
except IncompatibleShapeError as e:

Check warning on line 1066 in astropy/modeling/core.py

View check run for this annotation

Codecov / codecov/patch

astropy/modeling/core.py#L1066

Added line #L1066 was not covered by tests
raise ValueError(
"All inputs must have identical shapes or must be scalars."
)
) from e

return input_shape

Expand Down
40 changes: 9 additions & 31 deletions astropy/modeling/tests/test_core.py
Expand Up @@ -11,7 +11,6 @@
from numpy.testing import assert_allclose, assert_equal

import astropy
import astropy.modeling.core as core
import astropy.units as u
from astropy.convolution import convolve_models
from astropy.modeling import models
Expand Down Expand Up @@ -984,40 +983,19 @@ def test__validate_input_shape():
def test__validate_input_shapes():
model = models.Gaussian1D()
model._n_models = 2
inputs = [mk.MagicMock() for _ in range(3)]
argnames = mk.MagicMock()
model_set_axis = mk.MagicMock()
all_shapes = [mk.MagicMock() for _ in inputs]

# Successful validation
with mk.patch.object(
Model, "_validate_input_shape", autospec=True, side_effect=all_shapes
) as mkValidate:
with mk.patch.object(core, "check_broadcast", autospec=True) as mkCheck:
assert mkCheck.return_value == model._validate_input_shapes(
inputs, argnames, model_set_axis
)
assert mkCheck.call_args_list == [mk.call(*all_shapes)]
assert mkValidate.call_args_list == [
mk.call(model, _input, idx, argnames, model_set_axis, True)
for idx, _input in enumerate(inputs)
]
# Full success
inputs = [np.array([[1, 2], [3, 4]]), np.array([[5, 6], [7, 8]])]
assert (2, 2) == model._validate_input_shapes(inputs, model.inputs, 1)

# Fail check_broadcast
MESSAGE = r"All inputs must have identical shapes or must be scalars"
with mk.patch.object(
Model, "_validate_input_shape", autospec=True, side_effect=all_shapes
) as mkValidate:
with mk.patch.object(
core, "check_broadcast", autospec=True, return_value=None
) as mkCheck:
with pytest.raises(ValueError, match=MESSAGE):
model._validate_input_shapes(inputs, argnames, model_set_axis)
assert mkCheck.call_args_list == [mk.call(*all_shapes)]
assert mkValidate.call_args_list == [
mk.call(model, _input, idx, argnames, model_set_axis, True)
for idx, _input in enumerate(inputs)
]

# Fails because the input shape of the second input has one more axis which
# for which the first input can be broadcasted to
inputs = [np.array([[1, 2], [3, 4]]), np.array([[5, 6], [7, 8], [9, 10]])]
with pytest.raises(ValueError, match=MESSAGE):
model._validate_input_shapes(inputs, model.inputs, 1)


def test__remove_axes_from_shape():
Expand Down
8 changes: 8 additions & 0 deletions astropy/modeling/tests/test_models.py
Expand Up @@ -120,6 +120,14 @@ def test_inconsistent_input_shapes():
y.shape = (1, 10)
result = g(x, y)
assert result.shape == (10, 10)
# incompatible shapes do _not_ work
g = Gaussian2D()
x = np.arange(-1.0, 1, 0.2)
y = np.arange(-1.0, 1, 0.1)
with pytest.raises(
ValueError, match="All inputs must have identical shapes or must be scalars"
):
g(x, y)


def test_custom_model_bounding_box():
Expand Down
2 changes: 2 additions & 0 deletions docs/changes/modeling/15209.api.rst
@@ -0,0 +1,2 @@
Creating a model instance with parameters that have incompatible shapes will
now raise a ``ValueError`` rather than an ``IncompatibleShapeError``.

0 comments on commit bf84b7f

Please sign in to comment.