In [None]:
import vice

In [None]:
import numpy as np

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import surp
from dataclasses import dataclass

import surp.gce_math as gcem
import arya

In [None]:
import toml

In [None]:
surp.set_yields()

In [None]:
@dataclass
class MCMCResult:
    params: dict
    labels: list
    samples: pd.DataFrame
    afe: pd.DataFrame
    ah: pd.DataFrame


    @classmethod
    def from_file(cls, modelname, y0=None, y_a=None, zeta_a=None, burn=0):
        modeldir = "../models/perturbations/mc_analysis/" + modelname + "/"
    
        with open(modeldir + "params.toml", "r") as f:
            params = toml.load(f)
    
    
        samples = pd.read_csv(modeldir + "mcmc_samples.csv")
        filt = samples.iteration >= burn
        samples = samples[filt]

        print("length of samples = ", len(samples))
        
        if y0 is not None:
            ya = samples["alpha"] * y0
            yt = ya + samples["y0_cc"] * 1e-3
            f = ya / yt
            samples["f_agb"] = f
            samples["y_tot"] = yt
            
        if y_a is not None:
            ya = samples["alpha"] * y_a
            yt = ya + samples["y0_cc"] * 1e-3
            f = ya / yt
            samples["f_agb_a"] = f
            samples["y_tot_a"] = yt
            samples["zeta1_a"] = samples["zeta_cc"] * 1e-3 + samples["alpha"] * zeta_a
        
        afe = pd.read_csv(modeldir + "mg_fe_binned.csv")
        ah = pd.read_csv(modeldir + "mg_h_binned.csv")
        labels = [k for k, v in params.items() if type(v) == dict]
    
    
        return cls(params=params, labels=labels, afe=afe, ah=ah, samples=samples)


    @classmethod
    def from_test_file(cls, modelname, y0=None, y_a=1e-3, zeta_a=-1e-3, burn=0):
        modeldir = "./mcmc_samples/"
    
        samples = pd.read_csv(modeldir + f"{modelname}.csv")
        print("length of samples = ", len(samples))
        
        if y0 is not None:
            ya = samples["alpha"] * y0
            yt = ya + samples["zeta0"] * 1e-3
            f = ya / yt
            samples["f_agb"] = f
            samples["y_tot"] = yt
            
        if y_a is not None:
            ya = samples["alpha"] * y_a
            yt = ya + samples["zeta0"] * 1e-3
            f = ya / yt
            samples["f_agb_a"] = f
            samples["y_tot_a"] = yt
            samples["zeta1_a"] = samples["zeta1"] * 1e-3 + samples["alpha"] * zeta_a

        ah = pd.read_csv(modeldir + "ah_binned.csv")
        ah["_x"] = ah.x
        afe = pd.read_csv(modeldir + "afe_binned.csv")
        afe["_x"] = afe.x
        labels = ["alpha", "zeta0", "zeta1", "zeta2"]
    
    
        return cls(params={}, labels=labels, afe=afe, ah=ah, samples=samples)

In [None]:
from corner import corner

In [None]:
def plot_corner(result, labels=None, **kwargs):

    if labels is not None:
        plot_labels = [labels[l] for l in result.labels]
    else:
        plot_labels = result.labels
        
    corner(result.samples[result.labels],  
           show_titles=True, 
           quantiles=[0.16, 0.5, 0.84], 
           labels=plot_labels,
           **kwargs)
    return 

In [None]:
def plot_samples_caah(mcmc_result, alpha=None, skip=10, color="black"):
    ah = mcmc_result.ah
    labels = mcmc_result.labels
    samples = mcmc_result.samples[::skip]


    if alpha is None:
        alpha = 1 / len(samples)**(1/3) / 10

    for l, sample in samples.iterrows():
        y = np.sum([sample[label] * ah[label] for label in labels], axis=0)

        plt.plot(ah._x, gcem.abund_ratio_to_brak(y, "c", "mg") , color=color, alpha=alpha, rasterized=True)
    
    plt.xlabel("[Mg/H]")
    plt.ylabel("[C/Mg]")

In [None]:
def plot_sample(sample, ah, labels, **kwargs):
    y = np.sum([sample[label] * ah[label] for label in labels], axis=0)
    plt.plot(ah._x, gcem.abund_ratio_to_brak(y, "c", "mg") , **kwargs )


In [None]:
def plot_samples_caah_mean(mcmc_result,plot_obs=True, **kwargs):
    ah = mcmc_result.ah
    labels = mcmc_result.labels

    sample = np.mean(mcmc_result.samples, axis=0)
    y = np.sum([sample[label] * ah[label] for label in labels], axis=0)

    plt.plot(ah._x, gcem.abund_ratio_to_brak(y, "c", "mg") , **kwargs )

    
    plt.xlabel("[Mg/H]")
    plt.ylabel("[C/Mg]")

In [None]:
def plot_obs_caah(mcmc_result, **kwargs):
    ah = mcmc_result.ah

    yerr = ah.obs_err / ah.obs / np.log(10) / np.sqrt(ah.obs_counts)
    y = gcem.abund_ratio_to_brak(ah.obs, "c", "mg") 
    plt.errorbar(ah._x, y, yerr=yerr, fmt="o", **kwargs)


