In [None]:

"""
Unimetry Roughness Pipeline — ALL-IN-ONE (v6)

Features:
- Robust loader for Pantheon+ (whitespace .dat or CSV), auto-detects columns.
- Options: prefer zHD vs zCMB; z-min cut; optional peculiar-velocity error term.
- Baseline ΛCDM fit (Ωm, M).
- ΛCDM + piecewise-constant Δμ(z) with smoothness prior (2nd-difference); 
  prints bin counts & percentages; saves bin_summary.csv; plots with % labels.
- 1-parameter template model Δμ(z) = A f(z) with several f(z) options;
  prints A ± σ(A), AIC/BIC, Δχ²; plots Δμ(z) and residuals.
- Sky map of SN positions: HEALPix Mollweide (if healpy available) or Matplotlib fallback.

Usage:
  - Set DATA_PATH to your Pantheon+ table (e.g., Pantheon+SH0ES.dat).
  - Run: python unimetry_roughness_pipeline.py
"""

import numpy as np
import pandas as pd
import warnings
from pathlib import Path
import matplotlib.pyplot as plt

# Optional: HEALPix for CMB-style maps
try:
    import healpy as hp
    HAVE_HEALPY = True
except Exception:
    HAVE_HEALPY = False

# ====================== CONFIG ======================
DATA_PATH = Path("./Pantheon+SH0ES.dat")   # path to Pantheon+ table

# Redshift handling
PREFER_ZHD = False      # if True, prefer zHD over zCMB
Z_MIN = 0.0             # redshift cut: use only z >= Z_MIN
ADD_PEC_ERR = False     # include peculiar-velocity error term using VPECERR if present
PEC_Z_CLIP = 0.003      # avoid blow-up near z=0 when computing sigma_mu_pec

# Roughness (multi-bin) model
Z_BINS = np.array([0.0,0.01,0.02,0.03,0.04,0.96,2.0])
LAMBDA_SMOOTH = 1.0     # smoothness strength (2nd-difference Tikhonov)
SIGMA_FLOOR = 0.12      # floor for sigma_mu if absent

# Distances
NINT = 600              # trapezoid integration resolution per z
SUPPRESS_WARNINGS = True

# Sky-map config
NSIDE = 16              # HEALPix resolution (12*NSIDE^2 pixels)

# Template (1-parameter) model defaults
TEMPLATE_KIND = "z_over_zplus"   # options: z_over_zplus, ramp_exp, log1pz, power_sat
TEMPLATE_Z0   = 0.30             # shape scale
TEMPLATE_P    = 1.5              # shape power (for ramp_exp)
TEMPLATE_BETA = 1.2              # exponent (for power_sat)

# =================== CONSTANTS ======================
c_km_s = 299792.458
if SUPPRESS_WARNINGS:
    warnings.simplefilter("ignore", category=DeprecationWarning)

# ===================== LOADER =======================
def _read_any(path: Path) -> pd.DataFrame:
    try:
        df = pd.read_table(path, comment='#', sep=r"\s+", engine="python")
        if len(df.columns) == 1:
            df = pd.read_csv(path, comment='#')
        return df
    except Exception:
        try:
            return pd.read_csv(path, comment='#')
        except Exception:
            return pd.read_csv(path, delim_whitespace=True, comment='#')

