# Notebook version of the SHERLOCK algorithm
### used to construct a probabilistic period spacing pattern (PSP) for g-mode asteroseismology

Created by Jordan Van Beeck, using Joey Mombarg's SHERLOCK algorithm.

In [1]:
# import statements
import matplotlib
matplotlib.use("TkAgg")
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import interpolate
from matplotlib.widgets import Button

In [2]:
# set rc parameters for plotting
fontsize = 14
matplotlib.rc('xtick', labelsize=fontsize)
matplotlib.rc('ytick', labelsize=fontsize)

### User input is required: adjust the following parameters before running

In [3]:
# Azimuthal order to search for. m = -1 retrograde, m = 0 zonal, m = 1 prograde.
m = 1
# Maximum number of skipped modes in the pattern before terminating.
max_skip = 2
# Maximum number of modes towards smaller periods before terminating.
max_modes_smaller_periods = 35
# Maximum number of modes towards larger periods before terminating.
max_modes_larger_periods  = 35
# KIC number
KIC =  '07760680' #
# Work directory, no trailing slash.
WORK_DIR = '/home/jordanv/Documents/TESTGRID/sherlock'
# Always use the entire grid for the search window and expected value.
always_use_entire_grid = False

### Read the observational information used to construct the pattern

In [4]:
def ensure_kic_of_correct_length(KIC: str, correct_length: int=9) -> str:
    """Ensures that the KIC number is of the correct length for loading in specific data (while reconstructing the names of these files). 

    Parameters
    ----------
    KIC : str
        KIC number without 'KIC' prefix.
    correct_length : int, optional
        Specified 'correct' length of the number, by default 9

    Returns
    -------
    str
        KIC number (without 'KIC' prefix) of the specified length.
    """
    return KIC.rjust(correct_length, '0')

