In [None]:
import sys
print("Confirm that the environment is correct.")
print(sys.executable)

import os
os.environ['CRDS_PATH'] = './crds_cache/jwst_ops'
os.environ['CRDS_SERVER_URL'] = 'https://jwst-crds.stsci.edu'

import shutil
from copy import deepcopy

import glob
import time

from jwst import datamodels as dm
from jwst.pipeline import Detector1Pipeline, Spec2Pipeline
import jwst

import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.ticker import FormatStrFormatter as fsf
from matplotlib.ticker import NullFormatter as nulf
from matplotlib.ticker import AutoMinorLocator as aml

import numpy as np
from astropy import modeling
from astropy.io import fits
from astropy.stats import sigma_clip

import batman
import emcee
import corner

from scipy.optimize import least_squares
from scipy.sparse import csr_matrix
from scipy.ndimage import median_filter
from scipy import signal

try:
    from exotic_ld import StellarLimbDarkening as SLD
    exotic_ld_available = True
except:
    print("EXoTiC-LD not found on this system, fixed limb darkening coefficients will not be available.")
    exotic_ld_available = False

In [None]:
def img(array, aspect=1, title=None, vmin=None, vmax=None, norm=None):
    '''
    Image plotting utility to plot the input 2D array.
    
    :param array: 2D array. Image you want to plot.
    :param aspect: float. Aspect ratio. Useful for visualizing narrow arrays.
    :param title: str. Title to give the plot.
    :param vmin: float. Minimum value for color mapping.
    :param vmax: float. Maximum value for color mapping.
    :param norm: str. Type of normalisation scale to use for this image.
    '''
    fig, ax = plt.subplots(figsize=(20, 25))
    if norm == None:
        im = ax.imshow(array, aspect=aspect, origin="lower", vmin=vmin, vmax=vmax)
    else:
        im = ax.imshow(array, aspect=aspect, norm=norm, origin="lower", vmin=vmin, vmax=vmax)
    ax.set_title(title)
    return fig, ax, im

In [None]:
def doStage1(filepath, outfile, outdir,
             group_scale={"skip":False},
             dq_init={"skip":False},
             saturation={"skip":False},
             superbias={"skip":False},
             refpix={"skip":False},
             linearity={"skip":False},
             dark_current={"skip":False},
             jump={"skip":True},
             ramp_fit={"skip":False},
             gain_scale={"skip":False},
             one_over_f={"skip":False, "bckg_rows":[1,2,3,4,5,6,-1,-2,-3,-4,-5,-6], "show":False}
             ):
    '''
    Performs Stage 1 calibration on one file.
    
    :param filepath: str. Location of the file you want to correct. The file must be of type *_uncal.fits.
    :param outfile: str. Name to give to the calibrated file.
    :param outdir: str. Location of where to save the calibrated file to.
    :param group_scale, dq_init, saturation, etc.: dict. These are the dictionaries shown in the Detector1Pipeline()
                                                   documentation, which control which steps are run and what parameters
                                                   they are run with. Please consult jwst-pipeline.readthedocs.io for
                                                   more information on these dictionaries.
    :param one_over_f: dict. Keyword "skip" is a bool that sets whether or not to perform this step. Keyword
                       "bckg_rows" contains list of integers that selects which rows of the array are used as
                       background for 1/f subtraction. Keyword "show" is a bool that sets whether the first group
                       of the first integration is shown as it is cleaned.
    :return: a Stage 1 calibrated file *_rateints.fits saved to the outdir.
    '''
    # Create the output directory if it does not yet exist.
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    
    # Report initialization of Stage 1.
    print("Performing Stage 1 calibration on file: " + filepath)
    print("Running JWST Stage 1 pipeline starting at GroupScale and stopping before Jump...")
    
    # Collect timestamp to track how long this takes.
    t0 = time.time()
    with Detector1Pipeline.call(filepath,
                                steps={"group_scale":group_scale,
                                       "dq_init":dq_init,
                                       "saturation":saturation,
                                       "superbias":superbias,
                                       "refpix":refpix,
                                       "linearity":linearity,
                                       "dark_current":dark_current,
                                       "jump":jump,
                                       "ramp_fit": {"skip": True}, "gain_scale": {"skip": True}}) as result:
        print("Stage 1 calibrations up to step Jump resolved in %.3f seconds." % (time.time() - t0))
        
        if not one_over_f["skip"]:
            # Before we ramp_fit, we perform 1/f subtraction.
            print("Performing pre-RampFit 1/f subtraction...")
            result.data = one_over_f_subtraction(result.data,
                                                 bckg_rows=one_over_f["bckg_rows"],
                                                 show=one_over_f["show"])
        else:
            print("Skipping 1/f subtraction...")
        
        # Now we can resume Stage 1 calibration.
        t02 = time.time()
        print("Resuming Stage 1 calibrations through RampFit and GainScale steps."
              "\nThe RampFit step can take several minutes to hours depending on how big your dataset is,\n"
              "so I suggest you find something else to do in the meantime. Anyways...")
        
        result = Detector1Pipeline.call(result, output_file=outfile, output_dir=outdir,
                                        steps={"group_scale": {"skip": True},
                                               "dq_init": {"skip": True},
                                               "saturation": {"skip": True},
                                               "superbias": {"skip": True},
                                               "refpix": {"skip": True},
                                               "linearity": {"skip": True},
                                               "dark_current": {"skip": True},
                                               "jump": {"skip": True},
                                               "ramp_fit":ramp_fit,
                                               "gain_scale":gain_scale})
        print("Finished final steps of Stage 1 calibrations in %.3f minutes." % ((time.time()-t02)/60))
    print("File calibrated and saved.")
    print("Stage 1 calibrations completed in %.3f minutes." % ((time.time() - t0)/60))

In [None]:
def one_over_f_subtraction(data, bckg_rows, show):
    '''
    Performs 1/f subtraction on the given array.
    Adapted from routine developed by Trevor Foote (tof2@cornell.edu).
    
    :param data: 4D array. Array of integrations x groups x rows x cols.
    :param bckg_rows: list of integers. Indices of the rows to use as the background region.
    :param show: bool. Whether to show the cleaned frames. For inspection of whether this
                 this step is working properly.
    :return: 4D array that has been subjected to 1/f subtraction.
    '''
    # Time this step.
    t0 = time.time()
    for i in range(np.shape(data)[0]): # for each integration
        for g in range(np.shape(data)[1]): # for each group
            # Define the background region.
            background_region = data[i, g, bckg_rows, :]
            if (i == 0 and g == 0 and show):
                fig, ax, im = img(np.log10(np.abs(background_region)), aspect=5, vmin=None, vmax=None, norm=None)
                plt.show()
                plt.close()
            
            # Clean the background region of outliers, so that CRs aren't propagated through the array.
            background_region = clean(background_region, 3, (5, 1)) # cleans on rows, rejecting at 3 sigma
            if (i == 0 and g == 0 and show):
                fig, ax, im = img(np.log10(np.abs(background_region)), aspect=5, vmin=None, vmax=None, norm=None)
                plt.show()
                plt.close()
            
            # Define the mean background in each column and extend to a full-size array.
            background = background_region.mean(axis=0)
            background = np.array([background,]*np.shape(data)[2])
            if (i == 0 and g == 0 and show):
                fig, ax, im = img(np.log10(np.abs(background)), aspect=5, vmin=None, vmax=None, norm=None)
                plt.show()
                plt.close()

            if (i == 0 and g == 0 and show):
                fig, ax, im = img(np.log10(np.abs(data[i, g, :, :])), aspect=5, vmin=None, vmax=None, norm=None)
                plt.show()
                plt.close()

            data[i, g, :, :] = data[i, g, :, :] - background

            if (i == 0 and g == 0 and show):
                fig, ax, im = img(np.log10(np.abs(data[i, g, :, :])), aspect=5, vmin=None, vmax=None, norm=None)
                plt.show()
                plt.close()
                
        if (i%1000 == 0 and i != 0):
            # Report every 1000 integrations.
            elapsed_time = time.time()-t0
            iterrate = i/elapsed_time
            iterremain = np.shape(data)[0] - i
            print("On integration %.0f. Elapsed time in this step is %.3f seconds." % (i, elapsed_time))
            print("Average rate of integration processing: %.3f ints/s." % iterrate)
            print("Estimated time remaining: %.3f seconds.\n" % (iterremain/iterrate))
    print("1/f subtraction completed in %.3f seconds." % (time.time()-t0))
    return data

In [None]:
def clean(data, sigma, kernel):
    '''
    Cleans one 2D array with median spatial filtering.
    Adapted from routine developed by Trevor Foote (tof2@cornell.edu).
    
    :param data: 2D array. Array that will be median-filtered.
    :param sigma: float. Sigma at which to reject outliers.
    :param kernel: tuple of odd int. Kernel to use for median filtering.
    :return: cleaned 2D array.
    '''
    medfilt = signal.medfilt2d(data, kernel)
    diff = data - medfilt
    temp = sigma_clip(diff, sigma=sigma, axis=0)
    mask = temp.mask
    int_mask = mask.astype(float) * medfilt
    test = (~mask).astype(float)
    return (data*test) + int_mask

In [None]:
def doStage2(filepath, outfile, outdir,
             assign_wcs={"skip":False},
             extract_2d={"skip":False},
             srctype={"skip":False},
             wavecorr={"skip":False},
             flat_field={"skip":False},
             pathloss={"skip":True},
             photom={"skip":True},
             resample_spec={"skip":True},
             extract_1d={"skip":True}
             ):
    '''
    Performs Stage 2 calibration on one file.
    
    :param filepath: str. Location of the file you want to correct. The file must be of type *_rateints.fits.
    :param outfile: str. Name to give to the calibrated file.
    :param outdir: str. Location of where to save the calibrated file to.
    :param assign_wcs, background, extract2d, etc.: dict. These are the dictionaries shown in the Spec2Pipeline()
                                                    documentation, which control which steps are run and what parameters
                                                    they are run with. Please consult jwst-pipeline.readthedocs.io for
                                                    more information on these dictionaries.
    :return: a Stage 2 calibrated file *_calints.fits saved to the outdir.
    '''
    # Create the output directory if it does not yet exist.
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    
    print("Performing Stage 2 calibration on file: " + filepath)
    print("Running JWST Stage 2 pipeline for spectroscopic data. This stage is a pure wrapper for JWST Stage 2 with no mods. Anyways...")
    t0 = time.time()
    result = Spec2Pipeline.call(filepath, output_file=outfile, output_dir=outdir,
                                steps={"assign_wcs":assign_wcs,
                                       "extract_2d":extract_2d,
                                       "srctype":srctype,
                                       "wavecorr":wavecorr,
                                       "flat_field":flat_field,
                                       "pathloss":pathloss,
                                       "photom":photom,
                                       "resample_spec":resample_spec,
                                       "extract_1d":extract_1d})
    print("File calibrated and saved.")
    print("Stage 2 calibrations completed in %.3f minutes." % ((time.time() - t0)/60))

