In [None]:
import vice

In [None]:
import numpy as np

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import surp
from dataclasses import dataclass

import surp.gce_math as gcem
import arya

In [None]:
import toml

In [None]:
surp.set_yields()

In [None]:
@dataclass
class MCMCResult:
    params: dict
    labels: list
    samples: pd.DataFrame
    afe: pd.DataFrame
    ah: pd.DataFrame


    @classmethod
    def from_file(cls, modelname):
        modeldir = "../models/perturbations/mc_analysis/" + modelname + "/"
    
        with open(modeldir + "params.toml", "r") as f:
            params = toml.load(f)
    
    
        samples = pd.read_csv(modeldir + "mcmc_samples.csv")
        afe = pd.read_csv(modeldir + "mg_fe_binned.csv")
        ah = pd.read_csv(modeldir + "mg_h_binned.csv")
        labels = list(params.keys())
    
    
        return cls(params=params, labels=labels, afe=afe, ah=ah, samples=samples)
    

In [None]:
from corner import corner

In [None]:
def plot_corner(result, **kwargs):

    corner(result.samples[result.labels],  
           show_titles=True, 
           quantiles=[0.16, 0.5, 0.84], 
           labels=result.labels,
           **kwargs)
    return 

In [None]:
def plot_samples_caah(mcmc_result, alpha=None, skip=10):
    ah = mcmc_result.ah
    labels = mcmc_result.labels
    samples = mcmc_result.samples[::skip]


    if alpha is None:
        alpha = 1 / len(samples)**(1/3) / 10

    for l, sample in samples.iterrows():
        y = np.sum([sample[label] * ah[label] for label in labels], axis=0)

        plt.plot(ah._x, gcem.abund_to_brak(y, "c") - ah._x, color="black", alpha=alpha)

    plt.errorbar(ah._x, gcem.abund_to_brak(ah.obs, "c") - ah._x, yerr=ah.obs_err / ah.obs / np.log(10), fmt="o", color=arya.COLORS[1])
    
    plt.xlabel("[Mg/H]")
    plt.ylabel("[C/Mg]")

In [None]:
def plot_samples_caafe(mcmc_result, alpha=None, skip=10, mh0=-0.1):
    ah = mcmc_result.afe
    labels = mcmc_result.labels
    samples = mcmc_result.samples[::skip]

    

    if alpha is None:
        alpha = 1 / len(samples)**(1/3) / 10

    for l, sample in samples.iterrows():
        y = np.sum([sample[label] * ah[label] for label in labels], axis=0)

        plt.plot(ah._x, gcem.abund_to_brak(y, "c") - mh0, color="black", alpha=alpha)

    plt.errorbar(ah._x, gcem.abund_to_brak(ah.obs, "c") - mh0, yerr=ah.obs_err / ah.obs / np.log(10), fmt="o", color=arya.COLORS[1])
    
    plt.xlabel("[Mg/Fe]")
    plt.ylabel("[C/Mg]")

In [None]:
def plot_fagb_hist(results, y0):
    ya = results.samples["alpha"] * y0
    yt = ya + results.samples["y0_cc"] * 1e-3
    f = ya / yt
    plt.hist(f)
    plt.xlabel(r"$f_{\rm AGB}$")
    plt.ylabel("counts")
    l, m, h = np.quantile(f, [0.16, 0.5, 0.84])
    plt.title(f"${m:0.3f}_{{-{m-l:0.3f}}}^{{+{h-m:0.3f}}}$")

In [None]:
plot_fagb_hist(results, 3e-4)

In [None]:
def find_model(name):
    """
    Finds the pickled model with either the given name or the parameters 
    and returns the csv summary
    """
    file_name = "../models/" + name + "/yield_params.toml"
    ys = surp.YieldParams.from_file(file_name)
    surp.set_yields(ys)
    
    file_name = "../models/" + name + "/stars.csv"
    model =  pd.read_csv(file_name, index_col=0)
    return model

In [None]:
def print_stats(result):
    print("parameter\t med\t 16th\t 84th")
    for name in result.labels:
        col = result.samples[name]
        m = np.median(col)
        l, h = np.quantile(col, [0.16, 0.84])
        print(f"{name:8s}\t{m:6.3f}\t{l-m:6.3f}\t+{h-m:5.3f}")

# Body

In [None]:
from scipy.stats import binned_statistic

In [None]:
def plot_yields(result):
    for label in result.labels:
        plt.scatter(result.ah._x, result.ah[label] / gcem.brak_to_abund(result.ah._x, "mg"), label=label)

    plt.xlabel("[M/H]")
    plt.ylabel("yield")
    arya.Legend(-1)

