# MCMC yield pertubation analysis

The goal of this investigation is to fit for the combination of yields which best represents APOGEE trends using our multizone model results.
Because the abundance evolution of each element is *linear in yields*, we can predict the present-day properties for new combinations of yields.

**Abundance Linearity lemma:**
Assume we know the yields for several different processes $y^{i}$ and we know (or calculate) the resulting evolution for each process $Z^{i}(t)$. For any linear combination of $p$ processes $\{\alpha_i\}_{i \in 1...p} \in \mathbb{R}$, if the total yield of the process we are interested in is 

$$
y = \sum_i \alpha_i y^i
$$
Then the resulting abundance evolution is
$$
Z_X(t) = \sum_{i} \alpha_i Z_X^i(t)
$$

Note that if the metallicity evolution changes and the yields depend on metallicity, than this does not hold. However, to first order, we can use the above property of chemical evolution to understand the space of yields which best reproduces the present day APOGEE measurements.


### Modeling linear combinations.

To most approprietly model each linear combination, I take our fiducial model and use Ag as a dummy element. Each model has the same seed so reproduces exactly the same pattern of stars. 

Additionally, to minimize the impacts on the metallicity evolution, I scale each process down by a factor of $10^6$ and the solar metallicity for each process down by the corresponding amount. Therefore, for a 1x amount of each process, the reported ratios of \[X/H\] accurately correspond to $Z_{\rm C}^{\rm proc}$. For best practice, each combination of models is best ran on the same vice run, so adding more components is as simple as extending the yield correspondance table above.

In [None]:
using CairoMakie
using Arya

In [None]:
using CSV, DataFrames
import CategoricalArrays: cut
using StatsBase
using OrderedCollections

In [None]:
using Turing
import NaNMath as nm
using LinearAlgebra: diagm
using PairPlots

In [None]:
subgiants = CSV.read("../data/subgiants.csv", DataFrame)

In [None]:
solar_z = (;
    c = 3.39e-03
    )

In [None]:
function C_H_to_Zc(ag_h)
    return solar_z.c * 10 .^ ag_h
end

In [None]:
function Zc_to_C_H(z_c)
    return @. nm.log10(z_c / solar_z.c)
end

In [None]:
"""
Finds the pickled model with either the given name or the parameters 
and returns the csv summary
"""
function find_model(name)
    file_name = "../models/perturbations/$name/stars.csv"
    model =  CSV.read(file_name, DataFrame)

    model[!, "z_c"] = C_H_to_Zc(model[:, "AG_H"])

    return model
end

## Parameters

In [None]:
mg_h_bins = -0.5:0.1:0.35
mg_fe_bins = 0.0:0.05:0.30
mg_h_0 = -0.1
d_mg_h = 0.05

In [None]:
fig = Figure()
ax = Axis(fig[1, 1], 
    xlabel="Mg/H",
    ylabel="[C/Mg]",
    limits=(-0.6, 0.6, -0.1, 0.5)
    )


p = scatter!(subgiants.MG_H, subgiants.MG_FE, color=subgiants.C_MG, markersize = 2., colorrange=(-0.5, 0.2))

Colorbar(fig[1, 2], p, label="[Mg/Fe]")

fig

In [None]:
function bin_means(df::DataFrame; x::Symbol=:MG_H_true, bins = mg_h_bins, val::Symbol=:z_c, n_min::Int=3)
    Nb = length(bins) - 1

    x_bin = cut(df[!, x], bins, extend=missing, labels=1:Nb)

    filt = .!ismissing.(x_bin)
    df_filtered = copy(df[filt, :])
    df_filtered[!, :x_bin] = x_bin[filt]
    grouped = groupby(df_filtered, :x_bin)

    results = combine(grouped,
        val => mean => :med,
        x => mean => :xmed,
        val => sem => :err,
        val => length => :counts
    )

    full_grid = DataFrame(x_bin = 1:Nb)

    df_result = leftjoin(full_grid, results, on=:x_bin)

    x_bin_mids = midpoints(bins)
    df_result.x = getindex.(Ref(x_bin_mids), df_result.x_bin)

    filt_missing = ismissing.(df_result.counts)
    df_result[filt_missing, :counts] .= 0
    df_result[filt_missing, :med] .= 0
    df_result[filt_missing, :err] .= 0
    #DataFrames.transform!(df_result, :counts => ByRow(count -> count < n_min ? NaN : count) => :counts)

    return df_result
