In [1]:
#Astropy
import astropy
from astropy.io import fits
from astropy.table import Table
from astropy import units as u

# Dlnpyutils
from dlnpyutils.utils import bspline,mad,interp

# dust_extinction
from dust_extinction.parameter_averages import CCM89,O94,F99,VCG04,GCC09,M14,F19,D22

# functools
from functools import partial

# Matplotlib
import matplotlib
import matplotlib.pyplot as plt
# %matplotlib inline
matplotlib.rcParams.update({'font.size': 25})

#Numpy/Scipy
import numpy as np
import scipy
from scipy.interpolate import interp1d
from scipy.optimize import curve_fit

# pdb
import pdb

# tqdm 
from tqdm.notebook import tqdm

class Aetas():
    '''
    A class to calculate a star's extinction, age and mass using PARSEC isochrones with 
    Gaia (E)DR3 and 2MASS photometry.
    '''
    def __init__(self,teff,logg,abund,obsphot,distance,isochrones,ext_law='CCM89',rv=3.1,
                 teff_extrap_limit=100,debug=False):
        
        '''
        Inputs:
        ------
                    star_data: Table (pandas dataframe or astropy table)
                               Observed and calculated properties of a star(s) with the 
                               following columns:
                               
                               'TEFF', 'TEFF_ERR', 'LOGG', 'LOGG_ERR', 'FE_H', 'FE_H_ERR',
                               'ALPHA_FE', 'ALPHA_FE_ERR', 'BP', 'G', 'RP', 'J', 'H', 'K',
                               'BP_ERR', 'G_ERR', 'RP_ERR', 'J_ERR', 'H_ERR', 'K_ERR', 
                               'DISTANCE'
                
                               'DISTANCE' must be in units of parsecs
                    
                   isochrones: Table (pandas dataframe or astropy table)
                               PARSEC isochrone table with the following columns:
                               
                               'MH', 'Mass', 'delta_int_IMF', 'logAge', 'logTe', 'logg', 
                               'BPmag', 'Gmag', 'RPmag', 'Jmag', 'Hmag', 'Ksmag'
                               
                               'delta_int_IMF' is the difference in adjacent 'int_IMF' 
                               values for each isochrone (i.e. int_IMF[i+1]-int_IMF[i])
                               with the last value repeated as the difference returns 
                               one less element
                        
                      ext_law: string, optional
                               extinction law to use. Default is CCM89.

                               Available Extinction Laws: 
                               -------------------------

                               CCM89 - Cardelli, Clayton, & Mathis 1989
                               O94 - O'Donnell 1994
                               F99 - Fitzpatrick 1999
                               F04 - Fitzpatrick 2004
                               VCG04 - Valencic, Clayton, & Gordon 2004
                               GCC09 - Grodon, Cartledge, & Clayton 2009
                               M14 - Maiz Apellaniz et al 2014
                               F19 - Fitzpatrick, Massa, Gordon, Bohlin & Clayton 2019
                               D22 - Decleir et al. 2022
                        
                           rv: float, optional
                               Rv (=Av/E(B-V)) extinction law slope. Default is 3.1 
                               (required to be 3.1 if ext_law = 'F99')
                        
            teff_extrap_limit: float
                               limit for maximum allowable temperature outside 
                               isochrone range that will be extrapolated 
                        
                        debug: bool
                               print useful information to the screen

        '''
        
        # Teff and log(g)
        self.teff = star_data['TEFF'] # temperature
        self.teff_err = star_data['TEFF_ERR'] # temperature error
        self.logg  = star_data['LOGG'] # log(g)
        self.logg_err = star_data['LOGG_ERR'] # log(g) error
        
        # Salaris corrected [Fe/H]
        sal_met = self.salaris_metallicity(star_data['FE_H'],star_data['FE_H_ERR'],
                                           star_data['ALPHA_FE'],star_data['ALPHA_FE_ERR'])
        
        self.salfeh,self.salfeh_err = sal_met[0],sal_met[1]
        
        # observed photometry
        self.obs_phot_labels = ['BP','G','RP','J','H','K']
        self.phot = 999999.0*np.ones(6)
        self.phot_err = 999999.0*np.ones(6)
        for i in range(len(self.phot)):
            self.phot = np.append(self.phot,star_data[self.obs_phot_labels[i]])
            self.phot_err = np.append(self.phot_err,star_data[self.obs_phot_labels[i]+'_ERR'])
        
        # Distance modulus
        self.distmod = 5.0*np.log10(star_data['DISTANCE'])-5.0
        
        # PARSEC isochrones
        self.iso_phot_labels = ['BPmag','Gmag','RPmag','Jmag','Hmag','Ksmag']
        self.iso_interp_labels = ['BPmag','Gmag','RPmag','Jmag','Hmag','Ksmag','logg','delta_int_IMF']
        
        isochrones = isochrones[np.argsort(isochrones['logAge'])]
        
        self.iso = isochrones[np.where(isochrones['MH']==self.closest(isochrones['MH'],self.salfeh))]
        self.uniq_ages = np.unique(self.iso['logAge'])
        
        age_idx_3 = []
        for i in range(len(self.uniq_ages)):
            ages_3, = np.where((self.iso['logAge']==self.uniq_ages[i])&(self.iso['logAge']==3))
            age_idx_3.append(np.array([min(ages_3),max(ages_3)]))
        
        self.age_idx_3 = np.asarray(age_idx_3)
        
        age_idx_7 = []
        for i in range(len(self.uniq_ages)):
            ages_7, = np.where((self.iso['logAge']==self.uniq_ages[i])&(self.iso['logAge']==7))
            age_idx_7.append(np.array([min(ages_7),max(ages_7)]))
        
        self.age_idx_7 = np.asarray(age_idx_7)
                
        # Extinction
        self.rv = rv
        self.leff = np.array([0.5387,0.6419,0.7667,1.2345,1.6393,2.1757]) #BP, G, RP, J, H, K (microns)
        self.extlaw_coeff = self.extcoeff(law=ext_law,rv=self.rv)
        
        # Other
        self.debug = debug
        self.teff_extrap_limit = teff_extrap_limit 
        
    #################
    ### Utilities ###
    #################
        
    def closest(self,data,value):
        '''
        Find nearest value in array to given value.

        Inputs:
        ------
             data: array-like
                   data to search through

            value: float or int
                   value of interest

        Output:
        ------
            close: float or int
                   value in data closest to given value
        '''
        
        data = np.asarray(data)
        
        return data[(np.abs(np.subtract(data,value))).argmin()]
    
    def neighbors(self,data,value):
        '''
        Find values of two elements closest to the given value.

        Inputs:
        ------
              data: array-like
                    data to search through

             value: float or int
                    value of interest

        Output:     
        ------     
            close1: float or int
                    closest value under the given value

            close2: float or int
                    closest value over the given value
        '''
    
        data = np.asarray(data)
        close1 = data[(np.abs(np.subtract(data,value))).argmin()]
        data = data[np.where(data!=close1)]
        close2 = data[(np.abs(np.subtract(data,value))).argmin()]
    
        return close1,close2
    
    ###################
    ### Metallicity ###
    ###################
    
    def salaris_metallicity(self,metal,metal_err,alpha,alpha_err):
        '''
        Calculate the Salaris corrected metallicity (Salaris et al. 1993) using updated solar 
        parameters from Asplund et al. 2021.
        
        Inputs:
        ------
                 metal: float
                        [Fe/H] of a star
                         
             metal_err: float
                        error in [Fe/H] of a star
                     
                 alpha: float
                        [alpha/Fe] of a star
                 
             alpha_err: float
                        error in [alpha/Fe] of a star
                        
        Outputs:
        -------
                salfeh: float
                        Salaris corrected metallicity
            
            salfeh_err: float
                        error in Salaris corrected metallicity
            
        '''
        salfeh = metal[0]+np.log10(0.659*(10**alpha[0])+0.341)
        salfeh_err = np.sqrt(metal[1]**2+(10**alpha[0]/(0.517372+10**alpha[0])*alpha[1])**2)
        
        return salfeh, salfeh_err
    
    ##################
    ### Extinction ###
    ##################
    
    def extcoeff(self,law='CCM89',rv=3.1):

        '''
        Calculate the relative extincion law coefficients for the BP, G, RP, J, H, Ks bands
        for a given Rv and extinction law.

        Input:
        -----
                    rv: float
                        Rv (=Av/E(B-V)) extinction law slope. Default is 3.1
                        
                   law: str
                        extinction law to use

                        Available Extinction Laws: 
                        -------------------------

                        CCM89 - Cardelli, Clayton, & Mathis 1989
                        O94 - O'Donnell 1994
                        F99 - Fitzpatrick 1999
                        F04 - Fitzpatrick 2004
                        VCG04 - Valencic, Clayton, & Gordon 2004
                        GCC09 - Grodon, Cartledge, & Clayton 2009
                        M14 - Maiz Apellaniz et al 2014
                        F19 - Fitzpatrick, Massa, Gordon, Bohlin & Clayton 2019
                        D22 - Decleir et al. 2022
                        
        Output:
        ------
             ext_coeff: float
                        calculated extinction coefficients for the BP, G, RP, J, H, and K bands

        '''

        leff = {'BP':0.5387,'G':0.6419,'RP':0.7667,'J':1.2345,'H':1.6393,'K':2.1757}

        # select the extinction model
        if law == 'CCM89':
            ext_model = CCM89(Rv=rv)

        elif law == 'O94':
            ext_model = O94(Rv=rv)

        elif law == 'F99':
            ext_model = F99(Rv=rv)

        elif law == 'F04':
            ext_model = F04(Rv=rv)

        elif law == 'VCG04':
            ext_model = VCG04(Rv=rv)

        elif law == 'GCC09':
            ext_model = GCC09(Rv=rv)

        elif law == 'M14':
            ext_model = M14(Rv=rv)

        elif law == 'F19':
            ext_model = F19(Rv=rv)

        elif law == 'D22':
            ext_model = D22(Rv=rv)   

        # Calculate the relative extinction coefficient
        ext_coeff_array = ext_model(np.reciprocal(self.leff*u.micron))

        return ext_coeff_array
    
    def extinction(self):
        '''
        Calulate the extinctions for the BP, G, RP, J, H, and Ks bands
        
        Output:
        ------
            ext: 6x2 array
                 first column is the extinction values and the second is the errors
        '''
        
        if self.debug:
            print('### Running Aetas.extinction() ###')
            print('Inputs from Aetas.__init__()')
            print('Salaris Corrected Metallicity:',self.salfeh)
            print('Temperature:',self.teff)

        # isochrone magnitude labels
        color_labels = np.array([['BPmag','Gmag','Gmag','Gmag','Gmag'],
                                 ['Gmag','RPmag','Jmag','Hmag','Ksmag']]).T
        
        # pick isochrone points with temperatures within 500 K of the star's Teff
        teffcut = np.where((self.iso['logTe']<np.log10(self.teff+500.))&
                           (self.iso['logTe']>np.log10(self.teff-500.)))
        
        iso_ = self.iso[teffcut]
        
        # check to make sure there are enough isochrone points
        if np.size(np.squeeze(teffcut))==0:
            self.ext = 999999.0*np.ones((6,2))
            return 999999.0*np.ones((6,2))
    
        # get colors and errors
        obs_colors = np.delete(self.phot-self.phot[1],1)
        obs_colors[1:] = -1*obs_colors[1:]
        obs_colors_err = np.delete(np.sqrt(self.phot_err**2+self.phot_err[1]**2),1)
        
        if self.debug:
            print('Calculated Observed Colors:')
            print('Observed Colors:',obs_colors)
            print('Observed Color Errors:',obs_colors_err)
        
        # create "reddening" vector (slightly modified from the normal definition)
        red_vec = np.delete(1-self.extlaw_coeff/self.extlaw_coeff[1],1)
        red_vec[0] = -1*red_vec[0]
        
        # relative extinction vector
        ext_vec = self.extlaw_coeff/self.extlaw_coeff[1]
        
        # calculate the intrinsic colors using a b-spline
        iso_colors = 999999.0*np.ones(5)
        iso_colors_deriv = 999999.0*np.ones(5)
        
        # determine if the Teff is in the isochrone range
        use_lgteff = np.log10(self.teff)
        if use_lgteff < np.min(iso_['logTe']) or use_lgteff > np.max(iso_['logTe']):
            use_lgteff = closest(iso_['logTe'],np.log10(self.teff))
        
        # Interpolate the color-Teff relation using a b-spline
        logTe = iso_['logTe']
        for i in range(5):
            try:
                color = (iso_[color_labels[i,0]]-iso_[color_labels[i,1]])

                bspl = bspline(logTe,color)
                iso_colors[i] = bspl(use_lgteff)
                iso_colors_deriv[i] = bspl.derivative()(use_lgteff)
                
            except:
                try:
                    bspl = bspline(logTe,color,extrapolate=True)
                    iso_colors[i] = bspl(use_lgteff)
                    iso_colors_deriv[i] = bspl.derivative()(use_lgteff)
                
                except:
                    iso_colors[i] = 999999.0
                    iso_colors_deriv[i] = 999999.0
                    
        if self.debug:
            print('Isochrone Colors:',iso_colors)
                
        # calculate the extinctions and errors
        color_diff = obs_colors-iso_colors
        color_errs = np.abs((iso_colors_deriv*self.teff_err)/(self.teff*np.log(10)))
        color_diff_err = np.sqrt(obs_colors_err**2+color_errs**2)
        
        # find bad values this should take care of bad values from the spline
        neg_cut = np.where(color_diff>0)
        
        # if all bad return bad values
        if np.size(np.squeeze(neg_cut))==0:
            if self.debug:
                print('All Colors are bad')
                print('Max Iso Teff:',10**np.nanmax(iso_['logTe']))
                print('Min Iso Teff:',10**np.nanmin(iso_['logTe']))
                print('Obs Teff:',self.teff)
            
            self.ext = 999999.0*np.ones((6,2))
            return 999999.0*np.ones((6,2))
        
        # calculate the extinction value and error
        ag = np.dot(red_vec[neg_cut],color_diff[neg_cut])/np.linalg.norm(red_vec[neg_cut])**2
        ag_err = np.dot(red_vec[neg_cut],color_diff_err[neg_cut])/np.linalg.norm(red_vec[neg_cut])**2
        
        ext = 999999.0*np.ones((6,2))
        ext[:,0] = ext_vec*ag
        ext[:,1] = ext_vec*ag_err
        
        # chisq
        iso_colors_extincted = iso_colors+red_vec*ag
        ext_chi = sum((obs_colors-iso_colors_extincted)**2/obs_colors_err**2)
        
        self.ext = ext
        
        if self.debug:
            print('A(G)+ Error:',ag,ag_err)
            print('All Extinctions:',ext[:,0])
            print('chisq:',ext_chi)
            print('resid:',obs_colors-iso_colors_extincted)
        
        return ext
    
    ##########################################
    ### Magnitudes, Log(g), delta_int_IMF  ###
    ##########################################
    
    def teff_2_appmags(self,teff,age,label,extrap=False):
        '''
        Calculate the expected apparent magnitudes and log(g) of a star. The change in 'int_IMF' 
        is also calculated and stored as self.delta_int_IMF and not returned. For more information 
        on 'delta_int_IMF' see __init()__. 
        
        Inputs:
        ------
                teff: float
                      Teff of star 
                      
                 age: float
                      age of star
               
               label: int
                      label for PARSEC evolutionary phase
                      
              extrap: bool
                      False: no extrapolation
                      True: extrapolate
        
        Output:
        ------
               calc_: 7x1 array 
                      BP, G, RP, J, H, and Ks with log(g) 
        '''
        
        #Some calcualtions to set things up
        quantity_labels = ['BPmag','Gmag','RPmag','Jmag','Hmag','Ksmag','logg','delta_int_IMF']
        
        lgteff = np.log10(teff)
        
        extincts = self.ext[:,0]
        
        if extincts[1] > 100.:
            print('Bad extinctions replaced with 0.0')
            extincts *= 0.0
        
        lgage = np.log10(age*10**9)
        
        if self.debug:
            print('Running Aetas.teff_2_appmags()')
            print('Teff:',teff)
            print('Extinctions:',extincts)
        
        # Figure out if age is actually in the ages given in the isochrone table
        if lgage in self.uniq_ages:
            
            ### pick out a single isochrone 
            aidx, = np.where(self.uniq_ages==lgage)
            iso_ = self.iso[self.age_idx[int(aidx)][0]:self.age_idx[int(aidx)][1]]
            #pdb.set_trace()
            
            ### Teff check
            self.did_extrap = 0

            if (10**min(iso_['logTe'])-self.teff > self.teff_extrap_limit or 
                self.teff - 10**max(iso_['logTe']) > self.teff_extrap_limit):
                if debug:
                    print('Teff Oustide extrapolation limit')
                    print('Age',age)
                    print('max iso',max(iso_['logTe']))
                    print('min iso',min(iso_['logTe']))
                    print('Teff',np.log10(self.teff))
                    print('Lower - Teff',10**min(iso_['logTe'])-self.teff)
                
                self.did_extrap = 0
                self.delta_int_IMF = 999999.0
                return 999999.0*np.ones(7)
                
            ### use a b-spline to get the apparent mags, log(g), and int_IMF
            calc_ = 999999.0*np.ones(8)
            self.did_extrap = 0
            
            for i in range(len(quantity_labels)):
                
                if extrap:
                    self.did_extrap=1
                    try:
                        spl_ = interp(iso_['logTe'],iso_[self.iso_interp_labels[i]],lgteff,
                                      assume_sorted=False,extrapolate=True)
                        
                        if i <= 5:
                            calc_[i] = spl_+self.distmod+extincts[i]
                    
                    except:
                        calc_[i] = 999999.0
                        
                else:
                    try:
                        spl_ = bspline(iso_['logTe'],iso_[self.iso_interp_labels[i]])(lgteff)
                        
                        if i<= 5:
                            calc_[i] = spl_+self.distmod+extincts[i]
                
                    except:
                        calc_[i] = 999999.0
                    
            self.delta_int_IMF = calc_[-1] # store int_IMF separately
            calc_ = calc_[:-1] # delete int_IMF from the calculated values
            return calc_
            
        else:
            # find 2 closest ages in the ischrones
            lgage_lo,lgage_hi = self.neighbors(self.uniq_ages,lgage)
            if self.debug:
                print('[age_lo,age_hi]: ',[10**lgage_lo/10**9,10**lgage_hi/10**9])
        
            # younger isochrone
            aidx_lo, = np.where(self.uniq_ages==lgage_lo)
            iso_lo = self.iso[self.age_idx[int(aidx_lo)][0]:self.age_idx[int(aidx_lo)][1]]
            
            # older isochrone
            aidx_hi, = np.where(self.uniq_ages==lgage_hi)
            iso_hi = self.iso[self.age_idx[int(aidx_hi)][0]:self.age_idx[int(aidx_hi)][1]]
            
            ### Temperature Check
            extrap_lo = extrap
            extrap_hi = extrap
            self.did_extrap = 0
                        
            if (10**min(iso_lo['logTe'])-self.teff > self.teff_extrap_limit or 
                self.teff - 10**max(iso_lo['logTe']) > self.teff_extrap_limit):
            
                if self.debug:
                    print('outside iso_lo')
                    print('Age',10**lgage_lo/10**9)
                    print('max iso_lo',max(iso_lo['logTe']))
                    print('min iso_lo',min(iso_lo['logTe']))
                    print('Teff',np.log10(self.teff))
                    print('Lower - Teff',10**min(iso_lo['logTe'])-self.teff)
                    
                self.did_extrap = 0
                self.delta_int_IMF = 999999.0
                return 999999.0*np.ones(7)
            
            if (10**min(iso_hi['logTe'])-self.teff > self.teff_extrap_limit or 
                self.teff - 10**max(iso_hi['logTe']) > self.teff_extrap_limit):
                if self.debug:
                    print('outside iso_hi')
                    print('max iso_hi',max(iso_hi['logTe']))
                    print('min iso_hi',min(iso_hi['logTe']))
                    print('Teff',np.log10(self.teff))
                    print('Lower - Teff',10**min(iso_hi['logTe'])-self.teff)
                    
                self.did_extrap = 0
                self.delta_int_IMF = 999999.0
                return 999999.0*np.ones(7)
            
            ### use a b-spline to get the apparent mags, log(g), and int_IMF
            age_lo = 10**lgage_lo/10**9
            age_hi = 10**lgage_hi/10**9
            calc_lo = 999999.0*np.ones(8)
            calc_hi = 999999.0*np.ones(8)
            
            for i in range(len(quantity_labels)):
                
                # younger age spline
                if extrap_lo:
                    self.did_extrap=1
                    try:
                        spl_lo = interp(iso_lo['logTe'],iso_lo[self.iso_interp_labels[i]],lgteff,
                                      assume_sorted=False,extrapolate=True)
                        
                        if i <= 5:
                            calc_lo[i] = spl_lo+self.distmod+extincts[i]
                    
                    except:
                        calc_lo[i] = 999999.0
                        
                else:
                    try:
                        spl_lo = bspline(iso_lo['logTe'],iso_lo[self.iso_interp_labels[i]])(lgteff)
                        if i<= 5:
                            calc_lo[i] = spl_lo+self.distmod+extincts[i]
                
                    except:
                        calc_lo[i] = 999999.0
                
                ### older age spline 
                if extrap_hi:
                    self.did_extrap=1
                    try:
                        spl_hi = interp(iso_hi['logTe'],iso_hi[self.iso_interp_labels[i]],lgteff,
                                      assume_sorted=False,extrapolate=True)
                        
                        if i <= 5:
                            calc_hi[i] = spl_hi+self.distmod+extincts[i]
                    
                    except:
                        calc_hi[i] = 999999.0
                        
                else:
                    try:
                        spl_hi = bspline(iso_hi['logTe'],iso_hi[self.iso_interp_labels[i]])(lgteff)
                        
                        if i<= 5:
                            calc_hi[i] = spl_hi+self.distmod+extincts[i]
                
                    except:
                        calc_hi[i] = 999999.0
                    
            calc_ = 999999.0*np.ones(8)        
            for i in range(len(quantity_labels)):
                spl_ = np.poly1d(np.squeeze(np.polyfit([age_lo,age_hi],[calc_lo[i],calc_hi[i]],1)))
                calc_[i] = spl_(age)
             
            self.delta_int_IMF = calc_[-1] # store int_IMF separately
            calc_ = calc_[:-1] # delete int_IMF from the calculated values
            
            return calc_
        
    def get_age(self):
        '''
        
        '''
        
        if self.debug:
            print('Running Aetas.get_age()')
            print('guess_ages:',guess_ages)
        
        # set photometry error or 0.01 if tiny
        phot_err = np.maximum(self.phot_err,0.01)
        
        ### 3 no extrapolation
        curve_ages_3 = []
        curve_chi_3 = []
        curve_rms_3 = []
        curve_extrap_3 = []
        curve_int_IMF_3 = []
        
        teff_2_appmags = partial(self.teff_2_appmags,)
        
        # loop over age and ak space 
        for j in range(len(guess_ages)):
            try:
                # calculate best fit parameters and covariance matrix
                obs_quants = np.append(np.copy(self.phot),self.logg)
                obs_quants_err = np.append(phot_err,self.logg_err)


                popt,pcov = curve_fit(self.teff_2_appmags,self.teff,obs_quants,p0=guess_ages[j],
                                      method='lm',sigma=obs_quants_err,
                                      absolute_sigma=True,maxfev=5000)

                # populate lists
                curve_ages_3.append(popt[0])
                curve_mags_logg = np.asarray(self.teff_2_appmags(self.teff,popt[0]))
                curve_extrap.append(self.did_extrap)
                curve_int_IMF.append(self.delta_int_IMF)

                if self.debug:
                    print('Calc App Mags + logg Label 3:',curve_mags_logg)
                    print('Observed Mags + logg Label 3:',obs_quants)
                    print('Observed Mags + logg Errors Label 3:',obs_quants_err)
                    

                curve_chi_3.append(sum((curve_mags_logg-obs_quants)**2/obs_quants_err**2))
                curve_rms_3.append(np.std(curve_mags_logg-obs_quants))

            except:
                # populate lists
                curve_ages_3.append(999999.0)
                curve_chi_3.append(999999.0)
                curve_rms_3.append(999999.0)
                curve_extrap_3.append(0)
                curve_int_IMF_3.append(999999.0)
                
                
            if self.debug:
                print(j+1,guess_ages[j],curve_ages_3[j],curve_chi_3[j],curve_rms_3[j])
                
        if np.sum(np.array(curve_ages_3)<1e5)==0:
            if self.debug:
                print('All Bad Label 3')
            age_3,chi_3,rms_3 = 999999.0, 999999.0, 999999.0
        
        # find smallest chisq value and corresponding age and Ak
        idx_3 = np.asarray(curve_chi_3).argmin()
        age_3 = np.asarray(curve_ages_3)[idx]
        chi_3 = np.asarray(curve_chi_3)[idx]
        rms_3 = np.asarray(curve_rms_3)[idx]
        
        ### 7 no extrapolation
        curve_ages_7 = []
        curve_chi_7 = []
        curve_rms_7 = []
        curve_extrap_7 = []
        curve_int_IMF_7 = []
        
        
        if and :
        
            ### 3 with extrapolation
            curve_ages_3 = []
            curve_chi_3 = []
            curve_rms_3 = []
            curve_extrap_3 = []
            curve_int_IMF_3 = []

            ### 7 with extrapolation
            curve_ages_7 = []
            curve_chi_7 = []
            curve_rms_7 = []
            curve_extrap_7 = []
            curve_int_IMF_7 = []

In [None]:
1. Run Aetas with LEVEL=3 with no extrapolation.  Save best age and delta_INT_IMF

2. Run Aetas with LEVEL=7 only with no extrapolation. Save best age and delta_INT_IMF

3. If both fail, then rerun with extrapolation.

4. if either one succeeds, save the best age and delta_INT_IMF

5. if both succeed, calculate a weighted (with delta_INT_IMF) age, and weighted delta_INT_IMF

weighted delta_int_IMF calculate:

delta_int_IMF = [1,4]
wgted_delta_int_IMF = (1^2+4^2)/(1+4)

if delta_int_IMF is zero use closest