Skip to content

Commit

Permalink
Merge 7eef14f into bc47905
Browse files Browse the repository at this point in the history
  • Loading branch information
javerbukh committed Sep 16, 2019
2 parents bc47905 + 7eef14f commit b9f9b92
Show file tree
Hide file tree
Showing 6 changed files with 396 additions and 4 deletions.
24 changes: 24 additions & 0 deletions docs/analysis.rst
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,30 @@ Each of the width analysis functions are applied to this spectrum below:
<Quantity 89.99999455 GHz>
Template Comparison
-------------------

With one observed spectrum and n template spectra, the process to do template matching is:

1. Move the templates as close as possible to the observed spectrum by doing redshifting, matching the resolution,
and matching the wavelength spacing
2. Then loop through all of the n template spectra and run chi square on them using the observed spectrum as the
expected frequency and the template spectrum as the observed
3. Once you have a corresponding chi square for each template spectrum, return the lowest chi square and its
corresponding template spectrum (normalized) and the index of the template spectrum if the original template
parameter is iterable

An example of how to do template matching in is:

.. code-block:: python
>>> from specutils.analysis import template_comparison
>>> spec_axis = np.linspace(0, 50, 50) * u.AA
>>> observed_spectrum = Spectrum1D(spectral_axis=spec_axis, flux=np.random.randn(50) * u.Jy, uncertainty=StdDevUncertainty(np.random.sample(50), unit='Jy'))
>>> spectral_template = Spectrum1D(spectral_axis=spec_axis, flux=np.random.randn(50) * u.Jy, uncertainty=StdDevUncertainty(np.random.sample(50), unit='Jy'))
>>> tm_result = template_comparison.template_match(observed_spectrum, spectral_template) # doctest:+FLOAT_CMP
Reference/API
-------------
.. automodapi:: specutils.analysis
Expand Down
1 change: 1 addition & 0 deletions specutils/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .uncertainty import * # noqa
from .location import * # noqa
from .width import * # noqa
from .template_comparison import *
162 changes: 162 additions & 0 deletions specutils/analysis/template_comparison.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import numpy as np

from ..manipulation import (FluxConservingResampler,
LinearInterpolatedResampler,
SplineInterpolatedResampler)
from ..spectra.spectrum1d import Spectrum1D

def _normalize_for_template_matching(observed_spectrum, template_spectrum):
"""
Calculate a scale factor to be applied to the template spectrum so the
total flux in both spectra will be the same.
Parameters
----------
observed_spectrum : :class:`~specutils.Spectrum1D`
The observed spectrum.
template_spectrum : :class:`~specutils.Spectrum1D`
The template spectrum, which needs to be normalized in order to be
compared with the observed spectrum.
Returns
-------
`float`
A float which will normalize the template spectrum's flux so that it
can be compared to the observed spectrum.
"""
num = np.sum((observed_spectrum.flux*template_spectrum.flux)/
(observed_spectrum.uncertainty.array**2))
denom = np.sum((template_spectrum.flux/
observed_spectrum.uncertainty.array)**2)

return num/denom


def _resample(resample_method):
"""
Find the user preferred method of resampling the template spectrum to fit
the observed spectrum.
Parameters
----------
resample_method: `string`
The type of resampling to be done on the template spectrum
Returns
-------
:class:`~specutils.ResamplerBase`
This is the actual class that will handle the resampling
"""
if resample_method == "flux_conserving":
return FluxConservingResampler()

if resample_method == "linear_interpolated":
return LinearInterpolatedResampler()

if resample_method == "spline_interpolated":
return SplineInterpolatedResampler()

return None


def _template_match(observed_spectrum, template_spectrum, resample_method):
"""
Resample the template spectrum to match the wavelength of the observed
spectrum. Then, calculate chi2 on the flux of the two spectra.
Parameters
----------
observed_spectrum : :class:`~specutils.Spectrum1D`
The observed spectrum.
template_spectrum : :class:`~specutils.Spectrum1D`
The template spectrum, which will be resampled to match the wavelength
of the observed spectrum.
Returns
-------
normalized_template_spectrum : :class:`~specutils.Spectrum1D`
The normalized spectrum template.
chi2 : `float`
The chi2 of the flux of the observed spectrum and the flux of the
normalized template spectrum.
"""
# Resample template
if _resample(resample_method) != 0:
fluxc_resample = _resample(resample_method)
template_obswavelength = fluxc_resample(template_spectrum,
observed_spectrum.wavelength)

# Normalize spectra
normalization = _normalize_for_template_matching(observed_spectrum,
template_obswavelength)

# Numerator
num_right = normalization * template_obswavelength.flux
num = observed_spectrum.flux - num_right

# Denominator
denom = observed_spectrum.uncertainty.array * observed_spectrum.flux.unit

# Get chi square
result = (num/denom)**2
chi2 = np.sum(result.value)

# Create normalized template spectrum, which will be returned with
# corresponding chi2
normalized_template_spectrum = Spectrum1D(
spectral_axis=template_spectrum.spectral_axis,
flux=template_spectrum.flux*normalization)

return normalized_template_spectrum, chi2


