In [1]:
from keckdrpframework.primitives.base_primitive import BasePrimitive

In [2]:
from astropy.io import fits
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from astropy.visualization import ZScaleInterval
from astropy.visualization.mpl_normalize import ImageNormalize

In [3]:
class OpenFits(BasePrimitive):
    """ Opens fits file """
    
    def __init__(self, action, context):
        BasePrimitive.__init__(self, action, context)
        self.logger = context.pipeline_logger
        
    def _perform(self):
        self.action.args.hdul = fits.open(self.action.args.filename)
        self.action.args.handled = 0
        return self.action.args

In [4]:
class Branch(BasePrimitive):
    """ Push events to HPQ  """
    
    def __init__(self, action, context):
        BasePrimitive.__init__(self, action, context)
        self.logger = contest.pipeline_logger
        
    def _perform(self):
        self.action.args.extension = np.arange(1,9)
        
        for ext in self.action.args.extension:
            self.push_event('get_header', self.action.args.hdul[ext])
        return self.action.args

In [5]:
class GetHeader(BasePrimitive):
    """ Get header of fits file """
    
    def __init__(self, action, context):
        BasePrimitive.__init__(self, action, context)
        self.logger = context.pipeline_logger
        
    def _perform(self):
        hdr0 = self.action.args.hdul[self.action.args.extension].header
        return self.action.args

In [6]:
class GetBinning(BasePrimitive):
    """ Get binning dimensions """
    
    def __init__(self, action, context):
        BasePrimitive.__init__(self, action, context)
        self.logger = context.pipeline_logger
    
    def _perform(self):
        self.action.args.binning = self.action.args.hdr0['BINNING'].split(',')
        return self.section.args

In [7]:
class GetOverscan(BasePrimitive):
    """ Get the parameters of the overscan region of each detector """
    
    def __init__(self, action, context):
        BasePrimitive.__init__(self, action, context)
        self.logger = context.pipeline_logger
    
    def _perform(self):
        self.action.args.precol   = int(self.action.args.hdr0['PRECOL'])   // int(self.action.args.binning[0])
        self.action.args.postpix  = int(self.action.args.hdr0['PPOSTPIX']) // int(self.action.args.binning[0])
        self.action.args.preline  = int(self.action.args.hdr0['PRELINE'])  // int(self.action.args.binning[1])
        self.action.args.postline = int(self.action.args.hdr0['POSTLINE']) // int(self.action.args.binning[1])
        return self.action.args

In [8]:
class GetBias(BasePrimitive):
    """ Get the bias from the overscan region """
    
    def __init__(self, action, context):
        BasePrimitive.__init__(self, action, context)
        self.logger = context.pipeline_logger
        
    def _perform(self):
        self.action.args.x1 = 0
        self.action.args.x2 = self.action.args.height
        self.action.args.y1 = self.action.args.width - self.action.args.postpix + 1
        self.action.args.y2 = self.action.args.width
        
        self.action.args.bias = np.median(self.action.args.data[x1:x2, y1:y2], axis=1)
        self.action.args.bias = np.array(self.action.args.bias, dtype=np.int64)
        return self.action.args

In [9]:
class PlotBias(BasePrimitive):
    """ Plot bias' of each CCD """
    
    def __init__(self, action, context):
        BasePrimitive.__init__(self, action, context)
        self.logger = context.pipeline_logger
        
    def _perform(self):
        plt.plot(self.action.args.bias)
        plt.xlim(xmin=-500, xmax=4500)
        plt.ylim(ymin=800, ymax=1800)
        plt.xlabel('PIXEL')
        plt.ylabel('#COUNTS')
        plt.show()

In [10]:
class BiasSubtraction(BasePrimitive):
    """ Subtract the bias from the detector data """
    
    def __init__(self, action, context):
        BasePrimitive.__init__(self, action, context)
        self.logger = context.pipeline_logger
        
    def _perform(self):
        self.action.args.data = self.action.args.data - self.action.args.bias[:, None]
        return self.action.args

In [11]:
class GetMinMax(BasePrimitive):
    """ Get the minimum and maximum for zscaling -- not including overscan regions """
    def __init__(self, action, context):
        BasePrimitive.__init__(self, action, context)
        self.logger = context.pipeline_logger
        
    def _perform(self):
        self.action.args.x1 = int(self.action.args.preline                             + (self.action.args.height * 0.10))
        self.action.args.x2 = int(self.action.args.height  - self.action.args.postline - (self.action.args.height * 0.10))
        self.action.args.y1 = int(self.action.args.precol                              + (self.action.args.width * 0.10))
        self.action.args.y2 = int(self.action.args.width   - self.action.args.postpix  - (self.action.args.width * 0.10))
        
        tmp_vmin, tmp_vmax = self.action.args.zscale.get_limits(self.action.args.data[x1:x1, y1:y2])
        if self.action.args.vmin == None or self.action.args.tmp_vmin < self.action.args.vmin: self.action.args.vmin = self.action.args.tmp_vmin
        if self.action.args.vmax == None or self.action.args.tmp_vmax < self.action.args.vmax: self.action.args.vmin = self.action.args.tmp_vmax
        if self.action.args.vmin < 0: self.action.args.vmin = 0
        return self.action.args

