In [87]:
### Allen B Davis
### Yale University
###
### First year project
### Spring 2015

###Created after v4 of the full program
###v3 Now with a callable simultation class.
###v4 Now with subplots
###v6 Softer colors and TESS region

###Full_v6 now with FAP calculations and Doppler survey searches. Removing support for SNR and K plotting. Use RV_Obs_Simple_v6 for that

### As of 5/18/15, hosted on GitHub. The _v# has been removed from the name, since version control is provided.

In [88]:
import numpy as np
import math
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import matplotlib.ticker as plticker
from matplotlib.colors import LinearSegmentedColormap
import scipy as sp
import scipy.signal
import pickle


plt.close()

In [89]:
### Constants, all in SI
mEarth = 5.97219E24
mJup = 1.898E27
mSun = 1.9891E30
G = 6.67384E-11 
au = 149597870700.
day = 86400.
pi = math.pi

In [90]:
### Define the Planet class, which contains the physical parameters of one planetary system

class Planet:
    
    def __init__(self,mass,massStar,period,semimajoraxis):
        
        self.mass = mass #in mEarth
        self.massStar = massStar #in mSun
        self.period = period #in years
        self.ecc = np.random.random()*0.1 #random ecc b/w 0 and 0.1
        self.incl = 90. #in degrees
        self.sini = np.sin(self.incl*pi/180.)
        self.sma = semimajoraxis #in AU
        self.w = np.random.random()*2*pi #random phase
        self.t0 = 2457129.53734 #JD of start
        self.tp = self.t0+(np.random.random()*self.period) #random tp        
        
        #crash if period and sma were both unspecified
        if period == -1 and semimajoraxis == -1:
            raise ValueError('Period or semimajor axis was not defined for planet')
    
        #get SMA if not specified
        if semimajoraxis == -1:
            totmass = (self.mass*mEarth)+(massStar*mSun)
            self.sma = (1/au) * ( ( (self.period*day*365.242)**2 * G*totmass ) / ( 4.*pi*pi))**(1./3.)
            
        #get period if not specified
        #WARNING NOT CHANGED FOR RANDOMS
        if period == -1:
            totmass = (self.mass*mEarth)+(massStar*mSun)
            self.period = (1/(day*365.242))* ( ((semimajoraxis*au)**3 * 4.*pi*pi) / (G*totmass) )**0.5
    
        self.rvamp = 29.8*((mEarth/mJup)*self.mass*self.sini)/((self.massStar*self.sma)**(0.5)) #in m/s
        
    
    # Creates RV observations
    def makeObs(self,ndays,offset_threshold,std_True):
        
        # Make all the random offsets first; we will draw from this list for each obs.
        roffsets = [np.random.random()*(2*offset_threshold) - (offset_threshold) for i in range(ndays)] 
        rnext = 0 #tracks the next element of the random list to pull from
        
        # Create the list of observation times according to the blind Doppler survey strategy
        obsList = np.arange(0,3) + roffsets[rnext:rnext+3]
        rnext = obsList.size
        
        if self.scatter(obsList+self.t0,True) > std_True: # Keep a separation of 1 day
            newtimes = np.arange(obsList.size,ndays)
            obsList = np.concatenate((obsList,newtimes+roffsets[rnext:ndays]))
#             print zip(np.arange(obsList.size),obsList),'\n'
        else: # Ramp it up to 5 day separation
            sep = 5 #separation in days b/w next obs
            n = 3 #how many more obs this round
            newtimes = np.arange(obsList.size-1+sep,obsList.size+n*sep,sep)
            obsList = np.concatenate((obsList,newtimes+roffsets[rnext:rnext+n]))
            rnext = obsList.size
#             print zip(np.arange(obsList.size),obsList),'\n'

            if self.scatter(obsList+self.t0 > std_True,True): # Keep a 5 day separation
                n = ndays - obsList.size + 1
                newtimes = np.arange(obsList[obsList.size-1]+sep,obsList[obsList.size-1]+n*sep,sep)
                obsList = np.concatenate((obsList,newtimes+roffsets[rnext:ndays]))
#                 print zip(np.arange(obsList.size),obsList),'\n'
            else: # Ramp it up to 10 day separation and keep it there
                sep = 10
                n = ndays-obsList.size + 1
                newtimes = np.arange(obsList[obsList.size-1]+sep,obsList[obsList.size-1]+n*sep,sep)        
                obsList = np.concatenate((obsList,newtimes+roffsets[rnext:ndays]))
