In [None]:
import warnings
warnings.filterwarnings("ignore", message="invalid value encountered in scalar subtract")

import pandas as pd
import numpy as np
import copy
import matplotlib.pyplot as plt
import emcee
import corner
import os
from mpl_toolkits.mplot3d import Axes3D  # for 3D plotting
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_absolute_error, r2_score, mean_squared_error
from matplotlib.colors import LogNorm
from sklearn.cluster import DBSCAN  # for clustering the MCMC samples
import plotly.express as px

###############################################################################
# DBSCAN cluster function (unchanged)
###############################################################################
def cluster_mcmc_samples(chain_samples, eps=0.1, min_samples=30):
    clustering = DBSCAN(eps=eps, min_samples=min_samples).fit(chain_samples)
    labels = clustering.labels_
    clusters = []
    for label in set(labels):
        if label == -1:  # ignore noise points
            continue
        indices = np.where(labels == label)[0]
        cluster_samples = chain_samples[indices]
        mean = np.mean(cluster_samples, axis=0)
        std = np.std(cluster_samples, axis=0)
        clusters.append({"mean": mean, "std": std, "n_samples": len(cluster_samples), "indices": indices})
    return clusters

###############################################################################
# Standard Gelman-Rubin convergence diagnostic for a set of chains.
###############################################################################
def gelman_rubin(chains):
    """
    Computes the Gelman-Rubin statistic for a set of chains.
    Input:
      chains: numpy array with shape (n_steps, n_walkers, n_dim)
    Returns:
      rhat: numpy array of length n_dim with the R̂ values.
    """
    n_steps, n_walkers, n_dim = chains.shape
    rhat = np.empty(n_dim)
    for d in range(n_dim):
        # Get chain for parameter d: shape (n_steps, n_walkers)
        chain_d = chains[:, :, d]
        chain_means = np.mean(chain_d, axis=0)
        overall_mean = np.mean(chain_means)
        B = n_steps * np.var(chain_means, ddof=1)
        W = np.mean(np.var(chain_d, axis=0, ddof=1))
        Var_hat = ((n_steps - 1) / n_steps) * W + (B / n_steps)
        rhat[d] = np.sqrt(Var_hat / W)
    return rhat

###############################################################################
# Gelman-Rubin diagnostic computed on samples in one cluster
###############################################################################
def gelman_rubin_cluster(chain_post, cluster_indices, n_walkers):
    """
    Compute Rhat for a given cluster. The cluster_indices are indices from the 
    flattened chain (from chain_post.reshape(-1, ndim)). This function re-groups 
    the samples by walker.

    Parameters:
      chain_post: numpy array of shape (n_steps, n_walkers, n_dim)
      cluster_indices: 1D array of flattened indices (from DBSCAN clustering)
      n_walkers: number of walkers (as in chain_post.shape[1])
    
    Returns:
      rhat: R̂ values for this cluster (numpy array of length n_dim),
            or None if insufficient samples/walkers.
    """
    n_steps, n_walkers, n_dim = chain_post.shape

    # Build a dictionary: for each walker, store a list of samples (in order of appearance).
    samples_by_walker = {w: [] for w in range(n_walkers)}
    
    # Because the flattened chain was obtained with row-major order,
    # each index i corresponds to: step = i // n_walkers, walker = i % n_walkers.
    for idx in cluster_indices:
        step = idx // n_walkers
        walker = idx % n_walkers
        samples_by_walker[walker].append(chain_post[step, walker, :])
    
    # Only keep walkers with at least 2 samples
    valid_walkers = {w: np.array(samples) for w, samples in samples_by_walker.items() if len(samples) >= 2}
    
    if len(valid_walkers) < 2:
        # Not enough chains in this cluster to compute Rhat
        return None

    # Truncate each walker's chain to the minimum number of samples found
    min_samples = min(len(samples) for samples in valid_walkers.values())
    # Build an array of shape (min_samples, n_valid_walkers, n_dim)
    chains_cluster = []
    for w in valid_walkers:
        chain_w = valid_walkers[w][:min_samples]
        chains_cluster.append(chain_w)
    chains_cluster = np.stack(chains_cluster, axis=1)  # shape: (min_samples, n_valid_walkers, n_dim)
    rhat = gelman_rubin(chains_cluster)
    return rhat

###############################################################################
# Plot predicted vs. actual
###############################################################################
def plot_predicted_vs_actual(y_true, y_pred, model_name, target_names):
    for i, target in enumerate(target_names):
        plt.figure()
        plt.scatter(y_true[:, i], y_pred[:, i], alpha=0.7)
        plt.plot([min(y_true[:, i]), max(y_true[:, i])],
                 [min(y_true[:, i]), max(y_true[:, i])], 'r--')
        plt.xlabel(f"Actual {target}")
        plt.ylabel(f"Predicted {target}")
        plt.title(f"{model_name} - Predicted vs. Actual: {target}")
        plt.show()

###############################################################################
# Get bin mask
###############################################################################
def get_bin_mask(df_all, period_bin, ecc_bin):
    if period_bin == "p1":
        period_mask = (df_all["ratio"] >= 1.9) & (df_all["ratio"] <= 1.99)
    elif period_bin == "p2":
        period_mask = (df_all["ratio"] >= 2.01) & (df_all["ratio"] <= 2.1)
    elif period_bin == "p3":
        period_mask = (df_all["ratio"] >= 1.4) & (df_all["ratio"] <= 1.49)
    elif period_bin == "p4":
        period_mask = (df_all["ratio"] >= 1.51) & (df_all["ratio"] <= 1.6)
    elif period_bin == "p5":
        period_mask = (df_all["ratio"] >= 2.2) & (df_all["ratio"] <= 2.5)
    else:
        period_mask = (df_all["ratio"] > 2.5) & (df_all["ratio"] <= 2.8)

    if ecc_bin == "e1":
        ecc_mask = (df_all["out_ecc"] <= -2.0)
    elif ecc_bin == "e2":
        ecc_mask = (df_all["out_ecc"] > -2.0) & (df_all["out_ecc"] <= -1.5228)
    else:
        ecc_mask = (df_all["out_ecc"] > -1.5228) & (df_all["out_ecc"] <= np.log10(0.08))
    return period_mask & ecc_mask

