In [1]:
%%HTML

<style>.container { width:100% !important; }</style>

<style>
div.input {
    display:none;
}
</style>

In [2]:
%%javascript

$('#header').toggle();

IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

<IPython.core.display.Javascript object>

In [3]:
# Enabling the `widget` backend.
# This requires jupyter-matplotlib a.k.a. ipympl.
# ipympl can be install via pip or conda.
%matplotlib widget

#::: imports
import os
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, HTML, FileLink
from scipy.interpolate import UnivariateSpline

#::: local imports
import computer
from mcmc import mcmc_fit
from mcmc_output import mcmc_output
from tabulate import tabulate

#::: plotting settings
import seaborn as sns
sns.set(context='paper', style='ticks', palette='deep', font='sans-serif', font_scale=1.5, color_codes=True)
sns.set_style({"xtick.direction": "in","ytick.direction": "in"})
sns.set_context(rc={'lines.markeredgewidth': 1})

#::: globals (don't do this at home, kids)
data = {'time':None, 'flux':None, 'flux_err':None}
model = {'time':None, 'flux':None}
params = {'radius_planet':None, 'radius_star':None, 'epoch':None, 'period':None ,'a':None, 'incl':None}
line = None

In [4]:
###########################################################################
#::: Define all widgets
###########################################################################
#::: widget dropdown, select which data set you want to load
widget_dropdown_targets = widgets.Dropdown(options = ['Select a target...', 'WASP-189b','KELT-3b','TOI-560c'], description='Target:')

#::: widget sliders, to select the initial guess params
style = {'description_width':'250px'}
layout = {'width':'600px', 'visibility':'hidden'} #hide them to start with
widget_floatslider_radius_planet = widgets.FloatSlider(min=0.1, max=22., value=1., description='Radius of the planet (Earth radii):', style=style, layout=layout)
widget_floatslider_radius_star = widgets.FloatSlider(min=0.1, max=3., value=1., description='Radius of the star (Solar radii):', style=style, layout=layout)
widget_floatslider_epoch = widgets.FloatSlider(min=0., max=1., step=0.01, value=0.3, description='Mid-transit time (days):', style=style, layout=layout)

#::: widget button, to start the MCMC run
layout = {'visibility':'hidden'} #hide the run button to start with
widget_button_run = widgets.Button(description='Investigate', tooltip='Click here to start the algorithm that fits the model to your data.', layout=layout)

#::: outputs
widget_output0 = widgets.Output(layout={'border':'solid 2px', 'margin':'10px 0px 10px 0px'}) #for the chat; margin 'top/right/bottom/left'
widget_output1 = widgets.Output() #for the light curve
widget_output2 = widgets.Output() #for the histograms
widget_output3 = widgets.Output() #for the table

#::: a tab widget for all fancier output (chat, histograms, table)
# tab = widgets.Tab(children = [widget_output1, widget_output2, widget_output3])
# tab.set_title(0, 'Light curve')
# tab.set_title(1, 'Histograms')
# tab.set_title(2, 'Table')
tab = widgets.Tab(children = [widget_output1])
tab.set_title(0, 'Light curve')

#::: widget images
with open("images/icon1.png", "rb") as file:
    image = file.read()
    im1 = widgets.Image(value=image, format='png', width=150)
with open("images/icon2.png", "rb") as file:
    image = file.read()
    im2 = widgets.Image(value=image, format='png', width=150)
with open("images/icon3.jpg", "rb") as file:
    image = file.read()
    im3 = widgets.Image(value=image, format='jpg', width=150)
    
#::: widget labels
lab1 = widgets.Label(value='Which target should we look at, Detective?')
lab2 = widgets.Label(value='We need good clues to start digging deeper...')
lab3 = widgets.Label(value='I think we are ready to start investigating, Detective!')

#::: widget chats
# chat0 = widgets.HTML(value='<style>p{word-wrap: break-word}</style> <p>'+ 
#                             'Good morning, Detective!<br>'+
#                             '</p>',
#                       layout={'width': '330px'})
chat1 = widgets.HTML(value='<style>p{word-wrap: break-word}</style> <p>'+ 
                            'Roger that, Detective! '+
                            'Thanks for all the clues. '+
                            'We are running a full investigation now... '+
                            '</p>')
chat2 = widgets.HTML(value='<style>p{word-wrap: break-word}</style> <p>'+ 
                            'Detective, look at this! '+
                            'The investigation was successful, we nailed down exactly what happened. '+
                            'We also found some old case files in the archive that gave us extra insights. '+
                            'Have a look, we prepared you a case file with all the details: '+
                            '</p>')
