In [1]:
# ------------------------------------------------------------
# Run imports and aesthetics
# ------------------------------------------------------------
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import os, warnings
import numpy as np
import matplotlib.pyplot as plt
import corner
import arviz as az
import pandas as pd
from IPython.display import display
from getdist import plots, MCSamples
import emcee
import matplotlib.cm as cm

# for exporting tables nicely
import dataframe_image as dfi  

# aesthetics
plt.rcParams['figure.dpi'] = 120

OUTDIR = "output"
PLOT_DIR = "plots"   # NEW folder for plots/tables
os.makedirs(PLOT_DIR, exist_ok=True)


In [2]:

# ------------------------------------------------------------
# Run configs and colors
# ------------------------------------------------------------
runs = {
    "pantheon_only_fixedM":      ["Omega_m", "h"],
    "pantheon_plus_bao_fixedM":  ["Omega_m", "h"],
    "joint_bao_h0_varyM":        ["Omega_m", "h", "M"],
}

# ONLY used for corner plots (walker colors handled separately)
corner_colors = {
    "pantheon_only_fixedM":     "royalblue",
    "pantheon_plus_bao_fixedM": "darkorange",
    "joint_bao_h0_varyM":       "seagreen",
}



In [3]:
# ------------------------------------------------------------
# Safe loaders
# ------------------------------------------------------------
def load_chain_raw(label):
    return np.load(os.path.join(OUTDIR, f"{label}_chain_raw.npy"))

def load_chain_flat(label):
    return np.load(os.path.join(OUTDIR, f"{label}_chain_flat.npy"))

def load_acceptance(label):
    return np.load(os.path.join(OUTDIR, f"{label}_acceptance.npy"))

# guard for plotting huge flats
MAX_PLOT_SAMPLES = 200000
def thin_for_plot(flat):
    n = flat.shape[0]
    if n > MAX_PLOT_SAMPLES:
        idx = np.linspace(0, n-1, MAX_PLOT_SAMPLES).astype(int)
        return flat[idx]
    return flat

In [4]:
# ============================================================
# Diagnostics 
# ============================================================

def to_arviz_idata(chain_raw, params):
    nsteps, nwalkers, ndim = chain_raw.shape
    chains = chain_raw.transpose(1, 0, 2)  # (chains, draws, ndim)
    post = {p: chains[:, :, i] for i, p in enumerate(params)}
    return az.from_dict(posterior=post), nwalkers, nsteps

def diagnostics_from_raw(chain_raw, params):
    idata, nchains, ndraws = to_arviz_idata(chain_raw, params)
    total_draws = nchains * ndraws
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        rhat_ds = az.rhat(idata)
        ess_ds  = az.ess(idata, method="bulk")
    rhat_out, ess_out, tau_out = {}, {}, {}
    for p in params:
        rhat_out[p] = float(rhat_ds[p].values)
        ess_out[p]  = float(ess_ds[p].values)
        tau_out[p]  = float(total_draws / ess_out[p]) if ess_out[p] > 0 else np.nan
    return rhat_out, ess_out, tau_out

def posterior_summaries(flat, params):
    out = {}
    for i, p in enumerate(params):
        q16, q50, q84 = np.percentile(flat[:, i], [16, 50, 84])
        out[p] = (float(q16), float(q50), float(q84))
    return out

# ------------------------------------------------------------
# Trace plots with per-walker colors + y-zoom
# ------------------------------------------------------------
import matplotlib.pyplot as plt
import numpy as np
import os

def make_trace_plot(chain_raw, label, params, max_steps=None, zoom_factor=0.0):
    """
    Make trace plots like in the emcee example:
    - all parameters stacked in subplots (one figure per run)
    - all walkers overlaid with transparency
    - optional y-axis zoom (quantile-based)
    - saved to PLOT_DIR as a single PNG per run
    """
    nsteps, nwalkers, ndim = chain_raw.shape

    # thin steps if requested
    if max_steps is not None and nsteps > max_steps:
        step_idx = np.linspace(0, nsteps - 1, max_steps).astype(int)
        chain_raw = chain_raw[step_idx]

    # wider, more spacious figure
    fig, axes = plt.subplots(ndim, 1, figsize=(16, 4 * ndim), sharex=True)

    if ndim == 1:
        axes = [axes]  # keep iterable

    # distinct colors for each walker
    cmap = plt.get_cmap("tab20")
    colors = [cmap(j % 20) for j in range(nwalkers)]

    for i, pname in enumerate(params):
        ax = axes[i]
        for j in range(nwalkers):
            ax.plot(chain_raw[:, j, i], color=colors[j], alpha=0.6, lw=1)

        ax.set_ylabel(pname, fontsize=12)
        ax.grid(alpha=0.3)

        # zoom if requested
        if zoom_factor > 0:
            qlow, qhigh = np.quantile(chain_raw[:, :, i], [zoom_factor / 2, 1 - zoom_factor / 2])
            ax.set_ylim(qlow, qhigh)

    axes[-1].set_xlabel("Step", fontsize=12)
    fig.suptitle(f"Trace plots: {label}", fontsize=16)

    outpath = os.path.join(PLOT_DIR, f"{label}_trace.png")
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    fig.savefig(outpath, dpi=150, bbox_inches="tight")
    plt.close(fig)



