Skip to content

Commit

Permalink
Merge pull request #10822 from george-galvin/convolution-issue
Browse files Browse the repository at this point in the history
Made convolve work with units, closing issue #10811
  • Loading branch information
mhvk committed Oct 17, 2020
2 parents 73973e9 + 0ef8d24 commit 079fcb7
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ astropy.constants
astropy.convolution
^^^^^^^^^^^^^^^^^^^

- Methods ``convolve`` and ``convolve_fft`` both now return Quantity arrays if user
input is given in one. [#10822]

astropy.coordinates
^^^^^^^^^^^^^^^^^^^

Expand Down
10 changes: 10 additions & 0 deletions astropy/convolution/convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,10 @@ def convolve(array, kernel, boundary='fill', fill_value=0.,
result[initially_nan] = np.nan

# Convert result to original data type
array_unit = getattr(passed_array, "unit", None)
if array_unit is not None:
result <<= array_unit

if isinstance(passed_array, Kernel):
if isinstance(passed_array, Kernel1D):
new_result = Kernel1D(array=result)
Expand Down Expand Up @@ -567,6 +571,9 @@ def convolve_fft(array, kernel, boundary='fill', fill_value=0.,
if nan_treatment not in ('interpolate', 'fill'):
raise ValueError("nan_treatment must be one of 'interpolate','fill'")

#Get array quantity if it exists
array_unit = getattr(array, "unit", None)

# Convert array dtype to complex
# and ensure that list inputs become arrays
array = _copy_input_if_needed(array, dtype=complex, order='C',
Expand Down Expand Up @@ -759,6 +766,9 @@ def convolve_fft(array, kernel, boundary='fill', fill_value=0.,

fftmult *= kernel_scale

if array_unit is not None:
fftmult <<= array_unit

if return_fft:
return fftmult

Expand Down
21 changes: 21 additions & 0 deletions astropy/convolution/tests/test_convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@


class TestConvolve1D:

def test_list(self):
"""
Test that convolve works correctly when inputs are lists
Expand All @@ -65,6 +66,26 @@ def test_tuple(self):
assert_array_almost_equal_nulp(z,
np.array([0., 3.6, 5., 5.6, 5.6, 6.8, 0.]), 10)

@pytest.mark.parametrize(('boundary', 'nan_treatment',
'normalize_kernel', 'preserve_nan', 'dtype'),
itertools.product(BOUNDARY_OPTIONS,
NANHANDLING_OPTIONS,
NORMALIZE_OPTIONS,
PRESERVE_NAN_OPTIONS,
VALID_DTYPES))
def test_quantity(self, boundary, nan_treatment,
normalize_kernel, preserve_nan, dtype):
"""
Test that convolve works correctly when input array is a Quantity
"""

x = np.array([1, 4, 5, 6, 5, 7, 8], dtype=dtype) * u.ph
y = np.array([0.2, 0.6, 0.2], dtype=dtype)
z = convolve(x, y, boundary=boundary, nan_treatment=nan_treatment,
normalize_kernel=normalize_kernel, preserve_nan=preserve_nan)

assert x.unit == z.unit

@pytest.mark.parametrize(('boundary', 'nan_treatment',
'normalize_kernel', 'preserve_nan', 'dtype'),
itertools.product(BOUNDARY_OPTIONS,
Expand Down
17 changes: 17 additions & 0 deletions astropy/convolution/tests/test_convolve_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from astropy.convolution.convolve import convolve_fft, convolve
from astropy.utils.exceptions import AstropyUserWarning
from astropy import units as u
from astropy.utils.compat.context import nullcontext

VALID_DTYPES = ('>f4', '<f4', '>f8', '<f8')
Expand Down Expand Up @@ -69,6 +70,22 @@ def assert_floatclose(x, y):

class TestConvolve1D:

@pytest.mark.parametrize(option_names, options)
def test_quantity(self, boundary, nan_treatment, normalize_kernel):
"""
Test that convolve_fft works correctly when input array is a Quantity
"""

x = np.array([1., 4., 5., 6., 5., 7., 8.], dtype='float64') * u.ph
y = np.array([0.2, 0.6, 0.2], dtype='float64')

with expected_boundary_warning(boundary=boundary):
z = convolve_fft(x, y, boundary=boundary,
nan_treatment=nan_treatment,
normalize_kernel=normalize_kernel)

assert x.unit == z.unit

@pytest.mark.parametrize(option_names, options)
def test_unity_1_none(self, boundary, nan_treatment, normalize_kernel):
'''
Expand Down

0 comments on commit 079fcb7

Please sign in to comment.