# with widget_output0:
#     display(chat0)
    

    
###########################################################################
#::: Clean up and set up the new figure
###########################################################################
plt.close('all')
with widget_output1:
    fig, ax = plt.subplots(1, figsize=(12, 6), tight_layout=True)
    plt.show(fig)
widget_output0.layout.visibility = 'hidden'
tab.layout.visibility = 'hidden'



###########################################################################
#::: Define all interactive actions
###########################################################################
#::: define the function to execute after the selection
def load_and_plot_data(target_name):
    
    #::: catching the blank state
    if target_name == 'Select a target...':
        
        #::: handle the figure
        ax.clear()
        ax.set(xlabel='Time (Days)', ylabel='Relative Brightness', title='No target selected')
        text = ax.text(0.5, 0.5, '?', fontsize=64)
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
            
        #::: temporarily hide the sliders and button
        im2.layout.visibility = 'hidden'
        lab2.layout.visibility = 'hidden'
        widget_floatslider_radius_planet.layout.visibility = 'hidden'
        widget_floatslider_radius_star.layout.visibility = 'hidden'
        widget_floatslider_epoch.layout.visibility = 'hidden'
        im3.layout.visibility = 'hidden'
        lab3.layout.visibility = 'hidden'
        widget_button_run.layout.visibility = 'hidden'
        widget_output0.layout.visibility = 'hidden'
        tab.layout.visibility = 'hidden'
        
        #::: end here
        return None
    
    else:
        #::: unhide the sliders (but keep the button hidden)
        ax.xaxis.set_visible(True)
        ax.yaxis.set_visible(True)
        im2.layout.visibility = 'visible'
        lab2.layout.visibility = 'visible'
        widget_floatslider_radius_planet.layout.visibility = 'visible'
        widget_floatslider_radius_star.layout.visibility = 'visible'
        widget_floatslider_epoch.layout.visibility = 'visible'
        tab.layout.visibility = 'visible'
    
    #::: globals (don't do this at home, kids)
    global data
    global model
    global params
    global line
    
    #::: load data
    data['time'], data['flux'], data['flux_err'], model['time'] = computer.load_data(target_name)
    
    #::: load frozen params and initial guesses for the sliders
    params['radius_planet'], params['radius_star'], params['epoch'], params['period'], params['a'], params['incl'] = computer.load_params(target_name)
    
    #::: set the sliders #TODO fix this: somehow this does not reset all slides when switching targets, but only the first one that has not already been reset...
    widget_floatslider_radius_planet.value = params['radius_planet']
    widget_floatslider_radius_star.value = params['radius_star']
    widget_floatslider_epoch.value = params['epoch']
    
    #::: plot the data and an empty model
    with widget_output1:
        ax.clear()
        ax.set(xlabel='Time (Days)', ylabel='Relative Brightness', title='Light curve of '+target_name)
        ax.plot(data['time'], data['flux'], 'b.', label='The evidence')
        line, = ax.plot(model['time'], np.ones_like(model['time']), color='silver', marker=None, linestyle='-', label='Your first clues')
        ax.legend(loc='lower left')
        fig.canvas.draw_idle() #this change was pivotal to make the interactive plot work


    
#::: update the plot with the light curve model
def update_plot(radius_planet, radius_star, epoch):
    
    #::: catching the blank state
    if widget_dropdown_targets.value == 'Select a target...':
        return None #abort
    
    #::: globals (don't do this at home, kids)
    global data
    global model
    global params
    global line
    
    #::: set params
    params['radius_planet'] = radius_planet
    params['radius_star'] = radius_star
    params['epoch'] = epoch
    
    #::: retrieve the light curve model
    model['flux'] = computer.calc_flux_model(params['radius_planet'], params['radius_star'], params['epoch'], params['period'], params['a'], params['incl'], model['time'])
    
    #::: update the axes (try/except catches some initialisation hickups)
    try:
        line.set_ydata(model['flux'])
    except:
        pass
    with widget_output0:
        fig.canvas.draw_idle() #this change was pivotal to make the interactive plot work
    
    #::: unhide/hide the button, depending on how close the initial guess is to the truth
    if computer.check_initial_guess(widget_dropdown_targets.value, params['radius_planet'], params['radius_star'], params['epoch']):
        im3.layout.visibility = 'visible'
        lab3.layout.visibility = 'visible'
        widget_button_run.layout.visibility = 'visible'
    else:
        im3.layout.visibility = 'hidden'
        lab3.layout.visibility = 'hidden'
        widget_button_run.layout.visibility = 'hidden'
    
    

