# Querying Kepler and TESS for Transit Data

This notebook walks through the complete workflow of downloading Kepler (or TESS) data, cleaning light curves, fitting individual transits, and finally building an Observed minus Calculated (O–C) diagram.

Each code block below now contains extensive comments so that it can serve as a step-by-step training module.

## 1. Environment Setup and Helper Functions

This section installs required packages, clones repositories, and defines utility functions used throughout the notebook.

In [None]:
"""
kepler_individ_lc.py

Script to download, clean, and fit individual Kepler transits for a specified target,
then create an O–C diagram. Notebooks or other scripts can be similarly adapted.

Important: NASA Exoplanet Archive typically provides T0 in BJD (~2454xxx),
while Kepler data from lightkurve is in BKJD = BJD - 2454833.0.
Hence, we subtract 2454833 from T0 before phase-folding and fitting.
"""

# --- Environment setup: install and clone required packages ---
!pip install ultranest rebound
!git clone https://github.com/rzellem/EXOTIC
%cd EXOTIC
!git checkout tess
%cd ..
import sys
sys.path.append("F:/EXOTIC")

!git clone https://github.com/pearsonkyle/Nbody-AI
sys.path.append('F:/Nbody-AI/nbody')
# Add the cloned repository to the Python path so we can import it
sys.path.insert(0, 'F:/Nbody-AI')

import sys
import os
import copy
import json
import pickle
import argparse
import requests
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from io import StringIO
from pandas import read_csv
from scipy.ndimage import binary_dilation, label
from scipy.signal import savgol_filter, medfilt

import lightkurve as lk
from astropy import constants as const
from astropy import units as u
from wotan import flatten

# EXOTIC imports (installed via "pip install exotic")
from exotic.api.elca import lc_fitter
from exotic.api.output_aavso import OutputFiles

# For transit modeling
from pylightcurve import exotethys

# ---------------------------------------------------------------------
# 1) Helper functions
# ---------------------------------------------------------------------

def tap_query(base_url, query, dataframe=True):
    """
    Table Access Protocol (TAP) query to NASA Exoplanet Archive.
    Builds a URL from 'base_url' and 'query', sends a GET request,
    returns response as a DataFrame or raw text.
    """
    uri_full = base_url
    for k in query:
        if k != "format":
            uri_full += f"{k} {query[k]} "
    uri_full = f"{uri_full[:-1]} &format={query.get('format', 'csv')}"
    uri_full = uri_full.replace(' ', '+')
    print("Query:", uri_full)

    response = requests.get(uri_full, timeout=300)
    if dataframe:
        return read_csv(StringIO(response.text))
    else:
        return response.text

def nea_scrape(target=None):
    """
    Use NASA Exoplanet Archive to fetch planet parameters for 'target'
    (e.g., 'Kepler-18 b'). Returns a DataFrame with many columns
    like pl_name, pl_orbper, st_rad, st_logg, etc.
    """
    uri_ipac_base = "https://exoplanetarchive.ipac.caltech.edu/TAP/sync?query="
    uri_ipac_query = {
        "select": (
            "pl_name,hostname,pl_radj,pl_radjerr1,ra,dec,"
            "pl_ratdor,pl_ratdorerr1,pl_ratdorerr2,pl_orbincl,pl_orbinclerr1,pl_orbinclerr2,"
            "pl_orbper,pl_orbpererr1,pl_orbpererr2,pl_orbeccen,pl_orbsmax,pl_orbsmaxerr1,pl_orbsmaxerr2,"
            "pl_orblper,pl_tranmid,pl_tranmiderr1,pl_tranmiderr2,"
            "pl_ratror,pl_ratrorerr1,pl_ratrorerr2,"
            "st_teff,st_tefferr1,st_tefferr2,st_met,st_meterr1,st_meterr2,"
            "st_logg,st_loggerr1,st_loggerr2,st_mass,st_rad,st_raderr1"
        ),
        "from": "pscomppars",
        "where": "tran_flag = 1",
        "format": "csv"
    }
    if target:
        uri_ipac_query["where"] += f" and pl_name = '{target}'"
    return tap_query(uri_ipac_base, uri_ipac_query)

def sigma_clip(ogdata, dt, iterations=1):
    """
    Iterative sigma-clipping on 'ogdata' using a Savitzky-Golay filter
    with window length 'dt' points. Outliers >3σ replaced by NaN.
    """
    mask = np.ones(ogdata.shape, dtype=bool)
    for _ in range(iterations):
        mdata = savgol_filter(ogdata[mask], dt, 2)
        res = ogdata[mask] - mdata
        std = np.nanstd(res)
        mask[mask] = np.abs(res) < 3*std

    mdata = savgol_filter(ogdata[mask], dt, 2)
    data = copy.deepcopy(ogdata)
    data[~mask] = np.nan
    return data, np.std(ogdata[mask] - mdata)

def check_std(time, flux, dt=0.5):
    """
    Sort by time, apply a Savitzky-Golay filter, return std of residuals.
    dt in hours => affects smoothing window.
    """
    tdt = np.diff(np.sort(time)).mean()
    si = np.argsort(time)
    # At least 15 points or so in the smoothing window
    wsize = 1 + 2*int(max(15, dt/(24*tdt)))
    sflux = savgol_filter(flux[si], wsize, 2)
    return np.nanstd(flux - sflux)

def stellar_mass(logg, rs):
    """
    Estimate stellar mass (Msun) from logg (cgs) and radius (Rsun).
    Msun ~ (R^2 * 10^logg / G).
    """
    return ((rs*u.R_sun)**2 * 10**logg*(u.cm/u.s**2) / const.G).to(u.M_sun).value

