# Aetas Code

In [4]:
#Astropy
import astropy
from astropy.io import fits
from astropy.table import Table

#emcee
import emcee
import corner as corner
import time

# Matplotlib
import matplotlib
import matplotlib.pyplot as plt
# %matplotlib inline
matplotlib.rcParams.update({'font.size': 25})

#Numpy/Scipy
import numpy as np
from scipy.interpolate import InterpolatedUnivariateSpline as IUS
from scipy.interpolate import interp1d
from scipy.optimize import curve_fit

class Aetas():
    '''
    A class to calculate a star's age and extinction using PARSEC isochrones and extinction law
    from Cardelli et al. 1989 by fitting Gaia (BP, RP) and 2MASS (J,H,K) photometry
    '''
    def __init__(self,teff,abund,obsphot,distance,isochrones,rv=3.1):
        
        '''
        teff: [array] Teff and error of star np.array([teff,teff_err])
        abund: [2x1 array] first column is [M/H],[Alpha/M]
        obs_phot: [5x2 array] first column is GBP,GRP,J,H,K and the second column is the errors in the first 
        distance: [float] distance to star in pc
        isochrones: [astropy Table] PARSEC isochrone table
        rv: [float] Rv value (=Av/E(B_V)) 
        '''
        
        # Observed Quantities
        self.teff = teff[0] # temperature
        self.teff_err = teff[1] # temperature error
#         # Asplund 2009 
#         self.salfeh = abund[0,0]+np.log10(0.655*(10**(abund[1,0]))+0.345) # Salaris Corrected [Fe/H]
#         self.salfeh_err = np.sqrt(abund[0,1]**2+((1-0.345/(0.655*(10**(abund[1,0]))+0.345))*abund[1,1])**2)
        # Asplund 2021 et al.
        self.salfeh = abund[0,0]+np.log10(0.659*(10**(abund[1,0]))+0.341) # Salaris Corrected [Fe/H]
        self.salfeh_err = np.sqrt(abund[0,1]**2+((1-0.341/(0.659*(10**(abund[1,0]))+0.341))*abund[1,1])**2)
        self.phot = obsphot[:,0] # photometry
        self.phot_err = obsphot[:,1] # photometry errors
        
        # Distance modulus
        self.distance = distance
        self.distmod = 5.0*np.log10(distance)-5.0
        
        # Absolute Magnitudes Not Derreddened
        self.absphot = self.phot-self.dismod
        
        # PARSEC isochrones
        self.rv = rv
        self.labels = ['G_BPmag','G_RPmag','Jmag','Hmag','Ksmag']
        self.phot_cov = np.cov(np.array([isochrones['G_BPmag'],isochrones['G_RPmag'],
                                         isochrones['Jmag'],isochrones['Hmag'],isochrones['Ksmag']]))
        