def load_sn_table(path: Path):
    if not path.exists():
        raise FileNotFoundError(f"File not found: {path}")
    df = _read_any(path)

    if len(df.columns) == 1:
        df = pd.read_table(path, comment='#', sep=r"\s+", engine="python")

    cols_lower = {c.lower().strip(): c for c in df.columns}

    def pick(*names):
        for n in names:
            key = n.lower()
            if key in cols_lower:
                return cols_lower[key]
        return None

    # choose z column by preference
    if PREFER_ZHD:
        zcol = pick('zhd','zcmb','z')
    else:
        zcol = pick('zcmb','zhd','z')

    mucol  = pick('mu', 'mu_sh0es', 'distmod', 'm_b_corr', 'mb_corr', 'mb')
    sigcol = pick('mu_err','sigma_mu','mu_sh0es_err_diag','m_b_corr_err_diag','dmu','dmb','merr')
    vperr_col = pick('vpecerr','vpec_err','sigma_vpec')

    if zcol is None or mucol is None:
        raise ValueError(f"Could not find essential columns.\n"
                         f"Have columns: {list(df.columns)}\n"
                         f"Need some of zCMB/zHD/z and mu or MU_SH0ES/m_b_corr.")

    out = pd.DataFrame()
    out['z'] = df[zcol].astype(float)
    out['mu_like'] = df[mucol].astype(float)
    if sigcol is not None:
        out['sigma_mu'] = df[sigcol].astype(float).clip(lower=SIGMA_FLOOR)
        sigma_src = sigcol
    else:
        out['sigma_mu'] = SIGMA_FLOOR
        sigma_src = f"floor={SIGMA_FLOOR}"

    # optional peculiar-velocity error contribution
    if ADD_PEC_ERR and (vperr_col is not None):
        zv = np.clip(out['z'].values, PEC_Z_CLIP, None)
        sigma_v = df[vperr_col].astype(float).values
        sigma_mu_pec = (5.0/np.log(10.0)) * (sigma_v / (c_km_s * zv))
        out['sigma_mu'] = np.sqrt(out['sigma_mu'].values**2 + sigma_mu_pec**2)

    racol = pick('ra','ra_deg')
    decol = pick('dec','dec_deg')
    out['ra']  = df[racol].astype(float) if racol else np.nan
    out['dec'] = df[decol].astype(float) if decol else np.nan

    # clean & cut
    out = out.replace([np.inf,-np.inf], np.nan).dropna(subset=['z','mu_like','sigma_mu'])
    out = out[(out['z']>=Z_MIN) & (out['z']<=2.0)].sort_values('z').reset_index(drop=True)

    print(f"Loaded {len(out)} SNe from {path.name}")
    print(f"Using z from '{zcol}', mu-like from '{mucol}', sigma from '{sigma_src}'")
    if ADD_PEC_ERR and (vperr_col is not None):
        print(f"Including peculiar-velocity error term from '{vperr_col}' (z clip {PEC_Z_CLIP})")
    return out

# =================== COSMOLOGY ======================
def E_z(z, Om): return np.sqrt(Om*(1+z)**3 + (1-Om))

def comoving_distance(z, Om, H0=70.0):
    z = float(z)
    zg = np.linspace(0.0, z, NINT)
    Ez = E_z(zg, Om)
    return (c_km_s/H0) * np.trapezoid(1.0/Ez, zg)

def lum_distance(z, Om, H0=70.0): return (1+z)*comoving_distance(z, Om, H0)

def mu_theory(z, Om, M):
    Dl = np.array([lum_distance(zi, Om) for zi in np.atleast_1d(z)])
    return 5.0*np.log10(np.clip(Dl,1e-6,None)) + M

# ==================== BASELINE ======================
def fit_lcdm(sn_df, Om_grid=np.linspace(0.05,0.6,200)):
    y = sn_df['mu_like'].values
    w = 1.0/(sn_df['sigma_mu'].values**2)
    best=None
    for Om in Om_grid:
        mu0 = 5.0*np.log10(np.array([lum_distance(z,Om) for z in sn_df['z'].values]))
        M = np.sum(w*(y-mu0))/np.sum(w)            # optimal M
        r = y - (mu0 + M)
        chi2 = np.sum(w*r*r)
        if best is None or chi2<best['chi2']:
            best={'Om':float(Om),'M':float(M),'chi2':float(chi2)}
    N = len(y); k = 2
    best['dof']=int(N-k); best['chi2_red']=best['chi2']/max(1,best['dof'])
    best['AIC'] = best['chi2'] + 2*k
    best['BIC'] = best['chi2'] + k*np.log(N)
    return best

# ================= MULTI-BIN ROUGHNESS ==============
def design_matrix_bins(z, edges):
    B = np.zeros((len(z), len(edges)-1))
    for j in range(len(edges)-1):
        m = (z>=edges[j]) & (z<edges[j+1]); B[m,j]=1.0
    B[z>=edges[-1]-1e-10, -1]=1.0
    return B

