In [None]:
import sys
import os.path
import glob

from core.framework import Framework
from config.framework_config import ConfigClass

from pipelines.kcwi_pipeline import Kcwi_pipeline
from models.arguments import Arguments
import subprocess
import time

from pipelines.base_pipeline import Base_pipeline

from primitives.kcwi_primitives import *
from primitives.kcwi_file_primitives import *
global arcs

In [None]:
class Kcwi_pipeline(Base_pipeline):
    """
    Pipeline to process KCWI data

    """

    event_table = {
        "next_file": ("ingest_file", "file_ingested", "file_ingested"),
        "file_ingested": ("action_planner", None, None),
        # BIAS
        "process_bias": ("process_bias", None, None),
        # CONTBARS PROCESSING
        "process_contbars": ("process_contbars", "contbars_processing_started", "contbar_subtract_overscan"),
        "contbar_subtract_overscan": ("subtract_overscan", "subtract_overscan_started", "contbar_trim_overscan"),
        "contbar_trim_overscan": ("trim_overscan", "trim_overscan_started", "contbar_correct_gain"),
        "contbar_correct_gain": ("correct_gain", "gain_correction_started", "contbar_find_bars"),
        "contbar_find_bars": ("find_bars", "find_bars_started", "contbar_trace_bars"),
        "contbar_trace_bars": ("trace_bars", "trace_bars_started", None),
        # ARCS PROCESSING
        "process_arc": ("process_arc", "arcs_processing_started", "arcs_subtract_overscan"),
        "arcs_subtract_overscan": ("subtract_overscan", "subtract_overscan_started", "arcs_trim_overscan"),
        "arcs_trim_overscan": ("trim_overscan", "trim_overscan_started", "arcs_correct_gain"),
        "arcs_correct_gain": ("correct_gain", "gain_correction_started", "arcs_extract_arcs"),
        "arcs_extract_arcs": ("extract_arcs", "extract_arcs_started", "arcs_arc_offsets"),
        "arcs_arc_offsets":  ("arc_offsets", "arc_offset_started", "arcs_calc_prelim_disp"),
        "arcs_calc_prelim_disp": ("calc_prelim_disp", "prelim_disp_started", "arcs_read_atlas"),
        "arcs_read_atlas": ("read_atlas", "read_atlas_started", "arcs_fit_center"),
        "arcs_fit_center": ("fit_center_profile", "fit_center_started", None),
        # FLAT
        "process_flat": ("process_flat", None, None),
        #"process_object": ("process_object", None, "save_png"),
        #"save_png": ("save_png", None, None)
    }

    #event_table = kcwi_event_table



    def __init__(self):
        """
        Constructor
        """
        Base_pipeline.__init__(self)
        self.cnt = 0

    def action_planner (self, action, context):
        self.logger.info("******* FILE TYPE DETERMINED AS %s" % action.args.imtype)
        groupid = action.args.groupid
        self.logger.info("******* GROUPID is %s " % action.args.groupid)
        if action.args.imtype == "BIAS":
            bias_args = Arguments(name="bias_args",
                                  groupid = groupid,
                                  want_type="BIAS",
                                  new_type="MASTER_BIAS",
                                  min_files=context.config.instrument.bias_min_nframes,
                                  new_file_name="master_bias_%s.fits" % groupid)
            context.push_event("process_bias", bias_args)
        elif "CONTBARS" in action.args.imtype:
            context.push_event("process_contbars", action.args)
        elif "FLAT" in action.args.imtype:
            context.push_event("process_flat", action.args)
        elif "ARCLAMP" in action.args.imtype:
            context.push_event("process_arc", action.args)
        elif "OBJECT" in action.args.imtype:
            context.push_event("process_object", action.args)


In [None]:
pipeline = Kcwi_pipeline()
framework = Framework(pipeline, 'config.cfg')
framework.config.instrument = ConfigClass("instr.cfg")
framework.logger.info("Framework initialized")

framework.logger.info("Checking path for files")

In [None]:
path = '/Users/lrizzi/KCWI_DATA/'

In [None]:
list = ['/Users/lrizzi/KCWI_DATA_1/kb181012_00014.fits',
        '/Users/lrizzi/KCWI_DATA_1/kb181012_00016.fits']
for f in list:
    args = Arguments(name=f)
    framework.append_event('next_file', args)

