Skip to content

Commit

Permalink
Auto detect fits extension and code refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
mchalela committed Jun 19, 2021
1 parent f24f297 commit 2777100
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 32 deletions.
150 changes: 125 additions & 25 deletions nirdust.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@
import specutils as su
import specutils.manipulation as sm

# ==============================================================================
# EXCEPTIONS
# ==============================================================================


class HeaderKeywordError(KeyError):
"""Raised when eader keyword not found."""

pass


# ==============================================================================
# BLACK BODY METHODS
Expand Down Expand Up @@ -441,13 +451,66 @@ def nplot(self, ax=None, data_color="firebrick", model_color="navy"):
# ==============================================================================


def _get_science_extension(
hdulist, extension, disp_k, first_wav_k, disp_type_k
):
"""Auto detect fits science extension using the provided keywords."""
if extension is not None:
return extension

keys = {disp_k, first_wav_k, disp_type_k}
extl = []
for ext, hdu in enumerate(hdulist):
if keys.issubset(hdu.header.keys()):
extl.append(ext)

if len(extl) > 1:
raise HeaderKeywordError(
"More than one extension with relevant keywords. "
"Please specify the extension."
)

elif len(extl) == 0:
raise HeaderKeywordError(
"No fits extension found with the requested keywords."
)

return extl[0]


def pix2wavelength(pix_arr, pix_0_wav, pix_disp, z=0):
"""Transform pixel to wavelength assuming linear dispersion.
This transformation assumes a linear dispersion.
Parameters
----------
pix_arr: float or `~numpy.ndarray`
Array of pixels values.
pix_0_wav: float
Wavelength value of the first pixel in pix_arr
pix_disp: float
Wavelength dispersion per pixel. Must be in same
units as pix_0_wav.
z: float
Redshift of object. Use for the scale factor 1 / (1 + z).
"""
scale_factor = 1 / (1 + z)
wave_arr = pix_0_wav + pix_disp * pix_arr # assume linear dispersion
wave_arr *= scale_factor
return wave_arr


def spectrum(
flux,
header,
dispersion_key,
first_wavelength,
dispersion_type,
z=0,
dispersion_key="CD1_1",
first_wavelength="CRVAL1",
dispersion_type="CTYPE1",
**kwargs,
):
"""Instantiate a NirdustSpectrum object from FITS parameters.
Expand All @@ -460,9 +523,6 @@ def spectrum(
header: FITS header
Header of the spectrum.
z: float
Redshif of the galaxy.
dispersion_key: str
Header keyword that gives dispersion in Å/pix. Default is 'CD1_1'
Expand All @@ -474,24 +534,27 @@ def spectrum(
Header keyword that contains the dispersion function type. Default is
``CTYPE1``.
z: float
Redshif of the galaxy.
Return
------
spectrum: ``NirsdustSpectrum``
Return a instance of the class NirdustSpectrum with the entered
parameters.
"""
if header[dispersion_key] <= 0:
pix_0_wav = header[first_wavelength] # wavelength of first pixel
pix_disp = header[dispersion_key] # dispersion Angstrom per pixel

if pix_disp <= 0:
raise ValueError("dispersion must be positive")

spectrum_length = len(flux)
spectral_axis = (
(
header[first_wavelength]
+ header[dispersion_key] * np.arange(0, spectrum_length)
)
/ (1 + z)
* u.AA
)

# unit should be the same as first_wavelength and dispersion_key, AA ?
pixel_axis = np.arange(spectrum_length)
spectral_axis = pix2wavelength(pixel_axis, pix_0_wav, pix_disp, z) * u.AA

spec1d = su.Spectrum1D(
flux=flux * u.adu, spectral_axis=spectral_axis, **kwargs
)
Expand All @@ -509,31 +572,68 @@ def spectrum(
)


def read_spectrum(file_name, extension, z, **kwargs):
def read_spectrum(
file_name,
extension=None,
dispersion_key="CD1_1",
first_wavelength="CRVAL1",
dispersion_type="CTYPE1",
z=0,
**kwargs,
):
"""Read a spectrum in FITS format and store it in a NirdustSpectrum object.
Parameters
----------
file_name: str
Path to where the fits file is stored.
extension: int
Extension of the FITS file where the spectrum is stored.
extension: int or str
Extension of the FITS file where the spectrum is stored. If None the
extension will be automatically identified by searching for the
relevant header keywords. Default is None.
dispersion_key: str
Header keyword that gives dispersion in Å/pix. Default is 'CD1_1'
first_wavelength: str
Header keyword that contains the wavelength of the first pixel. Default
is ``CRVAL1``.
dispersion_type: str
Header keyword that contains the dispersion function type. Default is
``CTYPE1``.
z: float
Redshift of the galaxy.
Redshift of the galaxy. Used to scale the spectral axis with the
cosmological sacle factor 1/(1+z). Default is 0.
Returns
-------
out: NirsdustSpectrum object
Returns an instance of the class NirdustSpectrum.
"""
with fits.open(file_name) as fits_spectrum:

fluxx = fits_spectrum[extension].data
header = fits.getheader(file_name)

single_spectrum = spectrum(flux=fluxx, header=header, z=z, **kwargs)
with fits.open(file_name) as hdulist:

ext = _get_science_extension(
hdulist,
extension,
dispersion_key,
first_wavelength,
dispersion_type,
)
flux = hdulist[ext].data
header = hdulist[ext].header

single_spectrum = spectrum(
flux,
header,
dispersion_key,
first_wavelength,
dispersion_type,
z,
**kwargs,
)

return single_spectrum

Expand Down
33 changes: 26 additions & 7 deletions test_nirdust.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,28 +37,28 @@
@pytest.fixture(scope="session")
def NGC4945_continuum():
file_name = TEST_PATH / "cont03.fits"
spect = nd.read_spectrum(file_name, 0, 0.00188)
spect = nd.read_spectrum(file_name, 0, z=0.00188)
return spect


@pytest.fixture(scope="session")
def NGC4945_continuum_rest_frame():
file_name = TEST_PATH / "cont03.fits"
spect = nd.read_spectrum(file_name, 0, 0)
spect = nd.read_spectrum(file_name, 0, z=0)
return spect


@pytest.fixture(scope="session")
def NGC4945_external_continuum_400pc():
file_name = TEST_PATH / "external_spectrum_400pc_N4945.fits"
spect = nd.read_spectrum(file_name, 0, 0.00188)
spect = nd.read_spectrum(file_name, 0, z=0.00188)
return spect


@pytest.fixture(scope="session")
def NGC4945_external_continuum_200pc():
file_name = TEST_PATH / "external_spectrum_200pc_N4945.fits"
spect = nd.read_spectrum(file_name, 0, 0.00188)
spect = nd.read_spectrum(file_name, 0, z=0.00188)
return spect


Expand Down Expand Up @@ -104,6 +104,13 @@ def snth_spectrum_1000(NGC4945_continuum_rest_frame):
# ==============================================================================


def test_read_spectrum():
# read with no extension and wrong keyword
file_name = TEST_PATH / "external_spectrum_200pc_N4945.fits"
with pytest.raises(nd.HeaderKeywordError):
nd.read_spectrum(file_name, dispersion_key="CD11")


def test_match(NGC4945_continuum):
spectrum = NGC4945_continuum
assert spectrum.spectral_axis.shape == spectrum.flux.shape
Expand Down Expand Up @@ -185,7 +192,7 @@ def test_sp_correction_second_if(
NGC4945_continuum, NGC4945_external_continuum_200pc
):
spectrum = nd.read_spectrum(
TEST_PATH / "cont01.fits", 0, 0.00188
TEST_PATH / "cont01.fits", 0, z=0.00188
).cut_edges(19600, 22900)
external_spectrum = NGC4945_external_continuum_200pc.cut_edges(
19600, 22900
Expand All @@ -200,7 +207,7 @@ def test_sp_correction_third_if(
):
spectrum = NGC4945_external_continuum_200pc.cut_edges(19600, 22900)
external_spectrum = nd.read_spectrum(
TEST_PATH / "cont01.fits", 0, 0.00188
TEST_PATH / "cont01.fits", 0, z=0.00188
).cut_edges(19600, 22900)
prepared = nd.sp_correction(spectrum, external_spectrum)
expected_len = len(external_spectrum.spectral_axis)
Expand Down Expand Up @@ -404,7 +411,7 @@ def test_fit_blackbody(NGC4945_continuum_rest_frame):
def test_nplot(fig_test, fig_ref):

spectrum = (
nd.read_spectrum(TEST_PATH / "cont03.fits", 0, 0.00188)
nd.read_spectrum(TEST_PATH / "cont03.fits", 0, z=0.00188)
.cut_edges(19500, 22900)
.normalize()
)
Expand All @@ -429,3 +436,15 @@ def test_nplot(fig_test, fig_ref):
ax_ref.set_xlabel("Frequency [Hz]")
ax_ref.set_ylabel("Normalized Energy [arbitrary units]")
ax_ref.legend()


def test_pix2wavelength():
pix_array = np.arange(10, 50)
pix_0_wav = 5.0
pix_disp = np.pi
z = 0.1

expected = (pix_0_wav + pix_disp * pix_array) / (1 + z)
result = nd.pix2wavelength(pix_array, 5.0, np.pi, 0.1)

np.testing.assert_almost_equal(result, expected, decimal=14)

0 comments on commit 2777100

Please sign in to comment.