def fit_lcdm_with_roughness(sn_df, edges, lam=1.0, Om_grid=np.linspace(0.05,0.6,120)):
    z=sn_df['z'].values; y=sn_df['mu_like'].values; w=1.0/(sn_df['sigma_mu'].values**2)
    B = design_matrix_bins(z, edges); K=B.shape[1]

    # Precompute normal matrix (Om-independent) and prior
    X = np.column_stack([np.ones_like(y), B])
    W = np.diag(w)
    A = X.T @ W @ X

    D2 = np.zeros((K-2, K))
    for i in range(K-2): D2[i,i:i+3] = [1,-2,1]
    RtR = D2.T @ D2

    R = np.zeros((K+1, K+1))
    R[1:,1:] = lam*RtR

    cov_theta = np.linalg.inv(A + R)

    best=None
    for Om in Om_grid:
        mu0 = 5.0*np.log10(np.array([lum_distance(zi,Om) for zi in z]))
        b = X.T @ W @ (y - mu0)
        theta = np.linalg.solve(A + R, b)
        M = theta[0]; delta = theta[1:]
        r = y - (mu0 + M + B@delta)
        chi2 = float(np.sum(w*r*r) + lam*float(delta@RtR@delta))
        if best is None or chi2<best['chi2']:
            best={'Om':float(Om),'M':float(M),'delta':delta,'bin_edges':edges,'chi2':chi2}

    N = len(y); k = 2 + len(best['delta'])
    best['dof']=int(N-k); best['chi2_red']=best['chi2']/max(1,best['dof'])
    var_theta = np.diag(cov_theta)
    best['delta_err'] = np.sqrt(var_theta[1:])
    best['AIC'] = best['chi2'] + 2*k
    best['BIC'] = best['chi2'] + k*np.log(N)
    return best

# ============== 1-PARAM TEMPLATE ROUGHNESS ==========
def f_template(z, kind="z_over_zplus", z0=0.25, p=1.5, beta=1.0):
    z = np.asarray(z, float)
    if kind == "z_over_zplus":
        f = z / (z + z0)
    elif kind == "ramp_exp":
        f = 1.0 - np.exp(- (z / z0)**p)
    elif kind == "log1pz":
        # normalize by log1p(1+z0) just to keep typical scale ~1 around z~z0
        f = np.log1p(z) / np.log1p(1.0 + z0)
    elif kind == "power_sat":
        f = (z / (z + z0))**beta
    else:
        raise ValueError(f"Unknown template kind: {kind}")
    return f

def fit_lcdm_with_template(sn_df,
                           kind="z_over_zplus",
                           z0=0.25, p=1.5, beta=1.0,
                           Om_grid=np.linspace(0.05, 0.6, 200)):
    """Fit ΛCDM + A f(z) with ONE amplitude parameter A (plus M, Om).
       f is centered (weighted) to avoid degeneracy with constant M.
    """
    z = sn_df['z'].values
    y = sn_df['mu_like'].values
    w = 1.0 / (sn_df['sigma_mu'].values**2)

    f = f_template(z, kind=kind, z0=z0, p=p, beta=beta)
    f = f - np.average(f, weights=w)  # weighted centering

    best = None
    for Om in Om_grid:
        mu0 = 5.0*np.log10(np.array([lum_distance(zi, Om) for zi in z]))
        X = np.column_stack([np.ones_like(z), f])  # [M, A]
        W = np.diag(w)
        A_mat = X.T @ W @ X
        b_vec = X.T @ W @ (y - mu0)
        theta = np.linalg.solve(A_mat, b_vec)
        M_hat, A_hat = float(theta[0]), float(theta[1])

        resid = y - (mu0 + M_hat + A_hat * f)
        chi2  = float(np.sum(w * resid**2))

        if (best is None) or (chi2 < best['chi2']):
            cov = np.linalg.inv(A_mat)
            sigA = float(np.sqrt(cov[1,1]))
            best = {'Om':float(Om), 'M': M_hat, 'A': A_hat, 'sigma_A': sigA,
                    'chi2': chi2, 'f': f, 'kind': kind, 'z0': z0, 'p': p, 'beta': beta}

    N = len(y); k = 3
    best['dof'] = int(N - k)
    best['chi2_red'] = best['chi2']/max(1,best['dof'])
    best['AIC'] = best['chi2'] + 2*k
    best['BIC'] = best['chi2'] + k*np.log(N)
    return best

