Skip to content

Commit

Permalink
Merge pull request #509 from brechmos-stsci/snr-threshold-nd
Browse files Browse the repository at this point in the history
Added in snr_thresholding with 1D and 3D tests
  • Loading branch information
eteq committed Oct 2, 2019
2 parents 60a1653 + 10bffa8 commit 53a00de
Show file tree
Hide file tree
Showing 5 changed files with 244 additions and 1 deletion.
37 changes: 37 additions & 0 deletions docs/manipulation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,43 @@ Or similarly, expressed in pixels:
>>> spec_w_unc.uncertainty # doctest: +ELLIPSIS
StdDevUncertainty([0.18714535, ..., 0.18714535])
S/N Threshold Mask
------------------

It is useful to be able to find all the spaxels in an ND spectrum
in which the signal to noise ratio is greater than some threshold.
This method implements this functionality so that a `~specutils.Spectrum1D`
object, `~specutils.SpectrumCollection` or an :class:`~astropy.nddata.NDData` derived
object may be passed in as the first parameter. The second parameter
is a floating point threshold.

For example, first a spectrum with flux and uncertainty is created, and
then call the ``snr_threshold`` method:

.. code-block:: python
>>> import numpy as np
>>> from astropy.nddata import StdDevUncertainty
>>> import astropy.units as u
>>> from specutils import Spectrum1D
>>> from specutils.manipulation import snr_threshold
>>> np.random.seed(42)
>>> wavelengths = np.arange(0, 10)*u.um
>>> flux = 100*np.abs(np.random.randn(10))*u.Jy
>>> uncertainty = StdDevUncertainty(np.abs(np.random.randn(10))*u.Jy)
>>> spectrum = Spectrum1D(spectral_axis=wavelengths, flux=flux, uncertainty=uncertainty)
>>> spectrum_masked = snr_threshold(spectrum, 50) #doctest:+SKIP
>>> # To create a masked flux array
>>> flux_masked = spectrum_masked.flux #doctest:+SKIP
>>> flux_masked[spectrum_masked.mask] = np.nan #doctest:+SKIP
The output ``spectrum_masked`` is a shallow copy of the input ``spectrum``
with the ``mask`` attribute set to False where the S/N is greater than 50
and True elsewhere. It is this way to be consistent with ``astropy.nddata``.

.. note:: The mask attribute is the only attribute modified by ``snr_threshold()``. To
retrieve the masked flux data use ``spectrum.masked.flux_masked``.

Reference/API
-------------

