Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions ccdproc/ccddata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# This module implements the base CCDData class.

import copy
import numbers

import numpy as np

from astropy.nddata import NDData
from astropy.nddata.nduncertainty import StdDevUncertainty, NDUncertainty
Expand Down Expand Up @@ -167,6 +170,68 @@ def copy(self):
"""
return copy.deepcopy(self)

def _ccddata_arithmetic(self, other, operation, scale_uncertainty=False):
"""
Perform the common parts of arithmetic operations on CCDData objects

This should only be called when `other` is a Quantity or a number
"""
# THE "1 *" IS NECESSARY to get the right result, at least in
# astropy-0.4dev. Using the np.multiply, etc, methods with a Unit
# and a Quantity is currently broken, but it works with two Quantity
# arguments.
if isinstance(other, u.Quantity):
other_value = other.value
elif isinstance(other, numbers.Number):
other_value = other
else:
raise TypeError("Cannot do arithmetic with type '{0}' "
"and 'CCDData'".format(type(other)))

result_unit = operation(1 * self.unit, other).unit
result_data = operation(self.data, other_value)

if self.uncertainty:
result_uncertainty = self.uncertainty.array
if scale_uncertainty:
result_uncertainty = operation(result_uncertainty, other_value)
result_uncertainty = StdDevUncertainty(result_uncertainty)
else:
result_uncertainty = None

result = CCDData(data=result_data, unit=result_unit,
uncertainty=result_uncertainty,
meta=self.meta)
return result

def multiply(self, other):
if isinstance(other, CCDData):
return super(CCDData, self).multiply(other)

return self._ccddata_arithmetic(other, np.multiply,
scale_uncertainty=True)

def divide(self, other):
if isinstance(other, CCDData):
return super(CCDData, self).divide(other)

return self._ccddata_arithmetic(other, np.divide,
scale_uncertainty=True)

def add(self, other):
if isinstance(other, CCDData):
return super(CCDData, self).add(other)

return self._ccddata_arithmetic(other, np.add,
scale_uncertainty=False)

def subtract(self, other):
if isinstance(other, CCDData):
return super(CCDData, self).subtract(other)

return self._ccddata_arithmetic(other, np.subtract,
scale_uncertainty=False)


def fits_ccddata_reader(filename, hdu=0, unit=None, **kwd):
"""
Expand Down
155 changes: 155 additions & 0 deletions ccdproc/tests/test_ccddata.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from astropy.tests.helper import pytest
from astropy.utils import NumpyRNGContext
from astropy.nddata import StdDevUncertainty
from astropy import units as u

from ..ccddata import CCDData, electron

Expand Down Expand Up @@ -133,6 +135,159 @@ def test_copy(ccd_data):
assert ccd_copy.meta == ccd_data.meta


@pytest.mark.parametrize('operation,affects_uncertainty', [
("multiply", True),
("divide", True),
])
@pytest.mark.parametrize('operand', [
2.0,
2 * u.dimensionless_unscaled,
2 * u.photon / u.adu,
])
@pytest.mark.parametrize('with_uncertainty', [
True,
False])
@pytest.mark.data_unit(u.adu)
def test_mult_div_overload(ccd_data, operand, with_uncertainty,
operation, affects_uncertainty):
if with_uncertainty:
ccd_data.uncertainty = StdDevUncertainty(np.ones_like(ccd_data))
method = ccd_data.__getattribute__(operation)
np_method = np.__getattribute__(operation)
result = method(operand)
assert result is not ccd_data
assert isinstance(result, CCDData)
assert (result.uncertainty is None or
isinstance(result.uncertainty, StdDevUncertainty))
try:
op_value = operand.value
except AttributeError:
op_value = operand

np.testing.assert_array_equal(result.data,
np_method(ccd_data.data, op_value))
if with_uncertainty:
if affects_uncertainty:
np.testing.assert_array_equal(result.uncertainty.array,
np_method(ccd_data.uncertainty.array,
op_value))
else:
np.testing.assert_array_equal(result.uncertainty.array,
ccd_data.uncertainty.array)
else:
assert result.uncertainty is None

