Skip to content

Commit

Permalink
Merge 5a432e1 into bc47905
Browse files Browse the repository at this point in the history
  • Loading branch information
javerbukh committed Sep 13, 2019
2 parents bc47905 + 5a432e1 commit f5d5254
Show file tree
Hide file tree
Showing 4 changed files with 285 additions and 0 deletions.
24 changes: 24 additions & 0 deletions docs/analysis.rst
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,30 @@ Each of the width analysis functions are applied to this spectrum below:
<Quantity 89.99999455 GHz>
Template Comparison
-------------------

With one observed spectrum and n template spectra, the process to do template matching is:

1. Move the templates as close as possible to the observed spectrum by doing redshifting, matching the resolution,
and matching the wavelength spacing
2. Then loop through all of the n template spectra and run chi square on them using the observed spectrum as the
expected frequency and the template spectrum as the observed
3. Once you have a corresponding chi square for each template spectrum, return the lowest chi square and its
corresponding template spectrum (normalized) and the index of the template spectrum if the original template
parameter is iterable

An example of how to do template matching in is:

.. code-block:: python
>>> from specutils.analysis import template_comparison
>>> spec_axis = np.linspace(0, 50, 50) * u.AA
>>> spec = Spectrum1D(spectral_axis=spec_axis, flux=np.random.randn(50) * u.Jy, uncertainty=StdDevUncertainty(np.random.sample(50), unit='Jy'))
>>> spec1 = Spectrum1D(spectral_axis=spec_axis, flux=np.random.randn(50) * u.Jy, uncertainty=StdDevUncertainty(np.random.sample(50), unit='Jy'))
>>> tm_result = template_comparison.template_match(spec, spec1) # doctest:+FLOAT_CMP
Reference/API
-------------
.. automodapi:: specutils.analysis
Expand Down
1 change: 1 addition & 0 deletions specutils/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .uncertainty import * # noqa
from .location import * # noqa
from .width import * # noqa
from .template_comparison import *
138 changes: 138 additions & 0 deletions specutils/analysis/template_comparison.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from ..spectra.spectrum1d import Spectrum1D
from ..spectra.spectrum_collection import SpectrumCollection
from ..manipulation import FluxConservingResampler
from ..manipulation import LinearInterpolatedResampler
from ..manipulation import SplineInterpolatedResampler
import numpy as np

def _normalize_for_template_matching(observed_spectrum, template_spectrum):
"""
Calculate a scale factor to be applied to the template spectrum so the total flux in both spectra will be the same.
Parameters
----------
observed_spectrum : :class:`~specutils.Spectrum1D`
the observed spectrum
template_spectrum : :class:`~specutils.Spectrum1D`
the template spectrum, which needs to be normalized in order to be compared with the observed_spectrum
Returns
-------
`float`
A float which will normalize the template spectrum's flux so that it can be compared to the observed spectrum
"""
num = np.sum((observed_spectrum.flux*template_spectrum.flux)/(observed_spectrum.uncertainty.array**2))
denom = np.sum((template_spectrum.flux/observed_spectrum.uncertainty.array)**2)
return num/denom

def _resample(resample_method):
"""
Find the user preferred method of resampling the template spectrum to fit the observed spectrum.
Parameters
----------
resample_method: `string`
The type of resampling to be done on the template spectrum
Returns
-------
:class:`~specutils.ResamplerBase`
This is the actual class that will handle the resampling
"""
if resample_method == "flux_conserving":
return FluxConservingResampler()
elif resample_method == "linear_interpolated":
return LinearInterpolatedResampler()
elif resample_method == "spline_interpolated":
return SplineInterpolatedResampler()
else:
return None

