Skip to content

Commit

Permalink
Merge pull request #86 from kecnry/background-subtract
Browse files Browse the repository at this point in the history
Basic (boxcar) background subtraction
  • Loading branch information
tepickering committed Mar 29, 2022
2 parents 38ff816 + 623248b commit dce3ca4
Show file tree
Hide file tree
Showing 5 changed files with 685 additions and 86 deletions.
383 changes: 355 additions & 28 deletions notebook_sandbox/jwst_boxcar/boxcar_extraction.ipynb

Large diffs are not rendered by default.

209 changes: 209 additions & 0 deletions specreduce/background.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst

from dataclasses import dataclass, field

import numpy as np
from astropy.nddata import NDData

from specreduce.extract import _ap_weight_image
from specreduce.tracing import Trace, FlatTrace

__all__ = ['Background']


@dataclass
class Background:
"""
Determine the background from an image for subtraction
Parameters
----------
image : `~astropy.nddata.NDData` or array-like
image with 2-D spectral image data
traces : List
list of trace objects (or integers to define FlatTraces) to
extract the background
width : float
width of extraction aperture in pixels
statistic: string
statistic to use when computing the background. 'average' will
account for partial pixel weights, 'median' will include all partial
pixels.
disp_axis : int
dispersion axis
crossdisp_axis : int
cross-dispersion axis
"""
# required so numpy won't call __rsub__ on individual elements
# https://stackoverflow.com/a/58409215
__array_ufunc__ = None

image: NDData
traces: list = field(default_factory=list)
width: float = 5
statistic: str = 'average'
disp_axis: int = 1
crossdisp_axis: int = 0

def __post_init__(self):
"""
Determine the background from an image for subtraction.
Parameters
----------
image : `~astropy.nddata.NDData` or array-like
image with 2-D spectral image data
traces : List
list of trace objects (or integers to define FlatTraces) to
extract the background
width : float
width of each background aperture in pixels
statistic: string
statistic to use when computing the background. 'average' will
account for partial pixel weights, 'median' will include all partial
pixels.
disp_axis : int
dispersion axis
crossdisp_axis : int
cross-dispersion axis
"""
def _to_trace(trace):
if not isinstance(trace, Trace):
trace = FlatTrace(self.image, trace)

# TODO: this check can be removed if/when implemented as a check in FlatTrace
if isinstance(trace, FlatTrace):
if trace.trace_pos < 1:
raise ValueError('trace_object.trace_pos must be >= 1')
return trace

bkg_wimage = np.zeros_like(self.image, dtype=np.float64)
for trace in self.traces:
trace = _to_trace(trace)
bkg_wimage += _ap_weight_image(trace,
self.width,
self.disp_axis,
self.crossdisp_axis,
self.image.shape)

if np.any(bkg_wimage > 1):
raise ValueError("background regions overlapped")

if self.statistic == 'median':
# make it clear in the expose image that partial pixels are fully-weighted
bkg_wimage[bkg_wimage > 0] = 1

self.bkg_wimage = bkg_wimage
if self.statistic == 'average':
self.bkg_array = np.average(self.image, weights=self.bkg_wimage, axis=0)
elif self.statistic == 'median':
med_image = self.image.copy()
med_image[np.where(self.bkg_wimage) == 0] = np.nan
self.bkg_array = np.nanmedian(med_image, axis=0)
else:
raise ValueError("statistic must be 'average' or 'median'")

@classmethod
def two_sided(cls, image, trace_object, separation, **kwargs):
"""
Determine the background from an image for subtraction centered around
an input trace.
Parameters
----------
image : nddata-compatible image
image with 2-D spectral image data
trace_object: Trace
estimated trace of the spectrum to center the background traces
separation: float
separation from ``trace_object`` for the background regions
width : float
width of each background aperture in pixels
statistic: string
statistic to use when computing the background. 'average' will
account for partial pixel weights, 'median' will include all partial
pixels.
disp_axis : int
dispersion axis
crossdisp_axis : int
cross-dispersion axis
"""
kwargs['traces'] = [trace_object-separation, trace_object+separation]
return cls(image=image, **kwargs)

@classmethod
def one_sided(cls, image, trace_object, separation, **kwargs):
"""
Determine the background from an image for subtraction above
or below an input trace.
Parameters
----------
image : nddata-compatible image
image with 2-D spectral image data
trace_object: Trace
estimated trace of the spectrum to center the background traces
separation: float
separation from ``trace_object`` for the background, positive will be
above the trace, negative below.
width : float
width of each background aperture in pixels
statistic: string
statistic to use when computing the background. 'average' will
account for partial pixel weights, 'median' will include all partial
pixels.
disp_axis : int
dispersion axis
crossdisp_axis : int
cross-dispersion axis
"""
kwargs['traces'] = [trace_object+separation]
return cls(image=image, **kwargs)

def bkg_image(self, image=None):
"""
Expose the background tiled to the dimension of ``image``.
Parameters
----------
image : nddata-compatible image or None
image with 2-D spectral image data. If None, will use ``image`` passed
to extract the background.
Returns
-------
array with same shape as ``image``.
"""
if image is None:
image = self.image

return np.tile(self.bkg_array, (image.shape[0], 1))

def sub_image(self, image=None):
"""
Subtract the computed background from ``image``.
Parameters
----------
image : nddata-compatible image or None
image with 2-D spectral image data. If None, will use ``image`` passed
to extract the background.
Returns
-------
array with same shape as ``image``
"""
if image is None:
image = self.image

if isinstance(image, NDData):
# https://docs.astropy.org/en/stable/nddata/mixins/ndarithmetic.html
return image.subtract(self.bkg_image(image)*image.unit)
else:
return image - self.bkg_image(image)

