In [1]:
from IPython.display import HTML

HTML('''<script>
code_show=true; 
function code_toggle() {
 if (code_show){
 $('div.input').hide();
 } else {
 $('div.input').show();
 }
 code_show = !code_show
} 
$( document ).ready(code_toggle);
</script>
<form action="javascript:code_toggle()"><input type="submit" value="Click here to toggle on/off the raw code."></form>''')

In [2]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Dec 14 11:26:33 2020

@author:
Maximilian N. Günther
MIT Kavli Institute for Astrophysics and Space Research, 
Massachusetts Institute of Technology,
77 Massachusetts Avenue,
Cambridge, MA 02109, 
USA
Email: maxgue@mit.edu
Web: www.mnguenther.com
"""

from __future__ import print_function, division, absolute_import
%matplotlib inline

#::: modules
import warnings
import matplotlib
warnings.filterwarnings("ignore", category=DeprecationWarning) 
warnings.filterwarnings("ignore", category=matplotlib.cbook.deprecation.MatplotlibDeprecationWarning) 
import os, sys
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from glob import glob
from pprint import pprint
from astropy.stats import sigma_clip
from matplotlib.widgets import SpanSelector
import ipywidgets as widgets
import IPython
from IPython.display import display, HTML, Markdown, clear_output, Javascript
display(HTML("<style>.container { width:80% !important; }</style>"))
from tkinter import Tk, filedialog

#::: special modules
try:
    from wotan import flatten, t14
except ImportError:
    pass

#::: my modules
import allesfitter
from allesfitter import tessplot
from allesfitter.detection.periodicity import estimate_period
from allesfitter.detection.transit_search import get_tls_kwargs_by_tic, tls_search, 
from allesfitter.time_series import slide_clip
from allesfitter.io import read_csv
from exoworlds.tess import tessio

#::: plotting settings
import seaborn as sns
sns.set(context='paper', style='ticks', palette='deep', font='sans-serif', font_scale=1.5, color_codes=True)
sns.set_style({"xtick.direction": "in","ytick.direction": "in"})
sns.set_context(rc={'lines.markeredgewidth': 1})

%load_ext autoreload
%autoreload 2

In [3]:
"""
Extension for disabling autoscrolling long output, which is super annoying sometimes
Usage:
    %load_ext disable_autoscroll
