From 690938936bb274770ee4a002c5ed4791bb88ecc9 Mon Sep 17 00:00:00 2001 From: Matthew Craig Date: Sun, 27 Apr 2014 17:26:12 -0500 Subject: [PATCH 1/2] Override basic arithmetic operations to allow operand to be scalar or Quantity or CCDData Closes #38 --- ccdproc/ccddata.py | 65 +++++++++++++++ ccdproc/tests/test_ccddata.py | 152 ++++++++++++++++++++++++++++++++++ 2 files changed, 217 insertions(+) diff --git a/ccdproc/ccddata.py b/ccdproc/ccddata.py index cf852831..e14490c7 100644 --- a/ccdproc/ccddata.py +++ b/ccdproc/ccddata.py @@ -2,6 +2,9 @@ # This module implements the base CCDData class. import copy +import numbers + +import numpy as np from astropy.nddata import NDData from astropy.nddata.nduncertainty import StdDevUncertainty, NDUncertainty @@ -167,6 +170,68 @@ def copy(self): """ return copy.deepcopy(self) + def _ccddata_arithmetic(self, other, operation, scale_uncertainty=False): + """ + Perform the common parts of arithmetic operations on CCDData objects + + This should only be called when `other` is a Quantity or a number + """ + # THE "1 *" IS NECESSARY to get the right result, at least in + # astropy-0.4dev. Using the np.multiply, etc, methods with a Unit + # and a Quantity is currently broken, but it works with two Quantity + # arguments. + if isinstance(other, u.Quantity): + other_value = other.value + elif isinstance(other, numbers.Number): + other_value = other + else: + raise TypeError("Cannot do arithmetic with type '{0}' " + "and 'CCDData'".format(type(other))) + + result_unit = operation(1 * self.unit, other).unit + result_data = operation(self.data, other_value) + + if self.uncertainty: + result_uncertainty = self.uncertainty.array + if scale_uncertainty: + result_uncertainty = operation(result_uncertainty, other_value) + result_uncertainty = StdDevUncertainty(result_uncertainty) + else: + result_uncertainty = None + + result = CCDData(data=result_data, unit=result_unit, + uncertainty=result_uncertainty, + meta=self.meta) + return result + + def multiply(self, other): + if isinstance(other, CCDData): + return super(CCDData, self).multiply(other) + + return self._ccddata_arithmetic(other, np.multiply, + scale_uncertainty=True) + + def divide(self, other): + if isinstance(other, CCDData): + return super(CCDData, self).divide(other) + + return self._ccddata_arithmetic(other, np.divide, + scale_uncertainty=True) + + def add(self, other): + if isinstance(other, CCDData): + return super(CCDData, self).add(other) + + return self._ccddata_arithmetic(other, np.add, + scale_uncertainty=False) + + def subtract(self, other): + if isinstance(other, CCDData): + return super(CCDData, self).subtract(other) + + return self._ccddata_arithmetic(other, np.subtract, + scale_uncertainty=False) + def fits_ccddata_reader(filename, hdu=0, unit=None, **kwd): """ diff --git a/ccdproc/tests/test_ccddata.py b/ccdproc/tests/test_ccddata.py index f995842f..3e1462fc 100644 --- a/ccdproc/tests/test_ccddata.py +++ b/ccdproc/tests/test_ccddata.py @@ -6,6 +6,8 @@ from astropy.tests.helper import pytest from astropy.utils import NumpyRNGContext +from astropy.nddata import StdDevUncertainty +from astropy import units as u from ..ccddata import CCDData, electron @@ -133,6 +135,156 @@ def test_copy(ccd_data): assert ccd_copy.meta == ccd_data.meta +@pytest.mark.parametrize('operation,affects_uncertainty', [ + ("multiply", True), + ("divide", True), + ]) +@pytest.mark.parametrize('operand', [ + 2.0, + 2 * u.dimensionless_unscaled, + 2 * u.photon / u.adu, + ]) +@pytest.mark.parametrize('with_uncertainty', [ + True, + False]) +@pytest.mark.data_unit(u.adu) +def test_mult_div_overload(ccd_data, operand, with_uncertainty, + operation, affects_uncertainty): + if with_uncertainty: + ccd_data.uncertainty = StdDevUncertainty(np.ones_like(ccd_data)) + method = ccd_data.__getattribute__(operation) + np_method = np.__getattribute__(operation) + result = method(operand) + assert result is not ccd_data + assert isinstance(result, CCDData) + assert (result.uncertainty is None or + isinstance(result.uncertainty, StdDevUncertainty)) + try: + op_value = operand.value + except AttributeError: + op_value = operand + + np.testing.assert_array_equal(result.data, + np_method(ccd_data.data, op_value)) + if with_uncertainty: + if affects_uncertainty: + np.testing.assert_array_equal(result.uncertainty.array, + np_method(ccd_data.uncertainty.array, + op_value)) + else: + np.testing.assert_array_equal(result.uncertainty.array, + ccd_data.uncertainty.array) + else: + assert result.uncertainty is None + + if isinstance(operand, u.Quantity): + assert result.unit == np_method(ccd_data.unit, operand.unit) + else: + assert result.unit == ccd_data.unit + + +@pytest.mark.parametrize('operation,affects_uncertainty', [ + ("add", False), + ("subtract", False), + ]) +@pytest.mark.parametrize('operand,expect_failure', [ + (2.0, u.UnitsError), # fail--units don't match image + (2 * u.dimensionless_unscaled, u.UnitsError), # same + (2 * u.adu, False), + ]) +@pytest.mark.parametrize('with_uncertainty', [ + True, + False]) +@pytest.mark.data_unit(u.adu) +def test_add_sub_overload(ccd_data, operand, expect_failure, with_uncertainty, + operation, affects_uncertainty): + if with_uncertainty: + ccd_data.uncertainty = StdDevUncertainty(np.ones_like(ccd_data)) + method = ccd_data.__getattribute__(operation) + np_method = np.__getattribute__(operation) + if expect_failure: + with pytest.raises(expect_failure): + result = method(operand) + return + else: + result = method(operand) + assert result is not ccd_data + assert isinstance(result, CCDData) + assert (result.uncertainty is None or + isinstance(result.uncertainty, StdDevUncertainty)) + try: + op_value = operand.value + except AttributeError: + op_value = operand + + np.testing.assert_array_equal(result.data, + np_method(ccd_data.data, op_value)) + if with_uncertainty: + if affects_uncertainty: + np.testing.assert_array_equal(result.uncertainty.array, + np_method(ccd_data.uncertainty.array, + op_value)) + else: + np.testing.assert_array_equal(result.uncertainty.array, + ccd_data.uncertainty.array) + else: + assert result.uncertainty is None + + if isinstance(operand, u.Quantity): + assert (result.unit == ccd_data.unit and result.unit == operand.unit) + else: + assert result.unit == ccd_data.unit + + +def test_arithmetic_overload_fails(ccd_data): + with pytest.raises(TypeError): + ccd_data.multiply("five") + + with pytest.raises(TypeError): + ccd_data.divide("five") + + with pytest.raises(TypeError): + ccd_data.add("five") + + with pytest.raises(TypeError): + ccd_data.subtract("five") + + +def test_arithmetic_overload_ccddata_operand(ccd_data): + ccd_data.uncertainty = StdDevUncertainty(np.ones_like(ccd_data)) + operand = ccd_data.copy() + result = ccd_data.add(operand) + assert len(result.meta) == 0 + np.testing.assert_array_equal(result.data, + 2 * ccd_data.data) + np.testing.assert_array_equal(result.uncertainty.array, + np.sqrt(2) * ccd_data.uncertainty.array) + + result = ccd_data.subtract(operand) + assert len(result.meta) == 0 + np.testing.assert_array_equal(result.data, + 0 * ccd_data.data) + np.testing.assert_array_equal(result.uncertainty.array, + np.sqrt(2) * ccd_data.uncertainty.array) + + result = ccd_data.multiply(operand) + assert len(result.meta) == 0 + np.testing.assert_array_equal(result.data, + ccd_data.data ** 2) + expected_uncertainty = (np.sqrt(2) * np.abs(ccd_data.data) * + ccd_data.uncertainty.array) + np.testing.assert_allclose(result.uncertainty.array, + expected_uncertainty) + + result = ccd_data.divide(operand) + assert len(result.meta) == 0 + np.testing.assert_array_equal(result.data, + np.ones_like(ccd_data.data)) + expected_uncertainty = (np.sqrt(2) / np.abs(ccd_data.data) * + ccd_data.uncertainty.array) + np.testing.assert_allclose(result.uncertainty.array, + expected_uncertainty) + if __name__ == '__main__': test_ccddata_empty() test_ccddata_simple() From 2cb9119ab4c23ee8b5e71cff3b6af1411eb84d48 Mon Sep 17 00:00:00 2001 From: Matthew Craig Date: Sun, 27 Apr 2014 20:13:41 -0500 Subject: [PATCH 2/2] Work around issue with multiplying units when one is dimensionless The issue is reported at astropy/astropy#2377 --- ccdproc/tests/test_ccddata.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ccdproc/tests/test_ccddata.py b/ccdproc/tests/test_ccddata.py index 3e1462fc..ddb999e4 100644 --- a/ccdproc/tests/test_ccddata.py +++ b/ccdproc/tests/test_ccddata.py @@ -178,7 +178,10 @@ def test_mult_div_overload(ccd_data, operand, with_uncertainty, assert result.uncertainty is None if isinstance(operand, u.Quantity): - assert result.unit == np_method(ccd_data.unit, operand.unit) + # Need the "1 *" below to force arguments to be Quantity to work around + # astropy/astropy#2377 + expected_unit = np_method(1 * ccd_data.unit, 1 * operand.unit).unit + assert result.unit == expected_unit else: assert result.unit == ccd_data.unit