In [None]:
#General imports
import pandas as pd
from glob import glob
from astropy.io import fits
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stats
from sklearn.cluster import DBSCAN
from astropy.stats import SigmaClip
from photutils.background import SExtractorBackground
import os
import time
from datetime import datetime
import sys

#BANZAI pipeline imports
import requests
from banzai.calibrations import make_master_calibrations
from banzai_nres import settings
from banzai import dbs
from banzai.utils.stage_utils import run_pipeline_stages
from banzai.logs import set_log_level
from banzai.context import Context
import logging
from banzai_nres.frames import NRESFrameFactory
from banzai.data import DataProduct
from banzai_nres.wavelength import IdentifyFeatures
from banzai_nres.flats import FlatLoader
from banzai_nres.wavelength import ArcLoader, LineListLoader, WavelengthCalibrate
from banzai_nres.qc.qc_wavelength import AssessWavelengthSolution

#SQAT imports
import matplotlib.dates as mdates
from photutils.segmentation import make_2dgaussian_kernel
from photutils.background import Background2D, MedianBackground
from photutils.utils import calc_total_error
from photutils.segmentation import detect_sources
from photutils.segmentation import SourceCatalog
from photutils.segmentation import deblend_sources
from astropy.convolution import convolve
from photutils.background import MADStdBackgroundRMS
from banzai.utils.stats import robust_standard_deviation


# SQAT (Spectrum Quality Analysis Tool)

This class (should) provide everything you need to analyse NRES spectra. An example of what running it might look like is shown below the class.

## NOTE:
To run a developer version of the BANZAI NRES pipeline, i.e. whatever version of code you are working on, install the pipeline as so:
```
pip install -e /path/to/banzai-nres
```

Switching branches in git will switch the version of code read by SQAT


Still working on customizability of plots a little, feel free to edit

