Skip to content

Commit

Permalink
Merge branch 'enh/wavecal_3.2' into release/3.2.x
Browse files Browse the repository at this point in the history
  • Loading branch information
KathleenLabrie committed Apr 13, 2024
2 parents 85ed49a + e86ebd2 commit b984d9d
Show file tree
Hide file tree
Showing 9 changed files with 205 additions and 74 deletions.
1 change: 1 addition & 0 deletions geminidr/core/parameters_spect.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class determineWavelengthSolutionConfig(config.core_1Dfitting_config):
check=list_of_ints_check)
debug_alternative_centers = config.Field("Try alternative wavelength centers?", bool, False)
interactive = config.Field("Display interactive fitter?", bool, False)
verbose = config.Field("Print additional fitting information?", bool, False)

def setDefaults(self):
del self.function
Expand Down
30 changes: 21 additions & 9 deletions geminidr/core/primitives_spect.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ def _get_fit1d_input_data(ext, exptime, spec_table):
uiparams = UIParameters(config)
visualizer = fit1d.Fit1DVisualizer({"x": all_waves, "y": all_zpt, "weights": all_weights},
fitting_parameters=all_fp_init,
tab_name_fmt="CCD {}",
tab_name_fmt=lambda i: f"CCD {i+1}",
xlabel=f'Wavelength ({xunits})',
ylabel=f'Sensitivity ({yunits})',
domains=all_domains,
Expand Down Expand Up @@ -1354,7 +1354,7 @@ def determineWavelengthSolution(self, adinputs=None, **params):
visualizer = WavelengthSolutionVisualizer(
reconstruct_points, all_fp_init,
modal_message="Re-extracting 1D spectra",
tab_name_fmt="Slit {}",
tab_name_fmt=lambda i: f"Slit {i+1}",
xlabel="Fitted wavelength (nm)", ylabel="Non-linear component (nm)",
domains=domains,
title="Wavelength Solution",
Expand All @@ -1375,12 +1375,12 @@ def determineWavelengthSolution(self, adinputs=None, **params):
input_data, fit1d, acceptable_fit = wavecal.get_automated_fit(
ext, uiparams, p=self, linelist=linelist, bad_bits=DQ.not_signal)
if not acceptable_fit:
log.warning("No acceptable wavelength solution found "
f"for {ext.id}")

wavecal.update_wcs_with_solution(ext, fit1d, input_data, config)
wavecal.save_fit_as_pdf(input_data["spectrum"], fit1d.points[~fit1d.mask],
fit1d.image[~fit1d.mask], ad.filename)
log.warning("No acceptable wavelength solution found")
else:
wavecal.update_wcs_with_solution(ext, fit1d, input_data, config)
wavecal.save_fit_as_pdf(
input_data["spectrum"], fit1d.points[~fit1d.mask],
fit1d.image[~fit1d.mask], ad.filename)

# Timestamp and update the filename
gt.mark_history(ad, primname=self.myself(), keyword=timestamp_key)
Expand Down Expand Up @@ -2987,7 +2987,7 @@ def recalc_fn(ad: AstroData, ui_parms: UIParameters):
ui_params = UIParameters(config, reinit_params=reinit_params, extras=reinit_extras)
visualizer = fit1d.Fit1DVisualizer(lambda ui_params: recalc_fn(ad, ui_params),
fitting_parameters=[fit1d_params]*count,
tab_name_fmt="Slit {}",
tab_name_fmt=lambda i: f"Slit {i+1}",
xlabel='Row',
ylabel='Signal',
domains=all_shapes,
Expand Down Expand Up @@ -3529,6 +3529,18 @@ def _get_spectrophotometry(self, filename, in_vacuo=False):
* u.Unit("erg cm-2 s-1") / u.Hz)
return spec_table

def _apply_wavelength_model_bounds(self, model=None, ext=None):
# Apply bounds to an astropy.modeling.models.Chebyshev1D to indicate
# the range of parameter space to explore
# The default here is 2% tolerance in central wavelength and dispersion
for i, (pname, pvalue) in enumerate(zip(model.param_names, model.parameters)):
if i == 0: # central wavelength
prange = 0.02 * pvalue
elif i == 1: # half the wavelength extent (~dispersion)
prange = 0.02 * abs(pvalue)
else: # higher-order terms
prange = 1
getattr(model, pname).bounds = (pvalue - prange, pvalue + prange)

# -----------------------------------------------------------------------------

Expand Down
2 changes: 1 addition & 1 deletion geminidr/gmos/primitives_gmos_longslit.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ def reconstruct_points(ui_params=None):
extras = {"row": RangeField("Row of data to operate on", int, int(nrows/2), min=1, max=nrows)}
uiparams = UIParameters(config, reinit_params=reinit_params, extras=extras)
visualizer = fit1d.Fit1DVisualizer(reconstruct_points, all_fp_init,
tab_name_fmt="CCD {}",
tab_name_fmt=lambda i: f"CCD {i+1}",
xlabel='x (pixels)', ylabel='counts',
domains=all_domains,
title="Normalize Flat",
Expand Down
12 changes: 12 additions & 0 deletions geminidr/gmos/primitives_gmos_spect.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,3 +581,15 @@ def _get_arc_linelist(self, waves=None):
filename = os.path.join(lookup_dir,
'CuAr_GMOS{}.dat'.format('_mixord' if use_second_order else ''))
return wavecal.LineList(filename)

def _apply_wavelength_model_bounds(self, model=None, ext=None):
# Apply bounds to an astropy.modeling.models.Chebyshev1D to indicate
# the range of parameter space to explore
for i, (pname, pvalue) in enumerate(zip(model.param_names, model.parameters)):
if i == 0: # central wavelength
prange = 10
elif i == 1: # half the wavelength extent (~dispersion)
prange = 0.02 * abs(pvalue)
else: # higher-order terms
prange = 1
getattr(model, pname).bounds = (pvalue - prange, pvalue + prange)
34 changes: 28 additions & 6 deletions geminidr/interactive/fit/fit1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def __init__(
listeners=None,
band_model=None,
extra_masks=None,
default_model=None,
):
"""Create base class with given parameters as initial model inputs.
Expand All @@ -179,6 +180,9 @@ def __init__(
extra_masks : dict of boolean arrays
points to display but not use in fit
default_model : callable
function to evaluate model if self.fit is None
"""
super().__init__()

Expand All @@ -195,6 +199,7 @@ def __init__(
self.domain = domain
self.fit = None
self.listeners = listeners
self.default_model = default_model

self.section = section
self.data = bm.ColumnDataSource({"x": [], "y": [], "mask": []})
Expand Down Expand Up @@ -514,6 +519,12 @@ def update_mask(self):
self.data.data["mask"] = mask

def evaluate(self, x):
if self.fit is None:
# fit_1D.evaluate() always returns an array so we need to also
retval = self.default_model(x)
if isinstance(retval, float):
return np.array([retval])
return retval
return self.fit.evaluate(x)


Expand Down Expand Up @@ -887,7 +898,10 @@ def model_change_handler(self, model):
model : :class:`~geminidr.interactive.fit.fit1d.InteractiveModel1D`
The model that has changed.
"""
rms_str = "--" if np.isnan(model.fit.rms) else f"{model.fit.rms:.4f}"
try:
rms_str = "--" if np.isnan(model.fit.rms) else f"{model.fit.rms:.4f}"
except AttributeError:
rms_str = "--"

rms = (
f'<div class="info_panel">'
Expand Down Expand Up @@ -968,6 +982,7 @@ def __init__(
enable_regions=True,
central_plot=True,
extra_masks=None,
default_model=None,
):
"""Panel for visualizing a 1-D fit, perhaps in a tab.
Expand Down Expand Up @@ -1021,6 +1036,9 @@ def __init__(
extra_masks : dict of boolean arrays
points to display but not use in the fit
default_model : callable
function to evaluate model if self.fit is None
"""
# Just to get the doc later
self.visualizer = visualizer
Expand Down Expand Up @@ -1050,6 +1068,7 @@ def __init__(
weights,
band_model=band_model,
extra_masks=extra_masks,
default_model=default_model,
)

self.model.add_listener(self.model_change_handler)
Expand Down Expand Up @@ -1528,7 +1547,7 @@ def __init__(
fitting_parameters,
modal_message=None,
modal_button_label=None,
tab_name_fmt="{}",
tab_name_fmt=None,
xlabel="x",
ylabel="y",
domains=None,
Expand Down Expand Up @@ -1571,8 +1590,8 @@ def __init__(
If set and if modal_message was set, this will be used for the
label on the recalculate button. It is not required.
tab_name_fmt : str
Format string for naming the tabs
tab_name_fmt : callable
Turns ext.id into a title for the tab name
xlabel : str
String label for X axis
Expand Down Expand Up @@ -1751,6 +1770,9 @@ def kickoff_modal(attr, old, new):

elif turbo_tabs:
self.turbo = TabsTurboInjector(self.tabs)

if tab_name_fmt is None:
tab_name_fmt = lambda i: f"Extension {i+1}"

for i in range(self.nfits):
extra_masks = {}
Expand Down Expand Up @@ -1780,12 +1802,12 @@ def kickoff_modal(attr, old, new):

if turbo_tabs:
self.turbo.add_tab(
tui.component, title=tab_name_fmt.format(i + 1)
tui.component, title=str(tab_name_fmt(i))
)

else:
tab = bm.TabPanel(
child=tui.component, title=tab_name_fmt.format(i + 1)
child=tui.component, title=str(tab_name_fmt(i))
)

self.tabs.tabs.append(tab)
Expand Down
2 changes: 1 addition & 1 deletion geminidr/interactive/fit/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def interactive_trace_apertures(ext, fit1d_params, ui_params: UIParameters):
fitting_parameters=fit_par_list,
help_text=help_text,
primitive_name="traceApertures",
tab_name_fmt="Aperture {}",
tab_name_fmt=lambda i: f"Aperture {i+1}",
title="Interactive Trace Apertures",
xlabel=xlabel,
ylabel=ylabel,
Expand Down
9 changes: 7 additions & 2 deletions geminidr/interactive/fit/wavecal.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from gempy.library.matching import match_sources
from gempy.library.tracing import cwt_ricker, pinpoint_peaks
from gempy.library.fitting import fit_1D

from .fit1d import (
Fit1DPanel,
Expand Down Expand Up @@ -117,7 +118,7 @@ def __init__(
The weights of the fit.
meta : dict
A dictionary of metadata.
A dictionary of metadata. This is the "all_input_data"
kwargs : dict
Any additional keyword arguments.
Expand All @@ -132,6 +133,7 @@ def __init__(
"wavelengths": np.zeros_like(meta["spectrum"]),
"spectrum": meta["spectrum"],
}
kwargs["default_model"] = meta["init_models"][0]

self.spectrum = bm.ColumnDataSource(spectrum_data_dict)

Expand Down Expand Up @@ -483,7 +485,10 @@ def linear_model(model):
"""Return only the linear part of a model. It doesn't work for
splines, which is why it's not in the InteractiveModel1D class
"""
model = model.fit._models
if model.fit is None:
model = model.default_model
else:
model = model.fit.model
new_model = model.__class__(
degree=1, c0=model.c0, c1=model.c1, domain=model.domain
)
Expand Down
11 changes: 6 additions & 5 deletions gempy/library/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,8 @@ def __call__(self, model, in_coords, ref_coords, in_weights=None,
farg = (model_copy, in_coords, tree)
p0, *_ = model_to_fit_params(model_copy)

arg_names = inspect.getfullargspec(self._opt_method).args
argspec = inspect.getfullargspec(self._opt_method)
arg_names, kwarg_names = argspec.args, argspec.kwonlyargs
args = [self.objective_function]
if arg_names[1] == 'x0':
args.append(p0)
Expand All @@ -381,15 +382,15 @@ def __call__(self, model, in_coords, ref_coords, in_weights=None,
else:
raise ValueError("Don't understand argument {}".format(arg_names[1]))

if 'args' in arg_names:
if 'args' in arg_names or 'args' in kwarg_names:
kwargs['args'] = farg

if 'method' in arg_names:
kwargs['method'] = self._method

if 'minimizer_kwargs' in arg_names:
kwargs['minimizer_kwargs'] = {'args': farg,
'method': 'Nelder-Mead'}
#if 'minimizer_kwargs' in arg_names:
# kwargs['minimizer_kwargs'] = {'args': farg,
# 'method': 'Nelder-Mead'}

result = self._opt_method(*args, **kwargs)

Expand Down

0 comments on commit b984d9d

Please sign in to comment.