In [None]:
def plot_samples_caafe(mcmc_result, alpha=None, skip=10, color="black", **kwargs):
    ah = mcmc_result.afe
    labels = mcmc_result.labels
    samples = mcmc_result.samples[::skip]
    
    if alpha is None:
        alpha = 1 / len(samples)**(1/3) / 10

    for l, sample in samples.iterrows():
        y = np.sum([sample[label] * ah[label] for label in labels], axis=0)

        plt.plot(ah._x, gcem.abund_ratio_to_brak(y, "c", "mg"), color=color, alpha=alpha, rasterized=True, **kwargs)
    plt.xlabel("[Mg/Fe]")
    plt.ylabel("[C/Mg]")

In [None]:
def plot_obs_caafe(mcmc_result, **kwargs):
    ah = mcmc_result.afe

    yerr = ah.obs_err / ah.obs / np.log(10) / np.sqrt(ah.obs_counts)
    y = gcem.abund_ratio_to_brak(ah.obs, "c", "mg") 
    plt.errorbar(ah._x, y, yerr=yerr, fmt="o", **kwargs)


In [None]:
def plot_samples_caafe_mean(mcmc_result, plot_obs=True, **kwargs):
    ah = mcmc_result.afe
    labels = mcmc_result.labels
    
    sample = np.mean(mcmc_result.samples, axis=0)

    y = np.sum([sample[label] * ah[label] for label in labels], axis=0)

    plt.plot(ah._x, gcem.abund_ratio_to_brak(y, "c", "mg"), **kwargs)

    plt.xlabel("[Mg/Fe]")
    plt.ylabel("[C/Mg]")

In [None]:
def plot_fagb_hist(results):
    f = results.samples["f_agb"]
    plt.hist(f)
    plt.xlabel(r"$f_{\rm AGB}$")
    plt.ylabel("counts")
    l, m, h = np.quantile(f, [0.16, 0.5, 0.84])
    plt.title(f"${m:0.3f}_{{-{m-l:0.3f}}}^{{+{h-m:0.3f}}}$")

In [None]:
def find_model(name):
    """
    Finds the pickled model with either the given name or the parameters 
    and returns the csv summary
    """
    file_name = "../models/" + name + "/yield_params.toml"
    ys = surp.YieldParams.from_file(file_name)
    surp.set_yields(ys)
    
    file_name = "../models/" + name + "/stars.csv"
    model =  pd.read_csv(file_name, index_col=0)
    return model

In [None]:
def print_stats(result):
    print("parameter\t med\t 16th\t 84th")
    for name in result.labels:
        col = result.samples[name]
        m = np.median(col)
        l, h = np.quantile(col, [0.16, 0.84])
        print(f"{name:8s}\t{m:6.3f}\t{l-m:6.3f}\t+{h-m:5.3f}")

In [None]:
from scipy.stats import binned_statistic

In [None]:
def plot_yields(result):
    for label in result.labels:
        plt.scatter(result.ah._x, result.ah[label], label=label)

    plt.xlabel("[M/H]")
    plt.ylabel("yield")
    arya.Legend(-1)

In [None]:
def plot_all(filename, y0=None, y_a=None, zeta_a=None, test=False, burn=0, skip=10):
    if test:
        result = MCMCResult.from_test_file(filename, burn=burn)
    else:
        result = MCMCResult.from_file(filename, y0=y0, burn=burn, y_a=y_a, zeta_a=zeta_a)

    print_stats(result)
    plot_yields(result)
    plt.show()
    
    fig = plt.figure(figsize=(6, 6))
    plot_corner(result, fig=fig)
    plt.show()

    plot_samples_caah(result, skip=skip)
    plot_obs_caah(result, color=arya.COLORS[1])
    plt.show()

    plot_samples_caafe(result, skip=skip)
    plot_obs_caafe(result, color=arya.COLORS[1])
    plt.show()

    if y0 is not None:
        plot_fagb_hist(result)
        
    return result

In [None]:
vice.yields.agb.settings["c"] = surp.agb_interpolator.interpolator("c", mass_factor=0.5)
y0 = surp.yields.calc_y(kind="agb")
y0

In [None]:
for study in surp.AGB_MODELS:
    vice.yields.agb.settings["c"] = surp.agb_interpolator.interpolator("c", study=study)
    y0 = surp.yields.calc_y(Z=0.016, kind="agb")
    print(f"{study:8s}\t{y0/1e-4:6.3f}")


# Body

## Main comparisons

In [None]:
results = {}

In [None]:
results["fiducial"] = plot_all("fiducial", y0=1e-3, y_a=1e-3, zeta_a=-1e-3)

In [None]:
results["analytic"] = plot_all("analytic_quad", y0=1e-3, y_a=1e-3, zeta_a=-1e-3)

In [None]:
ana = results["analytic"]