In [None]:
class SQAT():
    
    def __init__(self, files=None):
        if files:
            self.files = sorted(files)
            self.lampflats_path = None
            self.doubles_path = None
            self.context_path = None
            
    def download_data(self, start_date=None, end_date=None, site=None, doubles_path=None, lampflats_path=None, stacked=True):
        self.stacked = stacked
        self.site = site
        self.start_date = start_date
        self.end_date = end_date
        self.doubles_path = doubles_path
        self.lampflats_path = lampflats_path
        #check format of dates
        if start_date and end_date:
            try:
                time.strptime(self.start_date, '%Y-%m-%d')
                time.strptime(self.end_date, '%Y-%m-%d')
            except ValueError:
                print('Please format input dates as YYYY-mm-dd')
        #check that site is part of list of sites
        sites_dict = {'sites': ['lsc', 'tlv', 'elp']}
        if site:
            if not self.site in sites_dict['sites']:
                print(f'Please choose a site from the following: {sites_dict["sites"]}')
        #check path format
        if doubles_path and lampflats_path:
            if self.doubles_path[-1] == '/':
                self.doubles_path = self.doubles_path[:-1]
            if self.lampflats_path[-1] == '/':
                self.lampflats_path = self.lampflats_path[:-1]
                
        #To get stacked doubles
        if self.start_date and self.end_date and self.site and self.doubles_path and self.lampflats_path and self.stacked==True:
            #query archive for frames
            archive_record = requests.get(f'https://archive-api.lco.global/frames/?reduction_level=92&site_id={self.site}&configuration_type=DOUBLE&basename=double-bin1x1-110&start={self.start_date}&end={self.end_date}&public=true').json()['results']
            for rec in archive_record:
                #Give path to write files to
                with open(f'{self.doubles_path}/{rec["filename"]}', 'wb') as f:
                    f.write(requests.get(rec['url']).content)
            #get associated lampflats
            files = glob(f'{self.doubles_path}/*.fz', recursive=True)
            for file in files:
                # get double and associated stacked lampflat
                double_frame = fits.open(file)
                super_flat = double_frame['SPECTRUM'].header.get('L1IDFLAT')[:-8] # gives the basename of the lampflat
                if super_flat is not None:
                    # query archive for frame
                    archive_record = requests.get(f'https://archive-api.lco.global/frames/?basename_exact={super_flat}').json()['results'][0]
                    # write frame to disk
                    with open(f'{self.lampflats_path}/{archive_record["filename"]}', 'wb') as f:
                        f.write(requests.get(archive_record['url']).content)
        
        #To get unstacked doubles
        elif self.start_date and self.end_date and self.site and self.doubles_path and self.lampflats_path and self.stacked==False:
            #query archive for frames
            archive_record = requests.get(f'https://archive-api.lco.global/frames/?reduction_level=92&site_id={self.site}&configuration_type=DOUBLE&basename=a92&start={self.start_date}&end={self.end_date}&public=true').json()['results']
            for rec in archive_record:
                #Give path to write files to
                with open(f'{self.doubles_path}/{rec["filename"]}', 'wb') as f:
                    f.write(requests.get(rec['url']).content)
            #get associated lampflats
            files = glob(f'{self.doubles_path}/*.fz', recursive=True)
            for file in files:
                # get double and associated stacked lampflat
                double_frame = fits.open(file)
                super_flat = double_frame['SPECTRUM'].header.get('L1IDFLAT')[:-8] # gives the basename of the lampflat
                print(super_flat)
                if super_flat is not None:
                    # query archive for frame
                    archive_record = requests.get(f'https://archive-api.lco.global/frames/?basename_exact={super_flat}').json()['results'][0]
                    # write frame to disk
                    with open(f'{self.lampflats_path}/{archive_record["filename"]}', 'wb') as f:
                        f.write(requests.get(archive_record['url']).content)
            
        else:
            print('Please Provide the date range of interest (YYYY-MM-DD), site name (eg. "lsc"), and paths to write files')
    
    def setup_pipeline(self, processed_path, db_path=None):
        print('Please do not run this more than one time. Database is set up in your current directory if not provided')     
        set_log_level('DEBUG')
        logger = logging.getLogger('banzai')
        
        self.db_path = db_path
        if self.db_path:
            if self.db_path[-1] == '/':
                self.db_path = self.db_path[:-1]
                os.environ['DB_ADDRESS'] = f'sqlite:///{db_path}/test.db'
        else:
            os.environ['DB_ADDRESS'] = 'sqlite:///test.db'
        os.environ['CONFIGDB_URL'] = 'http://configdb.lco.gtn/sites'
        os.environ['OPENTSDB_PYTHON_METRICS_TEST_MODE'] = 'True'
        os.system(f'banzai_nres_create_db --db-address={os.environ["DB_ADDRESS"]}')
        
        settings.processed_path= os.path.join(os.getcwd(), 'test_data')
        settings.fpack=True
        settings.db_address = os.environ['DB_ADDRESS']
        settings.reduction_level = 92
        
        # set up the context object.
        import banzai.main
        context = banzai.main.parse_args(settings, parse_system_args=False)
        context = vars(context)
        context['no_bpm'] = True 
        context['processed_path'] = processed_path
        context['post_to_archive'] = False
        context['no_file_cache'] = False
        self.context = Context(context)
        
        # initialize the DB with some instruments from ConfigDB
        
        os.system(f'banzai_nres_create_db --db-address={os.environ["DB_ADDRESS"]}')
        os.system(f'banzai_update_db --db-address={os.environ["DB_ADDRESS"]} --configdb-address={os.environ["CONFIGDB_URL"]}')
        
        # wow, after all that you can actually open an image!
        
        self.frame_factory = NRESFrameFactory()
    
    def run_pipeline(self, lampflats_path=None, doubles_path=None):
        #Load flats in
        try:
            self.lampflats_path
        except AttributeError:
            if not lampflats_path:
                print('Please provide a path to your lampflats')
                sys.exit(-1)
            else:
                self.lampflats_path = lampflats_path
        
        try:
            self.doubles_path
        except AttributeError:
            if not doubles_path:
                print('Please provide a path to your doubles')
                sys.exit(-1)
            else:
                self.doubles_path = doubles_path
        
        try:
            self.context
        except AttributeError:
            print('Please run the .setup_pipeline method before running the pipeline')
            sys.exit(-1)
            
        frame_factory = NRESFrameFactory()
        lamps = glob(f'{self.lampflats_path}/*.fz', recursive=True)
        lamps = sorted(lamps)
        for image_path in lamps:
            cal_image = frame_factory.open({'path': image_path}, self.context)
            dbs.save_calibration_info(cal_image.to_db_record(DataProduct(None, filename=os.path.basename(image_path),
                                                                               filepath=os.path.dirname(image_path))),
                                                                               os.environ['DB_ADDRESS'])
        feature_stage = IdentifyFeatures(self.context)
        flat_stage = FlatLoader(self.context)
        arc_stage = ArcLoader(self.context)
        line_list_stage = LineListLoader(self.context)
        wavelength_stage = WavelengthCalibrate(self.context)
        qc_stage = AssessWavelengthSolution(self.context)

        files = glob(f'{self.doubles_path}/*.fz', recursive=True)
        files = sorted(files)

        for path in files:
            image = frame_factory.open({'path': path}, self.context)
            image = flat_stage.do_stage(image) # should pull info from the flats on disk and add them to the image
            image = arc_stage.do_stage(image)
            image = line_list_stage.do_stage(image) #should pull info from the line list and add them to the image
            image = feature_stage.do_stage(image) #Create the features table
            image = wavelength_stage.do_stage(image) #should now have all the data it needs!
            image = qc_stage.do_stage(image)
            
            image.write(self.context)
           
    def get_data(self, path):
        if path[-1] == '/':
            path = path[:-1]
        self.stacked_files = glob(f'{path}/*.fz', recursive=True)
        self.header = [fits.open(f)['SPECTRUM'].header for f in self.stacked_files]
        self.times = [h['DAY-OBS'][-4:] for h in self.header]
        return
    
    def get_sources_old(self, max_dist=1, min_cluster_size=20):
        self.features = [pd.DataFrame(fits.open(f)['FEATURES'].data) for f in self.stacked_files]
        #First concatenate tables
        [f.insert(14, 'date-obs', [self.times[i]]*len(f)) for i, f in enumerate(self.features)]
        feature_stack = pd.concat(self.features)

        #Then collate using DBSCAN
        clustering = DBSCAN(eps = max_dist, min_samples = min_cluster_size).fit_predict(feature_stack.loc[:, ['id', 'xcentroid', 'ycentroid', 'wavelength', 'fiber', 'order']].to_numpy())

        feature_stack.insert(15, 'labels', clustering)
        feature_stack.sort_values('labels')
        sources = [v.reset_index() for k, v in feature_stack.groupby('labels')]

        #Fill in missing dates with nans
        for s in sources[1:]:
            missing = list(set(self.times).difference(s['date-obs']))
            for m in missing:
                s.loc[len(s)] = [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, m, np.nan]
                s.sort_values('date-obs', inplace=True)
            s.drop(['labels', 'index'], inplace=True, axis=1)
            s.rename(columns={'id':'line'})
        
        self.sources = sources
        return
    
    def get_sources_new(self, max_dist=1, min_cluster_size=20):
        self.features = [pd.DataFrame(fits.open(f)['FEATURES'].data) for f in self.stacked_files]
        #First concatenate tables
        [f.insert(12, 'date-obs', [self.times[i]]*len(f)) for i, f in enumerate(self.features)]
        feature_stack = pd.concat(self.features)

        #Then collate using DBSCAN
        clustering = DBSCAN(eps = max_dist, min_samples = min_cluster_size).fit_predict(feature_stack.loc[:, ['id', 'xcentroid', 'ycentroid', 'wavelength', 'fiber', 'order']].to_numpy())

        feature_stack.insert(12, 'labels', clustering)
        feature_stack.sort_values('labels')
        sources = [v.reset_index() for k, v in feature_stack.groupby('labels')]

        #Fill in missing dates with nans
        for s in sources[1:]:
            missing = list(set(self.times).difference(s['date-obs']))
            for m in missing:
                s.loc[len(s)] = [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, m]
                s.sort_values('date-obs', inplace=True)
            s.drop(['labels', 'index'], inplace=True, axis=1)
            s.rename(columns={'id':'line'})
        
        self.sources = sources
        return
    
    def ds9_to_python_id(self, coords):
        new = []
        if coords==None:
            raise ValueError('Please Provide set of ds9 coordinates')
        else:
            pass
        try:
            for coord in coords:
                new.append([coord[0]-1, coord[1]-1])
        except TypeError:
            new = [coords[0]-1, coords[1]-1]
        return new

    def search(self, coords, tbl, r=3):
        #Search on coord at a time
        new_coords = self.ds9_to_python_id(coords)
        table = tbl.dropna()
        x_table = np.mean(table['xcentroid'])
        y_table = np.mean(table['ycentroid'])
        try:
            res = []
            for coord in new_coords:
                if (coord[0] - x_table)**2 + (coord[1] - y_table)**2 <= r**2:
                    res.append(tbl)
        except TypeError:
            if (new_coords[0] - x_table)**2 + (new_coords[1] - y_table)**2 <= r**2:
                res = tbl
        return res
    
    def extract(self):
        bpm_mask = np.array(fits.open(self.stacked_files[0])['BPM'].data, dtype=bool)
        tables = []
        file_data = [fits.open(f)['SPECTRUM'].data for f in self.stacked_files]
        for data in file_data:
            #Background estimation
            #bkg_estimator = MedianBackground()
            bkg = Background2D(data, (50, 50), filter_size=(3, 3)) #This is just an example not actual background estimation
            data -= bkg.background  # subtract the background
            data -= bkg.background_rms
            threshold = 1.5*bkg.background_rms
            
            #Image segmentation
            kernel = make_2dgaussian_kernel(5, size=3, mode='center')
            convolved_data = convolve(data, kernel)
            
            #Get total error
            err = calc_total_error(data, bkg.background_rms, 1)
            
            print('Making a segment map')
            print('Masking')
            segment_map = detect_sources(convolved_data, threshold, npixels = 5, connectivity = 4, mask = bpm_mask)
            
            print('Deblending')
            segm_deblend = deblend_sources(convolved_data, segment_map, npixels=5, nlevels=64, contrast=0.001, progress_bar=False)
            segment_map = segm_deblend
            cat = SourceCatalog(data, segment_map, convolved_data=convolved_data, error=err)
            xerr = np.sqrt(cat.covar_sigx2)
            yerr = np.sqrt(cat.covar_sigy2)
            tbl = cat.to_table()
            tbl.add_columns([xerr, yerr], names = ['xcentroid_err', 'ycentroid_err'])
            tables.append(tbl.to_pandas())
        
        header = [fits.open(f)['SPECTRUM'].header for f in self.stacked_files]
        times = [h['DAY-OBS'][-4:] for h in header]

        self.features = [t.drop(t.iloc[:, [0,3,4,5,6,7,8,9,10,11,13,14,15,16,17,18,19]], axis = 1) for t in tables]
        [f.insert(0, 'date-obs', [times[i]]*len(f)) for i, f in enumerate(self.features)]

        feature_stack = pd.concat(self.features)

        #Then collate using DBSCAN
        clustering = DBSCAN(eps = 1, min_samples = 20).fit_predict(feature_stack.loc[:, ['xcentroid', 'ycentroid', 'eccentricity']].to_numpy())


        feature_stack.insert(0, 'labels', clustering)
        feature_stack.sort_values('labels')
        sources = [v.reset_index() for k, v in feature_stack.groupby('labels')]

        #Fill in missing dates with nans
        for s in sources[1:]:
            missing = list(set(times).difference(s['date-obs']))
            for m in missing:
                s.loc[len(s)] = [np.nan, np.nan, m, np.nan, np.nan, np.nan, np.nan, np.nan]
                s.sort_values('date-obs', inplace=True)
            s.drop(['labels', 'index'], inplace=True, axis=1)
        self.sources = sources
        return
    
    #-------------------PLOTTING--------------------------------
    plt.style.use('seaborn')
    
    def wavelength_std_plot(self, clim = None):
        wave = [fits.open(f)['WAVELENGTH'].data for f in self.stacked_files]
        stdim = np.std(wave, axis = 0)
        median = np.median(stdim)
        std = scipy.stats.median_abs_deviation(stdim, axis=None)
        if clim:
            plt.imshow(stdim, origin='lower', cmap = 'inferno', clim=clim)
        else:
            plt.imshow(stdim, origin='lower', cmap = 'inferno', clim=[median - 3*std, median+3*std])
        plt.grid(visible=False)
        plt.colorbar()
        plt.show()
        
        plt.hist(stdim.flatten(), range=(0.0003, .015), bins = 100)
        plt.title('Noise Distribution')
        plt.xlabel('Standard Deviation (pixel)')
        plt.ylabel('Number of wavellengths')
        plt.show()
    
    def features_std(self):
        std = []
        for tbl in self.sources[1:]:
            wavelength = tbl.dropna()['wavelength']
            if len(wavelength)>1:
                std.append(np.std(wavelength))
        plt.title('Standard Deviation of calculated wavelengths of features')
        plt.xlabel('Standard Deviation')
        plt.ylabel('Number of Features')
        plt.hist(std, bins = int(len(self.sources)/100))
        plt.show()
    
    def get_some_features(self, coords):
        self.found = []
        try:
            for coord in coords:
                for tbl in self.sources:
                    finds = self.search(coord, tbl, r=5)
                    if len(finds)!=0:
                        self.found.append(finds)
        except TypeError:
            for tbl in self.sources:
                finds = self.search(coord, tbl, r=5)
                if len(finds)!=0:
                    self.found.append(finds)
        return self.found
                    
    def plot_feature_coords(self, coords):
        plt.style.use('seaborn')
       
        t0 = datetime.strptime(self.times[0], '%m%d')
        t1 = datetime.strptime(self.times[-1], '%m%d')
        for f in self.found:
            fig, (ax1, ax2, ax3) = plt.subplots(1,  3, figsize=(20,5), sharey=False)
            x = f['date-obs']
            locator = mdates.AutoDateLocator(minticks=len(x), maxticks=len(x)).get_locator(t0, t1)
            ax1.errorbar(x, f['xcentroid']-np.nanmedian(f['xcentroid']), yerr = f['centroid_err'], ms = 6, fmt='.')
            ax2.errorbar(x, f['ycentroid']-np.nanmedian(f['ycentroid']), yerr = f['centroid_err'], ms = 6, fmt='.')
            ax3.errorbar(x, f['wavelength']-np.nanmedian(f['wavelength']), yerr = None, ms = 6, fmt='.')
            
            ax1.set_title('x position of ThAr cal points '+ str(f.dropna()['xcentroid'][0]))
            ax2.set_title('y position of ThAr cal points ' + str(f.dropna()['ycentroid'][0]))
            ax3.set_title('wavelength of feature ' + str(f.dropna()['wavelength'][0]))
            
            ax1.tick_params(axis = 'x', rotation=70, labelsize=7)
            ax2.tick_params(axis = 'x', rotation=70, labelsize=7)
            ax3.tick_params(axis = 'x', rotation=70, labelsize=7)
            ax1.xaxis.set_major_locator(locator)
            ax1.xaxis.set_minor_locator(locator)
            ax2.xaxis.set_major_locator(locator)
            ax2.xaxis.set_minor_locator(locator)
            ax3.xaxis.set_major_locator(locator)
            ax3.xaxis.set_minor_locator(locator)
            plt.setp((ax1, ax2), ylabel='pixel value-median', )
            plt.setp(ax3, ylabel = 'wavelength - median')
            plt.show()
        
    def plot_extracted_coords(self, coords):
        
        plt.style.use('seaborn')
        for f in self.found:
            fig, (ax1, ax2) = plt.subplots(1,  2, figsize=(20,5), sharey=False)
            x = f['date-obs']
            ax1.errorbar(x, f['xcentroid']-np.nanmedian(f['xcentroid']), ms = 6, fmt='.')
            ax2.errorbar(x, f['ycentroid']-np.nanmedian(f['ycentroid']), ms = 6, fmt='.')
            
            ax1.set_title('x position of ThAr cal points '+ str(f.dropna()['xcentroid'][0]))
            ax2.set_title('y position of ThAr cal points ' + str(f.dropna()['ycentroid'][0]))
            
            ax1.tick_params(axis = 'x', rotation=70, labelsize=7)
            ax2.tick_params(axis = 'x', rotation=70, labelsize=7)
            plt.setp((ax1, ax2), ylabel='pixel value-median', xlabel= 'Date-Obs')
            plt.show()
    
    def centroid_std_map(self, cmin=0, cmax=0.5):
        plt.style.use('seaborn')
        xcentroid_stds = []
        ycentroid_stds = []
        x=[]
        y=[]
        for s in self.sources[1:]:
            xcentroid_stds.append(np.std(s['xcentroid']))
            ycentroid_stds.append(np.std(s['ycentroid']))
            x.append(np.nanmean(s['xcentroid']))
            y.append(np.nanmean(s['ycentroid']))
        plt.style.use('dark_background')
        plt.figure(figsize=(30,30))
        plt.scatter(x, y, c=xcentroid_stds, cmap='inferno', vmin=cmin, vmax=cmax)
        plt.axis('off')
        plt.grid('off')
        cbar = plt.colorbar()
        tick_font_size = 30
        cbar.ax.tick_params(labelsize=tick_font_size)
        plt.show()
    
    def rmse(self, data):
        mean = np.nanmean(data)
        return np.sqrt(np.sum((data-mean)**2)/len(data))
    
    def centroid_rmse_plot(self, x_range = (0,1)):
        rmse_x = []
        rmse_y = []
        for s in self.sources[1:]:
            rmse_x.append(self.rmse(s['xcentroid'].to_numpy()))
            rmse_y.append(self.rmse(s['ycentroid'].to_numpy()))
    
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 12), sharex=True)
        ax1.hist(rmse_x, bins = 100, label='xcentroid', range=x_range)
        ax2.hist(rmse_y, bins = 100, label='ycentroid', range=x_range)
        ax1.legend()
        ax2.legend()
        plt.setp((ax1, ax2), ylabel='Number of Features')
        ax1.set_title('RMSE of x and y of Features')
        ax2.set_xlabel('RMSE')
        plt.show()
    
    def violin_plot(self, vmax=0.3, vmin=-0.3):
        dfs = []

        for source in self.sources[1:]:
            center_mean = [np.mean(stats.sigmaclip(source['xcentroid'].dropna(), 3, 3)[0]), np.mean(stats.sigmaclip(source['ycentroid'].dropna(), 3, 3)[0]), np.mean(stats.sigmaclip(source['wavelength'].dropna(), 3, 3)[0])]
            x_dist = []
            y_dist = []
            wave_dist = []
            for i in np.arange(len(source)):
                x_source = source.iloc[i]['xcentroid']
                y_source = source.iloc[i]['ycentroid']
                wave = source.iloc[i]['wavelength']
                if not np.isnan(x_source):
                    x_dist.append(x_source-center_mean[0])
                    y_dist.append(y_source-center_mean[1])
                    wave_dist.append(wave-center_mean[2])
                else:
                    x_dist.append(np.nan)
                    y_dist.append(np.nan)
                    wave_dist.append(np.nan)
            #Make a new dataframe out of this
            df = pd.DataFrame({'date-obs':self.times, 'xdiffs':x_dist,'ydiffs':y_dist, 'wavediffs':wave_dist})
            dfs.append(df)

        #Concatenate all the difference dfs
        diff_df = pd.concat(dfs)
        diff_df = pd.melt(diff_df, id_vars=['date-obs'], value_vars=['xdiffs', 'ydiffs', 'wavediffs'], var_name='category', value_name = 'diffs')
        diff_df = diff_df[(diff_df['diffs'] >= vmin)&(diff_df['diffs'] <= vmax)]

        import seaborn
        seaborn.set(style='whitegrid', font_scale=2)

        seaborn.set(style='whitegrid', rc={"figure.figsize":(40, 20)}, font_scale=2)
        fig, axes = plt.subplots(3, 1, figsize = (40,20))
        palette = {'xdiffs':'tab:blue', 'ydiffs':'tab:orange', 'wavediffs':'tab:green'}
        seaborn.violinplot(x='date-obs', y='diffs', data=diff_df.loc[diff_df['category']=='xdiffs'], scale='count', hue='category', palette=palette, cut=0, bw=0.1, ax=axes[0])
        seaborn.violinplot(x='date-obs', y='diffs', data=diff_df.loc[diff_df['category']=='ydiffs'], scale='count',  hue='category', palette=palette, cut=0, bw=0.1, ax=axes[1])
        g = seaborn.violinplot(x='date-obs', y='diffs', data=diff_df.loc[diff_df['category']=='wavediffs'], scale='count',  hue='category', palette=palette, cut=0, bw=0.1, ax=axes[2])
        g.set(ylim=(-0.01,0.01))
        axes[0].set_title('x centroid variation')
        axes[1].set_title('y centroid variation')
        axes[2].set_title('wavelength variation')
        plt.show()
    
    def signaltonoise(self, data, sigma=3.0):
        sigma_clip = SigmaClip(sigma=3.0)        
        bkgrms = MADStdBackgroundRMS(sigma_clip)
        noise = bkgrms(data)
        snr = np.sqrt((robust_standard_deviation(data)**2/noise**2) - 1)
        return snr
    
    def rvprecsn(self):
        self.rv = [h['RVPRECSN'] for h in self.header]
        plt.style.use('seaborn')
        plt.plot(self.times, self.rv)
        plt.legend()
        plt.xlabel('Times')
        plt.ylabel('RVPRECSN')
        plt.title('LSC RV Precision Comparison')
        plt.tick_params(axis = 'x', rotation=70, labelsize=10)
        plt.show()
    
    def cal_and_sci(self, coords):
        cal = []
        sci = []
        for s in self.sources[1:]:
            if np.nanmean(s['fiber']) == 0:
                cal.append(s)
            elif np.nanmean(s['fiber']) == 1:
                sci.append(s)

        #Then do a radial search looking for each sci for each cal
        def radial_search(comp_tbl, tbls, r=5):
            for tbl in tbls:
                coord = [np.nanmean(comp_tbl['xcentroid']), np.nanmean(comp_tbl['ycentroid'])]
                x = np.nanmean(tbl['xcentroid'])
                y = np.nanmean(tbl['ycentroid'])
                if np.abs(coord[0] - x) + np.abs(coord[1] - y) <= r:  
                    return tbl
                else:
                    pass
        res = []
        for t in cal:
            res.append(radial_search(t, sci, r=20))

        for i in np.arange(len(cal))[:50]:
            if res[i] is not None:
                fig, (ax1, ax2, ax3) = plt.subplots(1,  3, figsize=(20,5), sharey=False)
                x = cal[i]['date-obs']
                ax1.errorbar(x, cal[i]['xcentroid']-np.nanmedian(cal[i]['xcentroid']), ms = 6, fmt='.', label='cal fiber')
                ax1.errorbar(x, res[i]['xcentroid']-np.nanmedian(res[i]['xcentroid']), ms = 6, mfc = 'red',fmt='.', label='sci fiber')
                ax2.errorbar(x, cal[i]['ycentroid']-np.nanmedian(cal[i]['ycentroid']), ms = 6, fmt='.', label = 'cal fiber')
                ax2.errorbar(x, res[i]['ycentroid']-np.nanmedian(res[i]['ycentroid']), ms = 6, mfc = 'red',fmt='.', label = 'sci fiber')
                ax3.errorbar(x, cal[i]['wavelength']-np.nanmedian(cal[i]['wavelength']), yerr = None, ms = 6, fmt='.', label = 'cal fiber')
                ax3.errorbar(x, res[i]['wavelength']-np.nanmedian(res[i]['wavelength']), yerr = None, ms = 6, mfc='red', fmt='.', label = 'sci fiber')
                
                
                ax1.set_title('x position of ThAr cal and sci points '+ str(cal[i].dropna()['xcentroid'][0]))
                ax2.set_title('y position of ThAr cal and sci points ' + str(cal[i].dropna()['ycentroid'][0]))
                ax3.set_title('wavelength of feature ' + str(cal[i].dropna()['wavelength'][0]))
                
                ax1.legend()
                ax2.legend()
                
                ax1.tick_params(axis = 'x', rotation=70, labelsize=9)
                ax2.tick_params(axis = 'x', rotation=70, labelsize=9)
                ax3.tick_params(axis = 'x', rotation=70, labelsize=9)
                plt.setp((ax1, ax2), ylabel='pixel value-median', xlabel='Date Obs')
                plt.setp(ax3, ylabel = 'wavelength - median')
                plt.show()
                
    def do_all(self, coords, vmax=0.3, vmin=-0.3, x_range = (0,1), cmin=0, cmax=0.5):
        self.violin_plot(vmin=vmin, vmax=vmax)
        self.centroid_rmse_plot(x_range=x_range)
        self.centroid_std_map(cmin=cmin, cmax=cmax)
        self.plot_feature_coords(coords)
        self.features_std()
        self.wavelength_std_plot()
        self.rvprecsn()
        self.cal_and_sci(coords)