Expand Down
1 change: 1 addition & 0 deletions specutils/manipulation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
from .estimate_uncertainty import * # noqa
from .extract_spectral_region import * # noqa
from .utils import * # noqa
from .manipulation import * # noqa
from .resample import *
82 changes: 82 additions & 0 deletions specutils/manipulation/manipulation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""
A module for analysis tools dealing with uncertainties or error analysis in
spectra.
"""

import copy
import numpy as np
import operator

__all__ = ['snr_threshold']


def snr_threshold(spectrum, value, op=operator.gt):
"""
Calculate the mean S/N of the spectrum based on the flux and uncertainty
in the spectrum. This will be calculated over the regions, if they
are specified.
Parameters
----------
spectrum : `~specutils.Spectrum1D`, `~specutils.SpectrumCollection` or `~astropy.nddata.NDData`
The spectrum object overwhich the S/N threshold will be calculated.
value: ``float``
Threshold value to be applied to flux / uncertainty.
op: One of operator.gt, operator.ge, operator.lt, operator.le or
the str equivalent '>', '>=', '<', '<='
The mathematical operator to apply for thresholding.
Returns
-------
spectrum: `~specutils.Spectrum1D`
Output object with ``spectrum.mask`` set based on threshold.
Notes
-----
The input object will need to have the uncertainty defined in order for the SNR
to be calculated.
"""

# Setup the mapping
operator_mapping = {
'>': operator.gt,
'<': operator.lt,
'>=': operator.ge,
'<=': operator.le
}

if not hasattr(spectrum, 'uncertainty') or spectrum.uncertainty is None:
raise Exception("S/N thresholding requires the uncertainty be defined.")

if (op not in [operator.gt, operator.ge, operator.lt, operator.le] and
op not in operator_mapping.keys()):
raise ValueError('Threshold operator must be a string or operator that represents ' +
'greater-than, less-than, greater-than-or-equal or ' +
'less-than-or-equal')

# If the operator passed in is a string, then map to the
# operator method.
if isinstance(op, str):
op = operator_mapping[op]

# Spectrum1D
if hasattr(spectrum, 'flux'):
data = spectrum.flux

# NDData
elif hasattr(spectrum, 'data'):
data = spectrum.data * (spectrum.unit if spectrum.unit is not None else 1)
else:
raise ValueError('Could not find data attribute.')

# NDData convention: Masks should follow the numpy convention that valid
# data points are marked by False and invalid ones with True.
mask = ~op(data / (spectrum.uncertainty.quantity), value)

spectrum_out = copy.copy(spectrum)
spectrum_out._mask = mask

return spectrum_out
2 changes: 1 addition & 1 deletion specutils/spectra/spectrum1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from astropy import units as u
from astropy import constants as cnst
from astropy.nddata import NDDataRef
from astropy.nddata import NDDataRef, NDUncertainty
from astropy.utils.decorators import lazyproperty

from ..wcs import WCSAdapter, WCSWrapper
Expand Down
123 changes: 123 additions & 0 deletions specutils/tests/test_manipulation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import operator
import pytest
import numpy as np

import astropy.units as u
from astropy.modeling import models
from astropy.nddata import StdDevUncertainty, NDData
from astropy.tests.helper import quantity_allclose
from specutils.wcs.wcs_wrapper import WCSWrapper

from ..spectra import Spectrum1D, SpectralRegion, SpectrumCollection
from ..manipulation import snr_threshold


def test_snr_threshold():

np.random.seed(42)

# Setup 1D spectrum
wavelengths = np.arange(0, 10)*u.um
flux = 100*np.abs(np.random.randn(10))*u.Jy
uncertainty = StdDevUncertainty(np.abs(np.random.randn(10))*u.Jy)
spectrum = Spectrum1D(spectral_axis=wavelengths, flux=flux, uncertainty=uncertainty)

spectrum_masked = snr_threshold(spectrum, 50)
assert all([x==y for x,y in zip(spectrum_masked.mask, [False, True, False, False, True, True, False, False, False, True])])

spectrum_masked = snr_threshold(spectrum, 50, operator.gt)
assert all([x==y for x,y in zip(spectrum_masked.mask, [False, True, False, False, True, True, False, False, False, True])])

spectrum_masked = snr_threshold(spectrum, 50, '>')
assert all([x==y for x,y in zip(spectrum_masked.mask, [False, True, False, False, True, True, False, False, False, True])])

spectrum_masked = snr_threshold(spectrum, 50, operator.ge)
assert all([x==y for x,y in zip(spectrum_masked.mask, [False, True, False, False, True, True, False, False, False, True])])

spectrum_masked = snr_threshold(spectrum, 50, '>=')
assert all([x==y for x,y in zip(spectrum_masked.mask, [False, True, False, False, True, True, False, False, False, True])])

spectrum_masked = snr_threshold(spectrum, 50, operator.lt)
assert all([not x==y for x,y in zip(spectrum_masked.mask, [False, True, False, False, True, True, False, False, False, True])])

spectrum_masked = snr_threshold(spectrum, 50, '<')
assert all([not x==y for x,y in zip(spectrum_masked.mask, [False, True, False, False, True, True, False, False, False, True])])

spectrum_masked = snr_threshold(spectrum, 50, operator.le)
assert all([not x==y for x,y in zip(spectrum_masked.mask, [False, True, False, False, True, True, False, False, False, True])])

spectrum_masked = snr_threshold(spectrum, 50, '<=')
assert all([not x==y for x,y in zip(spectrum_masked.mask, [False, True, False, False, True, True, False, False, False, True])])

# Setup 3D spectrum
np.random.seed(42)
wavelengths = np.arange(0, 10)*u.um
flux = 100*np.abs(np.random.randn(3, 4, 10))*u.Jy
uncertainty = StdDevUncertainty(np.abs(np.random.randn(3, 4, 10))*u.Jy)
spectrum = Spectrum1D(spectral_axis=wavelengths, flux=flux, uncertainty=uncertainty)

spectrum_masked = snr_threshold(spectrum, 50)

masked_true = np.array([[[ False, True, True, False, True, True, False, False, False, False],
[True, False, True, False, False, True, False, False, False, False],
[ False, True, True, False, False, True, False, True, False, False],
[ False, False, True, False, False, False, True, False, False, True]],
[[ False, True, True, True, False, False, False, False, False, False],
[True, True, False, False, False, False, False, True, False, True],
[ False, True, False, False, False, False, True, False, True, True],
[ False, False, True, False, False, False, True, False, False, False]],
[[ False, False, False, True, False, False, False, False, False, True],
[True, False, False, False, False, False, True, False, True, False],
[ False, True, True, True, True, True, False, True, True, True],
[ False, True, False, False, True, True, True, False, False, False]]])

assert all([x==y for x,y in zip(spectrum_masked.mask.ravel(), masked_true.ravel())])


# Setup 3D NDData
np.random.seed(42)
flux = 100*np.abs(np.random.randn(3, 4, 10))*u.Jy
uncertainty = StdDevUncertainty(np.abs(np.random.randn(3, 4, 10))*u.Jy)
spectrum = NDData(data=flux, uncertainty=uncertainty)

spectrum_masked = snr_threshold(spectrum, 50)

masked_true = np.array([[[ False, True, True, False, True, True, False, False, False, False],
[True, False, True, False, False, True, False, False, False, False],
[ False, True, True, False, False, True, False, True, False, False],
[ False, False, True, False, False, False, True, False, False, True]],
[[ False, True, True, True, False, False, False, False, False, False],
[True, True, False, False, False, False, False, True, False, True],
[ False, True, False, False, False, False, True, False, True, True],
[ False, False, True, False, False, False, True, False, False, False]],
[[ False, False, False, True, False, False, False, False, False, True],
[True, False, False, False, False, False, True, False, True, False],
[ False, True, True, True, True, True, False, True, True, True],
[ False, True, False, False, True, True, True, False, False, False]]])

assert all([x==y for x,y in zip(spectrum_masked.mask.ravel(), masked_true.ravel())])


# Test SpectralCollection
np.random.seed(42)
flux = u.Quantity(np.random.sample((5, 10)), unit='Jy')
spectral_axis = u.Quantity(np.arange(50).reshape((5, 10)), unit='AA')
wcs = np.array([WCSWrapper.from_array(x).wcs for x in spectral_axis])
uncertainty = StdDevUncertainty(np.random.sample((5, 10)), unit='Jy')
mask = np.ones((5, 10)).astype(bool)
meta = [{'test': 5, 'info': [1, 2, 3]} for i in range(5)]

spec_coll = SpectrumCollection(
flux=flux, spectral_axis=spectral_axis, wcs=wcs,
uncertainty=uncertainty, mask=mask, meta=meta)

spec_coll_masked = snr_threshold(spec_coll, 3)
print(spec_coll_masked.mask)

ma = np.array([[True, True, True, True, True, True, True, False, False, True],
[True, False, True, True, True, True, True, True, False, True],
[True, True, False, True, True, True, True, False, True, True],
[True, True, True, False, False, True, True, True, True, True],
[True, True, True, True, True, True, True, True, False, True]])

assert all([x==y for x,y in zip(spec_coll_masked.mask.ravel(), ma.ravel())])

0 comments on commit 53a00de

Please sign in to comment.