#                 print zip(np.arange(obsList.size),obsList),'\n'

        # Set the JD
        self.obsList = obsList + self.t0

        #record true RV curve
        f = self.calcTrueAnomaly(self.period,self.tp,self.ecc,self.obsList)
        self.RV_True = self.rvamp*(np.cos(self.w+f) + self.ecc*np.cos(self.w))
        
        #add jitter scaled to a normal distribution
        errList = np.random.normal(0, std_True, ndays)
 
        #record observed RV curve
        self.RV_Obs = self.RV_True + errList
        
        #record observated scatter in RV curve
        self.std_Obs = np.std(self.RV_Obs)

        
    def scatter(self,obsList,lastThree):
        f = self.calcTrueAnomaly(self.period,self.tp,self.ecc,obsList)
        self.RV_True = self.rvamp*(np.cos(self.w+f) + self.ecc*np.cos(self.w))
        if(lastThree):
            size = np.size(self.RV_True)
            return np.std(self.RV_True[size-3:size])
        else:
            return np.std(self.RV_True)
    
        
    # Calculates the true anomaly from the time and orbital elements.
    def calcTrueAnomaly(self, P, tp, e, t):
        
        phase = (t-tp)/P #phase at each obsList time
        M = 2.*pi*(phase - np.floor(phase)) #Mean Anom array: at each obsList time
        E1 = self.calcKepler(M, np.array([e]))
        
        n1 = 1. + e
        n2 = 1. - e
        
        #True Anomaly:
        return 2.*np.arctan(np.sqrt(n1/n2)*np.tan(E1/2.))
    
    
    #returns Eccentric anomaly, given mean anomaly and eccentricity
    def calcKepler(self, Marr_in, eccarr_in):
        
        
        nm = np.size(Marr_in)
        nec = np.size(eccarr_in)
        
        if nec == 1 and nm > 1:
            eccarr = eccarr_in #[eccarr_in for x in range(nm)]
        else:
            eccarr = eccarr_in
        
        if nec > 1 and nm == 1:
            Marr = Marr_in #[Marr_in for x in range(nec)]
        else:
            Marr = Marr_in
    
        conv = 1.E-12 #threshold for convergence
        k = 0.85 #some parameter for guessing ecc
        ssm = np.sign(np.sin(Marr))
        Earr = Marr+(ssm*k*eccarr)  #first guess at E
        fiarr = (Earr-(eccarr*np.sin(Earr))-Marr)  #E - e*sin(E)-M    ; should go to 0 when converges
        convd = np.where(abs(fiarr) > conv) #which indices are unconverged

        count = 0
        while np.size(convd) > 0:
            count += 1

            M = np.copy(Marr[convd]) #we only run the unconverged elements
            ecc = eccarr #[convd] ??
            E = np.copy(Earr[convd])
            fi = np.copy(fiarr[convd])
            
            fip = 1.-ecc*np.cos(E) #;d/dE(fi) ;i.e.,  fi^(prime)
            fipp = ecc*np.sin(E)  #;d/dE(d/dE(fi)) ;i.e.,  fi^(\prime\prime)
            fippp = 1.-fip #;d/dE(d/dE(d/dE(fi))) ;i.e.,  fi^(\prime\prime\prime)

            d1 = -fi/fip                             #;first order correction to E
            d2 = -fi/(fip+(d1*fipp/2.))                #;second order correction to E
            d3 = -fi/(fip+(d2*fipp/2.)+(d2*d2*fippp/6.)) #;third order correction to E
            
            E += d3 #apply correction to E
            
#             print np.size(Earr),np.size(E)
            Earr[convd] = E #update values
            
            fiarr = (Earr-eccarr*np.sin(Earr)-Marr)     #;how well did we do?
            convd = np.where(abs(fiarr) > conv)   #;test for convergence; update indices

            if count > 100:
                print "WARNING!  Kepler's equation not solved!!!"
                break
        
        return Earr
        
        
    # Lomb Scargle periodogram
    def LSP(self,freqs):
        self.freqs = freqs[:]
        self.P_G = sp.signal.lombscargle(self.obsList, self.RV_Obs,self.freqs)
        self.P_G_max = np.max(self.P_G)
    
    
    # Determine FAP by shuffling velocities
    def calc_FAP(self,niter):
                
        max_powers = [None]*niter #contains max power of each iteration

        self.RV_Obs_scram = np.copy(self.RV_Obs) #initialize
        
        # iterate
        for n in range(0,niter):

            # Shuffle observations
            np.random.shuffle(self.RV_Obs_scram)
            
            # periodogram; max power
            P_G_new = sp.signal.lombscargle(self.obsList, self.RV_Obs_scram, self.freqs)
            max_powers[n] = max(P_G_new)


        self.max_powers_sort = np.sort(max_powers)
        
    
    # Returns the periodogram power corresponding to the given FAP (in percent)
    def calc_FAP_power(self,FAP):
        size = len(self.max_powers_sort)
        return self.max_powers_sort[(size-1) - (size*FAP/100.)]
    
    
    # Returns whether the planet met the FAP threshold or not
    def isDetected(self,threshold_FAP):
        self.threshold_power = self.calc_FAP_power(threshold_FAP)        
        return (self.P_G_max > self.threshold_power)
            

    