def compare_models_print(lcdm, rough, templ):
    N = lcdm['dof'] + 2
    k_lcdm = 2
    print("\n=== MODEL COMPARISON ===")
    print(f"LCDM:    χ²={lcdm['chi2']:.1f}, dof={lcdm['dof']}, χ²_red={lcdm['chi2_red']:.3f}, AIC={lcdm['AIC']:.1f}, BIC={lcdm['BIC']:.1f}")
    print(f"Rough(K={len(rough['delta'])}): χ²={rough['chi2']:.1f}, dof={rough['dof']}, χ²_red={rough['chi2_red']:.3f}, AIC={rough['AIC']:.1f}, BIC={rough['BIC']:.1f}")
    print(f"Template(1p): χ²={templ['chi2']:.1f}, dof={templ['dof']}, χ²_red={templ['chi2_red']:.3f}, AIC={templ['AIC']:.1f}, BIC={templ['BIC']:.1f}")
    print(f"Δχ² (LCDM→Rough)    = {lcdm['chi2']-rough['chi2']:.2f} for {2+len(rough['delta'])-2} extra params")
    print(f"Δχ² (LCDM→Template) = {lcdm['chi2']-templ['chi2']:.2f} for {3-2} extra params")

# =================== PLOTTING =======================
def plot_baseline_and_roughness(sn, lcdm, rough):
    z = sn['z'].values; y = sn['mu_like'].values
    B = design_matrix_bins(z, rough['bin_edges']); dm = rough['delta']; de = rough['delta_err']
    muL = mu_theory(z, lcdm['Om'], lcdm['M'])
    muR = mu_theory(z, rough['Om'], rough['M']) + B @ dm

    # Residuals (baseline)
    plt.figure(figsize=(8,5))
    plt.scatter(z, y-muL, s=6, alpha=0.5)
    plt.xlabel('z'); plt.ylabel('μ_obs - μ_LCDM'); plt.title('Residuals (baseline)')
    plt.tight_layout(); plt.savefig("residuals_baseline.png", dpi=160)

    # Δμ(z) step with % annotations and error bars
    be = rough['bin_edges']
    counts = B.sum(axis=0).astype(int)
    total = len(z); perc = counts/total*100.0
    zc = 0.5*(be[:-1] + be[1:])

    plt.figure(figsize=(9,4.8))
    plt.step(be[:-1], dm, where='post')
    plt.errorbar(zc, dm, yerr=de, fmt='o', ms=4)
    for j in range(len(counts)):
        ytext = dm[j] + (0.02 if np.isfinite(dm[j]) else 0.02)
        plt.text(zc[j], ytext, f"{perc[j]:.1f}%", ha='center', va='bottom')
    plt.xlabel('z'); plt.ylabel('Δμ bin (mag)'); plt.title('Inferred roughness Δμ(z)')
    plt.tight_layout(); plt.savefig("roughness_dmu_bins.png", dpi=160)

    # Residuals (after roughness)
    plt.figure(figsize=(8,5))
    plt.scatter(z, y-muR, s=6, alpha=0.5)
    plt.xlabel('z'); plt.ylabel('μ_obs - μ_model'); plt.title('Residuals (after roughness)')
    plt.tight_layout(); plt.savefig("residuals_after.png", dpi=160)

    # Bin summary CSV
    df_out = pd.DataFrame({
        "z_left":  be[:-1],
        "z_right": be[1:],
        "count":   counts,
        "percent": perc,
        "delta_mu": dm,
        "sigma_delta_mu": de
    })
    df_out.to_csv("bin_summary.csv", index=False)
    print("Saved bin summary -> bin_summary.csv")

def plot_template(sn, lcdm, templ, prefix="template"):
    z = sn['z'].values; y = sn['mu_like'].values
    f = templ['f']
    muT = mu_theory(z, templ['Om'], templ['M']) + templ['A'] * f

    # Δμ(z) = A f(z)
    idx = np.argsort(z)
    plt.figure(figsize=(8.2,4.6))
    plt.plot(z[idx], (templ['A']*f)[idx])
    plt.xlabel("z"); plt.ylabel("Δμ(z) [mag]")
    plt.title(f"Template: {templ['kind']}  (A = {templ['A']:+.4f} ± {templ['sigma_A']:.4f})")
    plt.tight_layout(); plt.savefig(f"{prefix}_dmu.png", dpi=160)

    # Residuals after template
    plt.figure(figsize=(8.2,5))
    plt.scatter(z, y - muT, s=6, alpha=0.5)
    plt.xlabel("z"); plt.ylabel("μ_obs - μ_model")
    plt.title("Residuals after ΛCDM + A f(z)")
    plt.tight_layout(); plt.savefig(f"{prefix}_resid.png", dpi=160)

