# Imports

In [None]:
import numpy as np
import astropy
from astropy.io import fits
from astropy.wcs import WCS
from astropy import units as u
import glob
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt
import os
import shutil
import math
import matplotlib.colors as mcolors
import random
from astroML.linear_model import LinearRegression, PolynomialRegression
import specutils
from specutils.fitting import fit_lines
from specutils import SpectralRegion
from specutils.analysis import line_flux
from specutils import Spectrum1D
from scipy.optimize import minimize
import sys
from astropy.modeling.polynomial import Polynomial1D
from scipy.interpolate import interp1d, interp2d
from astropy.modeling.fitting import LevMarLSQFitter, LinearLSQFitter
from astropy.modeling.models import Gaussian1D
from specutils.fitting import fit_generic_continuum
from specutils.fitting import fit_continuum
from astropy.convolution import convolve, Gaussian1DKernel
from scipy.ndimage import convolve as sci_convolve
from scipy.ndimage import convolve1d
from matplotlib.lines import Line2D
from astropy.nddata import NDData
import time
import logging as logger
from scipy import ndimage
from scipy import optimize
from scipy.linalg import lstsq
from scipy import stats
from mpl_toolkits.mplot3d import Axes3D
from specutils.analysis import centroid
from astropy.modeling import models
from scipy.ndimage import uniform_filter1d
from astropy.nddata import StdDevUncertainty
from scipy.ndimage import median_filter
from astropy.coordinates import SkyCoord
from astropy.table import Table, hstack
import RSSMOSPipeline
from RSSMOSPipeline import RSSMOSTools as rss
import importlib
import configparser

In [None]:
# makes plots larger/easier to read
font = {'family' : 'Times New Roman',
        'weight' : 'bold',
        'size'   : 20}

plt.rc('font', **font)

In [None]:
# units for flux and wavelength values, respectively.
flux_units = u.erg*(u.cm)**-2*(u.s)**-1*(u.AA)**-1
wave_units = u.AA

# SALTRSSCalPipeline

## - File Finder

In [None]:
# directory where all masks are contained
data_dir = '/home/george/Downloads/DATASETS/'
config_name = 'path_file_COSMOS-mask-B.config'
# spectroscopic data should be in  ../<mask>/<grating>/<dither>/reduced
# where reduced is the name of the folder that RSSMOSPipeline outputs into
def get_paths(data_dir, config_name):
    # open configuration file where mask name and ID are declared
    config = configparser.ConfigParser()
    config.read(data_dir+config_name)
    mask_name = config['Mask Info']['mask_name']
    mask_id = config['Mask Info']['mask_id']

    # declare paths to HETDEX catalogue (can be any catalgue you're comparing results to)
    HETDEX_path = data_dir+'HETDEX_catalogues'
    dir_mask = data_dir + mask_name

    # declare directories for each grating
    dir_red = dir_mask+'/PG0900'
    dir_green = dir_mask + '/PG2300'
    dir_blue = dir_mask + '/PG3000'
    # and each reduce folder in each dither
    l_path = f'/Dither-/reduce/{mask_name}_{mask_id}/1DSpec_2DSpec_stackAndExtract_iterative/'
    u_path = f'/Dither+/reduce/{mask_name}_{mask_id}/1DSpec_2DSpec_stackAndExtract_iterative/'

    # make dictionary of paths
    paths_dict = {'mask_name' :mask_name,
            'mask_id' :mask_id, 
            'HETDEX_path' :HETDEX_path,
            "dir_mask" :dir_mask,
            "dir_red" :dir_red,
            "dir_green" :dir_green,
            "dir_blue" :dir_blue,
            'l_path' :l_path,
            'u_path' :u_path}
    
    return paths_dict