# Use case 1: Start from scratch

In [None]:
#List DS9 x-y coordinates of well resolved features picked from various areas of the image
ds9_xy = [[2014, 2175], [2104, 2173], [1686, 2119], [1378, 2273], [1130, 2363], [2207, 3256], [2133, 3169], [3040, 2607], [3185, 2613], [2008, 3088], [738, 3647], [940, 3435], [3164, 3529], [2210, 1974], [1827, 1919], [1860, 1855], [1659, 1681], [1936, 3539], [2248, 3915], [3086, 2119]]

#more than 30 days is not recommended (to save your RAM)
test = SQAT()

test.download_data(start_date = '2022-09-01', end_date = '2022-09-02', site = 'lsc', doubles_path = '/home/pkottapalli/nres_tests/chile_sept_doubles/', lampflats_path = '/home/pkottapalli/nres_tests/lampflats/chile_sept_flats/')
test.setup_pipeline(processed_path = '~/SQAT_toolkit', db_path = None) #Sets up db and BANZAI context, and loads in images
test.run_pipeline() #runs feature identification, wavelength calibration, and wavelength qc

test.get_data() #Open fits files and gets all the data we might need

#new and old refer to the new feature tables and the old feature tables.
test.get_sources_new(path = '~/nres_tests/new_features/lsc/nres01/202209*/processed') #Collates sources, result is a list of pandas tables, each table is a source in every image.
test.get_some_features(ds9_xy) #finds source tables for requested coordinates
test.do_all(ds9_xy) #Plots all available plots