end

In [None]:
function bin_caah(df::DataFrame; bins=mg_h_bins, kwargs...)
    filt = .!df.high_alpha
    return bin_means(df[filt, :]; bins=bins, kwargs...)
end

In [None]:
function bin_caafe(df::DataFrame; m_h=:MG_H, m_h_0=mg_h_0, d_m_h=d_mg_h, x=:MG_FE, val=:C_MG, bins=mg_fe_bins, kwargs...)
    filt = df[:, m_h] .>= m_h_0 - d_m_h
    filt .&= df[:, m_h] .< m_h_0 + d_m_h
    return bin_means(df[filt, :]; bins=bins, x=x, val=val, kwargs...)
end

In [None]:
function bin_zc_afe(df::DataFrame; m_h=:MG_H_true, m_h_0=mg_h_0, d_m_h=d_mg_h, x=:MG_FE_true, val=:z_c, bins=mg_fe_bins, kwargs...)
    filt = df[:, m_h] .>= m_h_0 - d_m_h
    filt .&= df[:, m_h] .< m_h_0 + d_m_h
    return bin_means(df[filt, :]; bins=bins, x=x, val=val, kwargs...)
end

In [None]:
function bin_zc_ah(df::DataFrame;x=:MG_H_true, val=:z_c, bins=mg_h_bins, kwargs...)
    filt = .!df.high_alpha
    return bin_means(df[filt, :]; bins=bins, x=x, val=val, kwargs...)
end

In [None]:
mkpath("mcmc_samples")

# Data loading

In [None]:
subgiants_binned_afe = bin_caafe(subgiants, x=:MG_FE, m_h=:MG_H, val=:C_MG)

In [None]:
subgiants_binned_ah = bin_caah(subgiants, x=:MG_H, val=:C_MG)

In [None]:
model_alpha

In [None]:
model_y0_cc = find_model("const_cc")
model_A_cc = find_model("quadratic")
model_zeta = find_model("piecelin_m2")
model_alpha = find_model("analytic")

model_y0_cc

In [None]:
models = Dict(
    :alpha => model_alpha,
    :y0_cc => model_y0_cc,
    :A_cc => model_A_cc,
    :zeta_cc => model_zeta
    )

In [None]:
model_y0_cc

In [None]:
models_caah = Dict(component => bin_caah(model) for (component, model) in models)

In [None]:
models_caah[:alpha]

In [None]:
models_caafe = Dict(component => bin_caafe(model) for (component, model) in models)

In [None]:
models_zc_ah = Dict(component => bin_zc_ah(model) for (component, model) in models)

In [None]:
models_zc_afe = Dict(component => bin_zc_afe(model) for (component, model) in models)

In [None]:
models_zc_afe[:alpha]

In [None]:
models_zc_ah[:alpha]

In [None]:
hist(model_y0_cc.MG_H_true[.!model_y0_cc.high_alpha])

In [None]:
hist(model_y0_cc.MG_FE_true[ mg_h_0 - d_mg_h .< model_y0_cc.MG_H_true .< mg_h_0 + d_mg_h])

In [None]:
fig, ax = FigAxis(
    xlabel="[Mg/H]",
    ylabel="Zc  (for process)",
)

for (label, model) in models_zc_ah
    lines!(model.x, model.med ./ 10 .^ model.x, label=string(label))
end

axislegend()
fig

In [None]:
fig, ax = FigAxis(
    xlabel="[Mg/H]",
    ylabel="[C/Mg] (for process)",
)

for (label, model) in models_zc_ah
    scatter!(model.x, Zc_to_C_H(model.med) - model.x, label=string(label))
end

axislegend()
fig

In [None]:
fig, ax = FigAxis(
    xlabel = "[Mg/Fe]",
    ylabel = "[C/Mg] (for process)"
)

for (label, model) in models_zc_afe
    scatter!(model.x, Zc_to_C_H(model.med) .- mg_h_0, label=string(label))
end

axislegend()
fig

In [None]:
errscatter(subgiants_binned_ah.x, subgiants_binned_ah.med, yerr=subgiants_binned_ah.err,
    axis=(;xlabel="[Mg/H]", ylabel="[C/Mg]", title="binned subgiants")
)