In [5]:
def read_frequency_list(WORK_DIR: str, KIC: str, combinations_included: bool= True) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Read frequency list from Van Beeck et al. (2021, A&A, 655, A59) for a given KIC number, picking their extraction strategy with the highest f_sv factor.

    Parameters
    ----------
    WORK_DIR : str
        Work directory.
    KIC : str
        KIC number without 'KIC' prefix.
    combinations_included : bool, optional
        Whether to also include combination frequencies in the pattern, by default True

    Returns
    -------
    P_obs : np.ndarray
        Contains the observed periods in days.
    pe : np.ndarray
        Contains the uncertainties on the periods.
    A : np.ndarray
        Contains the amplitudes of the detected signals.
    ae : np.ndarray
        Contains the uncertainties on the amplitudes.
    phase : np.ndarray
        Contains the phases of the observed signals in radians.
    phe : np.ndarray
        Contains the uncertainties on the phases.
    nonlin_id : np.ndarray
        Contains the non-linear mode ID (e.g. 'freq1+freq2').
    """
    # read the data from the CSV file
    csv_directory = f'{WORK_DIR}/example_input_data/'
    ascii_file_name = f'amplitudes_frequencies_phases_KIC0{KIC}_strategy_5.asc'
    df_names = ['freq', 'sigma_freq', 'ampl', 'sigma_ampl', 'phase', 'sigma_phase', 'nr', 'nonlin_id']
    df = pd.read_csv(f'{csv_directory}{ascii_file_name}', sep = '\t', header = 9, names=df_names)
    # store the data into local variables
    if not combinations_included:
        i = 0
        while 1:
            nonlin_id = df['nonlin_id'][i]
            if '+' in nonlin_id or '-' in nonlin_id or '*' in nonlin_id:
                break
            else:
                i += 1
        n = i-1
        freq = np.array(df['freq'][0:n])
        fe = np.array(df['sigma_freq'][0:n])
        A = np.log10(np.array(df['ampl'][0:n]))
        ae = np.array(df['sigma_ampl'][0:n])
        phase = np.array(df['phase'][0:n])
        phe = np.array(df['sigma_phase'][0:n])
        nonlin_id = np.array(df['nonlin_id'][0:n])
    else:
        freq = np.array(df['freq'])
        fe = np.array(df['sigma_freq'])
        A = np.log10(np.array(df['ampl'])/np.min(np.array(df['ampl'])))
        ae = np.array(df['sigma_ampl'])
        phase = np.array(df['phase'])
        phe = np.array(df['sigma_phase'])
        nonlin_id = np.array(df['nonlin_id'])
    # compute the periods of the observed signals
    P_obs = 1/freq
    pe = fe/freq**2
    # return the observed signal characteristics
    return P_obs, pe, A, ae, phase, phe, nonlin_id

In [6]:
P_obs, pe, A, ae, phase, phe, nonlin_id = read_frequency_list(WORK_DIR, KIC, combinations_included = True)

### Machinery functions used for the PSP search

In [7]:
def get_deltaP_sel(x, n, deltaP_obs1):
    '''
    Get the difference of period-spacing of a specific radial order.

    -- Input --
    n: radial order
    deltaP_obs1: previous observed difference in period-spacing

    -- Output --
    deltaP_sel: expected difference in period spacings (P_n - P_n+1) and (P_n+1 - P_n+2) in seconds, given deltaP_obs1. p(deltaP_2 | deltaP_1)
    deltaP_all: expected difference in period spacings for the entire grid. p(deltaP)
    '''
    dP0 = np.array(x[str(n-1)])   - np.array(x[str(n)])
    dP1 = np.array(x[str(n)])   - np.array(x[str(n+1)])
    dP2 = np.array(x[str(n+1)]) - np.array(x[str(n+2)])
    # Previous period-spacing
    deltaP_prev = (dP1 - dP0)*86400
    # Next periods-spacing
    deltaP_ = (dP2 - dP1)*86400
    deltaP_ = np.array(deltaP_)
    ddp = np.abs(deltaP_prev - deltaP_obs1)
    deltaP_sort = deltaP_[np.argsort(ddp)]
    # Select 500 previous period-spacings that are closest to the observed one.
    deltaP_sel_ = deltaP_sort[0:500]
    # Filter nans in the grid.
    deltaP_sel = [x_ for x_ in deltaP_sel_ if not np.isnan(x_)]
    deltaP_all = [x_ for x_ in deltaP_ if not np.isnan(x_)]
    return deltaP_sel, deltaP_all

In [8]:
def deltaP_expected(x, deltaP_obs1, skipped_radial_order):
    '''
    Interpolate the distribution of differences in period spacing, and compute the most probable one.
    Also compute a search window based on the min/max value that is 0.001*p_max, where p_max is the highest probability.
    Lastely, integrate the probability distribution to get the normalization factor.

    -- Input --
    deltaP_obs1:          observed difference in period-spacings DeltaP_2 - DeltaP_1.
    skipped_radial_order: was a radial skipped in the pattern?

    -- Output --
    deltaP_min/max: search window.
    max_prob:       most likely difference in period spacing.
    p_trans_ipol:   interpolator for the PDF of differences in period-spacings.
    norm:           inverse of the integral of the PDF.
    '''
    deltaP_sel = []
    deltaP_all = []

    for n in np.arange(2,98,1):
        deltaP_sel_, deltaP_all_ = get_deltaP_sel(x, n,deltaP_obs1)
        deltaP_sel.extend(deltaP_sel_)
        deltaP_all.extend(deltaP_all_)

    deltaP_sel = np.array(deltaP_sel)
    deltaP_all = np.array(deltaP_all)

    # If a radial order is skipped, we take the estimate over the entire grid, i.e. compute p(deltaP) instead of p(deltaP_2 | deltaP_1).
    if skipped_radial_order:
        deltaP_sel = deltaP_all

    bin_low = np.min(deltaP_sel)
    bin_up  = np.max(deltaP_sel)


    PDF_bin, bins_edge = np.histogram(deltaP_sel, bins=10000, density=True, range = [bin_low, bin_up])
    bins = bins_edge[1:] + 0.5*(bins_edge[:-1] - bins_edge[1:])
    max_prob = bins[np.argmax(PDF_bin)]

    p_trans_ipol = interpolate.interp1d(bins, PDF_bin, kind='quadratic', fill_value='extrapolate')


    xx = np.linspace(np.min(bins), np.max(bins), 10000)
    p_trans_pos = p_trans_ipol(xx)
    #p_trans_pos[p_trans_pos <= 0] = 0.
    norm = 1/np.trapz(p_trans_pos)

    p_scale = np.array(p_trans_ipol(xx)/np.max(p_trans_ipol(xx)))
    #xmax = xx[np.argmax(p_scale)]
    #deltaP_min = np.max(xx[(p_scale < 10**-3) & (xx < xmax)]) #-3
    #deltaP_max = np.min(xx[(p_scale < 10**-3) & (xx > xmax)]) #-3
    # This cutoff value works well for the SPB grid. Make the plots below to verify.
    deltaP_min = np.min(xx[p_scale > 10**-3])
    deltaP_max = np.max(xx[p_scale > 10**-3])

    # Make plots of the probability distributions.
    if True:
        fig, ax = plt.subplots()
        plt.subplots_adjust(left = 0.14, bottom = 0.15, right = 0.95, top = 0.95)
        ax.plot(xx, np.log10(p_trans_ipol(xx)/np.max(p_trans_ipol(xx))), color = 'k')
        ax.vlines([deltaP_min, deltaP_max], ymin = 0, ymax = 1, color = 'grey')
        ax.vlines([max_prob], ymin = 0, ymax = 1, color = 'g')
        ax.axhline(-3, color ='grey', linestyle ='dashed')
        ax.set_xlabel(r'$\Delta P_1 - \Delta P_2$', fontsize = 14)
        ax.set_ylabel(r'$\log P/P_{\rm max}$', fontsize = 14)
        ax.set_ylim(-4,1.5)
        ax.set_xlim(-2000, 2000)

    return deltaP_min, deltaP_max, max_prob, p_trans_ipol, norm

In [9]:
def p_emis(A_potential, in_log):
    '''
    Compute the emission probability.
    '''
    if in_log:
        A_potential = 10**A_potential
    return A_potential/np.sum(A_potential)

### Interactive Figure Class

In [10]:
class InteractivePlot:
    def __init__(self, WORK_DIR, KIC, m, P_obs, A, pe, nonlin_id, fontsize, fig=None, multi_ax=None):
        # - file
        # work directory
        self.WORK_DIR = WORK_DIR
        # KIC nr
        self.KIC = KIC
        # - mode
        # azimuthal order
        self.m = m
        # - figure
        # store the figure
        self.fig = plt.gcf() if fig is None else fig
        # get the axes
        self.ax = self.fig.gca() if multi_ax is None else multi_ax
        # get the button axes
        self.search_ax = plt.axes([0.25, 0.93, 0.1, 0.05])
        # get the reset axes
        self.reset_ax = plt.axes([0.75, 0.93, 0.1, 0.05])
        # get the save axes
        self.save_ax = plt.axes([0.50, 0.93, 0.1, 0.05])
        # define the font size
        self.fontsize = fontsize
        # - observational data
        # store observed periods
        self.P_obs = P_obs
        # store observed amplitudes
        self.A = A
        # store nonlinear ids
        self.nonlin_id = nonlin_id
        # store observed period errors
        self.pe = pe
        # - global variables
        # store the initial period index
        self.initial_period_index = []
        # store the PSP dictionary
        self.psp_dict = {}
        # - data initialization
        # load pickled data
        self.load_pickled_data()
        # add the observed signals to the plot
        self.add_observed_signals()
    
    def __enter__(self):
        self.connect()
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        self.disconnect()
        
    def connect(self):
        """Install event handlers for the plot.
        """
        # add the picker
        self.picker = self.fig.canvas.mpl_connect('pick_event', self.onpick)
        
        # add the search button
        self.search_button = Button(self.search_ax, 'Search')
        self.search_connection = self.search_button.on_clicked(self.button_search)  
    
        # add the reset button
        self.reset_button = Button(self.reset_ax, 'Reset')
        self.reset_connection = self.reset_button.on_clicked(self.reset_selections)
        
        # add the save button
        self.save_button = Button(self.save_ax, 'Save')
        self.save_connection = self.save_button.on_clicked(self.save_selections)
    
    def disconnect(self):
        """Remove event handlers for the plot.
        """
        # disconnect all event handlers
        self.fig.canvas.mpl_disconnect(self.picker)
        self.search_button.disconnect(self.search_connection)
        self.reset_button.disconnect(self.reset_connection)
        self.save_button.disconnect(self.save_connection)

    def load_pickled_data(self):
        with open(f'{self.WORK_DIR}/grids/gyre_per_l1m{self.m}_ext.pkl', 'rb') as f:
            self.x = pickle.load(f)
            
    def add_observed_signals(self):
        self.lines = []
        for i, p in enumerate(self.P_obs):
            line_ = self.ax[0].vlines(x=p, ymin=0, ymax=self.A[i], color='k', picker=True)
            self.lines.append(line_)
            
    def onpick(self, event):
        if event.artist in self.lines:
            ind = self.lines.index(event.artist)
            if ind not in self.initial_period_index:
                self.initial_period_index.append(ind)
                event.artist.set_color('red')
                self.fig.canvas.draw_idle()
            else:
                self.initial_period_index.remove(ind)
                event.artist.set_color('k')
                self.fig.canvas.draw_idle()
    
    def button_search(self, event):
        # store the initial period indices in local variables
        ip1 = self.initial_period_index[0]
        ip2 = self.initial_period_index[1]
        ip3 = self.initial_period_index[2]
        print(f"Selected points: {self.initial_period_index}")
        # Sort the periods such that the order in which you click the initial periods is not important.
        P_ini_sort = np.sort([P_obs[ip1], P_obs[ip2], P_obs[ip3]])
        P_obs1_ini = P_obs1 = P_ini_sort[0]
        P_obs2_ini = P_obs2 = P_ini_sort[1]
        P_obs3_ini = P_obs3 = P_ini_sort[2]
        print(P_obs1_ini, P_obs2_ini, P_obs3_ini)
        # Add vertical lines to probability plot
        self.ax[2].vlines(x=P_obs1, ymin=0, ymax=1, color = 'r')
        self.ax[2].vlines(x=P_obs2, ymin=0, ymax=1, color = 'r')
        self.ax[2].vlines(x=P_obs3, ymin=0, ymax=1, color = 'r')
        # store internal variables for the computation of the probabilities
        pattern = []
        restart = []
        direction   = []
        uncertainty = []
        probability_tot   = []
        probability_trans = []
        probability_emis  = []
        nonlin_id_arr = []
        pattern.extend([P_obs1, P_obs2, P_obs3])
        # Keep track of the direction of the pattern building (towards lower/higher periods) and properly order them afterwards.
        direction.extend(['i', 'i', 'i'])
        uncertainty.extend([pe[ip1], pe[ip2], pe[ip3]])
        # We set the probabilities of the manually selected modes equal to 1.
        probability_tot.extend([1.,1.,1.])
        probability_trans.extend([1.,1.,1.])
        probability_emis.extend([1.,1.,1.])
        nonlin_id_arr.extend([nonlin_id[ip1][1:-1], nonlin_id[ip2][1:-1], nonlin_id[ip3][1:-1]])
        skipped_radial_order = False
        go_on = 0
        skipped = 0
        
        # compute the search window based on the entire grid of models
        deltaP_all_min, deltaP_all_max, _, _, _ = deltaP_expected(self.x, np.nan, True)
        
        # perform a search towards the right (longer periods)
        print('Towards longer periods (right)...')
        while go_on < max_modes_larger_periods:
            DeltaP_obs1 = np.abs(P_obs1 - P_obs2)*86400
            DeltaP_obs2 = np.abs(P_obs2 - P_obs3)*86400
            deltaP_obs1 = DeltaP_obs1 - DeltaP_obs2

            if always_use_entire_grid:
                skipped_radial_order = True
            # Compute the expected difference in period-spacing for the next period, and the search interval.
            deltaP_min, deltaP_max, deltaP_exp, p_trans_ipol, norm  = deltaP_expected(self.x, deltaP_obs1, skipped_radial_order)
            # Compute the most probable period spacing, and the expected minimum and maximum period spacing.
            DeltaP_exp  = DeltaP_obs2 - deltaP_exp
            # Take the smallest search window between the value computed using the deltaP_obs1 value and the value using the entire grid of models.
            DeltaP_up   = DeltaP_obs2 - np.max([deltaP_all_min, deltaP_min])
            DeltaP_low  = DeltaP_obs2 - np.min([deltaP_all_max, deltaP_max])

            DeltaP_all  =  (self.P_obs - P_obs3)*86400
            # Keep only periods with spacings within the search interval and positive spacings.
            get = np.array([DeltaP_low < DeltaP_all]) & np.array([DeltaP_all < DeltaP_up]) & np.array([DeltaP_all > 0])

            # Retry with a larger search window if no candidates are found.
            if np.sum(get) == 0:
                DeltaP_low  = DeltaP_obs1 + deltaP_all_min
                DeltaP_up   = DeltaP_obs1 + deltaP_all_max
                get = np.array([DeltaP_low < DeltaP_all]) & np.array([DeltaP_all < DeltaP_up]) & np.array([DeltaP_all > 0])

            if np.sum(get) > 0:
                P_potential = self.P_obs[get[0]]
                np.set_printoptions(precision=16)
                print('pobs', P_potential)

                A_potential = self.A[get[0]]
                print('Aobs', A_potential)

                pe_potential = pe[get[0]]
                # Compute the emission probabilities of all candidates.
                p_emission  = p_emis(A_potential, in_log = True)
                DeltaP_potential = (P_potential - P_obs3)*86400
                deltaP_obs2 = DeltaP_obs1 - DeltaP_potential
                p_transition = norm * p_trans_ipol(deltaP_obs2)
                # Interpolation can give probabilities slightly smaller than 0. Set these to 0.
                p_transition[p_transition < 0] = 0

                sig = np.array([p_emission > 0.]) & np.array([p_transition > 0])

                if np.sum(sig) > 0:
                    # Compute the total probability and normalize.
                    p_total = p_transition * p_emission
                    p_total /= np.sum(p_total)

                    P_next = P_potential[np.argmax(p_total)]
                    print(f'{go_on} New period found right at {P_next} with probability {np.max(p_total)} (p_trans = {p_transition[np.argmax(p_total)]}, p_emis = {p_emission[np.argmax(p_total)]})')
                    print(self.nonlin_id[P_obs == P_next][0])
                    pattern.extend([P_next])
                    uncertainty.extend([pe_potential[np.argmax(p_total)]])
                    probability_tot.extend([np.max(p_total)])
                    probability_trans.extend([p_transition[np.argmax(p_total)]])
                    probability_emis.extend([p_emission[np.argmax(p_total)]])
                    nonlin_id_arr.extend([self.nonlin_id[P_obs == P_next][0][1:-1]])
                    self.ax[0].axvspan((DeltaP_low/86400.)+P_obs3, (DeltaP_up/86400.)+P_obs3, ymin=-go_on, ymax=10, facecolor='g', alpha=0.5)
                    self.ax[0].vlines(x=P_obs3+(DeltaP_exp/86400.), ymin=0, ymax=np.max(self.A), color='g', linestyle='dashed')
                    self.ax[0].vlines(x=P_next, ymin=0, ymax=self.A[P_obs == P_next], color='r')
                    self.ax[2].vlines(x=P_next, ymin=0, ymax=np.max(p_total), color='k')
                    self.ax[0].text(P_next, self.A[P_obs == P_next], str(go_on+1), color='r', fontsize=10, ha='center' )
                    skipped_radial_order = False

                else:
                    P_next = P_obs3 + (DeltaP_exp/86400.)
                    self.ax[0].vlines(x=P_next, ymin=0, ymax=np.max(self.A), color='grey', linestyle='dashed')
                    print(f'{go_on} skipping radial order because of low emission/transition probability')
                    pattern.extend([np.nan])
                    uncertainty.extend([np.nan])
                    probability_tot.extend([np.nan])
                    probability_trans.extend([np.nan])
                    probability_emis.extend([np.nan])
                    nonlin_id_arr.extend(['missing'])

                    skipped_radial_order = True
                    skipped += 1
            else:
                P_next = P_obs3 + (DeltaP_exp/86400.)
                self.ax[0].vlines(x=P_next, ymin=0, ymax=np.max(self.A), color='grey', linestyle='dashed')
                self.ax[0].axvspan((DeltaP_low/86400.)+P_obs3, (DeltaP_up/86400.)+P_obs3, facecolor='grey', alpha=0.5)
                print((DeltaP_low/86400.)+P_obs3, (DeltaP_exp/86400.)+P_obs3,  (DeltaP_up/86400.)+P_obs3)
                print(f'{go_on} skipping radial order')
                skipped_radial_order = True
                skipped += 1
                pattern.extend([np.nan])
                uncertainty.extend([np.nan])
                probability_tot.extend([np.nan])
                probability_trans.extend([np.nan])
                probability_emis.extend([np.nan])
                nonlin_id_arr.extend(['missing'])

            direction.extend('r')

            P_obs1 = P_obs2
            P_obs2 = P_obs3
            P_obs3 = P_next

            go_on += 1
            if skipped > max_skip:
                print('Skipped too many modes. Stopping.')
                break
        
        P_obs1 = P_obs1_ini
        P_obs2 = P_obs2_ini
        P_obs3 = P_obs3_ini

        skipped_radial_order = False
        go_on = 0
        skipped = 0
        
        # perform a search towards the left (shorter periods)
        print('Towards shorter periods (left)...')
        while go_on < max_modes_smaller_periods:
            DeltaP_obs1 = np.abs(P_obs1 - P_obs2)*86400
            DeltaP_obs2 = np.abs(P_obs2 - P_obs3)*86400
            deltaP_obs1 = DeltaP_obs1 - DeltaP_obs2
            if always_use_entire_grid:
                skipped_radial_order = True
            deltaP_min, deltaP_max, deltaP_exp, p_trans_ipol, norm  = deltaP_expected(self.x, deltaP_obs1, skipped_radial_order)
            DeltaP_exp  = DeltaP_obs1 + deltaP_exp
            DeltaP_low  = DeltaP_obs1 + np.max([deltaP_all_min, deltaP_min])
            DeltaP_up   = DeltaP_obs1 + np.min([deltaP_all_max, deltaP_max])


            DeltaP_all  =  (P_obs1 - P_obs)*86400
            get = np.array([DeltaP_low < DeltaP_all]) & np.array([DeltaP_all < DeltaP_up]) & np.array([DeltaP_all > 0])
            if np.sum(get) == 0:
                DeltaP_low  = DeltaP_obs1 + deltaP_all_min
                DeltaP_up   = DeltaP_obs1 + deltaP_all_max
                get = np.array([DeltaP_low < DeltaP_all]) & np.array([DeltaP_all < DeltaP_up]) & np.array([DeltaP_all > 0])

            if np.sum(get) > 0:
                P_potential = self.P_obs[get[0]]
                A_potential = self.A[get[0]]
                pe_potential = self.pe[get[0]]

                p_emission  = p_emis(A_potential, in_log = True)
                DeltaP_potential = (P_obs1 - P_potential)*86400
                deltaP_obs2 =  DeltaP_potential - DeltaP_obs1
                p_transition = norm * p_trans_ipol(deltaP_obs2)

                print(p_transition)
                p_transition[p_transition < 0] = 0

                sig = np.array([p_emission > 0.0]) & np.array([p_transition > 0])
                if np.sum(sig) > 0:
                    p_total = p_transition * p_emission
                    p_total /= np.sum(p_total)

                    P_next = P_potential[np.argmax(p_total)]
                    print(f'{go_on} New period found left at {P_next} with probability {np.max(p_total)} (p_trans = {p_transition[np.argmax(p_total)]}, p_emis = {p_emission[np.argmax(p_total)]})')
                    print(nonlin_id[P_obs == P_next][0])

                    pattern.extend([P_next])
                    uncertainty.extend([pe_potential[np.argmax(p_total)]])
                    probability_tot.extend([np.max(p_total)])
                    probability_trans.extend([p_transition[np.argmax(p_total)]])
                    probability_emis.extend([p_emission[np.argmax(p_total)]])
                    nonlin_id_arr.extend([nonlin_id[P_obs == P_next][0][1:-1]])

                    self.ax[0].vlines(x=P_next, ymin=0, ymax=self.A[P_obs == P_next], color='r')
                    self.ax[0].text(P_next, self.A[P_obs == P_next], str(go_on+1), color='r', fontsize=10, ha='center' )
                    self.ax[2].vlines(x=P_next, ymin=0, ymax=np.max(p_total), color='k')
                    self.ax[0].axvspan(P_obs1-(DeltaP_low/86400.), P_obs1-(DeltaP_up/86400.), ymin=-go_on, ymax=10, facecolor='g', alpha=0.5)
                    self.ax[0].vlines(x=P_obs1-(DeltaP_exp/86400.), ymin=0, ymax=np.max(self.A), color='g', linestyle='dashed')
                    skipped_radial_order = False
                else:
                    P_next = P_obs1 - (DeltaP_exp/86400.)
                    self.ax[0].vlines(x=P_next, ymin=0, ymax=np.max(self.A), color='grey', linestyle='dashed')
                    self.ax[0].axvspan(P_obs1-(DeltaP_low/86400.), P_obs1-(DeltaP_up/86400.), facecolor='grey', alpha=0.5)
                    print(f'{go_on} skipping radial order because of low emission/transition probability')
                    skipped+=1
                    pattern.extend([np.nan])
                    uncertainty.extend([np.nan])
                    probability_tot.extend([np.nan])
                    probability_trans.extend([np.nan])
                    probability_emis.extend([np.nan])
                    nonlin_id_arr.extend(['missing'])
                    skipped_radial_order = True

            else:
                P_next = P_obs1 - (DeltaP_exp/86400.)
                self.ax[0].vlines(x=P_next, ymin=0, ymax=np.max(self.A), color='grey', linestyle='dashed')
                self.ax[0].axvspan(P_obs1-(DeltaP_low/86400.), P_obs1-(DeltaP_up/86400.), facecolor='grey', alpha=0.5)
                print(f'{go_on} skipping radial order')
                pattern.extend([np.nan])
                uncertainty.extend([np.nan])
                probability_tot.extend([np.nan])
                probability_trans.extend([np.nan])
                probability_emis.extend([np.nan])
                nonlin_id_arr.extend(['missing'])
                skipped+=1
                skipped_radial_order = True

            direction.extend('l')

            P_obs3 = P_obs2
            P_obs2 = P_obs1
            P_obs1 = P_next

            go_on += 1
            if skipped > max_skip:
                print('Skipped too many modes. Stopping.')
                break
        pattern           = np.array(pattern)
        direction         = np.array(direction)
        uncertainty       = np.array(uncertainty)
        probability_tot   = np.array(probability_tot)
        probability_trans = np.array(probability_trans)
        probability_emis  = np.array(probability_emis)
        nonlin_id_arr     = np.array(nonlin_id_arr)
        
        # Stitch the patterns towards lower periods and towards higher periods together.
        pattern           = np.array(list(np.flip(pattern[direction == 'l'])) + list(pattern[direction != 'l']))
        uncertainty       = np.array(list(np.flip(uncertainty[direction == 'l'])) + list(uncertainty[direction != 'l']))
        probability_tot   = np.array(list(np.flip(probability_tot[direction == 'l'])) + list(probability_tot[direction != 'l']))
        probability_trans = np.array(list(np.flip(probability_trans[direction == 'l'])) + list(probability_trans[direction != 'l']))
        probability_emis  = np.array(list(np.flip(probability_emis[direction == 'l'])) + list(probability_emis[direction != 'l']))
        nonlin_id_arr     = np.array(list(np.flip(nonlin_id_arr[direction == 'l'])) + list(nonlin_id_arr[direction != 'l']))
        
        # Plot the found period-spacing pattern.
        for i in range(len(pattern)-2):
            if i == -1:
                if np.isnan(pattern[1]):
                    continue
                else:
                    self.ax[1].plot(pattern[i], (pattern[i+1] - pattern[i])*86400, '-o', color ='k')
            else:
                if np.isnan(pattern[i]) or np.isnan(pattern[i+1]) or np.isnan(pattern[i+2]):
                    continue
                else:
                    dP1 = (pattern[i+1] - pattern[i])*86400
                    dP2 = (pattern[i+2] - pattern[i+1])*86400
                    self.ax[1].plot([pattern[i], pattern[i+1]], [dP1, dP2] , '-o', color = 'k')
        
        dp_ = (pattern[1:] - pattern[:-1])*86400
        self.ax[1].plot(pattern[:-1], dp_, '--o', color='k', label='HMM')

        self.ax[0].set_xlim(0.9*np.nanmin(pattern), 1.1*np.nanmax(pattern))

        self.psp_dict['pattern']           = pattern
        self.psp_dict['uncertainty']       = uncertainty
        self.psp_dict['total_prob']        = probability_tot
        self.psp_dict['transmission_prob'] = probability_trans
        self.psp_dict['emission_prob']     = probability_emis
        self.psp_dict['initial_periods_indices'] = [ip1, ip2, ip3]
        self.psp_dict['nonlin_id']         = nonlin_id_arr
        
    def reset_selections(self, event):
        '''
        Reset button to restart if the initial suggestion was not OK.
        '''
        for ind in self.initial_period_index:
            self.lines[ind].set_color('k')
        self.initial_period_index.clear()
        self.fig.canvas.draw_idle()
        self.ax[0].clear()
        self.ax[1].clear()
        self.ax[2].clear()
        self.lines.clear()
        for i, p in enumerate(self.P_obs):
            line_ = self.ax[0].vlines(x=p, ymin=0, ymax=self.A[i], color='k', picker=True)
            self.lines.append(line_)  # Re-add lines
        self.picker = self.fig.canvas.mpl_connect('pick_event', self.onpick)
        self.ax[0].set_xlim(0,4)
        self.ax[2].set_xlabel(r'$P\,[d]$', fontsize=self.fontsize)
        self.ax[1].set_ylabel(r'$\Delta P\,[s]$', fontsize=self.fontsize)
        self.ax[2].set_ylabel(r'$p_{\rm tot}$', fontsize=self.fontsize)
        self.ax[0].set_ylabel(r'$\log A$', fontsize=self.fontsize)
        self.ax[0].set_ylim(0, 1.1*np.max(self.A))
        self.ax[2].set_ylim(0, 1.05)
        self.ax[0].set_xlim(0,4)
        self.ax[0].set_title(f'm = {self.m}')
        print('=== RESET ===')
        
    def save_selections(self, event):
        '''
        Button to save the found period-spacing pattern to a dictionary. The following quantities are saved.

        pattern:                 periods of the modes in the pattern (days).
        uncertainty:             observational errors on the periods of the modes.
        total_prob:              total probability = p_trans*p_emis
        transmission_prob:       transmission probability
        emission_prob:           emission probability
        initial_periods_indices: indices of the three initial modes
        nonlin_id_arr:           non-linear mode ID (i.e. combination frequency)
        '''
        self.fig.savefig(f'{self.WORK_DIR}/PSP_KIC0{self.KIC}_strategy_5_l1m{self.m}_test.png', dpi=300)

        with open(f'{self.WORK_DIR}/PSP_KIC0{self.KIC}_strategy_5_l1m{self.m}_test.pkl', 'wb') as f:
            pickle.dump(self.psp_dict, f)
        print(f'Pattern saved as PSP_KIC0{self.KIC}_strategy_5_l1m{self.m}_test.pkl')
        

# Run the interactive sherlock script as

In [11]:
# create figure object
fig, ax = plt.subplots(3, 1, sharex=True, figsize=(8,6))
# set y limit
ax[0].set_ylim(0,3);
# create the interactive object
my_interactive_plot = InteractivePlot(WORK_DIR=WORK_DIR, KIC=KIC, m=m, P_obs=P_obs, A=A, pe=pe, nonlin_id=nonlin_id, fontsize=fontsize, fig=fig, multi_ax=ax)
# connect handlers
my_interactive_plot.connect()
# plot actions
fig.canvas.draw_idle()
# layout
ax[2].set_xlabel(r'$P\,[d]$', fontsize=fontsize)
ax[1].set_ylabel(r'$\Delta P\,[s]$', fontsize=fontsize)
ax[2].set_ylabel(r'$p_{\rm tot}$', fontsize=fontsize)
ax[0].set_ylabel(r'$\log A$', fontsize=fontsize)
ax[0].set_ylim(0, 1.1*np.max(A))
ax[2].set_ylim(0, 1.05)
ax[0].set_xlim(0,4)
ax[0].set_title(f'm = {m}')
# show
plt.show(block=False)

Selected points: [21, 25, 23]
0.5689151090606535 0.5794579567123641 0.5877468885437472


  ax.plot(xx, np.log10(p_trans_ipol(xx)/np.max(p_trans_ipol(xx))), color = 'k')
  ax.plot(xx, np.log10(p_trans_ipol(xx)/np.max(p_trans_ipol(xx))), color = 'k')


Towards longer periods (right)...
pobs [0.596227338449971  0.6003022108489772 0.5914787740152606]
Aobs [0.3119702954274389 0.4053199121065496 0.2674025021647655]
0 New period found right at 0.596227338449971 with probability 0.9994007474739766 (p_trans = 0.002351349075712709, p_emis = 0.31824197348928573)
'freq_1 + freq_4'
pobs [0.6050234482569427]
Aobs [0.0943355794178334]
1 New period found right at 0.6050234482569427 with probability 1.0 (p_trans = 0.0009690745043048432, p_emis = 1.0)
'freq_1 + freq_71'
pobs [0.6088711510977506]
Aobs [0.240545209825661]
2 New period found right at 0.6088711510977506 with probability 1.0 (p_trans = 2.7089239336062512e-05, p_emis = 1.0)
'freq_13 + freq_20'
0.6112772868725891 0.6126153516389321 0.6299853455521847
3 skipping radial order
0.6100730804476069 0.6163308617280341 0.6287811391272025
4 skipping radial order
0.6136850882370825 0.6200176813650565 0.632393146916678
5 skipping radial order
Skipped too many modes. Stopping.
Towards shorter periods 

  fig, ax = plt.subplots()


pobs [1.09483313481865   1.0882580909172044]
Aobs [1.5237323913090075 0.0568773132353544]
5 New period found right at 1.09483313481865 with probability 0.9984986104056401 (p_trans = 0.0006866117303740542, p_emis = 0.9669957779894293)
'freq_6'
pobs [1.1175112113357548 1.1113456870471874]
Aobs [2.0805475045507724 0.0308712083428735]
6 New period found right at 1.1113456870471874 with probability 1.0 (p_trans = 0.00028212408750714066, p_emis = 0.008840306780705444)
'freq_1 + freq_107'
pobs [1.1175112113357548 1.127315717724157  1.1229857674264674]
Aobs [2.0805475045507724 0.1597152982975094 0.1018356422796512]
7 New period found right at 1.127315717724157 with probability 0.9269169707930328 (p_trans = 0.00025447440607738345, p_emis = 0.011735554640678722)
'freq_1 + freq_17'
pobs [1.139936169463765  1.1491010212601918 1.1347228663467865
 1.1458380028813484]
Aobs [0.9121729917988908 0.2753587073335317 0.1537732815519657
 0.1260863436162694]
8 New period found right at 1.139936169463765 with

In [47]:
my_interactive_plot.disconnect()