#         self.uniq_ages = np.unique(isochrones['logAge'])
        self.iso = isochrones[np.where(isochrones['MH']==self.closest(isochrones['MH'],self.salfeh))]
        self.uniq_ages = np.unique(self.iso['logAge'])
        
        age_idx = []
        for i in range(len(self.uniq_ages)):
            ages, = np.where(self.iso['logAge']==self.uniq_ages[i])
            age_idx.append(np.array([min(ages),max(ages)]))
        
        self.age_idx = np.asarray(age_idx)
                
        # Effective Wavelengths of different passbands in units of microns
        self.leff = {'G_BPmag':0.5387,'Gmag':0.6419,'G_RPmag':0.7667,'Jmag':1.2345,'Hmag':1.6393,'Ksmag':2.1757}
        
    #################
    ### Utilities ###
    #################
        
    def closest(self,data,value):
        '''
        Find nearest value in array to given value
        
        Inputs:
        ------
            data: data to search through 
            value: value of interest
        '''
        
        data = np.asarray(data)
    
        return data[(np.abs(np.subtract(data,value))).argmin()]
    
    def neighbors(self,data,value):
        '''
        Find values of two elements closest to the given value
    
        Inputs:
        ------
            data: data to search through 
            value: value of interest
        
        Output:
        ------
            close1: closest value under the given value
            close2: closest value over the given value
        '''
    
        data = np.asarray(data)
        close1 = data[(np.abs(np.subtract(data,value))).argmin()]
        data = data[np.where(data!=close1)]
        close2 = data[(np.abs(np.subtract(data,value))).argmin()]
    
        return close1,close2
    
    def mad(self,data):
        '''
        Calculate the median absolute deviation of the data
        '''
        return np.nanmedian(np.abs(data-np.nanmedian(data)))
    
    ##################
    ### Extinction ###
    ##################
    
    def ccm_a(self,x):
        '''
        a(x) function from ccm et al. 1989
    
        Input:
        -----
            x: effective wavelength in units of 1/micron
        
        Output:
        ------
            a: a function value  
        '''
        if 0.3 <= x < 1.1:
            a = 0.574*(x**1.61)
            return a
    
        elif 1.1 <= x < 3.3:
            y = x - 1.82
            a = (1.+0.17699*y-0.50477*(y**2)-0.02427*(y**3)+0.72085*(y**4)+
                 0.01979*(y**5)-0.77530*(y**6)+0.32999*(y**7))
            return a
    
        elif 3.3 <= x < 8.0:
            if x < 5.9:
                a = 1.752-0.136*x-0.104/((x-4.67)**2+0.341)
                return a
        
            else:
                fa = -0.04473*((x-5.9)**2)+0.1207*((x-5.9)**3)
                a = 1.752-0.136*x-0.104/((x-4.67)**2+0.341)+fa
                return a       
    
    def ccm_b(self,x):
        '''
        b(x) function from ccm et al. 1989
    
        Input:
        -----
            x: effective wavelength in units of 1/micron
        
        Output:
        ------
            b: b function value 
        '''
        if 0.3 <= x < 1.1:
            b = -0.527*(x**1.61)
            return b
    
        elif 1.1 <= x <= 3.3:
            y = x - 1.82
            b = (1.41338*y+2.28305*(y**2)+1.07233*(y**3)-5.38434*(y**4)-
                 0.62251*(y**5)+5.30260*(y**6)-2.09002*(y**7))
            return b
    
        elif 3.3 <= x < 8.0:
            if x < 5.9:
                b = -3.090+1.825*x+1.206/((x-4.62)**2+0.263)
                return b
        
            else:
                fb = 0.2130*((x-5.9)**2)+0.1207*((x-5.9)**3)
                b = -3.090+1.825*x+1.206/((x-4.62)**2+0.263)+fb
                return b
    
    def ccm_alav(self,wave):
        '''
        Calculate A\lambda/Av
    
        Inputs:
        ------
            wave: effective wavelength in units of micron
        
        Output:
        ------
            alav: A\lambda/Av
        '''
        x=1/wave
        alav = self.ccm_a(x)+self.ccm_b(x)/self.rv
        return alav
    
    def extinction(self):
        '''
        Calculate the K band extinction and its uncertainty
        '''

        ### Reddening to K band extinctions based on Cardelli et al.
        ebpk_k = ((self.ccm_alav(self.leff['G_BPmag'])-self.ccm_alav(self.leff['Ksmag']))/
                  self.ccm_alav(self.leff['Ksmag']))
        erpk_k = ((self.ccm_alav(self.leff['G_RPmag'])-self.ccm_alav(self.leff['Ksmag']))/
                  self.ccm_alav(self.leff['Ksmag']))
        ejk_k = ((self.ccm_alav(self.leff['Jmag'])-self.ccm_alav(self.leff['Ksmag']))/
                 self.ccm_alav(self.leff['Ksmag']))
        ehk_k = ((self.ccm_alav(self.leff['Hmag'])-self.ccm_alav(self.leff['Ksmag']))/
                 self.ccm_alav(self.leff['Ksmag']))

        e_k = np.array([ebpk_k,erpk_k,ejk_k,ehk_k])

        ### Spline 
        # pick isochrone points with temperatures within 100 K of the star's Teff
        teffcut = np.where((self.iso['logTe']<np.log10(self.teff+100.))&
                           (self.iso['logTe']>np.log10(self.teff-100.)))
        iso_ = self.iso[teffcut]

        # sort teffs
        sidx = np.argsort(iso_['logTe'])
        slogTe = iso_['logTe'][sidx]
        _, uidx = np.unique(slogTe,return_index=True)
        slogTe = slogTe[uidx]

        calc_aks = 999999.0*np.ones(4)
        calc_aks_err = 999999.0*np.ones(4)

        coeff0s = 999999.0*np.ones(4)
        for i in range(4):
            try:
                coeff = np.polyfit(10**slogTe,(iso_[self.labels[i]]-iso_['Ksmag'])[sidx][uidx],1)
                color_line = np.poly1d(coeff)
                calc_aks[i] = np.divide((self.phot[i]-self.phot[-1])-color_line(self.teff),e_k[i])
                coeff0s[i] = coeff[0]

            except:
                calc_aks[i] = 999999.0 
                calc_aks_err[i] = 999999.0

        calc_aks_err = np.sqrt(np.divide(np.square(self.phot_err[:-1])+
                                         np.square(np.multiply(coeff0s,self.teff_err)),np.square(e_k)))

        ### Weighted Mean
        wgts = np.square(np.reciprocal(calc_aks_err))
        wgts_sum = np.sum(wgts)

        ak = np.sum(np.multiply(calc_aks,wgts))/wgts_sum
        ak_err = np.sqrt(np.reciprocal(wgts_sum))
        
        self.ak = ak
        self.ak_err = ak_err
        
        return ak, ak_err
    
    #################################
    ### Magnitudes, Ages & Masses ###
    #################################
    
    def teff_2_absmags(self,age,ak,teff):
        '''
        Calculate the intrinsic absolute magnitude of a star
        
        Inputs:
        ------
            teff: temperature of a star
            age: age of a star in Gyr
        
        output:
        ------
            calc_mags: calculated intrinsic
        '''
        
        #Some calcualtions to set things up
        lgteff = np.log10(teff)
        lgage = np.log10(age*10**9)
        
        # Figure out if age is actually in the ages given in the isochrone table
        if lgage in self.uniq_ages:
            
            ### pick out a single isochrone 
            aidx, = np.where(self.uniq_ages==lgage)
            iso_ = self.iso[self.age_idx[int(aidx)][0]:self.age_idx[int(aidx)][1]]
            
            if lgteff < min(iso_['logTe']) or lgteff > max(iso_['logTe']):
                return np.array([999999.0, 999999.0, 999999.0, 999999.0, 999999.0])
            
            ### sort so temp is always increasing
            sidx = np.argsort(iso_['logTe'])
            slogTe = iso_['logTe'][sidx]
            _, uidx = np.unique(slogTe,return_index=True)
            slogTe = slogTe[uidx]
            
            ### use a spline to get the apparent mags
            calc_mags = 999999.0*np.ones(5)
            for i in range(5):
                mag_spl = IUS(slogTe,iso_[self.labels[i]][sidx][uidx])
                calc_mags[i] = mag_spl(lgteff)+self.distmod+extincts[i]
                
            return calc_mags
            
        else:
            lgage_lo,lgage_hi = self.neighbors(self.uniq_ages,lgage)
            if verbose:
                print('[age_lo,age_hi]: ',[10**lgage_lo/10**9,10**lgage_hi/10**9])
            
            ### Pick out single isochrones
        
            # younger
            aidx_lo, = np.where(self.uniq_ages==lgage_lo)
            iso_lo = self.iso[self.age_idx[int(aidx_lo)][0]:self.age_idx[int(aidx_lo)][1]]
            
            # older
            aidx_hi, = np.where(self.uniq_ages==lgage_hi)
            iso_hi = self.iso[self.age_idx[int(aidx_hi)][0]:self.age_idx[int(aidx_hi)][1]]
            
            ### Temperature Check
            if lgteff < min(iso_lo['logTe']) or lgteff > max(iso_lo['logTe']):
                return np.array([999999.0, 999999.0, 999999.0, 999999.0, 999999.0])
            if lgteff < min(iso_hi['logTe']) or lgteff > max(iso_hi['logTe']):
                return np.array([999999.0, 999999.0, 999999.0, 999999.0, 999999.0])
            
            ### sort so temp is always increasing
            
            # younger
            sidx_lo = np.argsort(iso_lo['logTe'])
            slogTe_lo = iso_lo['logTe'][sidx_lo]
            _, uidx_lo = np.unique(slogTe_lo,return_index=True)
            slogTe_lo = slogTe_lo[uidx_lo]
            
            # older
            sidx_hi = np.argsort(iso_hi['logTe'])
            slogTe_hi = iso_hi['logTe'][sidx_hi]
            _, uidx_hi = np.unique(slogTe_hi,return_index=True)
            slogTe_hi = slogTe_hi[uidx_hi]
            
            ### use a spline to get the apparent mags
            age_lo = 10**lgage_lo/10**9
            age_hi = 10**lgage_hi/10**9
            calc_mags = 999999.0*np.ones(5)
            for i in range(5):
                mag_spl_lo = IUS(slogTe_lo,iso_lo[self.labels[i]][sidx_lo][uidx_lo])
                mag_spl_hi = IUS(slogTe_hi,iso_hi[self.labels[i]][sidx_hi][uidx_hi])
                age_spl_interp = np.poly1d(np.squeeze(np.polyfit([age_lo,age_hi],
                                                                 [mag_spl_lo(lgteff),mag_spl_hi(lgteff)],1)))
                calc_mags[i] = age_spl_interp(age)

            return calc_mags
    
