In [45]:
#Astropy
import astropy
from astropy.io import fits
from astropy.table import Table

# Dlnpyutils
from dlnpyutils import utils

# Matplotlib
import matplotlib
import matplotlib.pyplot as plt
# %matplotlib inline
matplotlib.rcParams.update({'font.size': 25})

#Numpy/Scipy
import numpy as np
import scipy
from scipy.interpolate import interp1d
from scipy.optimize import curve_fit

class Aetas():
    '''
    A class to calculate a star's extinction, age and mass using PARSEC isochrones with 
    Gaia and 2MASS photometry.
    '''
    def __init__(self,teff,abund,obsphot,distance,isochrones):
        
        '''
        Inputs:
        ------
            teff:       2x1 array
                        Teff and error of star np.array([teff,teff_err])
                    
            abund:      2x1 array
                        first column is [M/H] and [Alpha/M] and the second column is the errors
                    
            obs_phot:   6x2 array 
                        first column is BP,G,RP,J,H and K and the second column is the errors
                        
            distance:   float
                        distance to star in pc
                    
            isochrones: astropy Table 
                        PARSEC isochrone table
            ----------------------------------------------        
            rv:         [float] 
                        Rv value (=Av/E(B_V)) 
        '''
        
        # Teff
        self.teff = teff[0] # temperature
        self.teff_err = teff[1] # temperature error
        
        # Salaris [Fe/H] Asplund 2021 et al.
        self.salfeh = abund[0,0]+np.log10(0.659*(10**(abund[1,0]))+0.341) # Salaris Corrected [Fe/H]
        self.salfeh_err = np.sqrt(abund[0,1]**2+((1-0.341/(0.659*(10**(abund[1,0]))+0.341))*abund[1,1])**2)
        self.phot = obsphot[:,0] # photometry
        self.phot_err = obsphot[:,1] # photometry errors
        
        # Distance modulus
        self.distance = distance
        self.distmod = 5.0*np.log10(distance)-5.0
        
        # Absolute Magnitudes Not Derreddened
        self.absphot = self.phot-self.distmod
        
        # PARSEC isochrones
        self.labels = ['G_BPEDR3mag','GEDR3mag','G_RPEDR3magmag','Jmag','Hmag','Ksmag']
        
        self.iso = isochrones[np.where(isochrones['MH']==self.closest(isochrones['MH'],self.salfeh))]
        self.uniq_ages = np.unique(self.iso['logAge'])
        
        age_idx = []
        for i in range(len(self.uniq_ages)):
            ages, = np.where(self.iso['logAge']==self.uniq_ages[i])
            age_idx.append(np.array([min(ages),max(ages)]))
        
        self.age_idx = np.asarray(age_idx)
                
        # Effective Wavelengths of different passbands in units of microns
        self.leff = {'G_BPEDR3mag':0.5387,'GEDR3mag':0.6419,'G_RPEDR3magmag':0.7667,
                     'Jmag':1.2345,'Hmag':1.6393,'Ksmag':2.1757}
        
    #################
    ### Utilities ###
    #################
        
    def closest(self,data,value):
        '''
        Find nearest value in array to given value

        Inputs:
        ------
            data:  array-like
                   data to search through

            value: float or int
                   value of interest

        Output:
        ------
            close: float or int
                   value in data closest to given value
        '''
        
        data = np.asarray(data)
    