plot_sample({"alpha": 0.49, "y0_cc": 2.28, "zeta_cc": 1.50, "A_cc": 2.65}, ana.ah, ana.labels, label="mean")
plot_sample({"alpha": 0.49, "y0_cc": 2.28, "zeta_cc": 1.36, "A_cc": 2.65}, ana.ah, ana.labels, label="low z1", linestyle="--")
plot_sample({"alpha": 0.49, "y0_cc": 2.28, "zeta_cc": 1.50, "A_cc": 2.25}, ana.ah, ana.labels, label="low z2", linestyle="--")
plot_sample({"alpha": 0.49, "y0_cc": 2.28, "zeta_cc": 1.36, "A_cc": 2.25}, ana.ah, ana.labels, label="low z1 z2")
plot_sample({"alpha": 0.49, "y0_cc": 2.28, "zeta_cc": 1.60, "A_cc": 3.0}, ana.ah, ana.labels, label="high z1 z2")
plot_obs_caah(ana, color="black")
arya.Legend(-1, color_only=True)

plt.xlabel("[Mg/H]")
plt.ylabel("[C/Mg]")

In [None]:
ana = results["analytic"]

plot_sample({"alpha": 0.49, "y0_cc": 2.28, "zeta_cc": 1.50, "A_cc": 2.65}, ana.afe, ana.labels, label="mean")
plot_sample({"alpha": 0.49, "y0_cc": 2.28, "zeta_cc": 1.36, "A_cc": 2.25}, ana.afe, ana.labels, label="low z1 z2")
plot_sample({"alpha": 0.49, "y0_cc": 2.28, "zeta_cc": 1.60, "A_cc": 3.0}, ana.afe, ana.labels, label="high z1 z2")
plot_obs_caafe(ana, color="black")
arya.Legend(-1, color_only=True)

plt.xlabel("[Mg/Fe]")
plt.ylabel("[C/Mg]")

The fiducial model should be exactly the same as the analytic model. The only difference is the fiducial model uses a single vice simulation instead of compositing several different in the case of the analytic. The analytic formulation makes it easier to add additional yield models, but the fiducial formulation makes exploring GCE uncertanties more efficient. In this notebook, all methods are read in and analyzed identically..

what happens if we remove the \[alpha/Fe\] component of the likelihood?

In [None]:
res_bad = plot_all("fiducial_caah_only", y0=1e-3, burn=1000, skip=100)

In [None]:
results["eta2"] = plot_all("eta2", y0=1e-3, y_a=1e-3, zeta_a=-1e-3)

In [None]:
results["lateburst"] = plot_all("lateburst", y0=1e-3, y_a=1e-3, zeta_a=-1e-3)

In [None]:
results["twoinfall"] = plot_all("twoinfall", y0=1e-3, y_a=1e-3, zeta_a=-1e-3)

In [None]:
plot_all("analytic_quad_m0.2", y0=1e-3, y_a=1e-3, zeta_a=-1e-3)

## Agb models

In [None]:
results["fruity"] = plot_all("fruity_quad", y0=3.229e-4, y_a=3.82e-4, zeta_a=-3.5e-4)

In [None]:
results["fruity_m0.5"] = plot_all("fruity_mf0.5_quad", y0=2.9e-4)

In [None]:
plt.hist(ana.samples.lp)

In [None]:
results["aton"] = plot_all("aton_quad", y0=0.285e-4, y_a=1.85e-4, zeta_a=-9.4e-5)

In [None]:
results["monash"] = plot_all("monash_quad", y0=3.444e-4, y_a=2.8e-4, zeta_a=-10.1e-4);

In [None]:
results["nugrid"] = plot_all("nugrid_quad", y0=10.95e-4, y_a=5.9e-4, zeta_a=-5.7e-4);

## Surveys

In [None]:
results["v21"] = plot_all("fiducial_vincenzo21", y0=1e-3, y_a=1e-3, zeta_a=-1e-3)

In [None]:
results["gso"] = plot_all("fiducial_gso", y0=1e-3, y_a=1e-3, zeta_a=-1e-3)

In [None]:
results["galah"] = plot_all("fiducial_galah", y0=1e-3, y_a=1e-3, zeta_a=-1e-3)

## Test Cases

In [None]:
test_results = {}

In [None]:
test_results["NUTS"] = plot_all("NUTS", test=True)

In [None]:
test_results["HMC"] = plot_all("HMC", test=True)

In [None]:
test_results["RWMH"] = plot_all("RWMH", test=True)

In [None]:
test_results["both_sigma"] = plot_all("both_sigma", test=True)

In [None]:
test_results["fine_bins"] = plot_all("fine_bins", test=True)

In [None]:
test_results["superfine_bins"] = plot_all("superfine_bins", test=True)

In [None]:
test_results["both_sigma"] = plot_all("both_sigma", test=True)

In [None]:
test_results["equal_num_bins"] = plot_all("equal_num_bins", test=True)

In [None]:
test_results["add_rand_scatter"] = plot_all("add_rand_scatter", test=True)

In [None]:
test_results["equalnum_fine"] = plot_all("equalnum_fine", test=True)

In [None]:
test_results["equal_num_bins"] = plot_all("equal_num_bins", test=True)

In [None]:
test_results["t_test"] = plot_all("t_test", test=True)

In [None]:
test_results["kstest_samples"] = plot_all("ks_test", test=True)

In [None]:
test_results["2s"] = plot_all("both_sigma", test=True)

# Comparisons

## Main comparions

In [None]:
fig = plt.figure(figsize=(3.3, 3.3))