#     def teff_2_appmags(self,teff,age):
#         '''
#         Calculate the expected apparent magnitudes of a star in the BP, RP, J, H, and K bands
        
#         Input:
#         -----
#             teff: Teff of a star
#             age: age of a star
            
#         Output:
#         ------
#             calc_mags: calculated absolute magnitude
#         '''
        
#         ### Setup Calculations
#         # extinctions
#         abpak = (self.ccm_alav(self.leff['G_BPmag'])/self.ccm_alav(self.leff['Ksmag']))
#         arpak = (self.ccm_alav(self.leff['G_RPmag'])/self.ccm_alav(self.leff['Ksmag']))
#         ajak = (self.ccm_alav(self.leff['Jmag'])/self.ccm_alav(self.leff['Ksmag']))
#         ahak = (self.ccm_alav(self.leff['Hmag'])/self.ccm_alav(self.leff['Ksmag']))
        
#         aks = np.array([abpak,arpak,ajak,ahak,1.0])*self.ak
        
#         # log values
#         lgteff = np.log10(teff)
#         lgage = np.log10(age*10**9)
        
#         ### Check if age is in table
#         if lgage in self.uniq_ages:
            
#             # pick isochrone and only keep points close in Teff 
#             aidx, = np.where(self.uniq_ages==lgage)
#             iso_ = self.iso[self.age_idx[int(aidx)][0]:self.age_idx[int(aidx)][1]]
#             iso_ = iso_[np.where((iso_['logTe']<np.log10(teff+100.))&(iso_['logTe']>np.log10(teff-100.)))]
            