def _template_match(observed_spectrum, template_spectrum, resample_method):
"""
Resample the template spectrum to match the wavelength of the observed spectrum.
Then, calculate chi2 on the flux of the two spectra.
Parameters
----------
observed_spectrum : :class:`~specutils.Spectrum1D`
the observed spectrum
template_spectrum : :class:`~specutils.Spectrum1D`
the template spectrum, which will be resampled to match the wavelength of the observed_spectrum
Returns
-------
normalized_template_spectrum : :class:`~specutils.Spectrum1D`
the template_spectrum that has been normalized
chi2 : `float`
the chi2 of the flux of the observed_spectrum and the flux of the normalized_template_spectrum
"""
# Resample template
if _resample(resample_method) != 0:
fluxc_resample = _resample(resample_method)
template_obswavelength = fluxc_resample(template_spectrum, observed_spectrum.wavelength)

# Normalize spectra
normalization = _normalize_for_template_matching(observed_spectrum, template_obswavelength)

# Numerator
num_right = normalization*template_obswavelength.flux
num = observed_spectrum.flux - num_right

# Denominator
denom = observed_spectrum.uncertainty.array * observed_spectrum.flux.unit

# Get chi square
result = (num/denom)**2
chi2 = np.sum(result)

# Create normalized template spectrum, which will be returned with corresponding chi2
normalized_template_spectrum = Spectrum1D(spectral_axis=template_spectrum.spectral_axis,
flux=template_spectrum.flux*normalization)

return normalized_template_spectrum, chi2

def template_match(observed_spectrum, spectral_templates, resample_method="flux_conserving"):
"""
Find which spectral templates is the best fit to an observed spectrum by computing the chi-squared. If two
template_spectra have the same chi2, the first template is returned
Parameters
----------
observed_spectrum : :class:`~specutils.Spectrum1D`
the observed spectrum
spectral_templates : :class:`~specutils.Spectrum1D` or :class:`~specutils.SpectrumCollection` or `list` or anything
that will give a single :class:`~specutils.Spectrum1D` when iterated over.
The template spectra, which will be resampled and normalized and compared to the observed_spectrum, where the
smallest chi2 and normalized_template_spectrum will be returned
Returns
-------
normalized_template_spectrum : :class:`~specutils.Spectrum1D`
the template_spectrum that has been normalized
chi2 : `float`
the chi2 of the flux of the observed_spectrum and the flux of the normalized_template_spectrum
"""
if hasattr(spectral_templates, 'flux') and len(spectral_templates.flux.shape)==1:
normalized_spectral_template, chi2 = _template_match(observed_spectrum, spectral_templates, resample_method)
return normalized_spectral_template, chi2

# Loop through spectra in list and return spectrum with lowest chi square
# and its corresponding chi square
else:
chi2_min = None
smallest_chi_spec = None

index = 0
try:
for spectrum in spectral_templates:
normalized_spectral_template, chi2 = _template_match(observed_spectrum, spectrum, resample_method)
if chi2_min is None or chi2 < chi2_min:
chi2_min = chi2
smallest_chi_spec = normalized_spectral_template
smallest_chi_index = index
index+=1
except Exception as e:
print("Parameter spectral_templates is not iterable. The following error was fired: {}".format(e))

return smallest_chi_spec, chi2_min, smallest_chi_index
122 changes: 122 additions & 0 deletions specutils/tests/test_template_comparison.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import astropy.units as u
import numpy as np
from astropy.nddata import StdDevUncertainty

from ..spectra.spectrum1d import Spectrum1D
from ..spectra.spectrum_collection import SpectrumCollection
from ..analysis import template_comparison
from astropy.tests.helper import quantity_allclose


# TODO: Add some tests that are outliers: where the observed and template do not overlap (what happens?),
# TODO: where there is minimal overlap (1 point)?

def test_template_match_spectrum():
"""
Test template_match when both observed and template spectra have the same wavelength axis
"""
# Seed np.random so that results are consistent
np.random.seed(42)

# Create test spectra
spec_axis = np.linspace(0, 50, 50) * u.AA
spec = Spectrum1D(spectral_axis=spec_axis,
flux=np.random.randn(50) * u.Jy,
uncertainty=StdDevUncertainty(np.random.sample(50), unit='Jy'))