plot_corner(results["analytic"], 
            fig = fig,
            labels={
    "alpha": r"$\alpha$",
    "y0_cc": r"$\zeta^{(0)}$",
    "zeta_cc": r"$\zeta^{(1)}$",
    "A_cc": r"$\zeta^{(2)}$",},
            labelpad=0.1,
           )

plt.savefig("figures/mcmc_corner.pdf")

In [None]:
fig = plt.figure(figsize=(3.3, 3.3))

plot_corner(res_bad, 
            fig = fig,
            labels={
    "alpha": r"$\alpha$",
    "y0_cc": r"$\zeta^{(0)}$",
    "zeta_cc": r"$\zeta^{(1)}$",
    "A_cc": r"$\zeta^{(2)}$",},
            labelpad=0.1,
           )

plt.savefig("figures/mcmc_corner_bad.pdf")

In [None]:
agb_interpolator = surp.agb_interpolator.interpolator

In [None]:
y_z0 = lambda z: 1e-3
y_z1 = np.vectorize(lambda z: -2*y_z0(z) + surp.yield_models.BiLogLin_CC(y0=0.002, zeta=0.001, y1=0)(z))
y_z2 = np.vectorize(surp.yield_models.Quadratic_CC(y0=0, zeta=0, A=0.001, Z1=0.0016))

Y_agbs = {
    "fruity": agb_interpolator("c"),
    "fruity_m0.5": agb_interpolator("c", mass_factor=0.5),
    "aton": agb_interpolator("c", study="ventura13"),
    "monash": agb_interpolator("c", study="karakas16"),
    "nugrid": agb_interpolator("c", study="pignatari16"),
    "analytic": surp.yield_models.C_AGB_Model(y0=1e-3, zeta=1e-3, tau_agb=1, t_D=0.15)
}

In [None]:
results["fruity"].samples

In [None]:
M_H=np.linspace(-0.5, 0.5, 1000)
Z = gcem.MH_to_Z(M_H)
surp.set_yields(verbose=False)
ys_fiducial = surp.yields.calc_y(Z)

Z = gcem.MH_to_Z(M_H)
y_agbs = {}
for key, Y_agb in Y_agbs.items():
    print(Y_agb)
    vice.yields.agb.settings["c"] = Y_agb
    ys_a = surp.yields.calc_y(Z, kind="agb")
    
    y_agbs[key] = ys_a


In [None]:
def plot_y_tot_mean(result, ys_a, M_H=M_H, **kwargs):
    samples = result.samples
    Z = gcem.MH_to_Z(M_H)
        
    ys_z0 = y_z0(Z)
    ys_z1 = y_z1(Z)
    ys_z2 = y_z2(Z)
    ymg = vice.yields.ccsne.settings["mg"]

    sample = samples.median()
    yt = sample.y0_cc * ys_z0 + sample.zeta_cc * ys_z1 + sample.A_cc * ys_z2 + sample.alpha * ys_a
    plt.plot(M_H, yt / ymg, **kwargs)


In [None]:
results["analytic"].samples.median()

In [None]:
def plot_y_tot(result, ys_a, thin=10, M_H=M_H, color="black", alpha=None):
    samples = result.samples
    Z = gcem.MH_to_Z(M_H)
    if alpha is None:
        alpha = 1 / len(samples)**(1/3) / 10
        
    ys_z0 = y_z0(Z)
    ys_z1 = y_z1(Z)
    ys_z2 = y_z2(Z)
    ymg = vice.yields.ccsne.settings["mg"]

    for i, sample in samples[::thin].iterrows():
        yt = sample.y0_cc * ys_z0 + sample.zeta_cc * ys_z1 + sample.A_cc * ys_z2 + sample.alpha * ys_a
        plt.plot(M_H, yt / ymg, color=color, alpha=alpha, rasterized=True)
    





In [None]:
plot_y_tot(results["monash"], y_agbs["monash"])

In [None]:
plot_labels = {
    "fruity": r"fruity",
    "monash": r"monash",
    "nugrid": r"nugrid",
    "aton": r"aton",
    "analytic": r"analytic",
    # "eta2": r"$y\rightarrow 2y$",
    # "lateburst": r"lateburst",
    # "twoinfall": r"twoinfall",
}

In [None]:

plt.figure()
plot_y_tot(results["fruity"], y_agbs["fruity"], thin=30, alpha=0.01, color=arya.COLORS[0])

for i, (key, label) in enumerate(plot_labels.items()): 
    result = results[key]
    if key in y_agbs.keys():
        y_agb = y_agbs[key]
    else:
        print("warning, no agb for ", key)
        
        y_agb = y_agbs["analytic"]

    
    plot_y_tot_mean(result, y_agb, color=arya.COLORS[i], label=label, ls=["-", ":", "--", "-."][i%4])
    

plt.xlabel(r"$\log Z / Z_\odot$")
plt.ylabel(r"$y_{\rm C}$")

plt.legend()
plt.savefig("figures/mcmc_y_tot.pdf")

In [None]:
plot_labels = {

    "analytic": r"analytic",
    "eta2": r"$y\rightarrow 2y$",
    "lateburst": r"lateburst",
    "twoinfall": r"twoinfall",
}

In [None]:

plt.figure()
plot_y_tot(results["analytic"], y_agbs["analytic"], thin=30, alpha=0.01, color=arya.COLORS[0])

for i, (key, label) in enumerate(plot_labels.items()): 
    result = results[key]
    if key in y_agbs.keys():
        y_agb = y_agbs[key]
    else:
        print("warning, no agb for ", key)
        
        y_agb = y_agbs["analytic"]

    
    plot_y_tot_mean(result, y_agb, color=arya.COLORS[i], label=label, ls=["-", ":", "--", "-."][i%4])
    

plt.xlabel(r"$\log Z / Z_\odot$")
plt.ylabel(r"$y_{\rm C}$")

plt.legend()


In [None]:
Nr = len(plot_labels)
fig, axs = plt.subplots(Nr, 1, figsize=(3, 6), sharex="col", sharey=True, gridspec_kw={"hspace": 0})

for i, (key, label) in enumerate(plot_labels.items()): 
    result = results[key]
    plt.sca(axs[i])
    if key in y_agbs.keys():
        y_agb = y_agbs[key]
    else:
        print("warning, no agb for ", key)
        
        y_agb = y_agbs["analytic"]
        
    plot_y_tot(result, y_agb, thin=10, alpha=0.01)
    
    plt.ylabel(label)

fig.supylabel(r"$y_{\rm C} / y_{\rm Mg}$")

plt.xlabel(r"$\log Z / Z_\odot$")
plt.ylim(3, 8)
plt.tight_layout()
plt.savefig("figures/mcmc_ytot.pdf")

In [None]:
plot_labels = {
    "fruity": r"fruity",
    "monash": r"monash",
    "nugrid": r"nugrid",
    "aton": r"aton",
    "analytic": r"analytic",
    "eta2": r"$y\rightarrow 2y$",
    "lateburst": r"lateburst",
    "twoinfall": r"twoinfall",
}

In [None]:
def compare_param_hists(results, plot_labels, var):
    Nr = len(plot_labels)
    fig, axs = plt.subplots(Nr, 1, figsize=(3, 2), sharex="col", gridspec_kw={"hspace": 0})
    
    for i, (key, label) in enumerate(plot_labels.items()):
        if key == "hline":
            ax = axs[i]
            plt.sca(axs[i])
            ax.spines[['bottom', 'top']].set_visible(False)
            plt.axhline(0.5, color=label, linestyle=":")
            ax.xaxis.set_visible(False)
            ax.set_yticks([])
            ax.set_yticks([], minor=True)
            
            continue
    
        color = arya.COLORS[i]
        
        result = results[key]
        ax = axs[i]
        plt.sca(axs[i])
        ls = "-"
        plt.hist(result.samples[var], histtype="step", color=color, ls=ls)
        plt.ylabel(label, rotation=0, ha="right", va="center")
        
        if Nr - 1 > i > 0:
            ax.spines[['bottom', 'top']].set_visible(False)
            ax.xaxis.set_visible(False)
        elif i == 0:
            ax.spines[['bottom']].set_visible(False)
            ax.tick_params(axis='x',  bottom=False, which="both")
        elif i == Nr - 1:
            ax.spines[['top']].set_visible(False)
            ax.tick_params(axis='x',  top=False, which="both")
    
    
        ax.set_yticks([])
        ax.set_yticks([], minor=True)
    
    
    
    plt.sca(axs[-1])
    plt.xlabel(var)

In [None]:
Nr = len(plot_labels)
fig, axs = plt.subplots(Nr, 1, figsize=(3, 2), sharex="col", gridspec_kw={"hspace": 0})

for i, (key, label) in enumerate(plot_labels.items()):
    if key == "hline":
        ax = axs[i]
        plt.sca(axs[i])
        ax.spines[['bottom', 'top']].set_visible(False)
        plt.axhline(0.5, color=label, linestyle=":")
        ax.xaxis.set_visible(False)
        ax.set_yticks([])
        ax.set_yticks([], minor=True)
        
        continue

    if i < 4:
        color = arya.COLORS[i]
        ls = "-"
    elif i == 4:
        ls = "-"
        color = "k"
    else:
        ls = "--"
        color = "k"

    result = results[key]
    ax = axs[i]
    plt.sca(axs[i])
    plt.hist(result.samples.f_agb, histtype="step", color=color, ls=ls)
    plt.ylabel(label, rotation=0, ha="right", va="center")
    
    if Nr - 1 > i > 0:
        ax.spines[['bottom', 'top']].set_visible(False)
        ax.xaxis.set_visible(False)
    elif i == 0:
        ax.spines[['bottom']].set_visible(False)
        ax.tick_params(axis='x',  bottom=False, which="both")
    elif i == Nr - 1:
        ax.spines[['top']].set_visible(False)
        ax.tick_params(axis='x',  top=False, which="both")


    ax.set_yticks([])
    ax.set_yticks([], minor=True)



plt.sca(axs[-1])
plt.xlabel("f agb")
plt.xlim(0, 0.5)

plt.tight_layout()
plt.savefig("figures/mcmc_fagb.pdf")

In [None]:
def plot_all_params(plot_labels):
    for param in ["f_agb", "y0_cc", "zeta_cc", "A_cc"]:
        compare_param_hists(results, plot_labels,  param)
        plt.show()

