-
-
Notifications
You must be signed in to change notification settings - Fork 124
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
285 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
from ..spectra.spectrum1d import Spectrum1D | ||
from ..spectra.spectrum_collection import SpectrumCollection | ||
from ..manipulation import FluxConservingResampler | ||
from ..manipulation import LinearInterpolatedResampler | ||
from ..manipulation import SplineInterpolatedResampler | ||
import numpy as np | ||
|
||
def _normalize_for_template_matching(observed_spectrum, template_spectrum): | ||
""" | ||
Calculate a scale factor to be applied to the template spectrum so the total flux in both spectra will be the same. | ||
Parameters | ||
---------- | ||
observed_spectrum : :class:`~specutils.Spectrum1D` | ||
the observed spectrum | ||
template_spectrum : :class:`~specutils.Spectrum1D` | ||
the template spectrum, which needs to be normalized in order to be compared with the observed_spectrum | ||
Returns | ||
------- | ||
`float` | ||
A float which will normalize the template spectrum's flux so that it can be compared to the observed spectrum | ||
""" | ||
num = np.sum((observed_spectrum.flux*template_spectrum.flux)/(observed_spectrum.uncertainty.array**2)) | ||
denom = np.sum((template_spectrum.flux/observed_spectrum.uncertainty.array)**2) | ||
return num/denom | ||
|
||
def _resample(resample_method): | ||
""" | ||
Find the user preferred method of resampling the template spectrum to fit the observed spectrum. | ||
Parameters | ||
---------- | ||
resample_method: `string` | ||
The type of resampling to be done on the template spectrum | ||
Returns | ||
------- | ||
:class:`~specutils.ResamplerBase` | ||
This is the actual class that will handle the resampling | ||
""" | ||
if resample_method == "flux_conserving": | ||
return FluxConservingResampler() | ||
elif resample_method == "linear_interpolated": | ||
return LinearInterpolatedResampler() | ||
elif resample_method == "spline_interpolated": | ||
return SplineInterpolatedResampler() | ||
else: | ||
return None | ||
|
||
def _template_match(observed_spectrum, template_spectrum, resample_method): | ||
""" | ||
Resample the template spectrum to match the wavelength of the observed spectrum. | ||
Then, calculate chi2 on the flux of the two spectra. | ||
Parameters | ||
---------- | ||
observed_spectrum : :class:`~specutils.Spectrum1D` | ||
the observed spectrum | ||
template_spectrum : :class:`~specutils.Spectrum1D` | ||
the template spectrum, which will be resampled to match the wavelength of the observed_spectrum | ||
Returns | ||
------- | ||
normalized_template_spectrum : :class:`~specutils.Spectrum1D` | ||
the template_spectrum that has been normalized | ||
chi2 : `float` | ||
the chi2 of the flux of the observed_spectrum and the flux of the normalized_template_spectrum | ||
""" | ||
# Resample template | ||
if _resample(resample_method) != 0: | ||
fluxc_resample = _resample(resample_method) | ||
template_obswavelength = fluxc_resample(template_spectrum, observed_spectrum.wavelength) | ||
|
||
# Normalize spectra | ||
normalization = _normalize_for_template_matching(observed_spectrum, template_obswavelength) | ||
|
||
# Numerator | ||
num_right = normalization*template_obswavelength.flux | ||
num = observed_spectrum.flux - num_right | ||
|
||
# Denominator | ||
denom = observed_spectrum.uncertainty.array * observed_spectrum.flux.unit | ||
|
||
# Get chi square | ||
result = (num/denom)**2 | ||
chi2 = np.sum(result) | ||
|
||
# Create normalized template spectrum, which will be returned with corresponding chi2 | ||
normalized_template_spectrum = Spectrum1D(spectral_axis=template_spectrum.spectral_axis, | ||
flux=template_spectrum.flux*normalization) | ||
|
||
return normalized_template_spectrum, chi2 | ||
|
||
def template_match(observed_spectrum, spectral_templates, resample_method="flux_conserving"): | ||
""" | ||
Find which spectral templates is the best fit to an observed spectrum by computing the chi-squared. If two | ||
template_spectra have the same chi2, the first template is returned | ||
Parameters | ||
---------- | ||
observed_spectrum : :class:`~specutils.Spectrum1D` | ||
the observed spectrum | ||
spectral_templates : :class:`~specutils.Spectrum1D` or :class:`~specutils.SpectrumCollection` or `list` or anything | ||
that will give a single :class:`~specutils.Spectrum1D` when iterated over. | ||
The template spectra, which will be resampled and normalized and compared to the observed_spectrum, where the | ||
smallest chi2 and normalized_template_spectrum will be returned | ||
Returns | ||
------- | ||
normalized_template_spectrum : :class:`~specutils.Spectrum1D` | ||
the template_spectrum that has been normalized | ||
chi2 : `float` | ||
the chi2 of the flux of the observed_spectrum and the flux of the normalized_template_spectrum | ||
""" | ||
if hasattr(spectral_templates, 'flux') and len(spectral_templates.flux.shape)==1: | ||
normalized_spectral_template, chi2 = _template_match(observed_spectrum, spectral_templates, resample_method) | ||
return normalized_spectral_template, chi2 | ||
|
||
# Loop through spectra in list and return spectrum with lowest chi square | ||
# and its corresponding chi square | ||
else: | ||
chi2_min = None | ||
smallest_chi_spec = None | ||
|
||
index = 0 | ||
try: | ||
for spectrum in spectral_templates: | ||
normalized_spectral_template, chi2 = _template_match(observed_spectrum, spectrum, resample_method) | ||
if chi2_min is None or chi2 < chi2_min: | ||
chi2_min = chi2 | ||
smallest_chi_spec = normalized_spectral_template | ||
smallest_chi_index = index | ||
index+=1 | ||
except Exception as e: | ||
print("Parameter spectral_templates is not iterable. The following error was fired: {}".format(e)) | ||
|
||
return smallest_chi_spec, chi2_min, smallest_chi_index |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import astropy.units as u | ||
import numpy as np | ||
from astropy.nddata import StdDevUncertainty | ||
|
||
from ..spectra.spectrum1d import Spectrum1D | ||
from ..spectra.spectrum_collection import SpectrumCollection | ||
from ..analysis import template_comparison | ||
from astropy.tests.helper import quantity_allclose | ||
|
||
|
||
# TODO: Add some tests that are outliers: where the observed and template do not overlap (what happens?), | ||
# TODO: where there is minimal overlap (1 point)? | ||
|
||
def test_template_match_spectrum(): | ||
""" | ||
Test template_match when both observed and template spectra have the same wavelength axis | ||
""" | ||
# Seed np.random so that results are consistent | ||
np.random.seed(42) | ||
|
||
# Create test spectra | ||
spec_axis = np.linspace(0, 50, 50) * u.AA | ||
spec = Spectrum1D(spectral_axis=spec_axis, | ||
flux=np.random.randn(50) * u.Jy, | ||
uncertainty=StdDevUncertainty(np.random.sample(50), unit='Jy')) | ||
|
||
spec1 = Spectrum1D(spectral_axis=spec_axis, | ||
flux=np.random.randn(50) * u.Jy, | ||
uncertainty=StdDevUncertainty(np.random.sample(50), unit='Jy')) | ||
|
||
# Get result from template_match | ||
tm_result = template_comparison.template_match(spec, spec1) | ||
|
||
# Create new spectrum for comparison | ||
spec_result = Spectrum1D(spectral_axis=spec_axis, | ||
flux=spec1.flux * template_comparison._normalize_for_template_matching(spec, spec1)) | ||
|
||
assert quantity_allclose(tm_result[0].flux, spec_result.flux, atol=0.01*u.Jy) | ||
assert tm_result[1] == 40093.28353756253 | ||
|
||
def test_template_match_with_resample(): | ||
""" | ||
Test template_match when both observed and template spectra have different wavelength axis using resampling | ||
""" | ||
np.random.seed(42) | ||
|
||
# Create test spectra | ||
spec_axis1 = np.linspace(0, 50, 50) * u.AA | ||
spec_axis2 = np.linspace(0, 50, 50) * u.AA | ||
spec = Spectrum1D(spectral_axis=spec_axis1, | ||
flux=np.random.randn(50) * u.Jy, | ||
uncertainty=StdDevUncertainty(np.random.sample(50), unit='Jy')) | ||
|
||
spec1 = Spectrum1D(spectral_axis=spec_axis2, | ||
flux=np.random.randn(50) * u.Jy, | ||
uncertainty=StdDevUncertainty(np.random.sample(50), unit='Jy')) | ||
|
||
# Get result from template_match | ||
tm_result = template_comparison.template_match(spec, spec1) | ||
|
||
# Create new spectrum for comparison | ||
spec_result = Spectrum1D(spectral_axis=spec_axis1, | ||
flux=spec1.flux * template_comparison._normalize_for_template_matching(spec, spec1)) | ||
|
||
assert quantity_allclose(tm_result[0].flux, spec_result.flux, atol=0.01*u.Jy) | ||
assert tm_result[1] == 40093.28353756253 | ||
|
||
def test_template_match_list(): | ||
""" | ||
Test template_match when template spectra are in a list | ||
""" | ||
np.random.seed(42) | ||
|
||
# Create test spectra | ||
spec_axis1 = np.linspace(0, 50, 50) * u.AA | ||
spec_axis2 = np.linspace(0, 50, 50) * u.AA | ||
spec = Spectrum1D(spectral_axis=spec_axis1, | ||
flux=np.random.randn(50) * u.Jy, | ||
uncertainty=StdDevUncertainty(np.random.sample(50), unit='Jy')) | ||
|
||
spec1 = Spectrum1D(spectral_axis=spec_axis2, | ||
flux=np.random.randn(50) * u.Jy, | ||
uncertainty=StdDevUncertainty(np.random.sample(50), unit='Jy')) | ||
spec2 = Spectrum1D(spectral_axis=spec_axis2, | ||
flux=np.random.randn(50) * u.Jy, | ||
uncertainty=StdDevUncertainty(np.random.sample(50), unit='Jy')) | ||
|
||
# Combine spectra into list | ||
template_list = [spec1, spec2] | ||
|
||
# Get result from template_match | ||
tm_result = template_comparison.template_match(spec, template_list) | ||
|
||
assert tm_result[1] == 40093.28353756253 | ||
|
||
def test_template_match_spectrum_collection(): | ||
""" | ||
Test template_match when template spectra are in a SpectrumCollection object | ||
""" | ||
np.random.seed(42) | ||
|
||
# Create test spectra | ||
spec_axis1 = np.linspace(0, 50, 50) * u.AA | ||
spec_axis2 = np.linspace(0, 50, 50) * u.AA | ||
spec = Spectrum1D(spectral_axis=spec_axis1, | ||
flux=np.random.randn(50) * u.Jy, | ||
uncertainty=StdDevUncertainty(np.random.sample(50), unit='Jy')) | ||
|
||
spec1 = Spectrum1D(spectral_axis=spec_axis2, | ||
flux=np.random.randn(50) * u.Jy, | ||
uncertainty=StdDevUncertainty(np.random.sample(50), unit='Jy')) | ||
spec2 = Spectrum1D(spectral_axis=spec_axis2, | ||
flux=np.random.randn(50) * u.Jy, | ||
uncertainty=StdDevUncertainty(np.random.sample(50), unit='Jy')) | ||
|
||
# Combine spectra into SpectrumCollection object | ||
spec_coll = SpectrumCollection.from_spectra([spec1, spec2]) | ||
|
||
# Get result from template_match | ||
tm_result = template_comparison.template_match(spec, spec_coll) | ||
|
||
assert tm_result[1] == 40093.28353756253 |