###############################################################################
# Build ModelB for a bin
###############################################################################
def build_modelB_for_bin(df_all, period_bin, ecc_bin):
    print(f"\n=== build_modelB_for_bin => period='{period_bin}', eccentricity='{ecc_bin}'")
    mask_bin = get_bin_mask(df_all, period_bin, ecc_bin)
    mask_amp = df_all["AmpP1"] <= 500
    df = df_all[mask_bin & mask_amp].copy()
    if df.empty:
        print("No data for this bin.")
        return None, None, None
    featB = ["star_m", "inn_p", "inn_m", "inn_ecc", "inn_inc", "inn_omega", "AmpP1", "DomP1"]
    targB = ["out_m", "out_p"]
    df.dropna(subset=(featB + targB), inplace=True)
    if df.empty or len(df) < 5:
        print("Not enough data after dropping NA for this bin.")
        return None, None, None
    X_tr, X_val, y_tr, y_val = train_test_split(df[featB], df[targB],
                                                test_size=0.2, random_state=42)
    scB = StandardScaler()
    X_tr_sc = scB.fit_transform(X_tr)
    X_val_sc = scB.transform(X_val)
    rfB = RandomForestRegressor(n_estimators=100, max_depth=10, random_state=42)
    rfB.fit(X_tr_sc, y_tr)
    y_pred_val = rfB.predict(X_val_sc)
    plot_predicted_vs_actual(y_val.values, y_pred_val,
                             f"Model B ({period_bin}, {ecc_bin})", targB)
    mae_vals = []
    rmse_vals = []
    mafe_vals = []
    r2_vals = []
    for i, col_ in enumerate(targB):
        mae_ = mean_absolute_error(y_val[col_], y_pred_val[:, i])
        r2_  = r2_score(y_val[col_], y_pred_val[:, i])
        rmse_ = np.sqrt(mean_squared_error(y_val[col_], y_pred_val[:, i]))
        mafe_ = np.mean(np.abs((y_val[col_] - y_pred_val[:, i]) / (y_val[col_] + 1e-8)))
        print(f"ModelB({period_bin}, {ecc_bin}) - {col_}: MAE={mae_:.4f}, RMSE={rmse_:.4f}, MAFE={mafe_:.4f}, R2={r2_:.4f}")
        mae_vals.append(mae_)
        rmse_vals.append(rmse_)
        mafe_vals.append(mafe_)
        r2_vals.append(r2_)
    error_metrics_B = {"MAE": np.mean(mae_vals),
                       "RMSE": np.mean(rmse_vals),
                       "MAFE": np.mean(mafe_vals),
                       "R2": np.mean(r2_vals)}
    return rfB, scB, error_metrics_B

###############################################################################
# Build ModelA for a bin
###############################################################################
def build_modelA_for_bin(df_all, period_bin, ecc_bin):
    print(f"\n=== build_modelA_for_bin => period='{period_bin}', eccentricity='{ecc_bin}' ===")
    featA = ["star_m", "inn_p", "inn_m", "inn_ecc", "inn_inc", "inn_omega",
             "out_p", "out_m", "out_ecc", "out_inc", "out_omega"]
    targA = ["AmpP1", "DomP1"]
    mask_bin = get_bin_mask(df_all, period_bin, ecc_bin)
    mask_amp = df_all["AmpP1"] <= 500
    df = df_all[mask_bin & mask_amp].copy()
    if df.empty or len(df) < 5:
        print("Not enough data for this bin.")
        return None, None, None
    df.dropna(subset=(featA + targA), inplace=True)
    if df.empty:
        print("Not enough data after dropping NA for this bin.")
        return None, None, None
    X_tr, X_val, y_tr, y_val = train_test_split(df[featA], df[targA],
                                                test_size=0.2, random_state=42)
    scA = StandardScaler()
    X_tr_sc = scA.fit_transform(X_tr)
    X_val_sc = scA.transform(X_val)
    rfA = RandomForestRegressor(n_estimators=100, max_depth=10, random_state=42)
    rfA.fit(X_tr_sc, y_tr)
    y_pred_val = rfA.predict(X_val_sc)
    plot_predicted_vs_actual(y_val.values, y_pred_val,
                             f"Model A ({period_bin}, {ecc_bin})", targA)
    mae_vals = []
    rmse_vals = []
    mafe_vals = []
    r2_vals = []
    for i, col_ in enumerate(targA):
        mae_ = mean_absolute_error(y_val[col_], y_pred_val[:, i])
        r2_  = r2_score(y_val[col_], y_pred_val[:, i])
        rmse_ = np.sqrt(mean_squared_error(y_val[col_], y_pred_val[:, i]))
        mafe_ = np.mean(np.abs((y_val[col_] - y_pred_val[:, i]) / (y_val[col_] + 1e-8)))
        print(f"ModelA({period_bin}, {ecc_bin}) - {col_}: MAE={mae_:.4f}, RMSE={rmse_:.4f}, MAFE={mafe_:.4f}, R2={r2_:.4f}")
        mae_vals.append(mae_)
        rmse_vals.append(rmse_)
        mafe_vals.append(mafe_)
        r2_vals.append(r2_)
    error_metrics_A = {"MAE": np.mean(mae_vals),
                       "RMSE": np.mean(rmse_vals),
                       "MAFE": np.mean(mafe_vals),
                       "R2": np.mean(r2_vals)}
    return rfA, scA, error_metrics_A