#         data = data[np.where(np.isfinite(data)==True)]
        
        return data[(np.abs(np.subtract(data,value))).argmin()]
    
    def neighbors(self,data,value):
        '''
        Find values of two elements closest to the given value

        Inputs:
        ------
            data:   array-like
                    data to search through

            value:  float or int
                    value of interest

        Output:
        ------
            close1: float or int
                    closest value under the given value

            close2: float or int
                    closest value over the given value
        '''
    
        data = np.asarray(data)
        close1 = data[(np.abs(np.subtract(data,value))).argmin()]
        data = data[np.where(data!=close1)]
        close2 = data[(np.abs(np.subtract(data,value))).argmin()]
    
        return close1,close2
    
    def mad(self,data,normal=False):
        '''
        Calculate the median absolute deviation of the data

        Inputs:
        ------
            data:   array-like
                    data to calculate the median absolute deviation of

            normal: True or False
                    if True the MAD will be scaled to the normal distribution

        Output:
        ------
            mad:    float
                    median absolute deviation in the data
        '''

        if normal == True:
            mad = np.nanmedian(np.abs(data-np.nanmedian(data)))/scipy.stats.norm.ppf(0.75)
            return mad

        mad = np.nanmedian(np.abs(data-np.nanmedian(data)))
        return mad
    
    ##################
    ### Extinction ###
    ##################
    
    def ccm_a(self,x):
        '''
        a(x) function from CCM et al. 1989

        Input:
        -----
            x: float
               effective wavelength in units of 1/micron

        Output:
        ------
            a: float
               value of the "a" function from CCM et al. 89  
        '''
        if 0.3 <= x < 1.1:
            a = 0.574*(x**1.61)
            return a
    
        elif 1.1 <= x < 3.3:
            y = x - 1.82
            a = (1.+0.17699*y-0.50477*(y**2)-0.02427*(y**3)+0.72085*(y**4)+
                 0.01979*(y**5)-0.77530*(y**6)+0.32999*(y**7))
            return a
    
        elif 3.3 <= x < 8.0:
            if x < 5.9:
                a = 1.752-0.136*x-0.104/((x-4.67)**2+0.341)
                return a
        
            else:
                fa = -0.04473*((x-5.9)**2)+0.1207*((x-5.9)**3)
                a = 1.752-0.136*x-0.104/((x-4.67)**2+0.341)+fa
                return a       
    
    def ccm_b(self,x):
        '''
        b(x) function from ccm et al. 1989

        Input:
        -----
            x: float
               effective wavelength in units of 1/micron

        Output:
        ------
            b: float
               value of the "b" function from CCM et al. 89 
        '''
        if 0.3 <= x < 1.1:
            b = -0.527*(x**1.61)
            return b
    
        elif 1.1 <= x <= 3.3:
            y = x - 1.82
            b = (1.41338*y+2.28305*(y**2)+1.07233*(y**3)-5.38434*(y**4)-
                 0.62251*(y**5)+5.30260*(y**6)-2.09002*(y**7))
            return b
    
        elif 3.3 <= x < 8.0:
            if x < 5.9:
                b = -3.090+1.825*x+1.206/((x-4.62)**2+0.263)
                return b
        
            else:
                fb = 0.2130*((x-5.9)**2)+0.1207*((x-5.9)**3)
                b = -3.090+1.825*x+1.206/((x-4.62)**2+0.263)+fb
                return b
    
    def ccm_alav(self,wave):
        '''
        Calculate the relative extinction of one band to the V band using
        Cardelli et al. 1989 (A\lambda/Av)

        Inputs:
        ------
            wave: float
                  effective wavelength in units of micron

            rv:   float
                  Rv Slope of the extinction law (= Av/E(B-V))

        Output:
        ------
            alav: float
                  A\lambda/Av
        '''
        x=1/wave
        alav = self.ccm_a(x)+self.ccm_b(x)/3.1
        return alav
    
    def fitz_alebv(self,wave):
        '''
        Calculate the relative extinction to E(B-V) using Fitzpatrick 1999 extinction law
        '''
        recip_anchors = np.array([0.000,0.377,0.820,1.667,1.828,2.141,2.433,3.704,3.846])
        alebv_anchors = np.array([0.000,0.265,0.829,2.688,3.055,3.806,4.315,6.265,6.591])
        
        x=1/wave
        spl = interp1d(recip_anchors,alebv_anchors)
        alebv = spl(x)
        return alebv
    
    def extinction(self,law='cardelli'):
        '''
        Calulate the extinctions for the BP, G, RP, J, H, and K bands
        
        Input:
        -----
            law: str
                 the extinction law used ('cardelli' or 'fitzpatrick')
        
        Output:
        ------
            ext: 6x2 array
                 first column is the extinction values and the second is the errors
            
        '''
        
        # isochrone magnitude labels
        color_labels = np.array([['G_BPEDR3mag','GEDR3mag','GEDR3mag','GEDR3mag','GEDR3mag'],
                                 ['GEDR3mag','G_RPEDR3magmag','Jmag','Hmag','Ksmag']]).T
        
        # pick isochrone points with temperatures within 200 K of the star's Teff
        teffcut = np.where((self.iso['logTe']<np.log10(self.teff+200.))&
                           (self.iso['logTe']>np.log10(self.teff-200.)))
        iso_ = self.iso[teffcut]
    
        # get colors and errors
        obs_colors = np.delete(self.phot[:]-self.phot[1],1)
        obs_colors[1:] = -1*obs_colors[1:]
        print(obs_colors)
        obs_colors_err = np.delete(np.sqrt(self.phot_err[:]**2+self.phot_err[1]**2),1)
        print(obs_colors_err)
        
        # create "reddening" vector (slightly modified from the normal definition)
        if law == 'Cardelli':
            red_vec = np.array([self.fitz_alebv(self.leff['G_BPEDR3mag'])/self.fitz_alebv(self.leff['GEDR3mag'])-1.,
                               1.-self.fitz_alebv(self.leff['G_RPEDR3magmag'])/self.fitz_alebv(self.leff['GEDR3mag']),
                               1.-self.fitz_alebv(self.leff['Jmag'])/self.fitz_alebv(self.leff['GEDR3mag']),
                               1.-self.fitz_alebv(self.leff['Hmag'])/self.fitz_alebv(self.leff['GEDR3mag']),
                               1.-self.fitz_alebv(self.leff['Ksmag'])/self.fitz_alebv(self.leff['GEDR3mag'])])
            
            ext_vec = np.array([self.fitz_alebv(self.leff['G_BPEDR3mag'])/self.fitz_alebv(self.leff['GEDR3mag']),1.,
                               self.fitz_alebv(self.leff['G_RPEDR3magmag'])/self.fitz_alebv(self.leff['GEDR3mag']),
                               self.fitz_alebv(self.leff['Jmag'])/self.fitz_alebv(self.leff['GEDR3mag']),
                               self.fitz_alebv(self.leff['Hmag'])/self.fitz_alebv(self.leff['GEDR3mag']),
                               self.fitz_alebv(self.leff['Ksmag'])/self.fitz_alebv(self.leff['GEDR3mag'])])
            
        else:
            red_vec = np.array([self.ccm_alav(self.leff['G_BPEDR3mag'])/self.ccm_alav(self.leff['GEDR3mag'])-1.,
                               1.-self.ccm_alav(self.leff['G_RPEDR3magmag'])/self.ccm_alav(self.leff['GEDR3mag']),
                               1.-self.ccm_alav(self.leff['Jmag'])/self.ccm_alav(self.leff['GEDR3mag']),
                               1.-self.ccm_alav(self.leff['Hmag'])/self.ccm_alav(self.leff['GEDR3mag']),
                               1.-self.ccm_alav(self.leff['Ksmag'])/self.ccm_alav(self.leff['GEDR3mag'])])
            
            ext_vec = np.array([self.ccm_alav(self.leff['G_BPEDR3mag'])/self.ccm_alav(self.leff['GEDR3mag']),1.,
                               self.ccm_alav(self.leff['G_RPEDR3magmag'])/self.ccm_alav(self.leff['GEDR3mag']),
                               self.ccm_alav(self.leff['Jmag'])/self.ccm_alav(self.leff['GEDR3mag']),
                               self.ccm_alav(self.leff['Hmag'])/self.ccm_alav(self.leff['GEDR3mag']),
                               self.ccm_alav(self.leff['Ksmag'])/self.ccm_alav(self.leff['GEDR3mag'])])
        
        # calculate the intrinsic colors using a b-spline
        b_vec = 999999.0*np.ones(5)
        b_vec_deriv = 999999.0*np.ones(5)
        
        for i in range(5):
            try:
                # Interpolate the color-Teff relation using a b-spline
                color = (iso_[color_labels[i,0]]-iso_[color_labels[i,1]])
                logTe = iso_['logTe']

                bspl = utils.bspline(logTe,color)
                b_vec[i] = bspl(np.log10(self.teff)[i])
                b_vec_deriv[i] = bspl.derivative()(np.log10(self.teff)[i])
                
            except:
                b_vec[i] = 999999.0
                b_vec_deriv[i] = 999999.0
                
        # calculate the extinctions and errors
        color_diff = obs_colors-b_vec
        
        ag = np.dot(red_vec,color_diff)/np.linalg.norm(red_vec)**2
        ag_err = np.dot(red_vec,obs_colors_err)/np.linalg.norm(red_vec)**2
        
        ext = 999999.0*np.ones((6,2))
        ext[:,0] = ext_vec*ag
        ext[:,1] = ext_vec*ag_err
        
        self.ext = ext
        
        return ext
    