In [None]:
def flux_calibrated_specs(SLIT):
    
    '''
    Finds flux calibrated spectra to be used for reprojection and continuum comparison.
    dir_red, dir_green, dir_blue must be strings containing the respective directories to each grating
    Declare data from all gratings and dithers
    Works if data is missing for a single dither
    '''


    
    # get files 
    paths = get_paths(data_dir,config_name)
    color = []
    #red
    if os.path.exists(paths["dir_red"]+'/Dither-/') and os.path.exists(paths["dir_red"]+'/Dither+/'):
        red_setting_l = fits.open(glob.glob(f'{paths["dir_red"]}/Dither-/FLUXcal_2D*{SLIT}.fits')[0])
        red_setting_u = fits.open(glob.glob(f'{paths["dir_red"]}/Dither+/FLUXcal_2D*{SLIT}.fits')[0])
        if glob.glob(f'{paths["dir_red"]}/Dither-/FLUXcal_1D*{SLIT}.fits') and glob.glob(f'{paths["dir_red"]}/Dither+/FLUXcal_1D*{SLIT}.fits'):
            red_setting1D_l = fits.open(glob.glob(f'{paths["dir_red"]}/Dither-/FLUXcal_1D*{SLIT}.fits')[0])
            red_setting1D_u = fits.open(glob.glob(f'{paths["dir_red"]}/Dither+/FLUXcal_1D*{SLIT}.fits')[0])
            red_setting1D = [red_setting1D_l,red_setting1D_u]
        red_setting = [red_setting_l,red_setting_u]
        color.append('red')
        color.append('red')
    else:
        
        if os.path.exists(paths["dir_red"]+'/Dither-/'):
            red_setting = [fits.open(glob.glob(f'{paths["dir_red"]}/Dither-/FLUXcal_2D*{SLIT}.fits')[0])]
            if glob.glob(f'{paths["dir_red"]}/Dither-/FLUXcal_1D*{SLIT}.fits'):
                red_setting1D = [fits.open(glob.glob(f'{paths["dir_red"]}/Dither-/FLUXcal_1D*{SLIT}.fits')[0])]
            color.append('red')
        elif os.path.exists(paths["dir_red"]+'/Dither+/'):
            red_setting = [fits.open(glob.glob(f'{paths["dir_red"]}/Dither+/FLUXcal_2D*{SLIT}.fits')[0])]
            if glob.glob(f'{paths["dir_red"]}/Dither+/FLUXcal_1D*{SLIT}.fits'):
                red_setting1D = [fits.open(glob.glob(f'{paths["dir_red"]}/Dither+/FLUXcal_1D*{SLIT}.fits')[0])]
            color.append('red')
    
    #green
    if os.path.exists(paths["dir_green"]+'/Dither-/') and os.path.exists(paths["dir_green"]+'/Dither+/'):
        green_setting_l = fits.open(glob.glob(f'{paths["dir_green"]}/Dither-/FLUXcal_2D*{SLIT}.fits')[0])
        green_setting_u = fits.open(glob.glob(f'{paths["dir_green"]}/Dither+/FLUXcal_2D*{SLIT}.fits')[0])
        if glob.glob(f'{paths["dir_green"]}/Dither-/FLUXcal_1D*{SLIT}.fits') and glob.glob(f'{paths["dir_green"]}/Dither+/FLUXcal_1D*{SLIT}.fits'):
            green_setting1D_l = fits.open(glob.glob(f'{paths["dir_green"]}/Dither-/FLUXcal_1D*{SLIT}.fits')[0])
            green_setting1D_u = fits.open(glob.glob(f'{paths["dir_green"]}/Dither+/FLUXcal_1D*{SLIT}.fits')[0])
            green_setting1D = [green_setting1D_l,green_setting1D_u]
        green_setting = [green_setting_l,green_setting_u]
        color.append('green')
        color.append('green')
    else:
        if os.path.exists(paths["dir_green"]+'/Dither-/'):
            green_setting = [fits.open(glob.glob(f'{paths["dir_green"]}/Dither-/FLUXcal_2D*{SLIT}.fits')[0])]
            if glob.glob(f'{paths["dir_green"]}/Dither-/FLUXcal_1D*{SLIT}.fits'):
                green_setting1D = [fits.open(glob.glob(f'{paths["dir_green"]}/Dither-/FLUXcal_1D*{SLIT}.fits')[0])]
            color.append('green')
        elif os.path.exists(paths["dir_green"]+'/Dither+/'):
            green_setting = [fits.open(glob.glob(f'{paths["dir_green"]}/Dither+/FLUXcal_2D*{SLIT}.fits')[0])]
            if glob.glob(f'{paths["dir_green"]}/Dither+/FLUXcal_1D*{SLIT}.fits'):
                green_setting1D = [fits.open(glob.glob(f'{paths["dir_green"]}/Dither+/FLUXcal_1D*{SLIT}.fits')[0])]
            color.append('green')
    #blue
    if os.path.exists(paths["dir_blue"]+'/Dither-/') and os.path.exists(paths["dir_blue"]+'/Dither+/'):
        blue_setting_l = fits.open(glob.glob(f'{paths["dir_blue"]}/Dither-/FLUXcal_2D*{SLIT}.fits')[0])
        blue_setting_u = fits.open(glob.glob(f'{paths["dir_blue"]}/Dither+/FLUXcal_2D*{SLIT}.fits')[0])
        if glob.glob(f'{paths["dir_blue"]}/Dither-/FLUXcal_1D*{SLIT}.fits') and glob.glob(f'{paths["dir_blue"]}/Dither+/FLUXcal_1D*{SLIT}.fits'):
            blue_setting1D_l = fits.open(glob.glob(f'{paths["dir_blue"]}/Dither-/FLUXcal_1D*{SLIT}.fits')[0])
            blue_setting1D_u = fits.open(glob.glob(f'{paths["dir_blue"]}/Dither+/FLUXcal_1D*{SLIT}.fits')[0])
            blue_setting1D = [blue_setting1D_l,blue_setting1D_u]
        blue_setting = [blue_setting_l,blue_setting_u]
        color.append('blue')
        color.append('blue')
    else:
        if os.path.exists(paths["dir_blue"]+'/Dither-/'):
            blue_setting = [fits.open(glob.glob(f'{paths["dir_blue"]}/Dither-/FLUXcal_2D*{SLIT}.fits')[0])]
            if glob.glob(f'{paths["dir_blue"]}/Dither-/FLUXcal_1D*{SLIT}.fits'):
                blue_setting1D = [fits.open(glob.glob(f'{paths["dir_blue"]}/Dither-/FLUXcal_1D*{SLIT}.fits')[0])]
            color.append('blue')
        elif os.path.exists(paths["dir_blue"]+'/Dither+/'):
            blue_setting = [fits.open(glob.glob(f'{paths["dir_blue"]}/Dither+/FLUXcal_2D*{SLIT}.fits')[0])]
            if glob.glob(f'{paths["dir_blue"]}/Dither+/FLUXcal_1D*{SLIT}.fits'):
                blue_setting1D = [fits.open(glob.glob(f'{paths["dir_blue"]}/Dither+/FLUXcal_1D*{SLIT}.fits')[0])]
            color.append('blue')
    
    full_settings = np.concatenate([red_setting,green_setting, blue_setting])
    try :
        if red_setting1D and green_setting1D and blue_setting1D:
            full_settings1D = np.concatenate([red_setting1D,green_setting1D, blue_setting1D])
            HETDEX_range_settings = np.concatenate([green_setting1D, blue_setting1D])
    except NameError:
            full_settings1D = []
            HETDEX_range_settings = []
    path_list_1D = glob.glob(f'{paths["dir_mask"]}/FLUXcal_1D*{SLIT}_comb.fits')
    if len(path_list_1D) >0 :
        combined_spec = fits.open(path_list_1D[0])
    else:
        combined_spec = None
        
    #returns from red --> blue order
    return full_settings,full_settings1D,HETDEX_range_settings, combined_spec, color, paths

## - Declare WCS

In [None]:
'''
Used to convert WCS to world coordinates (Angstrom in our case)
This only returns wavelengths for non reprojected data
For reprojected data, it returns a linear array called 'L' that needs to be put into an exponential function as described in the HDU header
'''
def WCS2World(hdu):
    wcs = WCS(hdu[0].header)
    if hdu[0].data.ndim == 2:
        pixel_coords = np.arange(hdu[0].data.shape[1])
        pixel_array = np.column_stack((pixel_coords, np.zeros_like(pixel_coords)))
        world_coords = wcs.pixel_to_world_values(pixel_array[:, 0], pixel_array[:, 1])
        L = world_coords[0]  
        return L
    elif hdu[0].data.ndim == 1:
        pixel_coords = np.arange(hdu[0].data.shape[0])
        world_coords = wcs.pixel_to_world_values(pixel_coords)
        wavelength_array = world_coords 
        return wavelength_array

## - Sensitivity Function