# =================== SKY MAPS =======================
def save_sky_map_healpy(sn_df, nside=NSIDE, out_png="sn_sky_healpy.png"):
    ra = sn_df['ra'].values; dec = sn_df['dec'].values
    mask = np.isfinite(ra) & np.isfinite(dec)
    if not mask.any():
        print("No RA/Dec available; skipping HEALPix map.")
        return False
    theta = np.deg2rad(90.0 - dec[mask])
    phi   = np.deg2rad(ra[mask])
    npix = hp.nside2npix(nside)
    m = np.zeros(npix, dtype=float)
    ipix = hp.ang2pix(nside, theta, phi)
    for p in ipix: m[p] += 1.0
    hp.mollview(m, title=f"SN sky distribution (HEALPix NSIDE={nside})")
    plt.savefig(out_png, dpi=160, bbox_inches="tight")
    plt.close()
    print(f"Saved HEALPix sky map -> {out_png}")
    return True

def save_sky_map_mollweide_scatter(sn_df, out_png="sn_sky_mollweide_scatter.png"):
    ra = sn_df['ra'].values; dec = sn_df['dec'].values
    mask = np.isfinite(ra) & np.isfinite(dec)
    if not mask.any():
        print("No RA/Dec available; skipping Mollweide scatter.")
        return False
    ra = ra[mask]; dec = dec[mask]
    x = np.deg2rad(ra % 360.0); x[x > np.pi] -= 2.0*np.pi; x = -x
    y = np.deg2rad(dec)
    fig = plt.figure(figsize=(9,5))
    ax = fig.add_subplot(111, projection='mollweide')
    ax.scatter(x, y, s=6, alpha=0.7)
    ax.grid(True)
    ax.set_title("SN sky distribution (Mollweide scatter)")
    plt.tight_layout()
    plt.savefig(out_png, dpi=160, bbox_inches="tight")
    plt.close(fig)
    print(f"Saved Mollweide scatter -> {out_png}")
    return True

# ====================== MAIN ========================
def main():
    sn = load_sn_table(DATA_PATH)

    # Baseline ΛCDM
    lcdm = fit_lcdm(sn)

    # Multi-bin roughness
    rough = fit_lcdm_with_roughness(sn, Z_BINS, LAMBDA_SMOOTH)

    # Print results
    print("\n=== RESULTS (BASELINE & ROUGHNESS) ===")
    print(f"LCDM:  chi2={lcdm['chi2']:.1f}, dof={lcdm['dof']}, chi2_red={lcdm['chi2_red']:.3f}, AIC={lcdm['AIC']:.1f}, BIC={lcdm['BIC']:.1f}")
    print(f"Rough: chi2={rough['chi2']:.1f}, dof={rough['dof']}, chi2_red={rough['chi2_red']:.3f}, AIC={rough['AIC']:.1f}, BIC={rough['BIC']:.1f}")
    print(f"Δχ² = {lcdm['chi2'] - rough['chi2']:.2f} for {2+len(rough['delta'])-2} extra bins")

    # Plots & bin summary
    plot_baseline_and_roughness(sn, lcdm, rough)

    # Sky map
    made = False
    if HAVE_HEALPY:
        made = save_sky_map_healpy(sn, nside=NSIDE, out_png="sn_sky_healpy.png")
    if not made:
        save_sky_map_mollweide_scatter(sn, out_png="sn_sky_mollweide_scatter.png")

    # 1-parameter template fit
    templ = fit_lcdm_with_template(sn,
                                   kind=TEMPLATE_KIND,
                                   z0=TEMPLATE_Z0, p=TEMPLATE_P, beta=TEMPLATE_BETA)
    print("\n=== 1-PARAM TEMPLATE ===")
    print(f"Template: {templ['kind']}  (z0={templ['z0']}, p={templ['p']}, beta={templ['beta']})")
    print(f"Om={templ['Om']:.3f}, A={templ['A']:+.4f} ± {templ['sigma_A']:.4f} mag")
    print(f"chi2={templ['chi2']:.1f}, dof={templ['dof']}, chi2_red={templ['chi2_red']:.3f}, AIC={templ['AIC']:.1f}, BIC={templ['BIC']:.1f}")

    # Compare all three
    compare_models_print(lcdm, rough, templ)
    plot_template(sn, lcdm, templ, prefix=f"template_{templ['kind']}")

if __name__ == "__main__":
    main()