#     def get_ak_teff_chi(self):
#         '''
#         Calculate the teff and ak using the extinction function and chi-squared statistic
#         '''
        
#         obs_colors = self.phot[:4]-self.phot[-1]
#         obs_colors_err = np.sqrt(np.square(self.phot_err)[:4]+self.phot_err[-1]**2)
        
#         abpak_1 = self.ccm_alav(self.leff['G_BPEDR3mag'])/self.ccm_alav(self.leff['Ksmag'])-1.0
#         arpak_1 = self.ccm_alav(self.leff['G_RPEDR3magmag'])/self.ccm_alav(self.leff['Ksmag'])-1.0
#         ajak_1 = self.ccm_alav(self.leff['Jmag'])/self.ccm_alav(self.leff['Ksmag'])-1.0
#         ahak_1 = self.ccm_alav(self.leff['Hmag'])/self.ccm_alav(self.leff['Ksmag'])-1.0
        
#         test_teffs = np.linspace(self.teff-500,self.teff+500,5)
#         aks = 999999.0*np.ones(len(test_teffs))
#         aks_err = 999999.0*np.ones(len(test_teffs))
#         chis = 999999.0*np.ones(len(test_teffs))
#         rmses = 999999.0*np.ones(len(test_teffs))
        
#         for i in range(len(test_teffs)):
#             aks[i], aks_err[i], intcolors = self.extinction(test_teffs[i],self.teff_err)
#             chis[i] = sum((intcolors-obs_colors)**2/obs_colors_err**2)
#             rmses[i] = np.sqrt(sum((intcolors-obs_colors)**2)/len(intcolors))
            
#         idx = chis.argmin()
#         teff = test_teffs[idx]
#         ak = aks[idx]
#         ak_err = aks_err[idx]
#         rmse = rmses[idx]
#         chi = chis[idx]
        