In [None]:
def find_sens_func(SDSS_star_folder,sci_spec, starSLITS):   

    '''
    Using the observed and known star spectra, a sensitivity function is made by dividing the observed spectra (counts) by the known
    spectra (flux units). First the 2D star spectra are summed after masking out chip gaps, and are interpolated on a wavelength range. 
    Then we take the corresponding flux calibrated spectra from SDSS and divide. Returns the sensitivity function to be used in 
    flux_calibration, as well as other parameters. 
    '''
    # pull data and file names from FITS files
    SDSS_star_path = str(SDSS_star_folder)
    star_diagnostics = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(SDSS_star_path))), 'star_diagnostics')
    os.makedirs(star_diagnostics, exist_ok=True)
    sci_spec_path = str(sci_spec)
    sci_spec = fits.open(sci_spec_path)

    # make a list of SDSS stars and corresponding observed stars
    paths=get_paths(data_dir,config_name)


    
    SDSS_star_specs = [fits.open(glob.glob(f'{SDSS_star_path}*SLIT{num}.fits')[0]) for num in starSLITS]
    observed_star_specs = [fits.open(glob.glob(f'{color}{path}2D_noSky_{paths["mask_name"]}_{paths["mask_id"]}_SLIT{num}.fits')[0]) for num in starSLITS]
    print(starSLITS)
    sci_data = sci_spec[0].data
    nmad = 1.4826 * np.median(np.abs(sci_data - np.median(sci_data)))
    atol = .00001 * nmad
    zero_mask_sci = np.isclose(sci_data, 0, atol=atol)



    # very simple chip gap zapper, as RSSMOS ouputs chipgaps = 0, and we cant average zeros
    # chip gaps will have values at zero and very close to zero, so we mask out those values and neighboring 8 pixels which might pop through
    zero_mask_sci_ext = np.zeros_like(zero_mask_sci, dtype=bool)
    for i in range(sci_data.shape[0]):
        for j in range(sci_data.shape[1]):
            if zero_mask_sci[i,j]:
              zero_mask_sci_ext[i,max(j-8,0):min(j+9, sci_data.shape[1])] = True  
    sci_data[zero_mask_sci_ext]=np.nan


    
    # initialize lists to store sens functions, tuple ranges etc
    sens_funcs = []
    sens_funcs_inverse = []
    observed_wavs = []
    ranges = []
    interp_stars = []

    # loop over all pairs of specs to generate sens functions for each star, as well as interpolation functions for each star.
    for star_spec_cal, star_spec_obs, slit in zip(SDSS_star_specs,observed_star_specs, starSLITS):
        print(slit)
        star_data_cal = star_spec_cal[1].data
        star_data_obs = star_spec_obs[0].data
        print(star_data_obs.shape)
        star_data_obs = star_data_obs.astype(float)

        # fix poor star spec sky subtraction
        star_data_obs=check_skysub(cont=5,sky=1, data = star_data_obs)
        check_skysub(cont=5,sky=1, data = star_data_obs)
        

        # mask out values very close to zero to zap chip gaps and prevent infinities  
        star_data_obs= np.nansum(star_data_obs, axis = 0)
        nmad = 1.4826 * np.median(np.abs(star_data_obs - np.median(star_data_obs)))
        atol = .02 * nmad
        zero_mask_obs = np.isclose(star_data_obs, 0, atol=atol)
        star_data_obs[zero_mask_obs]=np.nan
        # get the wavelength arrays for all stars
        obs_wavs = WCS2World(star_spec_obs)

        # get the min and max values for each stars wavelength array, then append as tuples    
        min_val = obs_wavs[0]
        max_val = obs_wavs[-1]
        ranges.append((min_val, max_val))

        file_name = os.path.basename(sci_spec_path).split('/')[-1]
        if sci_data.ndim == 1:
            sci_data = sci_spec[1].data
        elif sci_data.ndim ==2:
            sci_data = sci_spec[0].data
            SLIT = sci_spec[0].header['EXTNAME']

        #interpolate over the wavelength values from dataset to generate a model.
        calibrated_wavs = 10**star_spec_cal[1].data['loglam']
        calibrated_flux = star_spec_cal[1].data['flux']*10**-17
        for i in range(1, len(calibrated_flux)):
            #include fix to avoid infinities
            if np.isclose(calibrated_flux[i], 0, atol=10e-17):
                calibrated_flux[i] = calibrated_flux[i-1] 
        print(obs_wavs.shape)
        print(calibrated_wavs.shape)
        print(calibrated_flux.shape)
        calibrated_flux_interp_func = interp1d(calibrated_wavs,calibrated_flux, bounds_error =False, fill_value = np.nan)
        calibrated_flux_interp = calibrated_flux_interp_func(obs_wavs)
        calibrated_wavs = calibrated_wavs*wave_units
        calibrated_flux = calibrated_flux*flux_units

        #defines the sensitivity and inverse sensitivity functions
        obs_flux = (star_data_obs)*u.ph
        sensitivity_function = obs_flux/calibrated_flux_interp
        
        #append everything
        sens_funcs.append(sensitivity_function)
        observed_wavs.append(obs_wavs)
        interp_star = interp1d(obs_wavs, sensitivity_function, bounds_error =False, fill_value = np.nan)
        interp_stars.append(interp_star)


        

        

        #plots the calibrated and uncalibrated star spectrum from SALT.
        fig_star, (ax1_star, ax2_star) = plt.subplots(2, 1, sharex=True,figsize=(20,20))
        ax1_star.set_title(f'Slit {slit} Star Spectrum from SDSS')

        ax1_star.plot(10**star_spec_cal[1].data['loglam'], star_spec_cal[1].data['flux']*10**-17)   
        ax1_star.set_ylabel('Flux (erg/cm$^2$/s/Å)')
        ax2_star.set_title(f'Observed Slit {slit} Star Spectrum in PG2300')
        ax2_star.plot(obs_wavs, star_data_obs, 'darkorange') 
        ax2_star.set_xlabel('Wavelength (Å)') 
        ax2_star.set_ylabel('Counts')  
        file_info = star_spec_cal.fileinfo(0) 
        fits_file_name = os.path.basename(file_info['file'].name)
        plot_filename_base = os.path.splitext(fits_file_name)[0]
        plot_cal_filename = f'{plot_filename_base}_calibration.png'
        plot_cal_full_path = os.path.join(star_diagnostics, plot_cal_filename)

      
        plt.savefig(plot_cal_full_path, facecolor='w', transparent=False, overwrite=True)
         

        # plot the interpolated star spectrum and sensitivity function
        fig_interp, (ax1_interp, ax3_interp) = plt.subplots(2, 1, sharex=True,figsize=(20,20))
        ax1_interp.plot(calibrated_wavs, calibrated_flux, 'o', label = 'Calibrated Star Spectrum')
        ax1_interp.set_title('Star Spectrum Interpolation')
        ax1_interp.set_ylabel('Counts')  
        # create a secondary axis sharing the same x-axis
        ax2_interp = ax1_interp
        ax2_interp.plot(obs_wavs, calibrated_flux_interp, 'o', mfc='none', label = 'Interpolated Star Spectrum', color = 'darkorange')
        ax2_interp.set_ylabel('Flux (erg/cm$^2$/s/Å)')
        ax3_interp.set_title('Single Sensitivity Function')
        ax3_interp.set_xlabel('Wavelength (Å)')
        ax3_interp.plot(obs_wavs, sensitivity_function, color = 'blue')
        ax3_interp.set_ylabel('Counts/Flux (#/erg/cm$^2$/s/Å)')

        lines, labels = ax1_interp.get_legend_handles_labels()
        lines2, labels2 = ax2_interp.get_legend_handles_labels()
        lines3, labels3 = ax3_interp.get_legend_handles_labels()
        ax1_interp.legend(lines + lines3, labels + labels3, loc='upper right') 


    all_observed_wavs = [wave for array in observed_wavs for wave in array]
    min_wav = min(all_observed_wavs)
    max_wav = max(all_observed_wavs)


    final_sens=[]
    for wav in all_observed_wavs:
        sens_values = np.array([interp_star(wav) for interp_star in interp_stars])
        valid_sens_count = np.count_nonzero(~np.isnan(sens_values))
        if valid_sens_count == 0:
            final_sens.append(0)
        elif valid_sens_count == 1:
            final_sens.append(sens_values[~np.isnan(sens_values)][0])

        else:
            average_sensitivity = np.nansum(sens_values) / valid_sens_count
            final_sens.append(average_sensitivity)

    all_observed_wavs=np.asarray(all_observed_wavs)
    final_sens = np.asarray(final_sens)
    sorted_indices = np.argsort(all_observed_wavs)
    sorted_wavelengths = all_observed_wavs[sorted_indices]
    sorted_combined_sensitivity = final_sens[sorted_indices]
    labels_set = {}


    # plots the calibrated and uncalibrated star spectrum from SALT.
    fig, ax1 = plt.subplots(figsize=(20,10))
    ax1.set_title('Individual Sensitivity Functions')
    for i in range(len(observed_wavs)):
            named_colors = list(mcolors.TABLEAU_COLORS.keys())
            named_colors = [color for color in named_colors if "blue" not in color]
            random_color = random.choice(named_colors)
            label = "Constituents" if labels_set.get("Constituents") is None else " "
            ax1.plot(observed_wavs[i], sens_funcs[i], alpha=0.5, color=random_color, label=label)
            labels_set["Constituents"] = True
   
    
    ax2 = ax1
    ax2.set_title('Combined Sensitivity Function')
    ax2.plot(sorted_wavelengths, sorted_combined_sensitivity, label = 'Full Sensitivity Function', color = 'blue')
    ax2.set_xlabel('Wavelength (Å)') 
    ax2.set_ylabel('Count/Flux (#/erg/cm$^2$/s/Å)')  
   

    lines, labels = ax1.get_legend_handles_labels()
    ax1.legend(lines, labels, loc='upper right') 
    plot_sens_filename = f'combined_sens_function.png'
    plot_sens_full_path = os.path.join(star_diagnostics, plot_sens_filename)
    plt.savefig(plot_sens_full_path, facecolor='w', transparent=False, overwrite=True)
    







    
    
    return sorted_combined_sensitivity, sci_spec,sorted_wavelengths, sci_data, file_name, star_spec_cal, star_diagnostics