#             # check if there are enough points
#             if np.size(iso_) < 2:
#                 return np.array([999999.0,999999.0,999999.0,999999.0,999999.0])
            
#             # sort Teff values
#             sidx = np.argsort(iso_['logTe'])
#             slogTe = iso_['logTe'][sidx]
#             _, uidx = np.unique(slogTe,return_index=True)
#             slogTe = slogTe[uidx]
            
#             # calculate the apparent magnitudes
#             calc_mags = 999999.0*np.ones(5)
#             for i in range(len(calc_mags)):
#                 interpol = np.poly1d(np.polyfit(slogTe,iso_[self.labels[i]][sidx][uidx],1))
#                 calc_mags[i] = interpol(teff)+self.distmod+aks[i]
                
#             return calc_mags
            
#         else:
#             # find closest two ages to the given age
#             lgage_lo, lgage_hi = self.neighbors(self.uniq_ages,lgage)
            
#             # lower age
#             # pick isochrone and only keep points close in Teff
#             aidx_lo, = np.where(self.uniq_ages==lgage_lo)
#             iso_lo = self.iso[self.age_idx[int(aidx_lo)][0]:self.age_idx[int(aidx_lo)][1]]
#             iso_lo = iso_lo[np.where((iso_lo['logTe']<np.log10(teff+100.))&
#                                      (iso_lo['logTe']>np.log10(teff-100.)))]
            
#             # sort Teff values
#             sidx_lo = np.argsort(iso_lo['logTe'])
#             slogTe_lo = iso_lo['logTe'][sidx_lo]
#             _, uidx_lo = np.unique(slogTe_lo,return_index=True)
#             slogTe_lo = slogTe_lo[uidx_lo]
            
#             # higher age
#             # pick isochrone and only keep points close in Teff
#             aidx_hi, = np.where(self.uniq_ages==lgage_hi)
#             iso_hi = self.iso[self.age_idx[int(aidx_hi)][0]:self.age_idx[int(aidx_hi)][1]]
#             iso_hi = iso_hi[np.where((iso_hi['logTe']<np.log10(teff+100.))&
#                                      (iso_hi['logTe']>np.log10(teff-100.)))]
            
#             # check if there are enough points
#             if (np.size(iso_lo) < 2) or (np.size(iso_hi) < 2):
#                 return np.array([999999.0,999999.0,999999.0,999999.0,999999.0])
            
#             # sort Teff values
#             sidx_hi = np.argsort(iso_hi['logTe'])
#             slogTe_hi = iso_hi['logTe'][sidx_hi]
#             _, uidx_hi = np.unique(slogTe_hi,return_index=True)
#             slogTe_hi = slogTe_hi[uidx_hi]
            
#             # calculate the apparent magnitudes
#             calc_mags = 999999.0*np.ones(5)
#             for i in range(len(calc_mags)):
#                 interpol_lo = np.poly1d(np.polyfit(slogTe_lo,iso_lo[self.labels[i]][sidx_lo][uidx_lo],1))
#                 interpol_hi = np.poly1d(np.polyfit(slogTe_hi,iso_hi[self.labels[i]][sidx_hi][uidx_hi],1))
#                 interpol_age = np.poly1d(np.polyfit([lgage_lo,lgage_hi],
#                                                     [interpol_lo(teff),interpol_hi(teff)],1))
#                 calc_mags[i] = interpol_age(lgage)+self.distmod+aks[i]
        