# Use case 2: To use your own files already downloaded

In [None]:
sqat = SQAT()
# format of coords [[x1, y1], [x2,y2],...]
#for single coordinate do [[x1, y1]]
ds9_xy = [[2031, 2179],[2121, 2177], [1702, 2124], [1409, 2141], [1203, 2294], [2231, 3265], [2148, 3158], [3053, 2537], [3059, 2612], [2031, 3098], [1352, 3500], [963, 3448], [3191, 3535], [2226, 1977], [1841, 1923], [1861, 1859], [1673, 1684], [1960, 3550], [2276, 3929], [3100, 2121]]

sqat.get_data(path='~/nres_tests/doubles/')
sqat.get_sources_old()
sqat2.get_some_features(ds9_xy)
sqat.do_all(coords = ds9_xy)

# Available convenience functions

In [None]:
#search for one source given ds9 coordinates
sqat.get_some_features([[2031,2179]])

#convert coordinates from ds9 to python
new_coords = sqat.ds9_to_python([[2031,2179]])

#run image segmentation on provided files
sqat.stacked_files = '/path/to/files'
sqat.extract() #returns list of source tables
sqat.signaltonoise(image_data) #Calculates signal to noise for an image

# Available convenience variables

In [None]:
#After running setup_pipeline()
test.context #Banzai context object
test.frame_factory #NRES Frame Factory object

