In [None]:
#SUPERcomputer Searching PeriodOgrams for Transits To EmeRge (SUPER-SPOTTER)

import numpy as np
import matplotlib.pyplot as plt
import batman
import scipy.signal as signal
import time
from tqdm import tqdm
from astropy import units as u
from astropy.io import ascii
#import transitleastsquares
from transitleastsquares import transitleastsquares

In [None]:
def binner(bins, time, flux, eflux):
    means, errors = [], []
    midpoints = [(bins[i] + bins[i+1])/2 for i in range(len(bins)-1)]
    for i in range(len(bins)-1):
        binned_mask = (time < bins[i+1])*(time > bins[i])
        if np.sum(binned_mask) < 5:
            mean = 1
        else:
            mean = np.average(flux[binned_mask], weights = 1/eflux[binned_mask]**2)
        means.append(mean)
        errors.append(np.std(flux[binned_mask])/np.sqrt(len(flux[binned_mask])))
    return means, errors, midpoints

In [None]:
data_file_name = 'Final_KARPA_Data_BRO2.csv' #The name of the input data file

input_periods = list(np.logspace(np.log10(3), np.log10(40), 100))[0:10] #The periodgrid
input_radii = list(np.linspace(0.5, 2, 11))                     #The radii grid
n_checks = 10                                                     #Amount of times to check each pixel

detections_file_name = 'detections_ownd_KARPA2_set_0.csv'               #File name to save the number of detections file_name


#Define constants
r_earth = u.astrophys.earthRad.to(u.m)
r_star = 0.96*u.astrophys.R_sun.to(u.m)
m_star = 0.95*u.astrophys.M_sun.to(u.kg)
day_in_sec = u.day.to(u.second)
G = 6.6743*10**-11
r_jupiter = u.astrophys.jupiterRad.to(u.m)
resonances = [1./5.,1./4.,1./3., 1./2., 1., 2., 3., 4., 5.]

#Reading in the data
data = ascii.read(data_file_name)

mask = (data['col3'] < 0.05) #Removing measurements with high uncertainty
data = data[mask]

#Defining the times over which to produce fake data, the errors on the data and the transit parameters
times = data['col1']
transit_params = [2, 5, 10, m_star, r_star, 90, 0, 90]
e_flux = data['col3']

In [None]:
def transit_time(period):
    a = float(((((period*day_in_sec)**2)*G*m_star)/(4*np.pi**2))**(1./3.))/r_star 
    return period/(np.pi*a)

#Generating noise and fake data
def noise_generator(times, std):
    return np.random.randn(times.size)*std #Generating Gaussian noise with a sigma of choice
       
def fake_data_generator(times, noise):
    data = np.ones_like(times) 
    data += noise
    return data

#Generating a fake transit
def transit_generator(times, r_planet, t0, per, m_star, r_star, inc, ecc, w):
    params = batman.TransitParams()              #object to store transit parameters
    params.t0 = float(t0)                        #time of inferior conjunction
    params.per = float(per)                      #orbital period
    params.rp = float(r_planet*r_jupiter)/r_star #planet radius (in units of stellar radii)
    params.a = float(((((per*day_in_sec)**2)*G*m_star)/(4*np.pi**2))**(1./3.))/r_star   #semi-major axis (in units of stellar radii)
    params.inc = float(inc)                      #orbital inclination (in degrees)
    params.ecc = float(ecc)                      #eccentricity
    params.w = float(w)                          #longitude of periastron (in degrees)
    params.limb_dark = "nonlinear"               #limb darkening model "nonlinear"
    params.u = [1.2, -0.47, -0.22, 0.24]         #limb darkening coefficients [u1, u2, u3, u4] for T4500 log4.0, standard = [0.5, 0.1, 0.1, -0.1] 
    #print("The transit takes", params.per/(np.pi*params.a), "Days")

    t = times #times at which to calculate light curve 
    m = batman.TransitModel(params, t)    #initializes model
    return m.light_curve(params)