def __rsub__(self, image):
"""
Subtract the background from an image.
"""
return self.sub_image(image)
118 changes: 60 additions & 58 deletions specreduce/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,66 @@
__all__ = ['BoxcarExtract', 'HorneExtract', 'OptimalExtract']


def _get_boxcar_weights(center, hwidth, npix):
"""
Compute weights given an aperture center, half width,
and number of pixels
"""
weights = np.zeros((npix))

# pixels with full weight
fullpixels = [max(0, int(center - hwidth + 1)),
min(int(center + hwidth), npix)]
weights[fullpixels[0]:fullpixels[1]] = 1.0

# pixels at the edges of the boxcar with partial weight
if fullpixels[0] > 0:
w = hwidth - (center - fullpixels[0] + 0.5)
if w >= 0:
weights[fullpixels[0] - 1] = w
else:
weights[fullpixels[0]] = 1. + w
if fullpixels[1] < npix:
weights[fullpixels[1]] = hwidth - (fullpixels[1] - center - 0.5)

return weights


def _ap_weight_image(trace, width, disp_axis, crossdisp_axis, image_shape):

"""
Create a weight image that defines the desired extraction aperture.
Parameters
----------
trace : Trace
trace object
width : float
width of extraction aperture in pixels
disp_axis : int
dispersion axis
crossdisp_axis : int
cross-dispersion axis
image_shape : tuple with 2 elements
size (shape) of image
Returns
-------
wimage : 2D image
weight image defining the aperture
"""
wimage = np.zeros(image_shape)
hwidth = 0.5 * width
image_sizes = image_shape[crossdisp_axis]

# loop in dispersion direction and compute weights.
for i in range(image_shape[disp_axis]):
# TODO trace must handle transposed data (disp_axis == 0)
wimage[:, i] = _get_boxcar_weights(trace[i], hwidth, image_sizes)

return wimage


@dataclass
class BoxcarExtract(SpecreduceOperation):
"""
Expand Down Expand Up @@ -67,64 +127,6 @@ def __call__(self, image, trace_object, width=5,
The extracted 1d spectrum with flux expressed in the same
units as the input image, or u.DN, and pixel units
"""
def _get_boxcar_weights(center, hwidth, npix):
"""
Compute weights given an aperture center, half width,
and number of pixels
"""
weights = np.zeros((npix))

# pixels with full weight
fullpixels = [max(0, int(center - hwidth + 1)),
min(int(center + hwidth), npix)]
weights[fullpixels[0]:fullpixels[1]] = 1.0

# pixels at the edges of the boxcar with partial weight
if fullpixels[0] > 0:
w = hwidth - (center - fullpixels[0] + 0.5)
if w >= 0:
weights[fullpixels[0] - 1] = w
else:
weights[fullpixels[0]] = 1. + w
if fullpixels[1] < npix:
weights[fullpixels[1]] = hwidth - (fullpixels[1] - center - 0.5)

return weights

def _ap_weight_image(trace, width, disp_axis, crossdisp_axis, image_shape):

"""
Create a weight image that defines the desired extraction aperture.
Parameters
----------
trace : Trace
trace object
width : float
width of extraction aperture in pixels
disp_axis : int
dispersion axis
crossdisp_axis : int
cross-dispersion axis
image_shape : tuple with 2 elements
size (shape) of image
Returns
-------
wimage : 2D image
weight image defining the aperture
"""
wimage = np.zeros(image_shape)
hwidth = 0.5 * width
image_sizes = image_shape[crossdisp_axis]

# loop in dispersion direction and compute weights.
for i in range(image_shape[disp_axis]):
# TODO trace must handle transposed data (disp_axis == 0)
wimage[:, i] = _get_boxcar_weights(trace[i], hwidth, image_sizes)

return wimage

# TODO: this check can be removed if/when implemented as a check in FlatTrace
if isinstance(trace_object, FlatTrace):
if trace_object.trace_pos < 1:
Expand Down
44 changes: 44 additions & 0 deletions specreduce/tests/test_background.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import numpy as np

import astropy.units as u
from astropy.nddata import CCDData

from specreduce.background import Background
from specreduce.tracing import FlatTrace


# NOTE: same test image as in test_extract.py
# Test image is comprised of 30 rows with 10 columns each. Row content
# is row index itself. This makes it easy to predict what should be the
# value extracted from a region centered at any arbitrary Y position.
image = np.ones(shape=(30, 10))
for j in range(image.shape[0]):
image[j, ::] *= j
image = CCDData(image, unit=u.Jy)


def test_background():
#
# Try combinations of extraction center, and even/odd
# extraction aperture sizes.
#
trace_pos = 15.0
trace = FlatTrace(image, trace_pos)
bkg_sep = 5
bkg_width = 2
# all the following should be equivalent:
bg1 = Background(image, [trace-bkg_sep, trace+bkg_sep], width=bkg_width)
bg2 = Background.two_sided(image, trace, bkg_sep, width=bkg_width)
bg3 = Background.two_sided(image, trace_pos, bkg_sep, width=bkg_width)
assert np.allclose(bg1.bkg_array, bg2.bkg_array)
assert np.allclose(bg1.bkg_array, bg3.bkg_array)

# test that creating a one_sided background works
Background.one_sided(image, trace, bkg_sep, width=bkg_width)

# test that image subtraction works
sub1 = image - bg1
sub2 = bg1.sub_image(image)
sub3 = bg1.sub_image()
assert np.allclose(sub1, sub2)
assert np.allclose(sub1, sub3)
Loading

0 comments on commit dce3ca4

Please sign in to comment.