# 40 — Parameter evolution (per-bin vs joint)

In [1]:
from pathlib import Path
import re, json
import numpy as np
import matplotlib.pyplot as plt

# -------- paths & setup --------
PER_BIN_DIR = Path.home() / "project/results/mcmc/per_bin"
OUT = Path.home() / "project/results/figures/per_bin_fitting"
OUT.mkdir(parents=True, exist_ok=True)

# Canonical → possible keys in your .npz
ALIASES = {
    "alpha":        ["alpha"],
    "beta":         ["beta"],
    "log10_Lstar":  ["log10_Lstar","e","Lstar_log10","logLstar"],
    "log10_Nphi":   ["log10_Nphi","d","Nphi_log10","logNphi"],
    "m0":           ["m0"],
    "b":            ["b"],
}

# discover files
pat = re.compile(r"y3_bin(\d+)_z([0-9.]+)-([0-9.]+)\.npz$")
bins = []
for p in sorted(PER_BIN_DIR.glob("y3_bin*_z*.npz")):
    m = pat.match(p.name)
    if m:
        idx, zmin, zmax = int(m.group(1)), float(m.group(2)), float(m.group(3))
        bins.append({"idx": idx, "file": p, "zmin": zmin, "zmax": zmax, "z_center": 0.5*(zmin+zmax)})
bins = sorted(bins, key=lambda b: b["idx"])
assert bins, f"No per-bin npz files in {PER_BIN_DIR}"

def ravel(x):
    x = np.asarray(x)
    return x.reshape(-1) if x.ndim>1 else x

def q16_50_84(x):
    x = ravel(x)
    return np.percentile(x, [16,50,84])

def get_param_dict(npz_dict, name):
    """Return array for param 'name' using aliases, or None."""
    for k in ALIASES[name]:
        if k in npz_dict:
            return ravel(npz_dict[k])
    return None

# load all per-bin chains once
chains = []
for b in bins:
    chains.append((b, dict(np.load(b["file"], allow_pickle=True))))

# ------------- (A) per-bin evolution plots (error bars vs z) ----------------
def plot_evolution(param, ylabel=None):
    z = np.array([b["z_center"] for b,_ in chains])
    meds, elo, ehi = [], [], []
    present = True
    for b, D in chains:
        arr = get_param_dict(D, param)
        if arr is None:
            present = False
            meds.append(np.nan); elo.append(np.nan); ehi.append(np.nan)
        else:
            lo, med, hi = q16_50_84(arr)
            meds.append(med); elo.append(med-lo); ehi.append(hi-med)
    if not any(np.isfinite(meds)):
        print(f"[skip] '{param}' not found in any per-bin files")
        return
    plt.figure(figsize=(5,4))
    plt.errorbar(z, meds, yerr=[elo,ehi], fmt="o", capsize=3)
    plt.xlabel("z (bin center)")
    plt.ylabel(ylabel or param)
    plt.title(f"Per bin fitting — {param}")
    plt.tight_layout()
    fn = OUT / f"per_bin_evolution__{param}.png"
    plt.savefig(fn, dpi=200); plt.close()
    print("Saved", fn)

# ------------- (B) overlapped posterior plots (all bins on one axis) --------
def kde_1d(x, nbins=256, widen=0.1):
    """Simple smooth histogram (not true KDE, but fine for overlay)."""
    x = ravel(x); x = x[np.isfinite(x)]
    if len(x) < 10:
        return None, None
    lo, hi = np.percentile(x, [0.5, 99.5])
    pad = (hi-lo)*widen if hi>lo else 1.0
    grid = np.linspace(lo-pad, hi+pad, nbins)
    hist, edges = np.histogram(x, bins=nbins, range=(grid.min(), grid.max()), density=True)
    ctrs = 0.5*(edges[1:]+edges[:-1])
    # light smoothing
    from scipy.ndimage import gaussian_filter1d as gf
    pdf = gf(hist, sigma=2)
    return ctrs, pdf

def plot_overlap(param, xlabel=None):
    plt.figure(figsize=(6,4))
    have_any = False
    for (b, D) in chains:
        arr = get_param_dict(D, param)
        if arr is None: 
            continue
        x, pdf = kde_1d(arr)
        if x is None: 
            continue
        plt.plot(x, pdf, label=f"bin{b['idx']} z{b['zmin']}-{b['zmax']}")
        have_any = True
    if not have_any:
        print(f"[skip] '{param}' not found anywhere for overlap")
        plt.close(); return
    plt.xlabel(xlabel or param)
    plt.ylabel("density (arb.)")
    plt.title(f"Per bin fitting — overlapped posteriors: {param}")
    plt.legend(frameon=False, fontsize=9)
    plt.tight_layout()
    fn = OUT / f"per_bin_overlap__{param}.png"
    plt.savefig(fn, dpi=200); plt.close()
    print("Saved", fn)

# --------- run for key parameters (modify list as needed) ----------
PARAMS = [
    ("alpha",       r"$\alpha$"),
    ("beta",        r"$\beta$"),
    ("log10_Lstar", r"$\log_{10} L_\ast$"),
    ("log10_Nphi",  r"$\log_{10} \Phi_\ast$"),
    # include these only if present in *individual* fits:
    ("m0",          r"$m_0$"),
    ("b",           r"$b$"),
]

for p, label in PARAMS:
    plot_evolution(p, ylabel=label)
    plot_overlap(p,  xlabel=label)


Saved /global/homes/z/zhaozhon/project/results/figures/per_bin_fitting/per_bin_evolution__alpha.png
Saved /global/homes/z/zhaozhon/project/results/figures/per_bin_fitting/per_bin_overlap__alpha.png
Saved /global/homes/z/zhaozhon/project/results/figures/per_bin_fitting/per_bin_evolution__beta.png
Saved /global/homes/z/zhaozhon/project/results/figures/per_bin_fitting/per_bin_overlap__beta.png
Saved /global/homes/z/zhaozhon/project/results/figures/per_bin_fitting/per_bin_evolution__log10_Lstar.png
Saved /global/homes/z/zhaozhon/project/results/figures/per_bin_fitting/per_bin_overlap__log10_Lstar.png
Saved /global/homes/z/zhaozhon/project/results/figures/per_bin_fitting/per_bin_evolution__log10_Nphi.png
Saved /global/homes/z/zhaozhon/project/results/figures/per_bin_fitting/per_bin_overlap__log10_Nphi.png
Saved /global/homes/z/zhaozhon/project/results/figures/per_bin_fitting/per_bin_evolution__m0.png
Saved /global/homes/z/zhaozhon/project/results/figures/per_bin_fitting/per_bin_overlap__m0.