In [91]:
### Simulator Class

class Simulator:
    def __init__(self,nobs,sigma,jitter,massStar,massRange,logmasses,periodRange,logperiods,ntimes):
        self.nobs = nobs #days
        self.sigma = sigma #m/s
        self.jitter = jitter #m/s
        self.massStar = massStar #Msun
        self.logmasses = logmasses
        self.logperiods = logperiods
        
        self.minmass = massRange[0] #mEarths
        self.maxmass = massRange[1]
        self.massbins = massRange[2]
        
        self.minperiod = periodRange[0] #years
        self.maxperiod = periodRange[1]
        self.periodbins = periodRange[2]
        
        self.ntimes =  ntimes #how many planets averaged per point

        
        # Other astrophysical parameters
        self.semimajoraxis = -1 # sma of systems in AU; but "-1" iff we give period later.
        self.eccentricity = 0.0
        
        # Observational parameters
        self.offset_threshold = 2./24. #maximum allowed offset time (+/-) for observations in days. Float. 
        self.std_True = sigma #injected noise in m/s. Float.

        # Periodogram parameters
        pmin = self.minperiod #minimum period in days
        pmax = self.maxperiod #max period in days
        pres = 0.01 #period resolution

        # Detection parameter
        self.threshold_FAP = 1.0 # tolerated FAP percentage for a detection
        self.FAP_niter = 100 #number of FAP shuffle iterations. Integer
    
        # DO NOT ALTER: generated by choices above
        self.pers = np.linspace(pmin,pmax,num=pmax/pres)
        self.freqs = 2.* pi / self.pers
    
    def runSim(self):
        ### Set up the simulation

        #Select mass and period bins

        if self.logmasses:
            self.masses = np.logspace(np.log10(self.minmass),np.log10(self.maxmass+mbinsize),self.massbins+1)
            extram = mtemp[np.size(mtemp)]*(mtemp[np.size(mtemp)]/mtemp[np.size(mtemp)-1]) #this gives an extra bin on the end for injecting unc in last bin
            self.masses = np.concatenate((mtemp,extram))
        else:
            self.masses = np.linspace(self.minmass,self.maxmass,self.massbins+1)

        if self.logperiods:
            self.periods = np.logspace(np.log10(self.minperiod),np.log10(self.maxperiod),self.periodbins+1)
        else:
            self.periods = np.linspace(self.minperiod,self.maxperiod,self.periodbins+1)

       
        #Create the M x N grid of periods and masses.
        PMgrid = np.zeros((len(self.periods),len(self.masses)))
        
        print 'self.masses:',self.masses
        print 'self.periods:',self.periods
        ### Run the simulation: populate the grid

        def SNR(K,Nobs,jit,err):
            return K * np.sqrt(Nobs/(jit**2 + err**2))

        # Vary only for debugging. Else: start = 0, stop = len(masses) or len(periods)
        start_m = 0
        stop_m = len(self.masses)

        start_p = 0
        stop_p = len(self.periods)
        
        
        counter = 0
        tot = np.size(PMgrid)
        for m_ind in range(start_m,stop_m): #MIGHT NEED TO SUB 1 HERE AND BELOW FOR DM AND DP REASONS
            for p_ind in range(start_p,stop_p):
                detections = np.zeros(self.ntimes)
                
                for n in range(0,self.ntimes):
                    # Assign mass and period for the grid point
                    massPlanet = self.masses[m_ind]
                    period = self.periods[p_ind]

                    # Create a planet object with given astrophysical parameters
                    p = Planet(massPlanet,self.massStar,period,self.semimajoraxis)

                    # Create observations according to blind survey strategy
                    p.makeObs(self.nobs,self.offset_threshold,self.std_True)
                    
                    # Make a LS periodogram of the observations
                    p.LSP(self.freqs)
                    
                    # MC for FAP
                    p.calc_FAP(self.FAP_niter)
                    
                    # Store FAP
                    if p.isDetected(self.threshold_FAP):
                        detections[n] = 1
                        
                PMgrid[p_ind,m_ind] = np.mean(detections)
                print '%.1f %%' %(counter*100./tot)
                counter += 1

        # After PMgrid finished, save as instance variable
        self.PMgrid = PMgrid
        print PMgrid
        return self.PMgrid