def sa(m, P):
    """
    Semi-major axis (AU) from star mass m (Msun) and period P (days).
    """
    return ((const.G*m*u.M_sun*P*u.day**2/(4*np.pi**2))**(1./3)).to(u.AU).value

def parse_args():
    """
    Parses command-line arguments for analyzing Kepler data.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("-t", "--target", type=str, default="Kepler-1710 b",
                        help="Kepler target, e.g. 'Kepler-19 b'.")
    parser.add_argument("-o", "--output", type=str, default="kepler_individ_lc_output",
                        help="Output directory for results.")
    parser.add_argument("--quarter", type=int, default=0,
                        help="Which Kepler quarter to process (0=all).")
    parser.add_argument("-r", "--reprocess", action='store_true', default=False,
                        help="Reprocess even if existing results found.")
    parser.add_argument("--ars", action='store_true', default=False,
                        help="Include a/R* in the fit bounds.")
    parser.add_argument("--tls", action='store_true', default=False,
                        help="Perform TLS search on final residuals (not fully used here).")

    args, _ = parser.parse_known_args()
    return args

# ---------------------------------------------------------------------
# 2) Main Script
# ---------------------------------------------------------------------
# The following block ties everything together:
#   1. Search for public light curve data.
#   2. Retrieve or estimate stellar/planet parameters.
#   3. Clean and flatten each quarter of data.
#   4. Fit individual transits and compute timing residuals.
#   5. Produce an O--C diagram summarizing all epochs.

if __name__ == "__main__":
    args = parse_args()
    if not os.path.exists(args.output):
        os.makedirs(args.output)

    # Planet-specific directory
    planetdir = os.path.join(args.output, args.target.replace(' ', '_').replace('-', '_'))
    if not os.path.exists(planetdir):
        os.mkdir(planetdir)
    else:
        # If there's a 'global_fit.png' and no --reprocess, skip
        if os.path.exists(os.path.join(planetdir, "global_fit.png")) and not args.reprocess:
            raise Exception("Target appears already processed. Use --reprocess to overwrite.")

    planetname = args.target.lower().replace(' ', '').replace('-', '')

    # 1) Search for Kepler data
    search_result = lk.search_targetpixelfile(args.target, mission='Kepler')
    print(search_result)
    if len(search_result) == 0:
        raise Exception(f"No Kepler data found for target: {args.target}")

    # 2) Load or scrape prior planet parameters
    prior_path = os.path.join(planetdir, planetname + "_prior.json")
    if os.path.exists(prior_path):
        prior = json.load(open(prior_path, "r"))
    else:
        # Query NASA Exoplanet Archive
        nea_df = nea_scrape(args.target)
        if len(nea_df) == 0:
            raise Exception(f"No planet found in NASA Exoplanet Archive: {args.target}")

        prior = {}
        for key in nea_df.columns:
            prior[key] = nea_df[key].values[0]

        # if pl_ratdor is NaN, compute from pl_orbsmax / st_rad
        if np.isnan(prior.get('pl_ratdor', np.nan)) and not np.isnan(prior.get('pl_orbsmax', np.nan)):
            prior['pl_ratdor'] = (
                (prior['pl_orbsmax']*u.AU)/(prior['st_rad']*u.R_sun)
            ).to(u.dimensionless_unscaled).value
            prior['pl_ratdorerr1'] = 0.1 * prior['pl_ratdor']
            prior['pl_ratdorerr2'] = -0.1 * prior['pl_ratdor']

        # Save local prior
        with open(prior_path, 'w', encoding='utf8') as jf:
            json.dump(prior, jf, indent=4)

    # 3) Limb darkening for Kepler band
    #    If these are NaN, set defaults
    st_logg = prior.get('st_logg', 4.4)
    if np.isnan(st_logg): st_logg = 4.4
    st_teff = prior.get('st_teff', 5700)
    if np.isnan(st_teff): st_teff = 5700
    st_met = prior.get('st_met', 0.0)
    if np.isnan(st_met): st_met = 0.0

    u0, u1, u2, u3 = exotethys(st_logg, st_teff, st_met,
                               'Kepler', method='claret', stellar_model='phoenix')

    # 4) Prepare structure to store data from each quarter
    sv = {
        'lightcurves': [],
        'quarters': {},
        'time': [],
        'flux': [],
        'flux_err': [],
        'trend': [],
        'quarter': []
    }

    # If there's a 'quarter' column in search_result, we can check it
    if "quarter" in search_result.table.colnames:
        unique_quarters = np.unique(search_result.table["quarter"])
    else:
        # fallback: each row is "quarterlike"
        unique_quarters = np.arange(len(search_result))

    # If user wants only a specific quarter
    if args.quarter != 0 and args.quarter not in unique_quarters:
        raise Exception(f"No data for quarter {args.quarter} in search_result.")

    # 5) Download & Clean each quarter
    for idx in range(len(search_result)):

        if "quarter" in search_result.table.colnames:
            quarter = search_result.table["quarter"][idx]
            if args.quarter != 0 and quarter != args.quarter:
                continue
        else:
            quarter = idx + 1  # fallback label

        print(f"Downloading Quarter {quarter} ...")
        try:
            tpf = search_result[idx].download(quality_bitmask='hard')
        except Exception as err:
            print(f"Failed to download quarter {quarter}: {err}")
            continue

        # Aperture selection
        lc = tpf.to_lightcurve(aperture_mask=tpf.pipeline_mask)
        nmask = np.isnan(lc.flux.value)
        lstd = check_std(lc.time.value[~nmask], lc.flux.value[~nmask])

        aper_final = tpf.pipeline_mask
        for it in [1, 2]:
            bigger_mask = binary_dilation(tpf.pipeline_mask, iterations=it)
            lcd = tpf.to_lightcurve(aperture_mask=bigger_mask)
            nmaskd = np.isnan(lcd.flux.value)
            lstdd = check_std(lcd.time.value[~nmaskd], lcd.flux.value[~nmaskd])
            if lstdd < lstd:
                aper_final = bigger_mask
                lstd = lstdd

        lc = tpf.to_lightcurve(aperture_mask=aper_final)
        tpf.plot(aperture_mask=aper_final)
        plt.savefig(os.path.join(planetdir, f"{planetname}_quarter_{quarter}_aperture.png"))
        plt.close()

        # Remove first ~30 min after big gaps
        time_arr = lc.time.value  # Kepler BKJD
        flux_arr = lc.flux.value
        tmask = np.ones_like(time_arr, dtype=bool)

        srt = np.argsort(time_arr)
        dt_ = np.diff(time_arr[srt])
        if len(dt_) and not np.isnan(dt_).all():
            median_dt = np.nanmedian(dt_)
            # 30 min in days => 30./1440. => how many points
            ndt = int(np.round((30./1440.) / median_dt))
        else:
            ndt = 30  # fallback

        tmask[0:ndt] = False
        big_gap = dt_ > (1./24.)  # 1 hour
        biggap_idx = np.argwhere(big_gap).flatten()
        for gap_idx in biggap_idx:
            tmask[gap_idx:gap_idx+ndt] = False

        nmask2 = ~np.isnan(flux_arr)
        tmask = tmask & nmask2

        time_c = time_arr[tmask]
        flux_c = flux_arr[tmask]

        # Remove outliers
        mflux = medfilt(flux_c, kernel_size=15)
        rflux = flux_c / mflux
        newflux, std_ = sigma_clip(rflux, dt=15, iterations=1)
        maskNaN = np.isnan(newflux)
        time_c = time_c[~maskNaN]
        flux_c = flux_c[~maskNaN]

        # Flatten (remove stellar variability)
        dflux = np.copy(flux_c)
        dtrend = np.ones(len(time_c))

        diff = np.diff(time_c)
        if len(diff):
            day_gaps = diff > 0.5
            dmask = np.concatenate([[True], ~day_gaps])
        else:
            dmask = np.ones(len(time_c), dtype=bool)

        seg_label, _ = label(dmask)
        for seg_id in np.unique(seg_label):
            if seg_id == 0:
                continue
            seg = (seg_label == seg_id)
            if seg.sum() < 5:
                continue
            flc, tlc = flatten(time_c[seg], flux_c[seg],
                               window_length=2.0, return_trend=True,
                               method='biweight')
            dflux[seg] = flc
            dtrend[seg] = tlc

        # Store in state
        sv['quarters'][quarter] = True
        sv['time'].append(time_c)
        sv['flux'].append(dflux)
        sv['flux_err'].append((dtrend**0.5)/np.nanmedian(dtrend))
        sv['trend'].append(dtrend)
        sv['quarter'].append(np.ones(len(time_c))*quarter)

        # Plot
        plt.figure()
        plt.plot(time_c, flux_c, 'k.', label='Flux')
        plt.plot(time_c, dtrend, 'r--', label='Trend')
        plt.title(f"{args.target} - Quarter {quarter}")
        plt.xlabel("Time [BKJD]")  # BKJD = BJD - 2454833
        plt.ylabel("Flux")
        plt.legend()
        outfig = os.path.join(planetdir, f"{planetname}_quarter_{quarter}_trend.png")
        plt.savefig(outfig)
        plt.close()

    if not sv['time']:
        raise Exception("No valid quarters processed. Check logs for errors.")

    # Concatenate all quarters
    time_all = np.concatenate(sv['time'])
    flux_all = np.concatenate(sv['flux'])
    fluxerr_all = np.concatenate(sv['flux_err'])
    trend_all = np.concatenate(sv['trend'])
    quarter_all = np.concatenate(sv['quarter'])

    # Remove nans
    nanmask = np.isnan(flux_all) | np.isnan(fluxerr_all)
    time_all = time_all[~nanmask]
    flux_all = flux_all[~nanmask]
    fluxerr_all = fluxerr_all[~nanmask]
    quarter_all = quarter_all[~nanmask]
    trend_all = trend_all[~nanmask]

    # Save combined light curve to CSV
    df = pd.DataFrame({
        'time_bkjd': time_all,
        'flux': flux_all * trend_all,
        'flux_err': fluxerr_all * trend_all,
        'quarter': quarter_all
    })
    outcsv = os.path.join(planetdir, f"{planetname}_lightcurve.csv")
    df.to_csv(outcsv, index=False)
    print("Saved combined lightcurve:", outcsv)

    # 6) Transit Fitting / O-C

    # Convert T0 from BJD to BKJD by subtracting 2454833.0
    # if we have a valid pl_tranmid
    if not np.isnan(prior.get('pl_tranmid', np.nan)):
        tmid_bjd = float(prior['pl_tranmid'])
    else:
        tmid_bjd = 2454950.0  # fallback guess

    tmid_bkjd = tmid_bjd - 2454833.0  # *CRITICAL SHIFT*

    # Orbital period
    if not np.isnan(prior.get('pl_orbper', np.nan)):
        period = float(prior['pl_orbper'])
    else:
        period = 10.0

    # a/R*
    if not np.isnan(prior.get('pl_ratdor', np.nan)):
        ars = float(prior['pl_ratdor'])
    else:
        ars = 15.0

    # rprs = (Rplanet/Rstar)
    if not np.isnan(prior.get('pl_radj', np.nan)):
        rprs = float((prior['pl_radj']*u.R_jup)/(prior['st_rad']*u.R_sun))
    else:
        rprs = 0.05

    inc = prior.get('pl_orbincl', 88.0)
    if np.isnan(inc): inc = 88.0

    ecc = prior.get('pl_orbeccen', 0.0)
    if np.isnan(ecc): ecc = 0.0

    omega = prior.get('pl_orblper', 0.0)
    if np.isnan(omega): omega = 0.0

    tpars = {
        'rprs': rprs,
        'ars': ars,
        'per': period,
        'inc': inc,
        'tmid': tmid_bkjd,
        'omega': omega,
        'ecc': ecc,
        'a1': 0, 'a2': 0,
        'u0': u0, 'u1': u1, 'u2': u2, 'u3': u3
    }

    # Quick flux_err guess
    if len(flux_all):
        phot_std = np.nanstd(flux_all[flux_all < 1.1])  # rough
        fluxerr0 = phot_std / flux_all
    else:
        fluxerr0 = np.array([])

    # Phase times
    tphase = (time_all - tmid_bkjd)/period
    # approximate transit duration
    pdur = 2.0 * np.arctan(1./ars) / (2*np.pi) if ars > 0 else 0.05

    # We'll define event epochs from floor(tphase)
    events = np.unique(np.floor(tphase))

    # Prepare arrays to store O-C
    all_epochs = []
    all_oc = []
    all_oc_err = []

    # Bounds
    mybounds = {
        'rprs': [0, 3*rprs],
        'tmid': [tmid_bkjd - 0.3, tmid_bkjd + 0.3],
        'inc': [70, 90]
    }
    if args.ars:
        mybounds['ars'] = [ars*0.5, ars*2.0]

    # Iterate through each expected transit epoch and
    # fit a local light curve model to refine the mid-transit time.
    for e in events:
        # mask within ~2 x pdur of the predicted transit
        intrans = (tphase >= e - 2*pdur) & (tphase <= e + 2*pdur)
        if intrans.sum() < 10:
            continue

        # update guess Tmid for that epoch
        guess_tmid = tmid_bkjd + e*period
        local_bounds = copy.deepcopy(mybounds)
        # allow +/- 0.2 * (period*pdur) around guess
        local_bounds['tmid'] = [guess_tmid - 0.2*period*pdur,
                                guess_tmid + 0.2*period*pdur]
        tpars['tmid'] = guess_tmid

        airmass = np.zeros(intrans.sum())
        try:
            fit = lc_fitter(
                time_all[intrans], flux_all[intrans],
                fluxerr0[intrans],
                airmass,
                tpars,
                local_bounds
            )
        except:
            print(f"Failed to fit transit at epoch {int(e)}.")
            continue

        rprs2 = fit.parameters['rprs']**2
        rprs2err = 2 * fit.parameters['rprs'] * fit.errors['rprs']
        if (rprs2 - rprs2err <= 0):
            print(f"Skipping epoch {int(e)}: rprs^2 ~ 0 or negative.")
            continue

        lcdata = {
            'time': fit.time,           # BKJD
            'flux': fit.data,
            'residuals': fit.residuals,
            'phase': fit.phase,
            'pars': fit.parameters,
            'errors': fit.errors,
            'rchi2': fit.chi2/len(fit.time),
            'epoch': int(e)
        }
        sv['lightcurves'].append(lcdata)

        # Observed Tmid (in BKJD)
        Tobs_bkjd = fit.parameters['tmid']
        Tobs_err = fit.errors['tmid']

        # Predicted Tmid for epoch e in BKJD
        Tcalc_bkjd = tmid_bkjd + e*period
        OC_bkjd = Tobs_bkjd - Tcalc_bkjd
        OC_err = Tobs_err  # ignoring ephemeris uncertainty

        all_epochs.append(e)
        all_oc.append(OC_bkjd)
        all_oc_err.append(OC_err)

        # Plot bestfit
        fig, ax = fit.plot_bestfit(title=f"{args.target}, Quarter Fit, E={int(e)}")
        outname = f"{planetname}_E{int(e)}_fit.png"
        plt.savefig(os.path.join(planetdir, outname))
        plt.close()

    # Dump pickled results
    with open(os.path.join(planetdir, f"{planetname}_data.pkl"), 'wb') as pf:
        pickle.dump(sv, pf)

    # 7) O-C Diagram
    # Compile all of the measured transit times and compare them
    # to the predictions from the ephemeris. The result is a classic
    # Observed minus Calculated (O-C) plot showing timing variations.
    all_epochs = np.array(all_epochs)
    all_oc = np.array(all_oc)
    all_oc_err = np.array(all_oc_err)

    plt.figure(figsize=(8,5))
    plt.errorbar(all_epochs, all_oc, yerr=all_oc_err, fmt='o',
                 color='blue', ecolor='gray', capsize=3,
                 label="O-C (BKJD)")
    plt.axhline(0, color='k', linestyle='--', alpha=0.7)
    plt.xlabel("Epoch")
    plt.ylabel("O - C [days, BKJD]")
    plt.title(f"O-C Diagram: {args.target}")
    plt.legend()
    plt.tight_layout()

    oc_plot_file = os.path.join(planetdir, f"{planetname}_OC_diagram.png")
    plt.savefig(oc_plot_file)
    plt.show()

    print("Done! O-C diagram saved at:", oc_plot_file)
    print("Remember these times are in BKJD. (BJD - 2454833.0)")


## 2. Lomb–Scargle Transit Timing Analysis

Here we implement an iterative Lomb–Scargle procedure to identify periodic signals in transit timing variations and helper routines for querying the NASA Exoplanet Archive.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import pandas as pd
import requests
from io import StringIO
from astropy.timeseries import LombScargle

# Set plot tick label sizes for consistency
plt.rcParams['xtick.labelsize'] = 18
plt.rcParams['ytick.labelsize'] = 18

###############################################################################
# LOMB–SCARGLE TTV ANALYSIS (Iterative Removal of Dominant Periods)
###############################################################################
# This cell contains a helper function that applies a weighted
# Lomb--Scargle periodogram multiple times, removing the strongest
# signal after each iteration. The goal is to identify periodic
# signatures in the transit timing variations.

def run_lomb_scargle_TTV(all_epochs, all_oc, all_oc_err, period, show_plots=True):
    """
    Tutorial function: given arrays of epochs and O--C values this routine
    demonstrates how to clean the timing data and perform an iterative
    Lomb--Scargle period search.
    1) Converts O–C from days to minutes.
    2) Filters points with O–C uncertainty < 3 min.
    3) Optionally plots the TTV curve (with an 8th-order polynomial fit) vs shifted epoch.
    4) Runs weighted Lomb–Scargle iteratively:
         - In each iteration, considers only periods < 40 epochs.
         - Finds the dominant peak, computes its amplitude and FAP.
         - Subtracts the best-fit sinusoid from the current signal.
         - Repeats for a total of 10 iterations.
    5) Returns a list of results (one per iteration) and the final mask used.
    """
    # Convert O–C from days to minutes
    oc_min = all_oc * 24.0 * 60.0
    oc_err_min = all_oc_err * 24.0 * 60.0

    # Filter for uncertainties less than 3 minutes
    mask = oc_err_min < 10
    if np.sum(mask) == 0:
        raise ValueError("No O–C points remain with uncertainties < 3 min!")

    epochs_f = all_epochs[mask]
    ttv_f = oc_min[mask]
    ttv_err_f = oc_err_min[mask]

    # Create a shifted epoch axis for plotting (shift so minimum epoch is zero)
    x_data = epochs_f
    x_shifted = x_data - np.min(x_data)

    # Plot the TTV curve with error bars and an 8th-order polynomial fit
    if show_plots:
        plt.figure(figsize=(8, 5))
        plt.errorbar(x_shifted, ttv_f, yerr=ttv_err_f, fmt='o', color='black',
                     ecolor='gray', capsize=3)
        plt.title("Figure 1a: Observed-Calculated (O–C) Plot")
        plt.xlabel("Epochs (shifted)")
        plt.ylabel("TTV Amplitude (minutes)")
        plt.tight_layout()
        plt.show()

        # Polynomial fit for visualization (8th order)
        coeffs = np.polyfit(x_shifted, ttv_f, 8)
        poly_fit = np.poly1d(coeffs)
        x_fit = np.linspace(np.min(x_shifted), np.max(x_shifted), 500)
        y_fit = poly_fit(x_fit)

        plt.figure(figsize=(8, 5))
        plt.errorbar(x_shifted, ttv_f, yerr=ttv_err_f, fmt='o', color='black',
                     ecolor='gray', capsize=3, label='Data')
        plt.plot(x_fit, y_fit, color='black', label='8th Order Fit')
        plt.title("Figure 1b: O–C Plot with Polynomial Fit")
        plt.xlabel("Epochs (shifted)")
        plt.ylabel("TTV Amplitude (minutes)")
        plt.legend()
        plt.tight_layout()
        plt.show()

    # Mean-center the TTV data for Lomb–Scargle
    y_data = ttv_f - np.mean(ttv_f)

    # Set the current signal to be iteratively cleaned
    current_signal = y_data.copy()
    results = []

    # Iteratively remove the dominant periodic signal 10 times
    for i in range(10):
        ls_current = LombScargle(epochs_f, current_signal, ttv_err_f)
        frequency, power = ls_current.autopower(
            minimum_frequency=1.0/50.0,
            maximum_frequency=0.5,
            samples_per_peak=10
        )
        periods = 1.0 / frequency  # Periods in epochs

        # Limit to periods less than 40 epochs
        valid = periods < 40
        if not np.any(valid):
            print(f"No valid periods (periods < 40 epochs) found in iteration {i+1}.")
            break

        periods_valid = periods[valid]
        power_valid = power[valid]

        idx_max = np.argmax(power_valid)
        dom_period = periods_valid[idx_max]
        dom_power = power_valid[idx_max]

        var_current = np.var(current_signal)
        amplitude = 2.0 * np.sqrt(dom_power * var_current) if dom_power > 0 else 0.0
        fap = ls_current.false_alarm_probability(dom_power)

        results.append({
            'iteration': i + 1,
            'dominant_period': dom_period,
            'power': dom_power,
            'amplitude': amplitude,
            'fap': fap
        })

        print(f"\n===== Iteration {i+1} =====")
        print(f"Dominant Period (epochs) = {dom_period:.3f}")
        print(f"Power                  = {dom_power:.4f}")
        print(f"Amplitude (min)        = {amplitude:.3f}")
        print(f"FAP                    = {fap:.4g}")

        if show_plots:
            plt.figure(figsize=(8, 5))
            plt.plot(periods, power, label=f"Iteration {i+1} Power Spectrum")
            plt.scatter(dom_period, dom_power, color='red', zorder=3,
                        label=f"Dominant Period ({dom_period:.2f} epochs)")
            plt.xlabel("Period (epochs)")
            plt.ylabel("Power")
            plt.title(f"Lomb–Scargle Periodogram (Iteration {i+1})")
            plt.xlim(2, 10)
            plt.legend()
            plt.tight_layout()
            plt.show()

        # Subtract the best–fit sinusoid for the dominant period from the current signal
        model = ls_current.model(epochs_f, frequency=1.0/dom_period)
        current_signal = current_signal - model

    return results, mask

###############################################################################
# NASA EXOPLANET ARCHIVE QUERY FUNCTIONS
###############################################################################

def tap_query(base_url, query):
    uri_full = base_url
    for k in query:
        if k != "format":
            uri_full += f"{k} {query[k]} "
    uri_full = f"{uri_full[:-1]} &format={query.get('format', 'csv')}"
    uri_full = uri_full.replace(' ', '+')
    print("Query URL:", uri_full)

    response = requests.get(uri_full, timeout=60)
    if response.status_code != 200:
        raise RuntimeError(f"Query failed: status {response.status_code}\n{response.text}")

    df = pd.read_csv(StringIO(response.text))
    return df

def query_inner_params(target_name):
    """
    Query NASA Exoplanet Archive for parameters:
      st_mass, pl_bmasse, pl_orbper, pl_orbeccen, pl_orblper, pl_orbincl.
    The helper returns the row with the most non-null values among
    these columns so that we have the best available parameters for
    subsequent analysis.
    """
    base_url = "https://exoplanetarchive.ipac.caltech.edu/TAP/sync?query="
    query_dict = {
        "select": "st_mass, pl_bmasse, pl_orbper, pl_orbeccen, pl_orblper, pl_orbincl",
        "from": "ps",
        "where": f"pl_name = '{target_name}'",
        "format": "csv"
    }
    df = tap_query(base_url, query_dict)
    if len(df) == 0:
        print(f"No rows returned for {target_name}. Possibly unrecognized planet name.")
        return {}

    cols = ["st_mass", "pl_bmasse", "pl_orbper", "pl_orbeccen", "pl_orblper", "pl_orbincl"]
    best_idx = None
    best_count = -1
    for i in range(len(df)):
        count_valid = df.iloc[i][cols].notna().sum()
        if count_valid > best_count:
            best_count = count_valid
            best_idx = i
    if best_idx is None:
        return {}

    row = df.iloc[best_idx]
    out = {
        "st_mass_msun":   row.get("st_mass", np.nan),
        "pl_bmasse":      row.get("pl_bmasse", np.nan),
        "pl_orbper_days": row.get("pl_orbper", np.nan),
        "pl_orbeccen":    row.get("pl_orbeccen", np.nan),
        "pl_orblper":     row.get("pl_orblper", np.nan),
        "pl_orbincl_deg": row.get("pl_orbincl", np.nan)
    }
    return out

###############################################################################
# Append results to ocean.csv
###############################################################################

def write_to_ocean_csv(params, amplitude_dom, dominant_period, power_dominant,
                       second_dominant_period, power_second_dominant, filename="ocean.csv"):
    columns = [
        "Stellar Mass (Msun)",
        "Inner Mass (Mearth)",
        "Inner Period (days)",
        "Inner Eccentricity",
        "Inner Inclination",
        "Inner Omega",
        "Amplitude of Dominant Period Test (P1)",
        "Dominant Period Planet 1",
        "Dominant Period Power Planet 1",
        "Second Dominant Period Planet 1",
        "Second Dominant Period Power Planet 1"
    ]
    new_row = {
        "Stellar Mass (Msun)":   params.get("st_mass_msun", np.nan),
        "Inner Mass (Mearth)":   params.get("pl_bmasse", np.nan),
        "Inner Period (days)":   params.get("pl_orbper_days", np.nan),
        "Inner Eccentricity":    params.get("pl_orbeccen", np.nan),
        "Inner Inclination":     params.get("pl_orbincl_deg", np.nan),
        "Inner Omega":           params.get("pl_orblper", np.nan),
        "Amplitude of Dominant Period Test (P1)": amplitude_dom,
        "Dominant Period Planet 1": dominant_period,
        "Dominant Period Power Planet 1": power_dominant,
        "Second Dominant Period Planet 1": second_dominant_period,
        "Second Dominant Period Power Planet 1": power_second_dominant
    }
    if not os.path.exists(filename):
        df = pd.DataFrame(columns=columns)
        df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)
        df.to_csv(filename, index=False)
        print(f"Created {filename} and wrote first row.")
    else:
        df = pd.read_csv(filename)
        df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)
        df.to_csv(filename, index=False)
        print(f"Appended a new row to existing {filename}.")

###############################################################################
# Write planet_name.csv with Epoch and O–C (days)
###############################################################################

def write_planet_csv(target_name, all_epochs, all_oc, all_oc_err):
    """
    Creates a CSV file named "<target_name>.csv" (spaces and dashes replaced by underscores)
    with columns: Epoch, O–C (days), and O–C Error (days) for data points
    with timing uncertainties below 15 minutes. This allows the timing
    data to be imported into other analysis tools.
    """
    # Convert O–C from days to minutes for uncertainty filtering
    oc_min = all_oc * 24.0 * 60.0
    oc_err_min = all_oc_err * 24.0 * 60.0

    mask = oc_err_min < 20.0
    if np.sum(mask) == 0:
        raise ValueError("No O–C points remain with uncertainties < 15 min, so no CSV created.")

    epochs_f = all_epochs[mask]
    oc_f = all_oc[mask]        # Keep O–C in days
    oc_err_f = all_oc_err[mask]  # Keep O–C uncertainty in days

    # Construct a filename based on the target name
    planet_csv_filename = f"E:/{target_name.replace(' ', '_').replace('-', '_')}.csv"
    df = pd.DataFrame({
        "Epoch": epochs_f,
        "O-C (days)": oc_f,
        "O-C Error (days)": oc_err_f
    })
    df.to_csv(planet_csv_filename, index=False)
    print(f"Wrote {len(df)} rows of O–C data with errors to {planet_csv_filename}")


###############################################################################
# MAIN: Example usage (ensure your data arrays are defined)
###############################################################################

if __name__ == "__main__":
    # The following variables (all_epochs, all_oc, all_oc_err) must be defined before running.
    # For example, they might come from a file or earlier processing step.
    #
    # all_epochs : numpy array of epochs (in epochs)
    # all_oc     : numpy array of O–C values (in days)
    # all_oc_err : numpy array of uncertainties in O–C (in days)

    # Set your target planet name (adjust as needed)
    target_name = "Kepler-1710 b"
    # 'period' is provided for compatibility, although it is not directly used in the iterative analysis.
    period = 7.64159

    # --- Run iterative Lomb–Scargle analysis ---
    results, final_mask = run_lomb_scargle_TTV(all_epochs, all_oc, all_oc_err, period, show_plots=True)

    # Find the first iteration with FAP below 0.1
    dominant_result = None
    for res in results:
        if res['fap'] < 0.1:
            dominant_result = res
            break

    if dominant_result is not None:
        amplitude_dom = dominant_result['amplitude']
        dom_period = dominant_result['dominant_period']
        power_dom = dominant_result['power']
        # For compatibility we set second dominant period values equal to the first
        sec_period = dom_period
        power_sec = power_dom

        # --- Query NASA Exoplanet Archive for additional parameters ---
        params = query_inner_params(target_name)

        # --- Append results to ocean.csv ---
        write_to_ocean_csv(params, amplitude_dom, dom_period, power_dom,
                           sec_period, power_sec, filename="E:/ocean.csv")
    else:
        print("No dominant period with FAP below 0.1 was found; nothing appended to ocean.csv.")

    # --- Write CSV with Epoch and O–C data ---
    # Saving the final timing data allows you to explore the
    # TTV signal in external programs or share it with collaborators.
    write_planet_csv(target_name, all_epochs, all_oc, all_oc_err)

    print("\nAll tasks complete.")

## 3. Visualizing Harmonics of the O–C Curve

This cell demonstrates how to combine the dominant Lomb–Scargle signals and overlay them on the observed minus calculated data.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from astropy.timeseries import LombScargle

# -----------------------------------------------
# Visualise the harmonic content found by the
# Lomb--Scargle analysis. This cell reconstructs
# the dominant sinusoids and plots them together
# with the measured TTV data.
# -----------------------------------------------
# Suppose you already have:
#   all_epochs, all_oc, all_oc_err
#   final_mask from run_lomb_scargle_TTV
#   results (list of dictionaries) from the 10 iterations

# Convert O–C to minutes
oc_min = all_oc * 24 * 60
oc_err_min = all_oc_err * 24 * 60

# Filtered data using final_mask
epochs_f = all_epochs[final_mask]
ttv_f = oc_min[final_mask]
ttv_err_f = oc_err_min[final_mask]

# Shift epochs for plotting (optional)
x_shifted = epochs_f - np.min(epochs_f)

# Mean-center TTV data (as done in the LS analysis)
mean_ttv = np.mean(ttv_f)
y_data = ttv_f - mean_ttv

# Create a Lomb-Scargle object for the entire dataset
ls_orig = LombScargle(epochs_f, y_data, ttv_err_f)

# We will evaluate each harmonic on a *dense* grid of epochs:
num_points = 500
epochs_dense = np.linspace(np.min(epochs_f), np.max(epochs_f), num_points)
x_dense_shifted = epochs_dense - np.min(epochs_f)

# Initialize composite model on the dense grid
composite_dense = np.zeros(num_points)

# Plot original data
plt.figure(figsize=(10,6))
plt.errorbar(x_shifted, ttv_f, yerr=ttv_err_f, fmt='o', color='black',
             ecolor='gray', capsize=3)

# Number of harmonics to show
num_harmonics = min(2, len(results))

colors = plt.cm.viridis(np.linspace(0, 1, num_harmonics))

for i in range(num_harmonics):
    period_i = results[i]['dominant_period']
    freq_i = 1.0 / period_i

    # Evaluate the sinusoidal model on the *dense* grid
    model_dense_i = ls_orig.model(epochs_dense, frequency=freq_i)
    # Add this component to the composite
    composite_dense += model_dense_i

    # Plot the individual harmonic as a dashed line
    # Shift upward by the mean so it aligns with the data's absolute scale
    plt.plot(x_dense_shifted, model_dense_i + mean_ttv,
             linestyle='--', color=colors[i],
             label=f'Dominant Signal {i+1} (Period = {period_i:.2f} epochs)')

# After summing up all components, shift the total back up by the mean
composite_dense += mean_ttv

# Plot the composite model as a solid line
#plt.plot(x_dense_shifted, composite_dense, color='blue', lw=2,
#         label='Composite 2-Harmonic Fit')

# Customize the plot
plt.xlabel("Epochs", fontsize=14)
plt.ylabel("TTV Amplitude (minutes)", fontsize=14)
plt.title("Composite 2-Harmonic Curve Fit", fontsize=16)
plt.legend(loc = 'upper left')
plt.tight_layout()
plt.show()


## 4. Chi-Squared Diagnostics for Model Fits

Finally we compute goodness-of-fit statistics for models derived from the Lomb–Scargle analysis and visualize the results.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from astropy.timeseries import LombScargle

# -----------------------------------------------
# Assess the quality of the harmonic fit using
# reduced chi-square statistics and visualise the
# model together with the raw TTV measurements.
# -----------------------------------------------
# Suppose you already have:
#   all_epochs, all_oc, all_oc_err
#   final_mask from run_lomb_scargle_TTV
#   results (list of dictionaries) from the 10 iterations

# 1) Convert O–C to minutes
oc_min     = all_oc     * 24 * 60
oc_err_min = all_oc_err * 24 * 60

# 2) Filter data using final_mask
epochs_f  = all_epochs[final_mask]
ttv_f     = oc_min[final_mask]
ttv_err_f = oc_err_min[final_mask]

# 3) Shift epochs for plotting (optional)
x_shifted_epochs = epochs_f - np.min(epochs_f)
x_shifted_days   = x_shifted_epochs * 33.6    # multiply by 33.6 to turn epochs into days

# 4) Mean‐center the TTV data
mean_ttv = np.mean(ttv_f)
y_data   = ttv_f - mean_ttv

# 5) Create Lomb-Scargle object
ls_orig = LombScargle(epochs_f, y_data, ttv_err_f)

# 6) Compute χ² for the null model (y=0)
chi2_null  = np.sum((y_data / ttv_err_f)**2)
ndof_null  = len(y_data) - 1               # one parameter (the mean)
chi2r_null = chi2_null / ndof_null
print(f"Chi² (null model y=0):           {chi2_null:.2f}")
print(f"Reduced χ² (dof={ndof_null}):    {chi2r_null:.2f}")

# 7) Compute χ² for the dominant-period fit
P1            = results[0]['dominant_period']
f1            = 1.0 / P1
model1_epochs = ls_orig.model(epochs_f, frequency=f1)
resid1        = y_data - model1_epochs
chi2_dom      = np.sum((resid1 / ttv_err_f)**2)
ndof_dom      = len(y_data) - 2           # two fitted parameters (amp & phase)
chi2r_dom     = chi2_dom / ndof_dom
print(f"Chi² (dominant P={P1:.2f} epochs):    {chi2_dom:.2f}")
print(f"Reduced χ² (dof={ndof_dom}):    {chi2r_dom:.2f}")

# 8) Compute χ² for the two‐period composite (dominant + second‐dominant)
if len(results) > 1:
    P2            = results[1]['dominant_period']
    f2            = 1.0 / P2
    model2_raw    = ls_orig.model(epochs_f, frequency=f2)

    amp2_raw      = 0.5 * (model2_raw.max() - model2_raw.min())
    scale2        = amp2_raw
    model2_epochs = model2_raw * scale2

    composite_epochs = model1_epochs + model2_epochs
    resid_comp       = y_data - composite_epochs
    chi2_comp        = np.sum((resid_comp / ttv_err_f)**2)
    ndof_comp        = len(y_data) - 4       # 4 fitted params: amp & phase × 2
    chi2r_comp       = chi2_comp / ndof_comp
    print(f"Chi² (P₁={P1:.2f}, P₂={P2:.2f} epochs):  {chi2_comp:.2f}")
    print(f"Reduced χ² (dof={ndof_comp}):           {chi2r_comp:.2f}")

# 9) Plot data and up to two harmonics
num_points     = 500
epochs_dense   = np.linspace(np.min(epochs_f), np.max(epochs_f), num_points)
x_dense_shift  = (epochs_dense - np.min(epochs_f)) * 33.6

plt.figure(figsize=(7,4))
plt.errorbar(
    x_shifted_days,
    ttv_f,
    yerr=ttv_err_f,
    fmt='o', color='black', ecolor='gray', capsize=3
)

num_harmonics = min(2, len(results))
colors        = ['blue', 'green'][:num_harmonics]
composite_dense = np.zeros(num_points)

for i in range(num_harmonics):
    P_i = results[i]['dominant_period']
    f_i = 1.0 / P_i
    model_dense = ls_orig.model(epochs_dense, frequency=f_i)

    if i == 1:
        # scale second harmonic to 20.25 min
        amp_raw     = 0.5 * (model_dense.max() - model_dense.min())
        scale       = amp_raw
        model_dense *= scale

    composite_dense += model_dense

    plt.plot(
        x_dense_shift,
        model_dense + mean_ttv,
        linestyle='--',
        color=colors[i],
        alpha=0.5,
        label=f"Signal {i+1} (P={P_i:.2f} epochs)"
    )

plt.xlabel("Days", fontsize=14)
plt.ylabel("TTV Amplitude (minutes)", fontsize=14)
plt.legend(loc='upper left')
plt.tight_layout()
plt.show()