#         return ak, ak_err, teff, chi, rmse
    
    #############################################
    ### Gonz\'{a}lez Hern\'{a}dez & Bonifacio ###
    #############################################
    
    def ghb_jk_teff(self,jk):
        '''
        Calculate the photometric Teff of a star using Gonz\'{a}lez Hern\'{a}dez & Bonifacio 2009
        
        Input:
        -----
            jk:   float
                  J - K color
                
        Output:
        ------
            teff: float
                  photometric teff
        '''
        b = np.array([0.6517,0.6312,0.0168,-0.0381,0.0256,0.0013])
        
        theta_eff = b[0]+b[1]*jk+b[2]*(jk**2)+b[3]*(jk*self.salfeh)+b[4]*self.salfeh+b[5]*(self.salfeh**2)
        teff = 5040/theta_eff
        self.ghb_teff = teff
        return teff
    
    ########################################################
    ### Separated Magnitudes, Extinctions, Ages & Masses ###
    ########################################################
    
    def teff_2_appmags(self,teff,age,verbose=False):
        '''
        Calculate the expected apparent magnitude of a star
        
        Inputs:
        ------
            teff:     float
                      Teff of star 
                      
            age:      float
                      age of star
        
        Output:
        ------
            calc_mag: 6x2 array 
                      expected intrinsic magnitude for the given temperature
        '''
        
        #Some calcualtions to set things up
        lgteff = np.log10(teff)
        
        extincts = self.ext
        
        lgage = np.log10(age*10**9)
        
        # Figure out if age is actually in the ages given in the isochrone table
        if lgage in self.uniq_ages:
            
            ### pick out a single isochrone 
            aidx, = np.where(self.uniq_ages==lgage)
            iso_ = self.iso[self.age_idx[int(aidx)][0]:self.age_idx[int(aidx)][1]]
            
            if lgteff < min(iso_['logTe']) or lgteff > max(iso_['logTe']):
                return np.array([999999.0, 999999.0, 999999.0, 999999.0, 999999.0])
            
            ### use a spline to get the apparent mags
            calc_mags = 999999.0*np.ones(5)
            for i in range(5):
                mag_spl = utils.bslpine(iso_['logTe'],iso_[self.labels[i]])
                calc_mags[i] = mag_spl(lgteff)+self.distmod+extincts[i]
                
            return calc_mags
            
        else:
            lgage_lo,lgage_hi = self.neighbors(self.uniq_ages,lgage)
            if verbose:
                print('[age_lo,age_hi]: ',[10**lgage_lo/10**9,10**lgage_hi/10**9])
            
            ### Pick out single isochrones
        
            # younger
            aidx_lo, = np.where(self.uniq_ages==lgage_lo)
            iso_lo = self.iso[self.age_idx[int(aidx_lo)][0]:self.age_idx[int(aidx_lo)][1]]
            
            # older
            aidx_hi, = np.where(self.uniq_ages==lgage_hi)
            iso_hi = self.iso[self.age_idx[int(aidx_hi)][0]:self.age_idx[int(aidx_hi)][1]]
            
            ### Temperature Check
            if lgteff < min(iso_lo['logTe']) or lgteff > max(iso_lo['logTe']):
                return np.array([999999.0, 999999.0, 999999.0, 999999.0, 999999.0])
            if lgteff < min(iso_hi['logTe']) or lgteff > max(iso_hi['logTe']):
                return np.array([999999.0, 999999.0, 999999.0, 999999.0, 999999.0])
            
            ### use a b-spline to get the apparent mags
            age_lo = 10**lgage_lo/10**9
            age_hi = 10**lgage_hi/10**9
            calc_mags = 999999.0*np.ones(6)
            for i in range(6):
                mag_spl_lo = utils.bslpine(iso_lo['logTe'],iso_lo[self.labels[i]])
                mag_spl_hi = utils.bslpine(iso_hi['logTe'],iso_hi[self.labels[i]])
                age_spl_interp = np.poly1d(np.squeeze(np.polyfit([age_lo,age_hi],
                                                                 [mag_spl_lo(lgteff),mag_spl_hi(lgteff)],1)))
                calc_mags[i] = age_spl_interp(age)+self.distmod+extincts[i]

            return calc_mags
    
    def get_age(self,guess_ages=np.linspace(0.,17.)[::10],verbose=False):
        '''
        Find best fitting age for a star by searching chisq space given initial guesses for
        age and extinction.
        
        Inputs:
        ------
            guess_ages: array 
                        initial guesses for ages in Gyr

            
        Output:
        ------
            age:        float 
                        best age according to chi^2 space search
                        
            chi:        float
                        best chi^2 according to chi^2 space search
                        
            rms:        float
                        RMSE of the result
            
        '''
            
        # initialize lists
        
        curve_ages = []
        curve_chi = []
        curve_rms = []

        # loop over age and ak space 
        for j in range(len(guess_ages)): 
            try:
                # calculate best fit parameters and covariance matrix
                popt,pcov = curve_fit(self.teff_2_appmags,self.teff,self.phot,p0=guess_ages[j],
                                      bounds=(0.,17.),method='trf',sigma=self.phot_err,
                                      absolute_sigma=True,maxfev=5000)

                # populate lists
                curve_ages.append(popt[0])
                curve_mags = np.asarray(self.teff_2_appmags(self.teff,popt[0]))
                curve_chi.append(sum((curve_mags-self.phot)**2/self.phot_err**2))
                curve_rms.append(np.sqrt(sum((curve_mags-self.phot)**2)/len(curve_mags)))

            except:
                # populate lists
                curve_ages.append(999999.0)
                curve_chi.append(999999.0)
                curve_rms.append(999999.0)
        
        # find smallest chisq value and corresponding age and Ak
        idx = np.asarray(curve_chi).argmin()
        age = np.asarray(curve_ages)[idx]
        chi = np.asarray(curve_chi)[idx]
        rms = np.asarray(curve_rms)[idx]
        self.age = age
        self.chi = chi
        self.rms = rms

        return age, chi, rms
    
    ###########################################
    ### curve_fit Magnitudes, Ages & Masses ###
    ###########################################
    
    def teff_2_absmags_curvefit(self,teff,ak,age):
        '''
        Calculate the intrinsic absolute magnitude of a star
        
        Inputs:
        ------
            teff:      float
                       temperature of a star
            
            ak:        float 
                       K band extinction of a star
            
            age:       float 
                       age of a star in Gyr
        
        Output:
        ------
            calc_mags: 5x1 array 
                       calculated absolute magnitudes
        '''
        
        #Some calcualtions to set things up
        lgteff = np.log10(teff)
        lgage = np.log10(age*10**9)
        
        abpak = (self.ccm_alav(self.leff['G_BPEDR3mag'])/self.ccm_alav(self.leff['Ksmag']))
        arpak = (self.ccm_alav(self.leff['G_RPEDR3magmag'])/self.ccm_alav(self.leff['Ksmag']))
        ajak = (self.ccm_alav(self.leff['Jmag'])/self.ccm_alav(self.leff['Ksmag']))
        ahak = (self.ccm_alav(self.leff['Hmag'])/self.ccm_alav(self.leff['Ksmag']))
        
        extincts = np.array([abpak,arpak,ajak,ahak,1.0])*ak
        
        # Figure out if age is actually in the ages given in the isochrone table
        if lgage in self.uniq_ages:
            
            ### pick out a single isochrone 
            aidx, = np.where(self.uniq_ages==lgage)
            iso_ = self.iso[self.age_idx[int(aidx)][0]:self.age_idx[int(aidx)][1]]
            
            if lgteff < min(iso_['logTe']) or lgteff > max(iso_['logTe']):
                return np.array([999999.0, 999999.0, 999999.0, 999999.0, 999999.0])
            
            ### sort so temp is always increasing
            sidx = np.argsort(iso_['logTe'])
            slogTe = iso_['logTe'][sidx]
            _, uidx = np.unique(slogTe,return_index=True)
            slogTe = slogTe[uidx]
            
            ### use a spline to get the apparent mags
            calc_mags = 999999.0*np.ones(5)
            for i in range(5):
                mag_spl = utils.bslpine(slogTe,iso_[self.labels[i]][sidx][uidx])
                calc_mags[i] = mag_spl(lgteff)+extincts[i]
                
            return calc_mags
            
        else:
            lgage_lo,lgage_hi = self.neighbors(self.uniq_ages,lgage)
            if verbose:
                print('[age_lo,age_hi]: ',[10**lgage_lo/10**9,10**lgage_hi/10**9])
            
            ### Pick out single isochrones
        
            # younger
            aidx_lo, = np.where(self.uniq_ages==lgage_lo)
            iso_lo = self.iso[self.age_idx[int(aidx_lo)][0]:self.age_idx[int(aidx_lo)][1]]
            
            # older
            aidx_hi, = np.where(self.uniq_ages==lgage_hi)
            iso_hi = self.iso[self.age_idx[int(aidx_hi)][0]:self.age_idx[int(aidx_hi)][1]]
            
            ### Temperature Check
            if lgteff < min(iso_lo['logTe']) or lgteff > max(iso_lo['logTe']):
                return np.array([999999.0, 999999.0, 999999.0, 999999.0, 999999.0])
            if lgteff < min(iso_hi['logTe']) or lgteff > max(iso_hi['logTe']):
                return np.array([999999.0, 999999.0, 999999.0, 999999.0, 999999.0])
            
            ### sort so temp is always increasing
            
            # younger
            sidx_lo = np.argsort(iso_lo['logTe'])
            slogTe_lo = iso_lo['logTe'][sidx_lo]
            _, uidx_lo = np.unique(slogTe_lo,return_index=True)
            slogTe_lo = slogTe_lo[uidx_lo]
            
            # older
            sidx_hi = np.argsort(iso_hi['logTe'])
            slogTe_hi = iso_hi['logTe'][sidx_hi]
            _, uidx_hi = np.unique(slogTe_hi,return_index=True)
            slogTe_hi = slogTe_hi[uidx_hi]
            
            ### use a spline to get the apparent mags
            age_lo = 10**lgage_lo/10**9
            age_hi = 10**lgage_hi/10**9
            calc_mags = 999999.0*np.ones(5)
            for i in range(5):
                mag_spl_lo = utils.bslpine(slogTe_lo,iso_lo[self.labels[i]][sidx_lo][uidx_lo])
                mag_spl_hi = utils.bslpine(slogTe_hi,iso_hi[self.labels[i]][sidx_hi][uidx_hi])
                age_spl_interp = interp1d(np.squeeze(np.polyfit([age_lo,age_hi],
                                                                [mag_spl_lo(lgteff),mag_spl_hi(lgteff)],1)))
                calc_mags[i] = age_spl_interp(age)+extincts[i]

            return calc_mags
    
    def get_ak_age_curvefit(self,guess_exts=np.linspace(0.,0.75,num=5),guess_ages=np.linspace(0.,17.,num=5)):
        # make on order 10 for guesses
        '''
        Find best fitting age and Ak values for a star by searching chisq space given initial guesses for
        age and extinction.
        
        Inputs:
        ------
            guess_ages: array
                        initial guesses for ages in Gyr
            
            guess_Exts: array 
                        initial guesses for the K band extinction
            
        Output:
        ------
            ak:         float
                        K band extinction of star
                        
            age:        float
                        age of star
                        
            chi:        float  
                        final chi squared value
                        
            rms:        float 
                        root mean squared value
            
        '''
            
        # initialize lists
        
        curve_aks = 999999.0*np.ones((len(guess_exts),len(guess_ages)))
        curve_ages = 999999.0*np.ones((len(guess_exts),len(guess_ages)))
        curve_chis = 999999.0*np.ones((len(guess_exts),len(guess_ages)))
        curve_rms = 999999.0*np.ones((len(guess_exts),len(guess_ages)))

        # loop over age and ak space 
        for i in range(len(guess_exts)):
            for j in range(len(guess_ages)): 
                try:
                    # calculate best fit parameters and covariance matrix
                    popt,_ = curve_fit(self.teff_2_absmags_curvefit,self.teff,self.absphot,
                                       p0=[guess_exts[i],guess_ages[j]],bounds=((0.,0.),(0.75,17.)),
                                       method='trf',sigma=self.phot_err,absolute_sigma=True,maxfev=5000)

                    # populate lists
                    curve_aks[i][j] = popt[0]
                    curve_ages[i][j] = popt[1]
                    curve_mags = np.asarray(self.teff_2_absmags_curvefit(self.teff,popt[0],popt[1]))
                    curve_chis[i][j]=sum((curve_mags-self.absphot)**2/self.phot_err**2)
                    curve_rms[i][j] = np.sqrt(sum(np.square(curve_mags-self.absphot))/len(curve_mags))

                except:
                    #populate lists
                    curve_aks[i][j] = 999999.0
                    curve_ages[i][j] = 999999.0
                    curve_chis[i][j] = 999999.0
                    curve_rms[i][j] = 999999.0
        
        # find smallest chisq value and corresponding age and Ak
        idx = np.argmin(curve_chis)
        ak = curve_aks.flatten()[idx]
        age = curve_ages.flatten()[idx]
        chi = curve_chis.flatten()[idx]
        rms = curve_rms.flatten()[idx]
        
        self.ak = ak
        self.age = age
        self.chi = chi
        self.rms = rms
        
        return ak,age,chi,rms
    
    ##############
    ### Masses ###
    ##############
    
    def get_mass(self,age):
        '''
        Calculate the mass of a star from its age
        
        Input:
        -----
            age:      2x1 array
                      age of a star in Gyr and associated error

        Output:
        ------
            mass:     float 
                      mass of star in solar masses
            
            mass_err: float
                      error in the calculated mass of the star
        '''

        if self.age == 999999.0:
            return np.array([999999.0,999999.0])
        
        teffcut = np.where((self.iso['logTe']<np.log10(self.teff+200.))&
                           (self.iso['logTe']>np.log10(self.teff-200.)))
        iso_ = self.iso[teffcut]
        
        if np.size(iso_) < 2:
            return np.array([999999.0,999999.0])
        
        ### calculate the mass and error using interpolation
        bspl = utils.bspline(iso_['logAge'],iso_['Mass'])
        mass = bspl(np.log10(age[0]*10**9))
        mass_deriv = bspl.derivative()(np.log10(age[0]*10**9))
        mass_err = np.abs(mass_deriv)*age[1]
        
        self.mass = np.array([mass,mass_err])
        
        return mass, mass_err
        
    def mass_2_age(self,mass):
        '''
        Calculate the age of a star given a mass using PARSEC isochrones

        Input:
        -----
            mass: float
                  mass of star in solar masses

        Output:
        ------
            age:  float
                  age of star in Gyr
        '''
        
