Skip to content

Commit

Permalink
Merge pull request #180 from GeminiDRSoftware/enh/write1DSpectra
Browse files Browse the repository at this point in the history
add write1DSpectra() primitive
  • Loading branch information
chris-simpson committed Apr 2, 2021
2 parents ee2c381 + ca240a8 commit 7962037
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 10 deletions.
43 changes: 33 additions & 10 deletions geminidr/core/parameters_spect.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
# This parameter file contains the parameters related to the primitives located
# in the primitives_spect.py file, in alphabetical order.
from astropy import units as u
from geminidr.core import parameters_generic
from astropy import table, units as u
from astropy.io import registry

from astrodata import AstroData
from geminidr.core import parameters_generic
from gempy.library import config, astrotools as at


def list_of_ints_check(value):
[int(x) for x in str(value).split(',')]
return True


def table_writing_formats():
t = registry.get_formats(table.Table, readwrite="Write")
return {fmt: "" for fmt, dep in t["Format", "Deprecated"] if dep != "Yes"}


def validate_regions_float(value):
at.parse_user_regions(value, dtype=float, allow_step=False)
return True
Expand Down Expand Up @@ -68,11 +80,6 @@ class determineDistortionConfig(config.Config):
interactive = config.Field("Display interactive fitter?", bool, False)


def min_lines_check(value):
[int(x) for x in str(value).split(',')]
return True


class determineWavelengthSolutionConfig(config.Config):
suffix = config.Field("Filename suffix", str, "_wavelengthSolutionDetermined", optional=True)
order = config.RangeField("Order of fitting polynomial", int, 2, min=1)
Expand All @@ -87,7 +94,7 @@ class determineWavelengthSolutionConfig(config.Config):
default="natural")
fwidth = config.RangeField("Feature width in pixels", float, None, min=2., optional=True)
min_lines = config.Field("Minimum number of lines to fit each segment", (str, int), '15,20',
check=min_lines_check)
check=list_of_ints_check)
central_wavelength = config.RangeField("Estimated central wavelength (nm)", float, None,
min=300., max=25000., optional=True)
dispersion = config.Field("Estimated dispersion (nm/pixel)", float, None, optional=True)
Expand Down Expand Up @@ -416,5 +423,21 @@ class traceAperturesConfig(config.Config):
max_missed = config.RangeField("Maximum number of steps to miss before a line is lost", int, 5, min=0)
debug = config.Field("Draw aperture traces on image display?", bool, False)

def setDefaults(self):
self.order = 2

class write1DSpectraConfig(config.Config):
#format = config.Field("Format for writing", str, "ascii")
format = config.ChoiceField("Format for writing", str,
allowed=table_writing_formats(),
default="ascii", optional=False)
header = config.Field("Write full FITS header?", bool, False)
extension = config.Field("Filename extension", str, "dat")
apertures = config.Field("Apertures to write", (str, int), None,
optional=True, check=list_of_ints_check)
dq = config.Field("Write Data Quality values?", bool, False)
var = config.Field("Write Variance values?", bool, False)
overwrite = config.Field("Overwrite existing files?", bool, False)

def validate(self):
config.Config.validate(self)
if self.header and not self.format.startswith("ascii"):
raise ValueError("FITS header can only be written with ASCII formats")
104 changes: 104 additions & 0 deletions geminidr/core/primitives_spect.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from astropy import units as u
from astropy.io.ascii.core import InconsistentTableError
from astropy.io.registry import IORegistryError
from astropy.io import fits
from astropy.utils.exceptions import AstropyUserWarning
from astropy.modeling import Model, fitting, models
from astropy.stats import sigma_clip, sigma_clipped_stats
from astropy.table import Table, vstack
Expand Down Expand Up @@ -2938,6 +2940,108 @@ def averaging_func(data, mask=None, variance=None):
ad.update_filename(suffix=sfx, strip=True)
return adinputs

def write1DSpectra(self, adinputs=None, **params):
"""
Write 1D spectra to files listing the wavelength and data (and
optionally variance and mask) in one of a range of possible formats.
Parameters
----------
adinputs : list of :class:`~astrodata.AstroData`
Science data as 2D spectral images.
format : str
format for writing output files
header : bool
write FITS header before data values?
extension : str
extension to be used in output filenames
apertures : str
comma-separated list of aperture numbers to write
dq : bool
write DQ (mask) plane?
var : bool
write VAR (variance) plane?
overwrite : bool
overwrite existing files?
Returns
-------
list of :class:`~astrodata.AstroData`
The unmodified input files.
"""
# dict of {format parameter: (Table format, file suffix)}
log = self.log
log.debug(gt.log_message("primitive", self.myself(), "starting"))
fmt = params["format"]
header = params["header"]
extension = params["extension"]
apertures = params["apertures"]
if apertures:
these_apertures = [int(x) for x in str(apertures).split(",")]
write_dq = params["dq"]
write_var = params["var"]
overwrite = params["overwrite"]