In [None]:
errscatter(subgiants_binned_afe.x, subgiants_binned_afe.med, yerr=subgiants_binned_afe.err,
    axis=(;xlabel="[Mg/Fe]", ylabel="[C/Mg]", title="binned subgiants (with [M/H] in $(mg_h_0 - d_mg_h), $(mg_h_0 + d_mg_h))")
)

# Linear Model

In [None]:

# Helper function to compute model contribution
function compute_model_contribution(params, model)
    return sum(p * model[key] for (p, key) in zip(params, keys(model)))
end

@model function n_component_model(data1, data2, models, priors)
    # Check if priors are provided, and set defaults if not
    if isempty(priors)
        n_params = length(keys(models))
        priors = [Normal(0, 1) for _ in 1:n_params]  # Default priors if none are specified
    end
    
    # Create parameters based on the specified priors
    params ~ arraydist(priors)
    
    # Compute model contributions for each dataset
    mu1 = compute_model_contribution(params, models)
    mu2 = compute_model_contribution(params, models)  # This could be adjusted if different model terms are needed
    
    # Data likelihoods
    y ~ MvNormal(mu1, diagm(data1[:errors] .^ 2))
    y2 ~ MvNormal(mu2, diagm(data2[:errors] .^ 2))
end

In [None]:
@model function fit_3_comp_model(x, y, y_e, x2, y2, y_e2, models, models2; m_h_0=mg_h_0)
    y0_cc ~ Normal(2, 1)
    α ~ Normal(2, 1)
    ζ ~ Normal(0, 1)

    Zc = y0_cc * models[:y0_cc] .+ α * models[:alpha] .+ ζ * models[:zeta_cc]

    mu = Zc_to_C_H.(Zc) .- x
    y1_pred = mu

    Zc2 = y0_cc * models2[:y0_cc] .+ α * models2[:alpha] .+ ζ * models2[:zeta_cc]
    mu2 = Zc_to_C_H.(Zc2) .- mg_h_0

    
    y ~ MvNormal(mu, diagm(y_e .^ 2))
    y2 ~ MvNormal(mu2,  diagm(y_e2 .^ 2))
end

In [None]:
function plot_samples!(samples, x, models;
        thin=10, color=:black, alpha=nothing, kwargs...)

    if alpha === nothing
        alpha = 1/size(samples, 1)^(1/3)
    end
    
    for sample in eachrow(samples)[1:thin:end]
        y = @. sample.α * models[:alpha] + sample.y0_cc * models[:y0_cc] + sample.ζ * models[:zeta_cc]
        lines!(x, Zc_to_C_H(y) .- x, color=color, alpha=alpha)
    end
end

In [None]:
function plot_samples_afe!(samples, x, models;
        thin=10, color=:black, alpha=nothing, kwargs...)

    if alpha === nothing
        alpha = 1/size(samples, 1)^(1/3)
    end
    
    for sample in eachrow(samples)[1:thin:end]
        y = @. sample.α * models[:alpha] + sample.y0_cc * models[:y0_cc] + sample.ζ * models[:zeta_cc]
        lines!(x, Zc_to_C_H(y) .- mg_h_0, color=color, alpha=alpha)
    end
end

In [None]:
y = disallowmissing(subgiants_binned_ah.med)
yerr = disallowmissing(subgiants_binned_ah.err)# ./ sqrt.(subgiants_binned.counts)
x = disallowmissing(subgiants_binned_ah.x)

y2 = disallowmissing(subgiants_binned_afe.med)
yerr2 = disallowmissing(subgiants_binned_afe.err)# ./ sqrt.(subgiants_binned2.counts)
x2 = disallowmissing(subgiants_binned_afe.x)

In [None]:
yerr

In [None]:
model_med_ah = Dict(label => disallowmissing(model.med) for (label, model) in models_zc_ah)
model_med_afe = Dict(label => disallowmissing(model.med) for (label, model) in models_zc_afe)

In [None]:
model = fit_3_comp_model(x, y, yerr, x2, y2, yerr2, model_med_ah, model_med_afe)

In [None]:
chain = sample(model, NUTS(0.65), 5_000)

In [None]:
samples = DataFrame(chain)

In [None]:
pairplot(chain)