#         try:
        mass_lo,mass_hi = neighbors(self.iso['Mass'],mass)

        iso_lo = self.iso[np.where(self.iso['Mass']==mass_lo)]
        iso_hi = self.iso[np.where(self.iso['Mass']==mass_hi)]

        # younger
        spl_lo = utils.bspline(iso_lo['MH'],10**iso_lo['logAge']/10**9)

        # older
        spl_hi = utils.bspline(iso_hi['MH'],10**iso_hi['logAge']/10**9)

        # final spline
        final_spl = interp1d([mass_lo,mass_hi],[spl_lo(self.salfeh),spl_hi(self.salfeh)])
        age = final_spl(mass)
        return age


# Extra Functions

In [3]:
from tqdm import tqdm_notebook

leff = {'BP':0.5387,'G':0.6419,'RP':0.7667,'J':1.2345,'H':1.6393,'K':2.1757}
def ccm_a(x):
    '''
    a(x) function from Cardelli et al. 1989
    
    Input:
    -----
        x: effective wavelength in units of 1/micron
        
    Output:
    ------
        a: a function value  
    '''
    if 0.3 <= x < 1.1:
        a = 0.574*(x**1.61)
        return a
    
    elif 1.1 <= x < 3.3:
        y = x - 1.82
        a = (1.+0.17699*y-0.50477*(y**2)-0.02427*(y**3)+0.72085*(y**4)+
                0.01979*(y**5)-0.77530*(y**6)+0.32999*(y**7))
        return a
    
    elif 3.3 <= x < 8.0:
        if x < 5.9:
            a = 1.752-0.136*x-0.104/((x-4.67)**2+0.341)
            return a
        
        else:
            fa = -0.04473*((x-5.9)**2)+0.1207*((x-5.9)**3)
            a = 1.752-0.136*x-0.104/((x-4.67)**2+0.341)+fa
            return a       
    