In [12]:
class RemoveOverscan(BasePrimitive):
    """ Remove overscan regions from CCD images """
    
    def __init__(self, action, context):
        BasePrimitive.__init__(self, action, context)
        self.logger = context.pipeline_logger
        
    def _perform(self):
        self.action.args.data = self.action.args.data[:, self.action.args.precol: self.action.args.width - self.action.args.postpix]
        self.action.args.handled += 1
        
        self.action.args.alldata = []
        self.action.args.alldata.append(self.action.args.data)
        return self.action.args

In [13]:
class CreateBottomRow(BasePrimitive):
    """ Creating first row of detector mosaic -- ext#: 1 - 4 """
    def __init__(self, action, context):
        BasePrimitive.__init__(self, action, context)
        self.logger = context.pipeline_logger
        
    def _pre_condition(self):
        
        if self.action.args.handled < 8: 
            return False
        else:
            return True
    
    def _perform(self):
        self.action.args.bottom = np.concatenate(self.action.args.alldata[:4], axis=1)
        self.action.args.bottom = np.flipud(self.action.args.bottom)
        return self.action.args

In [14]:
class CreateTopRow(BasePrimitive):
    """ Creating second row of detector mosaic -- ext#: 4 - 8 """
    def __init__(self, action, context):
        BasePrimitive.__init__(self, action, context)
        self.logger = context.pipeline_logger
        
    def _perform(self):
        self.action.args.top = []
        for self.action.args.arr in self.action.args.alldata[4:]:
            self.action.args.arr = np.flip(self.action.args.arr)
            self.action.args.top.append(self.action.args.arr)
        self.action.args.top = np.concatenate(self.action.args.top, axis=1)
        return self.action.args

In [15]:
class StackRows(BasePrimitive):
    """ Stacking top and bottom rows of mosaic together and rotating the final image """
    
    def __init__(self, action, context):
        BasePrimitive.__init__(self, action, context)
        self.logger = context.pipeline_logger
        
    def _perform(self):
        self.action.args.fulldata = np.concatenate((self.action.args.top, self.action.args.bottom), axis=0)
        self.action.args.fulldata = np.rot90(self.action.args.fulldata)
        return self.action.args

In [16]:
class DetectorMosaic(BasePrimitive):
    """ Creating mosaic of DEIMOS 8 CCDs """
    
    def __init__(self, action, context):
        BasePrimitive.__init__(self, action, context)
        self.logger = context.pipeline_logger
        
    def _perform(self):
        self.action.args.norm = ImageNormalize(self.action.args.data, 
                              self.action.args.vmin == self.action.args.vmin, 
                              self.action.args.vmax == self.action.args.vmax)
        fig = plt.figure(frameon=False)
        ax = fig.add_axes([0,0,1,1])
        plt.axis('off')
        plt.imshow(self.action.args.data, cmap='gray', norm=self.action.args.norm)
        plt.savefig(self.action.args.filename, + '.jpg', dpi=300)

In [17]:
event_table = {
    # Opens file
    "open_fits_file":                 ("OpenFits",
                                       "opening_fits_file",
                                       "branch"),
    
    "branch":                         ("Branch",
                                       "getting_extensions",
                                        None),
    
    # Gets info for each extension, computes bias, and subtracts overscan region
    "get_header":                     ("GetHeader",
                                       "getting_fits_header",
                                       "get_binning"),
    
    "get_binning":                    ("GetBinning",
                                       "getting_binning_dimensions",
                                       "get_overscan"),
    
    "get_overscan":                   ("GetOverscan",
                                       "getting_overscan_regions",
                                       "get_bias"),
    
    "get_bias":                       ("GetBias",
                                       "getting_bias",
                                       "plot_bias"),
    
    "plot_bias":                      ("PlotBias",
                                       "plotting_bias",
                                       "bias_subtraction"),
    
    "bias_subtraction":               ("BiasSubtraction",
                                       "bias_subtraction_starting",
                                        "get_min_max"),
    
    "get_min_max":                    ("GetMinMax",
                                       "getting_vmin_vmax",
                                       "remove_overscan_regions"),
    
    "remove_overscan_regions":        ("RemoveOverscan",
                                       "removing_overscan_regions",
                                       "create_bottom_row"),
    
    # Creates mosaic
    "create_bottom_row":              ("CreateBottomRow",
                                       "creating_bottom_row",
                                       "create_top_row"),
    
    "create_top_row":                 ("CreateTopRow",
                                       "creating_top_row",
                                       "stack_rows"),
    
    "stack_rows":                     ("StackRows",
                                       "stacking_rows",
                                       "create_mosaic"),
    
    "create_mosaic":                  ("DetectorMosaic",
                                       "creating_full_detector_mosaic",
                                       None),
}

##