# 

In [None]:
"""
kepler_individ_lc.py

Script to download, clean, and fit individual Kepler transits for a specified target,
then create an O–C diagram using the Exotic Library.

Important: NASA Exoplanet Archive typically provides T0 in BJD (~2454xxx),
while Kepler data from lightkurve is in BKJD = BJD - 2454833.0.
Hence, we subtract 2454833 from T0 before phase-folding and fitting.
"""

!pip install ultranest rebound
!git clone https://github.com/rzellem/EXOTIC
%cd EXOTIC
!git checkout tess
%cd ..
import sys
sys.path.append("F:/EXOTIC")

!git clone https://github.com/pearsonkyle/Nbody-AI
sys.path.append('F:/Nbody-AI/nbody')

sys.path.insert(0, 'F:/Nbody-AI')
    

import sys
import os
import copy
import json
import pickle
import argparse
import requests
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from io import StringIO
from pandas import read_csv
from scipy.ndimage import binary_dilation, label
from scipy.signal import savgol_filter, medfilt

import lightkurve as lk
from astropy import constants as const
from astropy import units as u
from wotan import flatten

# EXOTIC imports (installed via "pip install exotic")
from exotic.api.elca import lc_fitter
from exotic.api.output_aavso import OutputFiles

# For transit modeling
from pylightcurve import exotethys

# ---------------------------------------------------------------------
# 1) Helper functions
# ---------------------------------------------------------------------

def tap_query(base_url, query, dataframe=True):
    """
    Table Access Protocol (TAP) query to NASA Exoplanet Archive.
    Builds a URL from 'base_url' and 'query', sends a GET request,
    returns response as a DataFrame or raw text.
    """
    uri_full = base_url
    for k in query:
        if k != "format":
            uri_full += f"{k} {query[k]} "
    uri_full = f"{uri_full[:-1]} &format={query.get('format', 'csv')}"
    uri_full = uri_full.replace(' ', '+')
    print("Query:", uri_full)

    response = requests.get(uri_full, timeout=300)
    if dataframe:
        return read_csv(StringIO(response.text))
    else:
        return response.text

def nea_scrape(target=None):
    """
    Use NASA Exoplanet Archive to fetch planet parameters for 'target'
    (e.g., 'Kepler-18 b'). Returns a DataFrame with many columns
    like pl_name, pl_orbper, st_rad, st_logg, etc.
    """
    uri_ipac_base = "https://exoplanetarchive.ipac.caltech.edu/TAP/sync?query="
    uri_ipac_query = {
        "select": (
            "pl_name,hostname,pl_radj,pl_radjerr1,ra,dec,"
            "pl_ratdor,pl_ratdorerr1,pl_ratdorerr2,pl_orbincl,pl_orbinclerr1,pl_orbinclerr2,"
            "pl_orbper,pl_orbpererr1,pl_orbpererr2,pl_orbeccen,pl_orbsmax,pl_orbsmaxerr1,pl_orbsmaxerr2,"
            "pl_orblper,pl_tranmid,pl_tranmiderr1,pl_tranmiderr2,"
            "pl_ratror,pl_ratrorerr1,pl_ratrorerr2,"
            "st_teff,st_tefferr1,st_tefferr2,st_met,st_meterr1,st_meterr2,"
            "st_logg,st_loggerr1,st_loggerr2,st_mass,st_rad,st_raderr1"
        ),
        "from": "pscomppars",
        "where": "tran_flag = 1",
        "format": "csv"
    }
    if target:
        uri_ipac_query["where"] += f" and pl_name = '{target}'"
    return tap_query(uri_ipac_base, uri_ipac_query)

def sigma_clip(ogdata, dt, iterations=1):
    """
    Iterative sigma-clipping on 'ogdata' using a Savitzky-Golay filter
    with window length 'dt' points. Outliers >3σ replaced by NaN.
    """
    mask = np.ones(ogdata.shape, dtype=bool)
    for _ in range(iterations):
        mdata = savgol_filter(ogdata[mask], dt, 2)
        res = ogdata[mask] - mdata
        std = np.nanstd(res)
        mask[mask] = np.abs(res) < 3*std

    mdata = savgol_filter(ogdata[mask], dt, 2)
    data = copy.deepcopy(ogdata)
    data[~mask] = np.nan
    return data, np.std(ogdata[mask] - mdata)