## - Flux Calibration

In [None]:
def custom_poly(x,coeffs):
    a,b,c = coeffs
    return a * x**2 + b * x + c

In [None]:
def objective_function(coeffs, x, y):
    return np.sum((y - custom_poly(x, coeffs))**2)

In [None]:
# used to constrain max throughput
def derivative_constraint(coeffs, x_c):
    a, b, _ = coeffs
    return 2 * a * x_c + b

In [None]:
def flux_calibration(sci_spec_path, SDSS_star_path,Order, starSLITS, max_throughput_wav):


    '''
    Uses the sensititvty function to flux calibrate science data. It is first subject to polynomial regression by first using AstroML to find an intial guess, then 
    running a scipy routine starting from the intial guess, but with the added benefit of being able to constrain the max throughput of 
    each setting. Then we calibrate the 2D science spectra column by column, as we dont expect flux to vary noticably in the spatial dimension.
    '''
    # create the sensitivity function using the 1D spectra
    sensitivity_function, sci_spec,all_observed_wavs, sci_data, file_name, star_spec_cal, star_diagnostics= find_sens_func(SDSS_star_path,sci_spec_path, starSLITS)

 
    # smooth the sensitivity function using polynomial regression
    non_zero_indices = np.nonzero(sensitivity_function)[0]
    non_zero_wavs = all_observed_wavs[non_zero_indices]
    non_zero_sens = sensitivity_function[non_zero_indices]
    x_data = non_zero_wavs
    y_data= non_zero_sens

    obs_sci_wavs = WCS2World(sci_spec)
    sci_header = sci_spec[0].header
    
    # regression using astroml to get initial guesses
    smooth_sensitivity_func = PolynomialRegression(degree = Order)
    smooth_sensitivity_func.fit(x_data[:,None], y_data)
    smooth_sensitivity_y_sci =  smooth_sensitivity_func.predict(obs_sci_wavs[:,None])  


    model_coeffs =smooth_sensitivity_func.coef_
    a_model, b_model, c_model = model_coeffs
    print(a_model, b_model, c_model)


    
    # initial guess for coefficients (a, b, c) from AstroML to be used by scipy
    initial_guess = [a_model,b_model,c_model]

    if max_throughput_wav !=None:
        x_c = max_throughput_wav
        constraints = [
        {'type': 'eq', 'fun': lambda coeffs: derivative_constraint(coeffs, x_c)},    # activate this when dealing with blue data
    ]

    else:
        constraints = []
    
    result = minimize(
        objective_function,
        initial_guess,
        args=(x_data, y_data),
        constraints=constraints,

        method = 'SLSQP',
        options={'maxiter': 100000}

    )

    if result.success:
        fitted_coeffs_scaled = result.x
        a = fitted_coeffs_scaled[0]
        b = fitted_coeffs_scaled[1]
        c = fitted_coeffs_scaled[2]
        fitted_coeffs = a,b,c
        print("Fitted coefficients:", fitted_coeffs)
    else:
        print("No solution found")
     
    
    x_full = np.linspace(x_data[0]-1000,x_data[-1]+1000, x_data.shape[0])
    y_smooth_sci = custom_poly(obs_sci_wavs, fitted_coeffs)
    y_smooth = custom_poly(x_data, fitted_coeffs)
    y_smooth_full = custom_poly(x_full, fitted_coeffs)
    
    smooth_sensitivity_y_astroml = smooth_sensitivity_func.predict(x_full[:,None])



    # plot the sensititvy function 
    plt.figure(figsize=(20,10))
    plt.xlabel('Wavelength (Å)')
    plt.ylabel('Count/Flux (#/erg/cm$^2$/s/Å)')
    plt.title('Smooth Function')
    plt.plot(x_data, y_data, label = 'Combined Sensitivity Function', color = 'blue')
    plt.plot(x_full, y_smooth_full, label = 'Smooth Function', color = 'r')
    plt.plot(x_data, y_smooth,linestyle = 'dashed', color = 'r')
    
    # plot the smooth sensitivity function, along with the polynomial for reference
    plt.plot(x_full, smooth_sensitivity_y_astroml, color = 'black', label = 'Initial Guess', linestyle='dashed')
    plt.plot([],[], label = 'Order: '+ str( Order), color = 'w')
    plt.legend()
        

    plot_smooth_sens_filename = f'combined_smooth_sens_function.png'
    plot_smooth_sens_full_path = os.path.join(star_diagnostics, plot_smooth_sens_filename)
    plt.savefig(plot_smooth_sens_full_path, facecolor='w', transparent=False, overwrite=True)
    check_skysub(cont=19, sky =5, data = sci_data)

    # put it all together and save, works for either 1D or 2D science spectra
    if sci_data.data.ndim == 1:
        plt.figure(figsize=(20,10))
        plt.plot(sci_data['LAMBDA'], sci_data['SPEC'])
        plt.xlabel('Wavelength (Å)')
        plt.ylabel('Flux (Counts)')
        plt.title('Uncalibrated Science Spectrum')
        plt.ylim(-10, 300)
    
        new_flux = (sci_data['SPEC'])/sci_interval_sens_func        
        plt.plot(sci_data['LAMBDA'], new_flux)
        plt.xlabel('Wavelength (Angstroms)')
        plt.ylabel('Flux (erg/cm$^2$/s/Å)')
        plt.title('Calibrated Science Spectrum')
    elif sci_data.data.ndim == 2:
        newFITS= fits.HDUList()
        header = sci_spec[0].header
        header.append(('FLUXUNIT', '1E-17 erg/cm^2/s/Ang'))
        hdu=fits.PrimaryHDU(None, header)
        new_flux = []
        for m in range(len(sci_data[0,:])):
            new_flux.append(sci_data[:,m]/y_smooth_sci[m]) 
        data = (new_flux)
        data_trans = np.transpose(data)
        hdu.data = data_trans
        print(hdu.shape)
        cal_plot_data = np.sum(hdu.data, axis = 0)
        newFITS.append(hdu)
        newFITS.writeto(os.path.abspath(os.path.join(SDSS_star_path, ".."))+"/FLUXcal_"+(file_name), overwrite=True)
        obs_sci_wavs = obs_sci_wavs*wave_units
        uncal_plot_data = np.sum(sci_data, axis = 0)
        path1D = os.path.abspath(os.path.join(SDSS_star_path, ".."))+"/FLUXcal_"+('1' + file_name[1:])
        cal_plot_data, newFITS1D = extract1D(newFITS, path1D, Order = 6)


        # plots the calibrated and uncalibrated star spectrum from SALT.
        fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True,figsize=(20,20))
        ax1.set_title('Observed Science Spectrum')
        ax1.plot(obs_sci_wavs, uncal_plot_data, color = 'darkorange')  
        ax1.set_ylabel('Counts')
        ax2.set_title('Flux Calibrated Science Spectrum')
        ax2.plot(obs_sci_wavs, cal_plot_data, color='blue') 
        ax2.set_xlabel('Wavelength (Å)') 
        ax2.set_ylabel('Flux (erg/cm$^2$/s/Å)')  
        

    return cal_plot_data, uncal_plot_data, obs_sci_wavs, hdu, newFITS, sci_data, star_diagnostics, y_smooth_sci