def tls(times, transit_flux, yerr, min_p, max_p, plot, stats): #tls 
    model = transitleastsquares(times, transit_flux, yerr)
    #Stellar parameters from Mentel 2018
    results = model.power(transit_depth_min = 0.001, use_threads = 96, R_star = 0.96, R_star_min = 0.81, R_star_max = 1.11, M_star = 0.95, M_star_min = 0.85, M_star_max = 1.05, period_min = min_p, period_max = max_p, oversampling_factor = 3, show_progress_bar = False)

    #print(results.periods[:40])
    #-----------------------------------------------------------------------------
    plt.plot(results.periods, results.power, c='black')
    fig = plt.gcf()
    fig.set_size_inches(9, 6)
    #plt.savefig("GB_no_tranist_added_SDE.png")
    plt.show()
    #-----------------------------------------------------------------------------
    
    #Picking out the three most likely periods, excluding possible insignificantly different values
    picker = 1
    top_4_SDE = [0]
    while len(top_4_SDE) < 4:
        if len(top_4_SDE) == 3:
            if abs((results.periods[list(results.power).index(sorted(list(results.power))[-picker])]-results.periods[list(results.power).index(top_4_SDE[-1])])/(results.periods[list(results.power).index(top_4_SDE[-1])])) < 0.05 or abs((results.periods[list(results.power).index(sorted(list(results.power))[-picker])]-results.periods[list(results.power).index(top_4_SDE[-2])])/(results.periods[list(results.power).index(top_4_SDE[-2])])) < 0.05:
                pass
            else:
                top_4_SDE.append(sorted(list(results.power))[-picker])
        elif len(top_4_SDE) == 1:
            top_4_SDE.append(sorted(list(results.power))[-1])
        else:
            if abs((results.periods[list(results.power).index(sorted(list(results.power))[-picker])]-results.periods[list(results.power).index(top_4_SDE[-1])])/(results.periods[list(results.power).index(top_4_SDE[-1])])) < 0.05:
                pass
            else:
                top_4_SDE.append(sorted(list(results.power))[-picker])
        picker += 1
        
    top_3_SDE = top_4_SDE[1:]
    top_3_periods = [results.periods[list(results.power).index(top_3_SDE[0])], results.periods[list(results.power).index(top_3_SDE[1])], results.periods[list(results.power).index(top_3_SDE[2])]]
       
    #bins = np.linspace(0, top_3_periods[0], int(5*top_3_periods[0]/transit_time(top_3_periods[0])))
    #means, errors, midpoints = binner(bins, times%top_3_periods[0], transit_flux, yerr)
    #plt.errorbar(midpoints, means, errors, c='b', fmt = '.')
    #plt.axhline(y=1, linestyle = '--', c='grey')
    #plt.axvline(x=transit_params[1], linestyle = '--', c='r', alpha = 0.3)
    #plt.title("Binned Folded data")
    #plt.savefig("TESS_no_transit_added_binned.png")
    #plt.show()
    
    return top_3_periods, top_3_SDE, results

def retrieved_condition(retrieved_p, inserted_p, sde): #A function that determines whether a fake transit was retrieved
    print("Retrieved periods: ", retrieved_p)
    print("SDE: ", sde)
    for i in range(3):
        if sde[i] > 6 and abs(retrieved_p[i]-inserted_p)/(inserted_p) < 0.05: 
            return True, i
    return False, -1

def iterator(periods, radii, n_checks, stats): #Function iterating over different periods and radii
        sdes, detections = np.zeros((n_checks, len(radii), len(periods))), np.zeros((n_checks, len(radii), len(periods)))  #Array to store scores
        start = time.time()
        for i in tqdm(periods):
            for j in tqdm(radii):
                retrieved = 0 #Counter for succesful iterations
                transit_params[0], transit_params[2] = j, i #Setting transit parameters to a certain period, radius
                for k in tqdm(range(n_checks)): #testing parameter set 'n_checks' amount of times
                    phase = np.random.uniform(0, i) #setting a random phase diff
                    transit_params[1] = phase
                    
                    flux = np.copy(data['col2']) #Creating fake transit flux
                    results = main(times, flux, e_flux, 3, i*2.1, False, stats, True, transit_params) #Looking for transit
                    detected, number = retrieved_condition(results[0], i, results[1])
                    if detected:
                        sdes[k, radii.index(j), periods.index(i)] = results[1][number]
                        detections[k, radii.index(j), periods.index(i)] = 1
                    else:
                        sdes[k, radii.index(j), periods.index(i)] = 0
                        detections[k, radii.index(j), periods.index(i)] = 0
                    
                    
                np.savetxt(detections_file_name, np.sum(detections, axis=0), delimiter=',')
            
        return sdes, detections 

def start(periods, radii, n_checks, stats=False):
    return iterator(periods, radii, n_checks, stats)

def main(times, transit_flux, yerr, min_p, max_p, plot, stats, add_transit, transit_params):
    if add_transit:

         #Building a dataset including a fake transit

        #transit_params[0], transit_params[1], transit_params[2] = 2, 1.91, 3.82

        transit_flux += (transit_generator(times, *transit_params)-1)

        #sig2_mask = (transit_flux < 1+2*np.std(transit_flux))*(transit_flux > 1-2*np.std(transit_flux))
        #times, transit_flux, yerr = times[sig2_mask], transit_flux[sig2_mask], yerr[sig2_mask]

        #----------------------------------------------------------------------------------------------------
        #plt.scatter(times, (transit_generator(times, *transit_params)-1), s=2)
        #plt.title("de transit")
        #plt.show()


        #bins = np.linspace(0, transit_params[2], int(3*transit_params[2]/transit_time(transit_params[2])))
        #means, errors, midpoints = binner(bins, times%transit_params[2], transit_flux, yerr)
        #plt.errorbar(midpoints, means, errors, c='b', fmt = '.')
        #plt.axhline(y=1, linestyle = '--', c='grey')
        #plt.axvline(x=transit_params[1], linestyle = '--', c='r', alpha = 0.3)
        #plt.title("Binned Folded data")
        #plt.savefig("Binned_post_transit_sig_clipped_flux.png")
        #plt.show()
        #----------------------------------------------------------------------------------------------------

    results = tls(times, transit_flux, yerr, min_p, max_p, plot, stats)
    return results

In [None]:
heat_map_data = start(input_periods, input_radii, n_checks)
print('Finished')