def check_std(time, flux, dt=0.5):
    """
    Sort by time, apply a Savitzky-Golay filter, return std of residuals.
    dt in hours => affects smoothing window.
    """
    tdt = np.diff(np.sort(time)).mean()
    si = np.argsort(time)
    # At least 15 points or so in the smoothing window
    wsize = 1 + 2*int(max(15, dt/(24*tdt)))
    sflux = savgol_filter(flux[si], wsize, 2)
    return np.nanstd(flux - sflux)

def stellar_mass(logg, rs):
    """
    Estimate stellar mass (Msun) from logg (cgs) and radius (Rsun).
    Msun ~ (R^2 * 10^logg / G).
    """
    return ((rs*u.R_sun)**2 * 10**logg*(u.cm/u.s**2) / const.G).to(u.M_sun).value

def sa(m, P):
    """
    Semi-major axis (AU) from star mass m (Msun) and period P (days).
    """
    return ((const.G*m*u.M_sun*P*u.day**2/(4*np.pi**2))**(1./3)).to(u.AU).value

def parse_args():
    """
    Parses command-line arguments for analyzing Kepler data.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("-t", "--target", type=str, default="Kepler-1710 b",
                        help="Kepler target, e.g. 'Kepler-19 b'.")
    parser.add_argument("-o", "--output", type=str, default="kepler_individ_lc_output",
                        help="Output directory for results.")
    parser.add_argument("--quarter", type=int, default=0,
                        help="Which Kepler quarter to process (0=all).")
    parser.add_argument("-r", "--reprocess", action='store_true', default=False,
                        help="Reprocess even if existing results found.")
    parser.add_argument("--ars", action='store_true', default=False,
                        help="Include a/R* in the fit bounds.")
    parser.add_argument("--tls", action='store_true', default=False,
                        help="Perform TLS search on final residuals (not fully used here).")

    args, _ = parser.parse_known_args()
    return args

# ---------------------------------------------------------------------
# 2) Main Script
# ---------------------------------------------------------------------

if __name__ == "__main__":
    args = parse_args()
    if not os.path.exists(args.output):
        os.makedirs(args.output)

    # Planet-specific directory
    planetdir = os.path.join(args.output, args.target.replace(' ', '_').replace('-', '_'))
    if not os.path.exists(planetdir):
        os.mkdir(planetdir)
    else:
        # If there's a 'global_fit.png' and no --reprocess, skip
        if os.path.exists(os.path.join(planetdir, "global_fit.png")) and not args.reprocess:
            raise Exception("Target appears already processed. Use --reprocess to overwrite.")

    planetname = args.target.lower().replace(' ', '').replace('-', '')

    # 1) Search for Kepler data
    search_result = lk.search_targetpixelfile(args.target, mission='Kepler')
    print(search_result)
    if len(search_result) == 0:
        raise Exception(f"No Kepler data found for target: {args.target}")

    # 2) Load or scrape prior planet parameters
    prior_path = os.path.join(planetdir, planetname + "_prior.json")
    if os.path.exists(prior_path):
        prior = json.load(open(prior_path, "r"))
    else:
        # Query NASA Exoplanet Archive
        nea_df = nea_scrape(args.target)
        if len(nea_df) == 0:
            raise Exception(f"No planet found in NASA Exoplanet Archive: {args.target}")

        prior = {}
        for key in nea_df.columns:
            prior[key] = nea_df[key].values[0]

        # if pl_ratdor is NaN, compute from pl_orbsmax / st_rad
        if np.isnan(prior.get('pl_ratdor', np.nan)) and not np.isnan(prior.get('pl_orbsmax', np.nan)):
            prior['pl_ratdor'] = (
                (prior['pl_orbsmax']*u.AU)/(prior['st_rad']*u.R_sun)
            ).to(u.dimensionless_unscaled).value
            prior['pl_ratdorerr1'] = 0.1 * prior['pl_ratdor']
            prior['pl_ratdorerr2'] = -0.1 * prior['pl_ratdor']

        # Save local prior
        with open(prior_path, 'w', encoding='utf8') as jf:
            json.dump(prior, jf, indent=4)

    # 3) Limb darkening for Kepler band
    #    If these are NaN, set defaults
    st_logg = prior.get('st_logg', 4.4)
    if np.isnan(st_logg): st_logg = 4.4
    st_teff = prior.get('st_teff', 5700)
    if np.isnan(st_teff): st_teff = 5700
    st_met = prior.get('st_met', 0.0)
    if np.isnan(st_met): st_met = 0.0

    u0, u1, u2, u3 = exotethys(st_logg, st_teff, st_met,
                               'Kepler', method='claret', stellar_model='phoenix')

    # 4) Prepare structure to store data from each quarter
    sv = {
        'lightcurves': [],
        'quarters': {},
        'time': [],
        'flux': [],
        'flux_err': [],
        'trend': [],
        'quarter': []
    }

    # If there's a 'quarter' column in search_result, we can check it
    if "quarter" in search_result.table.colnames:
        unique_quarters = np.unique(search_result.table["quarter"])
    else:
        # fallback: each row is "quarterlike"
        unique_quarters = np.arange(len(search_result))

    # If user wants only a specific quarter
    if args.quarter != 0 and args.quarter not in unique_quarters:
        raise Exception(f"No data for quarter {args.quarter} in search_result.")

    # 5) Download & Clean each quarter
    for idx in range(len(search_result)):

        if "quarter" in search_result.table.colnames:
            quarter = search_result.table["quarter"][idx]
            if args.quarter != 0 and quarter != args.quarter:
                continue
        else:
            quarter = idx + 1  # fallback label

        print(f"Downloading Quarter {quarter} ...")
        try:
            tpf = search_result[idx].download(quality_bitmask='hard')
        except Exception as err:
            print(f"Failed to download quarter {quarter}: {err}")
            continue

        # Aperture selection
        lc = tpf.to_lightcurve(aperture_mask=tpf.pipeline_mask)
        nmask = np.isnan(lc.flux.value)
        lstd = check_std(lc.time.value[~nmask], lc.flux.value[~nmask])

        aper_final = tpf.pipeline_mask
        for it in [1, 2]:
            bigger_mask = binary_dilation(tpf.pipeline_mask, iterations=it)
            lcd = tpf.to_lightcurve(aperture_mask=bigger_mask)
            nmaskd = np.isnan(lcd.flux.value)
            lstdd = check_std(lcd.time.value[~nmaskd], lcd.flux.value[~nmaskd])
            if lstdd < lstd:
                aper_final = bigger_mask
                lstd = lstdd

        lc = tpf.to_lightcurve(aperture_mask=aper_final)
        tpf.plot(aperture_mask=aper_final)
        plt.savefig(os.path.join(planetdir, f"{planetname}_quarter_{quarter}_aperture.png"))
        plt.close()

        # Remove first ~30 min after big gaps
        time_arr = lc.time.value  # Kepler BKJD
        flux_arr = lc.flux.value
        tmask = np.ones_like(time_arr, dtype=bool)

        srt = np.argsort(time_arr)
        dt_ = np.diff(time_arr[srt])
        if len(dt_) and not np.isnan(dt_).all():
            median_dt = np.nanmedian(dt_)
            # 30 min in days => 30./1440. => how many points
            ndt = int(np.round((30./1440.) / median_dt))
        else:
            ndt = 30  # fallback

        tmask[0:ndt] = False
        big_gap = dt_ > (1./24.)  # 1 hour
        biggap_idx = np.argwhere(big_gap).flatten()
        for gap_idx in biggap_idx:
            tmask[gap_idx:gap_idx+ndt] = False

        nmask2 = ~np.isnan(flux_arr)
        tmask = tmask & nmask2

        time_c = time_arr[tmask]
        flux_c = flux_arr[tmask]

        # Remove outliers
        mflux = medfilt(flux_c, kernel_size=15)
        rflux = flux_c / mflux
        newflux, std_ = sigma_clip(rflux, dt=15, iterations=1)
        maskNaN = np.isnan(newflux)
        time_c = time_c[~maskNaN]
        flux_c = flux_c[~maskNaN]

        # Flatten (remove stellar variability)
        dflux = np.copy(flux_c)
        dtrend = np.ones(len(time_c))

        diff = np.diff(time_c)
        if len(diff):
            day_gaps = diff > 0.5
            dmask = np.concatenate([[True], ~day_gaps])
        else:
            dmask = np.ones(len(time_c), dtype=bool)

        seg_label, _ = label(dmask)
        for seg_id in np.unique(seg_label):
            if seg_id == 0:
                continue
            seg = (seg_label == seg_id)
            if seg.sum() < 5:
                continue
            flc, tlc = flatten(time_c[seg], flux_c[seg],
                               window_length=2.0, return_trend=True,
                               method='biweight')
            dflux[seg] = flc
            dtrend[seg] = tlc

        # Store in state
        sv['quarters'][quarter] = True
        sv['time'].append(time_c)
        sv['flux'].append(dflux)
        sv['flux_err'].append((dtrend**0.5)/np.nanmedian(dtrend))
        sv['trend'].append(dtrend)
        sv['quarter'].append(np.ones(len(time_c))*quarter)

        # Plot
        plt.figure()
        plt.plot(time_c, flux_c, 'k.', label='Flux')
        plt.plot(time_c, dtrend, 'r--', label='Trend')
        plt.title(f"{args.target} - Quarter {quarter}")
        plt.xlabel("Time [BKJD]")  # BKJD = BJD - 2454833
        plt.ylabel("Flux")
        plt.legend()
        outfig = os.path.join(planetdir, f"{planetname}_quarter_{quarter}_trend.png")
        plt.savefig(outfig)
        plt.close()

    if not sv['time']:
        raise Exception("No valid quarters processed. Check logs for errors.")

    # Concatenate all quarters
    time_all = np.concatenate(sv['time'])
    flux_all = np.concatenate(sv['flux'])
    fluxerr_all = np.concatenate(sv['flux_err'])
    trend_all = np.concatenate(sv['trend'])
    quarter_all = np.concatenate(sv['quarter'])

    # Remove nans
    nanmask = np.isnan(flux_all) | np.isnan(fluxerr_all)
    time_all = time_all[~nanmask]
    flux_all = flux_all[~nanmask]
    fluxerr_all = fluxerr_all[~nanmask]
    quarter_all = quarter_all[~nanmask]
    trend_all = trend_all[~nanmask]

    # Save combined light curve to CSV
    df = pd.DataFrame({
        'time_bkjd': time_all,
        'flux': flux_all * trend_all,
        'flux_err': fluxerr_all * trend_all,
        'quarter': quarter_all
    })
    outcsv = os.path.join(planetdir, f"{planetname}_lightcurve.csv")
    df.to_csv(outcsv, index=False)
    print("Saved combined lightcurve:", outcsv)

    # 6) Transit Fitting / O-C

    # Convert T0 from BJD to BKJD by subtracting 2454833.0
    # if we have a valid pl_tranmid
    if not np.isnan(prior.get('pl_tranmid', np.nan)):
        tmid_bjd = float(prior['pl_tranmid'])
    else:
        tmid_bjd = 2454950.0  # fallback guess

    tmid_bkjd = tmid_bjd - 2454833.0  # *CRITICAL SHIFT*

    # Orbital period
    if not np.isnan(prior.get('pl_orbper', np.nan)):
        period = float(prior['pl_orbper'])
    else:
        period = 10.0

    # a/R*
    if not np.isnan(prior.get('pl_ratdor', np.nan)):
        ars = float(prior['pl_ratdor'])
    else:
        ars = 15.0

    # rprs = (Rplanet/Rstar)
    if not np.isnan(prior.get('pl_radj', np.nan)):
        rprs = float((prior['pl_radj']*u.R_jup)/(prior['st_rad']*u.R_sun))
    else:
        rprs = 0.05

    inc = prior.get('pl_orbincl', 88.0)
    if np.isnan(inc): inc = 88.0

    ecc = prior.get('pl_orbeccen', 0.0)
    if np.isnan(ecc): ecc = 0.0

    omega = prior.get('pl_orblper', 0.0)
    if np.isnan(omega): omega = 0.0

    tpars = {
        'rprs': rprs,
        'ars': ars,
        'per': period,
        'inc': inc,
        'tmid': tmid_bkjd,
        'omega': omega,
        'ecc': ecc,
        'a1': 0, 'a2': 0,
        'u0': u0, 'u1': u1, 'u2': u2, 'u3': u3
    }

    # Quick flux_err guess
    if len(flux_all):
        phot_std = np.nanstd(flux_all[flux_all < 1.1])  # rough
        fluxerr0 = phot_std / flux_all
    else:
        fluxerr0 = np.array([])

    # Phase times
    tphase = (time_all - tmid_bkjd)/period
    # approximate transit duration
    pdur = 2.0 * np.arctan(1./ars) / (2*np.pi) if ars > 0 else 0.05

    # We'll define event epochs from floor(tphase)
    events = np.unique(np.floor(tphase))

    # Prepare arrays to store O-C
    all_epochs = []
    all_oc = []
    all_oc_err = []

    # Bounds
    mybounds = {
        'rprs': [0, 3*rprs],
        'tmid': [tmid_bkjd - 0.3, tmid_bkjd + 0.3],
        'inc': [70, 90]
    }
    if args.ars:
        mybounds['ars'] = [ars*0.5, ars*2.0]

    for e in events:
        # mask within ~2 x pdur of the predicted transit
        intrans = (tphase >= e - 2*pdur) & (tphase <= e + 2*pdur)
        if intrans.sum() < 10:
            continue

        # update guess Tmid for that epoch
        guess_tmid = tmid_bkjd + e*period
        local_bounds = copy.deepcopy(mybounds)
        # allow +/- 0.2 * (period*pdur) around guess
        local_bounds['tmid'] = [guess_tmid - 0.2*period*pdur,
                                guess_tmid + 0.2*period*pdur]
        tpars['tmid'] = guess_tmid

        airmass = np.zeros(intrans.sum())
        try:
            fit = lc_fitter(
                time_all[intrans], flux_all[intrans],
                fluxerr0[intrans],
                airmass,
                tpars,
                local_bounds
            )
        except:
            print(f"Failed to fit transit at epoch {int(e)}.")
            continue

        rprs2 = fit.parameters['rprs']**2
        rprs2err = 2 * fit.parameters['rprs'] * fit.errors['rprs']
        if (rprs2 - rprs2err <= 0):
            print(f"Skipping epoch {int(e)}: rprs^2 ~ 0 or negative.")
            continue

        lcdata = {
            'time': fit.time,           # BKJD
            'flux': fit.data,
            'residuals': fit.residuals,
            'phase': fit.phase,
            'pars': fit.parameters,
            'errors': fit.errors,
            'rchi2': fit.chi2/len(fit.time),
            'epoch': int(e)
        }
        sv['lightcurves'].append(lcdata)

        # Observed Tmid (in BKJD)
        Tobs_bkjd = fit.parameters['tmid']
        Tobs_err = fit.errors['tmid']

        # Predicted Tmid for epoch e in BKJD
        Tcalc_bkjd = tmid_bkjd + e*period
        OC_bkjd = Tobs_bkjd - Tcalc_bkjd
        OC_err = Tobs_err  # ignoring ephemeris uncertainty

        all_epochs.append(e)
        all_oc.append(OC_bkjd)
        all_oc_err.append(OC_err)

        # Plot bestfit
        fig, ax = fit.plot_bestfit(title=f"{args.target}, Quarter Fit, E={int(e)}")
        outname = f"{planetname}_E{int(e)}_fit.png"
        plt.savefig(os.path.join(planetdir, outname))
        plt.close()

    # Dump pickled results
    with open(os.path.join(planetdir, f"{planetname}_data.pkl"), 'wb') as pf:
        pickle.dump(sv, pf)

    # 7) O-C Diagram
    all_epochs = np.array(all_epochs)
    all_oc = np.array(all_oc)
    all_oc_err = np.array(all_oc_err)

    plt.figure(figsize=(8,5))
    plt.errorbar(all_epochs, all_oc, yerr=all_oc_err, fmt='o',
                 color='blue', ecolor='gray', capsize=3,
                 label="O-C (BKJD)")
    plt.axhline(0, color='k', linestyle='--', alpha=0.7)
    plt.xlabel("Epoch")
    plt.ylabel("O - C [days, BKJD]")
    plt.title(f"O-C Diagram: {args.target}")
    plt.legend()
    plt.tight_layout()

    oc_plot_file = os.path.join(planetdir, f"{planetname}_OC_diagram.png")
    plt.savefig(oc_plot_file)
    plt.show()

    print("Done! O-C diagram saved at:", oc_plot_file)
    print("Remember these times are in BKJD. (BJD - 2454833.0)")


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import pandas as pd
import requests
from io import StringIO
from astropy.timeseries import LombScargle
from scipy.signal import find_peaks

###############################################################################
# LOMB–SCARGLE TTV ANALYSIS
###############################################################################

def run_lomb_scargle_TTV(all_epochs, all_oc, all_oc_err, show_plots=True):
    """
    1) Convert O–C from days -> minutes
    2) Filter points with O–C uncertainty < 5 min
    3) (Optional) plot TTV curve vs epoch (now plotted in epochs with a 6th order polynomial fit)
    4) Lomb–Scargle => find dominant & second-dominant peaks in "epoch" domain
    5) Return amplitude_dom, dominant_period, power_dominant,
             second_dominant_period, power_second_dominant, plus
       also return the final mask for writing planet_name.csv
    """

    # Convert to minutes
    oc_min = all_oc * 24.0 * 60.0
    oc_err_min = all_oc_err * 24.0 * 60.0

    # Filter for uncertainties < 5 min (here threshold is 3 min as in original code)
    mask = oc_err_min < 3.0
    if np.sum(mask) == 0:
        raise ValueError("No O–C points remain with uncertainties < 5 min!")

    epochs_f = all_epochs[mask]
    ttv_f = oc_min[mask]
    ttv_err_f = oc_err_min[mask]

    # Optional: plot TTV curve in epochs with a 6th order polynomial fit
    if show_plots:
        plt.figure(figsize=(8,5))
        # Plot the errorbar data (x-axis is epochs)
        plt.errorbar(epochs_f, ttv_f, yerr=ttv_err_f, fmt='o', color='black', ecolor='gray', capsize=3)
        
        # Fit a 6th order polynomial to the TTV data
        poly_coeffs = np.polyfit(epochs_f, ttv_f, 19)
        # Create a smooth set of x-values over the range of epochs for plotting the curve
        x_fit = np.linspace(np.min(epochs_f), np.max(epochs_f), 200)
        y_fit = np.polyval(poly_coeffs, x_fit)
        plt.plot(x_fit, y_fit, color='black', linewidth=2)
        
        plt.title("Figure 4a: Kepler-1710b Observed-Calculated (O-C) Plot", fontsize=20, pad=20)
        plt.xlabel("Epochs", fontsize=18)
        plt.ylabel("TTV Amplitude (minutes)", fontsize=18)
        plt.xticks(fontsize=18)
        plt.yticks(fontsize=18)
        plt.tight_layout()
        plt.show()

    # Mean-center TTV data for Lomb–Scargle
    y_data = ttv_f - np.mean(ttv_f)

    # Weighted Lomb–Scargle
    ls = LombScargle(epochs_f, y_data, ttv_err_f)

    # Frequency range => periods ~2..200 epochs
    frequency, power = ls.autopower(
        minimum_frequency=1.0/200.0,
        maximum_frequency=0.5,
        samples_per_peak=10
    )
    periods = 1.0 / frequency  # in epochs

    # Identify top peaks
    peaks, _ = find_peaks(power, height=0.01, distance=5)
    if len(peaks) == 0:
        # fallback: top 2 powers
        sorted_indices = np.argsort(power)[::-1]
        peaks = sorted_indices[:2]

    peak_periods = periods[peaks]
    peak_powers = power[peaks]

    # Dominant & second-dominant
    if len(peak_periods) == 1:
        dominant_period = peak_periods[0]
        power_dominant = peak_powers[0]
        second_dominant_period = peak_periods[0]
        power_second_dominant = peak_powers[0]
    else:
        sort_idx = np.argsort(peak_powers)[::-1]
        dominant_period = peak_periods[sort_idx[0]]
        power_dominant = peak_powers[sort_idx[0]]

        second_dominant_period = None
        power_second_dominant = None
        for idx in sort_idx[1:]:
            if abs(peak_periods[idx] - dominant_period) > 0.05 * dominant_period:
                second_dominant_period = peak_periods[idx]
                power_second_dominant = peak_powers[idx]
                break
        if second_dominant_period is None and len(sort_idx) > 1:
            second_dominant_period = peak_periods[sort_idx[1]]
            power_second_dominant = peak_powers[sort_idx[1]]
        elif second_dominant_period is None:
            second_dominant_period = dominant_period
            power_second_dominant = power_dominant

    # amplitude ~ 2 * sqrt(power * var(y_data)), in minutes
    var_y = np.var(y_data)
    amplitude_dom = 2.0 * np.sqrt(power_dominant * var_y) if power_dominant > 0 else 0.0
    amplitude_2nd = 2.0 * np.sqrt(power_second_dominant * var_y) if power_second_dominant > 0 else 0.0

    print("\n===== LOMB-SCARGLE RESULTS =====")
    print(f"Dominant Period (epochs)       = {dominant_period:.3f}")
    print(f"Dominant Power                 = {power_dominant:.4f}")
    print(f"Amplitude (Dominant) [min]     = {amplitude_dom:.3f}")
    print(f"Second Dominant Period (epochs)= {second_dominant_period:.3f}")
    print(f"Second Dominant Power          = {power_second_dominant:.4f}")
    print(f"Amplitude (Second) [min]       = {amplitude_2nd:.3f}")

    # Optional FAP
    try:
        fap_val = ls.false_alarm_probability(power_dominant)
        print(f"FAP of top peak: {fap_val:.4g}")
    except:
        pass

    # Optional periodogram plot: highlight the dominant (highest) peak
    if show_plots:
        plt.figure(figsize=(8,5))
        # Plot periods directly in epochs (instead of converting to days)
        plt.plot(periods, power)
        # Zoom limit example: show up to 50 epochs
        plt.xlim(0, 50)

        # Plot only the highest peak
        plt.scatter(
            dominant_period, 
            power_dominant, 
            color='red', 
            zorder=3, 
            label=f"Dominant Period ({dominant_period:.2f} epochs)"
        )

        plt.xlabel("Epochs", fontsize=18)
        plt.ylabel("Power", fontsize=18)
        plt.title("Figure 4b: Kepler-1710b O-C Lomb-Scargle Periodogram", fontsize=20, pad=20)
        plt.xticks(fontsize=18)
        plt.yticks(fontsize=18)
        #plt.legend()
        plt.tight_layout()
        plt.show()

    # Return main results + final mask
    return (amplitude_dom, dominant_period, power_dominant,
            second_dominant_period, power_second_dominant, mask)

###############################################################################
# NASA EXOPLANET ARCHIVE QUERY
###############################################################################

def tap_query(base_url, query):
    uri_full = base_url
    for k in query:
        if k != "format":
            uri_full += f"{k} {query[k]} "
    uri_full = f"{uri_full[:-1]} &format={query.get('format', 'csv')}"
    uri_full = uri_full.replace(' ', '+')
    print("Query URL:", uri_full)

    response = requests.get(uri_full, timeout=60)
    if response.status_code != 200:
        raise RuntimeError(f"Query failed: status {response.status_code}\n{response.text}")

    df = pd.read_csv(StringIO(response.text))
    return df

def query_inner_params(target_name):
    """
    Query NASA Exoplanet Archive for:
      - st_mass (Msun)
      - pl_bmasse (Earth masses)
      - pl_orbper (days)
      - pl_orbeccen (eccentricity)
      - pl_orblper (omega)
      - pl_orbincl (inclination, deg)
    Return the row with the most non-null among these columns.
    """
    base_url = "https://exoplanetarchive.ipac.caltech.edu/TAP/sync?query="
    query_dict = {
        "select": "st_mass, pl_bmasse, pl_orbper, pl_orbeccen, pl_orblper, pl_orbincl",
        "from": "ps",
        "where": f"pl_name = '{target_name}'",
        "format": "csv"
    }
    df = tap_query(base_url, query_dict)
    if len(df) == 0:
        print(f"No rows returned for {target_name}. Possibly unrecognized planet name.")
        return {}

    cols = ["st_mass","pl_bmasse","pl_orbper","pl_orbeccen","pl_orblper","pl_orbincl"]
    best_idx = None
    best_count = -1
    for i in range(len(df)):
        count_valid = df.iloc[i][cols].notna().sum()
        if count_valid > best_count:
            best_count = count_valid
            best_idx = i
    if best_idx is None:
        return {}

    row = df.iloc[best_idx]
    out = {
        "st_mass_msun":    row.get("st_mass", np.nan),
        "pl_bmasse":       row.get("pl_bmasse", np.nan),
        "pl_orbper_days":  row.get("pl_orbper", np.nan),
        "pl_orbeccen":     row.get("pl_orbeccen", np.nan),
        "pl_orblper":      row.get("pl_orblper", np.nan),
        "pl_orbincl_deg":  row.get("pl_orbincl", np.nan)
    }
    return out

###############################################################################
# APPEND RESULTS TO ocean.csv
###############################################################################

def write_to_ocean_csv(
    params,
    amplitude_dom,
    dominant_period,
    power_dominant,
    second_dominant_period,
    power_second_dominant,
    filename="ocean.csv"
):
    columns = [
        "Stellar Mass (Msun)",
        "Inner Mass (Mearth)",
        "Inner Period (days)",
        "Inner Eccentricity",
        "Inner Inclination",
        "Inner Omega",
        "Amplitude of Dominant Period Test (P1)",
        "Dominant Period Planet 1",
        "Dominant Period Power Planet 1",
        "Second Dominant Period Planet 1",
        "Second Dominant Period Power Planet 1"
    ]

    new_row = {
        "Stellar Mass (Msun)":   params.get("st_mass_msun", np.nan),
        "Inner Mass (Mearth)":   params.get("pl_bmasse", np.nan),
        "Inner Period (days)":   params.get("pl_orbper_days", np.nan),
        "Inner Eccentricity":    params.get("pl_orbeccen", np.nan),
        "Inner Inclination":     params.get("pl_orbincl_deg", np.nan),
        "Inner Omega":           params.get("pl_orblper", np.nan),
        "Amplitude of Dominant Period Test (P1)": amplitude_dom,
        "Dominant Period Planet 1": dominant_period,
        "Dominant Period Power Planet 1": power_dominant,
        "Second Dominant Period Planet 1": second_dominant_period,
        "Second Dominant Period Power Planet 1": power_second_dominant
    }

    if not os.path.exists(filename):
        df = pd.DataFrame(columns=columns)
        df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)
        df.to_csv(filename, index=False)
        print(f"Created {filename} and wrote first row.")
    else:
        df = pd.read_csv(filename)
        df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)
        df.to_csv(filename, index=False)
        print(f"Appended a new row to existing {filename}.")

###############################################################################
# WRITE planet_name.csv with O-C vs EPOCH (2 columns)
###############################################################################

def write_planet_csv(target_name, all_epochs, all_oc, all_oc_err):
    """
    Creates a CSV named "<target_name>.csv" with columns:
      Epoch, O-C (days)
    for the points with uncertainties < 5 min.
    """

    # Convert to minutes
    oc_min = all_oc * 24.0 * 60.0
    oc_err_min = all_oc_err * 24.0 * 60.0

    mask = oc_err_min < 15.0
    if np.sum(mask) == 0:
        raise ValueError("No O–C points remain with uncertainties < 5 min, so no CSV created.")
    # Filter
    epochs_f = all_epochs[mask]
    oc_f = all_oc[mask]  # still in days

    # Build filename from planet name
    planet_csv_filename = f"E:/{target_name.replace(' ', '_').replace('-', '_')}.csv"

    df = pd.DataFrame({
        "Epoch": epochs_f,
        "O-C (days)": oc_f
    })

    df.to_csv(planet_csv_filename, index=False)
    print(f"Wrote {len(df)} rows of O-C data to {planet_csv_filename}")

###############################################################################
# EXAMPLE MAIN: use your real data here
###############################################################################

if __name__ == "__main__":
    # Suppose your real data arrays are named exactly:
    # all_epochs, all_oc, all_oc_err
    # (They must be defined in the kernel or code prior to running.)

    target_name = "Kepler-1710 b"  # or whichever planet name you want

    # 1) Lomb–Scargle
    (amplitude_dom,
     dom_period,
     power_dom,
     sec_period,
     power_sec,
     final_mask) = run_lomb_scargle_TTV(all_epochs, all_oc, all_oc_err, show_plots=True)

    # 2) Query NASA Exoplanet Archive
    params = query_inner_params(target_name)

    # 3) Append results to ocean.csv
    write_to_ocean_csv(
        params,
        amplitude_dom,
        dom_period,
        power_dom,
        sec_period,
        power_sec,
        filename="E:/ocean.csv"
    )

    # 4) Also write planet_name.csv with "Epoch" and "O-C (days)" (filtered data)
    write_planet_csv(target_name, all_epochs, all_oc, all_oc_err)

    print("\nAll tasks complete.")