In [None]:
def doStage3(filesdir, outdir,
             trace_aperture={"hcut1":0,
                             "hcut2":-1,
                             "vcut1":0,
                             "vcut2":-1},
             frames_to_reject = [],
             loss_stats_step={"skip":False},
             mask_flagged_pixels={"skip":False},
             iteration_outlier_removal={"skip":False, "n":2, "sigma":10},
             spatialfilter_outlier_removal={"skip":False, "sigma":3, "kernel":(1,15)},
             laplacianfilter_outlier_removal={"skip":False, "sigma":50},
             second_bckg_subtract={"skip":False,"bckg_rows":[0,1,2,3,4,5,6,-6,-5,-4,-3,-2,-1]},
             track_source_location={"skip":False,"reject_disper":False,"reject_spatial":True}
            ):
    '''
    Performs custom Stage 3 calibrations on all *_calints.fits files located in the filesdir.
    Can be run on *.fits files that have already been run on this step, if you want only to
    load the data from those *.fits files.
    
    :param filesdir: str. Directory where the *_calints.fits files you want to calibrate are stored.
    :param outdir: str. Directory where you want the additionally-calibrated .fits files to be stored,
                   as well as any output images for reference.
    :param trace_aperture: dict. Keywords are "hcut1", "hcut2", "vcut1", "vcut2", all integers
                           denoting the rows and columns respectively that define the edges of
                           the aperture bounding the trace.
    :param frames_to_reject: list of int. Indices of frames that you want to reject, for reasons like
                             a high-gain antenna move, a satellite crossing, or an anomalous drop/rise in flux.
    :param loss_stats: dict.
    :param mask_flagged_pixels: dict.
    :param iteration_outlier_removal: dict.
    :param spatialfilter_outlier_removal: dict.
    :param laplacianfilter_outlier_removal: dict.
    :param second_bckg_subtract: dict.
    :param track_source_location: dict.
    :return: calibrated .fits files in the outdir.
    '''
    # Create the output directory if it does not yet exist.
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    if not os.path.exists(os.path.join(outdir, "output_imgs_calibration")):
        os.makedirs(os.path.join(outdir, "output_imgs_calibration"))
    
    files = sorted(glob.glob(os.path.join(filesdir,'*_calints.fits')))
    print("Performing Stage 3 calibrations on the following files:")
    t0 = time.time()
    for file in files:
        with fits.open(file) as file:
            print(file[0].header["FILENAME"])
    
    # First, we need to stitch all the segments together. This step is not optional.
    # It can, however, be run on additionally-calibrated files if you point it to the
    # directory where those are held.
    segments, errors, segstarts, wavelengths, dqflags, times = stitch_files(files)
    
    print("Read all files and collected needed data.")
    print("Creating the aperture for extraction of the data and saving an image of it for reference...")
    aperture = np.ones(np.shape(segments))
    aperture[:,
             trace_aperture["hcut1"]:trace_aperture["hcut2"],
             trace_aperture["vcut1"]:trace_aperture["vcut2"]] = 0
    fig, ax, im = img(aperture[0, :, :], aspect=5)
    plt.savefig(os.path.join(outdir, "output_imgs_calibration/trace_aperture.pdf"), dpi=300)
    plt.close(fig)
    fig, ax, im = img(segments[0,
                               trace_aperture["hcut1"]:trace_aperture["hcut2"],
                               trace_aperture["vcut1"]:trace_aperture["vcut2"]],
                      aspect=5)
    plt.savefig(os.path.join(outdir, "output_imgs_calibration/trace_extracted_region.pdf"), dpi=300)
    plt.close(fig)
    
    plt.imshow(np.log10(np.abs(segments[0,:,:])))
    plt.title("Frame 0 log10 abs before Stage 3 corrections.")
    plt.show()
    plt.close()
    
    if not loss_stats_step["skip"]:
        num_pixels_lost_to_DQ_flags = loss_stats(dqflags, trace_aperture, outdir)
        print("%.0f pixels were marked by DQ flags." % num_pixels_lost_to_DQ_flags)
    
    if not mask_flagged_pixels["skip"]:
        segments, num_pixels_lost_by_flagging = mask_flagged(segments, dqflags, trace_aperture)
    else:
        num_pixels_lost_by_flagging = 0
    
    if not iteration_outlier_removal["skip"]:
        segments, num_pixels_lost_by_iteration = iterate_outlier_removal(segments, dqflags, trace_aperture,
                                                                         n=iteration_outlier_removal["n"],
                                                                         sigma=iteration_outlier_removal["sigma"])
    else:
        num_pixels_lost_by_iteration = 0
    
    if not spatialfilter_outlier_removal["skip"]:
        segments, num_pixels_lost_by_filtering = spatial_outlier_removal(segments, trace_aperture,
                                                                         sigma=spatialfilter_outlier_removal["sigma"],
                                                                         kernel=spatialfilter_outlier_removal["kernel"])
    
    else:
        num_pixels_lost_by_filtering = 0
        
    if not laplacianfilter_outlier_removal["skip"]:
        segments, num_pixels_lost_by_laplacian = laplacian_outlier_removal(segments, errors, trace_aperture,
                                                                           sigma=laplacianfilter_outlier_removal["sigma"],
                                                                           verbose=True)
    else:
        num_pixels_lost_by_laplacian = 0
    
    if not second_bckg_subtract["skip"]:
        segments = bckg_subtract(segments,
                                 bckg_rows=second_bckg_subtract["bckg_rows"])
    
    plt.imshow(np.log10(np.abs(segments[0,:,:])))
    plt.title("Frame 0 log10 abs after Stage 3 corrections.")
    plt.show()
    plt.close()
    
    if not track_source_location["skip"]:
        frames_rejected_by_source_motion = gaussian_source_track(segments,
                                                                 reject_dispersion_direction=track_source_location["reject_disper"],
                                                                 reject_spatial_direction=track_source_location["reject_spatial"])
        for frame in frames_rejected_by_source_motion:
            if frame not in frames_to_reject:
                frames_to_reject.append(frame)
    
    totalmasked = num_pixels_lost_by_flagging + num_pixels_lost_by_iteration + num_pixels_lost_by_filtering + num_pixels_lost_by_laplacian
    total = np.shape(segments)[0]*(trace_aperture["vcut2"]-trace_aperture["vcut1"])*(trace_aperture["hcut2"]-trace_aperture["hcut1"])
    print("In all, %.0f trace pixels out of %.0f were masked (fraction of %.2f) and %.0f frames will be skipped." % (totalmasked, total, totalmasked/total, len(frames_to_reject)))
    
    print("Stage 3 calibrations completed in %.3f minutes." % ((time.time()-t0)/60))
    
    print("Writing calibrated fits file as several .fits file...")
    for i, file in enumerate(files):
        outfile = os.path.join(outdir, "postprocessed_{0:g}.fits".format(i))
        shutil.copy(file,outfile)
        with fits.open(outfile, mode="update") as fits_file:
            # Need to write new objects "segstarts" and "frames_to_reject" to fits file,
            # and update data, int_times, and wavelength attributes to be concatenated arrays.

            # Create REJECT ImageHDU to contain frames_to_reject object.
            REJECT = fits.ImageHDU(np.array(frames_to_reject), name='REJECT')
            # Append REJECT to the hdulist of fits_file.
            fits_file.append(REJECT)

            # Create SEGSTARTS ImageHDU to contain segstarts object.
            SEGSTARTS = fits.ImageHDU(np.array(segstarts), name='SEGSTARTS')
            # Append SEGSTARTS to the hdulist of fits_file.
            fits_file.append(SEGSTARTS)

            # Write calibrated image data, times, and wavelengths to new fits file.
            if i == 0:
                fits_file['SCI'].data = segments[:segstarts[i],:,:]
            else:
                fits_file['SCI'].data = segments[segstarts[i-1]:segstarts[i],:,:]

            # All modified headers get written out.
            fits_file.writeto(outfile, overwrite=True)
    print("Wrote calibrated postprocessed_#.fits files.")

In [None]:
def stitch_files(files):
    '''
    Reads all *_calints.fits files given by the filepaths provided
    and stitches them together.
    '''
    if "postprocessed" in files[0]:
        print("Reading post-processed files, adjusting outputs accordingly...")
        # Reading a postprocessed file, adjust strategy!
        for i, file in enumerate(files):
            print("Attempting to locate file: " + file)
            fitsfile = dm.open(file)
            print("Loaded the file " + file + " successfully.")

            # Need to retrieve the wavelength object, science object, and DQ object from this.
            if i == 0:
                segments = fitsfile.data
                errors = fitsfile.err
                segstarts = [np.shape(fitsfile.data)[0]]
                wavelengths = [fitsfile.wavelength]
                dqflags = fitsfile.dq
                times = fitsfile.int_times["int_mid_MJD_UTC"]
            else:
                segments = np.concatenate([segments, fitsfile.data], 0)
                errors = np.concatenate([errors, fitsfile.err], 0)
                segstarts.append(np.shape(fitsfile.data)[0] + sum(segstarts))
                wavelengths.append(fitsfile.wavelength)
                dqflags = np.concatenate([dqflags, fitsfile.dq], 0)
                times = np.concatenate([times, fitsfile.int_times["int_mid_MJD_UTC"]], 0)
            
            fitsfile.close()
            
            with fits.open(file) as f:
                frames_to_reject = f["REJECT"].data
            
            print("Retrieved segments, wavelengths, DQ flags, and times from file: " + file)
            print("Closing file and moving on to next one...")
        return segments, errors, segstarts, wavelengths, dqflags, times, frames_to_reject
    else:
        for i, file in enumerate(files):
            print("Attempting to locate file: " + file)
            fitsfile = dm.open(file)
            print("Loaded the file " + file + " successfully.")

            # Need to retrieve the wavelength object, science object, and DQ object from this.
            if i == 0:
                segments = fitsfile.data
                errors = fitsfile.err
                segstarts = [np.shape(fitsfile.data)[0]]
                wavelengths = [fitsfile.wavelength]
                dqflags = fitsfile.dq
                times = fitsfile.int_times["int_mid_MJD_UTC"]
            else:
                segments = np.concatenate([segments, fitsfile.data], 0)
                errors = np.concatenate([errors, fitsfile.err], 0)
                segstarts.append(np.shape(fitsfile.data)[0] + sum(segstarts))
                wavelengths.append(fitsfile.wavelength)
                dqflags = np.concatenate([dqflags, fitsfile.dq], 0)
                times = np.concatenate([times, fitsfile.int_times["int_mid_MJD_UTC"]], 0)

            print("Retrieved segments, wavelengths, DQ flags, and times from file: " + file)
            print("Closing file and moving on to next one...")
            fitsfile.close()
        return segments, errors, segstarts, wavelengths, dqflags, times

In [None]:
def loss_stats(dqflags, trace_aperture, outdir):
    '''
    Gets pixel loss statistics for the region inside of the trace.
    
    :param dqflags: 3D array. Integrations x rows x cols of data quality flags.
    :param trace_aperture: dict. Keywords are "hcut1", "hcut2", "vcut1", "vcut2", all integers
                           denoting the rows and columns respectively that define the edges of
                           the aperture bounding the trace.
    :return: int of how many pixels were affected by DQ flags.
    '''
    # Create the aperture for counting loss statistics.
    aperture = np.ones(np.shape(dqflags))
    aperture[:,
             trace_aperture["hcut1"]:trace_aperture["hcut2"],
             trace_aperture["vcut1"]:trace_aperture["vcut2"]] = 0
    
    print("Checking to see how much of the trace data will be lost to dq flags...")
    loss_stats = []
    for k in range(np.shape(dqflags)[0]):
        dqflag_arr = dqflags[k,
                             trace_aperture["hcut1"]:trace_aperture["hcut2"],
                             trace_aperture["vcut1"]:trace_aperture["vcut2"]]
        loss_stats.append(np.count_nonzero(dqflag_arr))
    
    # Report what unique flags appear in this dataset.
    unique = []
    
    for row in np.ma.masked_array(dqflags[k, :, :], aperture[k, :, :]):
        for element in row:
            if element not in unique:
                unique.append(element)
    del(unique[0])
    print("The following flags were reported:")
    print(unique)
    
    # Creates an image of the aperture with flags.
    flagged_arr = np.where(np.ma.masked_array(dqflags[0, :, :], aperture[0, :, :]) > 0, 1, 0)
    fig, ax, im = img(np.ma.masked_array(flagged_arr, aperture[0, :, :]), aspect=5)
    plt.savefig(os.path.join(outdir, "output_imgs_calibration/trace_aperture_with_flags.pdf"), dpi=300)
    plt.close(fig)
    
    # Reports pixel loss stats.
    print("The total number of lost trace pixels is %.3f." % np.sum(loss_stats))
    print("The total percentage of pixels being lost is %.3f." % (100*(np.sum(loss_stats)/(np.shape(dqflags)[0]*(trace_aperture["vcut2"]-trace_aperture["vcut1"])*(trace_aperture["hcut2"]-trace_aperture["hcut1"])))))
    print("The median number of lost trace pixels is %.3f." % np.median(loss_stats))
    print("The median percentage of pixels being lost is %.3f." % (100*(np.median(loss_stats)/((trace_aperture["vcut2"]-trace_aperture["vcut1"])*(trace_aperture["hcut2"]-trace_aperture["hcut1"])))))
    print("Creating a histogram showing pixel loss statistics...")

    fig, ax = plt.subplots(figsize=(5, 5))
    ax.hist(loss_stats, density=True, bins=70)
    ax.set_xlabel('number of flagged trace pixels')
    ax.set_ylabel('frequency')
    plt.show()
    plt.close()
    
    return np.sum(loss_stats)

In [None]:
def mask_flagged(segments, dqflags, trace_aperture):
    '''
    Mask all pixels flagged by the data quality array with their medians in time.
    
    :param segments: 3D array. Integrations x rows x cols of data.
    :param dqflags: 3D array. Integrations x rows x cols of data quality flags.
    :param trace_aperture: dict. Keywords are "hcut1", "hcut2", "vcut1", "vcut2", all integers
                           denoting the rows and columns respectively that define the edges of
                           the aperture bounding the trace.
    :return: segments array with flagged pixels masked, and int of how many pixels in the trace were affected
             by this process.
    '''
    # Get borders to evaluate.
    hcut1, hcut2, vcut1, vcut2 = trace_aperture["hcut1"], trace_aperture["hcut2"], trace_aperture["vcut1"], trace_aperture["vcut2"]
    
    # Turn dqflags into mask arrays.
    dq_mask = np.empty_like(dqflags)
    dq_mask[:, :, :] = np.where(dqflags[:, :, :] > 0, 1, 0)
    
    print("Masking flagged pixels inside and outside of the trace...")
    t0 = time.time()
    masked_flagged = 0
    
    for i in range(np.shape(segments)[1]):
        for j in range(np.shape(segments)[2]):
            # Track temporal variations and replace any flagged pixels with the temporal median
            # that was calculated out of only unmasked values.
            pmed = np.ma.median(np.ma.masked_array(segments[:, i, j], dq_mask[:, i, j]))
            psigma = np.std(np.ma.masked_array(segments[:, i, j], dq_mask[:, i, j]))
            if (i in range(hcut1, hcut2) and j in range(vcut1, vcut2)):
                masked_flagged += np.count_nonzero(dqflags[:, i, j])
            segments[:, i, j] = np.where(dq_mask[:, i, j] == 1, pmed, segments[:, i, j])
    print("Masked %.0f flagged trace pixels in %.3f seconds." % (masked_flagged, time.time() - t0))
    return segments, masked_flagged

In [None]:
def iterate_outlier_removal(segments, dqflags, trace_aperture, n, sigma):
    '''
    Iterate and remove outliers to reject CRs. Does not use masked values to compute the time median.
    
    :param segments: 3D array. Integrations x rows x cols of data.
    :param dqflags: 3D array. Integrations x rows x cols of data quality flags.
    :param trace_aperture: dict. Keywords are "hcut1", "hcut2", "vcut1", "vcut2", all integers
                           denoting the rows and columns respectively that define the edges of
                           the aperture bounding the trace.
    :param n: int. Number of times to iterate.
    :param sigma: float. Sigma at which to reject outliers.
    :return: segments array with CRs rejected, and int of how many pixels in the trace were affected
             by this process.
    '''
    # Get borders to evaluate.
    hcut1, hcut2, vcut1, vcut2 = trace_aperture["hcut1"], trace_aperture["hcut2"], trace_aperture["vcut1"], trace_aperture["vcut2"]
    
    # Turn dqflags into mask arrays.
    dq_mask = np.empty_like(dqflags)
    dq_mask[:, :, :] = np.where(dqflags[:, :, :] > 0, 1, 0)
    
    print("Masking pixels with substantial time variations...")
    t0 = time.time()
    masked_iter = 0
    
    for iteration in range(n):
        print("On iteration %.0f..." % iteration)
        t02 = time.time()
        for i in range(np.shape(segments)[1]):
            for j in range(np.shape(segments)[2]):
                # Track temporal variation of a single pixel and mask anywhere that pixel is 10sigma
                # deviating from its usual levels. This should help suppress noise.
                pmed = np.ma.median(np.ma.masked_array(segments[:, i, j], dq_mask[:, i, j]))
                psigma = np.std(np.ma.masked_array(segments[:, i, j], dq_mask[:, i, j]))
                if (i in range(hcut1, hcut2) and j in range(vcut1, vcut2)):
                    maskcount = np.where(np.abs(segments[:, i, j] - pmed)>sigma*psigma, 1, 0)
                    masked_iter += np.count_nonzero(maskcount)
                segments[:, i, j] = np.where(np.abs(segments[:, i, j] - pmed)>sigma*psigma, pmed, segments[:, i, j])
        print("Performed round %.0g of %.2f-sigma temporal outlier rejection in %.3f seconds." % (iteration, sigma, time.time()-t02))
    print("Masked %.0f trace pixels for significant temporal variations in %.3f seconds." % (masked_iter, time.time()-t0))
    return segments, masked_iter