In [None]:
def check_skysub(cont, sky, data):
    '''
    Used only for star spectra. where there is poor sky subtraction at the top and bottom rows. Measures the median values of the top and bottom
    two rows and re-subtracts the negative sky values. Returns corrected data.
    '''
    plt.figure(figsize=(30,10))
    image_data = np.nan_to_num(data)
    plt.imshow(image_data,cmap='gray',aspect='auto', vmin = np.percentile(image_data, 5), vmax = np.percentile(image_data, 95))
    plt.gca().invert_yaxis()
    sky_rows_l = data[:2]
    sky_rows_u = data[-2:]
    sky_vals_l = np.nanmedian(sky_rows_l, axis =0)
    sky_vals_u = np.nanmedian(sky_rows_u, axis =0)
    sky_vals = np.minimum(sky_vals_l,sky_vals_u)
    return data-sky_vals

## - RSSMOS Extraction

In [None]:
#run RSSMOSPipelines trace fitting and extraction routine, then save 1d extractions
def extract1D(hdu,path, Order=4):
    '''
    Extraction routine that utilises RSSMOSPipeline's profile fitter which returns trace centers and sigmas at each pixel of a given 2D spectrum (hdu). The centers are then
    subject to polynomial fitting with a given Order, and a profile is made. The product of the profile and 2D data are then summed to give the
    extracted 1D spectrum, which is then saved as a fits file in 'path'. Returns the 1D extracted spectrum and its fits file. 
    '''    

    
    
    print('Now extracting 1D')
    # multiple by scalar since values are extremely small and can throw off profile fitter. 
    scalar = 1e20
    data_norm  = np.ma.masked_array(hdu[0].data*scalar, mask=np.isnan(hdu[0].data*scalar))

    # utilise RSSMOSPipeline profile fitting routine.
    importlib.reload(rss)
    rss_extract = rss.finalExtraction(data_norm)
    centers = rss_extract[3]
    sigmas = 2*rss_extract[4]
    
    # create dispersion and spatial arrays to be used for fitting. RSSMOSPipeline returns -99, -99 for masked data, to be easily masked out.
    disp=np.arange(data_norm.shape[1])
    spatial = np.arange(data_norm.shape[0])
    mask=np.greater(centers, 0)


    # fitting trace with order
    coeffs = np.polyfit(disp[mask], centers[mask], Order) 
    trace_center = np.polyval(coeffs, disp[mask])
    trace_sigma =np.median(sigmas[mask])

    # make weighted profile and sum product of profile and 2D data. Divide out original scalar to obtain 1D extracted spectra. 
    weighted_prof = np.zeros_like(data_norm)
    for i in disp:
        weighted_prof[:,i] = np.exp(-((spatial-trace_center[i])**2)/(2*sigmas[i]**2))

    signal_norm=np.nansum(data_norm*weighted_prof, axis = 0)
    signal = signal_norm/scalar


    
    # save extraction to 'path'

    # for combined spectra with saved wavelengths in second hdu. 
    if len(hdu) ==2:
        header_data = hdu[0].header.copy()
        header_wcs = hdu[1].header.copy()
        data_wcs = hdu[1].data
        
        fits_hdu = fits.PrimaryHDU(signal.filled(fill_value=np.nan), header_data)
        fits_wcs = fits.ImageHDU(data_wcs, header_wcs)
        fits1D = fits.HDUList([fits_hdu, fits_wcs])
        fits1D.writeto(path, overwrite=True)


    # for inidividual spectra that can be run with WCS2World.
    else:
        header = hdu[0].header.copy()
        fits1D = fits.HDUList()
        fits_hdu =fits.PrimaryHDU(signal.filled(fill_value=np.nan), header) 
        fits1D.append(fits_hdu)
        fits1D.writeto(path, overwrite=True)

    return signal, fits1D

## - Continuum Analysis

In [None]:
def measure_continuum(dispersion_axis, signal, order, window, spectral_region):
    '''
    Measurees the continuum level of a spectrum using Specutils, then returns a Spectrum1D object for both the original spectrum and
    a continuum-subtracted spectrum. This is because Specutils line flux routines require continuum subtracted spectra. Meanwhile, a 
    polynomial is fitted to the spectrum, and the standard deviation of the root mean squared at each pixel is used for measuring uncertainties in 
    the line flux measurements. 
    '''    

    # first convert nans to zeros
    signal = np.nan_to_num(signal)

    # fit a model and meassure residuals
    model =models.Polynomial1D(degree=order)
    fit = LevMarLSQFitter()
    model_fit = fit(model, dispersion_axis,signal)
    new_y = model_fit(dispersion_axis)
    residuals = signal-new_y
    
    # measure uncertainty using rms at each pixel
    mean_squared = median_filter(residuals**2, size=window, mode='nearest')
    rms = np.sqrt(mean_squared)*flux_units
    uncertainty = StdDevUncertainty(rms)

    # make a Spectrum1D object with spectrum, then measure its continuum
    spec1D_with_continuum = Spectrum1D(spectral_axis = dispersion_axis, flux = signal*flux_units,uncertainty=uncertainty)
    continuum_fit = fit_continuum(spec1D_with_continuum, window = spectral_region)
    continuum = continuum_fit(dispersion_axis)

    # subtract continuum then make another Spectrum1D object for continuum subtracted spectrum
    signal_no_continuum = signal-continuum.value
    spec1D_no_continuum = Spectrum1D(spectral_axis = dispersion_axis, flux = signal_no_continuum*flux_units,uncertainty=uncertainty)
    
    return spec1D_with_continuum, spec1D_no_continuum, continuum,new_y

In [None]:
def read_src(RA,Dec,HETDEX_path,version = 'v3.2'):
    '''
    Currently only usable if youre literature spectra are from HETDEX. Reads in the hetdex_sc1_{version}.escv 
    and the hetdex_sc1_spec_{}.fits files. Follows a similar structure to how HETDEX
    suggests finding information on spectra. With inputted RA and Dec values derived from HETDEX, finds the object
    and returns its spectrum, wavelength array, and other useful information to be used for continuum comparison.
    '''    

    # open source table and spec files to access spectra, wavelengths
    source_table = Table.read(os.path.join( HETDEX_path, 'hetdex_sc1_{}.ecsv'.format(version)))  
    hetdex_hdu = fits.open( os.path.join( HETDEX_path, 'hetdex_sc1_spec_{}.fits'.format(version)))
    spec = hetdex_hdu['SPEC'].data
    spec_err = hetdex_hdu['SPEC_ERR'].data
    wave_rect = hetdex_hdu['WAVELENGTH'].data*wave_units

    # uses coordinates to find target information
    source_coords = SkyCoord(ra = source_table['RA'], dec= source_table['DEC'])
    coord = SkyCoord(ra=RA*u.deg, dec=Dec*u.deg)
    sel_match = source_coords.separation(coord) < 1.*u.arcsec
    source_id = source_table['source_id'][sel_match][0]
    name = source_table['source_name'][sel_match][0]
    return source_table, hetdex_hdu, spec, spec_err, wave_rect, source_coords, coord, sel_match, source_id, name