#             return calc_mags
        
    def teff_2_appmags(self,teff,age,verbose=False):
        '''
        Calculate the expected apparent magnitude of a star
        
        Inputs:
        ------
            teff: Teff of star 
            age: age of star
            ak: extinction in the 2MASS K band
        
        Output:
        ------
            calc_mag: expected intrinsic magnitude for the given temperature
        '''
        
        #Some calcualtions to set things up
        lgteff = np.log10(teff)
        
        abpak = self.ccm_alav(self.leff['G_BPmag'])/self.ccm_alav(self.leff['Ksmag'])
        arpak = self.ccm_alav(self.leff['G_RPmag'])/self.ccm_alav(self.leff['Ksmag'])
        ajak = self.ccm_alav(self.leff['Jmag'])/self.ccm_alav(self.leff['Ksmag'])
        ahak = self.ccm_alav(self.leff['Hmag'])/self.ccm_alav(self.leff['Ksmag'])
        
        extincts = np.array([abpak,arpak,ajak,ahak,1.0])*self.ak
        
        lgage = np.log10(age*10**9)
        
        # Figure out if age is actually in the ages given in the isochrone table
        if lgage in self.uniq_ages:
            
            ### pick out a single isochrone 
            aidx, = np.where(self.uniq_ages==lgage)
            iso_ = self.iso[self.age_idx[int(aidx)][0]:self.age_idx[int(aidx)][1]]
            
            if lgteff < min(iso_['logTe']) or lgteff > max(iso_['logTe']):
                return np.array([999999.0, 999999.0, 999999.0, 999999.0, 999999.0])
            
            ### sort so temp is always increasing
            sidx = np.argsort(iso_['logTe'])
            slogTe = iso_['logTe'][sidx]
            _, uidx = np.unique(slogTe,return_index=True)
            slogTe = slogTe[uidx]
            
            ### use a spline to get the apparent mags
            calc_mags = 999999.0*np.ones(5)
            for i in range(5):
                mag_spl = IUS(slogTe,iso_[self.labels[i]][sidx][uidx])
                calc_mags[i] = mag_spl(lgteff)+self.distmod+extincts[i]
                
            return calc_mags
            
        else:
            lgage_lo,lgage_hi = self.neighbors(self.uniq_ages,lgage)
            if verbose:
                print('[age_lo,age_hi]: ',[10**lgage_lo/10**9,10**lgage_hi/10**9])
            
            ### Pick out single isochrones
        
            # younger
            aidx_lo, = np.where(self.uniq_ages==lgage_lo)
            iso_lo = self.iso[self.age_idx[int(aidx_lo)][0]:self.age_idx[int(aidx_lo)][1]]
            
            # older
            aidx_hi, = np.where(self.uniq_ages==lgage_hi)
            iso_hi = self.iso[self.age_idx[int(aidx_hi)][0]:self.age_idx[int(aidx_hi)][1]]
            
            ### Temperature Check
            if lgteff < min(iso_lo['logTe']) or lgteff > max(iso_lo['logTe']):
                return np.array([999999.0, 999999.0, 999999.0, 999999.0, 999999.0])
            if lgteff < min(iso_hi['logTe']) or lgteff > max(iso_hi['logTe']):
                return np.array([999999.0, 999999.0, 999999.0, 999999.0, 999999.0])
            
            ### sort so temp is always increasing
            
            # younger
            sidx_lo = np.argsort(iso_lo['logTe'])
            slogTe_lo = iso_lo['logTe'][sidx_lo]
            _, uidx_lo = np.unique(slogTe_lo,return_index=True)
            slogTe_lo = slogTe_lo[uidx_lo]
            
            # older
            sidx_hi = np.argsort(iso_hi['logTe'])
            slogTe_hi = iso_hi['logTe'][sidx_hi]
            _, uidx_hi = np.unique(slogTe_hi,return_index=True)
            slogTe_hi = slogTe_hi[uidx_hi]
            
            ### use a spline to get the apparent mags
            age_lo = 10**lgage_lo/10**9
            age_hi = 10**lgage_hi/10**9
            calc_mags = 999999.0*np.ones(5)
            for i in range(5):
                mag_spl_lo = IUS(slogTe_lo,iso_lo[self.labels[i]][sidx_lo][uidx_lo])
                mag_spl_hi = IUS(slogTe_hi,iso_hi[self.labels[i]][sidx_hi][uidx_hi])
                age_spl_interp = np.poly1d(np.squeeze(np.polyfit([age_lo,age_hi],
                                                                 [mag_spl_lo(lgteff),mag_spl_hi(lgteff)],1)))
                calc_mags[i] = age_spl_interp(age)+self.distmod+extincts[i]

            return calc_mags
        
#     def get_age(self):
#         '''
#         Calculate the age of a star
#         '''
#         guess_ages = np.linspace(0.,17.)[::3]
        