for ad in adinputs:
aperture_map = dict(zip(range(len(ad)), ad.hdr.get("APERTURE")))
if apertures is None:
these_apertures = sorted(list(aperture_map.values()))
for aperture in these_apertures:
indices = [k for k, v in aperture_map.items() if v == aperture]
if len(indices) > 2:
log.warning(f"{ad.filename} has more than one aperture "
f"numbered {aperture} - continuing")
continue
elif not indices:
log.warning(f"{ad.filename} does not have an aperture "
f"numbered {aperture} - continuing")
continue

ext = ad[indices.pop()]
if ext.data.ndim != 1:
log.warning(f"{ad.filename} aperture {aperture} is not a "
"1D array - continuing")
continue

data_unit = u.Unit(ext.hdr.get("BUNIT"))
t = Table((ext.wcs(range(ext.data.size)), ext.data),
names=("wavelength", "data"),
units=(ext.wcs.output_frame.unit[0], str(data_unit)))
if write_dq:
t.add_column(ext.mask, name="dq")
if write_var:
t.add_column(ext.variance, name="variance")
t["variance"].unit = str(data_unit ** 2)
var_col = len(t.colnames)

filename = (os.path.splitext(ad.filename)[0] +
f"_{aperture:03d}.{extension}")
log.stdinfo(f"Writing {filename}")
try:
if header:
with open(filename, "w" if overwrite else "x") as f:
for line in (repr(ext.phu) + repr(ext.hdr)).split("\n"):
if line != " " * len(line):
f.write(f"# {line.strip()}\n")
t.write(f, format=fmt)
elif fmt == "fits":
# Table.write isn't happy with the unit 'electron'
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=AstropyUserWarning)
thdu = fits.table_to_hdu(t)
if "TUNIT2" not in thdu.header:
thdu.header["TUNIT2"] = str(data_unit)
if write_var and f"TUNIT{var_col}" not in thdu.header:
thdu.header[f"TUNIT{var_col}"] = str(data_unit ** 2)
hlist = fits.HDUList([fits.PrimaryHDU(), thdu])
hlist.writeto(filename, overwrite=overwrite)
else:
t.write(filename, format=fmt, overwrite=overwrite)
except OSError:
log.warning(f"{filename} already exists - cannot write")

return adinputs

def _read_and_convert_linelist(self, filename, w2=None, in_vacuo=False):
"""
Reads a standard-format linelist and returns a list of wavelengths
Expand Down
65 changes: 65 additions & 0 deletions geminidr/gmos/tests/spect/test_write_1d_spectra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os
from glob import glob
import pytest

import numpy as np
from astropy.table import Table

import astrodata, gemini_instruments
from geminidr.gmos.primitives_gmos_longslit import GMOSLongslit

input_files = ["S20190206S0108_fluxCalibrated.fits"]
formats = [("ascii", "dat", "ascii.basic"),
("fits", "fits", "fits"),
("ascii.csv", "csv", "ascii.csv")]

@pytest.mark.gmosls
@pytest.mark.preprocessed_data
@pytest.mark.parametrize("ad", input_files, indirect=True)
@pytest.mark.parametrize("output_format, extension, input_format", formats)
def test_write_spectrum(ad, output_format, extension, input_format, change_working_dir):

with change_working_dir():
nfiles = len(glob("*"))
p = GMOSLongslit([ad])
p.write1DSpectra(apertures=1, format=output_format,
extension=extension)
assert len(glob("*")) == nfiles + 1
t = Table.read(ad.filename.replace(".fits", f"_001.{extension}"),
format=input_format)
assert len(t) == ad[0].data.size
np.testing.assert_allclose(t["data"].data, ad[0].data, atol=1e-9)
p.write1DSpectra(apertures=None, format=output_format,
extension=extension, overwrite=True)
assert len(glob("*")) == nfiles + len(ad)


# Local Fixtures and Helper Functions ------------------------------------------
@pytest.fixture(scope='function')
def ad(path_to_inputs, request):
"""
Returns the pre-processed spectrum file.
Parameters
----------
path_to_inputs : pytest.fixture
Fixture defined in :mod:`astrodata.testing` with the path to the
pre-processed input file.
request : pytest.fixture
PyTest built-in fixture containing information about parent test.
Returns
-------
AstroData
Input spectrum processed up to right before the
`determineWavelengthSolution` primitive.
"""
filename = request.param
path = os.path.join(path_to_inputs, filename)

if os.path.exists(path):
ad = astrodata.open(path)
else:
raise FileNotFoundError(path)

return ad

0 comments on commit 7962037

Please sign in to comment.