In [None]:
def continuum_comparison(SLIT,RA,Dec, HETDEX_path,mult_bias):
    '''
    Uses specutils to measure continua over a common wavelength range and compares the measurements between literature spectra (HETDEX in this case) 
    and observed spectra from SALT. Using the same inputted path, RA, and Dec used for read_src, it finds the targets and performs measurements. 
    Given a multiplicative bias derived from literature comparison, it will correct for that, then plot the litertaure and observed spectra together
    allowing the user to see how close they are. Also plots the measured continua for comparison.
    '''   

    # finds the paths, specs, and other important information from flux_calibrated_specs and read_src
    full_settings,full_settings1D,HETDEX_range_settings, combined_spec1D, color, path_dict = flux_calibrated_specs(SLIT)
    source_table, hetdex_hdu, spec, spec_err, wave_rect, source_coords, coord, sel_match, source_id, name = read_src(RA,Dec,path_dict["HETDEX_path"],version = 'v3.2')
    HETDEX_data = spec[sel_match][0]*1e-17

    # get plots ready
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharex=True,figsize=(20,30))
    plt.ylabel('Flux (erg/cm$^2$/s/Å)')
    plt.xlabel('Wavelength (Å)')
    ax2.set_xlabel('Wavelength (Å)')
    ax2.set_ylabel('Flux (erg/cm$^2$/s/Å)')
    ax3.set_xlabel('Wavelength (Å)')
    ax3.set_ylabel('Flux (erg/cm$^2$/s/Å)')
    ax1.set_title('Observed Continua vs. HETDEX Continuum')
    ax2.set_title('Constituent Spectra vs. HETDEX Spectrum')
    ax3.set_title('Combined Spectrum vs. HETDEX Spectrum')
    ax2.plot(wave_rect, HETDEX_data, label ='HETDEX Spectrum')
    ax3.plot(wave_rect, HETDEX_data, label ='HETDEX Spectrum')

    # find min and max wavelengths for all settings that are wihtin HETDEX range (blue and green settings only)
    min_wav = min(WCS2World(hdu)[0] for hdu in HETDEX_range_settings)
    max_wav = max(WCS2World(hdu)[-1] for hdu in HETDEX_range_settings)


    # make a Spectrum1D object and fit continua for literature spectra
    HETDEX_Spec = Spectrum1D(spectral_axis = wave_rect, flux = HETDEX_data*flux_units)
    HETDEX_continuum_fit = fit_continuum(HETDEX_Spec)
    HETDEX_continuum = HETDEX_continuum_fit(wave_rect)
    ax1.plot(wave_rect, HETDEX_continuum, label = 'HETDEX Continuum')



    # run through each individual spectra that lies in the literatue spectrum range (in this case, blue and green settings only) and 
    # make Spectrum1D objects for each, measure continua, and plot. 
    sci_continua = []
    print('Green --> Blue')
    for hdu in HETDEX_range_settings:
        sci_data = hdu[0].data*mult_bias
        waves=WCS2World(hdu)*wave_units
        # we use the minimum and maximum values of the overlap wavelengths and make a spectral region from them.
        spectral_region = SpectralRegion(max(wave_rect[0], waves[0]), min(wave_rect[-1], waves[-1]))
        sci_signal_spec1d, sci_signal_nocont_spec1d, sci_continuum, sci_new_y = measure_continuum(waves, sci_data, order=6,window=100, spectral_region = spectral_region)
        ax1.plot(waves, sci_continuum, label = 'Observed Continuum')
        ax2.plot(waves, sci_data, alpha = 0.3, label = 'Individual Spectrum')

        HETDEX_continuum_fit_range = fit_continuum(HETDEX_Spec, window = spectral_region)
        spec_range = np.arange(spectral_region.lower.value, spectral_region.upper.value)*wave_units
        HETDEX_continuum_range = HETDEX_continuum_fit_range(spec_range)
        #ax1.plot(spec_range, HETDEX_continuum_range, label = 'HETDEX Continuum')
        print(f'HETDEX = {np.round(np.median(HETDEX_continuum_range)/1e-16,8)}',f'Observed = {np.round(np.median(sci_continuum)/1e-16,8)}')
        sci_continua.append(sci_continuum)
        
        



    # does the same thing as the for loop above, but for the reprojected spectrum     
    sci_data_comb = combined_spec1D[0].data*mult_bias
    waves_comb = combined_spec1D[1].data*wave_units
    spectral_region_comb = SpectralRegion(max(wave_rect[0], waves_comb[0]),min(wave_rect[-1], waves_comb[-1]))
    sci_signal_spec1d_comb, sci_signal_nocont_spec1d_comb, sci_continuum_comb, sci_new_y_comb =measure_continuum(waves_comb, sci_data_comb, order=6,window=100, spectral_region=spectral_region_comb)
    HETDEX_continuum_fit_range_comb = fit_continuum(HETDEX_Spec, window = spectral_region_comb)
    HETDEX_continuum_range_comb = HETDEX_continuum_fit_range_comb(wave_rect)
    print(f'HETDEX Combined = {np.round(np.median(HETDEX_continuum_range_comb)/1e-16,8)}',f'Observed Combined = {np.round(np.median(sci_continuum_comb)/1e-16,8)}')
    ax1.plot(waves_comb, sci_continuum_comb, label = 'Observed Reprojected Continuum')
    ax3.plot(waves_comb, sci_data_comb, label = 'Observed Reprojected Spectrum',alpha = 0.5, color = 'purple')
    ax1.set_ylim(ax2.get_ylim())
    ax3.set_ylim(ax2.get_ylim())
    plt.xlim(wave_rect[0].value-200, wave_rect[-1].value+200)
    ax1.legend(loc = 'upper left')
    ax2.legend(loc = 'upper left')
    ax3.legend(loc = 'upper left')
    #plt.xlim(4600,5250)
    plt.ylim(-1e-16,4e-16)
    plt.savefig(path_dict["dir_mask"]+"/"+f'continuum_comparison_SLIT{SLIT}', facecolor='w', transparent=False, overwrite=True)

    return HETDEX_continuum, sci_continuum_comb, sci_continua




    


## - Full-Optical Spectroscopic Combination 