In [None]:
plot_labels = {
    "fiducial": "subgiants", 
    "galah": "galah",
    "v21": "v21",
    #"gso": "gso",
}
plot_all_params(plot_labels)

In [None]:
Nr = len(plot_labels)
fig, axs = plt.subplots(Nr, 1, figsize=(3, 2), sharex="col", gridspec_kw={"hspace": 0})

for i, (key, label) in enumerate(plot_labels.items()):
    if key == "hline":
        ax = axs[i]
        plt.sca(axs[i])
        ax.spines[['bottom', 'top']].set_visible(False)
        plt.axhline(0.5, color=label, linestyle=":")
        ax.xaxis.set_visible(False)
        ax.set_yticks([])
        ax.set_yticks([], minor=True)
        
        continue

    if i < 4:
        color = arya.COLORS[i]
        ls = "-"
    elif i == 4:
        ls = "-"
        color = "k"
    else:
        ls = "--"
        color = "k"

    result = results[key]
    ax = axs[i]
    plt.sca(axs[i])
    plt.hist(result.samples.f_agb_a, histtype="step", color=color, ls=ls)
    plt.ylabel(label, rotation=0, ha="right", va="center")
    
    if Nr - 1 > i > 0:
        ax.spines[['bottom', 'top']].set_visible(False)
        ax.xaxis.set_visible(False)
    elif i == 0:
        ax.spines[['bottom']].set_visible(False)
        ax.tick_params(axis='x',  bottom=False, which="both")
    elif i == Nr - 1:
        ax.spines[['top']].set_visible(False)
        ax.tick_params(axis='x',  top=False, which="both")


    ax.set_yticks([])
    ax.set_yticks([], minor=True)



plt.sca(axs[-1])
plt.xlabel("f agb")
plt.xlim(0, 0.5)

plt.tight_layout()

In [None]:
results["analytic"].samples

In [None]:
Nr = len(plot_labels)
fig, axs = plt.subplots(Nr, 1, figsize=(3, 2), sharex="col", gridspec_kw={"hspace": 0})

for i, (key, label) in enumerate(plot_labels.items()):
    if key == "hline":
        ax = axs[i]
        plt.sca(axs[i])
        ax.spines[['bottom', 'top']].set_visible(False)
        plt.axhline(0.5, color=label, linestyle=":")
        ax.xaxis.set_visible(False)
        ax.set_yticks([])
        ax.set_yticks([], minor=True)
        
        continue

    if i < 4:
        color = arya.COLORS[i]
        ls = "-"
    elif i == 4:
        ls = "-"
        color = "k"
    else:
        ls = "--"
        color = "k"

    result = results[key]
    ax = axs[i]
    plt.sca(axs[i])
    plt.hist(result.samples.zeta1_a, histtype="step", color=color, ls=ls)
    plt.ylabel(label, rotation=0, ha="right", va="center", color=color)
    
    if Nr - 1 > i > 0:
        ax.spines[['bottom', 'top']].set_visible(False)
        ax.xaxis.set_visible(False)
    elif i == 0:
        ax.spines[['bottom']].set_visible(False)
        ax.tick_params(axis='x',  bottom=False, which="both")
    elif i == Nr - 1:
        ax.spines[['top']].set_visible(False)
        ax.tick_params(axis='x',  top=False, which="both")


    ax.set_yticks([])
    ax.set_yticks([], minor=True)



plt.sca(axs[-1])
plt.xlabel("zeta1")

plt.tight_layout()

In [None]:
Nr = len(plot_labels)
fig, axs = plt.subplots(Nr, 1, figsize=(3, 2), sharex="col", gridspec_kw={"hspace": 0})

for i, (key, label) in enumerate(plot_labels.items()):
    if key == "hline":
        ax = axs[i]
        plt.sca(axs[i])
        ax.spines[['bottom', 'top']].set_visible(False)
        plt.axhline(0.5, color=label, linestyle=":")
        ax.xaxis.set_visible(False)
        ax.set_yticks([])
        ax.set_yticks([], minor=True)
        
        continue

    if i < 4:
        color = arya.COLORS[i]
        ls = "-"
    elif i == 4:
        ls = "-"
        color = "k"
    else:
        ls = "--"
        color = "k"

    result = results[key]
    ax = axs[i]
    plt.sca(axs[i])
    y_mg = vice.yields.ccsne.settings["mg"]
    plt.hist(result.samples.y_tot / y_mg, histtype="step", color=color, ls=ls)
    plt.ylabel(label, rotation=0, ha="right", va="center")
    
    if Nr - 1 > i > 0:
        ax.spines[['bottom', 'top']].set_visible(False)
        ax.xaxis.set_visible(False)
    elif i == 0:
        ax.spines[['bottom']].set_visible(False)
        ax.tick_params(axis='x',  bottom=False, which="both")
    elif i == Nr - 1:
        ax.spines[['top']].set_visible(False)
        ax.tick_params(axis='x',  top=False, which="both")


    ax.set_yticks([])
    ax.set_yticks([], minor=True)



plt.sca(axs[-1])
plt.xlabel("y0 / ymg")

plt.tight_layout()