In [None]:
def spatial_outlier_removal(segments, trace_aperture, sigma, kernel):
    '''
    Median filter the image to remove hot pixels.
    
    :param segments: 3D array. Integrations x rows x cols of data.
    :param trace_aperture: dict. Keywords are "hcut1", "hcut2", "vcut1", "vcut2", all integers
                           denoting the rows and columns respectively that define the edges of
                           the aperture bounding the trace.
    :param sigma: float. Sigma at which to reject outliers.
    :param kernel: tuple of odd ints. Kernel to use for spatial filtering.
    '''
    # Get borders to evaluate.
    hcut1, hcut2, vcut1, vcut2 = trace_aperture["hcut1"], trace_aperture["hcut2"], trace_aperture["vcut1"], trace_aperture["vcut2"]
    
    print("Performing hot pixel masking through spatial median filtering...")
    masked_filter = 0
    cleaned_segments = np.zeros_like(segments)
    
    t0 = time.time()
    for i in range(np.shape(segments)[0]):
        # Clean the array.
        cleaned_segments[i, :, :] = clean(segments[i, :, :], sigma, kernel)
        
        # Check where it has been changed.
        maskcount = np.empty((hcut2-hcut1, vcut2-vcut1))
        maskcount = np.where(cleaned_segments[i, hcut1:hcut2, vcut1:vcut2] != segments[i, hcut1:hcut2, vcut1:vcut2], 1, 0)
        masked_filter += np.count_nonzero(maskcount)
    print("Masked %.0f trace pixels for significant spatial variation in %.3f seconds." % (masked_filter, time.time()-t0))
    print("Performed median filtering in %.3f seconds." % (time.time()-t0))
    return cleaned_segments, masked_filter

In [None]:
def laplacian_outlier_removal(segments, errors, trace_aperture, sigma=50, verbose=False):
    '''
    Convolves a Laplacian kernel with the segments array to replace spatial outliers with
    the median of the surrounding 3x3 kernel.
    
    :param segments: 3D array. The segments(t,x,y) array from which outliers will be removed.
    :param sigma: float. Threshold of deviation from median of Laplacian image, above which a pixel
                  will be flagged as an outlier and masked.
    :param verbose: bool. If True, occasionally prints out a progress report.
    :return: segments(t,x,y) array with spatial outliers masked.
    '''
    # Get borders to evaluate.
    hcut1, hcut2, vcut1, vcut2 = trace_aperture["hcut1"], trace_aperture["hcut2"], trace_aperture["vcut1"], trace_aperture["vcut2"]
    
    l = 0.25*np.array([[0,-1,0],[-1,4,-1],[0,-1,0]])
    segmentsc = deepcopy(segments)
    errorsc = deepcopy(errors)
    
    bad_pix_removed = 0
    t0 = time.time()
    steps_taken = 0
    nsteps = np.shape(segments)[0]
    
    print("Cleaning %.1f-sigma outliers with Laplacian edge detection..." % sigma)
    for k in range(np.shape(segments)[0]):
        # Iterate over frames.
        print("On frame %.0f..." % k)
        # Estimate readnoise.
        errf  = errorsc[k,:,:]**2
        errf -= segmentsc[k,:,:] # remove shot noise from errors to get read noise variance.
        errf[errf < 0] = 0 # enforce positivity.
        errf  = np.sqrt(errf) # turn variance into readnoise.
        rn    = np.mean(errf) # mean readnoise is our estimate.
        
        if (verbose and k == 0):
            print("Estimated readnoise: %.10f" % rn)
        
        # Build noise model.
        NOISE = np.sqrt(median_filter(np.abs(segmentsc[k,:,:]),size=5)+rn**2)
        NOISE[NOISE <= 0] = np.min(NOISE[np.nonzero(NOISE)]) # really want to avoid nans
        if (verbose and k == 0):
            plt.figure(figsize=(20,5))
            plt.imshow(np.log10(NOISE))
            plt.title("Noise model for LED")
            plt.show()
            plt.close()
        
        original_shape = np.shape(segments[k,:,:])
        ss_shape = (original_shape[0]*2,original_shape[1]*2) # double subsampling
        subsample = np.empty(ss_shape)
        
        # Subsample the array.
        for i in range(ss_shape[0]):
            for j in range(ss_shape[1]):
                try:
                    subsample[i,j] = segments[int((i+1)/2),int((j+1)/2),k]
                except IndexError:
                    subsample[i,j] = 0
        
        # Convolve subsample with laplacian.
        lap_img = np.convolve(l.flatten(),subsample.flatten(),mode='same').reshape(ss_shape)
        lap_img[lap_img < 0] = 0 # force positivity
        
        # Resample to original size.
        resample = np.empty(original_shape)
        for i in range(original_shape[0]):
            for j in range(original_shape[1]):
                resample[i,j] = 0.25*(lap_img[2*i-1,2*j-1] +
                                      lap_img[2*i-1,2*j] +
                                      lap_img[2*i,2*j-1] +
                                      lap_img[2*i,2*j])
                
        # Divide by subsample factor times noise model.
        scaled_resample = resample/(2*NOISE)
        
        # Spot outliers.
        med = np.median(scaled_resample)
        if (verbose and k == 0):
            print("Median of scaled resampled laplacian image: %.10f" % med)
        scaled_resample[np.abs(scaled_resample-med) < med*sigma] = 0 # any not zero after this are rays.
        scaled_resample[scaled_resample!=0] = 1 # for visualization
        
        if (verbose and k == 0):
            plt.figure(figsize=(20,5))
            plt.imshow(scaled_resample)
            plt.title("Where CRs and hot pixels were detected")
            plt.colorbar()
            plt.show()
            plt.close()
        
        # Correct frames
        for i, j in zip(np.where(scaled_resample!=0)[0],np.where(scaled_resample!=0)[1]):
            segments[k,i,j] = np.median(segments[k,i-1:i+2,j-1:j+2]) # replace with local median
            if (hcut1 <= i <= hcut2 and vcut1 <= j <= vcut2):
                bad_pix_removed += 1
            
        # Report progress.
        steps_taken += 1
        if (steps_taken % int(nsteps*.1) == 0 and verbose):
            iter_rate = steps_taken/(time.time()-t0)
            print("%.0f-percent done. Time elapsed: %.0f seconds. Estimated time remaining: %.0f seconds." % (steps_taken*100/nsteps, time.time()-t0, (nsteps-steps_taken)/iter_rate))
    print("Iterations complete. Removed %.0f spatial outliers in %.0f seconds." % (bad_pix_removed, time.time()-t0))
    return segments, bad_pix_removed

In [None]:
def bckg_subtract(segments, bckg_rows):
    '''
    Subtract background signal using the rows defined by bckg_rows as the background.
    
    :param segments: 3D array. Integrations x rows x cols of data.
    :param bckg_rows: list of integers. Indices of the rows to use as the background region.
    :return: segments with background subtracted.
    '''
    print("Performing additional background subtraction...")
    t0 = time.time()
    
    for i in range(np.shape(segments)[0]):
        background_region = segments[i, bckg_rows, :]
        background_region = sigma_clip(background_region, sigma=3)
        mmed = np.ma.median(background_region)
        background_region = background_region.filled(fill_value=mmed)
        background = background_region.mean(axis=0)
        background = np.array([background,]*np.shape(segments)[1])
        segments[i, :, :] = segments[i, :, :] - background
        if (i%1000 == 0 and i != 0):
            # Report progress every 1,000 integrations.
            elapsed_time = time.time()-t0
            iterrate = i/elapsed_time
            iterremain = np.shape(segments)[0] - i
            print("On integration %.0f. Elapsed time is %.3f seconds." % (i, elapsed_time))
            print("Average rate of integration processing: %.3f ints/s." % iterrate)
            print("Estimated time remaining: %.3f seconds.\n" % (iterremain/iterrate))
    print("Additional background subtraction completed in %.3f seconds." % (time.time()-t0))
    return segments

In [None]:
def gaussian_source_track(segments, reject_dispersion_direction=True, reject_spatial_direction=False):
    '''
    Tracks the location of the trace between frames and reports frame numbers that show
    significant deviations from the usual location.
    
    :param segments: 3D array. Integrations x rows x cols of data.
    :param reject_dispersion_direction: bool. Whether to reject outliers of position in the dispersion direction.
    :param reject_spatial_direction: bool. Whether to reject outliers of position in the spatial direction.
    :return: reject_frames list of ints showing which frames are to get rejected by
    '''
    reject_frames = []
    t0 = time.time()
    source_pos_disp = []
    source_pos_cros = []
    print("Fitting source position for each integration...")
    for k in range(np.shape(segments)[0]):
        # First find the dispersion axis position.
        profile = np.sum(segments[k, :, :], axis=0)
        profile = profile/np.max(profile) # normalize amplitude to 1 for ease of fit
        fitter = modeling.fitting.LevMarLSQFitter()
        model = modeling.models.Gaussian1D(amplitude=1, mean=100, stddev=1)
        fitted_model = fitter(model, [i for i in range(np.shape(profile)[0])], profile)
        source_pos_disp.append(fitted_model.mean[0])

        # Then find the cross dispersion axis position.
        profile = np.sum(segments[k, :, :], axis=1)
        profile = profile/np.max(profile) # normalize amplitude to 1 for ease of fit
        fitter = modeling.fitting.LevMarLSQFitter()
        model = modeling.models.Gaussian1D(amplitude=1, mean=14, stddev=1)
        fitted_model = fitter(model, [i for i in range(np.shape(profile)[0])], profile)
        source_pos_cros.append(fitted_model.mean[0])

        if (k%500 == 0 and k != 0):
            # Report progress every 500 integrations.
            elapsed_time = time.time()-t0
            iterrate = k/elapsed_time
            iterremain = np.shape(segments)[0] - k
            print("On integration %.0f. Elapsed time is %.3f seconds." % (k, elapsed_time))
            print("Average rate of integration processing: %.3f ints/s." % iterrate)
            print("Estimated time remaining: %.3f seconds.\n" % (iterremain/iterrate))
    print("Fit source positions in %.3f minutes." % ((time.time()-t0)/60))
    
    # Now that we have the positions, build list of integration indices to reject for being too far off.
    mpd = np.median(source_pos_disp)
    spd = np.std(source_pos_disp)
    mpc = np.median(source_pos_cros)
    spc = np.std(source_pos_cros)

    print("Median position: " + str(mpd) + ", " + str(mpc))
    print("Sigma: " + str(spd) + ", " + str(spc))
    skipped = 0
    for k in range(np.shape(segments)[0]):
        if reject_dispersion_direction:
            if np.abs(mpd - source_pos_disp[k]) > 3*spd:
                reject_frames.append(k)
                skipped += 1
        
        if reject_spatial_direction:
            if np.abs(mpc - source_pos_cros[k]) > 3*spc:
                reject_frames.append(k)
                skipped += 1
    print("%.0f integrations had source positions significantly off from the median position and will be rejected." % skipped)
    return reject_frames

In [None]:
def doStage4(filepaths, outdir,
             trace_aperture={"hcut1":0,
                             "hcut2":-1,
                             "vcut1":0,
                             "vcut2":-1},
             extract_light_curves={"skip":False,
                                   "wavbins":np.linspace(0.6,5.3,70),
                                   "ext_type":"box"},
             median_normalize_curves={"skip":False},
             sigma_clip_curves={"skip":False,
                                "b":100,
                                "clip_at":5},
             fix_transit_times={"skip":False,
                                "epoch":None},
             plot_light_curves={"skip":False},
             save_light_curves={"skip":False}
            ):
    '''
    Performs Stage 4 extractions on the files located at filepaths.
    
    :param filepath: lst of str. Location of the postprocessed_*.fits files you want to extract spectra from.
    :param outdir: str. Location where you want output images and text files to be saved to.
    :param trace_aperture: dict. Keywords are "hcut1", "hcut2", "vcut1", "vcut2", all integers
                           denoting the rows and columns respectively that define the edges of
                           the aperture bounding the trace.
    :return: .txt files of curves extracted from the postprocessed_*.fits files.
    '''
    print("Performing Stage 4 extractions of spectra from the data located at: {}".format(filepaths))
    
    # Grab the needed info from the file.
    segments, errors, segstarts, wavelengths, dqflags, times, frames_to_reject = stitch_files(filepaths)
    
    # Build the aperture object.
    aperture = np.ones(np.shape(segments))
    aperture[:,
             trace_aperture["hcut1"]:trace_aperture["hcut2"],
             trace_aperture["vcut1"]:trace_aperture["vcut2"]] = 0
    
    # Should not skip extract light curves! Rest of code breaks.
    if not extract_light_curves["skip"]:
        wlc, slc, times, central_lams = extract_curves(segments, errors, times, aperture, segstarts, wavelengths, frames_to_reject,
                                                       wavbins=extract_light_curves["wavbins"],
                                                       ext_type=extract_light_curves["ext_type"])
        
    if not median_normalize_curves["skip"]:
        wlc = median_normalize(wlc)
        for i, lc in enumerate(slc):
            slc[i] = median_normalize(lc)
    
    if not sigma_clip_curves["skip"]:
        wlc = clip_curve(wlc,
                         b=sigma_clip_curves["b"],
                         clip_at=sigma_clip_curves["clip_at"])
        for i, lc in enumerate(slc):
            slc[i] = clip_curve(lc,
                                b=sigma_clip_curves["b"],
                                clip_at=sigma_clip_curves["clip_at"])
            
    if not fix_transit_times["skip"]:
        print("Fixing transit timestamps...")
        times=fix_times(times, wlc=wlc,
                        epoch=fix_transit_times["epoch"])
        print("Fixed.")
        
    if not plot_light_curves["skip"]:
        print("Generating output plots of extracted light curves...")
        imgs_outdir = os.path.join(outdir, "output_imgs_extraction")
        if not os.path.exists(imgs_outdir):
            os.makedirs(imgs_outdir)
        plot_curve(times, wlc, "White light curve", "wlc", imgs_outdir)
        for lc, central_lam in zip(slc, central_lams):
            plot_curve(times, lc, "Spectroscopic light curve at %.3f micron" % central_lam, "slc_%.3fmu" % central_lam, imgs_outdir)
        print("Plots generated.")
    
    if not save_light_curves["skip"]:
        print("Writing all curves to .txt files...")
        txts_outdir = os.path.join(outdir, "output_txts_extraction")
        if not os.path.exists(txts_outdir):
            os.makedirs(txts_outdir)
        write_light_curve(times, wlc, "wlc", txts_outdir)
        for i, lc in enumerate(slc):
            if extract_light_curves["wavbins"][i] == extract_light_curves["wavbins"][-1]:
                pass
            else:
                write_light_curve(times, lc, "slc_%.3fmu_%.3fmu" % (extract_light_curves["wavbins"][i],extract_light_curves["wavbins"][i+1]), txts_outdir)
        print("Files written.")
    print("Stage 4 finished.")