def ccm_b(x):
    '''
    b(x) function from Cardelli et al. 1989
    
    Input:
    -----
        x: effective wavelength in units of 1/micron
        
    Output:
    ------
        b: b function value 
    '''
    if 0.3 <= x < 1.1:
        b = -0.527*(x**1.61)
        return b
    
    elif 1.1 <= x <= 3.3:
        y = x - 1.82
        b = (1.41338*y+2.28305*(y**2)+1.07233*(y**3)-5.38434*(y**4)-
                0.62251*(y**5)+5.30260*(y**6)-2.09002*(y**7))
        return b
    
    elif 3.3 <= x < 8.0:
        if x < 5.9:
            b = -3.090+1.825*x+1.206/((x-4.62)**2+0.263)
            return b
        
        else:
            fb = 0.2130*((x-5.9)**2)+0.1207*((x-5.9)**3)
            b = -3.090+1.825*x+1.206/((x-4.62)**2+0.263)+fb
            return b
    
def ccm_alav(wave,rv):
    '''
    Calculate A\lambda/Av
    
    Inputs:
    ------
        wave: effective wavelength in units of micron
        rv: Rv value (=Av/E(B_V))
        
    Output:
    ------
        alav: A\lambda/Av
    '''
    x=1/wave
    alav = ccm_a(x)+ccm_b(x)/rv
    return alav

ejk_ak = (ccm_alav(leff['J'],3.1)-ccm_alav(leff['K'],3.1))/ccm_alav(leff['K'],3.1)
ebv_ak = (ccm_alav(0.445,3.1)-ccm_alav(0.551,3.1))/ccm_alav(leff['K'],3.1)

def fitz_alebv(wave):
    '''
    Fitzpatrick 1999 extinction law
    '''
    recip_anchors = np.array([0.000,0.377,0.820,1.667,1.828,2.141,2.433,3.704,3.846])#[:-2]
    alebv_anchors = np.array([0.000,0.265,0.829,2.688,3.055,3.806,4.315,6.265,6.591])#[:-2]

    x=1/wave
    spl = interp1d(recip_anchors,alebv_anchors)
    alebv = spl(x)
    return alebv

def closest(data,value):
    '''
    Find nearest value in array to given value
        
    Inputs:
    ------
        data: data to search through 
        value: value of interest
    '''
        
    data = np.asarray(data)
#     data = data[np.where(np.isfinite(data)==True)]
    return data[(np.abs(np.subtract(data,value))).argmin()]

def neighbors(data,value):
    '''
    Find values of two elements closest to the given value

    Inputs:
    ------
        data: data to search through 
        value: value of interest

    Output:
    ------
        close1: closest value under the given value
        close2: closest value over the given value
    '''

    data = np.asarray(data)
    close1 = data[(np.abs(np.subtract(data,value))).argmin()]
    data = data[np.where(data!=close1)]
    close2 = data[(np.abs(np.subtract(data,value))).argmin()]

    return close1,close2

def mad(dat):
    return np.nanmedian(np.abs(dat-np.nanmedian(dat)))

def mass_2_age_PARSEC(mass,salfeh,isochrones):
    '''
    Inputs:
    ------
        mass: star mass in Msun
        salfeh: Salaris corrected [Fe/H]
        isochrones: table of PARSEC isochrones
    
    Output:
    ------
        age: age in Gyr
    '''
    
    feh_lo,feh_hi = neighbors(isochrones['MH'],salfeh)
    
    iso_lo = isochrones[np.where(isochrones['MH']==feh_lo)]
    iso_hi = isochrones[np.where(isochrones['MH']==feh_hi)]
    
    try:
        # lower [Fe/H]
        sidx_lo = np.argsort(iso_lo['Mass'])
        smass_lo = iso_lo['Mass'][sidx_lo]
        _,uidx_lo = np.unique(smass_lo,return_index=True)
        smass_lo = smass_lo[uidx_lo]
        sage_lo = 10**iso_lo['logAge'][sidx_lo][uidx_lo]/10**9

        spl_lo = interp1d(smass_lo,sage_lo)

        # higher [Fe/H]
        sidx_hi = np.argsort(iso_hi['Mass'])
        smass_hi = iso_hi['Mass'][sidx_hi]
        _,uidx_hi = np.unique(smass_hi,return_index=True)
        smass_hi = smass_hi[uidx_hi]
        sage_hi = 10**iso_hi['logAge'][sidx_hi][uidx_hi]/10**9

        spl_hi = interp1d(smass_hi,sage_hi)

        # final spline
        final_spl = interp1d([feh_lo,feh_hi],[spl_lo(mass),spl_hi(mass)])
        age = final_spl(salfeh)
        return age
    except:
        return 999999.0

# PARSEC

In [37]:
massive = Table(fits.getdata('/Users/joshuapovick/Desktop/Research/parsec/parsec36_DR2_EDR3.fits'))
# massive = Table(massive[np.where(massive['label']==3.0)])
# massive['index'] = np.arange(len(massive))
massive = massive[np.argsort(massive['logAge'])]
massive = massive['MH','Mass','logAge','logTe','logg','GEDR3mag','G_BPEDR3mag','G_RPEDR3mag',
                  'Jmag','Hmag','Ksmag']