In [None]:
fig = Figure()
ax = Axis(fig[1, 1],
    #limits=(-0.5, 0.3, -0.2, 0.0),
    xgridvisible=false,
    ygridvisible=false,
    xlabel="[Mg/H]",
    ylabel="[C/Mg] (low alpha)",
)


plot_samples!(samples, x, model_med_ah, alpha=0.008)
errscatter!(x, y, yerr=yerr)

fig

In [None]:
fig = Figure()
ax = Axis(fig[1, 1],
    xlabel="[Mg/Fe]",
    ylabel=L"\textrm{mean([C/Mg]) | [Mg/H] }\in [%$(round(mg_h_0-d_mg_h, digits=2)), %$(round(mg_h_0+d_mg_h, digits=2))]",
    xgridvisible=false,
    ygridvisible=false,
)


plot_samples_afe!(samples, x2, model_med_afe, alpha=0.008, label="mean subgiants")
errscatter!(x2, y2, yerr=yerr2, label="mean samples")

axislegend()
fig

# 4-component (quadratic) model

In [None]:
@model function fit_4_comp_model(x, y, y_e, x2, y2, y_e2, models, models2; m_h_0=mg_h_0)
    y0_cc ~ Normal(2, 1)
    α ~ Normal(2, 1)
    ζ ~ Normal(0, 1)
    A ~ Normal(0, 1)

    Zc = y0_cc * models[:y0_cc] .+ α * models[:alpha] .+ ζ * models[:zeta_cc] .+ A * models[:A_cc]

    mu = Zc_to_C_H.(Zc) .- x
    y1_pred = mu

    Zc2 = y0_cc * models2[:y0_cc] .+ α * models2[:alpha] .+ ζ * models2[:zeta_cc] .+ A * models2[:A_cc]
    mu2 = Zc_to_C_H.(Zc2) .- mg_h_0

    
    y ~ MvNormal(mu, diagm(y_e .^ 2))
    y2 ~ MvNormal(mu2,  diagm(y_e2 .^ 2))
end

In [None]:
function plot_samples!(samples, x, models;
        thin=10, color=:black, alpha=nothing, kwargs...)

    if alpha === nothing
        alpha = 1/size(samples, 1)^(1/3)
    end
    
    for sample in eachrow(samples)[1:thin:end]
        y = @. sample.α * models[:alpha] + sample.y0_cc * models[:y0_cc] + sample.ζ * models[:zeta_cc] + sample.A * models[:A_cc]
        lines!(x, Zc_to_C_H(y) .- x, color=color, alpha=alpha)
    end
end

In [None]:
scatter(subgiants.MG_H, subgiants.MG_H_ERR)

In [None]:
function plot_samples_afe!(samples, x, models;
        thin=10, color=:black, alpha=nothing, kwargs...)

    if alpha === nothing
        alpha = 1/size(samples, 1)^(1/3)
    end
    
    for sample in eachrow(samples)[1:thin:end]
        y = @. sample.α * models[:alpha] + sample.y0_cc * models[:y0_cc] + sample.ζ * models[:zeta_cc] + sample.A * models[:A_cc]
        lines!(x, Zc_to_C_H(y) .- mg_h_0, color=color, alpha=alpha)
    end
end

In [None]:
y = disallowmissing(subgiants_binned_ah.med)
yerr = disallowmissing(subgiants_binned_ah.err)# ./ sqrt.(subgiants_binned.counts)
x = disallowmissing(subgiants_binned_ah.x)

y2 = disallowmissing(subgiants_binned_afe.med)
yerr2 = disallowmissing(subgiants_binned_afe.err)# ./ sqrt.(subgiants_binned2.counts)
x2 = disallowmissing(subgiants_binned_afe.x)

In [None]:
yerr

In [None]:
model = fit_4_comp_model(x, y, yerr, x2, y2, yerr2, model_med_ah, model_med_afe)

In [None]:
chain = sample(model, NUTS(0.65), 50_000)

In [None]:
samples = DataFrame(chain)

In [None]:
CSV.write("mcmc_samples/samples_analytic.csv", samples)

In [None]:
fig = pairplot(chain, labels=Dict(
        :y0_cc=>L"y_0^\textrm{CC}",
        :α => L"\alpha_\textrm{C}^\textrm{AGB}",
        :ζ => L"\zeta_\textrm{C}^\textrm{CC}",
        :A => L"A_\textrm{C}^\textrm{CC}",
    )
)