spec1 = Spectrum1D(spectral_axis=spec_axis,
flux=np.random.randn(50) * u.Jy,
uncertainty=StdDevUncertainty(np.random.sample(50), unit='Jy'))

# Get result from template_match
tm_result = template_comparison.template_match(spec, spec1)

# Create new spectrum for comparison
spec_result = Spectrum1D(spectral_axis=spec_axis,
flux=spec1.flux * template_comparison._normalize_for_template_matching(spec, spec1))

assert quantity_allclose(tm_result[0].flux, spec_result.flux, atol=0.01*u.Jy)
assert tm_result[1] == 40093.28353756253

def test_template_match_with_resample():
"""
Test template_match when both observed and template spectra have different wavelength axis using resampling
"""
np.random.seed(42)

# Create test spectra
spec_axis1 = np.linspace(0, 50, 50) * u.AA
spec_axis2 = np.linspace(0, 50, 50) * u.AA
spec = Spectrum1D(spectral_axis=spec_axis1,
flux=np.random.randn(50) * u.Jy,
uncertainty=StdDevUncertainty(np.random.sample(50), unit='Jy'))

spec1 = Spectrum1D(spectral_axis=spec_axis2,
flux=np.random.randn(50) * u.Jy,
uncertainty=StdDevUncertainty(np.random.sample(50), unit='Jy'))

# Get result from template_match
tm_result = template_comparison.template_match(spec, spec1)

# Create new spectrum for comparison
spec_result = Spectrum1D(spectral_axis=spec_axis1,
flux=spec1.flux * template_comparison._normalize_for_template_matching(spec, spec1))

assert quantity_allclose(tm_result[0].flux, spec_result.flux, atol=0.01*u.Jy)
assert tm_result[1] == 40093.28353756253

def test_template_match_list():
"""
Test template_match when template spectra are in a list
"""
np.random.seed(42)

# Create test spectra
spec_axis1 = np.linspace(0, 50, 50) * u.AA
spec_axis2 = np.linspace(0, 50, 50) * u.AA
spec = Spectrum1D(spectral_axis=spec_axis1,
flux=np.random.randn(50) * u.Jy,
uncertainty=StdDevUncertainty(np.random.sample(50), unit='Jy'))

spec1 = Spectrum1D(spectral_axis=spec_axis2,
flux=np.random.randn(50) * u.Jy,
uncertainty=StdDevUncertainty(np.random.sample(50), unit='Jy'))
spec2 = Spectrum1D(spectral_axis=spec_axis2,
flux=np.random.randn(50) * u.Jy,
uncertainty=StdDevUncertainty(np.random.sample(50), unit='Jy'))

# Combine spectra into list
template_list = [spec1, spec2]

# Get result from template_match
tm_result = template_comparison.template_match(spec, template_list)

assert tm_result[1] == 40093.28353756253

def test_template_match_spectrum_collection():
"""
Test template_match when template spectra are in a SpectrumCollection object
"""
np.random.seed(42)

# Create test spectra
spec_axis1 = np.linspace(0, 50, 50) * u.AA
spec_axis2 = np.linspace(0, 50, 50) * u.AA
spec = Spectrum1D(spectral_axis=spec_axis1,
flux=np.random.randn(50) * u.Jy,
uncertainty=StdDevUncertainty(np.random.sample(50), unit='Jy'))

spec1 = Spectrum1D(spectral_axis=spec_axis2,
flux=np.random.randn(50) * u.Jy,
uncertainty=StdDevUncertainty(np.random.sample(50), unit='Jy'))
spec2 = Spectrum1D(spectral_axis=spec_axis2,
flux=np.random.randn(50) * u.Jy,
uncertainty=StdDevUncertainty(np.random.sample(50), unit='Jy'))

# Combine spectra into SpectrumCollection object
spec_coll = SpectrumCollection.from_spectra([spec1, spec2])

# Get result from template_match
tm_result = template_comparison.template_match(spec, spec_coll)

assert tm_result[1] == 40093.28353756253

0 comments on commit f5d5254

Please sign in to comment.