# ------------------------------------------------------------
# Corner plots (styled, saved)
# ------------------------------------------------------------
def make_corner_plot(flat, label, params, posts=None, truths=None, title=None):
    flat_plot = flat

    # Default truths = posterior medians if not provided
    if truths is None and posts is not None:
        truths = [posts[p][1] for p in params]  # q50 for each param

    # Prettier LaTeX labels
    label_map = {
        "Omega_m": r"$\Omega_m$",
        "h":       r"$h$",
        "M":       r"$M$",
    }
    nice_labels = [label_map.get(p, p) for p in params]

    # Pick color from config (fallback = black)
    run_color = corner_colors.get(label, "black")

    # Apply burn-in and thin before plotting
    burn_in = 500  # or whatever you decide
    thin    = 10
    flat_plot = flat[burn_in::thin]

    # Auto-set nice ranges based on posterior percentiles
    ranges = []
    for i in range(flat_plot.shape[1]):
        low, high = np.percentile(flat_plot[:, i], [0.5, 99.5])
        ranges.append((low, high))

    fig = corner.corner(
        flat_plot,
        labels=nice_labels,
        truths=truths,
        show_titles=True,
        color=run_color,
        bins=30,
        smooth=0.8,
        smooth1d=0.8,
        fill_contours=True,
        levels=(0.68, 0.95),
        range=ranges,   # << zooms to main posterior
    )


    if title:
        plt.suptitle(title, fontsize=14)

    outpath = os.path.join(PLOT_DIR, f"{label}_corner.png")
    fig.savefig(outpath, dpi=200, bbox_inches="tight")
    plt.close(fig)

In [5]:

#https://getdist.readthedocs.io/en/latest/plot_gallery.html

In [6]:
# ============================================================
# Cell 3: Run per-run diagnostics and plots
# ============================================================

results = {}   # store results for later table
for label, params in runs.items():
    print("\n" + "="*60)
    print(f"Run: {label}")
    print("="*60)

    chain_raw = load_chain_raw(label)
    flat = load_chain_flat(label)
    acc  = load_acceptance(label)

    # quick shape check
    nsteps, nwalkers, ndim = chain_raw.shape
    assert ndim == len(params), f"ndim mismatch for {label}"

    # diagnostics
    rhat, ess, tau = diagnostics_from_raw(chain_raw, params)
    acc_mean = float(np.mean(acc))
    posts = posterior_summaries(flat, params)

    # print
    print(f"Acceptance mean: {acc_mean:.4f}, walkers: {nwalkers}, steps: {nsteps}")
    for p in params:
        print(f"{p}: Rhat={rhat[p]:.4f}, ESS={ess[p]:.0f}, tau~={tau[p]:.1f}")
    for p in params:
        q16,q50,q84 = posts[p]
        print(f"{p:8s} = {q50:.4f} (+{q84-q50:.4f}, -{q50-q16:.4f})")

    # save for consolidated table
    results[label] = {"rhat": rhat, "ess": ess, "tau": tau, "posts": posts, "acc_mean": acc_mean}

    # plots
    make_trace_plot(chain_raw, label, params, max_steps=5000)
    make_corner_plot(flat, label, params, posts=posts,
                 title=f"Posterior for {label}")




Run: pantheon_only_fixedM
Acceptance mean: 0.5309, walkers: 30, steps: 35000
Omega_m: Rhat=1.0004, ESS=78104, tau~=13.4
h: Rhat=1.0003, ESS=80499, tau~=13.0
Omega_m  = 0.3308 (+0.0195, -0.0189)
h        = 0.7085 (+0.0025, -0.0026)

Run: pantheon_plus_bao_fixedM
Acceptance mean: 0.4895, walkers: 30, steps: 35000
Omega_m: Rhat=1.0937, ESS=189, tau~=5566.6
h: Rhat=1.0939, ESS=188, tau~=5581.6
Omega_m  = 0.2670 (+0.0084, -0.0078)
h        = 0.7150 (+0.0018, -0.0019)

Run: joint_bao_h0_varyM
Acceptance mean: 0.4404, walkers: 27, steps: 33000
Omega_m: Rhat=1.1009, ESS=162, tau~=5509.2
h: Rhat=1.1012, ESS=160, tau~=5553.0
M: Rhat=1.1012, ESS=161, tau~=5547.7
Omega_m  = 0.2676 (+0.0084, -0.0088)
h        = 0.7176 (+0.0077, -0.0081)
M        = -19.2902 (+0.0263, -0.0243)