In [None]:
Nr = len(plot_labels)
fig, axs = plt.subplots(Nr, 1, figsize=(3, 2), sharex="col", gridspec_kw={"hspace": 0})

for i, (key, label) in enumerate(plot_labels.items()):
    if key == "hline":
        ax = axs[i]
        plt.sca(axs[i])
        ax.spines[['bottom', 'top']].set_visible(False)
        plt.axhline(0.5, color=label, linestyle=":")
        ax.xaxis.set_visible(False)
        ax.set_yticks([])
        ax.set_yticks([], minor=True)
        
        continue

    if i < 4:
        color = arya.COLORS[i]
        ls = "-"
    elif i == 4:
        ls = "-"
        color = "k"
    else:
        ls = "--"
        color = "k"

    result = results[key]
    ax = axs[i]
    plt.sca(axs[i])
    y_mg = vice.yields.ccsne.settings["mg"]
    plt.hist(result.samples.y_tot_a / y_mg, histtype="step", color=color, ls=ls)
    plt.ylabel(label, rotation=0, ha="right", va="center")
    
    if Nr - 1 > i > 0:
        ax.spines[['bottom', 'top']].set_visible(False)
        ax.xaxis.set_visible(False)
    elif i == 0:
        ax.spines[['bottom']].set_visible(False)
        ax.tick_params(axis='x',  bottom=False, which="both")
    elif i == Nr - 1:
        ax.spines[['top']].set_visible(False)
        ax.tick_params(axis='x',  top=False, which="both")


    ax.set_yticks([])
    ax.set_yticks([], minor=True)



plt.sca(axs[-1])
plt.xlabel("y0 / ymg")

plt.tight_layout()

In [None]:
plot_labels = {
    "fruity": r"fruity",
    "monash": r"monash",
    "nugrid": r"nugrid",
    "aton": r"aton",
    "analytic": r"analytic",
}

In [None]:

fig, axs = plt.subplots(1, 2, figsize=(7, 3), sharex="col", sharey=True,  gridspec_kw={"wspace": 0, "hspace": 0})

plt.sca(axs[0])
plot_samples_caah(results["fruity"], color=arya.COLORS[0], alpha=0.01, skip=30)
for i, (key, label) in enumerate(plot_labels.items()):
    plot_samples_caah_mean(results[key], label=label, color=arya.COLORS[i])


plot_obs_caah(results["analytic"], color="black")


plt.sca(axs[1])
plot_samples_caafe(results["fruity"], color=arya.COLORS[0], alpha=0.01, skip=30)

for key, label in plot_labels.items():
    plot_samples_caafe_mean(results[key], label=label)
plot_obs_caafe(results["analytic"], color="k")
plt.ylabel("")
plt.legend()
plt.savefig("figures/mcmc_caahfe_predicted.pdf")

In [None]:
results["galah"].afe

In [None]:
plot_labels = {
    "fiducial": "subgiants", 
    "galah": "galah",
    "v21": "v21",
}


fig, axs = plt.subplots(1, 2, figsize=(7, 3), sharex="col", sharey=True,  gridspec_kw={"wspace": 0, "hspace": 0})

plt.sca(axs[0])
plot_samples_caah(results["fiducial"], color=arya.COLORS[0], alpha=0.01, skip=30)
for i, (key, label) in enumerate(plot_labels.items()):
    plot_samples_caah_mean(results[key], label=label, color=arya.COLORS[i])




plt.sca(axs[1])
plot_samples_caafe(results["fiducial"], color=arya.COLORS[0], alpha=0.01, skip=30)

for key, label in plot_labels.items():
    plot_samples_caafe_mean(results[key], label=label)
#plot_obs_caafe(results["analytic"], color="k")
plt.ylabel("")
plt.legend()


In [None]:
Nr = len(plot_labels)
fig, axs = plt.subplots(Nr, 2, figsize=(6, 9), sharex="col", sharey=True, gridspec_kw={"hspace": 0, "wspace": 0})

for i, (key, label) in enumerate(plot_labels.items()):
    ax = axs[i]
    plt.sca(axs[i][0])
    
    if key == "hline":
        continue

    result = results[key]
    
    plot_samples_caah(result)
    plt.ylabel(label, rotation=0, ha="right", va="center")


    plt.sca(axs[i][1])
    plot_samples_caafe(result)
    plt.ylabel("")

    

plt.yticks([-0.05, -0.10, -0.15, -0.20])
plt.ylim(-0.2, 0)

fig.supylabel("[C/Mg]")
plt.tight_layout()


In [None]:
results["analytic"].labels

## Tabulated properties

In [None]:
def calc_χ2(result, median=False, normalized=False):

    if median:
        samples = result.samples.median()
    else:
        samples = result.samples
        
    chi2_ah = calc_χ2_binned(result.ah, samples, result.labels)
    chi2_afe = calc_χ2_binned(result.afe, samples, result.labels)

    df = (len(result.ah) + len(result.afe)) - len(result.labels)
    if normalized:
        return (chi2_ah + chi2_afe) / df
    else:
        return (chi2_ah + chi2_afe) 
        

In [None]:
calc_χ2_binned(ana.ah, ana.samples, result.labels)