In [None]:
def plot_all(filename, y0=None):
    result = MCMCResult.from_file(filename)
    print_stats(result)
    plot_yields(result)
    plt.show()
    
    fig = plt.figure(figsize=(6, 6))
    plot_corner(result, fig=fig)
    plt.show()

    plot_samples_caah(result)
    plt.show()

    plot_samples_caafe(result)

    plt.show()

    if y0 is not None:
        plot_fagb_hist(result, y0)
        
    return result

In [None]:
vice.yields.agb.settings["c"] = surp.agb_interpolator.interpolator("c", mass_factor=0.5)
y0 = surp.yields.calc_y(kind="agb")
y0

In [None]:
for study in surp.AGB_MODELS:
    vice.yields.agb.settings["c"] = surp.agb_interpolator.interpolator("c", study=study)
    y0 = surp.yields.calc_y(kind="agb")
    print(f"{study:8s}\t{y0:6.3e}")


In [None]:
result = plot_all("analytic_quad", y0=1e-3)

In [None]:
plot_all("analytic_quad_m0.2", y0=1e-3)

In [None]:
plot_all("fruity_quad", y0=3.2e-4)

In [None]:
plot_all("fruity_mf0.5_quad", y0=2.9e-4)

In [None]:
plot_all("aton_quad", y0=1.8e-4)

In [None]:
plot_all("monash_quad", y0=3.4e-4);

In [None]:
plot_all("nugrid_quad", y0=1e-3);

# Validation

In [None]:
result = MCMCResult.from_file("analytic_quad")

In [None]:
np.median(result.samples, axis=0)

In [None]:

df = mz_stars[~mz_stars.high_alpha]

mg_h_bins = np.arange(-0.5, 0.31, 0.1)

bin_mids = (mg_h_bins[1:] + mg_h_bins[:-1])/2
ss = binned_statistic(df.MG_H, df.C_MG, bins=mg_h_bins, statistic="mean").statistic
se = binned_statistic(df.MG_H, df.C_MG, bins=mg_h_bins, statistic="std").statistic

In [None]:
plot_samples_caah(result)

plt.errorbar(bin_mids, ss, yerr=se, fmt="o")

In [None]:

df = mz_stars[(mz_stars.MG_H > -0.15 ) & (mz_stars.MG_H < -0.05)]

mg_h_bins = np.arange(0, 0.31, 0.05)

bin_mids = (mg_h_bins[1:] + mg_h_bins[:-1])/2
ss = binned_statistic(df.MG_FE, df.C_MG, bins=mg_h_bins, statistic="mean").statistic
se = binned_statistic(df.MG_FE, df.C_MG, bins=mg_h_bins, statistic="std").statistic

In [None]:
plot_samples_caafe(result)

plt.errorbar(bin_mids, ss, yerr=se, fmt="o")

In [None]:
yc = surp.yield_models.BiLogLin_CC(y0=0.000, zeta=0.001, y1=-0.002)

In [None]:
mz_stars

In [None]:
mz_stars = find_model("analytic/mc_best")

In [None]:
x = np.linspace(-3, 2, 10000)

plt.plot(x, [yc(0.016 * 10**xx) for xx in x])

In [None]:
import sys
sys.path.append("../models/perturbations/")



In [None]:
y0_cc = 1e-3
zeta = lambda z: surp.yield_models.BiLogLin_CC(y0=0.002, zeta=0.001, y1=0)(z) - 2e-3
quad = surp.yield_models.Quadratic_CC(y0=0, zeta=0, A=0.001, Z1=0.0016)
agb = surp.yield_models.C_AGB_Model(y0=0.001, zeta=-0.001, tau_agb=10)

@np.vectorize
def cc_tot(z):
    return 1.57*y0_cc + 0.505 * zeta(z) + 5.75 * quad(z)

In [None]:
result.labels

In [None]:
zeta(0.014)

In [None]:
plot_yields(result)
vice.yields.agb.settings["c"] = agb
mh = np.linspace(-0.4, 0.3, 100)

ymg = vice.solar_z("mg")

z = gcem.MH_to_Z(mh)
plt.plot(mh, surp.yields.calc_y(z, kind="agb") / ymg)

plt.axhline(y0_cc / ymg)

plt.plot(mh, np.array([zeta(zz) for zz in z]) / ymg)
plt.legend()

In [None]:
vice.yields.agb.settings["c"] = surp.yield_models.ZeroAGB()

In [None]:
x = np.linspace(-0.5, 0.5, 1000)
z = gcem.MH_to_Z(x)

plt.plot(x, cc_tot(z))
plt.plot(x, surp.yields.calc_y(z))

In [None]:
def get_ytot(Z):
    return surp.yields.calc_y(