In [92]:
def makeSubplot(sim,ax,returnIm,minval,maxval):
    
    ### SET THESE
    logVal = False
    logX = False
    logY = False
    colorscheme = 'green'
    tesslimit = True
    fsize = 11 #for title: 11 if subplot plot
    ### END SET
    
    
    # Colors
    def custom_div_cmap(numcolors=100, name='custom_div_cmap',
                    mincol='white', maxcol=colorscheme):
        cmap = LinearSegmentedColormap.from_list(name=name, 
                                             colors =[mincol, maxcol],
                                             N=numcolors)
        return cmap
    custom_map = custom_div_cmap(100, mincol='white' ,maxcol=colorscheme)
    
    X, Y = np.meshgrid(sim.periods, sim.masses)

    if logVal:
        PMgrid_plt = np.log10(sim.PMgrid)
    else:
        PMgrid_plt = np.copy(sim.PMgrid)
    
    print 'plotted grid:\n',np.transpose(PMgrid_plt)
    print 'plotted grid clipped:\n',np.transpose(PMgrid_plt)[0:X.size-1,0:Y.size-1]

    
    im = ax.pcolormesh(X,Y, np.transpose(PMgrid_plt),cmap=custom_map,vmin=minval,vmax=maxval)
#     im = ax.imshow(np.transpose(PMgrid_plt),cmap=custom_map,vmin=minval,vmax=maxval)

    startx, endx = np.min(X), np.max(X)
    starty, endy = 0, np.max(Y)
        
    ax.set_xlim(startx,endx)
    ax.set_ylim(starty,endy)
    
    locy = plticker.MultipleLocator(base=1.0) # this locator puts ticks at regular intervals
    ax.yaxis.set_major_locator(locy)

    
    if logX:
        ax.set_xscale('log')
        ax.xaxis.set_major_formatter(plticker.FormatStrFormatter('%0.1f'))
        ax.set_xticks([0.1,1.0])

    if logY:
        ax.set_yscale('log')   
        
    if tesslimit:
        tessmaxper = 20./365.242
        tessminmass = 1.3**3. #min radius is 1.3 R_e. Mass goes as the cube of this.
        cornersx = [startx,tessmaxper,tessmaxper,startx]
        cornersy = [tessminmass,tessminmass,endy,endy]
        ax.fill(cornersx,cornersy,color='blue',alpha=0.5)
        ax.annotate('TESS', xy=(startx+(.1*(tessmaxper-startx)),0.5*(endy+tessminmass)),
                    xytext=(startx+(.1*(tessmaxper-startx)),0.5*(endy+tessminmass)),color='white',fontsize=11,weight='bold')
    
    ax.set_title(r'%.1f $M_{\odot}$ Star' %sim.massStar,fontsize=fsize)
    
    if returnIm:
        
        #  Make a new PMgrid where we set the threshold according to minval and maxval
        PMgrid_temp1 = np.where(sim.PMgrid < minval,minval,sim.PMgrid)
        PMgrid_scaled = np.where(PMgrid_temp1 > maxval,maxval,PMgrid_temp1)
        
        return im