#After running get_data()
sqat.stacked_files #list of provided files
sqat.header
sqat.times #mmdd formatted times

#After running get_sources
sqat.features #list of feature tables from images
sqat.sources #a list of tables where each table is one source over every image provided. missing values are NaNs

#After running get_some_features()
sqat.found #the sources tables found after searching for sources at specific DS9 coordinates

#After running rvprecsn()
sqat.rv #calculated rv precision for the set of images provided

# Use case 3: Use BANZAI utilities

In [None]:
# Calculate the sigma clipped mean of an image
from banzai.utils.stats import sigma_clipped_mean

sm_mean = sigma_clipped_mean(fits.open(sqat2.stacked_files[0])['SPECTRUM'].data, sigma = 3)
print(sm_mean)

# Use case 4: Use data to explore on your own

In [None]:
#Plot de-trended centroids and wavelengths
for f in sqat.found:
    x_diffs = []
    y_diffs = []
    wave_diffs = []
    x = f['xcentroid']
    y = f['ycentroid']
    wave = f['wavelength']
    for i in range(1, len(x)):
        xval = x[i]-x[i-1]
        yval = y[i]-y[i-1]
        waveval = wave[i]-wave[i-1]
        x_diffs.append(xval)
        y_diffs.append(yval)
        wave_diffs.append(waveval)
    fig, (ax1, ax2, ax3) = plt.subplots(1,  3, figsize=(20,5), sharey=False)
    t = f['date-obs']
    ax1.errorbar(t[1:], x_diffs, yerr = None, ms = 6, fmt='.')
    ax1.hlines(y = np.mean(x_diffs), xmin = t[0], xmax = t[1], color='r', linestyle='--')
    ax2.errorbar(t[1:], y_diffs, yerr = None, ms = 6, fmt='.')
    ax2.hlines(y = np.mean(y_diffs), xmin = t[0], xmax = t[1], color='r', linestyle='--')
    ax3.errorbar(t[1:], wave_diffs, yerr = None, ms = 6, fmt='.')
    ax3.hlines(y = np.mean(wave_diffs), xmin = t[0], xmax = t[1], color='r', linestyle='--')
    
    ax1.set_title('detrended x position of ThAr cal points '+ str(f.dropna()['xcentroid'][0])[:-5])
    ax2.set_title('detrended y position of ThAr cal points ' + str(f.dropna()['ycentroid'][0])[:-5])
    ax3.set_title('wavelength of feature ' + str(f.dropna()['wavelength'][0])[:-5])
    ax1.tick_params(axis = 'x', rotation=70, labelsize=9)
    ax2.tick_params(axis = 'x', rotation=70, labelsize=9)
    ax3.tick_params(axis = 'x', rotation=70, labelsize=9)
    plt.setp((ax1, ax2), ylabel='difference from previous point', )
    plt.setp(ax3, ylabel = 'wavelength - median')
    plt.show()