###############################################################################
# LOG-PROB for MCMC
###############################################################################
def log_prob_3params_layer1(theta_, bin_name, param_fixed, modelA, scA, actual_amp, actual_freq):
    newD = copy.deepcopy(param_fixed)
    out_p = theta_[0]
    out_e = theta_[1]  # in log scale
    out_m = theta_[2]
    newD["out_p"]   = out_p
    newD["out_ecc"] = out_e
    newD["out_m"]   = out_m

    # Period prior
    if out_p < param_fixed["period_prior_min"] or out_p > param_fixed["period_prior_max"]:
        return -np.inf
    # ratio in correct bin
    inner_p = newD["inn_p"]
    period_ratio = out_p / inner_p
    if period_ratio < param_fixed["period_bin_min"] or period_ratio > param_fixed["period_bin_max"]:
        return -np.inf
    # e must be <=1 => out_e <=0
    if out_e > 0.0:
        return -np.inf
    if bin_name == "e1":
        if out_e >= -2.0:
            return -np.inf
    elif bin_name == "e2":
        if out_e < -2.0 or out_e > -1.5228:
            return -np.inf
    else:
        if out_e <= -1.5228 or out_e > np.log10(0.08):
            return -np.inf
    # mass prior
    mm = newD["out_m"]
    if mm < param_fixed["mass_prior_min"] or mm > param_fixed["mass_prior_max"]:
        return -np.inf

    # SSE from modelA
    featA = ["star_m","inn_p","inn_m","inn_ecc","inn_inc","inn_omega",
             "out_p","out_m","out_ecc","out_inc","out_omega"]
    rowA = [[ newD["star_m"], newD["inn_p"], newD["inn_m"], newD["inn_ecc"],
              newD["inn_inc"], newD["inn_omega"],
              newD["out_p"], newD["out_m"], newD["out_ecc"],
              newD["out_inc"], newD["out_omega"] ]]
    X_dfA = pd.DataFrame(rowA, columns=featA)
    X_scA = scA.transform(X_dfA)
    pred_ = modelA.predict(X_scA)[0]
    da = pred_[0] - actual_amp
    df = pred_[1] - actual_freq
    sse = da**2 + df**2
    return -0.5 * sse

