Skip to content

Commit

Permalink
Resampling now working
Browse files Browse the repository at this point in the history
  • Loading branch information
mchalela committed Jul 10, 2021
1 parent a90501b commit 1a0a7c8
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 34 deletions.
109 changes: 90 additions & 19 deletions nirdust.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ class HeaderKeywordError(KeyError):


def _remove_internals(attrs_dict):
"""Remove internal attributes of a class.
Used when the initialization attributes are required, but attrs.asdict(obj)
also returns internal attributes. The convention for internal attributes
here is one underscore.
"""
new_dict = attrs_dict.copy()
for k in attrs_dict.keys():
if k.endswith("_"):
Expand Down Expand Up @@ -685,31 +691,63 @@ def read_spectrum(file_name, extension=None, z=0):
# RESAMPLE SPECTRA TO MATCH SPECTRAL RESOLUTIONS
# ==============================================================================

def _downscale(low_disp_sp, high_disp_sp):

input_spectra = high_disp_sp.spec1d_
resample_axis = low_disp_sp.spectral_axis
def _rescale(sp, reference_sp):
"""Resample a given spectrum to a reference spectrum.
The first spectrum will be resampled to have the same spectral_axis as
the reference spectrum. The resampling algorithm is the specutils method
FluxConservingResampler.
resampler = FluxConservingResampler(extrapolation_treatment='nan_fill')
output_sp = resampler(input_spectra, resample_axis)
Notes
-----
nan values may occur at the edges where the resampler is forced
to extrapolate.
"""
input_sp1d = sp.spec1d_
resample_axis = reference_sp.spectral_axis

resampled_freq_axis = output_sp.spectral_axis.to(u.Hz)
resampler = FluxConservingResampler(extrapolation_treatment="nan_fill")
output_sp1d = resampler(input_sp1d, resample_axis)

kwargs = _remove_internals(attr.asdict(high_disp_sp))
resampled_freq_axis = output_sp1d.spectral_axis.to(u.Hz)

kwargs = _remove_internals(attr.asdict(sp))
kwargs.update(
flux=output_sp.flux,
flux=output_sp1d.flux,
frequency_axis=resampled_freq_axis,
)
return low_disp_sp, NirdustSpectrum(**kwargs)
return NirdustSpectrum(**kwargs)


def _clean_and_match(sp1, sp2):
"""Clean nan values and apply the same mask to both spectrums."""
# nan values occur in the flux variable
# check for invalid values in both spectrums
mask = np.isfinite(sp1.flux) & np.isfinite(sp2.flux)

sp_list = []
for sp in [sp1, sp2]:
kw = _remove_internals(attr.asdict(sp))
kw.update(flux=sp.flux[mask], frequency_axis=sp.frequency_axis[mask])
sp_list.append(NirdustSpectrum(**kw))

def spectrum_resampling(first_sp, second_sp):
return sp_list


def spectrum_resampling(
first_sp,
second_sp,
scaling="downscale",
clean=True,
):
"""Resample the higher resolution spectrum.
Spectrum_resampling uses the spectral_axis of the lower resolution spectrum
to resample the higher resolution one. To do so this function uses the
FluxConservingResampler() class of 'Specutils'. The order of the input
spectra is arbitrary and the order in the output is the same as in the
input. Only the higher resolution spectrum will be modified, the lower
Spectrum_resampling uses the spectral_axis of the lower resolution
spectrum to resample the higher resolution one. To do so this function
uses the FluxConservingResampler() class of 'Specutils'. The order of the
input spectra is arbitrary and the order in the output is the same as in
the input. Only the higher resolution spectrum will be modified, the lower
resolution spectrum will be unaltered. It is recomended to run
spectrum_resampling after 'cut_edges'.
Expand All @@ -719,18 +757,51 @@ def spectrum_resampling(first_sp, second_sp):
second_sp: NirdustSpectrum object
scaling: string
If 'downscale' the higher resolution spectrum will be resampled to
match the lower resolution spectrum. If 'upscale' the lower resolution
spectrum.
clean: bool
Flag to indicate if the spectrums have to be cleaned by nan values
after the rescaling procedure. nan values occur at the edges of the
resampled spectrum when it is forced to extrapolate beyond the
spectral range of the reference spectrum.
Return
------
out: NirdustSpectrum, NirdustSpectrum
"""
if scaling.lower() not in ["downscale", "upscale"]:
raise ValueError(
"Unknown scaling mode. Must be 'downscale' or 'upscale'."
)

first_disp = first_sp.spectral_dispersion
second_disp = second_sp.spectral_dispersion

# Larger numerical dispersion means lower resolution!
if first_disp > second_disp:
first_sp, second_sp = _downscale(first_sp, second_sp)
# Check type of rescaling
if scaling == "downscale":
second_sp = _rescale(second_sp, reference_sp=first_sp)
else:
first_sp = _rescale(first_sp, reference_sp=second_sp)

elif first_disp < second_disp:
second_sp, first_sp = _downscale(second_sp, first_sp)
if scaling == "downscale":
first_sp = _rescale(first_sp, reference_sp=second_sp)
else:
second_sp = _rescale(second_sp, reference_sp=first_sp)

else:
# they have the same dispersion, is that equivalent
# to equal spectral_axis?
pass