# ICR

In [38]:
allicr = fits.getdata('/Users/joshuapovick/Desktop/Research/MS_Young/magstream_youngstars_gaiaedr32mass.060122.fits')

In [39]:
allicr.columns

ColDefs(
    name = 'NAME'; format = '6A'
    name = 'RA'; format = 'D'
    name = 'DEC'; format = 'D'
    name = 'GLON'; format = 'D'
    name = 'GLAT'; format = 'D'
    name = 'MLON'; format = 'D'
    name = 'MLAT'; format = 'D'
    name = 'GAIA_ID'; format = '22A'
    name = 'GAIA_G'; format = 'E'
    name = 'GAIA_BP'; format = 'E'
    name = 'GAIA_RP'; format = 'E'
    name = 'GAIA_PARALLAX'; format = 'E'
    name = 'GAIA_E_PARALLAX'; format = 'E'
    name = 'GAIA_PMRA'; format = 'E'
    name = 'GAIA_E_PMRA'; format = 'E'
    name = 'GAIA_PMDEC'; format = 'E'
    name = 'GAIA_E_PMDEC'; format = 'E'
    name = 'GAIA_DISTANCE'; format = 'E'
    name = 'GAIA_E_DISTANCE'; format = 'E'
    name = 'V'; format = 'E'
    name = 'BV'; format = 'E'
    name = 'EBV'; format = 'E'
    name = 'MIKE_VHELIO'; format = 'E'
    name = 'MIKE_E_VHELIO'; format = 'E'
    name = 'MIKE_VLSR'; format = 'E'
    name = 'MIKE_VGSR'; format = 'E'
    name = 'MIKE_TEFF'; format = 'E'
    name = 'MIKE_E_TEFF';

# Gaia EDR3

In [40]:
### Replace nan's with zeropoint error
bperr = np.nan_to_num(allicr['GAIAEDR3_BPERR'],nan=0.0027901700)
gerr = np.nan_to_num(allicr['GAIAEDR3_GERR'],nan=0.0027553202)
rperr = np.nan_to_num(allicr['GAIAEDR3_RPERR'],nan=0.0037793818)

# Calculate Ages

In [47]:
icr_ext = 999999.0*np.ones((len(allicr),6))
icr_ext_err = 999999.0*np.ones((len(allicr),6))
icr_age = 999999.0*np.ones(len(allicr))
icr_chi = 999999.0*np.ones(len(allicr))
icr_rms = 999999.0*np.ones(len(allicr))
icr_mass = 999999.0*np.ones(len(allicr))
icr_mass_err = 999999.0*np.ones(len(allicr))

for i in tqdm_notebook(range(len(allicr))):
    
    # Initialize Aetas 
    te = np.array([allicr['BEST_TEFF'][i],allicr['BEST_E_TEFF'][i]])
    ab = np.array([[allicr['BEST_METAL'][i],0.0],
                   [allicr['BEST_E_METAL'][i],0.0]]).T
    op = np.array([[allicr['GAIAEDR3_BPMAG'][i],
                    allicr['GAIAEDR3_GMAG'][i],
                    allicr['GAIAEDR3_RPMAG'][i],
                    allicr['JMAG'][i],allicr['HMAG'][i],
                    allicr['KSMAG'][i]],
                   [bperr[i],gerr[i],rperr[i],allicr['JERR'][i],
                    allicr['HERR'][i],allicr['KSERR'][i]]]).T
    di = (1/allicr['GAIAEDR3_PARALLAX'][i])*1000

    CalcAge = Aetas(te,ab,op,di,massive)
    
    exts = CalcAge.extinction()
    icr_ext[i,:],icr_ext[i,:] = exts[:,0],exts[:,1]
    icr_age[i],icr_chi[i],icr_rms[i] = CalcAge.get_age()
    
    print('-------------------')

    icr_mass[i],icr_mass_err[i] = CalcAge.get_mass(icr_age[i])
    

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for i in tqdm_notebook(range(len(allicr))):


HBox(children=(FloatProgress(value=0.0, max=101.0), HTML(value='')))

[0.21365738 0.40678406 1.0328865  1.1928864  1.1248865 ]
[0.04324441 0.02619823 0.0550027  0.09855606 0.15634993]
-------------------
[ 4.0218353e-02  6.4577103e-02 -8.2952148e+01 -8.2952148e+01
 -8.2952148e+01]
[7.7417968e-03 7.3279906e-03 9.9899998e+00 9.9899998e+00 9.9899998e+00]
-------------------
[  1.1268234   1.0268784 -79.95391   -79.95391   -79.95391  ]
[0.13860199 0.04169502 9.990001   9.990001   9.990001  ]
-------------------
[ 4.3584824e-02  4.3918610e-02 -8.2795822e+01 -8.2795822e+01
 -8.2795822e+01]
[6.332996e-03 6.114058e-03 9.990000e+00 9.990000e+00 9.990000e+00]
-------------------
[-0.00905037 -0.04155922  0.13014221  0.32814217 -0.38785744]
[0.00557211 0.00856651 0.13400446        inf        inf]
-------------------


  obs_colors_err = np.delete(np.sqrt(self.phot_err[:]**2+self.phot_err[1]**2),1)


[-0.01161003 -0.04424477  0.2198677   0.7838669  -0.1381321 ]
[0.00601186 0.00755195 0.14700328 0.21500224        inf]
-------------------


  self.distmod = 5.0*np.log10(distance)-5.0


[-3.9329529e-02 -1.0694885e-01 -8.2852394e+01 -8.2852394e+01
 -8.2852394e+01]
[5.008427e-03 9.393211e-03 9.990000e+00 9.990000e+00 9.990000e+00]
-------------------
[-0.02789497 -0.04609871 -0.04076958  0.49623108  0.6482315 ]
[0.00310177 0.00379229 0.11100252        inf        inf]
-------------------
[-0.03507328 -0.07432938 -0.1805811  -0.14358139 -0.07058048]
[0.00119854 0.00115026 0.0280063  0.0410043  0.06900255]
-------------------
[-0.01966476 -0.03941536 -0.11248779 -0.20048809 -0.19948769]
[0.00077257 0.00068584 0.02200236 0.03300157 0.03500148]
-------------------
[-7.5777054e-02 -1.5547752e-01 -8.2781883e+01 -8.2781883e+01
 -8.2781883e+01]
[6.1941994e-03 9.4536832e-03 9.9899998e+00 9.9899998e+00 9.9899998e+00]
-------------------
[0.14259338 0.24210167 0.46204948 0.93704987 0.3280487 ]
[0.00686282 0.00724859 0.15800223 0.20300172        inf]
-------------------
[-0.08298302 -0.1577053  -0.34553337 -0.686533    0.30746746]
[0.00413593 0.0029761  0.1420018         inf        

[ -0.12174416  -0.25867462 -83.47793    -83.47793    -83.47793   ]
[4.8655267e-03 6.7533837e-03 9.9899998e+00 9.9899998e+00 9.9899998e+00]
-------------------
[-0.10884476 -0.2247715  -0.74749756 -0.27049828  0.7795019 ]
[0.00507289 0.00338646 0.15400359        inf        inf]
-------------------
[0.3519678  0.52070236 1.1331291  1.4951296  1.5451298 ]
[0.00123199 0.00130854 0.02500091 0.03000075 0.03700061]
-------------------
[ -0.11590385  -0.2443428  -83.51489    -83.51489    -83.51489   ]
[3.987954e-03 8.072807e-03 9.990000e+00 9.990000e+00 9.990000e+00]
-------------------
[-2.9907227e-02 -7.3873520e-02 -8.2912575e+01 -8.2912575e+01
 -8.2912575e+01]
[5.8489437e-03 9.6444208e-03 9.9899998e+00 9.9899998e+00 9.9899998e+00]
-------------------
[ 0.00957584  0.0077858   0.03284645 -0.17815399  0.0138464 ]
[0.00174947 0.00232096 0.03900728 0.1050027  0.15500183]
-------------------
[0.06666279 0.13551617 0.35001278 0.393013   0.43501282]
[0.0007015  0.0006176  0.02000258 0.02400216 0.0

In [42]:
for i in range(len(allicr)):
    print('########## {} ##########'.format(allicr['NAME'][i]))
    print(icr_ext[i])
    print(icr_ext_err[i])

########## DI1020 ##########
[0.16997277 0.13952233 0.1074299  0.04780287 0.0302799  0.01919642]
[999999. 999999. 999999. 999999. 999999. 999999.]
########## DI1085 ##########
[14.83437512 12.17681279  9.3759457   4.17199565  2.64267905  1.67536755]
[999999. 999999. 999999. 999999. 999999. 999999.]
########## DI1294 ##########
[14.8578872  12.19611271  9.39080632  4.17860815  2.64686762  1.67802296]
[999999. 999999. 999999. 999999. 999999. 999999.]
########## DI1304 ##########
[14.83399685 12.17650229  9.37570662  4.17188926  2.64261166  1.67532483]
[999999. 999999. 999999. 999999. 999999. 999999.]
########## DI1316 ##########
[inf inf inf inf inf inf]
[999999. 999999. 999999. 999999. 999999. 999999.]
########## DI1340 ##########
[inf inf inf inf inf inf]
[999999. 999999. 999999. 999999. 999999. 999999.]
########## DI1353 ##########
[14.83429677 12.17674848  9.37589618  4.17197361  2.64266509  1.6753587 ]
[999999. 999999. 999999. 999999. 999999. 999999.]
########## DI1354 ##########
[i

In [43]:
icr_ext[np.where(allicr['GAIAEDR3_PARALLAX']/allicr['GAIAEDR3_PARALLAX_ERROR']>3)]

array([[1.69972765e-01, 1.39522327e-01, 1.07429899e-01, 4.78028653e-02,
        3.02799047e-02, 1.91964173e-02],
       [           inf,            inf,            inf,            inf,
                   inf,            inf],
       [7.12774831e-02, 5.85081988e-02, 4.50503514e-02, 2.00459640e-02,
        1.26977719e-02, 8.04995029e-03],
       [4.56580745e-02, 3.74784797e-02, 2.88578133e-02, 1.28408030e-02,
        8.13378629e-03, 5.15654053e-03],
       [1.22328201e-01, 1.00413236e-01, 7.73165410e-02, 3.44033853e-02,
        2.17922340e-02, 1.38155262e-02],
       [7.97294000e-02, 6.54459639e-02, 5.03923164e-02, 2.24229674e-02,
        1.42034439e-02, 9.00449453e-03],
       [8.51186085e-02, 6.98697015e-02, 5.37985216e-02, 2.39386197e-02,
        1.51635078e-02, 9.61314201e-03],
       [           inf,            inf,            inf,            inf,
                   inf,            inf],
       [1.38904911e-01, 1.14020246e-01, 8.77937151e-02, 3.90653923e-02,
        2.47453024e-02, 

In [44]:
icr_age

array([999999., 999999., 999999., 999999., 999999., 999999., 999999.,
       999999., 999999., 999999., 999999., 999999., 999999., 999999.,
       999999., 999999., 999999., 999999., 999999., 999999., 999999.,
       999999., 999999., 999999., 999999., 999999., 999999., 999999.,
       999999., 999999., 999999., 999999., 999999., 999999., 999999.,
       999999., 999999., 999999., 999999., 999999., 999999., 999999.,
       999999., 999999., 999999., 999999., 999999., 999999., 999999.,
       999999., 999999., 999999., 999999., 999999., 999999., 999999.,
       999999., 999999., 999999., 999999., 999999., 999999., 999999.,
       999999., 999999., 999999., 999999., 999999., 999999., 999999.,
       999999., 999999., 999999., 999999., 999999., 999999., 999999.,
       999999., 999999., 999999., 999999., 999999., 999999., 999999.,
       999999., 999999., 999999., 999999., 999999., 999999., 999999.,
       999999., 999999., 999999., 999999., 999999., 999999., 999999.,
       999999., 9999