Skip to content

Commit

Permalink
Merge 028d0f7 into dadab3b
Browse files Browse the repository at this point in the history
  • Loading branch information
ibusko committed Oct 23, 2019
2 parents dadab3b + 028d0f7 commit da6bc0f
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 0 deletions.
64 changes: 64 additions & 0 deletions specutils/analysis/correlation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import numpy as np

import astropy.units as u


def _normalize(observed_spectrum, template_spectrum):
"""
Calculate a scale factor to be applied to the template spectrum so the
total flux in both spectra will be the same.
Parameters
----------
observed_spectrum : :class:`~specutils.Spectrum1D`
The observed spectrum.
template_spectrum : :class:`~specutils.Spectrum1D`
The template spectrum, which needs to be normalized in order to be
compared with the observed spectrum.
Returns
-------
`float`
A float which will normalize the template spectrum's flux so that it
can be compared to the observed spectrum.
"""
num = np.sum((observed_spectrum.flux*template_spectrum.flux)/
(observed_spectrum.uncertainty.array**2))
denom = np.sum((template_spectrum.flux/
observed_spectrum.uncertainty.array)**2)

return num/denom


def template_correlate(observed_spectrum, template_spectrum):
"""
Compute cross-correlation of the observed and template spectra
Parameters
----------
observed_spectrum : :class:`~specutils.Spectrum1D`
The observed spectrum.
template_spectrum : :class:`~specutils.Spectrum1D`
The template spectrum, which will be correlated with
the observed spectrum.
Returns
-------
tuple : (`~astropy.units.Quantity`, `~astropy.units.Quantity`)
Arrays with correlation values and lags in km/s
"""
# Normalize template
normalization = _normalize(observed_spectrum, template_spectrum)

# Correlation
corr = np.correlate(observed_spectrum.flux.value,
(template_spectrum.flux.value * normalization),
mode='full')

# Lag in km/s
equiv = getattr(u.equivalencies, 'doppler_{0}'.format(
observed_spectrum.velocity_convention))(observed_spectrum.rest_value)

lag = observed_spectrum.spectral_axis.to(u.km / u.s, equivalencies=equiv)

return (corr * u.dimensionless_unscaled, lag)
59 changes: 59 additions & 0 deletions specutils/tests/test_correlation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import numpy as np
import astropy.units as u

from astropy.nddata import StdDevUncertainty
from astropy.modeling import models

from ..spectra.spectrum1d import Spectrum1D
from ..analysis import correlation

SIZE = 40


def test_correlation():
"""
Test correlation when both observed and template spectra have the same wavelength axis
"""
# Seed np.random so that results are consistent
np.random.seed(41)

# Create test spectra
spec_axis = np.linspace(5000., 5040., num=SIZE) * u.AA

# Two narrow Gaussians are offset from each other so
# as to generate a correlation peak at a expected lag.
f1 = np.random.randn(SIZE) * u.Jy
f2 = np.random.randn(SIZE) * u.Jy
g1 = models.Gaussian1D(amplitude=30 * u.Jy, mean=5020 * u.AA, stddev=2 * u.AA)
g2 = models.Gaussian1D(amplitude=30 * u.Jy, mean=5023 * u.AA, stddev=2 * u.AA)

flux1 = f1 + g1(spec_axis)
flux2 = f2 + g2(spec_axis)

# Observed spectrum must have a rest wavelength value set in.
spec1 = Spectrum1D(spectral_axis=spec_axis,
flux=flux1,
uncertainty=StdDevUncertainty(np.random.sample(SIZE), unit='Jy'),
velocity_convention='optical',
rest_value=spec_axis[int(SIZE/2)])

spec2 = Spectrum1D(spectral_axis=spec_axis,
flux=flux2,
uncertainty=StdDevUncertainty(np.random.sample(SIZE), unit='Jy'))

# Get result from correlation
corr, lag = correlation.template_correlate(spec1, spec2)

# Check units
assert corr.unit == u.dimensionless_unscaled
assert lag.unit == u.km / u.s

# Check that lag at mid-point is zero and lags are symmetrical
midpoint = int(len(lag) / 2)
assert int((lag[midpoint]).value) == 0
np.testing.assert_almost_equal(lag[midpoint+10].value, (-(lag[midpoint-10])).value, 0.01)

# Check position of correlation peak.
maximum = np.argmax(corr)
assert maximum == 36
np.testing.assert_almost_equal(lag[maximum].value, 980., 0.1)

0 comments on commit da6bc0f

Please sign in to comment.