In [None]:
def extract_curves(segments, errors, times, aperture, segstarts, wavelengths, frames_to_reject, wavbins, ext_type="box"):
    '''
    Extract a white light curve and spectroscopic light curves from the trace.
    
    :param segments: 3D array. Integrations x rows x cols of data.
    :param errors: 3D array. Integrations x rows x cols of uncertainties.
    :param times: 1D array. Timestamps of integrations.
    :param aperture: 3D array. Mask that defines where the trace is.
    :param segstarts: list of ints. Defines where new segment files begin.
    :param wavelengths: list of lists of floats. The wavelength solutions for each segment.
    :param frames_to_reject: list of ints. Frames that will not be added into the light curve.
    :param wavbins: list of floats. The edges defining each spectroscopic light curve. The ith bin
                    will count pixels that have wavelength solution wavbins[i] <= wav < wavbins[i+1].
    :param ext_type: str. Choices are "box" or "opt".
    :return: corrected timestamps, median-normalized white light curve, and median-normalized spectroscopic light curve.
    '''
    # Get just the trace that you want to sum over.
    trace = np.ma.masked_array(segments, aperture)
    
    # Initialize 1Dspec objects.
    oneDspec = []
    central_lams = []
    times_with_skips = []

    t0 = time.time()
    masks_built_yet = 0
    if ext_type == "opt":
        spatial_profile = get_spatial_profile(trace, ext_type=ext_type)
    for k in range(np.shape(segments)[0]):
        if k in frames_to_reject:
            print("Integration %.0f will be skipped." % k)
        else:
            # Not a rejected frame, so proceed.
            print("Gathering 1D spectrum of integration %.0f..." % k)
            times_with_skips.append(times[k])
            
            # When we are at the start of a new segment, we have to rebuild the wavelength masks.
            if (k in segstarts or masks_built_yet == 0):
                print("Building wavelength masks...")
                masks = []
                for i in range(1):
                    if (k == segstarts[i] or masks_built_yet == 0):
                        wavelength = wavelengths[i]
                for j, w in enumerate(wavbins):
                    if w == wavbins[-1]:
                        # Don't build a bin at the end of the wavelength range.
                        pass
                    else:
                        central_lams.append((wavbins[j]+wavbins[j+1])/2)
                        mask_step1 = np.where(wavelength <= wavbins[j+1], wavelength, 0)
                        mask_step2 = np.where(mask_step1 >= wavbins[j], mask_step1, 0)
                        mask = np.where(mask_step2 != 0, 1, 0)
                        # Now we want to invert it, setting all 0s to 1s and vice versa.
                        mask = np.where(mask == 1, 0, 1)
                        masks.append([mask])
                masks_built_yet = 1
                print("Masks built.")
            
            if ext_type == "opt":
                profile  = spatial_profile[k,:,:]
                errors2  = errors[k,:,:]**2
                errors2 -= trace[k,:,:]
                errors2[errors2<=0] = 10**-8
                f = np.sum(trace[k,:,:], axis=0)
                V = errors2+np.abs(f[np.newaxis,:]*profile)
                trace[k,:,:] = (profile * trace[k,:,:] / V)/np.sum(profile ** 2 / V, axis=0)
            
            spectrum = []
            for mask in masks:
                spectrum.append(np.sum(np.ma.masked_array(np.ma.masked_array(np.copy(trace[k, :, :]), mask), aperture[k, :, :])))
            oneDspec.append(spectrum)
            print("Collected spectrum.")
        if (k%1000 == 0 and k != 0):
            elapsed_time = time.time()-t0
            iterrate = k/elapsed_time
            iterremain = np.shape(segments)[0] - k
            print("On integration %.0f. Elapsed time is %.3f seconds." % (k, elapsed_time))
            print("Average rate of integration processing: %.3f ints/s." % iterrate)
            print("Estimated time remaining: %.3f seconds.\n" % (iterremain/iterrate))
    print("Gathered 1D spectra in %.3f minutes." % ((time.time()-t0)/60))
    
    # We now have the oneDspec object. We're going to sum the oneDspec into a wlc,
    # then reorganize the oneDspec into a bunch of spectroscopic light curves.
    print("Producing wlc from 1D spectra...")
    wlc = []
    for spectra in oneDspec:
        wlc.append(np.sum(spectra))
    print("Generated wlc.")
    
    slc = []
    print("Reshaping oneDspec into slc...")
    for i, w in enumerate(wavbins):
        if w == wavbins[-1]:
            # Didn't build a bin for the last wavelength.
            pass
        else:
            lc = []
            for j in range(len(wlc)):
                lc.append(oneDspec[j][i])
            slc.append(np.array(lc))
    # Now each object in slc is a full time series corresponding to just one wavelength bin.
    print("Reshaped. Returning wlc and slc...")
    
    return wlc, slc, times_with_skips, np.round(central_lams, 3)

In [None]:
def get_spatial_profile(segments, ext_type="box"):
    '''
    Builds a spatial profile for extraction.
    If ext_type "box", profile is unifom 1s.
    If ext_type "opt", profile is an optimum profile.
    '''
    if ext_type == "box":
        P = np.ones_like(segments)
    if ext_type == "opt":
        P = np.empty_like(segments)
        # Iterate through frames.
        for k in range(np.shape(P)[0]):
            P[k,:,:] = polycol(segments[k,:,0], poly_order=4, threshold=3)
    return P

In [None]:
def polycol(trace, poly_order=4, threshold=3):
    '''
    Computes spatial profile fit to the trace using a polynomial of the specified order, fitting along columns.
    
    :param trace: 2D array. A frame out of the trace[y,x,t] array, form trace[y,x].
    :param order: int. Order of polynomial to fit as the profile.
    :param threshold: float. Sigma threshold at which to mask polynomial fit outliers.
    :return: P[x,y] array. An array profile to use for optimal extraction.
    '''
    # Initialize P as a list.
    P = []
    
    # Iterate on columns.
    for i in range(np.shape(trace)[1]):
        col = deepcopy(trace[:,i])
        j = 0
        while True:
            p_coeff = np.polyfit(range(np.shape(col)[0]),col,deg=poly_order)
            p_col = np.polyval(p_coeff, range(np.shape(col)[0]))

            res = np.array(col-p_col)
            dev = np.abs(res)/np.std(res)
            max_dev_idx = np.argmax(dev)

            j += 1
            if (dev[max_dev_idx] > threshold and j < 20):
                try:
                    col[max_dev_idx] = (col[max_dev_idx-1]+col[max_dev_idx+1])/2
                except IndexError:
                    col[max_dev_idx] = np.median(p_col)
                continue
            else:
                break
        P.append(p_col)
    P = np.array(P).T
    P[P < 0] = 0 # enforce positivity
    P /= np.sum(P,axis=0) # normalize on columns
    return P

In [None]:
def median_normalize(lc):
    '''
    The extract_curves function is way too big and doing way too much, so I am going to break its
    median normalization functionality into here.
    '''
    return lc/np.median(lc)

In [None]:
def fix_times(times, wlc=None, epoch=None):
    '''
    Fixes times in the times array so that the mid-transit time is 0.
    
    :param times: 1D array. Times in MJD, not corrected for mid-transit.
    :param wlc: 1D array or None. If not None and epoch is None, the epoch is defined
                as the time when wlc hits its minimum.
    :param epoch: float. If not None, the mid-transit time used to correct the times arrays.
    :return: corrected times array.
    '''
    if epoch is None:
        if wlc is not None:
            minimum_value = min(wlc)
            epoch = t[wlc.index(minimum_value)]
        else:
            epoch = np.mean(t)
        
    for i in range(len(times)):
        times[i] = times[i] - epoch
        
    return times

In [None]:
def clip_curve(lc, b, clip_at):
    '''
    Sigma clip the given light curve.
    '''
    clipcount = 0
    for i in np.arange(0, len(lc)+b, b):
        try:
            # Sigma-clip this segment of wlc and fill the clipped parts with the median.
            smed = np.median(lc[i:i+b])
            ssig = np.std(lc[i:i+b])
            clipcount += np.count_nonzero(np.where(np.abs(lc[i:i+b]-smed) > clip_at*ssig, 1, 0))
            lc[i:i+b] = np.where(np.abs(lc[i:i+b]-smed) > clip_at*ssig, smed, lc[i:i+b])
        except IndexError:
            if len(lc[i:-1]) == 0:
                pass
            else:
                lc[i:-1] = sigma_clip(lc[i:-1], sigma=clip_at)
                lc[i:-1] = lc[i:-1].filled(fill_value=np.ma.median(lc[i:-1]))
    print("Clipped %.0f values from the given light curve." % clipcount)
    return lc

In [None]:
def plot_curve(t, lc, title, outfile, outdir):
    '''
    The extract_curves function is way too big and doing way too much, so I am going to break its
    plotting functionality into here.
    '''
    #print("Plotting light curve and saving plot...")
    t0 = time.time()
    
    fig, ax = plt.subplots(figsize=(20, 5))
    ax.scatter(t, lc, s=3)
    ax.set_xlabel("time since mid-transit [days]")
    ax.set_ylabel("relative flux [no units]")
    ax.set_title(title)
    plt.savefig(os.path.join(outdir, outfile + ".pdf"), dpi=300)
    plt.close()

In [None]:
def write_light_curve(t, lc, outfile, outdir):
    '''
    Writes light curve to .txt file for future reading.
    '''
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    
    with open('{}/{}.txt'.format(outdir,outfile), mode='w') as file:
        file.write("#time[MJD]     flux[normalized]\n")
        for ti, lci in zip(t, lc):
            file.write('{:8}   {:8}\n'.format(ti, lci))