In [None]:
def resolver(k, SLIT):
    '''
    Given a target (slit) and a resolution decreasing parameter k, reprojects all observed settings onto an exponential wavelength array with decreasing resolution at higher wavelengths. 
    It first makes the exponential wavelength array using the minimum and maximum value for the entire wavelength range. For each setting, 
    it will interpolate the data onto this new exponential array, with NaNs as a fill value. Once all are reprojected, they are averaged together.
    '''   


    # finds the paths, specs, and other important information from flux_calibrated_specs and read_src
    full_settings,full_settings1D,HETDEX_range_settings, combined_spec, color_list, paths_dict = flux_calibrated_specs(SLIT)
    paths = get_paths(data_dir, config_name)

    # here we find total pixels for our slit size, which is used to find central resolution values. these can be changed to cater to other programs.
    slit_width = 2*u.arcsec
    arcsec_per_pix = 0.1267*u.arcsec*(u.pix**-1)
    tot_pix = (slit_width/arcsec_per_pix)/2
    # find the minimum and maximum wavelengths   
    min_wav = min((WCS2World(hdu)[0]) for hdu in full_settings)
    max_wav = max((WCS2World(hdu)[-1]) for hdu in full_settings)

    
    print(max_wav)
    print(min_wav)
    lambda_start = min_wav
    lambda_end = max_wav
    
    
    A=lambda_start/(1+np.exp(1/k))
    numbins = int(k*np.log((lambda_end/A)-1))
    L=np.linspace(1, numbins, numbins)
    
    # calculate true wavelengths and find resolution per pixel values
    true_wavs = np.empty_like(L-1)
    resolution_pix = np.empty_like(L)
    true_wavs = A*(np.exp(L/k)+1)
    for i in range(len(L)):
        if i ==0:
            resolution_pix[i] = np.nan
        else:
            resolution_pix[i] = true_wavs[i]/np.diff(true_wavs)[i-1]

    
    # get the wavelength arrays for all settings
    resolution_settings = []
    wavelength_arrays = []
    center_vals = []
    for hdu in full_settings:
        wavelength_array=WCS2World(hdu)
        wavelength_arrays.append(wavelength_array)
        cent_val = np.median(wavelength_array)
        print(cent_val)
        center_vals.append(cent_val)
        wav_slope = hdu[0].header['CDELT1']*wave_units
        per_res_element = wav_slope*tot_pix
        R_cent = cent_val*wave_units/per_res_element
        print(R_cent)

    # get the resolutions for each setting.
    for wav, hdu in zip(wavelength_arrays, full_settings):
        res_setting = wav/hdu[0].header['CDELT1']
        resolution_settings.append(res_setting)
    
        
    # make inteprolation functions
    interpolation_functions = []
    for hdu, wavelength_array in zip(full_settings, wavelength_arrays):
        dat = hdu[0].data
        dat[dat==0]=np.nan
        interp_func = interp1d(wavelength_array,dat,bounds_error=False, fill_value=np.nan)
        interpolation_functions.append(interp_func)
    

    # generate new data projected onto an exponential wavelength array
    data_list = []
    for funcs in interpolation_functions:
        comb_data = funcs(true_wavs)
        data_list.append(comb_data)

    # just in case some data is not the same size, we can chop off data. If you want to avoid this then you must use -F key in RSSMOSPipeline
    # and make sure slit size remains consistent.
    
    # max_y_dim = max(arr.shape[0] for arr in data_list)
    # Pad each image in projected_data to match the maximum y-dimension size
    # padded_data = []
    # for arr in data_list:
    #     if arr.shape[0] < max_y_dim:
    #         padding_size = max_y_dim - arr.shape[0]
    #         padded_arr = np.pad(arr, ((0, padding_size),(0,0)), mode='constant', constant_values=np.nan)
    #         padded_data.append(padded_arr)
    #     else:
    #         padded_data.append(arr)
    # stack_data_list = data_list
    
    print(data_list[0].shape,data_list[1].shape,data_list[2].shape)
    stack_data_list = np.stack(data_list, axis = 2)


    # put it all together and save
    final_stack_data = np.zeros((stack_data_list.shape[0],stack_data_list.shape[1]))
    stack_data_list_transposed = np.moveaxis(stack_data_list, -1, 0)
    print(stack_data_list_transposed.shape)

    # some code to favor higher resolution data based on favoring_threshold, can be set to desired range but we set so green is favored over red
    # could go with multiple of resolution instead of threshold
    favoring_threshold=10500
    final_stack_data = np.full(stack_data_list_transposed.shape[1:], np.nan)
    for z in range(stack_data_list_transposed.shape[0]):
        for i in range(stack_data_list_transposed.shape[2]):
            data_slice = stack_data_list_transposed[z,:,i]
            res_slice = resolution_settings[z]

            # make masks where resolution is higher or lower than favoring_threshold. If data exists above and below favoring threshold,
            # just use higher resolution data. otherwise average normally
            high_res_mask =max(res_slice)>favoring_threshold
            low_res_mask=max(res_slice)<favoring_threshold
            if high_res_mask.any():
                high_res_avg = np.nanmean(data_slice[high_res_mask, :], axis=0)
                valid_high_res_avg = ~np.isnan(high_res_avg)
                final_stack_data[valid_high_res_avg, i] = high_res_avg[valid_high_res_avg]
            if not high_res_mask.any() and low_res_mask.any():
                low_res_avg = np.nanmean(data_slice[low_res_mask, :], axis=0)
                valid_low_res_avg = ~np.isnan(low_res_avg)
                final_stack_data[valid_low_res_avg, i] = low_res_avg[valid_low_res_avg]
                
    combined_data = final_stack_data


    # plot the wavelength vs L array.
    plt.figure(figsize=(20,10))
    plt.plot(L, true_wavs)
    plt.xlabel('L')
    plt.ylabel('Wavelength')
    plt.title('Wavelength vs. L')
    
    # plot the resolution vs wavelength for each of the six settings and for the projected wavelength
    plt.figure(figsize=(20,10))
    colors_used = []
    for arr, res, color in zip(wavelength_arrays[::-1], resolution_settings[::-1], color_list[::-1]):
        if color in colors_used:
            plt.plot(arr, res, color = color)
        else:
            plt.plot(arr, res, color = color, label = f'{color} settings')
        colors_used.append(color)
        plt.ylabel('Resolution Per Pixel')
        plt.xlabel('Wavelength')
        plt.title('Resolution Per Pixel vs. Wavelength')
    plt.plot(true_wavs, resolution_pix, label = 'Reprojected', color = 'purple')
    plt.axhline(favoring_threshold, color = 'black', linestyle = 'dashed', label = 'Threshold')
    plt.legend()
    plt.savefig(paths["dir_mask"]+f"/resolution_per_pixel_{SLIT}.png", facecolor='w', transparent=False, overwrite=True)
    

    print(f'Max Resolution = {np.max(resolution_pix[1:])}')
    print(f'Min Resolution = {np.min(resolution_pix[1:])}')
    print(f'Numbins = {numbins}')
    print(f'A = {A}')
    print(f'True Wavelengths: {true_wavs}')


    # save hdu and associated non-linear wavelength array
    for hdu in full_settings:
        new_image_header = hdu[0].header.copy()
    new_image_header['CRVAL1'] = L[0]
    new_image_header['CRPIX1'] = 1
    new_image_header['CDELT1'] = np.diff(L)[0]
    new_image_header['CD1_1'] = np.diff(L)[0]
    new_image_header['CUNIT1'] = 'Custom'
    new_image_header['A'] = A
    new_image_header['k'] = k
    new_image_header['COMMENT1'] = 'To get non-linear wavelengths, perform following: A*(exp(WCS/k)+1)'
    new_image_header['COMMENT2'] = 'Alternatively, wavelengths can be accessed in the second HDU'
    new_image_hdu = fits.PrimaryHDU(combined_data, header=new_image_header)
    new_wcs_hdu = fits.ImageHDU(true_wavs, name='WAVELENGTHS')
    new_wcs_hdu.header['EXTNAME'] = 'WAVELENGTHS'
    new_wcs_hdu.header['CUNIT'] = 'Angstrom'
    new_wcs_hdu.header['COMMENT'] = 'Non-linear wavelength array'
    hdulist = fits.HDUList([new_image_hdu,new_wcs_hdu])
    path2D = paths["dir_mask"]+f'/FLUXcal_2D_noSky_{paths["mask_name"]}_{paths["mask_id"]}_{SLIT}_comb.fits'
    hdulist.writeto(path2D, overwrite = True)
    path1D = paths["dir_mask"]+f'/FLUXcal_1D_noSky_{paths["mask_name"]}_{paths["mask_id"]}_{SLIT}_comb.fits'
    extract1D(hdulist, path1D)

    return L, true_wavs, A, numbins, full_settings, combined_data,hdulist