In [7]:
rows = []
for label, params in runs.items():
    meta = results[label]
    row = {"Run": label, "Acceptance": meta["acc_mean"]}
    for p in params:
        row[f"Rhat({p})"]  = meta["rhat"][p]
        row[f"ESS({p})"]   = meta["ess"][p]
        row[f"tau({p})"]   = meta["tau"][p]
        q16,q50,q84 = meta["posts"][p]
        row[f"{p}_median"] = q50
        row[f"{p}_+err"]   = q84 - q50
        row[f"{p}_-err"]   = q50 - q16
    rows.append(row)

df = pd.DataFrame(rows)
display(df)

Unnamed: 0,Run,Acceptance,Rhat(Omega_m),ESS(Omega_m),tau(Omega_m),Omega_m_median,Omega_m_+err,Omega_m_-err,Rhat(h),ESS(h),tau(h),h_median,h_+err,h_-err,Rhat(M),ESS(M),tau(M),M_median,M_+err,M_-err
0,pantheon_only_fixedM,0.530899,1.000381,78104.051852,13.443605,0.330782,0.019511,0.018916,1.000336,80498.577968,13.043709,0.708539,0.002537,0.002558,,,,,,
1,pantheon_plus_bao_fixedM,0.489467,1.093744,188.624146,5566.625607,0.266962,0.008426,0.007776,1.093909,188.116904,5581.635568,0.71496,0.00181,0.001908,,,,,,
2,joint_bao_h0_varyM,0.440435,1.100902,161.729623,5509.19481,0.267585,0.008418,0.008786,1.101249,160.452998,5553.028071,0.71756,0.007712,0.00807,1.101167,160.605673,5547.749241,-19.290246,0.026279,0.024284


In [8]:
# ============================================================
# Cell 4: Build summary table, save nicely
# ============================================================

rows = []
for label, params in runs.items():
    meta = results[label]
    row = {"Run": label, "Acceptance": meta["acc_mean"]}
    for p in params:
        row[f"Rhat({p})"]  = meta["rhat"][p]
        row[f"ESS({p})"]   = meta["ess"][p]
        row[f"tau({p})"]   = meta["tau"][p]
        q16,q50,q84 = meta["posts"][p]
        row[f"{p}_median"] = q50
        row[f"{p}_+err"]   = q84 - q50
        row[f"{p}_-err"]   = q50 - q16
    rows.append(row)

df = pd.DataFrame(rows)

# format numbers to .2f
df_fmt = df.copy()
for col in df_fmt.columns:
    if col != "Run":
        df_fmt[col] = df_fmt[col].map(lambda x: f"{x:.2f}")

display(df_fmt)

# save to multiple formats
df_fmt.to_csv(os.path.join(PLOT_DIR, "summary_table.csv"), index=False)

html = df_fmt.to_html(index=False, escape=False, justify="center")
with open(os.path.join(PLOT_DIR, "summary_table.html"), "w") as f:
    f.write(html)

# screenshot-friendly exports
dfi.export(df_fmt, os.path.join(PLOT_DIR, "summary_table.png"))
dfi.export(df_fmt, os.path.join(PLOT_DIR, "summary_table.pdf"))


Unnamed: 0,Run,Acceptance,Rhat(Omega_m),ESS(Omega_m),tau(Omega_m),Omega_m_median,Omega_m_+err,Omega_m_-err,Rhat(h),ESS(h),tau(h),h_median,h_+err,h_-err,Rhat(M),ESS(M),tau(M),M_median,M_+err,M_-err
0,pantheon_only_fixedM,0.53,1.0,78104.05,13.44,0.33,0.02,0.02,1.0,80498.58,13.04,0.71,0.0,0.0,,,,,,
1,pantheon_plus_bao_fixedM,0.49,1.09,188.62,5566.63,0.27,0.01,0.01,1.09,188.12,5581.64,0.71,0.0,0.0,,,,,,
2,joint_bao_h0_varyM,0.44,1.1,161.73,5509.19,0.27,0.01,0.01,1.1,160.45,5553.03,0.72,0.01,0.01,1.1,160.61,5547.75,-19.29,0.03,0.02


Error: It looks like you are using Playwright Sync API inside the asyncio loop.
Please use the Async API instead.

In [None]:
# ============================================================
# Cell 5: Combined GetDist comparison
# ============================================================

gd_samples = {}
for label, params in runs.items():
    flat = load_chain_flat(label)
    gd_samples[label] = MCSamples(samples=flat, names=params, labels=params, label=label)

# only include parameters common across runs
common_params = ["Omega_m", "h"]

g = plots.get_subplot_plotter()
g.triangle_plot(
    [gd_samples["pantheon_only_fixedM"],
     gd_samples["pantheon_plus_bao_fixedM"],
     gd_samples["joint_bao_h0_varyM"]],
    params=common_params,
    filled=True,
    legend_labels=["Pantheon only (fixed M)",
                   "Pantheon+BAO (fixed M)",
                   "Joint BAO+H0 (vary M)"],
)

g.export(os.path.join(PLOT_DIR, "combined_getdist_corner.png"))