save("figures/mcmc_multizone_corner.pdf", fig)

fig

In [None]:
fig = Figure(size=(600, 300))
ax = Axis(fig[1, 1],
    #limits=(-0.5, 0.3, -0.2, 0.0),
    xgridvisible=false,
    ygridvisible=false,
    xlabel="[Mg/H]",
    ylabel="[C/Mg]",
)


plot_samples!(samples[1:10:end, :], x, model_med_ah, alpha=0.008)
errscatter!(x, y, yerr=yerr, color=COLORS[2])


ax = Axis(fig[1, 2],
    xlabel="[Mg/Fe]",
    ylabel=L"\textrm{[C/Mg]}",
    xgridvisible=false,
    ygridvisible=false,
)


plot_samples_afe!(samples[1:10:end, :], x2, model_med_afe, alpha=0.008)
errscatter!(x2, y2, yerr=yerr2, color=COLORS[2])


fig

# N component fit

In [None]:
function load_binned_models(modelname)
    dir = "../models/perturbations/mc_analysis/$modelname"

    afe = CSV.read(dir * "/mg_fe_binned.csv", DataFrame)
    ah = CSV.read(dir * "/mg_h_binned.csv", DataFrame)

    return ah, afe
end

In [None]:

# Helper function to compute model contribution
function compute_model_contribution(params, model)
    return sum(p * model[key] for (p, key) in zip(params, keys(model)))
end

@model function n_component_model(models_ah, models_afe, labels, priors)
    # Create parameters based on the specified priors
    params ~ arraydist(priors)
    
    # Compute model contributions for each dataset
    mu_ah = sum(p * models_ah[:, key] for (p, key) in zip(params, labels))
    mu_afe = sum(p * models_afe[:, key] for (p, key) in zip(params, labels))
    
    # Data likelihoods
    models_ah.obs ~ MvNormal(mu_ah, diagm(models_ah.obs_err .^ 2))
    models_afe.obs ~ MvNormal(mu_afe, diagm(models_afe.obs_err .^ 2))
end

In [None]:
ah, afe = load_binned_models("analytic_quad")

In [None]:
labels = ["alpha", "y0_cc", "zeta_cc", "A_cc"]
model = n_component_model(ah, afe, labels, [
        Normal(1, 1), Normal(2, 1), Normal(0, 1), Normal(0, 1)]
)

In [None]:
chain = sample(model, NUTS(0.65), 10_000)

In [None]:
samples = DataFrame(chain)

In [None]:
fig = pairplot(chain, labels=Dict(
        Symbol("params[$i]") => label for (i, label) in enumerate(labels)
    )
)


In [None]:
function plot_samples_ah!(afe, samples, labels;
        thin=10, color=:black, alpha=nothing, kwargs...)
    
    if alpha === nothing
        alpha = 1/size(samples, 1)^(1/3)
    end

    Zmg = 6.71e-04 * 10 .^ afe._x 
    for sample in eachrow(samples)[1:thin:end]
        y = sum(sample["params[$i]"] * afe[:, label]  for (i, label) in enumerate(labels))
        
        lines!(afe._x, y ./ Zmg, color=color, alpha=alpha)
    end

    errscatter!(afe._x, afe.obs ./ Zmg, yerr=afe.obs_err ./ Zmg, color=COLORS[2])
end

In [None]:
function plot_samples_afe!(afe, samples, labels;
        thin=10, color=:black, alpha=nothing, kwargs...)
    
    if alpha === nothing
        alpha = 1/size(samples, 1)^(1/3)
    end
    
    for sample in eachrow(samples)[1:thin:end]
        y = sum(sample["params[$i]"] * afe[:, label]  for (i, label) in enumerate(labels))
        
        lines!(afe._x, y, color=color, alpha=alpha)
    end

    errscatter!(afe._x, afe.obs, yerr=afe.obs_err, color=COLORS[2])
end

In [None]:
fig = Figure()
ax = Axis(fig[1, 1],
    xlabel="[Mg/H]",
    ylabel="mean Z_C / Zmg",
)


plot_samples_ah!(ah, samples, labels, thin=100)

fig

In [None]:
fig = Figure()
ax = Axis(fig[1, 1],
    xlabel="[Mg/Fe]",
    ylabel="mean Z_C",
)


plot_samples_afe!(afe, samples, labels, thin=100)

fig