You can also put the js snippet below in profile_dir/static/js/custom.js
"""

from IPython.display import display, Javascript

disable_js = """
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}
"""

def load_ipython_extension(ip):
    display(Javascript(disable_js))
    print ("autoscrolling long output is disabled")

![allesfitter](_static/_logos/logo_circ_tess_transit_search.png)

In [4]:
#::: globals
global INPUT
global VBOXES
global BUTTONS
global DROPDOWNS
INPUT = {}
VBOXES = {}
BUTTONS = {}
DROPDOWNS = {}
layout = {'width': '180px'}
layout_wide = {'width': '360px'}
layout_textbox = {'width': '120px'}
layout_checkbox = {}

## Load the data

In [5]:
#TODO:
#    run tls in here (spawn it in a terminal to run fast)

In [6]:
INPUT['tic_id'] = widgets.Text(value=None, description='TIC ID:', placeholder='123456789')
INPUT['datafile'] = widgets.Text(value='', description='Input file:', placeholder='Leave empty to automatically download the light curve',)

display(INPUT['tic_id'])
display(INPUT['datafile'])



def select_datafile():
    root = Tk()
    root.withdraw()
    root.call('wm', 'attributes', '.', '-topmost', True)
    INPUT['datafile_dialog'] = filedialog.askopenfilename()
    %gui tk
    if INPUT['datafile_dialog'] != '':
        INPUT['datafile'].value = INPUT['datafile_dialog']
im = widgets.interact_manual(select_datafile);
im.widget.children[0].description = 'Open file dialog' #https://stackoverflow.com/a/51361461/4718101



def load():
    global tic_id
    global time
    global flux
    global flux_err
    global tls_kwargs
    
    tic_id = INPUT['tic_id'].value
    datafile = INPUT['datafile'].value
    
    if len(tic_id)==0:
        print('Please enter a valid TIC ID first.')
        return None
    
    print('Loading TIC', tic_id, '...')
    try:
        tls_kwargs = get_tls_kwargs_by_tic(tic_id=tic_id)
        print(' > Stellar parameters succesfully loaded. Rstar=', tls_kwargs['R_star'], 'Rsun, M_star=', tls_kwargs['M_star'], 'Msun')
    except:
        tls_kwargs = {'R_star': 1., 'M_star': 1.}
        print(' > Stellar parameters could not be loaded. Using solar analog instead.')
    print('Loading data ...')

    try:
        time, flux, flux_err = read_csv(datafile)
        print(' > Input file succesfully loaded.')
    except:
        try:
            time, flux, flux_err = tessio.get(tic_id, unpack=True)
            print(' > Data succesfully loaded locally via tessio.')
        except:
            try:
                print(' > No local file found, trying to download (might take a minute)...')
                time, flux, flux_err = tessio.get_via_lightkurve(tic_id)
                print(' > Data succesfully downloaded via lightkurve.')
            except:
                print(' > Failed to load any data.')
                    
    fig, ax = plt.subplots(figsize=(8,6))
    ps = np.linspace(0.5,101,100)
    tdur1 = [t14(R_s=tls_kwargs['R_star'], M_s=tls_kwargs['M_star'], P=p, small_planet=True) for p in ps]
    tdur2 = [t14(R_s=tls_kwargs['R_star'], M_s=tls_kwargs['M_star'], P=p, small_planet=False) for p in ps]
    ax.plot(ps, tdur1, label='small planet')
    ax.plot(ps, tdur2, label='large planet')
    ax.set(xlabel='Planet orbital period (days)', ylabel='Transit duration (days)')
    ax.legend()
    ax2 = ax.twinx()
    ax2.set(ylim=[3*x for x in ax.get_ylim()], ylabel='Minimum detrending\nwindow length (days)')
    
im = widgets.interact_manual(load);
im.widget.children[0].description = 'Load / Refresh' #https://stackoverflow.com/a/51361461/4718101

Text(value='', description='TIC ID:', placeholder='123456789')

Text(value='', description='Input file:', placeholder='Leave empty to automatically download the light curve')

interactive(children=(Button(description='Run Interact', style=ButtonStyle()), Output()), _dom_classes=('widge…

interactive(children=(Button(description='Run Interact', style=ButtonStyle()), Output()), _dom_classes=('widge…

## Tailor your detrending
In this interactive plot, you have various options:

- **clip**: Clip flares and outliers in the plot. This only affects the plot and nothing else.
- **mask**: Mask out certain regions for the analysis. This **will** affect all other operations.
- **mask_ranges**: If you select to mask regions, you need to type those here (in BJD_TDB). 
  Example: "1970-1977 and 1980-1981 and 1991.1-1991.2" (don't type the "").
- **method**: Choose between biweight, cosine, and rspline (biweight is good for wild noise, cosine is good for very sinusoidal noise).
- **window_length**: Choose the detrending window, in days (1 day is a good first guess). Try to keep it at least 3x the transit duration you expect.

Note: For speed, only every second data point is plotted - but all data is analyzed.

In [7]:
#==============================================================================
#::: layout
#==============================================================================
out1 = widgets.Output()
out2 = widgets.Output()
out3 = widgets.Output()
out4 = widgets.Output()
tab = widgets.Tab(children=[out1,out2,out3,out4])
tab.set_title(0,'Data & Trend')
tab.set_title(1,'Data Periodogram')
tab.set_title(2,'Residuals')
tab.set_title(3,'Residual Periodogram')
display(tab)


#==============================================================================
#::: manually mask bad regions
#==============================================================================
def get_mask(mask, mask_ranges):
    mask2 = np.array([False] * len(time)) #mask2 is an array the size of time
    if mask and len(mask_ranges)>0: #mask is a bool
            limits = []
            for m in mask_ranges.split(' and '):
                try:
                    m0, m1 = [float(x) for x in m.split('-')]
                except:
                    mask = False
                    warnings.warn('Your mask_ranges syntax was invalid and has been ignored. The parameter "mask" has been turned off automatically.')
                limits.append([m0,m1])
                mask2[ (time>2457000+m0) & (time<2457000+m1) ] = True
            return mask2, limits
    else:
        return mask2, []
    
def plot_mask(axes, limits):
    if len(limits)>0:
        for ax in np.atleast_1d(axes): 
            for (xmin, xmax) in limits:
                ax.axvspan(xmin, xmax, color='r', alpha=0.3)

                
#==============================================================================
#::: manually optimize detrending
#==============================================================================
def optimize(clip, mask, mask_ranges, method, window_length):
    global wotan_kwargs
    global fig1
    global fig2
    global fig3
    global fig4
    global time2
    global flux3
    global flux_err2
    
    try:
        time*flux*flux_err #check if they work
    except:
        print('Please load a valid data file first.')
        return None
    
    #::: mask (if chosen)
    mask2, limits = get_mask(mask, mask_ranges) #mask2 is an array the size of time
    time2 = time[~mask2]
    flux2 = flux[~mask2]
    flux_err2 = flux_err[~mask2]

    #::: wotan kwargs
    wotan_kwargs = {}
    wotan_kwargs['slide_clip'] = {}
    wotan_kwargs['slide_clip']['window_length'] = 1
    wotan_kwargs['slide_clip']['low'] = 20
    wotan_kwargs['slide_clip']['high'] = 3
    wotan_kwargs['flatten'] = {}
    wotan_kwargs['flatten']['method'] = method
    wotan_kwargs['flatten']['window_length'] = window_length

    #silence
    import contextlib
    import sys
    class DummyFile(object):
        def write(self, x): pass
    @contextlib.contextmanager
    def nostdout():
        save_stdout = sys.stdout
        sys.stdout = DummyFile()
        yield
        sys.stdout = save_stdout
    
    #::: clip and detrend
    with nostdout():
        flux3 = sigma_clip(flux2, sigma_upper=3, sigma_lower=20)
        flux3 = slide_clip(time2, flux3, **wotan_kwargs['slide_clip']) #slide_clip is super fast (<1 second for a TESS 2 min lightcurve for a single Sector)
        flux3, trend = flatten(time2, flux3, return_trend=True, **wotan_kwargs['flatten']) #flatten is super fast, (<1 second for a TESS 2 min lightcurve for a single Sector)
    
    #::: data & tred
    with out1:
        out1.clear_output()
        axes = tessplot(time[::2], flux[::2], clip=True)
        axes = tessplot(time2[::2], trend[::2], axes=axes, shade=False, marker=None, color='r', linestyle='-', linewidth=2)
        plot_mask(axes, limits) #if chosen
        fig1 = np.atleast_1d(axes)[0].get_figure()
        plt.show(fig1)
        #plt.gcf().savefig(outdir+'/tessplot_trend.pdf')
    
    #::: data periodogram (takes masks into account)
    with out2:
        out2.clear_output()
        axes = estimate_period(time2, flux2, flux_err2, options={'show_plot':True, 'save_plot':False, 'return_plot':True})[2]
        fig2 = np.atleast_1d(axes)[0].get_figure()
#         plt.show(fig2)
        #estimate_period(time, flux2, flux_err, options={'outdir':outdir, 'fname_plot':'periodogram_flat'})
        
    #::: residuals
    with out3:
        out3.clear_output()
        axes = tessplot(time2[::2], flux3[::2])
        fig3 = np.atleast_1d(axes)[0].get_figure()
        plt.show(fig3)
        #plt.gcf().savefig(outdir+'/tessplot_flat.pdf')

    #::: residual periodogram
    with out4:
        out4.clear_output()
        axes = estimate_period(time2, flux3, flux_err2, options={'show_plot':True, 'save_plot':False, 'return_plot':True})[2]
        fig4 = np.atleast_1d(axes)[0].get_figure()
#         plt.show(fig4)
        #estimate_period(time, flux2, flux_err, options={'outdir':outdir, 'fname_plot':'periodogram_flat'})

                
#==============================================================================
#::: main
#==============================================================================
im = widgets.interact_manual(optimize, clip=True, mask=False, mask_ranges='', method=['biweight','cosine','rspline'], window_length=1.)
im.widget.children[5].description = 'Plot / Refresh' #https://stackoverflow.com/a/51361461/4718101 

Tab(children=(Output(), Output(), Output(), Output()), _titles={'0': 'Data & Trend', '1': 'Data Periodogram', …

interactive(children=(Checkbox(value=True, description='clip'), Checkbox(value=False, description='mask'), Tex…

## Run a Transit-Least-Squares search

In [8]:
def search(period_min, period_max):
    try:
        tls_kwargs['R_star']
    except:
        print('Please finish the above steps first.')
        return None
    
    tls_kwargs['period_min'] = period_min
    tls_kwargs['period_max'] = period_max
    tls_kwargs['show_progress_bar'] = True
    
    tls_search(time2, flux3, flux_err2, 
               wotan_kwargs=None, #use the detrending from above 
               tls_kwargs=tls_kwargs,
               options={'show_plot':'2'})

im = widgets.interact_manual(search, 
                             period_min=widgets.IntSlider(min=0, max=99, step=1, value=1), 
                             period_max=widgets.IntSlider(min=1, max=100, step=1, value=20));
im.widget.children[2].description = 'TLS Search' #https://stackoverflow.com/a/51361461/4718101'

interactive(children=(IntSlider(value=1, description='period_min', max=99), IntSlider(value=20, description='p…

## Save the results

In [9]:
INPUT['outdir'] = widgets.Text(value='', description='Output directory:')
display(INPUT['outdir'])


def select_outdir():
    root = Tk()
    root.withdraw()
    root.call('wm', 'attributes', '.', '-topmost', True)
    INPUT['outdir_dialog'] = filedialog.askdirectory()
    %gui tk
    if INPUT['outdir_dialog'] != '':
        INPUT['outdir'].value = INPUT['outdir_dialog']
im = widgets.interact_manual(select_outdir);
im.widget.children[0].description = 'Open file dialog' #https://stackoverflow.com/a/51361461/4718101


def save():
    try:
        outdir = INPUT['outdir'].value
        allesfitter.write_dic( os.path.join(outdir,'TIC_'+tic_id+'_wotan_kwargs.json'), wotan_kwargs )
        fig1.savefig(os.path.join(outdir,'TIC_'+tic_id+'_data_and_trend.pdf'))
        fig2.savefig(os.path.join(outdir,'TIC_'+tic_id+'_data_periodogram.pdf'))
        fig3.savefig(os.path.join(outdir,'TIC_'+tic_id+'_residuals.pdf'))
        fig4.savefig(os.path.join(outdir,'TIC_'+tic_id+'_residuals_periodogram.pdf'))
        print('Saved all files succesfully.')
    except:
        print('Please get some results first.')
    
    
im = widgets.interact_manual(save);
im.widget.children[0].description = 'Save' #https://stackoverflow.com/a/51361461/4718101

Text(value='', description='Output directory:')

interactive(children=(Button(description='Run Interact', style=ButtonStyle()), Output()), _dom_classes=('widge…

interactive(children=(Button(description='Run Interact', style=ButtonStyle()), Output()), _dom_classes=('widge…