## - Line Inspector

In [None]:
def line_inspector(hdu_dir,SLIT, emline, line_range, obj_name, mult_bias):
    '''
    Given the directory of a target, slit number, emission line wavelength range and type, will plot the 1D spectrum and measure the 
    integrated line flux using Specutils line_flux. 
    '''   

    # open 1D fits file
    hdu = fits.open(glob.glob(f'{hdu_dir}FLUXcal_1D*SLIT{SLIT}*.fits')[0])
    signal = hdu[0].data*mult_bias
    if len(hdu) ==2:
        wavelengths = hdu[1].data

    else:
        wavelengths = WCS2World(hdu)


    plt.figure(figsize=(20,10))
    plt.xlabel('Wavelength')
    plt.ylabel('Flux (erg/cm$^2$/s/Å)') 
    plt.title(obj_name)

    # open dictionary of emission line rest wavelengths 
    path_to_rest_wavs = data_dir+'rest_wavelengths.txt'
    emline_dict = {}
    with open(path_to_rest_wavs) as emline_file:
        next(emline_file)
        for line in emline_file:
           (emline_type, rest_wav) = line.split()
           emline_dict[str(emline_type)] = float(rest_wav)
    em = emline_dict[emline]
    print(em)

    
    line_start,line_end = [float(x) for x in line_range.split('-')]
    continuum_region = SpectralRegion(wavelengths[0]*wave_units, wavelengths[-1]*wave_units)
    
    # find em line region and indices, make Spectrum1D objects both before and after continuum subtraction of spectra.
    spec1D,spec1D_nocont,continuum, new_y = measure_continuum(wavelengths*wave_units, signal, order = 6, window = 100, spectral_region = continuum_region)
    emline_indices = np.where((spec1D_nocont.spectral_axis >= line_start*wave_units) & (spec1D_nocont.spectral_axis <= line_end*wave_units))[0]
    emline_range = spec1D_nocont.spectral_axis[emline_indices]
    emline_fluxes = spec1D_nocont.flux[emline_indices]
    emline_cent = spec1D_nocont.spectral_axis[emline_indices[np.argmax(emline_fluxes)]]
    
    spectral_region = SpectralRegion(line_start*wave_units, line_end*wave_units)
    
    
    center = centroid(spec1D_nocont, spectral_region)
    
    
    # get line flux of emission line
    line_flux_cont = line_flux(spec1D, regions=spectral_region)
    line_flux_true = line_flux(spec1D_nocont, regions=spectral_region)
    
    
    plt.xlim(line_start-150,line_end+150)
    plt.axvline(line_start, label = 'emission line range')
    plt.axvline(line_end)

    # gaussian fitting of emission line
    g_init = Gaussian1D(amplitude=np.max(emline_fluxes), mean=emline_cent, stddev=1*wave_units, fixed={'mean': False})
    fit_g = fit_lines(spec1D_nocont,g_init, window=(line_start*wave_units, line_end*wave_units))
    gauss = fit_g(spec1D_nocont.spectral_axis)
    amplitude = fit_g.amplitude.value
    stddev = fit_g.stddev.value
    fit_g_area = amplitude * np.sqrt(2 * np.pi) * stddev

    
    print(f'Gaussian 1+z = {np.round(fit_g.mean/(em*wave_units),5)}')
    print(f'Centroid 1+z = {np.round(center/(em*wave_units),5)}')
    
    
    print(f'Central Gaussian Wavelength = {np.round(fit_g.mean.value, 1)}')
    print(f'Central Centroid Wavelength = {np.round(center, 1)}')
    print(line_flux_cont/1e-16)
    print(f'Line Flux = {np.round(line_flux_true/1e-16, 8)}, Error ={np.round(line_flux_true.uncertainty/1e-16,8)} ')
    print(f'Gaussian Line Flux: {np.round(fit_g_area/1e-16,8)}')
    # plt.plot([],[], label = 'Line Flux ($10^{-17} erg/cm^2/s$): '+linelegend, color = 'w')
    plt.plot(spec1D_nocont.spectral_axis,gauss, color = 'g')
    #plt.plot(spec1D_nocont.spectral_axis,continuum, color = 'r')
    #plt.plot(spec1D_nocont.spectral_axis,new_y, color = 'yellow')
    plt.plot(spec1D_nocont.spectral_axis,spec1D_nocont.flux, color = 'b', alpha =0.8)
    plt.ylim(-0.5e-16, 2.25e-16)
    plt.axhline(0, color = 'black')
    plt.legend()
    #plt.savefig(dir_mask+"/1doii.png", facecolor='w', transparent=False, overwrite=True)
    return emline_dict

# Run Below

In [None]:
# declare paths for running pipeline below. Multiplicative bias correction value needed, otherwise set as 1 for no correction.
path_dict=get_paths(data_dir, config_name)
mult_bias = 10**0.286

## - Flux Calibration

In [None]:
# declare dither, grating (color), slit, max throughput wavelength, and desired paths for SDSS star spectra and science spectra.
dith = "-"
color = path_dict["dir_green"]
SLIT =12

max_throughput_wav =5000
if dith == "+":
    path = path_dict["u_path"]

elif dith == "-":
    path = path_dict["l_path"]
SDSS_star_path = f'{color}/Dither{dith}/SDSS_Stars/'
sci_spec_path =f'{color}{path}2D_noSky_{path_dict["mask_name"]}_{path_dict["mask_id"]}_SLIT{str(SLIT)}.fits'
starSLITS = [2,8]
Order=2

In [None]:
# run flux calibration
cal_plot_data, uncal_plot_data, obs_sci_wavs, hdu, newFITS, sci_data,star_diagnostics, sensitivity_function = flux_calibration(sci_spec_path, SDSS_star_path,Order, starSLITS, max_throughput_wav)

## - Reprojection

In [None]:
# declare slit number and run reprojection
slit_num = 12
L, true_wavs, A, numbins, full_settings, combined_data,hdulist = resolver(k=9000, SLIT=f'SLIT{slit_num}')

## - Line Inspection

In [None]:
# declare path, emission line type and range of interest, object name (for titling plots), and the multiplicative bias to correct for
# to run for line inspection. 

# if checking reprojection
path = path_dict["dir_mask"]+'/'

# if checking individual settings
#path = path_dict["dir_red"]+'/Dither+/'
line_range = '5023.5-5037.5'
emline = 'SII6731'
obj_name = 'HETDEX J100041.45+021331.8'
emline_dict = line_inspector(path, 7, emline, line_range,obj_name,mult_bias)

## - Continuum Analysis

In [None]:
# declare slit num and coordinates to measure and compare continua to literature spectra

SLIT=7
RA = 150.241455078125
Dec =  2.2567691802978516

 
HETDEX_continuum, sci_continuum_comb, sci_continua= continuum_comparison(SLIT,RA,Dec,path_dict["HETDEX_path"],mult_bias)