diff --git a/specutils/manipulation/resample.py b/specutils/manipulation/resample.py index d5a113139..b25ec47b8 100644 --- a/specutils/manipulation/resample.py +++ b/specutils/manipulation/resample.py @@ -6,9 +6,9 @@ from astropy.nddata import StdDevUncertainty, VarianceUncertainty, InverseVariance from ..spectra import Spectrum1D +from ..spectra.spectrum_collection import runs_on_spectrum_collection -__all__ = ['ResamplerBase', 'FluxConservingResampler', - 'LinearInterpolatedResampler', 'SplineInterpolatedResampler'] +__all__ = ['ResamplerBase', 'FluxConservingResampler', 'LinearInterpolatedResampler', 'SplineInterpolatedResampler'] class ResamplerBase(ABC): @@ -137,6 +137,7 @@ def _resample_matrix(self, orig_lamb, fin_lamb): # set bins that don't overlap 100% with original bins # to zero by checking edges, and applying generated mask + # to zero by checking edges, and applying generated mask left_clip = np.where(fin_edges[:-1] - orig_edges[0] < 0, 0, 1) right_clip = np.where(orig_edges[-1] - fin_edges[1:] < 0, 0, 1) keep_overlapping_matrix = left_clip * right_clip @@ -145,6 +146,7 @@ def _resample_matrix(self, orig_lamb, fin_lamb): return resamp_mat + @runs_on_spectrum_collection def resample1d(self, orig_spectrum, fin_lamb): """ Create a re-sampling matrix to be used in re-sampling spectra in a way diff --git a/specutils/spectra/spectrum_collection.py b/specutils/spectra/spectrum_collection.py index cbfcc7600..7e8b80cbd 100644 --- a/specutils/spectra/spectrum_collection.py +++ b/specutils/spectra/spectrum_collection.py @@ -1,4 +1,5 @@ import logging +import functools import astropy.units as u import numpy as np @@ -244,3 +245,34 @@ def __repr__(self): Uncertainty type: {}""".format( self.ndim, self.shape, self.flux.unit, self.spectral_axis.unit, self.uncertainty.uncertainty_type if self.uncertainty is not None else None) + + +def runs_on_spectrum_collection(func): + """ + Decorator to wrap Spectrum1D manipulation functionality so that it can + run on a SpectrumCollection + + For this decorator to work, the first argument in your function/method + should be the Spectrum1D or SpectrumCollection object, although if it's + not it will simply pass the original function. + + Another fallout, right now it's expecting this to be attached to a class + method, since it's expecting the first argument to be self, and the second + argument to be the spectrum1D or SpectrumCollection + """ + @functools.wraps(func) + + def wrapper(sarg, farg, *args, **kwargs): + + if isinstance(farg, SpectrumCollection): + new_spectra = [] + + for spec in farg: + new_spectra.append(func(sarg, spec, *args, **kwargs)) + + return SpectrumCollection.from_spectra(new_spectra) + + else: + return func(sarg, farg, *args, **kwargs) + + return wrapper diff --git a/specutils/tests/test_resample.py b/specutils/tests/test_resample.py index 05abd19d8..74db85b6c 100644 --- a/specutils/tests/test_resample.py +++ b/specutils/tests/test_resample.py @@ -4,7 +4,7 @@ from astropy.nddata import InverseVariance, StdDevUncertainty from astropy.tests.helper import assert_quantity_allclose -from ..spectra.spectrum1d import Spectrum1D +from ..spectra import Spectrum1D, SpectrumCollection from ..tests.spectral_examples import simulated_spectra from ..manipulation.resample import FluxConservingResampler, LinearInterpolatedResampler, SplineInterpolatedResampler @@ -162,3 +162,20 @@ def test_expanded_grid_interp_spline(): assert_quantity_allclose(results.flux, np.array([np.nan, 3.98808594, 6.94042969, 6.45869141, 5.89921875, 7.29736328, np.nan, np.nan, np.nan])*u.mJy) + + +def test_spectrum_collection(): + """ + Test if spectrum collection decorator is working with resample + """ + + spec = Spectrum1D(spectral_axis=np.linspace(0, 50, 50) * u.AA, + flux = np.random.randn(50) * u.Jy, + uncertainty = StdDevUncertainty(np.random.sample(50), unit='Jy')) + spec1 = Spectrum1D(spectral_axis=np.linspace(20, 60, 50) * u.AA, + flux = np.random.randn(50) * u.Jy, + uncertainty = StdDevUncertainty(np.random.sample(50), unit='Jy')) + spec_coll = SpectrumCollection.from_spectra([spec, spec1]) + + inst = FluxConservingResampler() + results = inst(spec_coll, np.linspace(20,40,40))