#         for i in range(len(guess_ages)):
            
#             ages = 
#             try:
                
#             except:
#                 calc_age = 999999.0
        
#         return calc_age
        
    def get_age(self,guess_ages=np.linspace(0.,17.)[::3],verbose=False):
        '''
        Find best fitting age and Ak values for a star by searching chisq space given initial guesses for
        age and extinction.
        
        Inputs:
        ------
            guess_ages: [array] initial guesses for ages in Gyr

            
        Output:
        ------
            
        '''
            
        # initialize lists
        
        curve_ages = []
        curve_chi = []

        # loop over age and ak space 
        for j in range(len(guess_ages)): 
            try:
                # calculate best fit parameters and covariance matrix
                popt,pcov = curve_fit(self.teff_2_appmags,self.teff,self.phot,p0=guess_ages[j],
                                      bounds=(0.,17.),method='trf',sigma=self.phot_err,
                                      absolute_sigma=True,maxfev=5000)

                # populate lists
                curve_ages.append(popt[0])
                curve_mags = np.asarray(self.teff_2_appmags(self.teff,popt[0]))
                curve_chi.append(sum((curve_mags-self.phot)**2/self.phot_err**2))

            except:
                #populate lists
                curve_ages.append(999999.0)
                curve_chi.append(999999.0)
        
        # find smallest chisq value and corresponding age and Ak
        idx = np.asarray(curve_chi).argmin()
        age = np.asarray(curve_ages)[idx]
        self.age = age

#         best_params = np.array([age,age_err])
#         best_params_err = np.array([age_err,ak_err])
#         best_mags = self.teff_2_appmags_age(self.teff,age)#,ak)

#         if verbose:
#             print('Best Fit Params [Age, Age Err]:',np.round(best_params,3))
#             print('Best Apparent Fit [BP,RP,J,H,K]: ',np.round(best_mags,3))
#             print('Obs Apparent Phot [BP,RP,J,H,K]: ',np.round(self.phot,3))
#             print('Obs Apparent Phot Err [BP,RP,J,H,K]: ',np.round(self.phot_err,3))

        return age
    
    def get_mass(self):
        '''
        Calculate the mass of a star by interpolating isochrones with already calculated Ak and age 

        Output:
        ------
        mass: mass of star in solar masses
        mass_err: error in the calculated mass of the star
        '''

        if self.age == 999999.0:
            return 999999.0 #np.array([999999.0,999999.0])
        
        teffcut = np.where((self.iso['logTe']<np.log10(self.teff+100.))&
                           (self.iso['logTe']>np.log10(self.teff-100.)))
        iso_ = self.iso[teffcut]
        
        if np.size(iso_) < 2:
            return 999999.0
        
        ### calculate the mass using interpolation
        coeffs = np.polyfit(iso_['logAge'],iso_['Mass'],4)
        interpol = np.poly1d(coeffs)
        mass = interpol(np.log10(self.age*10**9))
        self.mass = mass
        
        ### get error in mass calculation
#         deriv = np.poly1d(np.multiply(coeffs[:-1],np.arange(len(coeffs))[::-1][:-1]))
#         mass_err = np.abs(deriv(np.log10(self.age*10**9)))*self.age_err
#         self.mass_err = mass_err
        
        return mass#, mass_err
    
#     def age_diagnostic(self,best_fit_age,star_name=None,lit_age=None,filename=None,verbose=False):
#         '''
#         Create a diagnostic plot for the best fit age
        
#         Inputs:
#         ------
#             best_fit_age: best fit age in Gyr
#             star_name: name of star
#             lit_age: known literature age in Gyr
#             filename: name of file if saving plot
#         '''
        
#         # intialize arrays
#         ages = np.arange(0,15,0.25)
#         theos_w_ext = 999999.0*np.ones((len(ages),5))
#         theos_wo_ext = 999999.0*np.ones((len(ages),5))
        
#         # Calculate theoretical mags given an age in ages
#         for i in range(len(ages)):
#             theos_w_ext[i,:] = self.teff_2_appmags(self.teff,ages[i],self.ak)
#             theos_wo_ext[i,:] = self.teff_2_appmags(self.teff,ages[i],0)
            
#         # plot results
#         fig, ax = plt.subplots(1,5,figsize=[67,15])
#         for j in range(5):
#             # with extinction
#             ax[j].scatter(theos_w_ext[:,j],ages,s=100,label='Best Fit Ak')
#             ax[j].plot(theos_w_ext[:,j],ages)
            
#             # without extinction
#             ax[j].scatter(theos_wo_ext[:,j],ages,s=100,label='No Ak')
#             ax[j].plot(theos_wo_ext[:,j],ages)
            