if isinstance(operand, u.Quantity):
# Need the "1 *" below to force arguments to be Quantity to work around
# astropy/astropy#2377
expected_unit = np_method(1 * ccd_data.unit, 1 * operand.unit).unit
assert result.unit == expected_unit
else:
assert result.unit == ccd_data.unit


@pytest.mark.parametrize('operation,affects_uncertainty', [
("add", False),
("subtract", False),
])
@pytest.mark.parametrize('operand,expect_failure', [
(2.0, u.UnitsError), # fail--units don't match image
(2 * u.dimensionless_unscaled, u.UnitsError), # same
(2 * u.adu, False),
])
@pytest.mark.parametrize('with_uncertainty', [
True,
False])
@pytest.mark.data_unit(u.adu)
def test_add_sub_overload(ccd_data, operand, expect_failure, with_uncertainty,
operation, affects_uncertainty):
if with_uncertainty:
ccd_data.uncertainty = StdDevUncertainty(np.ones_like(ccd_data))
method = ccd_data.__getattribute__(operation)
np_method = np.__getattribute__(operation)
if expect_failure:
with pytest.raises(expect_failure):
result = method(operand)
return
else:
result = method(operand)
assert result is not ccd_data
assert isinstance(result, CCDData)
assert (result.uncertainty is None or
isinstance(result.uncertainty, StdDevUncertainty))
try:
op_value = operand.value
except AttributeError:
op_value = operand

np.testing.assert_array_equal(result.data,
np_method(ccd_data.data, op_value))
if with_uncertainty:
if affects_uncertainty:
np.testing.assert_array_equal(result.uncertainty.array,
np_method(ccd_data.uncertainty.array,
op_value))
else:
np.testing.assert_array_equal(result.uncertainty.array,
ccd_data.uncertainty.array)
else:
assert result.uncertainty is None

if isinstance(operand, u.Quantity):
assert (result.unit == ccd_data.unit and result.unit == operand.unit)
else:
assert result.unit == ccd_data.unit


def test_arithmetic_overload_fails(ccd_data):
with pytest.raises(TypeError):
ccd_data.multiply("five")

with pytest.raises(TypeError):
ccd_data.divide("five")

with pytest.raises(TypeError):
ccd_data.add("five")

with pytest.raises(TypeError):
ccd_data.subtract("five")


def test_arithmetic_overload_ccddata_operand(ccd_data):
ccd_data.uncertainty = StdDevUncertainty(np.ones_like(ccd_data))
operand = ccd_data.copy()
result = ccd_data.add(operand)
assert len(result.meta) == 0
np.testing.assert_array_equal(result.data,
2 * ccd_data.data)
np.testing.assert_array_equal(result.uncertainty.array,
np.sqrt(2) * ccd_data.uncertainty.array)

result = ccd_data.subtract(operand)
assert len(result.meta) == 0
np.testing.assert_array_equal(result.data,
0 * ccd_data.data)
np.testing.assert_array_equal(result.uncertainty.array,
np.sqrt(2) * ccd_data.uncertainty.array)

result = ccd_data.multiply(operand)
assert len(result.meta) == 0
np.testing.assert_array_equal(result.data,
ccd_data.data ** 2)
expected_uncertainty = (np.sqrt(2) * np.abs(ccd_data.data) *
ccd_data.uncertainty.array)
np.testing.assert_allclose(result.uncertainty.array,
expected_uncertainty)

result = ccd_data.divide(operand)
assert len(result.meta) == 0
np.testing.assert_array_equal(result.data,
np.ones_like(ccd_data.data))
expected_uncertainty = (np.sqrt(2) / np.abs(ccd_data.data) *
ccd_data.uncertainty.array)
np.testing.assert_allclose(result.uncertainty.array,
expected_uncertainty)

if __name__ == '__main__':
test_ccddata_empty()
test_ccddata_simple()
Expand Down