#::: define the run button
def run(arg):
    
    #::: disable all widgets for luser-proofness
    widget_button_run.disabled = True
    widget_dropdown_targets.disabled = True
    widget_floatslider_radius_planet.disabled = True
    widget_floatslider_radius_star.disabled = True
    widget_floatslider_epoch.disabled = True
    
    #::: make the chat visible
    widget_output0.layout.visibility = 'visible'
    
    #::: change the line color of the initial guess
    #line.set_color('silver') #TODO: somehow this only takes effect after the MCMC run is completed
    
    #::: print statements and other output need an "output widget" (figures do not)
    with widget_output0:
         
        #::: start the MCMC run 
        display(chat1)
        mcmc_fit(widget_dropdown_targets.value, params)
        
        #::: get the MCMC output
        display(chat2)
        posterior_samples, fig_hist, table = mcmc_output(widget_dropdown_targets.value, params=params) #20 samples for plotting
        #plt.close(fig_hist) #for now
        table = tabulate(table, tablefmt='html', headers=['Name', 'Median', 'Lower Error', 'Upper Error', 'Case Note']) #for now
        

    #::: plot the fit
    for i in range(20):

        #::: compute the ellc model (on finer time grid)
        model['flux'] = computer.calc_flux_model(posterior_samples['radius_planet'][i], posterior_samples['radius_star'][i], posterior_samples['epoch'][i], params['period'], params['a'], params['incl'], model['time'])

        #::: compute the the baseline
        y2 = computer.calc_flux_model(posterior_samples['radius_planet'][i], posterior_samples['radius_star'][i], posterior_samples['epoch'][i], params['period'], params['a'], params['incl'], data['time']) #get model on data time grid
        yerr_weights = data['flux_err']/np.nanmean(data['flux_err'])
        weights = 1./yerr_weights
        spl = UnivariateSpline(data['time'], data['flux']-y2, w=weights, s=np.sum(weights)) #train a spline on the data time grid
        baseline = spl(model['time']) #evaluate the spline on the finer time grid

        #::: plot
        ax.plot(model['time'], model['flux']+baseline, 'r-', alpha=0.1)

    #::: add a legend
    ax.plot(np.NaN, np.NaN, 'r-', label='Your final investigation results')
    ax.legend(loc='lower left') 

    #::: save the light curve figure
    fig.savefig( os.path.join('results',widget_dropdown_targets.value,'light_curve.pdf'), bbox_inches='tight' )

    
    #::: add hist and table children to the tab widget
    tab.children = [widget_output1, widget_output2, widget_output3]
    tab.set_title(0, 'Light curve')
    tab.set_title(1, 'Histograms')
    tab.set_title(2, 'Table')
    

    #::: print statements and other output need an "output widget" (figures do not)
    with widget_output1:
        
        #::: show the file links
        #::: not really needed for mybinder's jupyter notebook, as figures show up there by default (but needed for jupyter lab)
        # plt.show(fig_hist) #not needed for mybinder's jupyter notebook, as figures show up there by default (but needed for jupyter lab)
        fname1 = os.path.join('results',widget_dropdown_targets.value,'light_curve.pdf')
        display(FileLink(fname1))
    
    with widget_output2:
        plt.show(fig_hist)
        fname2 = os.path.join('results',widget_dropdown_targets.value,'histograms.pdf')
        display(FileLink(fname2))
        
    with widget_output3:
        display(table)
        fname3 = os.path.join('results',widget_dropdown_targets.value,'table.txt')
        display(FileLink(fname3))
    
    

#::: link the dropdown menu to the creation of the plot
w1 = widgets.interactive(load_and_plot_data, 
                         target_name=widget_dropdown_targets)



#::: now let's interact with this plot
w2 = widgets.interactive(update_plot,
                         radius_planet = widget_floatslider_radius_planet, 
                         radius_star = widget_floatslider_radius_star, 
                         epoch = widget_floatslider_epoch)



#::: execute clicks on the run button
widget_button_run.on_click(run)



###########################################################################
#::: Display everything neatly
###########################################################################
#::: display all widgets in a box
box1 = widgets.VBox(children=[im1, lab1, w1], layout={'align_items':'center'})
box2 = widgets.VBox(children=[im2, lab2, w2], layout={'align_items':'center'})
box3 = widgets.VBox(children=[im3, lab3, widget_button_run], layout={'align_items':'center'})

box4 = widgets.HBox(children=[box1, box2, box3], layout={'border':'solid 2px'})
box5 = widgets.VBox(children=[box4, widget_output0, tab])
# box = widgets.VBox(childred=[box1, box2]) #does not work for some reason

display(box5) # <- this one command displays all children

VBox(children=(HBox(children=(VBox(children=(Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x03\x00…