#             ax[j].axhline(best_fit_age,c='r',ls='--',label='Best Fit Age',lw=3.0)
            
#             if lit_age != None:
#                 ax[j].axhline(lit_age,c='k',ls='--',label='Lit Age',lw=3.0)
                
#             ax[j].axvline(self.phot[j],ls='--',c='magenta',lw=3.0)
            
#             ax[j].set_xlim(-14,14)
#             ax[j].set_ylim(bottom=0.01)
            
#             ax[j].set_title(r'{}: {:.3f}'.format(self.labels[j],self.phot[j]))
#             ax[j].set_xlabel(r'Magnitude')
        
#         ax[0].legend()
        
#         ax[0].set_ylabel('Age')
        
#         if star_name != None:
#             plt.suptitle(r'{}; {:.3f} Gyr; {:.3f} dex; {:.3f} kpc'.format(star_name,best_fit_age,
#                                                                           self.ak,self.distance))
            
#         else:
#             plt.suptitle(r'{:.3f} Gyr; {:.3f} dex; {:.3f} kpc'.format(best_fit_age,self.ak,
#                                                                       self.distance))
            
#         if filename != None:
#             plt.savefig(filename,bbox_inches='tight')
            
#         plt.show()
        
    def mass_2_age(self,mass):
        '''

        Input:
        -----
            mass: mass of star

        Output:
        ------
            age: age of star
        '''

        try:
            mass_lo,mass_hi = neighbors(self.iso['Mass'],mass)

            iso_lo = self.iso[np.where(self.iso['Mass']==mass_lo)]
            iso_hi = self.iso[np.where(self.iso['Mass']==mass_hi)]

            # younger
            sidx_lo = np.argsort(iso_lo['MH'])
            sMH_lo = iso_lo['MH'][sidx_lo]
            _, uidx_lo = np.unique(sMH_lo,return_index=True)
            sMH_lo = sMH_lo[uidx_lo]

            spl_lo = interp1d(sMH_lo,10**iso_lo['logAge'][sidx_lo][uidx_lo]/10**9)

            # older
            sidx_hi = np.argsort(iso_hi['MH'])
            sMH_hi = iso_hi['MH'][sidx_hi]
            _, uidx_hi = np.unique(sMH_hi,return_index=True)
            sMH_hi = sMH_hi[uidx_hi]

            spl_hi = interp1d(sMH_hi,10**iso_hi['logAge'][sidx_hi][uidx_hi]/10**9)

            # final spline
            final_spl = interp1d([mass_lo,mass_hi],[spl_lo(self.salfeh),spl_hi(self.salfeh)])
            age = final_spl(mass)
            return age
        except:
            return 999999.0
        
        
    ######################
    ### MCMC Functions ###
    ######################
    
    def lnL(self,theta,x,y):
        '''
        The ln(likelihood)
        '''
        ak, age = theta
        
        abpak = self.ccm_alav(self.leff['G_BPmag'])/self.ccm_alav(self.leff['Ksmag'])
        arpak = self.ccm_alav(self.leff['G_RPmag'])/self.ccm_alav(self.leff['Ksmag'])
        ajak = self.ccm_alav(self.leff['Jmag'])/self.ccm_alav(self.leff['Ksmag'])
        ahak = self.ccm_alav(self.leff['Hmag'])/self.ccm_alav(self.leff['Ksmag'])
        extincts = np.array([abpak,arpak,ajak,ahak,1.0])
        
        modl = self.teff_2_absmags(age,self.teff) + ak*extincts
        sig2 = self.phot_cov + np.diag([np.square(self.phot_err)])
        lnl = -0.5*(np.dot(self.phot-modl,np.dot(sig2,self.phot-modl) - np.log(np.linalg.det(sig2)/(2*np.pi)**5))
        return lnl
    
    def lnPrior(self,theta):
        '''
        The ln(Prior) for the Ak and age for a star
        
        Inputs:
        ------
            theta: the parameters of interest
        '''
        ak, age = theta
        if 0.0 < ak < 1.0 and 0.0 < age < 17.0:
            return 0.0
        return -np.inf
    
    def lnProb(self,theta,x,y):
        '''
        Calculate the Ln(probability)
        '''
        lnP = self.lnPrior(theta)
        if not np.isfinite(lnP):
            return -np.inf
        return lnP + self.lnL(theta,x,y,xerr,yerr)
    
#     def MLE_age_ext(self,):
#         return
        
#     def mass_2_age_PARSEC(self,mass,parsec):
#         '''
#         Input:
#         -----
#             mass: mass of star

#         Output:
#         ------
#             age: age of star
#         '''

#         try:
#             mass_lo,mass_hi = self.neighbors(self.iso['Mass'],mass)

#             iso_lo = self.iso[np.where(self.iso['Mass']==mass_lo)]
#             iso_hi = self.iso[np.where(self.iso['Mass']==mass_hi)]

#             # younger
#             sidx_lo = np.argsort(iso_lo['MH'])
#             sMH_lo = iso_lo['MH'][sidx_lo]
#             _, uidx_lo = np.unique(sMH_lo,return_index=True)
#             sMH_lo = sMH_lo[uidx_lo]

#             spl_lo = interp1d(sMH_lo,10**iso_lo['logAge'][sidx_lo][uidx_lo]/10**9)

#             # older
#             sidx_hi = np.argsort(iso_hi['MH'])
#             sMH_hi = iso_hi['MH'][sidx_hi]
#             _, uidx_hi = np.unique(sMH_hi,return_index=True)
#             sMH_hi = sMH_hi[uidx_hi]

#             spl_hi = interp1d(sMH_hi,10**iso_hi['logAge'][sidx_hi][uidx_hi]/10**9)

#             # final spline
#             final_spl = interp1d([mass_lo,mass_hi],[spl_lo(self.salfeh),spl_hi(self.salfeh)])
#             age = final_spl(mass)
#             return age
#         except:
#             return 999999.0
        
#     def mass_2_age_MESA(self,star_mass,mesa):
#         '''

#         Input:
#         -----
#             star_mass: star_mass of star

#         Output:
#         ------
#             age: age of star
#         '''

#         try:
#             star_mass_lo,star_mass_hi = self.neighbors(mesa['star_mass'],star_mass)

#             iso_lo = mesa[np.where(mesa['star_mass']==star_mass_lo)]
#             iso_hi = mesa[np.where(mesa['star_mass']==star_mass_hi)]

#             # younger
#             sidx_lo = np.argsort(iso_lo['MH'])
#             sMH_lo = iso_lo['MH'][sidx_lo]
#             _, uidx_lo = np.unique(sMH_lo,return_index=True)
#             sMH_lo = sMH_lo[uidx_lo]

#             spl_lo = interp1d(sMH_lo,10**iso_lo['log10_isochrone_age_yr'][sidx_lo][uidx_lo]/10**9)

#             # older
#             sidx_hi = np.argsort(iso_hi['MH'])
#             sMH_hi = iso_hi['MH'][sidx_hi]
#             _, uidx_hi = np.unique(sMH_hi,return_index=True)
#             sMH_hi = sMH_hi[uidx_hi]

#             spl_hi = interp1d(sMH_hi,10**iso_hi['log10_isochrone_age_yr'][sidx_hi][uidx_hi]/10**9)

#             # final spline
#             final_spl = interp1d([star_mass_lo,star_mass_hi],[spl_lo(self.salfeh),spl_hi(self.salfeh)])
#             age = final_spl(star_mass)
#             return age
#         except:
#             return 999999.0

# PARSEC

In [5]:
massive = fits.getdata('/Users/joshuapovick/Desktop/Research/parsec/parsec_massive_lite.fits.gz')
massive = Table(massive[np.where(massive['label']==3.0)])
massive['index'] = np.arange(len(massive))
massive = massive[np.argsort(massive['logAge'])]
massive = massive['index','MH','Mass','logAge','logTe','logg','Gmag','G_BPmag','G_RPmag','Jmag','Hmag','Ksmag']

In [8]:
mcov = np.cov(np.array([massive['G_BPmag'],massive['G_RPmag'],massive['Jmag'],massive['Hmag'],massive['Ksmag']]))

In [23]:
import scipy
np.real(scipy.linalg.eig(mcov)[0])

array([1.79404246e+01, 7.89318142e-01, 2.96791812e-03, 9.40699355e-04,
       1.29786846e-04])

In [28]:
np.shape(np.array([0,3,4]))

(3,)

In [29]:
np.diag([1,2,3])

array([[1, 0, 0],
       [0, 2, 0],
       [0, 0, 3]])

In [31]:
a = np.array([[102,340,567],[1,2,3],[25,49,81]])
b = np.array([1,2,3])

In [33]:
np.dot(a,np.dot(a,b))

array([465548,   3609,  92407])

x*(x+1)*(x+2)
<class 'str'>
<class 'int'>
60
<class 'int'>


In [36]:
x = compile('55/88', 'test', 'eval')
eval(x)


0.625

In [66]:
import sympy as sp
from sympy import symbols

# define variables 
x, y = symbols('x y')

# define a function
expr = x**3 + 2*y
print('expr:',expr)

# take double x derivative of expr and evaluate the function
f = compile(str(sp.diff(sp.diff(expr,x),x)),'test','eval')
print('expr at x = 1: ',eval(f,{'x':1}))

expr: x**3 + 2*y
expr at x = 1:  6