if clean:
first_sp, second_sp = _clean_and_match(first_sp, second_sp)

return first_sp, second_sp

Expand Down Expand Up @@ -888,10 +959,10 @@ def sp_correction(nuclear_spectrum, external_spectrum):
Parameters
----------
nuclear_spectrum: NirdustSpectrum object
Instance of NirdusSpectrum containing the nuclear spectrum.
Instance of NirdustSpectrum containing the nuclear spectrum.
external_spectrum: NirdustSpectrum object
Instance of NirdusSpectrum containing the external spectrum.
Instance of NirdustSpectrum containing the external spectrum.
Return
------
Expand Down
63 changes: 48 additions & 15 deletions test_nirdust.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,15 +611,17 @@ def test_number_of_lines(NGC4945_continuum_rest_frame):

assert len(positions[0]) == 2


def test_spectral_dispersion(NGC4945_continuum_rest_frame):

sp = NGC4945_continuum_rest_frame

dispersion = sp.spectral_dispersion.value
expected = sp.header['CD1_1']
expected = sp.header["CD1_1"]

np.testing.assert_almost_equal(dispersion, expected, decimal=14)


def test_mask_spectrum_1(NGC4945_continuum_rest_frame):

spectrum = NGC4945_continuum_rest_frame
Expand Down Expand Up @@ -699,7 +701,7 @@ def test_sp_correction_with_mask(
assert len(dust.flux) == 544


def test_spectrum_resampling1():
def test_spectrum_resampling_downscale():

rng = np.random.default_rng(75)

Expand All @@ -726,16 +728,26 @@ def test_spectrum_resampling1():
frequency_axis=new_axis.to(u.Hz, equivalencies=u.spectral()),
z=0,
)
print(low_disp_sp.spectral_dispersion)
print(high_disp_sp.spectral_dispersion)

print(len(low_disp_sp.flux), len(high_disp_sp.flux))
f_sp, s_sp = nd.spectrum_resampling(low_disp_sp, high_disp_sp)
print(len(f_sp.flux), len(s_sp.flux))

# check without cleaning nan values
f_sp, s_sp = nd.spectrum_resampling(
low_disp_sp, high_disp_sp, scaling="downscale", clean=False
)

assert len(f_sp.flux) == len(s_sp.flux)
assert len(s_sp.flux) == 500

# check cleaning nan values.
# we know only 1 nan occurs for these spectrums
f_sp, s_sp = nd.spectrum_resampling(
low_disp_sp, high_disp_sp, scaling="downscale", clean=True
)

assert len(f_sp.flux) == len(s_sp.flux)
assert len(s_sp.flux) == 499

def test_spectrum_resampling2():

def test_spectrum_resampling_upscale():

g1 = models.Gaussian1D(0.6, 21200, 10)
g2 = models.Gaussian1D(-0.3, 22000, 15)
Expand All @@ -747,26 +759,43 @@ def test_spectrum_resampling2():
y = g1(axis.value) + g2(axis.value) + rng.normal(0.0, 0.03, axis.shape)
y_tot = (y + 0.0001 * axis.value + 1000) * u.adu

snth_line_spectrum = NirdustSpectrum(
low_disp_sp = NirdustSpectrum(
flux=y_tot,
frequency_axis=axis.to(u.Hz, equivalencies=u.spectral()),
z=0,
)

# same as axis but half the points, hence twice the dispersion
new_axis = np.arange(1500, 2500, 2) * u.Angstrom
new_flux = np.ones(len(new_axis)) * u.adu

n_snth_line_spectrum = NirdustSpectrum(
high_disp_sp = NirdustSpectrum(
flux=new_flux,
frequency_axis=new_axis.to(u.Hz, equivalencies=u.spectral()),
z=0,
)

# check without cleaning nan values
f_sp, s_sp = nd.spectrum_resampling(
low_disp_sp, high_disp_sp, scaling="upscale", clean=False
)

assert len(f_sp.flux) == len(s_sp.flux)
assert len(s_sp.flux) == 1000

# check cleaning nan values.
# we know only 1 nan occurs for these spectrums
f_sp, s_sp = nd.spectrum_resampling(
snth_line_spectrum, n_snth_line_spectrum
low_disp_sp, high_disp_sp, scaling="upscale", clean=True
)

assert len(f_sp.flux) == 500
assert len(f_sp.flux) == len(s_sp.flux)
assert len(s_sp.flux) == 999


def test_spectrum_resampling_invalid_scaling():
with pytest.raises(ValueError):
nd.spectrum_resampling(None, None, scaling="equal")


@pytest.mark.parametrize("true_temp", [500.0, 1000.0, 5000.0])
Expand All @@ -793,6 +822,10 @@ def test_fit_blackbody_with_resampling(
)

snth_bb_temp = (
f_sp.normalize().convert_to_frequency().fit_blackbody(2000.).temperature
f_sp.normalize()
.convert_to_frequency()
.fit_blackbody(2000.0)
.temperature
)
np.testing.assert_almost_equal(snth_bb_temp.value, true_temp, decimal=8)

np.testing.assert_almost_equal(snth_bb_temp.value, true_temp, decimal=1)

0 comments on commit 1a0a7c8

Please sign in to comment.