In [None]:
class fit_center_profile(Base_primitive):

    def __init__(self, action, context):
        Base_primitive.__init__(self, action, context)
        self.action.args.centcoeff=[]
    def _perform(self):
        """ Fit central region

        At this point we have the offsets between bars and the approximate
        offset from the reference bar to the atlas spectrum and the approximate
        dispersion.
        """
        self.logger.info("Finding wavelength solution for central region")
        # Are we interactive?
        #if KcwiConf.INTER >= 2:
        #    do_inter = True
        #    pl.ion()
        #else:
        #    do_inter = False
        do_inter=False
        # image label
        imlab = "Img # %d (%s) Sl: %s Fl: %s Gr: %s" % \
                (self.action.args.ccddata.header['FRAMENO'], self.action.args.illum,
                 self.action.args.ifuname, self.action.args.filter,
                 self.action.args.grating)
        # y binning
        ybin = self.action.args.ybinsize
        # let's populate the 0 points vector
        p0 = self.action.args.cwave + np.array(self.context.baroffs) * self.context.prelim_disp \
            - self.action.args.offset_wave
        # next we are going to brute-force scan around the preliminary
        # dispersion for a better solution. We will wander 5% away from it.
        max_ddisp = 0.05    # fraction
        # we will try nn values
        nn = (int(max_ddisp*abs(self.context.prelim_disp)/self.action.args.refdisp*(
                self.action.args.maxrow-self.action.args.minrow)/3.0))
        if nn < 10:
            nn = 10
        if nn > 25:
            nn = 25
        self.logger.info("N disp. samples: %d" % nn)
        # dispersions to try
        disps = self.context.prelim_disp * (1.0 + max_ddisp *
                                    (np.arange(0, nn+1) - nn/2.) * 2.0 / nn)
        # containers for bar-specific values
        bardisp = []
        barshift = []
        centwave = []
        centdisp = []

        # values for central fit
        subxvals = self.action.args.xvals[self.action.args.minrow:self.action.args.maxrow]
        # loop over bars
        arcs = self.context.arcs
        
        for b, bs in enumerate(self.context.arcs):
            # wavelength coefficients
            coeff = [0., 0., 0., 0., 0.]
            # container for maxima, shifts
            maxima = []
            shifts = []
            # get sub spectrum for this bar
            subspec = bs[self.action.args.minrow:self.action.args.maxrow]
            # now loop over dispersions
            for di, disp in enumerate(disps):
                # populate the coefficients
                coeff[4] = p0[b]
                coeff[3] = disp
                cosbeta = disp / (self.context.config.instrument.PIX*ybin) * self.action.args.rho * \
                    self.context.config.instrument.FCAM * 1.e-4
                if cosbeta > 1.:
                    cosbeta = 1.
                beta = math.acos(cosbeta)
                coeff[2] = -(self.context.config.instrument.PIX * ybin / self.context.config.instrument.FCAM) ** 2 * \
                    math.sin(beta) / 2. / self.action.args.rho * 1.e4
                coeff[1] = -(self.context.config.instrument.PIX * ybin / self.context.config.instrument.FCAM) ** 3 * \
                    math.cos(beta) / 6. / self.action.args.rho * 1.e4
                coeff[0] = (self.context.config.instrument.PIX * ybin / self.context.config.instrument.FCAM) ** 4 * \
                    math.sin(beta) / 24. / self.action.args.rho * 1.e4
                # what are the min and max wavelengths to consider?
                wl0 = np.polyval(coeff, self.action.args.xvals[self.action.args.minrow])
                wl1 = np.polyval(coeff, self.action.args.xvals[self.action.args.maxrow])
                minwvl = np.nanmin([wl0, wl1])
                maxwvl = np.nanmax([wl0, wl1])
                # where will we need to interpolate to cross-correlate?
                minrw = [i for i, v in enumerate(self.action.args.refwave)
                         if v >= minwvl][0]
                maxrw = [i for i, v in enumerate(self.action.args.refwave)
                         if v <= maxwvl][-1]
                subrefwvl = self.action.args.refwave[minrw:maxrw]
                subrefspec = self.action.args.reflux[minrw:maxrw]
                # get bell cosine taper to avoid nasty edge effects
                tkwgt = signal.windows.tukey(len(subrefspec),
                                             alpha=self.context.config.instrument.TAPERFRAC)
                # apply taper to atlas spectrum
                subrefspec *= tkwgt
                # adjust wavelengths
                waves = np.polyval(coeff, subxvals)
                # interpolate the bar spectrum
                obsint = interpolate.interp1d(waves, subspec, kind='cubic',
                                              bounds_error=False,
                                              fill_value='extrapolate')
                intspec = obsint(subrefwvl)
                # apply taper to bar spectrum
                intspec *= tkwgt
                # get a label
                # cross correlate the interpolated spectrum with the atlas spec
                nsamp = len(subrefwvl)
                offar = np.arange(1 - nsamp, nsamp)
                # Cross-correlate
                xcorr = np.correlate(intspec, subrefspec, mode='full')
                # Get central region
                x0c = int(len(xcorr) / 3)
                x1c = int(2 * (len(xcorr) / 3))
                xcorr_central = xcorr[x0c:x1c]
                offar_central = offar[x0c:x1c]
                # Calculate offset
                maxima.append(xcorr_central[xcorr_central.argmax()])
                shifts.append(offar_central[xcorr_central.argmax()])
            # Get interpolations
            int_max = interpolate.interp1d(disps, maxima, kind='cubic',
                                           bounds_error=False,
                                           fill_value='extrapolate')
            int_shift = interpolate.interp1d(disps, shifts, kind='cubic',
                                             bounds_error=False,
                                             fill_value='extrapolate')
            xdisps = np.linspace(min(disps), max(disps), num=nn*100)
            # get peak values
            maxima_res = int_max(xdisps)
            shifts_res = int_shift(xdisps) * self.action.args.refdisp
            bardisp.append(xdisps[maxima_res.argmax()])
            barshift.append(shifts_res[maxima_res.argmax()])
            # update coeffs
            coeff[4] = p0[b] - barshift[-1]
            coeff[3] = bardisp[-1]
            cosbeta = coeff[3] / (self.context.config.instrument.PIX * ybin) * self.action.args.rho * \
                self.context.config.instrument.FCAM * 1.e-4
            if cosbeta > 1.:
                cosbeta = 1.
            beta = math.acos(cosbeta)
            coeff[2] = -(self.context.config.instrument.PIX * ybin / self.context.config.instrument.FCAM) ** 2 * \
                math.sin(beta) / 2. / self.action.args.rho * 1.e4
            coeff[1] = -(self.context.config.instrument.PIX * ybin / self.context.config.instrument.FCAM) ** 3 * \
                math.cos(beta) / 6. / self.action.args.rho * 1.e4
            coeff[0] = (self.context.config.instrument.PIX * ybin / self.context.config.instrument.FCAM) ** 4 * \
                math.sin(beta) / 24. / self.action.args.rho * 1.e4
            scoeff = pascal_shift(coeff, self.action.args.x0)
            self.logger.info("Central Fit: Bar#, Cdisp, Coefs: "
                          "%3d  %.4f  %.2f  %.4f  %13.5e %13.5e" %
                          (b, bardisp[-1], scoeff[4], scoeff[3], scoeff[2],
                           scoeff[1]))
            # store central values
            centwave.append(coeff[4])
            centdisp.append(coeff[3])
            # Store results
            self.action.args.centcoeff.append(coeff)

            if self.context.config.instrument.interactive >= 1:
                # plot maxima
                p = figure(title="Bar %d, Slice %d" % (b, int(b/5)),
                           x_axis_label="Central dispersion (Ang/px)", y_axis_label="X-Corr Peak Value")

                p.scatter(disps, maxima, color='red')
                p.line(xdisps, int_max(xdisps))
                ylim_min = min(maxima)
                ylim_max = max(maxima)
                p.line([bardisp[-1], bardisp[-1]], [ylim_min, ylim_max], color='green')
                bokeh_plot(p)
                if do_inter:
                    q = input("<cr> - Next, q to quit: ")
                    if 'Q' in q.upper():
                        do_inter = False
                else:
                    time.sleep(0.01)

        if self.context.config.instrument.interactive >= 1:
            # Plot results
            p = figure(title=imlab, x_axis_label="Bar #", y_axis_label="Central Wavelength (A)")
            x = range(len(centwave))
            p.scatter(x,centwave, marker='x')
            ylim = [min(centwave), max(centwave)]
            for ix in range(1, 24):
                 sx = ix*5 - 0.5
                 p.line([sx, sx], ylim, color='black', line_dash = 'dotted')
            p.x_range = Range1d(-1, 120)
            bokeh_plot(p)
            if self.context.config.instrument.interactive >= 2:
                 input("Next? <cr>: ")
            else:
                 time.sleep(self.context.config.instrument.plot_pause)
            p = figure(title=imlab, x_axis_label="Bar #", y_axis_label="Central Dispersion (A)")
            x = range(len(centdisp))
            p.scatter(x, centdisp, marker='x')
            ylim = [min(centdisp), max(centdisp)]
            for ix in range(1, 24):
                sx = ix * 5 - 0.5
                p.line([sx, sx], ylim,  color='black', line_dash = 'dotted')
            p.x_range = Range1d(-1, 120)
            bokeh_plot(p)
            if self.context.config.instrument.interactive >= 2:
                input("Next? <cr>: ")
            else:
                time.sleep(self.context.config.instrument.plot_pause)
        print(arcs)
        return self.action.args

In [None]:
framework.start()


In [None]:
arcs

In [None]:
arcs