In [93]:
def makePlots():
    
    ### SET SIMULATOR PARAMS
    nobsVals= [200,500]
    sigmaVals = [0.2,0.5]
    massStar = 0.8 #Msun
    massRange = [0.1,5.,500] #min,max,bins  in mEarths
    logmasses = False
    periodRange = [0.02,2.,500] #min,max,bins  in years
    logperiods = True
    jitter = 0.0 #m/s
    plotvar = 'SNR' #options include 'SNR' and 'K'
    contvar = 'K'
    minval = 5.8
    maxval = 7.6
    cbarlabel = 'White-Noise Detection Threshold'
    ### END SET
    
    
    ### Run Simulations
    sim1 = Simulator(nobs=nobsVals[0],sigma=sigmaVals[1],jitter=jitter,massStar=massStar,massRange=massRange,
                     logmasses=logmasses,periodRange=periodRange,logperiods=logperiods,plotvar=plotvar)
    sim2 = Simulator(nobs=nobsVals[1],sigma=sigmaVals[1],jitter=jitter,massStar=massStar,massRange=massRange,
                     logmasses=logmasses,periodRange=periodRange,logperiods=logperiods,plotvar=plotvar)
    sim3 = Simulator(nobs=nobsVals[0],sigma=sigmaVals[0],jitter=jitter,massStar=massStar,massRange=massRange,
                     logmasses=logmasses,periodRange=periodRange,logperiods=logperiods,plotvar=plotvar)
    sim4 = Simulator(nobs=nobsVals[1],sigma=sigmaVals[0],jitter=jitter,massStar=massStar,massRange=massRange,
                     logmasses=logmasses,periodRange=periodRange,logperiods=logperiods,plotvar=plotvar)
    sim1.runSim()
    sim2.runSim()
    sim3.runSim()
    sim4.runSim()
    
    
    ### Plot
    figBig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, sharex='col', sharey='row')

    figBig.subplots_adjust(right=0.8)
    
    im = makeSubplot(sim1,ax1,True,minval,maxval,plotvar,contvar) #arbitrarily grab the first cbar. They are all the same.
    makeSubplot(sim2,ax2,True,minval,maxval,plotvar,contvar)
    makeSubplot(sim3,ax3,True,minval,maxval,plotvar,contvar)
    makeSubplot(sim4,ax4,True,minval,maxval,plotvar,contvar)
    
    
    cax = figBig.add_axes([0.83,0.1,0.03,0.8])
    cbar = figBig.colorbar(im,ticks=[minval,maxval],cax=cax)
    cbarticklist = ['<%.1f' %minval , '>%.1f' %maxval]
    cbar.set_label(cbarlabel,rotation=270,labelpad=0)
    cbar.set_clim([minval, maxval])
    cbar.ax.set_yticklabels(cbarticklist)
    
    figBig.text(0.5, 0.04, 'Period (years)', ha='center',fontsize=12)
    figBig.text(0.06, 0.5, r'Mass ($M_{\oplus}$)', va='center', rotation='vertical',fontsize=12)
    
    plt.subplots_adjust(hspace = 0.15, wspace = 0.1)
    
    plt.savefig('apr15_highnobs.png', format='png', dpi=300)


In [94]:
def makePlot():
    ### SET SIMULATOR PARAMS
    nobsVal = 100
    sigmaVal = 0.5
    massStar = 0.8 #Msun
    massRange = [0.1,5.,3] #min,max,bins  in mEarths
    logmasses = False
    periodRange = [0.02,2.,3] #min,max,bins  in years
    logperiods = False
    jitter = 0.0 #m/s
    minval = 0.0
    maxval = 1.0
    ntimes = 5 #integer: number of planets averaged per grid point
    cbarlabel = 'Fraction Detected'
    onlyLimitTicks = False
    data_provided = True
    ### END SET
    
    
    fig,ax = plt.subplots(1,1)
    
    if not data_provided:
    
        ### Run Simulations
        sim = Simulator(nobs=nobsVal,sigma=sigmaVal,jitter=jitter,massStar=massStar,massRange=massRange,
                         logmasses=logmasses,periodRange=periodRange,logperiods=logperiods,ntimes=ntimes)
        sim.runSim()

        #Save sim,ax
        f = open('store.pckl', 'w')
        pickle.dump([sim,fig,ax], f)
        f.close()
    else:
        #Load im
        f = open('store.pckl')
        sim,fig,ax = pickle.load(f)
        f.close()
        
    im = makeSubplot(sim,ax,True,minval,maxval)
    
    if onlyLimitTicks:
        cbar = fig.colorbar(im)#,ticks=[minval,maxval])
        cbarticklist = ['%.0f' %minval , '%.0f' %maxval]
        cbar.ax.set_yticklabels(cbarticklist)
    else:
        cbar = fig.colorbar(im)
    cbar.set_label(cbarlabel,rotation=270,labelpad=20)
    cbar.set_clim([minval, maxval])
    
    
    ax.set_xlabel('Period (years)')
    ax.set_ylabel(r'Mass ($M_{\oplus}$)')
    
    print 'ntimes = ',ntimes
    print 'resolution =',massRange[2],'x',periodRange[2]
    
#     plt.savefig('apr13_tess.png', format='png', dpi=300)


In [95]:
# makePlots()
# plt.show()
# plt.close()

In [96]:
makePlot()
plt.show()
plt.close()

plotted grid:
[[ 0.   0.4  0.   0. ]
 [ 1.   0.   0.   0.2]
 [ 1.   0.6  0.6  0. ]
 [ 1.   1.   1.   0. ]]
plotted grid clipped:
[[ 0.   0.4  0.   0. ]
 [ 1.   0.   0.   0.2]
 [ 1.   0.6  0.6  0. ]
 [ 1.   1.   1.   0. ]]
ntimes =  5
resolution = 3 x 3