In [None]:
def doStage5(curvesdir, outdir, exoplanet_params, systematics, spectral_range,
             do_fit={"WLC_LM":True,
                     "WLC_MCMC":True,
                     "spec_LM":True,
                     "spec_MCMC":True,},
             limb_darkening_model={"model_type":"quadratic",
                                   "stellar_params":None,
                                   "initial_guess":[0.1,0.1],},
             fixed_param={"LD_coeffs":False,
                          "t0":False,
                          "period":False,
                          "aoR":False,
                          "impact":False,#"inc":False,
                          "ecc":False,
                          "lop":False,
                          "offset":False},
             MCMC_depth_type="standard",
             priors_dict={"t0":[-0.1, 0.1],
                         "period":[0, 100],
                         "aoR":[0.00001, 10],
                         "impact":[0,100],#"inc":[80,90],
                         "ecc":[0, 1],
                         "lop":[0, 90],
                         "offset":[-0.5, 0.5]},
             priors_type="uniform",
             reject_threshold=3,
             raise_alarm=10,
             exoticLD={"available":False,
                       "ld_data_path":None,
                       "ld_grid":'kurucz',
                       "ld_interpolate_type":'trilinear'},
             save_plots={"WLC_LM":True,
                         "WLC_MCMC":True,
                         "spec_LM":True,
                         "spec_MCMC":True}):
    '''
    Performs Stage 5 LM and MCMC fitting of the light curves in the specified directory.
    
    :param curvesdir: str. Where the .txt files of the curves you want to analyze are stored.
    :param exoplanet_params: dict of float. Contains keywords "t0", "period", "rp", "aoR", "impact", "ecc", "lop".
    :param systematics: tuple of float. Contains parameters for a linear-in-time fit a+b*(x-0.5).
    :param spectral_range: tuple of float. Spectral range being covered.
    :param R: float. Defines spectral resolving power.
    :param limb_darkening_model: dict. Contains "model_type" str which defines model choice (e.g. quadratic, 4-param),
                                 "stellar_params" tuple of (M_H, Teff, logg) or None if not using, "initial_guess"
                                 keyword containing tuple of floats.
    :param fixed_param. dict of bools. Keywords are parameters that can be held fixed or opened for fitting.
                        If True, parameter will be held fixed. If False, parameter is allowed to be fitted.
    :param MCMC_depth_type: str. Options are "standard" or "ldcta".
    :param exoticLD: dict. Contains "available" bool for whether EXoTiC-LD is on this system,
                     "ld_data_path" str of where the exotic_ld_data directory is located.
    :return: parameters from LM and MCMC fits, transit depths, and errors on depths.
    '''
    original_systematics = deepcopy(systematics)
    # Read out wlc first.
    wlc_path = os.path.join(curvesdir, "wlc.txt")
    wlc, times = read_light_curve(wlc_path)
    
    # Perform LM fit of wlc.
    original_exoplanet_params = deepcopy(exoplanet_params)
    if do_fit["WLC_LM"]:
        LM_priors = {}
        if priors_type == "gaussian":
            for key in priors_dict.keys():
                LM_priors[key] = [priors_dict[key][0]-3*priors_dict[key][1],
                                  priors_dict[key][0]+3*priors_dict[key][1]]
        else:
            LM_priors = priors_dict
        fit_params, fit_model, residuals, uncertainty = LMfit(times, wlc, exoplanet_params, systematics, limb_darkening_model,
                                                              fixed_param, LM_priors, exoticLD, spectral_range)
        # Update exoplanet_params based on fit_params.
        systematics = (fit_params["a"], fit_params["b"])
        for key in fit_params.keys():
            if key not in ("a", "b"):
                exoplanet_params[key] = fit_params[key]

        if save_plots["WLC_LM"]:
            WLC_fitsdir = os.path.join(outdir, "WLC_fits")
            if not os.path.exists(WLC_fitsdir):
                os.makedirs(WLC_fitsdir)
            # Plot the fit.
            plt.figure(figsize=(20,5))
            plt.errorbar(times, wlc, yerr=np.full(shape=len(wlc), fill_value=uncertainty), fmt='o', color='black', label='obs', alpha=0.1, zorder=1)
            plt.plot(times, fit_model, 'r-', label="fitted_function", zorder=10)
            plt.xlabel('time since mid-transit [days]')
            plt.ylabel('relative flux [no units]')
            plt.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left')
            plt.savefig(os.path.join(WLC_fitsdir, "wlc_fitLM.pdf"), dpi=300)
            plt.show()
            plt.close()

            # Plot the residuals.
            plt.figure(figsize=(20,5))
            plt.scatter(times, residuals)
            plt.xlabel('time since mid-transit [days]')
            plt.ylabel('normalized residuals')
            plt.savefig(os.path.join(WLC_fitsdir, "wlc_fitLM_residuals.pdf"), dpi=300)
            plt.show()
            plt.close()
    else:
        uncertainty = 0.01 # high uncertainty if not pre-fit
    
    # Perform MCMC fit of wlc with initial guesses updated from fit_params.
    original_exoplanet_params = deepcopy(exoplanet_params)
    if do_fit["WLC_MCMC"]:
        # temporary fix on aoR, inc during MCMC
        #fixed_param["aoR"] = True
        #fixed_param["impact"] = True
        #limb_darkening_model["initial_guess"] = [0.1,0.1]
        theta, depth, depth_err, model, residuals, res_err = MCMCfit(times, wlc, [uncertainty for i in wlc],
                                                                     exoplanet_params, systematics,
                                                                     limb_darkening_model,
                                                                     fixed_param, exoticLD, spectral_range,
                                                                     depth_type=MCMC_depth_type,
                                                                     priors_dict=priors_dict,
                                                                     priors_type=priors_type,
                                                                     N_walkers = 80,
                                                                     N_steps = 30000)
        try:
            print("Obtained grazing broadband depth of %.2f +/- %.2f." % (depth[2], depth_err[2]))
        except:
            print("Obtained broadband depth of %.2f +/- %.2f." % (depth, depth_err))

        # Update exoplanet_params based on theta.
        for key in theta.keys():
            exoplanet_params[key] = theta[key]

        if save_plots["WLC_MCMC"]:
            WLC_fitsdir = os.path.join(outdir, "WLC_fits")
            if not os.path.exists(WLC_fitsdir):
                os.makedirs(WLC_fitsdir)
            fig, ax = plt.subplots(figsize=(20, 5))
            ax.plot(times, wlc, lw=3)
            ax.errorbar(times, wlc, yerr=[uncertainty for i in wlc], fmt="none", capsize=3)
            ax.plot(times, model, color="red")
            ax.set_xlabel("time since mid-transit [MJD]")
            ax.set_ylabel("normalized flux [DN/s]")
            ax.set_title("White light curve with MCMC fit")
            plt.savefig(os.path.join(WLC_fitsdir, "wlc_fitMCMC.pdf"), dpi=300)
            plt.show()
            plt.close()

            # Plot the residuals.
            plt.figure(figsize=(20,5))
            plt.scatter(times, residuals)
            plt.xlabel('time since mid-transit [days]')
            plt.ylabel('normalized residuals')
            plt.savefig(os.path.join(WLC_fitsdir, "wlc_fitMCMC_residuals.pdf"), dpi=300)
            plt.show()
            plt.close()
    
    # Reset systematics.
    systematics = original_systematics
    
    # Perform LM fits on spectroscopic light curves.
    spectro_fixed_param = {"LD_coeffs":fixed_param["LD_coeffs"],
                           "t0":True,
                           "period":True,
                           "aoR":True,
                           "impact":True,#"inc":True,
                           "ecc":True,
                           "lop":True,
                           "offset":True}
    slc_paths = sorted(glob.glob(os.path.join(curvesdir, "slc*")))
    spec_uncertainties = []
    spec_updated_guesses = []
    
    original_exoplanet_params = deepcopy(exoplanet_params)
    LM_residuals = []
    if do_fit["spec_LM"]:
        LM_priors = {}
        if priors_type == "gaussian":
            for key in priors_dict.keys():
                LM_priors[key] = [priors_dict[key][0]-3*priors_dict[key][1],
                                  priors_dict[key][0]+3*priors_dict[key][1]]
        else:
            LM_priors = priors_dict
        for slc_path in slc_paths:
            # These files have names slc_#.###mu_#.###mu.txt, can get spectral range out of these.
            slc, times = read_light_curve(slc_path)
            slc_file = str.split(slc_path, sep="/")[-1]
            savetag = slc_file[4:19]
            min_wav = float(slc_file[4:9])
            max_wav = float(slc_file[12:17])
            slc_spectral_range = (min_wav, max_wav)

            exoplanet_params = deepcopy(original_exoplanet_params) # prevents original guess from being modified.
            print(exoplanet_params)
            fit_params, fit_model, residuals, uncertainty = LMfit(times, slc, exoplanet_params, systematics, limb_darkening_model,
                                                                  spectro_fixed_param, LM_priors, exoticLD, slc_spectral_range)
            spec_updated_guesses.append(deepcopy(fit_params))
            spec_uncertainties.append(uncertainty)
            LM_residuals.append(residuals)

            if save_plots["spec_LM"]:
                # Save plots of the LM fits to the spec curves.
                spec_fitsdir = os.path.join(outdir, "spec_fits")
                if not os.path.exists(spec_fitsdir):
                    os.makedirs(spec_fitsdir)
                # Plot the fit.
                plt.figure(figsize=(20,5))
                plt.errorbar(times, slc, yerr=np.full(shape=len(slc), fill_value=uncertainty), fmt='o', color='black', label='obs', alpha=0.1, zorder=1)
                plt.plot(times, fit_model, 'r-', label="fitted_function", zorder=10)
                plt.xlabel('time since mid-transit [days]')
                plt.ylabel('relative flux [no units]')
                plt.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left')
                plt.savefig(os.path.join(spec_fitsdir, "slc_fitLM_{}.pdf".format(savetag)), dpi=300)
                #plt.show()
                plt.close()

                # Plot the residuals.
                plt.figure(figsize=(20,5))
                plt.scatter(times, residuals)
                plt.xlabel('time since mid-transit [days]')
                plt.ylabel('normalized residuals')
                plt.savefig(os.path.join(spec_fitsdir, "slc_fitLM_residuals_{}.pdf".format(savetag)), dpi=300)
                #plt.show()
                plt.close()
    else:
        spec_uncertainties = [0.01 for i in slc_paths]
        spec_updated_guesses = [original_exoplanet_params for i in slc_paths]
    
    # Now do MCMC fits.
    rp_vals = []
    rp_errs = []
    alt_depths = []
    alt_depth_errs = []
    depths = []
    depth_errs = []
    wavelengths = []
    halfwidths = []
    original_exoplanet_params = deepcopy(exoplanet_params)
    n_rejected = []
    SDNRs = []
    if do_fit["spec_MCMC"]:
        for slc_path, guess, uncertainty, residual in zip(slc_paths,spec_updated_guesses, spec_uncertainties, LM_residuals):
            # These files have names slc_#.###mu_#.###mu.txt, can get spectral range out of these.
            slc, times = read_light_curve(slc_path)
            slc_file = str.split(slc_path, sep="/")[-1]
            savetag = slc_file[4:19]
            min_wav = float(slc_file[4:9])
            max_wav = float(slc_file[12:17])
            slc_spectral_range = (min_wav, max_wav)
            wavelengths.append((min_wav+max_wav)/2)
            halfwidths.append((max_wav-min_wav)/2)
            
            # Use LM residuals to spot outliers from the fit and delete them.
            slc, times, n_reject = reject_outliers(slc, times, residual, sigma=reject_threshold, raise_alarm=raise_alarm)

            # guess has to be used to update exoplanet_params.
            exoplanet_params = deepcopy(original_exoplanet_params) # prevents original guess from being modified.
            for key in exoplanet_params.keys():
                if key not in guess.keys():
                    guess[key] = exoplanet_params[key]

            theta, depth, depth_err, model, residuals, res_err = MCMCfit(times, slc, [uncertainty for i in slc],
                                                                         guess, systematics,
                                                                         limb_darkening_model,
                                                                         spectro_fixed_param, exoticLD, slc_spectral_range,
                                                                         depth_type=MCMC_depth_type,
                                                                         priors_dict=priors_dict,
                                                                         priors_type=priors_type)
            print("For wavelength range {}, {}:".format(min_wav, max_wav))
            try:
                # Outputs as 100*rp/rs, standard, ldcta.
                rp_vals.append(depth[0])
                rp_errs.append(depth_err[0])
                
                depths.append(depth[2])
                depth_errs.append(depth_err[2])
                
                alt_depths.append(depth[1])
                alt_depth_errs.append(depth_err[1])
                
                print("Obtained overlap depth %.2f +/- %.2f, at SDNR = %.2f ppm." % (depth[2], depth_err[2], np.std(residuals)*10**6))
            except:
                print("Obtained depth %.2f +/- %.2f, at SDNR = %.2f ppm." % (depth, depth_err, np.std(residuals)*10**6))
                depths.append(depth)
                depth_errs.append(depth_err)
            
            n_rejected.append(n_reject)
            SDNRs.append(np.std(residuals)*10**6)
            if save_plots["spec_MCMC"]:
                # Save plots of fits.
                spec_fitsdir = os.path.join(outdir, "spec_fits")
                if not os.path.exists(spec_fitsdir):
                    os.makedirs(spec_fitsdir)
                fig, ax = plt.subplots(figsize=(20, 5))
                ax.plot(times, slc, lw=3)
                ax.errorbar(times, slc, yerr=[uncertainty for i in slc], fmt="none", capsize=3)
                ax.plot(times, model, color="red")
                ax.set_xlabel("time since mid-transit [MJD]")
                ax.set_ylabel("normalized flux [DN/s]")
                ax.set_title("Spectroscopic light curve (%.3f micron) with MCMC fit" % wavelengths[-1])
                plt.savefig(os.path.join(spec_fitsdir, "slc_fitMCMC_{}.pdf".format(savetag)), dpi=300)
                #plt.show()
                plt.close()

                # Plot the residuals.
                plt.figure(figsize=(20,5))
                plt.scatter(times, residuals)
                plt.xlabel('time since mid-transit [days]')
                plt.ylabel('normalized residuals')
                plt.savefig(os.path.join(spec_fitsdir, "slc_fitMCMC_residuals_{}.pdf".format(savetag)), dpi=300)
                #plt.show()
                plt.close()

        if save_plots["spec_MCMC"]:
            # Save plots of spec curve MCMC results.
            try:
                fig, ax = plot_transit_spectrum(wavelengths, depths, depth_errs,
                                                ymin=0.95*min(depths), ymax=1.05*max(depths))
                spec_outdir = os.path.join(outdir, "spectrum")
                if not os.path.exists(spec_outdir):
                    os.makedirs(spec_outdir)
                plt.savefig(os.path.join(spec_outdir, "slc_transitspectrum.pdf"), dpi=300)
                plt.show()
                plt.close()
            except:
                print("Error encountered in plotting depths, passing...")
            
            try:
                if rp_vals:
                    fig, ax = plot_transit_spectrum(wavelengths, rp_vals, rp_errs,
                                                ymin=0.95*min(rp_vals), ymax=1.05*max(rp_vals))
                    spec_outdir = os.path.join(outdir, "spectrum")
                    if not os.path.exists(spec_outdir):
                        os.makedirs(spec_outdir)
                    plt.savefig(os.path.join(spec_outdir, "slc_transitspectrum_rprs.pdf"), dpi=300)
                    plt.show()
                    plt.close()
            except:
                print("Error encountered in plotting depths, passing...")
            
            try:
                if alt_depths:
                    fig, ax = plot_transit_spectrum(wavelengths, alt_depths, alt_depth_errs,
                                                ymin=0.95*min(alt_depths), ymax=1.05*max(alt_depths))
                    spec_outdir = os.path.join(outdir, "spectrum")
                    if not os.path.exists(spec_outdir):
                        os.makedirs(spec_outdir)
                    plt.savefig(os.path.join(spec_outdir, "slc_transitspectrum_standard.pdf"), dpi=300)
                    plt.show()
                    plt.close()
            except:
                print("Error encountered in plotting depths, passing...")

        # Write out transit spectrum.
        spec_outdir = os.path.join(outdir, "spectrum")
        try:
            write_transit_spectrum(wavelengths, halfwidths, depths, depth_errs, spec_outdir)
        except:
            print("Error encountered in saving depths, passing...")
        
        try:
            if rp_vals:
                spec_outdir = os.path.join(outdir, "spectrum_rprs")
                write_transit_spectrum(wavelengths, halfwidths, rp_vals, rp_errs, spec_outdir)
        except:
            print("Error encountered in saving depths, passing...")
            
        try:
            if alt_depths:
                spec_outdir = os.path.join(outdir, "spectrum_rprs2")
                write_transit_spectrum(wavelengths, halfwidths, alt_depths, alt_depth_errs, spec_outdir)
        with open(os.path.join(outdir,"{}_sigma_SDNR.txt".format(reject_threshold))) as f:
            f.write("wavelength [mu]    n_reject [na]    SDNR [ppm]:\n")
            for wav, n_reject, SDNR in zip(wavelengths, n_rejected, SDNRs):
                f.write("{}    {}    {}\n".format(wav, n_reject, SDNR))

        except:
            print("Error encountered in saving depths, passing...")

In [None]:
def read_light_curve(filepath):
    '''
    Reads out the light curve .txt located at filepath.
    
    :param filepath: str. Where the light curve .txt object is located.
    :return: lc_n, t object.
    '''
    lc = []
    t = []
    with open(filepath) as f:
            line = f.readline() 
            while line[0] == '#':
                # Read past comments.
                line = f.readline()
            while line != '':
                line = str.split(line)#, sep='   ')
                
                # Extract useful info.
                time = float(line[0]) # time in days relative to mid-transit
                flux = float(str.replace(line[1],'\n','')) # normalized flux
                
                t.append(time)
                lc.append(flux)
                
                line = f.readline()
    return np.array(lc), np.array(t)