###############################################################################
# MCMC function with per-cluster Gelman-Rubin convergence diagnostics and
# computation of an evidence estimate (lnZ) for Bayes factor comparisons.
###############################################################################
def do_mcmc_3param_layer1(period_bin, bin_name, old_param,
                          start_p, p_std, start_m, start_e,
                          modelA, scA,
                          actual_amp, actual_freq,
                          nsteps=400, nwalkers=30):
    param_fixed = copy.deepcopy(old_param)
    # Set period bin limits based on the period_bin argument
    if period_bin == "p1":
        param_fixed["period_bin_min"] = 1.9;   param_fixed["period_bin_max"] = 1.99
    elif period_bin == "p2":
        param_fixed["period_bin_min"] = 2.01;  param_fixed["period_bin_max"] = 2.1
    elif period_bin == "p3":
        param_fixed["period_bin_min"] = 1.4;   param_fixed["period_bin_max"] = 1.49
    elif period_bin == "p4":
        param_fixed["period_bin_min"] = 1.51;  param_fixed["period_bin_max"] = 1.6
    elif period_bin == "p5":
        param_fixed["period_bin_min"] = 2.2;   param_fixed["period_bin_max"] = 2.5
    else:
        param_fixed["period_bin_min"] = 2.5;   param_fixed["period_bin_max"] = 2.8

    param_fixed["period_prior_min"] = start_p - 3.0 * p_std
    param_fixed["period_prior_max"] = start_p + 3.0 * p_std
    param_fixed["mass_prior_min"]   = max((1.0/3.0)*start_m, 0.0)
    param_fixed["mass_prior_max"]   = min(3.0*start_m, 50.0)

    def local_logprob(theta_):
        return log_prob_3params_layer1(theta_, bin_name, param_fixed,
                                       modelA, scA,
                                       actual_amp, actual_freq)

    ndim = 3
    n_cluster = int(np.ceil(nwalkers * 0.7))
    n_uniform = nwalkers - n_cluster

    # Initialize period (p)
    p_allowed_min = old_param["inn_p"] * param_fixed["period_bin_min"]
    p_allowed_max = old_param["inn_p"] * param_fixed["period_bin_max"]
    p_cluster = start_p + 0.1*np.random.randn(n_cluster)
    p_cluster = np.clip(p_cluster, p_allowed_min, p_allowed_max)
    if n_uniform >= 2:
        p_uniform_random = np.random.uniform(p_allowed_min, p_allowed_max, size=n_uniform-2)
        p_uniform = np.concatenate(([p_allowed_min, p_allowed_max], p_uniform_random))
    else:
        p_uniform = np.random.uniform(p_allowed_min, p_allowed_max, size=n_uniform)
    p_init = np.concatenate([p_cluster, p_uniform])
    np.random.shuffle(p_init)

    # Initialize eccentricity (e)
    if bin_name=="e1":
        lower_e, upper_e = -9.0, -2.001
    elif bin_name=="e2":
        lower_e, upper_e = -2.0, -1.5228
    else:
        lower_e, upper_e = -1.5228, np.log10(0.08)
    e_cluster = start_e + 0.1*np.random.randn(n_cluster)
    e_cluster = np.clip(e_cluster, lower_e, upper_e)
    if n_uniform >= 2:
        e_uniform_random = np.random.uniform(lower_e, upper_e, size=n_uniform-2)
        e_uniform = np.concatenate(([lower_e, upper_e], e_uniform_random))
    else:
        e_uniform = np.random.uniform(lower_e, upper_e, size=n_uniform)
    e_init = np.concatenate([e_cluster, e_uniform])
    np.random.shuffle(e_init)

    # Initialize mass (m)
    m_cluster = start_m + (np.random.rand(n_cluster)-0.5)*8
    m_cluster = np.clip(m_cluster, param_fixed["mass_prior_min"], param_fixed["mass_prior_max"])
    m_uniform = np.random.uniform(param_fixed["mass_prior_min"], param_fixed["mass_prior_max"], size=n_uniform)
    m_init = np.concatenate([m_cluster, m_uniform])
    np.random.shuffle(m_init)

    pos0 = np.column_stack((p_init, e_init, m_init))
    print("\n=== do_mcmc_3param_layer1 => initial setup ===")
    print(f"  period_bin={period_bin}, e-bin={bin_name}")
    print(f"  start_p={start_p:.6f}, p_std={p_std:.6f}, start_m={start_m:.6f}, start_e={start_e:.6f}")
    print(f"  period_prior=[{param_fixed['period_prior_min']:.6f}..{param_fixed['period_prior_max']:.6f}]")
    print(f"  period_bin_range=[{param_fixed['period_bin_min']:.6f}..{param_fixed['period_bin_max']:.6f}]")
    print(f"  mass_prior=[{param_fixed['mass_prior_min']:.6f}..{param_fixed['mass_prior_max']:.6f}]")
    for i, pos_ in enumerate(pos0):
        print(f"    walker {i}: p={pos_[0]:.6f}, e={pos_[1]:.6f}, m={pos_[2]:.6f}")

    # Evaluate the log-probability at each initial walker position
    log_probs = np.array([local_logprob(pos) for pos in pos0])
    print("Initial log probabilities:", log_probs)
    
    # Optionally, check if any of the log probabilities are -inf (i.e., rejected)
    if not np.all(np.isfinite(log_probs)):
        print("Warning: Some initial log probabilities are -inf. Adjust your priors or initial guesses.")
    
    sampler = emcee.EnsembleSampler(nwalkers, ndim, local_logprob, a = 1)
    sampler.run_mcmc(pos0, nsteps, progress=True)

    # Plot chain traces (convert eccentricity to linear scale for plotting)
    chain_all = sampler.get_chain()
    chain_plot = chain_all.copy()
    chain_plot[:,:,1] = 10**(chain_plot[:,:,1])
    fig_chain, axes_chain = plt.subplots(ndim,1, figsize=(6,6), sharex=True)
    labels_3 = ["out_p","out_e","out_m"]
    for d in range(ndim):
        axes_chain[d].plot(chain_plot[:,:,d], alpha=0.5)
        axes_chain[d].set_ylabel(labels_3[d])
    axes_chain[-1].set_xlabel("Step")
    plt.tight_layout()
    plt.show()
    
    # --- NEW: Compute and print SSE after chain plot ---
    chain_post_temp = sampler.get_chain(discard=nsteps//2, thin=2)
    chain_flat_temp = chain_post_temp.reshape(-1, ndim).copy()
    overall_mean = np.mean(chain_flat_temp, axis=0)
    overall_p = overall_mean[0]
    overall_e_lin = overall_mean[1]
    overall_m = overall_mean[2]
    if overall_e_lin <= 1e-12:
        overall_e_log = -9.0
    else:
        overall_e_log = np.log10(overall_e_lin)
    newD_final = copy.deepcopy(param_fixed)
    newD_final["out_p"] = overall_p
    newD_final["out_ecc"] = overall_e_log
    newD_final["out_m"] = overall_m
    featA = ["star_m","inn_p","inn_m","inn_ecc","inn_inc","inn_omega",
             "out_p","out_m","out_ecc","out_inc","out_omega"]
    rowA_ = [[ newD_final["star_m"], newD_final["inn_p"], newD_final["inn_m"], newD_final["inn_ecc"],
               newD_final["inn_inc"], newD_final["inn_omega"],
               newD_final["out_p"], newD_final["out_m"], newD_final["out_ecc"],
               newD_final["out_inc"], newD_final["out_omega"] ]]
    X_dfA_ = pd.DataFrame(rowA_, columns=featA)
    X_scA_ = scA.transform(X_dfA_)
    pred__ = modelA.predict(X_scA_)[0]
    da_ = pred__[0] - actual_amp
    df_ = pred__[1] - actual_freq
    sse_ = da_**2 + df_**2
    print(f"Chain plot SSE = {sse_:.6f}")
    # --- END NEW ---

    # Get post burn-in chain in 3D: shape (n_steps_post, n_walkers, ndim)
    chain_post = sampler.get_chain(discard=nsteps//2, thin=2)
    # Flatten for clustering; note we convert eccentricity to linear scale for clustering.
    chain_flat = chain_post.reshape(-1, ndim).copy()
    chain_for_clust = chain_flat.copy()
    chain_for_clust[:,1] = 10**(chain_for_clust[:,1])

    # Cluster the flattened chain using DBSCAN
    clusters = cluster_mcmc_samples(chain_for_clust, eps=0.4, min_samples=500)
    print("\n=== DBSCAN clustering results ===")
    for iC, c_ in enumerate(clusters):
        print(f"Cluster {iC}: {c_['n_samples']} samples")
        # Compute Gelman-Rubin for this cluster separately:
        rhat_cluster = gelman_rubin_cluster(chain_post, c_["indices"], nwalkers)
        if rhat_cluster is not None:
            print(f"  Cluster {iC} Gelman-Rubin R̂: {rhat_cluster}")
        else:
            print(f"  Cluster {iC}: Not enough samples per walker to compute R̂ reliably.")
    if not clusters:
        print("No clusters => fallback to overall chain mean.")
        rhat_all = gelman_rubin(chain_post)  # --- NEW: Compute overall R̂ ---
        print(f"Overall chain Gelman-Rubin R̂: {rhat_all}")  # --- NEW: Print overall R̂ ---
        mean_ = np.mean(chain_flat, axis=0)
        std_  = np.std(chain_flat, axis=0)
        clusters = [{"mean": mean_, "std": std_, "n_samples": len(chain_flat),
                     "indices": np.arange(len(chain_flat))}]
    
    solutions = []
    for c_ in clusters:
        p_mean = c_["mean"][0]
        p_std_ = c_["std"][0]
        m_mean = c_["mean"][2]
        m_std_ = c_["std"][2]
        mean_lin_e = c_["mean"][1]
        std_lin_e  = c_["std"][1]
        if mean_lin_e<=1e-12:
            mean_log_e = -9.0
            std_log_e  = 0.0
        else:
            mean_log_e = np.log10(mean_lin_e)
            std_log_e  = std_lin_e / (mean_lin_e*np.log(10)) if mean_lin_e>0 else 0.0

        # TTV predictions for each sample in cluster
        indices_ = c_["indices"]
        cluster_samples_log = chain_flat[indices_]
        ttv_preds = []
        for samp_ in cluster_samples_log:
            newD_ = copy.deepcopy(param_fixed)
            newD_["out_p"]   = samp_[0]
            newD_["out_ecc"] = samp_[1]
            newD_["out_m"]   = samp_[2]
            featA = ["star_m","inn_p","inn_m","inn_ecc","inn_inc","inn_omega",
                     "out_p","out_m","out_ecc","out_inc","out_omega"]
            rowA = [[ newD_["star_m"], newD_["inn_p"], newD_["inn_m"], newD_["inn_ecc"],
                      newD_["inn_inc"], newD_["inn_omega"],
                      newD_["out_p"], newD_["out_m"], newD_["out_ecc"],
                      newD_["out_inc"], newD_["out_omega"] ]]
            X_dfA = pd.DataFrame(rowA, columns=featA)
            X_scA = scA.transform(X_dfA)
            pred_ = modelA.predict(X_scA)[0]
            ttv_preds.append(pred_)
        ttv_preds = np.array(ttv_preds)
        predAmp_mean = np.mean(ttv_preds[:,0])
        predAmp_std  = np.std(ttv_preds[:,0])
        predDomP_mean = np.mean(ttv_preds[:,1])
        predDomP_std = np.std(ttv_preds[:,1])

        newD_final = copy.deepcopy(param_fixed)
        newD_final["out_p"]   = p_mean
        newD_final["out_ecc"] = mean_log_e
        newD_final["out_m"]   = m_mean
        newD_final["predAmp"] = predAmp_mean
        newD_final["predDomP"] = predDomP_mean

        # Compute SSE
        featA = ["star_m","inn_p","inn_m","inn_ecc","inn_inc","inn_omega",
                 "out_p","out_m","out_ecc","out_inc","out_omega"]
        rowA_ = [[ newD_final["star_m"], newD_final["inn_p"], newD_final["inn_m"], newD_final["inn_ecc"],
                   newD_final["inn_inc"], newD_final["inn_omega"],
                   newD_final["out_p"], newD_final["out_m"], newD_final["out_ecc"],
                   newD_final["out_inc"], newD_final["out_omega"] ]]
        X_dfA_ = pd.DataFrame(rowA_, columns=featA)
        X_scA_ = scA.transform(X_dfA_)
        pred__ = modelA.predict(X_scA_)[0]
        da_ = pred__[0] - actual_amp
        df_ = pred__[1] - actual_freq
        sse_ = da_**2 + df_**2

        solutions.append({"final_param_dict": newD_final,
                          "p_mean": (p_mean, p_std_),
                          "e_mean": (mean_log_e, std_log_e),
                          "mass_mean": (m_mean, m_std_),
                          "ttvAmp": (predAmp_mean, predAmp_std),
                          "ttvDomP": (predDomP_mean, predDomP_std),
                          "SSE": sse_})
    print("\nConvergence solutions for this bin combo:")
    for iSol, sol_ in enumerate(solutions):
        fd = sol_["final_param_dict"]
        p_val, p_err = sol_["p_mean"]
        m_val, m_err = sol_["mass_mean"]
        e_log, e_err = sol_["e_mean"]
        e_val = 10**(e_log)
        e_err_lin = 10**(e_log + e_err) - 10**(e_log) if e_val>1e-12 else 0.0
        ttv_amp, ttv_amp_err = sol_["ttvAmp"]
        ttv_domP, ttv_domP_err = sol_["ttvDomP"]
        sse_ = sol_["SSE"]
        print(f"  Solution {iSol}:")
        print(f"    Outer Period = {p_val:.6f} ± {p_err:.6f}")
        print(f"    Outer Mass   = {m_val:.6f} ± {m_err:.6f}")
        print(f"    Outer Ecc (linear) = {e_val:.6f} ± {e_err_lin:.6f}")
        print(f"    Predicted TTV Amp   = {ttv_amp:.6f} ± {ttv_amp_err:.6f}")
        print(f"    Predicted TTV DomP  = {ttv_domP:.6f} ± {ttv_domP_err:.6f}")
        print(f"    SSE = {sse_:.6f}\n")

    # --- NEW: Compute Bayesian evidence using a harmonic mean estimator ---
    log_prob_post = sampler.get_log_prob(discard=nsteps//2, thin=2, flat=True)
    log_prob_post = log_prob_post[np.isfinite(log_prob_post)]
    if len(log_prob_post) > 0:
        L_values = np.exp(log_prob_post)
        Z_est = 1.0 / np.mean(1.0 / L_values)
        lnZ = np.log(Z_est)
    else:
        lnZ = -np.inf
    print(f"Estimated log evidence (lnZ) = {lnZ:.6f}")
    # --- END NEW ---

    # Corner plots for visualization (unchanged)
    p_min, p_max = np.min(chain_flat[:,0]), np.max(chain_flat[:,0])
    m_min, m_max = np.min(chain_flat[:,2]), np.max(chain_flat[:,2])
    pad_p = 0.1*(p_max-p_min)
    pad_m = 0.1*(m_max-m_min)
    full_range = [(p_min-pad_p, p_max+pad_p),
                  (np.min(10**(chain_flat[:,1])), np.max(10**(chain_flat[:,1]))),
                  (m_min-pad_m, m_max+pad_m)]
    chain_flat_for_corner = chain_flat.copy()
    chain_flat_for_corner[:,1] = 10**(chain_flat_for_corner[:,1])
    labels_3 = ["out_p","out_e","out_m"]

    fig_corner_all = corner.corner(chain_flat_for_corner,
                                   labels=labels_3,
                                   bins=50,
                                   range=full_range,
                                   color='red',
                                   show_titles=True,
                                   title_fmt=".3f")
    plt.title("Corner Plot: All Samples")
    plt.show()

    fig_heat = corner.corner(chain_flat_for_corner,
                             labels=labels_3,
                             bins=50,
                             range=full_range,
                             plot_density=True,
                             plot_datapoints=False,
                             smooth=1,
                             smooth1d=1,
                             fill_contours=True,
                             contour_kwargs={},
                             color=None,
                             show_titles=True,
                             title_fmt=".3f")
    #plt.title("Corner Plot")
    plt.show()

    chain_flat_post = sampler.get_chain(discard=nsteps//2, thin=2, flat=True).copy()
    chain_flat_post[:,1] = 10**(chain_flat_post[:,1])
    fig_corner_post = corner.corner(chain_flat_post,
                                    labels=labels_3,
                                    bins=50,
                                    range=full_range,
                                    color='blue',
                                    show_titles=True,
                                    title_fmt=".3f")
    plt.title("Corner Plot: Post Burn-In")
    plt.show()

    return solutions, lnZ

###############################################################################
# SINGLE-LAYER pipeline
###############################################################################
def run_singlelayer_pipeline_for_system(sysdict, modelB_map, modelA_map):
    baseParam = copy.deepcopy(sysdict)
    combos_e = {}
    period_bins = ["p1","p2","p3","p4","p5","p6"]
    ecc_bins = ["e1","e2","e3"]
    for pbin in period_bins:
        for ebin in ecc_bins:
            (rfB, scB, errB) = modelB_map.get((pbin, ebin),(None,None,None))
            if rfB is None:
                if pbin=="p1":
                    guess_p = baseParam["inn_p"]*1.95
                elif pbin=="p2":
                    guess_p = baseParam["inn_p"]*2.05
                elif pbin=="p3":
                    guess_p = baseParam["inn_p"]*1.445
                elif pbin=="p4":
                    guess_p = baseParam["inn_p"]*1.555
                elif pbin=="p5":
                    guess_p = baseParam["inn_p"]*2.35
                else:
                    guess_p = baseParam["inn_p"]*2.65
                guess_m = 10.0
            else:
                featB = ["star_m","inn_p","inn_m","inn_ecc","inn_inc","inn_omega","AmpP1","DomP1"]
                rowB = [[ baseParam["star_m"], baseParam["inn_p"], baseParam["inn_m"],
                          baseParam["inn_ecc"], baseParam["inn_inc"], baseParam["inn_omega"],
                          baseParam["AmpP1"], baseParam["DomP1"] ]]
                X_dfB = pd.DataFrame(rowB, columns=featB)
                X_scB = scB.transform(X_dfB)
                outB_ = rfB.predict(X_scB)[0]
                guess_m = outB_[0]
                guess_p = outB_[1]
            p_std_for_mcmc = 0.05*guess_p
            if p_std_for_mcmc<1e-6:
                p_std_for_mcmc = max(0.5, 0.05*baseParam["inn_p"])
            if ebin=="e1":
                guess_e = -2.5
            elif ebin=="e2":
                guess_e = -1.8
            else:
                guess_e = -1.3
            (rfA, scA, errA) = modelA_map.get((pbin, ebin),(None,None,None))
            sol_list, lnZ = do_mcmc_3param_layer1(
                period_bin=pbin, bin_name=ebin, old_param=baseParam,
                start_p=guess_p, p_std=p_std_for_mcmc,
                start_m=guess_m, start_e=guess_e,
                modelA=rfA, scA=scA,
                actual_amp=baseParam["AmpP1"], actual_freq=baseParam["DomP1"],
                nsteps=400, nwalkers=40
            )
            combos_e[(pbin,ebin)] = {"solutions": sol_list, "lnZ": lnZ, "errA": errA, "errB": errB}
    combos_final=[]
    for (pbin,ebin), results in combos_e.items():
        combos_final.append({"bins":(pbin,ebin), "solutions": results["solutions"], "lnZ": results["lnZ"],
                             "errA": results["errA"], "errB": results["errB"]})
    return combos_final

###############################################################################
# 3D / scatter plot helpers (unchanged)
###############################################################################
def set_3d_axes_and_reverse_mass(ax):
    x1,x2= ax.get_xlim3d()
    y1,y2= ax.get_ylim3d()
    z1,z2= ax.get_zlim3d()
    ax.set_xlim3d(x1,x2)
    ax.set_ylim3d(y2,y1)
    ax.set_zlim3d(z1,z2)

def reverse_mass_axis_plotly(fig):
    fig.update_layout(scene=dict(yaxis=dict(autorange="reversed")))

def three_3d_diff_plots_with_mass_reversed(X,Y,Z, title_extra=""):
    fig1= plt.figure()
    ax1= fig1.add_subplot(111, projection='3d')
    ax1.scatter(X,Y,Z, c='b', marker='o')
    ax1.set_xlabel("X")
    ax1.set_ylabel("Y (mass reversed)")
    ax1.set_zlabel("Z")
    ax1.set_title(f"3D Scatter {title_extra}")
    plt.draw()
    set_3d_axes_and_reverse_mass(ax1)
    plt.show()
    if len(X)>=3:
        fig2= plt.figure()
        ax2= fig2.add_subplot(111, projection='3d')
        ax2.plot_trisurf(X,Y,Z, edgecolor='gray', linewidth=0.2, alpha=0.5)
        ax2.set_xlabel("X")
        ax2.set_ylabel("Y (mass reversed)")
        ax2.set_zlabel("Z")
        ax2.set_title(f"3D Wire/TriSurf {title_extra}")
        plt.draw()
        set_3d_axes_and_reverse_mass(ax2)
        plt.show()
    fig3= px.scatter_3d(x=X, y=Y, z=Z, title=f"Interactive 3D {title_extra}")
    reverse_mass_axis_plotly(fig3)
    fig3.show()

def plot_3d_actual_pred_rev(p_act, m_act, e_act, p_pred, m_pred, e_pred, sys_idx):
    X= [p_act, p_pred]
    Y= [m_act, m_pred]
    Z= [e_act, e_pred]
    figA= plt.figure()
    axA= figA.add_subplot(111, projection='3d')
    axA.scatter([p_act],[m_act],[e_act], c='b', s=50, label='Actual')
    axA.scatter([p_pred],[m_pred],[e_pred], c='r', s=50, label='Predicted')
    axA.set_xlabel("Period")
    axA.set_ylabel("Mass (reversed)")
    axA.set_zlabel("Ecc")
    axA.set_title(f"System {sys_idx} => 3D scatter reversed mass")
    plt.draw()
    set_3d_axes_and_reverse_mass(axA)
    axA.legend()
    plt.show()

    figB= plt.figure()
    axB= figB.add_subplot(111, projection='3d')
    axB.plot(X,Y,Z, c='gray')
    axB.scatter([p_act],[m_act],[e_act], c='b', s=50, label='Actual')
    axB.scatter([p_pred],[m_pred],[e_pred], c='r', s=50, label='Predicted')
    axB.set_xlabel("Period")
    axB.set_ylabel("Mass (reversed)")
    axB.set_zlabel("Ecc")
    axB.set_title(f"System {sys_idx} => wire reversed mass")
    plt.draw()
    set_3d_axes_and_reverse_mass(axB)
    axB.legend()
    plt.show()

    df_plot= pd.DataFrame({"Period":X,"Mass":Y,"Ecc":Z,"Type":["Actual","Predicted"]})
    figC= px.scatter_3d(df_plot, x="Period", y="Mass", z="Ecc", color="Type",
                        symbol="Type", title=f"System {sys_idx} => interactive reversed mass")
    reverse_mass_axis_plotly(figC)
    figC.show()

def single_scatter(paramName, dataList):
    if not dataList:
        return
    x_ = [d[0] for d in dataList]
    y_ = [d[1] for d in dataList]
    e_ = [d[2] for d in dataList]
    plt.figure()
    plt.errorbar(x_, y_, yerr=e_, fmt='o', color='k', ecolor='red', capsize=3)
    mn = min(min(x_), min(y_))
    mx = max(max(x_), max(y_))
    plt.plot([mn,mx],[mn,mx],'r--')
    plt.xlabel(f"Actual {paramName}")
    plt.ylabel(f"Pred {paramName}")
    plt.title(f"{paramName} => so far: {len(dataList)} systems")
    plt.show()

###############################################################################
# MAIN
###############################################################################
def main():
    print("=== LOADING & PREPARING MAIN DATA ===")
    df_all = pd.read_csv("E:/SIMS92/ttv_dataset_multiple_params.csv")

    rename_map = {
        "Stellar Mass (Msun)": "star_m",
        "Inner Period (days)": "inn_p",
        "Inner Mass (Mearth)": "inn_m",
        "Inner Eccentricity": "inn_ecc",
        "Inner Inclination": "inn_inc",
        "Inner Omega": "inn_omega",
        "Outer Mass (Mearth)": "out_m",
        "Outer Period (days)": "out_p",
        "Outer Eccentricity": "out_ecc",
        "Outer Inclination": "out_inc",
        "Outer Omega": "out_omega",
        "Amplitude of Dominant Period Test (P1)": "AmpP1",
        "Dominant Period Planet 1": "DomP1"
    }
    df_all.rename(columns=rename_map, inplace=True, errors='ignore')

    # Filter out large masses, clamp ecc, log-transform
    df_all = df_all[df_all["out_m"] <= 50].copy()
    df_all["inn_ecc"] = df_all["inn_ecc"].clip(lower=0, upper=1.0)
    df_all["out_ecc"] = df_all["out_ecc"].clip(lower=0, upper=1.0)
    df_all["inn_ecc"] = np.log10(df_all["inn_ecc"].replace(0, 1e-9))
    df_all["out_ecc"] = np.log10(df_all["out_ecc"].replace(0, 1e-9))

    df_all["ratio"] = df_all["out_p"] / df_all["inn_p"]

    def ratio_in_any_period_bin(r):
        return ((1.9<=r<=1.99) or (2.01<=r<=2.1) or
                (1.4<=r<=1.49) or (1.51<=r<=1.6) or
                (2.2<=r<=2.5) or (2.5<r<=2.8))
    def ecc_in_any_bin(e_log):
        e_lin=10**(e_log)
        return ((0.0<=e_lin<=0.01) or (0.01<e_lin<=0.03) or (0.03< e_lin<=1.0))
    mask_ratio = df_all["ratio"].apply(ratio_in_any_period_bin)
    mask_ecc   = df_all["out_ecc"].apply(ecc_in_any_bin)
    df_filtered= df_all[mask_ratio & mask_ecc].copy()

    essential_cols= ["star_m","inn_p","inn_m","inn_ecc","inn_inc","inn_omega",
                     "out_m","out_p","out_ecc","out_inc","out_omega","AmpP1","DomP1","ratio"]
    df_filtered.dropna(subset=essential_cols, inplace=True)
    df_filtered.reset_index(drop=True, inplace=True)
    if len(df_filtered)==0:
        print("No data remains after filtering => abort.")
        return

    df_all= df_filtered
    print(f"Data filtered: {len(df_all)} rows remain.\n")

    # Build the 18 models
    df_tmp = df_all.copy()
    period_bins = ["p1","p2","p3","p4","p5","p6"]
    ecc_bins    = ["e1","e2","e3"]
    print("\n=== Building ModelB for period and eccentricity bins ===")
    modelB_map={}
    modelB_errors = []
    for pb in period_bins:
        for eb in ecc_bins:
            rfB, scB, errB = build_modelB_for_bin(df_tmp, pb, eb)
            modelB_map[(pb, eb)] = (rfB, scB, errB)
            if errB is not None:
                modelB_errors.append(errB)
    print("\n=== Building ModelA for period and eccentricity bins ===")
    modelA_map={}
    modelA_errors = []
    for pb in period_bins:
        for eb in ecc_bins:
            rfA, scA, errA = build_modelA_for_bin(df_tmp, pb, eb)
            modelA_map[(pb, eb)] = (rfA, scA, errA)
            if errA is not None:
                modelA_errors.append(errA)

    # Print overall aggregated error metrics for ModelA and ModelB (all bins)
    if modelA_errors:
        overall_A = {"MAE": np.mean([d["MAE"] for d in modelA_errors]),
                     "RMSE": np.mean([d["RMSE"] for d in modelA_errors]),
                     "MAFE": np.mean([d["MAFE"] for d in modelA_errors]),
                     "R2": np.mean([d["R2"] for d in modelA_errors])}
        print("\nOverall aggregated errors for Model A (all bins):")
        print(overall_A)
    if modelB_errors:
        overall_B = {"MAE": np.mean([d["MAE"] for d in modelB_errors]),
                     "RMSE": np.mean([d["RMSE"] for d in modelB_errors]),
                     "MAFE": np.mean([d["MAFE"] for d in modelB_errors]),
                     "R2": np.mean([d["R2"] for d in modelB_errors])}
        print("\nOverall aggregated errors for Model B (all bins):")
        print(overall_B)

    # Now filter out bins with period p5 and p6 and compute again.
    modelA_errors_filtered = [d for (k,d) in zip(modelA_map.keys(), modelA_errors) if k[0] not in {"p5","p6"}]
    modelB_errors_filtered = [d for (k,d) in zip(modelB_map.keys(), modelB_errors) if k[0] not in {"p5","p6"}]
    # Note: Because the errors lists were appended in the loop (in order of looping), an alternative is to loop over keys.
    # Here we re-loop over the keys:
    modelA_errors_filtered = []
    for (pb, eb), (rfA, scA, errA) in modelA_map.items():
        if pb not in {"p5", "p6"} and errA is not None:
            modelA_errors_filtered.append(errA)
    modelB_errors_filtered = []
    for (pb, eb), (rfB, scB, errB) in modelB_map.items():
        if pb not in {"p5", "p6"} and errB is not None:
            modelB_errors_filtered.append(errB)

    if modelA_errors_filtered:
        overall_A_filtered = {"MAE": np.mean([d["MAE"] for d in modelA_errors_filtered]),
                              "RMSE": np.mean([d["RMSE"] for d in modelA_errors_filtered]),
                              "MAFE": np.mean([d["MAFE"] for d in modelA_errors_filtered]),
                              "R2": np.mean([d["R2"] for d in modelA_errors_filtered])}
        print("\nOverall aggregated errors for Model A (bins excluding p5 and p6):")
        print(overall_A_filtered)
    if modelB_errors_filtered:
        overall_B_filtered = {"MAE": np.mean([d["MAE"] for d in modelB_errors_filtered]),
                              "RMSE": np.mean([d["RMSE"] for d in modelB_errors_filtered]),
                              "MAFE": np.mean([d["MAFE"] for d in modelB_errors_filtered]),
                              "R2": np.mean([d["R2"] for d in modelB_errors_filtered])}
        print("\nOverall aggregated errors for Model B (bins excluding p5 and p6):")
        print(overall_B_filtered)

    # Define a single system dictionary (using user-provided values)
    user_system = {
        "star_m": 0.97,
        "inn_m": 17.3,
        "inn_p": 7.64159,
        "inn_ecc": np.log10(1e-9) if 0.0<=1e-12 else 0.0,
        "inn_inc": 87.68,
        "inn_omega": 0.0,
        "AmpP1": 6.887061813,
        "DomP1": 34.48248077,
        "out_m": 16.4,
        "out_p": 14.85888,
        "out_ecc": np.log10(1e-9) if 0.0<=1e-12 else 0.0,
        "out_inc": 88.07,
        "out_omega": 0.0
    }
    #user_system["inn_ecc"] = -9.0
    #user_system["out_ecc"] = -9.0
    user_system["ratio"] = user_system["out_p"] / user_system["inn_p"]

    # Run the single-layer pipeline on the user system
    combos_18 = run_singlelayer_pipeline_for_system(user_system, modelB_map, modelA_map)

    print("\n=== RESULTS for the single user-provided system ===")
    bayes_results = {}
    for co in combos_18:
        pbin, ebin = co["bins"]
        sol_list = co["solutions"]
        lnZ = co["lnZ"]
        bayes_results[(pbin, ebin)] = lnZ
        print(f"Bin=({pbin},{ebin}), #solutions={len(sol_list)}, lnZ = {lnZ:.6f}")
        for iSol, sol_ in enumerate(sol_list):
            fd = sol_["final_param_dict"]
            p_val, p_err = sol_["p_mean"]
            m_val, m_err = sol_["mass_mean"]
            e_log, e_std = sol_["e_mean"]
            e_lin = 10**(e_log)
            e_err_lin = 10**(e_log + e_std)- 10**(e_log) if e_lin>1e-12 else 0.0
            a_val = fd["predAmp"]
            f_val = fd["predDomP"]
            sse_  = sol_["SSE"]
            print(f"  solution {iSol}: out_p={p_val:.6f} ± {p_err:.6f}, "
                  f"out_m={m_val:.6f} ± {m_err:.6f}, e_lin={e_lin:.6f} ± {e_err_lin:.6f}, "
                  f"PredAmp={a_val:.6f}, PredDomP={f_val:.6f}, SSE={sse_:.6f}")
    # --- NEW: Compute and print Bayes factor differences (ΔlnZ) ---
    if bayes_results:
        best_lnZ = max(bayes_results.values())
        print("\n=== Bayes Factor (ΔlnZ) Summary ===")
        for bin_key, lnZ in bayes_results.items():
            delta_lnZ = lnZ - best_lnZ
            print(f"Bin {bin_key}: lnZ = {lnZ:.6f}, ΔlnZ = {delta_lnZ:.6f}")
    # --- END NEW ---

    print("\nDone. You have the MCMC solutions for the single system.")
    print("No final difference plots are produced since we are only analyzing one system.")
    print("If you need additional plots, you can adapt the code above accordingly.")

if __name__=="__main__":
    main()