def template_match(observed_spectrum, spectral_templates,
resample_method="flux_conserving"):
"""
Find which spectral templates is the best fit to an observed spectrum by
computing the chi-squared. If two template_spectra have the same chi2, the
first template is returned.
Parameters
----------
observed_spectrum : :class:`~specutils.Spectrum1D`
The observed spectrum.
spectral_templates : :class:`~specutils.Spectrum1D` or :class:`~specutils.SpectrumCollection` or `list`
That will give a single :class:`~specutils.Spectrum1D` when iterated
over. The template spectra, which will be resampled, normalized, and
compared to the observed spectrum, where the smallest chi2 and
normalized template spectrum will be returned.
Returns
-------
normalized_template_spectrum : :class:`~specutils.Spectrum1D`
The template spectrum that has been normalized.
chi2 : `float`
The chi2 of the flux of the observed_spectrum and the flux of the
normalized template spectrum.
smallest_chi_index : `int`
The index of the spectrum with the smallest chi2 in spectral templates.
"""
if hasattr(spectral_templates, 'flux') and len(spectral_templates.flux.shape) == 1:
normalized_spectral_template, chi2 = _template_match(
observed_spectrum, spectral_templates, resample_method)

return normalized_spectral_template, chi2

# At this point, the template spectrum is either a ``SpectrumCollection``
# or a multi-dimensional``Spectrum1D``. Loop through the object and return
# the template spectrum with the lowest chi square and its corresponding
# chi square.
chi2_min = None
smallest_chi_spec = None

for index, spectrum in enumerate(spectral_templates):
normalized_spectral_template, chi2 = _template_match(
observed_spectrum, spectrum, resample_method)

if chi2_min is None or chi2 < chi2_min:
chi2_min = chi2
smallest_chi_spec = normalized_spectral_template
smallest_chi_index = index

return smallest_chi_spec, chi2_min, smallest_chi_index
49 changes: 45 additions & 4 deletions specutils/spectra/spectrum1d.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import logging
from copy import deepcopy

import numpy as np
from astropy import units as u
from astropy.nddata import NDDataRef
from astropy.nddata import NDDataRef, NDUncertainty
from astropy.utils.decorators import lazyproperty
from astropy.nddata import NDUncertainty
from ..wcs import WCSWrapper, WCSAdapter

from ..wcs import WCSAdapter, WCSWrapper
from .spectrum_mixin import OneDSpectrumMixin

__all__ = ['Spectrum1D']
Expand Down Expand Up @@ -123,6 +124,47 @@ def __init__(self, flux=None, spectral_axis=None, wcs=None,
raise ValueError('Flux axis ({}) and uncertainty ({}) shapes must be the same.'.format(
flux.shape, self.uncertainty.array.shape))

def __getitem__(self, item):
"""
Override the class indexer. We do this here because there are two cases
for slicing on a ``Spectrum1D``:
1.) When the flux is one dimensional, indexing represents a single
flux value at a particular spectral axis bin, and returns a new
``Spectrum1D`` where all attributes are sliced.
2.) When flux is multi-dimensional (i.e. several fluxes over the
same spectral axis), indexing returns a new ``Spectrum1D`` with
the sliced flux range and a deep copy of all other attributes.
The first case is handled by the parent class, while the second is
handled here.
"""
if len(self.flux.shape) > 1:
return self._copy(
flux=self.flux[item], uncertainty=self.uncertainty[item]
if self.uncertainty is not None else None)

return super().__getitem__(item)

def _copy(self, **kwargs):
"""
Peform deep copy operations on each attribute of the ``Spectrum1D``
object.
"""
alt_kwargs = dict(
flux=deepcopy(self.flux),
spectral_axis=deepcopy(self.spectral_axis),
uncertainty=deepcopy(self.uncertainty),
wcs=deepcopy(self.wcs),
mask=deepcopy(self.mask),
meta=deepcopy(self.meta),
unit=deepcopy(self.unit),
velocity_convention=deepcopy(self.velocity_convention),
rest_value=deepcopy(self.rest_value))

alt_kwargs.update(kwargs)

return self.__class__(**alt_kwargs)

@property
def frequency(self):
Expand Down Expand Up @@ -284,7 +326,6 @@ def __repr__(self):

return result


def spectral_resolution(self, true_dispersion, delta_dispersion, axis=-1):
"""Evaluate the probability distribution of the spectral resolution.
Expand Down
17 changes: 17 additions & 0 deletions specutils/tests/test_slicing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import astropy.units as u
import astropy.wcs as fitswcs
from astropy.tests.helper import quantity_allclose
import numpy as np
from numpy.testing import assert_allclose

Expand Down Expand Up @@ -72,6 +73,7 @@ def test_slicing():
assert sub_spec.frequency.unit == u.GHz
assert np.allclose(sub_spec.frequency.value, np.array([74948.1145, 59958.4916, 49965.40966667, 42827.494]))


def test_slicing_with_fits():
my_wcs = fitswcs.WCS(header={'CDELT1': 1, 'CRVAL1': 6562.8, 'CUNIT1': 'Angstrom',
'CTYPE1': 'WAVE', 'RESTFRQ': 1400000000, 'CRPIX1': 25})
Expand All @@ -82,3 +84,18 @@ def test_slicing_with_fits():
assert isinstance(spec_slice, Spectrum1D)
assert spec_slice.flux.size == 4
assert np.allclose(spec_slice.wcs.pixel_to_world([6, 7, 8, 9]).value, spec.wcs.pixel_to_world([6, 7, 8, 9]).value)


def test_slicing_multidim():
spec = Spectrum1D(spectral_axis=np.arange(10) * u.AA,
flux=np.random.sample((5, 10)) * u.Jy)

spec1 = spec[0]
spec2 = spec[1:3]

assert spec1.flux[0] == spec.flux[0][0]
assert quantity_allclose(spec1.spectral_axis, spec.spectral_axis)
assert spec.flux.shape[1:] == spec1.flux.shape

assert quantity_allclose(spec2.flux, spec.flux[1:3])
assert quantity_allclose(spec2.spectral_axis, spec.spectral_axis)
Loading

0 comments on commit b9f9b92

Please sign in to comment.