In [None]:
def LMfit(t, lc, exoplanet_params, systematics, limb_darkening_model, fixed_param, priors_dict, exoticLD, spectral_range):
    '''
    Performs Levenberg-Marquardt fit of transit model to provided light curve.
    '''
    # Uses Levenberg-Marquardt to fit a transit model to the transit data defined by t, lc_n using
    # a, b, exoplanet_params[0]=t0, exoplanet_params[1]=rp as the initial guess and using
    # exoplanet_params = [t0, period, rp, smax, inclination, eccen, lop]
    # to describe the system. Needs wavmin, wavmax to get the ld coefficients.
    # Returns the mid-transit time fitted from this and also returns the err = np.std(residuals).
    # Saves a plot of the lc_n fit as well.
    # Input the stellar parameters M_H (metallicity in dex), Teff (temperature in K),
    # and log g (log10 of the dimensionless surface gravity) to get ld coeffs.
    # Also input the ld_data_path, the path to where the exotic-ld data is stored.
    
    # Begin unpacking systematics and stellar_params.
    a, b = systematics
    exoplanet_params["model_type"], stellar_params, exoplanet_params["LD_coeffs"] = (limb_darkening_model["model_type"],
                                                                                     limb_darkening_model["stellar_params"],
                                                                                     limb_darkening_model["initial_guess"])
    
    if (exoticLD["available"]):
        # Check for custom model.
        if exoticLD["ld_grid"] == "custom":
            file_path = exoticLD["custom_model_path"]

            s_wvs = (np.genfromtxt(file_path, skip_header = 2, usecols = [0]).T)*1e4
            s_mus = np.flip(np.genfromtxt(file_path, skip_header = 1, max_rows = 1))
            stellar_intensity = np.flip(np.genfromtxt(file_path, skip_header = 2)[:,1:],axis = 1)

            sld = SLD(ld_data_path=exoticLD["ld_data_path"], ld_model="custom",
                      custom_wavelengths=s_wvs, custom_mus=s_mus, custom_stellar_model=stellar_intensity)

            exoplanet_params["LD_coeffs"] = sld.compute_quadratic_ld_coeffs(wavelength_range=[(10**4)*spectral_range[0], (10**4)*spectral_range[1]],
                                                                            mode="JWST_NIRSpec_prism")
            exoplanet_params["LD_coeffs"] = [exoplanet_params["LD_coeffs"][0], exoplanet_params["LD_coeffs"][1]+exoplanet_params["offset"]]

        else:
        # Get LD coefficients from EXoTiC-LD using standard grids.
            M_H, Teff, logg = stellar_params

            sld = SLD(M_H=M_H, Teff=Teff, logg=logg,
                      ld_model=exoticLD["ld_grid"], ld_data_path=exoticLD["ld_data_path"],
                      interpolate_type=exoticLD["ld_interpolate_type"], verbose=True)

            if exoplanet_params["model_type"] in ("quadratic", "kipping2013"):
                exoplanet_params["LD_coeffs"] = sld.compute_quadratic_ld_coeffs(wavelength_range=[(10**4)*spectral_range[0], (10**4)*spectral_range[1]],
                                                                                mode="JWST_NIRSpec_prism", mu_min=0.0)
            if exoplanet_params["model_type"] == "square-root":
                exoplanet_params["LD_coeffs"] = sld.compute_squareroot_ld_coeffs(wavelength_range=[(10**4)*spectral_range[0], (10**4)*spectral_range[1]],
                                                                                 mode="JWST_NIRSpec_prism", mu_min=0.0)
            if exoplanet_params["model_type"] == "nonlinear":
                exoplanet_params["LD_coeffs"] = sld.compute_4_parameter_non_linear_ld_coeffs(wavelength_range=[(10**4)*spectral_range[0], (10**4)*spectral_range[1]],
                                                                                             mode="JWST_NIRSpec_prism", mu_min=0.0)
        
    # Special condition for those using Kipping2013.
    if (exoplanet_params["model_type"] == "kipping2013"):# and not exoticLD["available"]):
        print("Kipping2013 formulation found, switching coefficients...")
        using_kipping = True
        exoplanet_params["model_type"] = "quadratic"
        q1, q2 = limb_darkening_model["initial_guess"]
        u1 = 2*np.sqrt(q1)*q2
        u2 = np.sqrt(q1)*(1-2*q2)
        exoplanet_params["LD_coeffs"] = [u1, u2] # modify guess to Kipping parameters
    else:
        using_kipping = False
    
    # Set up params for batman
    params = batman.TransitParams()
    params.t0 = exoplanet_params["t0"]                       #time of inferior conjunction in days
    params.per = exoplanet_params["period"]                  #orbital period in days
    params.rp = exoplanet_params["rp"]                       #planet radius (in units of stellar radii)
    params.a = exoplanet_params["aoR"]                       #semi-major axis (in units of stellar radii)
    params.inc = (180/np.pi)*np.arccos(exoplanet_params["impact"]/params.a)#exoplanet_params["inc"]                     #orbital inclination (in degrees)
    params.ecc = exoplanet_params["ecc"]                     #eccentricity
    params.w = exoplanet_params["lop"]                       #longitude of periastron (in degrees)
    params.u = exoplanet_params["LD_coeffs"]                 #limb darkening coefficients [u1, u2]
    params.limb_dark = exoplanet_params["model_type"]        #limb darkening model
    
    print("Using ld coeffs initial guess: ", params.u)
    residuals, theta_guess, fit_theta, fit_corr_bat_lc = jwst_ls(exoplanet_params, fixed_param, priors_dict,
                                                                 a, b, t, lc, using_kipping)
    
    err = np.std(residuals)
    print("Standard deviation of the residuals: %.0f ppm" % (err*10**6))
    
    return fit_theta, fit_corr_bat_lc, residuals, err

In [None]:
def jwst_ls(exoplanet_params, fixed_param, priors_dict, a, b, t, lc, using_kipping):
    #---BATMAN MODEL-----
    params = batman.TransitParams()
    params.t0 = exoplanet_params["t0"]                       #time of inferior conjunction in days
    params.per = exoplanet_params["period"]                  #orbital period in days
    params.rp = exoplanet_params["rp"]                       #planet radius (in units of stellar radii)
    params.a = exoplanet_params["aoR"]                       #semi-major axis (in units of stellar radii)
    params.inc = (180/np.pi)*np.arccos(exoplanet_params["impact"]/params.a)#exoplanet_params["inc"]                     #orbital inclination (in degrees)
    params.ecc = exoplanet_params["ecc"]                     #eccentricity
    params.w = exoplanet_params["lop"]                       #longitude of periastron (in degrees)
    params.u = exoplanet_params["LD_coeffs"]                 #limb darkening coefficients [u1, u2]
    params.limb_dark = exoplanet_params["model_type"]        #limb darkening model
    
    m = batman.TransitModel(params, t)                       #initializes model
    f = m.light_curve(params)                                #calculates light curve
    init_bat_model = batman.TransitModel(params, t)
    
    #---DEFINE GUESSES-----
    theta_guess = {}
    
    theta_guess["a"] = a
    theta_guess["b"] = b
    theta_guess["rp"] = params.rp
    
    for parameter in fixed_param.keys():
        if not fixed_param[parameter]:
            theta_guess[parameter] = exoplanet_params[parameter]
    
    #---LEAST SQUARES FITTING-----
    flux_err = np.full(shape=len(lc), fill_value=np.std(lc))
    
    fit_theta, fit_theta_arr, num_of_ld_coeffs, modified_keys, fit_corr_bat_lc, fit_bat_lc, polyfit, original_LDs = fit_model(theta_guess, init_bat_model,
                                                                                                                              params, t, lc, flux_err, priors_dict, using_kipping)
    
    residuals = residuals_(fit_theta_arr, fit_theta, num_of_ld_coeffs,
                           modified_keys, init_bat_model, params, t, lc, original_LDs, using_kipping)
    
    return residuals, theta_guess, fit_theta, fit_corr_bat_lc

In [None]:
def fit_model(theta_guess, init_bat_model, params, t, flux, flux_err, priors_dict, using_kipping):
    original_theta_guess = deepcopy(theta_guess)
    # Need to unpack the limb darkening coefficients and build bounds object
    lower_bounds = []
    upper_bounds = []
    
    # Check limb dark type
    if (params.limb_dark == "quadratic" and using_kipping):
        print("Kipping2013 formulation found, switching priors appropriately...")
        ld_lower = 0
        ld_upper = 1
    
    else:
        ld_lower = -1
        ld_upper = 2
        
    theta_arr = []
    num_of_ld_coeffs = 0
    modified_keys = []
    for key in theta_guess.keys():
        if key=="LD_coeffs":
            for ld_coeff in theta_guess[key]:
                num_of_ld_coeffs += 1
                theta_arr.append(ld_coeff)
                modified_keys.append("LD_coeff")
                lower_bounds.append(ld_lower)
                upper_bounds.append(ld_upper)
        elif key in priors_dict.keys():
            theta_arr.append(theta_guess[key])
            modified_keys.append(key)
            lower_bounds.append(priors_dict[key][0])
            upper_bounds.append(priors_dict[key][1])
        else:
            theta_arr.append(theta_guess[key])
            modified_keys.append(key)
            lower_bounds.append(-np.inf)
            upper_bounds.append(np.inf)
    bounds=(lower_bounds, upper_bounds)
    original_LDs = deepcopy(params.u)
    opt_result = least_squares(residuals_,
                               np.array(theta_arr),
                               bounds=bounds,
                               args=(theta_guess, num_of_ld_coeffs, modified_keys, init_bat_model, params, t, flux, original_LDs, using_kipping))
    fit_theta_arr = opt_result.x
    print("Fitted: ", fit_theta_arr)
    print("Least squares finished with status:", opt_result.status)
    print("Output message: ", opt_result.message)
    print("Success status: ", opt_result.success)
    # Return fit_theta to dictionary format, and unpack fitted LD coeffs back into list.
    fit_theta_dict = {}
    if "LD_coeff" in modified_keys:
        fitted_LD_coeffs = []
    for fitted_param, modified_key in zip(fit_theta_arr, modified_keys):
        if "LD_coeff" != modified_key:
            fit_theta_dict[modified_key] = fitted_param
        else:
            fitted_LD_coeffs.append(fitted_param)
    if "LD_coeff" in modified_keys:
        fit_theta_dict["LD_coeffs"] = fitted_LD_coeffs
    fit_theta = fit_theta_dict
    
    rchi2 = (residuals_(fit_theta_arr, fit_theta, num_of_ld_coeffs, modified_keys, init_bat_model, params, t, flux, original_LDs, using_kipping)**2).sum()/(len(flux)-len(theta_guess))
    
    print('Guess', original_theta_guess)
    print('Fitted', fit_theta)
    
    fit_corr_bat_lc, fit_bat_lc, linearfit = modify_model(fit_theta, init_bat_model, params, t, original_LDs)
    res = residuals_(fit_theta_arr, fit_theta, num_of_ld_coeffs, modified_keys, init_bat_model, params, t, flux, original_LDs, using_kipping)
    
    chi2 = sum(res*res)
    print('Chi-square =', chi2)
    
    dof = len(flux)-len(theta_guess)
    print('Deg of freedom =', dof)
    
    print('Reduced Chi-square =', rchi2)
    return fit_theta, fit_theta_arr, num_of_ld_coeffs, modified_keys, fit_corr_bat_lc, fit_bat_lc, linearfit, original_LDs

In [None]:
def residuals_(theta_arr, theta, num_of_ld_coeffs, modified_keys, init_bat_model, params, t, flux, original_LDs, using_kipping):
    # theta-arr is the array which will be modified. theta is the dictionary to which these
    # these changes must be broadcast back to. Need to be delicate handling LD coeffs.
    if "LD_coeff" in modified_keys:
        LD_coeffs = []
    checked = 0
    i = 0
    for theta_arr_item, modified_key in zip(theta_arr, modified_keys):
        if "LD_coeff" not in modified_key:
            theta[modified_key] = theta_arr_item
        else:
            if using_kipping:
                if checked == 0:
                    u1 = 2*np.sqrt(theta_arr[i])*theta_arr[i+1]
                    LD_coeffs.append(u1)
                    checked += 1
                else:
                    u2 = np.sqrt(theta_arr[i-1])*(1-2*theta_arr[i])
                    LD_coeffs.append(u2)
            else:
                LD_coeffs.append(theta_arr_item)
        i += 1
    if "LD_coeff" in modified_keys:
        theta["LD_coeffs"] = LD_coeffs
    
    full_bat_model, model_lc, linearfit = modify_model(theta, init_bat_model, params, t, original_LDs)
    
    residuals = (flux-full_bat_model)
    
    return residuals

In [None]:
def modify_model(theta, init_bat_model, params, t, original_LDs):
    a = theta["a"]
    b = theta["b"]
    params.rp = theta["rp"]
    
    fit_params = theta.keys()
    if "LD_coeffs" in fit_params:
        params.u = theta["LD_coeffs"]
    if "offset" in fit_params:
        params.u = [original_LDs[0], original_LDs[1] + theta["offset"]]
    if "t0" in fit_params:
        params.t0 = theta["t0"]
    if "period" in fit_params:
        params.per = theta["period"]
    if "aoR" in fit_params:
        params.a = theta["aoR"]
    #if "inc" in fit_params:
    if "impact" in fit_params:
        params.inc = (180/np.pi)*np.arccos(theta["impact"]/params.a) #theta["inc"]
    if "ecc" in fit_params:
        params.ecc = theta["ecc"]
    if "lop" in fit_params:
        params.w = theta["lop"]
    
    model_lc = init_bat_model.light_curve(params)
    polyfit = a + b*t
    
    full_bat_model = model_lc*polyfit
    
    return full_bat_model, model_lc, polyfit

