Skip to content

Commit

Permalink
Implementing BaseTemplate to manage Trace attr mutability
Browse files Browse the repository at this point in the history
  • Loading branch information
bmorris3 committed Oct 25, 2022
1 parent 67b6b8b commit 907457e
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 54 deletions.
6 changes: 3 additions & 3 deletions specreduce/background.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from astropy import units as u

from specreduce.extract import _ap_weight_image, _to_spectrum1d_pixels
from specreduce.tracing import Trace, FlatTrace
from specreduce.tracing import FlatTrace, BaseTrace

__all__ = ['Background']

Expand Down Expand Up @@ -77,7 +77,7 @@ def __post_init__(self):
cross-dispersion axis
"""
def _to_trace(trace):
if not isinstance(trace, Trace):
if not isinstance(trace, BaseTrace):
trace = FlatTrace(self.image, trace)

# TODO: this check can be removed if/when implemented as a check in FlatTrace
Expand All @@ -93,7 +93,7 @@ def _to_trace(trace):
self.bkg_array = np.zeros(self.image.shape[self.disp_axis])
return

if isinstance(self.traces, Trace):
if isinstance(self.traces, BaseTrace):
self.traces = [self.traces]

bkg_wimage = np.zeros_like(self.image, dtype=np.float64)
Expand Down
16 changes: 8 additions & 8 deletions specreduce/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from astropy.nddata import NDData

from specreduce.core import SpecreduceOperation
from specreduce.tracing import Trace, FlatTrace
from specreduce.tracing import FlatTrace, BaseTrace
from specutils import Spectrum1D

__all__ = ['BoxcarExtract', 'HorneExtract', 'OptimalExtract']
Expand Down Expand Up @@ -88,7 +88,7 @@ def _ap_weight_image(trace, width, disp_axis, crossdisp_axis, image_shape):
Parameters
----------
trace : `~specreduce.tracing.Trace`, required
trace : `~specreduce.tracing.BaseTrace`, required
trace object
width : float, required
width of extraction aperture in pixels
Expand Down Expand Up @@ -139,7 +139,7 @@ class BoxcarExtract(SpecreduceOperation):
----------
image : nddata-compatible image
image with 2-D spectral image data
trace_object : Trace
trace_object : BaseTrace
trace object
width : float
width of extraction aperture in pixels
Expand All @@ -154,7 +154,7 @@ class BoxcarExtract(SpecreduceOperation):
The extracted 1d spectrum expressed in DN and pixel units
"""
image: NDData
trace_object: Trace
trace_object: BaseTrace
width: float = 5
disp_axis: int = 1
crossdisp_axis: int = 0
Expand All @@ -173,7 +173,7 @@ def __call__(self, image=None, trace_object=None, width=None,
----------
image : nddata-compatible image
image with 2-D spectral image data
trace_object : Trace
trace_object : BaseTrace
trace object
width : float
width of extraction aperture in pixels [default: 5]
Expand Down Expand Up @@ -230,7 +230,7 @@ class HorneExtract(SpecreduceOperation):
NDData object must specify uncertainty and a mask. An array
requires use of the ``variance``, ``mask``, & ``unit`` arguments.
trace_object : `~specreduce.tracing.Trace`, required
trace_object : `~specreduce.tracing.BaseTrace`, required
The associated 1D trace object created for the 2D image.
disp_axis : int, optional
Expand Down Expand Up @@ -264,7 +264,7 @@ class HorneExtract(SpecreduceOperation):
"""
image: NDData
trace_object: Trace
trace_object: BaseTrace
bkgrd_prof: Model = field(default=models.Polynomial1D(2))
variance: np.ndarray = field(default=None)
mask: np.ndarray = field(default=None)
Expand Down Expand Up @@ -293,7 +293,7 @@ def __call__(self, image=None, trace_object=None,
NDData object must specify uncertainty and a mask. An array
requires use of the ``variance``, ``mask``, & ``unit`` arguments.
trace_object : `~specreduce.tracing.Trace`, required
trace_object : `~specreduce.tracing.BaseTrace`, required
The associated 1D trace object created for the 2D image.
disp_axis : int, optional
Expand Down
144 changes: 101 additions & 43 deletions specreduce/tracing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst

from copy import deepcopy
from dataclasses import dataclass
from dataclasses import dataclass, field
import warnings

from astropy.modeling import fitting, models
Expand All @@ -10,29 +10,22 @@
from scipy.interpolate import UnivariateSpline
import numpy as np

__all__ = ['Trace', 'FlatTrace', 'ArrayTrace', 'KosmosTrace']
__all__ = ['BaseTrace', 'Trace', 'FlatTrace', 'ArrayTrace', 'KosmosTrace']


@dataclass
class Trace:
@dataclass(frozen=True)
class BaseTrace:
"""
Basic tracing class that by default traces the middle of the image.
Parameters
----------
image : `~astropy.nddata.CCDData`
Image to be traced
Properties
----------
shape : tuple
Shape of the array describing the trace
A dataclass common to all Trace objects.
"""
image: CCDData
_trace_pos: (float, np.ndarray) = field(repr=False)
_trace: np.ndarray = field(repr=False)

def __post_init__(self):
self.trace_pos = self.image.shape[0] / 2
self.trace = np.ones_like(self.image[0]) * self.trace_pos
# this class only exists to catch __post_init__ calls in its
# subclasses, so that super().__post_init__ calls work correctly.
pass

def __getitem__(self, i):
return self.trace[i]
Expand All @@ -59,7 +52,7 @@ def _bound_trace(self):
Mask trace positions that are outside the upper/lower bounds of the image.
"""
ny = self.image.shape[0]
self.trace = np.ma.masked_outside(self.trace, 0, ny-1)
object.__setattr__(self, '_trace', np.ma.masked_outside(self._trace, 0, ny - 1))

def __add__(self, delta):
"""
Expand All @@ -77,9 +70,60 @@ def __sub__(self, delta):
"""
return self.__add__(-delta)

def shift(self, delta):
"""
Shift the trace by delta pixels perpendicular to the axis being traced
Parameters
----------
delta : float
Shift to be applied to the trace
"""
# act on self._trace.data to ignore the mask and then re-mask when calling _bound_trace
object.__setattr__(self, '_trace', np.asarray(self._trace.data) + delta)
object.__setattr__(self, '_trace_pos', self._trace_pos + delta)
self._bound_trace()

@property
def shape(self):
return self._trace.shape

@property
def trace(self):
return self._trace

@property
def trace_pos(self):
return self._trace_pos

@staticmethod
def _default_trace_attrs(image):
"""
Compute a default trace position and trace array using only
the image dimensions.
"""
trace_pos = image.shape[0] / 2
trace = np.ones_like(image[0]) * trace_pos
return trace_pos, trace


@dataclass(init=False, frozen=True)
class Trace(BaseTrace):
"""
Basic tracing class that by default traces the middle of the image.
Parameters
----------
image : `~astropy.nddata.CCDData`
Image to be traced
"""
def __init__(self, image):
trace_pos, trace = self._default_trace_attrs(image)
super().__init__(image, trace_pos, trace)


@dataclass
class FlatTrace(Trace):
@dataclass(init=False, frozen=True)
class FlatTrace(BaseTrace):
"""
Trace that is constant along the axis being traced
Expand All @@ -92,10 +136,11 @@ class FlatTrace(Trace):
trace_pos : float
Position of the trace
"""
trace_pos: float

def __post_init__(self):
self.set_position(self.trace_pos)
def __init__(self, image, trace_pos):
_, trace = self._default_trace_attrs(image)
super().__init__(image, trace_pos, trace)
self.set_position(trace_pos)

def set_position(self, trace_pos):
"""
Expand All @@ -106,13 +151,13 @@ def set_position(self, trace_pos):
trace_pos : float
Position of the trace
"""
self.trace_pos = trace_pos
self.trace = np.ones_like(self.image[0]) * self.trace_pos
object.__setattr__(self, '_trace_pos', trace_pos)
object.__setattr__(self, '_trace', np.ones_like(self.image[0]) * trace_pos)
self._bound_trace()


@dataclass
class ArrayTrace(Trace):
@dataclass(init=False, frozen=True)
class ArrayTrace(BaseTrace):
"""
Define a trace given an array of trace positions
Expand All @@ -121,25 +166,27 @@ class ArrayTrace(Trace):
trace : `numpy.ndarray`
Array containing trace positions
"""
trace: np.ndarray
def __init__(self, image, trace):
trace_pos, _ = self._default_trace_attrs(image)
super().__init__(image, trace_pos, trace)

def __post_init__(self):
nx = self.image.shape[1]
nt = len(self.trace)
nt = len(trace)
if nt != nx:
if nt > nx:
# truncate trace to fit image
self.trace = self.trace[0:nx]
trace = trace[0:nx]
else:
# assume trace starts at beginning of image and pad out trace to fit.
# padding will be the last value of the trace, but will be masked out.
padding = np.ma.MaskedArray(np.ones(nx - nt) * self.trace[-1], mask=True)
self.trace = np.ma.hstack([self.trace, padding])
padding = np.ma.MaskedArray(np.ones(nx - nt) * trace[-1], mask=True)
trace = np.ma.hstack([trace, padding])
object.__setattr__(self, '_trace', trace)
self._bound_trace()


@dataclass
class KosmosTrace(Trace):
@dataclass(init=False, frozen=True)
class KosmosTrace(BaseTrace):
"""
Trace the spectrum aperture in an image.
Expand Down Expand Up @@ -192,14 +239,25 @@ class KosmosTrace(Trace):
4) add other interpolation modes besides spline, maybe via
specutils.manipulation methods?
"""
bins: int = 20
guess: float = None
window: int = None
peak_method: str = 'gaussian'
bins: int
guess: float
window: int
peak_method: str
_crossdisp_axis = 0
_disp_axis = 1

def __post_init__(self):
def _process_init_kwargs(self, **kwargs):
for attr, value in kwargs.items():
object.__setattr__(self, attr, value)

def __init__(self, image, bins=20, guess=None, window=None, peak_method='gaussian'):
# This method will assign the user supplied value (or default) to the attrs:
self._process_init_kwargs(
bins=bins, guess=guess, window=window, peak_method=peak_method
)
trace_pos, trace = self._default_trace_attrs(image)
super().__init__(image, trace_pos, trace)

# handle multiple image types and mask uncaught invalid values
if isinstance(self.image, NDData):
img = np.ma.masked_invalid(np.ma.masked_array(self.image.data,
Expand All @@ -223,7 +281,7 @@ def __post_init__(self):

if not isinstance(self.bins, int):
warnings.warn('TRACE: Converting bins to int')
self.bins = int(self.bins)
object.__setattr__(self, 'bins', int(self.bins))

if self.bins < 4:
raise ValueError('bins must be >= 4')
Expand All @@ -240,7 +298,7 @@ def __post_init__(self):
"length of the image's spatial direction")
elif self.window is not None and not isinstance(self.window, int):
warnings.warn('TRACE: Converting window to int')
self.window = int(self.window)
object.__setattr__(self, 'window', int(self.window))

# set max peak location by user choice or wavelength with max avg flux
ztot = img.sum(axis=self._disp_axis) / img.shape[self._disp_axis]
Expand Down Expand Up @@ -343,4 +401,4 @@ def __post_init__(self):
warnings.warn("TRACE ERROR: No valid points found in trace")
trace_y = np.tile(np.nan, len(x_bins))

self.trace = np.ma.masked_invalid(trace_y)
object.__setattr__(self, '_trace', np.ma.masked_invalid(trace_y))

0 comments on commit 907457e

Please sign in to comment.