In [None]:
def calc_χ2_binned(binned_data, samples, labels):

    y_exp = binned_data.obs
    sigma2_exp = binned_data.obs_err**2 / (binned_data.obs_counts)

    χ2 = 0
    for i in range(len(y_exp)):
        y_tot = np.sum([binned_data[label][i] * samples[label] for label in labels], axis=0)
        sigma2_err = np.sum([binned_data[f"{label}_err"][i]**2 * samples[label] / binned_data["_counts"][i] for label in labels], axis=0)
        χ2 += (y_exp[i] - y_tot)**2  / (sigma2_exp[i] + sigma2_err)
    
    return  χ2
    

In [None]:
calc_χ2(ana, median=True)

In [None]:

chi2s = []
lps = []
for label, result in results.items():
    χ2 = -calc_χ2(result, median=True, normalized=False)/5 + 10
    lp = np.quantile(result.samples.lp, 0.2)
    chi2s.append(χ2)
    lps.append(lp)

plt.scatter(chi2s, np.array(lps))


In [None]:

for label, result in results.items():
    χ2 = calc_χ2(result, median=False, normalized=True)
    lp = np.quantile(result.samples.lp, 0.95)

    m = np.median(χ2)
    l, u= np.quantile(χ2, [0.16, 0.84])

    χ2_median = calc_χ2(result, median=True, normalized=True)
    
    print(f"{label:16}{χ2_median:8.2f} {m:8.2f} -{m-l:8.2f} +{u-m:8.2f}")

    plt.hist(χ2, label=label, histtype="step")

plt.xlabel(r"$\bar\chi^2$")
arya.Legend(-1)

In [None]:
keys = ana.labels + ["f_agb_a", "y_tot_a", "zeta1_a"]
latex_table = ""

print(f"{'model':16} & $\\chi2$  & $\\log p$ & " + " & ".join(keys) + r"\\")
print("\\hline\\\\")
for label, result in results.items():
    χ2 = calc_χ2(result, median=True, normalized=True)
    lp = np.max(result.samples.lp)

    
    # Add the row for χ2 and lp
    latex_table += f"{label:16} & {χ2:8.1f} & {lp:8.2f} & "

    # Extract parameter values and uncertainties
    parameter_lines = []
    for key in keys:
        if key in ["y_tot_a", "zeta1_a"]:
            x = result.samples[key] /vice.yields.ccsne.settings["mg"]
        else:
            x = result.samples[key]
        median = np.median(x)
        lower, upper = np.quantile(x, [0.16, 0.84])
        uncertainty = (upper - median, median - lower)  # Asymmetric uncertainties
        formatted_value = f"${median:.2f}^{{+{uncertainty[0]:.2f}}}_{{-{uncertainty[1]:.2f}}}$"
        parameter_lines.append(f"{formatted_value}")
        
    latex_table += "  &  ".join(parameter_lines)
    latex_table += "\\\\ \n"

print(latex_table)

## Test comparisons

In [None]:
def plot_all_params(plot_labels):
    for param in ["alpha", "zeta0", "zeta1", "zeta2"]:
        compare_param_hists(test_results, plot_labels,  param)
        plt.show()

In [None]:
plot_labels = {
    "NUTS": "NUTS", 
    "HMC": "HMC",
    "RWMH": "RWMH",
    "add_rand_scatter": "syn scatter",
}

In [None]:
plot_all_params(plot_labels)

In [None]:
plot_labels = {
    "NUTS": "fiducial", 
    "fine_bins": "fine_bins",
    "superfine_bins": "superfine_bins",
    "equal_num_bins": "equal number",
    "equalnum_fine": "equal number fine",
}

In [None]:
plot_all_params(plot_labels)

In [None]:
plot_labels = {
    "NUTS": "fiducial", 
    "t_test": "t test",
    "2s": "both uncertanties",
}

In [None]:
plot_all_params(plot_labels)

# Validation

In [None]:
result = MCMCResult.from_file("analytic_quad")

In [None]:
np.median(result.samples, axis=0)

In [None]:

df = mz_stars[~mz_stars.high_alpha]

mg_h_bins = np.arange(-0.5, 0.31, 0.1)

bin_mids = (mg_h_bins[1:] + mg_h_bins[:-1])/2
ss = binned_statistic(df.MG_H, df.C_MG, bins=mg_h_bins, statistic="mean").statistic
se = binned_statistic(df.MG_H, df.C_MG, bins=mg_h_bins, statistic="std").statistic

In [None]:
plot_samples_caah(result)

plt.errorbar(bin_mids, ss, yerr=se, fmt="o")

In [None]:

df = mz_stars[(mz_stars.MG_H > -0.15 ) & (mz_stars.MG_H < -0.05)]

mg_h_bins = np.arange(0, 0.31, 0.05)

bin_mids = (mg_h_bins[1:] + mg_h_bins[:-1])/2
ss = binned_statistic(df.MG_FE, df.C_MG, bins=mg_h_bins, statistic="mean").statistic
se = binned_statistic(df.MG_FE, df.C_MG, bins=mg_h_bins, statistic="std").statistic

In [None]:
plot_samples_caafe(result)

plt.errorbar(bin_mids, ss, yerr=se, fmt="o")