In [None]:
def MCMCfit(t, lc, err, exoplanet_params, systematics,
            limb_darkening_model, fixed_param, exoticLD, spectral_range, depth_type="standard",
            priors_dict={"t0":[-0.1, 0.1],
                         "period":[0, 100],
                         "aoR":[0.00001, 10],
                         "impact":[0, 100],#"inc":[80,90],
                         "ecc":[0, 1],
                         "lop":[0, 90],
                         "offset":[-0.5, 0.5]},
            priors_type="uniform",
            N_walkers = 32,
            N_steps = 5000):
    '''
    Fits a transit model to the given transit curve using MCMC.
    
    :param t: 1D array. Timestamps of each point in the transit curve, in days.
    :param lc: 1D array. The normalized flux at each point in the transit curve.
    :param err: 1D array. Errors on the transit curve.
    :param exoplanet_params: dict of float. Contains keywords "t0", "period", "rp", "aoR", "inc", "ecc", "lop".
    :param central_wavelength: float. The central wavelength of the transit.
    :param halfwidth: float. The halfwidth of the wavelength bin.
    :param systematics: tuple of float. Contains parameters for a linear-in-time fit a+b*(x-0.5).
    :param limb_darkening_model: dict. Contains "model_type" str which defines model choice (e.g. quadratic, 4-param),
                                 "stellar_params" tuple of (M_H, Teff, logg) or None if not using, "initial_guess"
                                 keyword containing tuple of floats. The permissable model types are "quadratic",
                                 "kipping2013", "square-root", "nonlinear".
    :param fixed_param. dict of bools. Keywords are parameters that can be held fixed or opened for fitting.
                        If True, parameter will be held fixed. If False, parameter is allowed to be fitted.
    :param exoticLD: dict. Contains "available" bool for whether EXoTiC-LD is on this system,
                     "ld_data_path" str of where the exotic_ld_data directory is located.
    :param spectral_range: tuple of float. Spectral range being covered.
    :param depth_type: str. Choices are "standard", "ldcta" "rp/rs", or "all" for (rp/rs)^2,
                       limb darkening-corrected time averaged, rp/rs, or to output all types.
                       "ldcta" is for grazing tranist geometries.
    :param priors_dict: dict. Priors for each parameter. If type "uniform", each
                        entry is a list of the endpoints. If type "gaussian", each 
                        entry is mu and sigma.
    :param priors_type: str. Choices are "uniform" or "gaussian". Type of priors to use.
    :return: theta, depth, depth_err, model, residuals, and err = sdnr of residuals.
    '''
    try:
        # Begin unpacking systematics and stellar_params.
        a, b = systematics
        exoplanet_params["model_type"], stellar_params, exoplanet_params["LD_coeffs"] = (limb_darkening_model["model_type"],
                                                                                         limb_darkening_model["stellar_params"],
                                                                                         limb_darkening_model["initial_guess"])
        
        if (exoticLD["available"]):
            # Check for custom model.
            if exoticLD["ld_grid"] == "custom":
                file_path = exoticLD["custom_model_path"]
                
                s_wvs = (np.genfromtxt(file_path, skip_header = 2, usecols = [0]).T)*1e4
                s_mus = np.flip(np.genfromtxt(file_path, skip_header = 1, max_rows = 1))
                stellar_intensity = np.flip(np.genfromtxt(file_path, skip_header = 2)[:,1:],axis = 1)
                
                sld = SLD(ld_data_path=exoticLD["ld_data_path"], ld_model="custom",
                          custom_wavelengths=s_wvs, custom_mus=s_mus, custom_stellar_model=stellar_intensity)
                
                exoplanet_params["LD_coeffs"] = sld.compute_quadratic_ld_coeffs(wavelength_range=[(10**4)*spectral_range[0], (10**4)*spectral_range[1]],
                                                                                mode="JWST_NIRSpec_prism")
                exoplanet_params["LD_coeffs"] = [exoplanet_params["LD_coeffs"][0], exoplanet_params["LD_coeffs"][1]+exoplanet_params["offset"]]

            else:
            # Get LD coefficients from EXoTiC-LD using standard grids.
                M_H, Teff, logg = stellar_params

                sld = SLD(M_H=M_H, Teff=Teff, logg=logg,
                          ld_model=exoticLD["ld_grid"], ld_data_path=exoticLD["ld_data_path"],
                          interpolate_type=exoticLD["ld_interpolate_type"], verbose=True)

                if exoplanet_params["model_type"] in ("quadratic", "kipping2013"):
                    exoplanet_params["LD_coeffs"] = sld.compute_quadratic_ld_coeffs(wavelength_range=[(10**4)*spectral_range[0], (10**4)*spectral_range[1]],
                                                                                    mode="JWST_NIRSpec_prism", mu_min=0.0)
                if exoplanet_params["model_type"] == "square-root":
                    exoplanet_params["LD_coeffs"] = sld.compute_squareroot_ld_coeffs(wavelength_range=[(10**4)*spectral_range[0], (10**4)*spectral_range[1]],
                                                                                     mode="JWST_NIRSpec_prism", mu_min=0.0)
                if exoplanet_params["model_type"] == "nonlinear":
                    exoplanet_params["LD_coeffs"] = sld.compute_4_parameter_non_linear_ld_coeffs(wavelength_range=[(10**4)*spectral_range[0], (10**4)*spectral_range[1]],
                                                                                                 mode="JWST_NIRSpec_prism", mu_min=0.0)
        
        # Special condition for those using Kipping2013.
        if (exoplanet_params["model_type"] == "kipping2013"):# and not exoticLD["available"]):
            print("Kipping2013 formulation found, switching coefficients...")
            using_kipping = True
            exoplanet_params["model_type"] = "quadratic"
            q1, q2 = limb_darkening_model["initial_guess"]
            u1 = 2*np.sqrt(q1)*q2
            u2 = np.sqrt(q1)*(1-2*q2)
            exoplanet_params["LD_coeffs"] = [u1, u2] # modify guess to Kipping parameters
        else:
            using_kipping = False
        
        # Initialize the batman transit model.
        
        params = batman.TransitParams()
        params.t0 = exoplanet_params["t0"]                       #time of inferior conjunction in days
        params.per = exoplanet_params["period"]                  #orbital period in days
        params.rp = exoplanet_params["rp"]                       #planet radius (in units of stellar radii)
        params.a = exoplanet_params["aoR"]                       #semi-major axis (in units of stellar radii)
        params.inc = (180/np.pi)*np.arccos(exoplanet_params["impact"]/params.a)#exoplanet_params["inc"]                     #orbital inclination (in degrees)
        params.ecc = exoplanet_params["ecc"]                     #eccentricity
        params.w = exoplanet_params["lop"]                       #longitude of periastron (in degrees)
        params.u = exoplanet_params["LD_coeffs"]                 #limb darkening coefficients [u1, u2]
        params.limb_dark = exoplanet_params["model_type"]        #limb darkening model
        
        print("Using ld coeffs initial guess: ", params.u)
        
        m = batman.TransitModel(params, t)                       #initializes model
        f = m.light_curve(params)                                #calculates light curve
        init_bat_model = batman.TransitModel(params, t)
        
        # Initialize the guess.
        theta_guess = {}
    
        theta_guess["a"] = a
        theta_guess["b"] = b
        theta_guess["rp"] = params.rp
    
        for parameter in fixed_param.keys():
            if not fixed_param[parameter]:
                theta_guess[parameter] = exoplanet_params[parameter]
        original_theta_guess = deepcopy(theta_guess)
        
        # Convert the guess into an array.
        theta_arr = []
        num_of_ld_coeffs = 0
        modified_keys = []
        for key in theta_guess.keys():
            if key=="LD_coeffs":
                for ld_coeff in theta_guess[key]:
                    num_of_ld_coeffs += 1
                    theta_arr.append(ld_coeff)
                    modified_keys.append("LD_coeff")
            else:
                theta_arr.append(theta_guess[key])
                modified_keys.append(key)
        theta_arr = np.array(theta_arr)
        
        # We redefine these functions in terms of the batman model.
        # Defines the log-likelihood.
        def log_likelihood(theta, modified_keys, num_of_ld_coeffs, x, y, yerr, original_LDs):
            # theta is an array. Need to turn it back into a dictionary.
            theta_dict = {}
            if "LD_coeff" in modified_keys:
                LD_coeffs = []
            for theta_arr_item, modified_key in zip(theta, modified_keys):
                if "LD_coeff" not in modified_key:
                    theta_dict[modified_key] = theta_arr_item
                else:
                    LD_coeffs.append(theta_arr_item)
            if "LD_coeff" in modified_keys:
                theta_dict["LD_coeffs"] = LD_coeffs
            
            a = theta_dict["a"]
            b = theta_dict["b"]
            params.rp = theta_dict["rp"]
            
            fit_params = theta_dict.keys()
            if "LD_coeffs" in fit_params:
                params.u = theta_dict["LD_coeffs"]
            if "offset" in fit_params:
                params.u = [original_LDs[0], original_LDs[1] + theta_dict["offset"]]
            if "t0" in fit_params:
                params.t0 = theta_dict["t0"]
            if "period" in fit_params:
                params.per = theta_dict["period"]
            if "aoR" in fit_params:
                params.a = theta_dict["aoR"]
            #if "inc" in fit_params:
            if "impact" in fit_params:
                params.inc = (180/np.pi)*np.arccos(theta_dict["impact"]/params.a)#theta_dict["inc"]
            if "ecc" in fit_params:
                params.ecc = theta_dict["ecc"]
            if "lop" in fit_params:
                params.w = theta_dict["lop"]
            
            if using_kipping:
                # Change to Kipping parameters
                q1, q2 = LD_coeffs
                u1 = 2*np.sqrt(q1)*q2
                u2 = np.sqrt(q1)*(1-2*q2)
                params.u = [u1, u2]
            
            polyfit= a + b*x
            
            model = m.light_curve(params)*polyfit
            sigma2 = yerr**2
            return -0.5 * np.sum((y - model) ** 2 / sigma2 + np.log(sigma2))

        # Defines the prior.
        def log_prior(theta, modified_keys, num_of_ld_coeffs, original_LDs):
             # theta is an array. Need to turn it back into a dictionary.
            theta_dict = {}
            prior_prob = 0 # 0 or inf for uniform, else from gaussian
            
            if "LD_coeff" in modified_keys:
                LD_coeffs = []
            for theta_arr_item, modified_key in zip(theta, modified_keys):
                if "LD_coeff" not in modified_key:
                    theta_dict[modified_key] = theta_arr_item
                else:
                    LD_coeffs.append(theta_arr_item)
            if "LD_coeff" in modified_keys:
                theta_dict["LD_coeffs"] = LD_coeffs
            
            a = theta_dict["a"]
            b = theta_dict["b"]
            params.rp = theta_dict["rp"]
            
            # Now need to check the posteriors for all of these.
            checks_on_posteriors = ["T"]
            if not -5 < a < 5:
                checks_on_posteriors.append("F")
                
            if not -5 < b < 5:
                checks_on_posteriors.append("F")
                
            if  not 0.01 < theta_dict["rp"] < 100:
                checks_on_posteriors.append("F")
            
            if using_kipping:
                umin, umax = (0, 1)
            elif (exoplanet_params["model_type"] == "quadratic" and not using_kipping):
                umin, umax = (-1, 2)
            else:
                umin, umax = (-5, 5)
            if "LD_coeff" in modified_keys:
                for u in LD_coeffs:
                    if not umin <= u <= umax:
                        checks_on_posteriors.append("F")
            
            for key in ("ecc","period","impact","lop","aoR","offset"):
                if key in modified_keys:
                    if priors_type == "uniform":
                        if not priors_dict[key][0] <= theta_dict[key] <= priors_dict[key][1]:
                            checks_on_posteriors.append("F")
                    if priors_type == "gaussian":
                        gauss_mu = priors_dict[key][0]
                        gauss_sig = priors_dict[key][1]
                        prior_prob += np.log(1.0/(np.sqrt(2*np.pi)*gauss_sig))-0.5*(theta_dict[key]-gauss_mu)**2/gauss_sig**2
            
            if "F" not in checks_on_posteriors:
                return prior_prob
            return -np.inf

        # Defines the probability.
        def log_probability(theta, modified_keys, num_of_ld_coeffs, x, y, yerr, original_LDs):
            lp = log_prior(theta, modified_keys, num_of_ld_coeffs, original_LDs)
            if not np.isfinite(lp):
                return -np.inf
            return lp + log_likelihood(theta, modified_keys, num_of_ld_coeffs, x, y, yerr, original_LDs)

        pos = theta_arr + 1e-3 * np.random.randn(N_walkers, theta_arr.shape[0])
        #pos = theta_arr + (1+err[0]*np.random.randn(N_walkers, theta_arr.shape[0]))
        nwalkers, ndim = pos.shape
        print("Fitting %.0f parameters to data..." % ndim)
        original_LDs = deepcopy(params.u)
        sampler = emcee.EnsembleSampler(
            nwalkers, ndim, log_probability, args=(modified_keys, num_of_ld_coeffs, t, lc, np.array(err), original_LDs,),)
            #moves=[(emcee.moves.DEMove(), 0.8),(emcee.moves.DESnookerMove(), 0.2),]
        #)
        sampler.run_mcmc(pos, N_steps, progress=True);
        
        
        fig, axes = plt.subplots(ndim, figsize=(10, 7), sharex=True)
        samples = sampler.get_chain()
        labels = [key for key in modified_keys]
        for i in range(ndim):
            ax = axes[i]
            ax.plot(samples[:, :, i], "k", alpha=0.3)
            ax.set_xlim(0, len(samples))
            ax.set_ylabel(labels[i])
            ax.yaxis.set_label_coords(-0.1, 0.5)

        axes[-1].set_xlabel("step number")

        n = np.shape(samples[:, :, 1])[0]*np.shape(samples[:, :, 1])[1]
        fig, axes = plt.subplots(ndim, figsize=(10, 7), sharex=False)
        samples = sampler.get_chain()
        labels = [key for key in modified_keys]
        for i in range(ndim):
            ax = axes[i]
            post = np.reshape(samples[:, :, i], (n))
            ax.hist(post, 100, alpha=0.3)
            ax.set_xlim(min(post), max(post))
            ax.set_ylabel(labels[i])

        flat_samples = sampler.get_chain(discard=int(0.2*N_steps), flat=True)
        theta = []
        posteriors = []
        for i in range(ndim):
            theta.append(np.percentile(flat_samples[:, i], 50))
            posteriors.append(flat_samples[:, i])
        
        # Need to turn theta back into a dict.
        theta_dict = {}
        if "LD_coeff" in modified_keys:
            LD_coeffs = []
        for theta_arr_item, modified_key in zip(theta, modified_keys):
            if "LD_coeff" not in modified_key:
                theta_dict[modified_key] = theta_arr_item
            else:
                LD_coeffs.append(theta_arr_item)
        if "LD_coeff" in modified_keys:
            theta_dict["LD_coeffs"] = LD_coeffs
        theta = theta_dict
        
        fig = corner.corner(
            flat_samples, labels=labels
        );
        plt.show()
        plt.close(fig)
        plt.close()
        
        # Reinitialize model.
        params = batman.TransitParams()
        params.t0 = exoplanet_params["t0"]                       #time of inferior conjunction in days
        params.per = exoplanet_params["period"]                  #orbital period in days
        params.rp = exoplanet_params["rp"]                       #planet radius (in units of stellar radii)
        params.a = exoplanet_params["aoR"]                       #semi-major axis (in units of stellar radii)
        params.inc = (180/np.pi)*np.arccos(exoplanet_params["impact"]/params.a)#exoplanet_params["inc"]                     #orbital inclination (in degrees)
        params.ecc = exoplanet_params["ecc"]                     #eccentricity
        params.w = exoplanet_params["lop"]                       #longitude of periastron (in degrees)
        params.u = exoplanet_params["LD_coeffs"]                 #limb darkening coefficients [u1, u2]
        params.limb_dark = exoplanet_params["model_type"]        #limb darkening model
        
        # Replace default params with fitted params as applicable.
        fit_params = theta.keys()
        if "LD_coeffs" in fit_params:
            params.u = theta["LD_coeffs"]
        if "offset" in fit_params:
            params.u = [original_LDs[0], original_LDs[1] + theta["offset"]]
        if "t0" in fit_params:
            params.t0 = theta["t0"]
        if "period" in fit_params:
            params.per = theta["period"]
        if "aoR" in fit_params:
            params.a = theta["aoR"]
        #if "inc" in fit_params:
        if "impact" in fit_params:
            params.inc = (180/np.pi)*np.arccos(theta["impact"]/params.a)#theta["inc"]
        if "ecc" in fit_params:
            params.ecc = theta["ecc"]
        if "lop" in fit_params:
            params.w = theta["lop"]
        params.rp=theta["rp"]
        #print(params.u,params.a,params.per,params.inc,params.t0)
        
        m = batman.TransitModel(params, t)    #initializes model
        polyfit= theta["a"] + theta["b"]*t
        model = m.light_curve(params)*polyfit

        residuals = lc - model
        err = np.std(residuals)
        print("Standard deviation of the residuals: %.0f ppm" % (err*10**6))
        
        if depth_type == "all":
            depth_types = ["rp/rs", "standard", "ldcta"]
        else:
            depth_types = [depth_type]
        depths = []
        depth_errs = []
        for depth_type in depth_types:
            if depth_type == "rp/rs":
                depth = 100*theta["rp"]
                depth_err = depth_err = 100*np.std(posteriors[2])
            if depth_type == "standard":
                depth = 100*theta["rp"]**2
                depth_err = 100*2*theta["rp"]*np.std(posteriors[2])
            if depth_type == "ldcta":
                # Code from Ryan MacDonald.
                # Compute angles from star-planet line to R_p = R_s intersection
                if "aoR" in theta.keys():
                    aoR = theta["aoR"]
                else:
                    aoR = exoplanet_params["aoR"]
                #if "inc" in theta.keys():
                if "impact" in theta.keys():
                    inc = (180/np.pi)*np.arccos(theta["impact"]/aoR)#theta["inc"]
                else:
                    inc = (180/np.pi)*np.arccos(exoplanet_params["impact"]/aoR)#exoplanet_params["inc"]
                bo  = aoR*np.cos(inc*np.pi/180)
                bo_sq = bo**2
                rs = 1
                rs_sq = rs**2
                rp = theta["rp"]
                rp_sq = rp**2

                arg_phi_1 = ((bo_sq * rs_sq) + rp_sq - rs_sq)/(2 * (bo * rs) * rp)
                arg_phi_2 = ((bo_sq * rs_sq) + rs_sq - rp_sq)/(2 * (bo * rs) * rs)

                phi_1 = np.arccos(arg_phi_1)  # Angle at planet centre
                phi_2 = np.arccos(arg_phi_2)  # Angle at star centre

                # Evaluate the overlapping area analytically
                A_overlap = (rp_sq * (phi_1 - 0.5 * np.sin(2.0 * phi_1)) +
                             rs_sq * (phi_2 - 0.5 * np.sin(2.0 * phi_2)))
                A_s = np.pi*rs_sq
                depth = 100*A_overlap/A_s

                # The tedious process of computing the depth error
                dphi_1 = (rs_sq*(bo_sq-1)-rp_sq)/(rp*np.sqrt(rs_sq*rs_sq*(-(bo_sq-1)**2)+2*rs_sq*(bo_sq+1)*rp_sq-rp_sq*rp_sq))
                dphi_2 = 2*rp/(np.sqrt(rs_sq*rs_sq*(-(bo_sq-1)**2)+2*rs_sq*(bo_sq+1)*rp_sq-rp_sq*rp_sq))
                depth_err = 100*(2*rp*(phi_1-0.5*np.sin(2 * phi_1))
                                 + rp_sq*dphi_1*(1-np.cos(2 * phi_1))
                                 + rs_sq*dphi_2*(1-np.cos(2 * phi_2)))*np.std(posteriors[2])
            depths.append(depth)
            depth_errs.append(depth_err)
        if len(depths) == 1:
            return theta, depths[0], depth_errs[0], model, residuals, err
        else:
            return theta, depths, depth_errs, model, residuals, err
    except ValueError:
        print("Encountered value error, returning Nones...")
        return None, None, None, None, None, None

