In [None]:
import numpy as np
import pandas as pd
import arviz as az
import matplotlib.pyplot as plt
from astropy.timeseries import LombScargle
from scipy.signal import find_peaks
import pymc as pm

# -------------------------------------------------------------------------
# I/O ---------------------------------------------------------------------
def load_data(path, sheet="Raw Data", first_kyr=30_000):
    df = (pd.read_excel(path, sheet_name=sheet, header=None, engine="openpyxl")
            .iloc[10:6124, 2:4]
            .rename(columns={2: "Age", 3: "d18o"}))

    df["Age"] -= df["Age"].min()          # t = 0 today
    df = df[df["Age"] <= first_kyr]       # keep only 0-30 kyr
    df["d18o"] -= df["d18o"].mean()       # centre

    return df["Age"].to_numpy(dtype=float), df["d18o"].to_numpy(dtype=float)

# -------------------------------------------------------------------------
# Spectral helpers ---------------------------------------------------------
def best_period(t, y, fap_level=0.01, min_P=500, max_P=20_000):
    """Return (period, power, fap) of the highest Lomb-Scargle peak."""
    ls = LombScargle(t, y)
    freq, power = ls.autopower(minimum_frequency=1/max_P,
                               maximum_frequency=1/min_P,
                               samples_per_peak=10)
    peaks, _ = find_peaks(power, distance=5)          # avoid double-counting
    if not len(peaks):
        return None, None, 1.0

    idx_best = peaks[np.argmax(power[peaks])]
    P_best   = 1/freq[idx_best]
    fap      = ls.false_alarm_probability(power[idx_best])
    return P_best, power[idx_best], fap

# -------------------------------------------------------------------------
# One-harmonic Bayesian fit (identifiable parametrisation) ----------------
def fit_one_harmonic(t, y, P_init,
                     draws=1_000, tune=1_000, chains=4, ta=0.95, seed=42):

    omega0   = 2*np.pi / P_init
    y_std    = y.std()

    with pm.Model() as model:
        omega = pm.Normal("ω", mu=omega0, sigma=0.2*omega0)
        a     = pm.Normal("a",  sigma=2*y_std)
        b     = pm.Normal("b",  sigma=2*y_std)
        c     = pm.Normal("c",  sigma=y_std)
        σ     = pm.HalfNormal("σ", sigma=y_std)

        mu = a*pm.math.cos(omega*t) + b*pm.math.sin(omega*t) + c
        pm.Normal("obs", mu=mu, sigma=σ, observed=y)

        trace = pm.sample(draws, tune=tune, chains=chains, cores=min(chains,4),
                          target_accept=ta, random_seed=seed,
                          progressbar=True, return_inferencedata=True)

    # convert to amplitude/phase/period (posterior means)
    a_hat   = trace.posterior["a"].mean().item()
    b_hat   = trace.posterior["b"].mean().item()
    A_hat   = np.hypot(a_hat, b_hat)
    phi_hat = (np.arctan2(-b_hat, a_hat)) % (2*np.pi)
    P_hat   = (2*np.pi) / trace.posterior["ω"].mean().item()

    harmonic = A_hat*np.cos(2*np.pi*t/P_hat + phi_hat)
    return harmonic, dict(P=P_hat, A=A_hat, phi=phi_hat), trace

# -------------------------------------------------------------------------
# Full extraction loop -----------------------------------------------------
def extract_harmonics_auto(t, y,
                           fap_stop=0.01, min_rel_sep=0.05,
                           max_harmonics=30, **sampler_kw):

    residuals = y.copy()
    harmonics, traces = [], []
    params = []

    for _ in range(max_harmonics):

        P0, pwr, fap = best_period(t, residuals, fap_level=fap_stop)
        if P0 is None or fap > fap_stop:
            break                                             # nothing significant left

        # keep only if period not "too close" to existing ones
        if any(abs(P0 - p["P"])/P0 < min_rel_sep for p in params):
            residuals = residuals                             # skip & keep searching
            continue

        harm, p, tr = fit_one_harmonic(t, residuals, P0, **sampler_kw)
        residuals  -= harm
        harmonics.append(harm)
        traces.append(tr)
        params.append(p)

    model = np.sum(harmonics, axis=0) if harmonics else np.zeros_like(y)
    return model, pd.DataFrame(params).sort_values("P", ascending=False), residuals, traces

# -------------------------------------------------------------------------
# Diagnostics & plotting ---------------------------------------------------
def show_results(t, y, model, residuals):
    plt.figure(figsize=(9, 6))
    plt.plot(t/1e3, y,  lw=0.4,  color="0.3", label="data (δ¹⁸O ‰)")
    plt.plot(t/1e3, model, lw=1.2, label="sum of harmonics")
    plt.plot(t/1e3, residuals, lw=0.4, label="residuals")
    plt.xlabel("time (kyr BP)");  plt.legend();  plt.grid(alpha=.3)
    plt.show()

# -------------------------------------------------------------------------
# Example run --------------------------------------------------------------
if __name__ == "__main__":
    t, y = load_data("./data/d18O NGRIP 21.04.24.xlsx", first_kyr=30_000)

    model, params, res, traces = extract_harmonics_auto(
        t, y,
        draws=1_500, tune=1_000, chains=8,  # 8 chains now behaves!
        ta=0.9, fap_stop=0.01)

    print(params)                     # amplitudes, periods, phases
    show_results(t, y, model, res)


Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (8 chains in 4 jobs)
NUTS: [ω, a, b, c, σ]
Sampling 8 chains for 1_000 tune and 1_500 draw iterations (8_000 + 12_000 draws total) took 21 seconds.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (8 chains in 4 jobs)
NUTS: [ω, a, b, c, σ]
Sampling 8 chains for 1_000 tune and 1_500 draw iterations (8_000 + 12_000 draws total) took 27 seconds.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  