Skip to content

Commit

Permalink
Resort to median fitting of the sky to avoid an infinite loop in subt…
Browse files Browse the repository at this point in the history
…ractSky. Triggered by EXTHD-193.
  • Loading branch information
KathleenLabrie committed Jul 28, 2023
1 parent ee7751b commit f639bf3
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 5 deletions.
1 change: 1 addition & 0 deletions geminidr/core/parameters_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class subtractSkyConfig(config.Config):
offset_sky = config.Field("Apply offset to sky frame to match science frame?", bool, False)
sky = config.ListField("Sky frame to subtract", (str, AstroData), None, optional=True, single=True)
save_sky = config.Field("Save sky frame to disk?", bool, False)
debug_threshold = config.Field("Convergence threshold when scaling", float, 0.001)


class skyCorrectConfig(parameters_stack.stackSkyFramesConfig, subtractSkyConfig):
Expand Down
27 changes: 24 additions & 3 deletions geminidr/core/primitives_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from scipy.interpolate import interp1d
from scipy.ndimage import binary_dilation

from gempy.utils.errors import ConvergenceError

from . import parameters_preprocess


Expand Down Expand Up @@ -1626,6 +1628,7 @@ def subtractSky(self, adinputs=None, **params):
reset_sky = params["reset_sky"]
scale = params["scale_sky"]
zero = params["offset_sky"]
debug_threshold = params["debug_threshold"]
if scale and zero:
log.warning("Both the scale_sky and offset_sky parameters are set. "
"Setting offset_sky=False.")
Expand All @@ -1635,7 +1638,7 @@ def subtractSky(self, adinputs=None, **params):
# in gt.measure_bg_from_image()
sampling = 1 if adinputs[0].instrument() == 'GNIRS' else 10
skyfunc = partial(gt.measure_bg_from_image, value_only=True,
sampling=sampling)
sampling=sampling, gaussfit=True)

for ad, ad_sky in zip(*gt.make_lists(adinputs, params["sky"],
force_ad=True)):
Expand All @@ -1652,8 +1655,26 @@ def subtractSky(self, adinputs=None, **params):
f"the science frame {ad.filename}")
if scale or zero:
# This actually does the sky subtraction as well
factors = [gt.sky_factor(ext, ext_sky, skyfunc, multiplicative=scale)
for ext, ext_sky in zip(ad, ad_sky)]
try:
factors = [gt.sky_factor(ext, ext_sky, skyfunc,
multiplicative=scale,
threshold=debug_threshold)
for ext, ext_sky in zip(ad, ad_sky)]
except ConvergenceError as error:
log.warning(f"The scaling of sky using a gaussian fit "
f"did not converge. \n"
f"Using the median method instead.")
skyfunc = partial(gt.measure_bg_from_image,
value_only=True,
sampling=sampling, gaussfit=False)
try:
factors = [gt.sky_factor(ext, ext_sky, skyfunc,
multiplicative=scale)
for ext, ext_sky in zip(ad, ad_sky)]
except ConvergenceError as error:
log.error(f"Failed to scaled sky.")
raise(error)

for ext_sky, factor in zip(ad_sky, factors):
log.fullinfo("Applying {} of {} to extension {}".
format("scaling" if scale else "offset",
Expand Down
19 changes: 17 additions & 2 deletions gempy/gemini/gemini_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import astrodata
from astrodata import Section

from gempy.utils.errors import ConvergenceError

ArrayInfo = namedtuple("ArrayInfo", "detector_shape origins array_shapes "
"extensions")

Expand Down Expand Up @@ -1993,21 +1995,34 @@ def sky_factor(nd1, nd2, skyfunc, multiplicative=False, threshold=0.001):
-------
float : factor to apply to "sky" to match "science"
"""
log = logutils.get_logger(__name__)

factor = 0
if multiplicative:
current_sky = 1
current_sky = 1.
# A subtlety here: deepcopy-ing an AD slice will create a full AD
# object, and so skyfunc() will return a list instead of the single
# float value we want. So make sure the copy is a single slice too
if isinstance(nd1, astrodata.AstroData) and nd1.is_single:
ndcopy = deepcopy(nd1)[0]
else:
ndcopy = deepcopy(nd1)
while abs(current_sky) > threshold:
iter = 1
max_iter = 100 # normally converges in < 10 iterations
while abs(current_sky) > threshold and iter <= max_iter:
f = skyfunc(ndcopy) / skyfunc(nd2)
ndcopy.subtract(nd2.multiply(f))
current_sky *= f
factor += current_sky
iter += 1
#print('iter upon exit: ', iter)
if iter > max_iter:
log.warning(f"Failed to converge.\n"
f" Reached: {abs(current_sky)} while threshold = {threshold}\n"
f" Final factor = {factor}")
nd2.divide(current_sky) # reset to original value
raise(ConvergenceError)

nd1.subtract(nd2.multiply(factor / current_sky))
nd2.divide(factor) # reset to original value
else:
Expand Down
2 changes: 2 additions & 0 deletions gempy/utils/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class ConvergenceError(RuntimeError):
pass

0 comments on commit f639bf3

Please sign in to comment.