In [None]:
def reject_outliers(curve, timestamps, residuals, sigma, raise_alarm):
    res_mean = np.mean(residuals)
    res_sig = np.std(residuals)
    outliers = np.where(np.abs(residuals-res_mean) > sigma*res_sig)[0]
    print("{} outliers were deleted at these times: ".format(len(outliers)), timestamps[outliers])
    if len(outliers) >= raise_alarm:
        print("alarm!")
    curve = np.delete(curve, outliers)
    timestamps = np.delete(timestamps, outliers)
    return curve, timestamps

In [None]:
def plot_transit_spectrum(wavelengths, depths, depth_errs, ymin=0, ymax=4):
    # Generates and saves a plot of the transit spectrum.
    fig, ax = plt.subplots(figsize=(20, 5))
    prox = [i for i in range(len(depths))]
    ax.scatter(prox, depths, s=5, color="k", marker="s")
    ax.errorbar(prox, depths, yerr=depth_errs, fmt="none", capsize=0, ecolor="k", elinewidth=3)
    
    ax.set_xlabel("wavelength [micron]", fontsize=14)
    ax.set_ylabel("transit depth [percent]", fontsize=14)
    #ax.set_xlim(min(wavelengths), max(wavelengths))
    ax.set_ylim(ymin, ymax)
    ax.xaxis.set_major_formatter(fsf("%.1f"))
    ax.yaxis.set_major_formatter(fsf("%.2f"))
    ax.xaxis.set_minor_formatter(nulf())
    ax.xaxis.set_minor_locator(aml(4))
    
    sparse_prox = [i for i in range(0, len(depths), 10)]
    ax.set_xticks(sparse_prox)
    xticklabels=[]
    wavelengths = np.round(np.array(wavelengths),decimals=3)
    for tick in sparse_prox:
        xticklabels.append(str(wavelengths[tick]))
    ax.set_xticklabels(xticklabels)
    
    ax.tick_params(which="both", axis="both", direction="in", pad=5, labelsize=10)
    return fig, ax

In [None]:
def write_transit_spectrum(wavelengths, halfwidths, depths, depth_errs, outdir):
    # Assuming bins of constant halfwidth, writes the transit spectrum out as
    # a POSEIDON-compatible .txt file.
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    
    with open('{}/transit_spectrum.txt'.format(outdir), mode='w') as file:
        file.write("#wavelength[micron]     halfwidth[micron]     depth[na]     uncertainty[na]\n")
        for wav, hw, depth, depth_err in zip(wavelengths, halfwidths, depths, depth_errs):
            file.write('{:5.4}   {:5.4}   {:8.6}   {:8.6}\n'.format(wav, hw, depth/100, depth_err/100))

Everything down here is for you to edit! It is for running the scripts above on a dataset of your choice.

In [None]:
# Perform Stage 1 calibrations on the file located at filepath. Output file outfile will be sent to outdir.
filepath = "./JWST_WD1856-selected/jw02358030001_04101_00001-seg001_nrs1_uncal.fits"
outfile = "WD1856" # will have suffix _rateints.fits attached to it when saved
outdir = "./rateints"

doStage1(filepath, outfile, outdir,
         group_scale={"skip":False},
         dq_init={"skip":False},
         saturation={"skip":False},
         superbias={"skip":False},
         refpix={"skip":False},
         linearity={"skip":False},
         dark_current={"skip":False},
         jump={"skip":True},
         ramp_fit={"skip":False},
         gain_scale={"skip":False},
         one_over_f={"skip":False, "bckg_rows":[1,2,3,4,5,6,-1,-2,-3,-4,-5,-6], "show":False}
         )

In [None]:
# Perform Stage 2 calibrations on the file located at filepath. Output file outfile will be sent to outdir.
datadir = "./rateints"
filepaths = sorted(glob.glob(os.path.join(datadir, "*_rateints.fits")))
outdir = "./calints"

for filepath in filepaths:
    outfile = str.replace(filepath, "rateints", "calints")
    doStage2(filepath, outfile, outdir,
             assign_wcs={"skip":False},
             extract_2d={"skip":False},
             srctype={"skip":False},
             wavecorr={"skip":False},
             flat_field={"skip":True},
             pathloss={"skip":True},
             photom={"skip":True},
             resample_spec={"skip":True},
             extract_1d={"skip":True}
             )

In [None]:
# Perform Stage 3 calibrations on all *_calints.fits files located at filepath. Output file outfile will be sent to outdir.
filesdir = "./calints"
outdir = "./postprocessed"

hcut1 = 8
hcut2 = 17

frames_to_reject = []

doStage3(filesdir, outdir,
         trace_aperture={"hcut1":hcut1,
                         "hcut2":hcut2,
                         "vcut1":0,
                         "vcut2":432},
         frames_to_reject = frames_to_reject,
         loss_stats_step={"skip":False},
         mask_flagged_pixels={"skip":False},
         iteration_outlier_removal={"skip":False, "n":2, "sigma":10},
         spatialfilter_outlier_removal={"skip":False, "sigma":4.5, "kernel":(1,15)},
         laplacianfilter_outlier_removal={"skip":True, "sigma":10},
         second_bckg_subtract={"skip":False,"bckg_rows":[0,1,2,-2,-1]},
         track_source_location={"skip":False,"reject_disper":False,"reject_spatial":False}
        )

In [None]:
# Produce wavbins with specified R, wavmin, wavmax.
wavmin = 0.6
wavmax = 5.3
R = 50

you_can_use_these_bins = []
l = wavmin
while l <= wavmax:
    you_can_use_these_bins.append(l)
    l = l + l/R

you_can_use_these_bins = np.round(you_can_use_these_bins, 3)
#you_can_use_these_bins = np.array([0.523, 0.883])
print(you_can_use_these_bins)

In [None]:
# Perform Stage 4 extractions on the file located at filepath. Output files will be sent to outdir.
filepaths = sorted(glob.glob(os.path.join(outdir, "postprocessed_*")))
outdir = "./spectra_Blouin"

wavbins = you_can_use_these_bins
epoch = 60061.50686491065 # MJD

doStage4(filepaths, outdir,
         trace_aperture={"hcut1":hcut1,
                         "hcut2":hcut2,
                         "vcut1":0,
                         "vcut2":432},
         extract_light_curves={"skip":False,
                               "wavbins":wavbins,
                               "ext_type":"box"},
         median_normalize_curves={"skip":False},
         sigma_clip_curves={"skip":True,
                            "b":100,
                            "clip_at":5},
         fix_transit_times={"skip":False,
                            "epoch":epoch},
         plot_light_curves={"skip":False},
         save_light_curves={"skip":False}
        )

In [None]:
# Perform Stage 5 light curve fitting on the .txt files located in filesdir. Output files will be sent to outdir.
filesdir = os.path.join("./spectra_Blouin", "output_txts_extraction")
outdir = "./spectra_Blouin"

# Set path to custom LD model.
path_to_custom_LD_model = "/Users/abby/code_dev/jwst/WD_1856/blouin_WD1856_LDmodel/Imu_bestfitJWST.txt"

# Initialize exoplanet_params.
exoplanet_params = {}
exoplanet_params["period"] = 1.407939217
exoplanet_params["t0"] = 0
exoplanet_params["rp"] = 7.28
exoplanet_params["aoR"] = 335.22 #349.83347059561754
#exoplanet_params["inc"] = 88.778
exoplanet_params["impact"] = 7.26 #7.92797143651862
exoplanet_params["ecc"] = 0.0
exoplanet_params["lop"] = 90
exoplanet_params["offset"] = 0.0

# Initialize systematics guess.
systematics = (1,0) # (b, m) for mx + b

# Input full wavelength range.
spectral_range = (np.min(you_can_use_these_bins),np.max(you_can_use_these_bins))

doStage5(filesdir, outdir, exoplanet_params, systematics, spectral_range,
         do_fit={"WLC_LM":True,
                 "WLC_MCMC":True,
                 "spec_LM":True,
                 "spec_MCMC":True,},
         limb_darkening_model={"model_type":"quadratic",
                               "stellar_params":None,
                               "initial_guess":[0,0],},
         fixed_param={"LD_coeffs":True,
                      "t0":False,
                      "period":True,
                      "aoR":False,
                      "impact":False,#"inc":False,
                      "ecc":True,
                      "lop":True,
                      "offset":False},
         MCMC_depth_type="all",
         priors_dict={"t0":[0.0, 0.001],
                      "period":[1.407939217, 0.000000016], #1.407939217 from ttv, 1.4079389 from wd 1856 proposal
                      "aoR":[336, 14],
                      "impact":[7.16, 0.65],#"inc":[88.778, 0.059],
                      "ecc":[0, 0.8],
                      "lop":[0, 90],
                      "offset":[0.2,0.25]},
         priors_type="gaussian",
         reject_threshold=2.5,
         raise_alarm=10,
         exoticLD={"available":True,
                   "ld_data_path":"/Users/abby/opt/anaconda3/exotic_ld_data-3.1.2",
                   "ld_grid":'custom',
                   "ld_interpolate_type":'trilinear',
                   "custom_model_path":path_to_custom_LD_model},
         save_plots={"WLC_LM":True,
                     "WLC_MCMC":True,
                     "spec_LM":True,
                     "spec_MCMC":True})