# MCMC yield pertubation analysis

The goal of this investigation is to fit for the combination of yields which best represents APOGEE trends using our multizone model results.
Because the abundance evolution of each element is *linear in yields*, we can predict the present-day properties for new combinations of yields. This notebook focuses on the MCMC methods (likelihoods, binning, etc) to understand what works well and the model sensativity. The scientific comparisons and models are instead in the notebook `perturb_results` and in the models `pertubations/mcmc_analysis`.  



## Background

**Abundance Linearity lemma:**
Assume we know the yields for several different processes $y^{i}$ and we know (or calculate) the resulting evolution for each process $Z^{i}(t)$. For any linear combination of $p$ processes $\{\alpha_i\}_{i \in 1...p} \in \mathbb{R}$, if the total yield of the process we are interested in is 

$$
y = \sum_i \alpha_i y^i
$$
Then the resulting abundance evolution is
$$
Z_X(t) = \sum_{i} \alpha_i Z_X^i(t)
$$

Note that if the metallicity evolution changes and the yields depend on metallicity, than this does not hold. However, to first order, we can use the above property of chemical evolution to understand the space of yields which best reproduces the present day APOGEE measurements.


### Modeling linear combinations.

To most approprietly model each linear combination, I take our fiducial model and use Ag as a dummy element. Each model has the same seed so reproduces exactly the same pattern of stars. 

Additionally, to minimize the impacts on the metallicity evolution, I scale each process down by a factor of $10^6$ and the solar metallicity for each process down by the corresponding amount. Therefore, for a 1x amount of each process, the reported ratios of \[X/H\] accurately correspond to $Z_{\rm C}^{\rm proc}$. For best practice, each combination of models is best ran on the same vice run, so adding more components is as simple as extending the yield correspondance table above.

In [None]:
using CairoMakie
using Arya

In [None]:
using CSV, DataFrames
import CategoricalArrays: cut
using StatsBase
using OrderedCollections

In [None]:
using Turing
import NaNMath as nm
using LinearAlgebra: diagm
using PairPlots

In [None]:
using AdvancedMH

In [None]:
import FillArrays: I

In [None]:
solar_z = (;
    c = 3.39e-03,
    mg = 6.71e-04
    )

In [None]:
function C_Mg_to_Z(ag_mg)
    return solar_z.c / solar_z.mg * 10 .^ ag_mg
end

In [None]:
function C_Mg_to_Z_relerr(ag_mg_err)
    return ag_mg_err * log(10)
end

In [None]:
using Measurements

In [None]:
function Mg_H_to_Zc(mg_h)
    return solar_z.mg * 10 .^ mg_h
end

In [None]:
function Zc_to_C_Mg(z_c_mg)
    return @. nm.log10(z_c_mg / solar_z.c * solar_z.mg)
end

In [None]:
"""
Finds the pickled model with either the given name or the parameters 
and returns the csv summary
"""
function find_model(name)
    file_name = "../models/perturbations/$name/stars.csv"
    model =  CSV.read(file_name, DataFrame)

    model[!, "z_c"] = C_Mg_to_Z(model[:, "AG_MG"])

    return model
end

In [None]:
function plot_mcmc_results(chain::Chains; bins::Int = 30)
    # Extract parameter names
    params = chain.name_map.parameters
    nparams = length(params)
    
    # Determine grid layout
    nrows = ceil(Int, nparams)
    
    # Initialize the figure
    fig = Figure(size = (600, 200 * nrows), 
                backgroundcolor = :white)
    
    # Determine number of chains
    nchains = size(chain, 3)
        
    # Iterate over each parameter
    for (i, param) in enumerate(params)

        acc_rate = length(unique(chain[:, i, :])) / length(chain[:, i, :])
        @info "Param $param, acc rate $acc_rate"
        # Trace Plot
        ax_trace = Axis(fig[i, 1],
            xlabel = "Iteration",
            ylabel = "$param",
            xgridvisible=false,
            ygridvisible=false
            )
        
        # Plot each chain's trace
        for c in 1:nchains
            samples = chain[:, i, c]
            lines!(ax_trace, collect(1:length(samples)), samples, color = c, colorrange=(0, nchains), label = "Chain $c")
        end
        
        # Add legend only once
        if i == 1 && nchains > 1
            axislegend(ax_trace, position = :rt)
        end
        
        # Histogram Plot
        ax_hist = Axis(fig[i, 2],
            limits=(0, nothing, nothing, nothing),
            xgridvisible=false,
            ygridvisible=false
        )
        hidedecorations!(ax_hist)
        linkyaxes!(ax_trace, ax_hist)

        
        # Combine samples from all chains for histogram
        combined_samples = vec(chain[:, i, :])
        hist!(ax_hist, direction=:x, combined_samples, bins = bins)
        
        if i < nparams
            hidexdecorations!(ax_trace, ticks=false)
        end
    end

    colgap!(fig.layout, 1, 0)
    rowgap!(fig.layout, 0)

    colsize!(fig.layout, 2, Relative(0.25))

    return fig
end

# Data loading

In [None]:
subgiants = CSV.read("../data/subgiants.csv", DataFrame)

In [None]:
subgiants[!, :z_c] = C_Mg_to_Z(subgiants.C_MG)
subgiants[!, :z_c_err] =  C_Mg_to_Z_relerr(subgiants.C_MG_ERR) .* subgiants.z_c

In [None]:
fig = Figure()
ax = Axis(fig[1, 1], 
    xlabel="[Mg/H]",
    ylabel="[Mg/Fe]",
    limits=(-0.6, 0.6, -0.1, 0.5)
    )


p = scatter!(subgiants.MG_H, subgiants.MG_FE, color=subgiants.C_MG, markersize = 2., colorrange=(-0.5, 0.2))

Colorbar(fig[1, 2], p, label="[C/Mg]")

fig

The above scatter plot visualizes the space which subgiant observations occupy. The low alpha sequence (low \[Mg/Fe\]) occupies a region between metallicities of -0.4 and 0.4. Almost all stars have \[Mg/Fe\] between 0 and 0.4. The vertical slice we use for the bins in \[Mg/Fe\] is at -0.1 which spans this entire range as well.

In [None]:
model_y0_cc = find_model("const_cc")
model_A_cc = find_model("quadratic")
model_zeta = find_model("piecelin_m2")
model_alpha = find_model("analytic")
model_zeta.z_c .-= model_y0_cc.z_c .* 2 # corrects offset

model_y0_cc

In [None]:
mg_h_bins = -0.5:0.1:0.35
mg_fe_bins = 0.0:0.05:0.30
mg_h_0 = -0.1
d_mg_h = 0.05

In [None]:
models = Dict(
    :alpha => model_alpha,
    :zeta0 => model_y0_cc,
    :zeta1 => model_zeta,
    :zeta2 => model_A_cc,
    )

In [None]:
fig = Figure()
ax = Axis(fig[1, 1], 
    xlabel="[Mg/H]",
    ylabel="[Mg/Fe]",
    limits=(-0.6, 0.6, -0.1, 0.5)
    )


p = scatter!(model_alpha.MG_H, model_alpha.MG_FE, color=model_alpha.C_MG, markersize = 2., colorrange=(-0.5, 0.2))

Colorbar(fig[1, 2], p, label="[C/Mg]")

fig

In [None]:
hist(model_y0_cc.MG_H_true[.!model_y0_cc.high_alpha])

In [None]:
hist(model_y0_cc.MG_FE_true[ mg_h_0 - d_mg_h .< model_y0_cc.MG_H_true .< mg_h_0 + d_mg_h])

# Binning utils

In [None]:
function bin_means(df::DataFrame; x::Symbol=:MG_H_true, bins = mg_h_bins, val::Symbol=:z_c, n_min::Int=3)
    Nb = length(bins) - 1

    x_bin = cut(df[!, x], bins, extend=missing, labels=1:Nb)

    filt = .!ismissing.(x_bin)
    df_filtered = copy(df[filt, :])
    df_filtered[!, :x_bin] = x_bin[filt]
    grouped = groupby(df_filtered, :x_bin)

    results = combine(grouped,
        val => mean => :med,
        x => mean => :xmed,
        val => std => :err,
        val => length => :counts
    )

    full_grid = DataFrame(x_bin = 1:Nb)

    df_result = leftjoin(full_grid, results, on=:x_bin)

    x_bin_mids = midpoints(bins)
    df_result.x = getindex.(Ref(x_bin_mids), df_result.x_bin)

    filt_missing = ismissing.(df_result.counts)
    df_result[filt_missing, :counts] .= 0
    df_result[filt_missing, :med] .= 0
    df_result[filt_missing, :err] .= 0
    #DataFrames.transform!(df_result, :counts => ByRow(count -> count < n_min ? NaN : count) => :counts)
    disallowmissing!(df_result)
    return df_result
end

In [None]:
function bin_caah(df::DataFrame; bins=mg_h_bins, val=:C_MG, x=:MG_H, kwargs...)
    filt = .!df.high_alpha
    return bin_means(df[filt, :]; x=x, val=val, bins=bins, kwargs...)
end

In [None]:
function bin_caafe(df::DataFrame; m_h=:MG_H, m_h_0=mg_h_0, d_m_h=d_mg_h, x=:MG_FE, val=:C_MG, bins=mg_fe_bins, kwargs...)
    filt = df[:, m_h] .>= m_h_0 - d_m_h
    filt .&= df[:, m_h] .< m_h_0 + d_m_h
    return bin_means(df[filt, :]; bins=bins, x=x, val=val, kwargs...)
end

In [None]:
function bin_zc_afe(df::DataFrame; m_h=:MG_H_true, m_h_0=mg_h_0, d_m_h=d_mg_h, x=:MG_FE_true, val=:z_c, bins=mg_fe_bins, kwargs...)
    filt = df[:, m_h] .>= m_h_0 - d_m_h
    filt .&= df[:, m_h] .< m_h_0 + d_m_h
    return bin_means(df[filt, :]; bins=bins, x=x, val=val, kwargs...)
end

In [None]:
function bin_zc_ah(df::DataFrame;x=:MG_H_true, val=:z_c, bins=mg_h_bins, kwargs...)
    filt = .!df.high_alpha
    return bin_means(df[filt, :]; bins=bins, x=x, val=val, kwargs...)
end

In [None]:
function bin_all(models, observations=subgiants;
        mg_h_bins=mg_h_bins, mg_fe_bins=mg_fe_bins, 
        mg_h_0=mg_h_0, d_mg_h=d_mg_h
    )

    labels = keys(models)

    ah_dfs = Dict(label => bin_zc_ah(model, bins=mg_h_bins) 
        for (label, model) in models)
    
    afe_dfs = Dict(label => bin_zc_afe(model, bins=mg_fe_bins, 
            m_h_0=mg_h_0, d_m_h=d_mg_h) 
        for (label, model) in models)

    ah_obs = bin_zc_ah(observations, bins=mg_h_bins, x=:MG_H, val=:z_c)
    afe_obs = bin_zc_afe(observations, bins=mg_fe_bins, 
        val=:z_c, x=:MG_FE, m_h=:MG_H,
            m_h_0=mg_h_0, d_m_h=d_mg_h)

    

    ah_binned = DataFrame([label => df.med for (label, df) in ah_dfs]...)
    afe_binned = DataFrame([label => df.med for (label, df) in afe_dfs]...)

    for label in labels
        ah_binned[:, ("$(label)_err")] = ah_dfs[label].err
        afe_binned[:, ("$(label)_err")] = afe_dfs[label].err
        ah_binned[:, ("$(label)_counts")] = ah_dfs[label].counts
        afe_binned[:, ("$(label)_counts")] = afe_dfs[label].counts
    end

    ah_binned[!, :obs] = ah_obs.med
    ah_binned[!, :obs_err] = ah_obs.err
    ah_binned[!, :obs_counts] = ah_obs.counts
    ah_binned[!, :x] = midpoints(mg_h_bins)
    
    afe_binned[!, :obs] = afe_obs.med
    afe_binned[!, :obs_err] = afe_obs.err
    afe_binned[!, :obs_counts] = afe_obs.counts
    afe_binned[!, :x] = midpoints(mg_fe_bins)


    return ah_binned, afe_binned, labels
end

In [None]:
function plot_ah_models(ah, labels)
    fig, ax = FigAxis(
        xlabel="[Mg/H]",
        ylabel="Zc / Zmg / Zmg(sun) (for process)",
    )
    
    for label in labels
        lines!(mg_h_m, ah[:, label] .* solar_z[:mg], label=string(label))
    end
    
    
    errscatter!(mg_h_m, ah.obs .* solar_z[:mg], yerr=ah.obs_err .* solar_z[:mg])
    axislegend()
    fig

end

In [None]:
function plot_afe_models(afe, labels)
    fig, ax = FigAxis(
        xlabel="[Mg/Fe]",
        ylabel="Zc / Zmg /Zmg(sun) (for process)",
    )
    
    for label in labels
        lines!(afe.x, afe[:, label] .* solar_z[:mg], label=string(label))
    end
    
    errscatter!(afe.x, afe.obs .* solar_z[:mg], yerr=afe.obs_err .* solar_z[:mg])
    
    axislegend()
    fig

end

In [None]:
function plot_samples_ah!(afe, samples, labels;
        thin=10, color=:black, alpha=nothing, kwargs...)
    
    if alpha === nothing
        alpha = 1/size(samples, 1)^(1/3)
    end

    for sample in eachrow(samples)[1:thin:end]
        y = sum(sample["params[$i]"] * afe[:, label]  for (i, label) in enumerate(labels))
        
        lines!(afe.x, y, color=color, alpha=alpha)
    end

    errscatter!(afe.x, afe.obs, yerr=afe.obs_err, color=COLORS[2])
end

In [None]:
function plot_samples_caah!(afe, samples, labels;
        thin=10, color=:black, alpha=nothing, kwargs...)
    
    if alpha === nothing
        alpha = 1/size(samples, 1)^(1/3)
    end

    for sample in eachrow(samples)[1:thin:end]
        y = sum(sample["params[$i]"] * afe[:, label]  for (i, label) in enumerate(labels))
        
        lines!(afe.x, Zc_to_C_Mg.(y), color=color, alpha=alpha)
    end

    errscatter!(afe.x, Zc_to_C_Mg.(afe.obs), yerr=afe.obs_err ./ afe.obs ./ sqrt.(afe.obs_counts) ./ log(10), color=COLORS[2])
end

In [None]:
function plot_samples_afe!(afe, samples, labels;
        thin=10, color=:black, alpha=nothing, kwargs...)
    
    if alpha === nothing
        alpha = 1/size(samples, 1)^(1/3)
    end
    
    for sample in eachrow(samples)[1:thin:end]
        y = sum(sample["params[$i]"] * afe[:, label]  for (i, label) in enumerate(labels))
        
        lines!(afe.x, y, color=color, alpha=alpha)
    end

    errscatter!(afe.x, afe.obs, yerr=afe.obs_err, color=COLORS[2])
end

In [None]:
ah, afe, labels = bin_all(models, mg_h_bins=mg_h_bins, mg_fe_bins=mg_fe_bins)

In [None]:
plot_ah_models(ah, labels)

In [None]:
plot_afe_models(afe, labels)

# N component model

In [None]:

# Helper function to compute model contribution
function compute_model_contribution(params, model)
    return sum(p * model[key] for (p, key) in zip(params, keys(model)))
end

@model function n_component_model(data1, data2, models, priors)
    # Check if priors are provided, and set defaults if not
    if isempty(priors)
        n_params = length(keys(models))
        priors = [Normal(0, 1) for _ in 1:n_params]  # Default priors if none are specified
    end
    
    # Create parameters based on the specified priors
    params ~ arraydist(priors)
    
    # Compute model contributions for each dataset
    mu1 = compute_model_contribution(params, models)
    mu2 = compute_model_contribution(params, models)  # This could be adjusted if different model terms are needed
    
    # Data likelihoods
    y ~ MvNormal(mu1, diagm(data1[:errors] .^ 2))
    y2 ~ MvNormal(mu2, diagm(data2[:errors] .^ 2))
end

In [None]:
@model function fit_3_comp_model(x, y, y_e, x2, y2, y_e2, models, models2; m_h_0=mg_h_0)
    y0_cc ~ Normal(2, 1)
    α ~ Normal(2, 1)
    ζ ~ Normal(0, 1)

    Zc = y0_cc * models[:y0_cc] .+ α * models[:alpha] .+ ζ * models[:zeta_cc]

    mu = Zc_to_C_H.(Zc) .- x
    y1_pred = mu

    Zc2 = y0_cc * models2[:y0_cc] .+ α * models2[:alpha] .+ ζ * models2[:zeta_cc]
    mu2 = Zc_to_C_H.(Zc2) .- mg_h_0

    
    y ~ MvNormal(mu, diagm(y_e .^ 2))
    y2 ~ MvNormal(mu2,  diagm(y_e2 .^ 2))
end

In [None]:
function plot_samples!(samples, x, models;
        thin=10, color=:black, alpha=nothing, kwargs...)

    if alpha === nothing
        alpha = 1/size(samples, 1)^(1/3)
    end
    
    for sample in eachrow(samples)[1:thin:end]
        y = @. sample.α * models[:alpha] + sample.y0_cc * models[:y0_cc] + sample.ζ * models[:zeta_cc]
        lines!(x, Zc_to_C_H(y) .- x, color=color, alpha=alpha)
    end
end

In [None]:
function plot_samples_afe!(samples, x, models;
        thin=10, color=:black, alpha=nothing, kwargs...)

    if alpha === nothing
        alpha = 1/size(samples, 1)^(1/3)
    end
    
    for sample in eachrow(samples)[1:thin:end]
        y = @. sample.α * models[:alpha] + sample.y0_cc * models[:y0_cc] + sample.ζ * models[:zeta_cc]
        lines!(x, Zc_to_C_H(y) .- mg_h_0, color=color, alpha=alpha)
    end
end

In [None]:
y = disallowmissing(subgiants_binned_ah.med)
yerr = disallowmissing(subgiants_binned_ah.err)# ./ sqrt.(subgiants_binned.counts)
x = disallowmissing(subgiants_binned_ah.x)

y2 = disallowmissing(subgiants_binned_afe.med)
yerr2 = disallowmissing(subgiants_binned_afe.err)# ./ sqrt.(subgiants_binned2.counts)
x2 = disallowmissing(subgiants_binned_afe.x)

In [None]:
yerr

In [None]:
model_med_ah = Dict(label => disallowmissing(model.med) for (label, model) in models_zc_ah)
model_med_afe = Dict(label => disallowmissing(model.med) for (label, model) in models_zc_afe)

In [None]:
model = fit_3_comp_model(x, y, yerr, x2, y2, yerr2, model_med_ah, model_med_afe)

In [None]:
chain = sample(model, NUTS(0.65), 5_000)

In [None]:
samples = DataFrame(chain)

In [None]:
pairplot(chain)

In [None]:
fig = Figure()
ax = Axis(fig[1, 1],
    #limits=(-0.5, 0.3, -0.2, 0.0),
    xgridvisible=false,
    ygridvisible=false,
    xlabel="[Mg/H]",
    ylabel="[C/Mg] (low alpha)",
)


plot_samples!(samples, x, model_med_ah, alpha=0.008)
errscatter!(x, y, yerr=yerr)

fig

In [None]:
fig = Figure()
ax = Axis(fig[1, 1],
    xlabel="[Mg/Fe]",
    ylabel=L"\textrm{mean([C/Mg]) | [Mg/H] }\in [%$(round(mg_h_0-d_mg_h, digits=2)), %$(round(mg_h_0+d_mg_h, digits=2))]",
    xgridvisible=false,
    ygridvisible=false,
)


plot_samples_afe!(samples, x2, model_med_afe, alpha=0.008, label="mean subgiants")
errscatter!(x2, y2, yerr=yerr2, label="mean samples")

axislegend()
fig

In [None]:
@model function fit_4_comp_model(x, y, y_e, x2, y2, y_e2, models, models2; m_h_0=mg_h_0)
    y0_cc ~ Normal(2, 1)
    α ~ Normal(2, 1)
    ζ ~ Normal(0, 1)
    A ~ Normal(0, 1)

    Zc = y0_cc * models[:y0_cc] .+ α * models[:alpha] .+ ζ * models[:zeta_cc] .+ A * models[:A_cc]

    mu = Zc_to_C_H.(Zc) .- x
    y1_pred = mu

    Zc2 = y0_cc * models2[:y0_cc] .+ α * models2[:alpha] .+ ζ * models2[:zeta_cc] .+ A * models2[:A_cc]
    mu2 = Zc_to_C_H.(Zc2) .- mg_h_0

    
    y ~ MvNormal(mu, diagm(y_e .^ 2))
    y2 ~ MvNormal(mu2,  diagm(y_e2 .^ 2))
end

In [None]:
function plot_samples!(samples, x, models;
        thin=10, color=:black, alpha=nothing, kwargs...)

    if alpha === nothing
        alpha = 1/size(samples, 1)^(1/3)
    end
    
    for sample in eachrow(samples)[1:thin:end]
        y = @. sample.α * models[:alpha] + sample.y0_cc * models[:y0_cc] + sample.ζ * models[:zeta_cc] + sample.A * models[:A_cc]
        lines!(x, Zc_to_C_H(y) .- x, color=color, alpha=alpha)
    end
end

In [None]:
scatter(subgiants.MG_H, subgiants.MG_H_ERR)

In [None]:
function plot_samples_afe!(samples, x, models;
        thin=10, color=:black, alpha=nothing, kwargs...)

    if alpha === nothing
        alpha = 1/size(samples, 1)^(1/3)
    end
    
    for sample in eachrow(samples)[1:thin:end]
        y = @. sample.α * models[:alpha] + sample.y0_cc * models[:y0_cc] + sample.ζ * models[:zeta_cc] + sample.A * models[:A_cc]
        lines!(x, Zc_to_C_H(y) .- mg_h_0, color=color, alpha=alpha)
    end
end

In [None]:
y = disallowmissing(subgiants_binned_ah.med)
yerr = disallowmissing(subgiants_binned_ah.err)# ./ sqrt.(subgiants_binned.counts)
x = disallowmissing(subgiants_binned_ah.x)

y2 = disallowmissing(subgiants_binned_afe.med)
yerr2 = disallowmissing(subgiants_binned_afe.err)# ./ sqrt.(subgiants_binned2.counts)
x2 = disallowmissing(subgiants_binned_afe.x)

In [None]:
yerr

In [None]:
model = fit_4_comp_model(x, y, yerr, x2, y2, yerr2, model_med_ah, model_med_afe)

In [None]:
chain = sample(model, NUTS(0.65), 50_000)

In [None]:
samples = DataFrame(chain)

In [None]:
CSV.write("mcmc_samples/samples_analytic.csv", samples)

In [None]:
fig = pairplot(chain, labels=Dict(
        :y0_cc=>L"y_0^\textrm{CC}",
        :α => L"\alpha_\textrm{C}^\textrm{AGB}",
        :ζ => L"\zeta_\textrm{C}^\textrm{CC}",
        :A => L"A_\textrm{C}^\textrm{CC}",
    )
)

save("figures/mcmc_multizone_corner.pdf", fig)

fig

In [None]:
fig = Figure(size=(600, 300))
ax = Axis(fig[1, 1],
    #limits=(-0.5, 0.3, -0.2, 0.0),
    xgridvisible=false,
    ygridvisible=false,
    xlabel="[Mg/H]",
    ylabel="[C/Mg]",
)


plot_samples!(samples[1:10:end, :], x, model_med_ah, alpha=0.008)
errscatter!(x, y, yerr=yerr, color=COLORS[2])


ax = Axis(fig[1, 2],
    xlabel="[Mg/Fe]",
    ylabel=L"\textrm{[C/Mg]}",
    xgridvisible=false,
    ygridvisible=false,
)


plot_samples_afe!(samples[1:10:end, :], x2, model_med_afe, alpha=0.008)
errscatter!(x2, y2, yerr=yerr2, color=COLORS[2])


fig

# Chi2 model

In [None]:
@code_warntype @model function n_component_model(models_ah, models_afe, labels, priors)
    # Create parameters based on the specified priors
    params ~ arraydist(priors)
    
    # Compute model contributions for each dataset
    mu_ah = sum(p * models_ah[:, key] for (p, key) in zip(params, labels))
    mu_afe = sum(p * models_afe[:, key] for (p, key) in zip(params, labels))
    
    # Data likelihoods

    sigma_ah = models_ah.obs_err ./ sqrt.(models_ah.obs_counts)
    sigma_afe = models_afe.obs_err ./ sqrt.(models_afe.obs_counts)
    
    models_ah.obs ~ MvNormal(mu_ah, diagm(sigma_ah .^ 2))
    models_afe.obs ~ MvNormal(mu_afe, diagm(sigma_afe .^ 2))
end

In [None]:
function plot_samples_ah!(afe, samples, labels;
        thin=10, color=:black, alpha=nothing, kwargs...)
    
    if alpha === nothing
        alpha = 1/size(samples, 1)^(1/3)
    end

    for sample in eachrow(samples)[1:thin:end]
        y = sum(sample["params[$i]"] * afe[:, label]  for (i, label) in enumerate(labels))
        
        lines!(afe.x, y, color=color, alpha=alpha)
    end

    errscatter!(afe.x, afe.obs, yerr=afe.obs_err, color=COLORS[2])
end

In [None]:
function plot_samples_caah!(afe, samples, labels;
        thin=10, color=:black, alpha=nothing, kwargs...)
    
    if alpha === nothing
        alpha = 1/size(samples, 1)^(1/3)
    end

    for sample in eachrow(samples)[1:thin:end]
        y = sum(sample["params[$i]"] * afe[:, label]  for (i, label) in enumerate(labels))
        
        lines!(afe.x, Zc_to_C_Mg.(y), color=color, alpha=alpha)
    end

    errscatter!(afe.x, Zc_to_C_Mg.(afe.obs), yerr=afe.obs_err ./ afe.obs ./ sqrt.(afe.obs_counts) ./ log(10), color=COLORS[2])
end

In [None]:
function plot_samples_afe!(afe, samples, labels;
        thin=10, color=:black, alpha=nothing, kwargs...)
    
    if alpha === nothing
        alpha = 1/size(samples, 1)^(1/3)
    end
    
    for sample in eachrow(samples)[1:thin:end]
        y = sum(sample["params[$i]"] * afe[:, label]  for (i, label) in enumerate(labels))
        
        lines!(afe.x, y, color=color, alpha=alpha)
    end

    errscatter!(afe.x, afe.obs, yerr=afe.obs_err, color=COLORS[2])
end

In [None]:
priors = Dict(
    :alpha => Normal(1, 6),
    :zeta0 => Normal(2, 6),
    :zeta1 => Normal(0, 6),
    :zeta2 => Normal(0, 6)
    )

In [None]:
priors = [priors[label] for label in labels]

In [None]:
model = n_component_model(ah, afe, labels, priors
)

In [None]:
s = externalsampler(AdvancedMH.RWMH(4))

In [None]:
chain = sample(model, s, 30_000)

In [None]:
samples = DataFrame(chain)

In [None]:
fig = pairplot(chain[2000:end], labels=Dict(
        Symbol("params[$i]") => string(label) for (i, label) in enumerate(labels)
    )
)

In [None]:
model = n_component_model(ah, afe, labels, priors
)

In [None]:
chain = sample(model, NUTS(0.65), 10_000)

In [None]:
samples = DataFrame(chain)

In [None]:
fig = pairplot(chain, labels=Dict(
        Symbol("params[$i]") => string(label) for (i, label) in enumerate(labels)
    )
)

In [None]:
model = n_component_model(ah, afe, labels, priors
)

In [None]:
chain = sample(model, NUTS(0.65), 10_000)

In [None]:
samples = DataFrame(chain)

In [None]:
fig = pairplot(chain, labels=Dict(
        Symbol("params[$i]") => string(label) for (i, label) in enumerate(labels)
    )
)

In [None]:
fig = Figure()
ax = Axis(fig[1, 1],
    xlabel="[Mg/H]",
    ylabel="mean Z_C / Zmg",
)


plot_samples_ah!(ah, samples, labels, thin=100)

fig

In [None]:
fig = Figure()
ax = Axis(fig[1, 1],
    xlabel="[Mg/H]",
    ylabel="mean Z_C / Zmg",
)


plot_samples_caah!(ah, samples, labels, thin=100)

fig

In [None]:
fig = Figure()
ax = Axis(fig[1, 1],
    xlabel="[Mg/Fe]",
    ylabel="mean Z_C",
)


plot_samples_afe!(afe, samples, labels, thin=100)

fig

## Binning choices

## Both uncertanties

In [None]:
@model function n_component_model_both(models_ah, models_afe, labels, priors)
    # Create parameters based on the specified priors
    params ~ arraydist(priors)
    
    # Compute model contributions for each dataset
    mu_ah = sum(p * models_ah[:, key] for (p, key) in zip(params, labels))
    mu_afe = sum(p * models_afe[:, key] for (p, key) in zip(params, labels))
    sem_ah = sum(p * models_ah[:, Symbol("$(key)_err")] ./ sqrt.(models_ah[:, Symbol("$(key)_counts")])  for (p, key) in zip(params, labels))
    sem_afe = sum(p * models_afe[:, Symbol("$(key)_err")] ./ sqrt.(models_afe[:, Symbol("$(key)_counts")]) for (p, key) in zip(params, labels))
    
    # Data likelihoods

    sigma_ah = models_ah.obs_err ./ sqrt.(models_ah.obs_counts)
    sigma_afe = models_afe.obs_err ./ sqrt.(models_afe.obs_counts)

    sigma2_ah = @. sigma_ah^2 + sem_ah^2 
    sigma2_afe = @. sigma_afe^2 + sem_afe^2 
    
    models_ah.obs ~ MvNormal(mu_ah, diagm(sigma2_ah))
    models_afe.obs ~ MvNormal(mu_afe, diagm(sigma2_afe))
end

In [None]:
priors = Dict(
    :alpha => Normal(1, 1),
    :zeta0 => Normal(2, 1),
    :zeta1 => Normal(0, 1),
    :zeta2 => Normal(0, 2)
    )

In [None]:
priors = [priors[label] for label in labels]

In [None]:
model = n_component_model_both(ah, afe, labels, priors
)

In [None]:
chain = sample(model, NUTS(0.65), 10_000)

In [None]:
samples = DataFrame(chain)

In [None]:
fig = pairplot(chain, labels=Dict(
        Symbol("params[$i]") => string(label) for (i, label) in enumerate(labels)
    )
)

In [None]:
fig = Figure()
ax = Axis(fig[1, 1],
    xlabel="[Mg/H]",
    ylabel="mean Z_C / Zmg",
)


plot_samples_ah!(ah, samples, labels, thin=100)

fig

In [None]:
fig = Figure()
ax = Axis(fig[1, 1],
    xlabel="[Mg/H]",
    ylabel="mean Z_C / Zmg",
)


plot_samples_caah!(ah, samples, labels, thin=100)

fig

In [None]:
fig = Figure()
ax = Axis(fig[1, 1],
    xlabel="[Mg/Fe]",
    ylabel="mean Z_C",
)


plot_samples_afe!(afe, samples, labels, thin=100)

fig

# Divergence methods

This section explores some much more computationally heavier but more sophisticated models using likelihoods of statistical divergences between probability distributions.

## Utilities & friends

In [None]:
using PythonCall
emcee = pyimport("emcee")

In [None]:
function get_simple_variables(model)
    model_labels = bijector(model, Val(true))[2]

    labels = Vector{Symbol, undef}()
    for label in keys(model_labels)
        idx = model_labels[key]
        @assert length(idx) == 1
        @assert length(idx[1]) == 1

        labels[idx[1][1]] = label
    end

    return labels
end
        


In [None]:
function sample_emcee(model, steps; nwalkers=nothing, progress=true, kwargs...)
    model_labels = bijector(model, Val(true))[2]

    # these are a named tuple
    variables = keys(model_labels) |> collect
    # add up number of indicies in bijection for dimensionality
    ndim = sum(length(vcat(idxs...)) for idxs in model_labels)
    
    if nwalkers === nothing
        nwalkers = 2*ndim
    end
    
    function py_log_prob(theta)
        vars = (; [k => theta[model_labels[k]...] for k in variables]...)
        lp = logprior(model, vars)
        if isfinite(lp)
            lp += loglikelihood(model, vars)
        end
            
        return lp
    end

    # sample from prior for initial conditions
    samples_prior = sample(model, Prior(), nwalkers)
    p0 = samples_prior.value[:, 1:ndim, 1].data

    # python run
    sammy = emcee.EnsembleSampler(nwalkers, ndim, py_log_prob)
    sammy.run_mcmc(p0, steps; progress=progress, kwargs...)

    #convert back to julia object
    chain = pyconvert(Array{Float64}, sammy.get_chain())
    chain = permutedims(chain, (1, 3, 2)) # emcee orderes differently
    varnames = samples_prior.name_map.parameters
    chain = Chains(chain, varnames)

    return chain
end

In [None]:
function bin_values_to_list(df::DataFrame; x::Symbol=:MG_H_true, bins = mg_h_bins, val::Symbol=:z_c)
    Nb = length(bins) - 1

    # Create bin identifiers for the x dimension
    x_bin = cut(df[!, x], bins, extend=missing, labels=1:Nb)

    # Filter out any rows with missing bin assignments
    filt = .!ismissing.(x_bin)
    df_filtered = copy(df[filt, :])
    df_filtered[!, :x_bin] = x_bin[filt]

    # Initialize a dictionary to store values for each bin
    bin_dict = Dict{Int, Vector{eltype(df[!, val])}}()
    for i in 1:Nb
        bin_dict[i] = Vector{eltype(df[!, val])}()
    end

    # Group the DataFrame by bins and populate the dictionary
    grouped = groupby(df_filtered, :x_bin)

    for group in grouped
        bin_number = first(group.x_bin)
        bin_dict[bin_number] = group[!, val] |> collect
    end

    return [bin_dict[i] for i in 1:Nb]
end

In [None]:
function bin_zc_afe_list(df::DataFrame; m_h=:MG_H_true, m_h_0=mg_h_0, d_m_h=d_mg_h, x=:MG_FE_true, val=:z_c, bins=mg_fe_bins, kwargs...)
    filt = df[:, m_h] .>= m_h_0 - d_m_h
    filt .&= df[:, m_h] .< m_h_0 + d_m_h
    return bin_values_to_list(df[filt, :]; bins=bins, x=x, val=val, kwargs...)
end

In [None]:
function bin_zc_ah_list(df::DataFrame;x=:MG_H_true, val=:z_c, bins=mg_h_bins, kwargs...)
    filt = .!df.high_alpha
    return bin_values_to_list(df[filt, :]; bins=bins, x=x, val=val, kwargs...)
end

In [None]:
function bin_all_to_list(models, observations=subgiants;
        mg_h_bins=mg_h_bins, mg_fe_bins=mg_fe_bins, 
        mg_h_0=mg_h_0, d_mg_h=d_mg_h
    )

    labels = keys(models) |> collect

    ah_dfs = Dict(label => bin_zc_ah_list(model, bins=mg_h_bins) 
        for (label, model) in models)
    
    afe_dfs = Dict(label => bin_zc_afe_list(model, bins=mg_fe_bins, 
            m_h_0=mg_h_0, d_m_h=d_mg_h) 
        for (label, model) in models)

    ah_dfs[:obs] = bin_zc_ah_list(observations, bins=mg_h_bins, x=:MG_H, val=:z_c) .|> sort
    afe_dfs[:obs] = bin_zc_afe_list(observations, bins=mg_fe_bins, 
        val=:z_c, x=:MG_FE, m_h=:MG_H,
            m_h_0=mg_h_0, d_m_h=d_mg_h) .|> sort

    

    idx_ah = [sortperm(df) for df in ah_dfs[labels[1]]]
    idx_afe = [sortperm(df) for df in afe_dfs[labels[1]]]

    for label in labels
        ah_dfs[label] = [df[idx] for (df, idx) in zip(ah_dfs[label], idx_ah)]
        afe_dfs[label] = [df[idx] for (df, idx) in zip(afe_dfs[label], idx_afe)]

    end

    
    return ah_dfs, afe_dfs
end

In [None]:
ah_list, afe_list = bin_all_to_list(models);

In [None]:
fig = Figure()
ax = Axis(fig[1,1])

for label in [labels; :obs]
    ys = []
    N = length(ah_list[label])
    
    for i in 1:N
        push!(ys, mean(ah_list[label][i]))
    end

    scatter!(1:N, ys, label=string(label))
        
end

axislegend()
fig

In [None]:
fig = Figure()
ax = Axis(fig[1,1])

i = 3
for label in labels
    stephist!(ah_list[label][i], label=string(label))
end

axislegend()
fig

In [None]:
fig = Figure()
ax = Axis(fig[1,1])

i = 6
for label in labels
    stephist!(afe_list[label][i], label=string(label))
end

axislegend()
fig

In [None]:
using HypothesisTests

In [None]:
@model function  divergence_model(models_ah, models_afe, labels, priors, p_func)
    # Create parameters based on the specified priors
    params ~ arraydist(priors)
    
    ah_obs = models_ah[:obs]
    afe_obs = models_afe[:obs]

    Nb_ah = length(ah_obs)

    LLtot = 0
    for i in 1:Nb_ah
        ah_pred = sum(p * models_ah[key][i] for (p, key) in zip(params, labels)) 
        obs = ah_obs[i]
        L = p_func(obs, ah_pred)
        @Turing.addlogprob!(L)

        LLtot += (L)
    end

    Nb_afe = length(afe_obs)

    for i in 1:Nb_afe
        afe_pred = sum(p * models_afe[key][i] for (p, key) in zip(params, labels))
        obs = afe_obs[i]

        L = p_func(obs, afe_pred)
        @Turing.addlogprob!(L)
        
        LLtot += (L)
    end

    
end

In [None]:
priors = Dict(
    :alpha => Normal(1, 1),
    :zeta0 => Normal(2, 1),
    :zeta1 => Normal(1, 0.5),
    :zeta2 => Normal(2, 0.5)
    )

In [None]:
labels = keys(models) |> collect

In [None]:
priors = [priors[label] for label in labels]

### Mean

In [None]:
metric(x, y) = -(mean(x) - mean(y))^2 / (sem(y))^2

In [None]:
priors = Dict(
    :alpha => Normal(0.5, 0.25),
    :zeta0 => Normal(2, 0.25),
    :zeta1 => Normal(2, 0.5),
    :zeta2 => Normal(2, 0.5)
    )

priors = [priors[label] for label in labels]

In [None]:
model = divergence_model(ah_list, afe_list, labels, priors,  metric);

In [None]:
chain = sample_emcee(model, 1_000, nwalkers=16)

In [None]:
burn = 300

In [None]:
plot_mcmc_results(chain[burn:end])

In [None]:
fig = pairplot(chain[burn:end], labels=Dict(
        Symbol("params[$i]") => string(label) for (i, label) in enumerate(labels)
    )
)

In [None]:
samples = DataFrame(chain[burn:end])

In [None]:
fig = Figure()
ax = Axis(fig[1, 1],
    xlabel="[Mg/H]",
    ylabel="mean Z_C / Zmg",
)


plot_samples_ah!(ah, samples, labels, thin=10)

fig

In [None]:
fig = Figure()
ax = Axis(fig[1, 1],
    xlabel="[Mg/Fe]",
    ylabel="mean Z_C",
)


plot_samples_afe!(afe, samples, labels, thin=10)

fig

### T test

In [None]:
metric(x, y) = log(pvalue(OneSampleTTest(x, median(y))))


In [None]:
model = divergence_model(ah_list, afe_list, labels, priors,  metric);

Gradient-based samplers (at least NUTS) struggle in performace with this model, likely because of vector sorting / etc. I believe the only non-gradient models are PG and MH (and Gibbs of those two), so lets just use PG for now.

In [None]:
chain = sample_emcee(model, 3_000, nwalkers=16)

In [None]:
burn = 500

In [None]:
chain_filt = dropdims(mean(chain[:, 1, :], dims=1) .> 0, dims=1)

In [None]:
samples = DataFrame(chain[burn:end, :, chain_filt])

In [None]:
fig = pairplot(chain[burn:end, :, chain_filt], labels=Dict(
        Symbol("params[$i]") => string(label) for (i, label) in enumerate(labels)
    )
)

In [None]:
plot_mcmc_results(chain[burn:end, :, chain_filt])

In [None]:
fig = Figure()
ax = Axis(fig[1, 1],
    xlabel="[Mg/H]",
    ylabel="mean Z_C / Zmg",
)


plot_samples_ah!(ah, samples, labels, thin=100)

fig

In [None]:
fig = Figure()
ax = Axis(fig[1, 1],
    xlabel="[Mg/Fe]",
    ylabel="mean Z_C",
)


plot_samples_afe!(afe, samples, labels, thin=10)

fig

### U model

In [None]:
metric(x, y) = log(pvalue(MannWhitneyUTest(x, y)))


In [None]:
ah_list

In [None]:
@time metric(ah_list[:alpha][1], ah_list[:zeta1][1])

In [None]:
model = divergence_model(ah_list, afe_list, labels, priors,  metric);

Gradient-based samplers (at least NUTS) struggle in performace with this model, likely because of vector sorting / etc. I believe the only non-gradient models are PG and MH (and Gibbs of those two), so lets just use PG for now.

In [None]:
chain = sample_emcee(model, 1000, nwalkers=8)

In [None]:
burn = 300

In [None]:
chain_filt = dropdims(mean(chain[:, 1, :], dims=1).data .< 1, dims=1)

In [None]:
samples = DataFrame(chain[burn:end, :, chain_filt])

In [None]:
plot_mcmc_results(chain[burn:end, :, :])

In [None]:
plot_mcmc_results(chain[burn:end, :, chain_filt])

In [None]:
fig = pairplot(chain[burn:end, :, chain_filt], labels=Dict(
        Symbol("params[$i]") => string(label) for (i, label) in enumerate(labels)
    )
)

In [None]:
fig = Figure()
ax = Axis(fig[1, 1],
    xlabel="[Mg/H]",
    ylabel="mean Z_C / Zmg",
)


plot_samples_ah!(ah, samples, labels, thin=100)

fig

In [None]:
fig = Figure()
ax = Axis(fig[1, 1],
    xlabel="[Mg/Fe]",
    ylabel="mean Z_C",
)


plot_samples_afe!(afe, samples, labels, thin=10)

fig

### KS test (hypothesistest)

In [None]:
metric(x, y) = log(pvalue(HypothesisTests.ApproximateTwoSampleKSTest(x, y)))


In [None]:
model = divergence_model(ah_list, afe_list, labels, priors,  metric);

Gradient-based samplers (at least NUTS) struggle in performace with this model, likely because of vector sorting / etc. I believe the only non-gradient models are PG and MH (and Gibbs of those two), so lets just use PG for now.

In [None]:
chain = sample_emcee(model, 10_000, nwalkers=8)

In [None]:
chain_filt = dropdims(mean(chain[:, 1, :], dims=1).data .> 0, dims=1)

In [None]:
samples = DataFrame(chain[burn:end, :, chain_filt])

In [None]:
samples_all = DataFrame(chain)
rename!(samples_all, collect("params[$i]" => labels[i] for i in eachindex(labels))...)


In [None]:
CSV.write("mcmc_samples/kstest_samples.csv", samples_all)

In [None]:
plot_mcmc_results(chain[burn:end, :, chain_filt])

In [None]:
fig = pairplot(chain[burn:end, :, chain_filt], labels=Dict(
        Symbol("params[$i]") => string(label) for (i, label) in enumerate(labels)
    )
)

In [None]:
fig = Figure()
ax = Axis(fig[1, 1],
    xlabel="[Mg/H]",
    ylabel="mean Z_C / Zmg",
)


plot_samples_ah!(ah, samples, labels, thin=100)

fig

In [None]:
fig = Figure()
ax = Axis(fig[1, 1],
    xlabel="[Mg/Fe]",
    ylabel="mean Z_C",
)


plot_samples_afe!(afe, samples, labels, thin=10)

fig

# 2D methods (binned

In [None]:
function bin_means_2d(df::DataFrame; x::Symbol=:MG_H_true, y::Symbol=:MG_FE_true, x_bins = mg_h_bins, y_bins = mg_fe_bins, val::Symbol=:z_c, n_min::Int=3)
    # Determine the number of bins in each dimension
    Nb_x = length(x_bins) - 1
    Nb_y = length(y_bins) - 1

    # Create bin identifiers for x and y dimensions
    x_bin = cut(df[!, x], x_bins, extend=missing, labels=1:Nb_x)
    y_bin = cut(df[!, y], y_bins, extend=missing, labels=1:Nb_y)

    # Filter out any rows with missing bin assignments
    filt = .!ismissing.(x_bin) .&& .!ismissing.(y_bin)
    df_filtered = copy(df[filt, :])
    df_filtered[!, :x_bin] = x_bin[filt]
    df_filtered[!, :y_bin] = y_bin[filt]

    # Group by the 2D bins
    grouped = groupby(df_filtered, [:x_bin, :y_bin])

    # Calculate the statistics within each 2D bin
    results = combine(grouped,
        val => mean => :med,
        x => mean => :xmed,
        y => mean => :ymed,
        val => std => :err,
        val => length => :counts
    )

    # Create the full 2D grid of bins
    full_grid = DataFrame(x_bin = repeat(1:Nb_x, Nb_y), y_bin = repeat(1:Nb_y, inner=Nb_x))

    # Join results with the full grid to fill missing bins with NaNs
    df_result = leftjoin(full_grid, results, on=[:x_bin, :y_bin], order=:left)
    
    # Fill in the midpoints for the bin centers
    x_bin_mids = midpoints(x_bins)
    y_bin_mids = midpoints(y_bins)
    df_result.x = getindex.(Ref(x_bin_mids), df_result.x_bin)
    df_result.y = getindex.(Ref(y_bin_mids), df_result.y_bin)

    
    # Handle missing values and replace with defaults where necessary
    filt_missing = ismissing.(df_result.counts) .|| df_result.counts .< n_min
    df_result[filt_missing, :counts] .= 0
    df_result[filt_missing, :med] .= 0
    df_result[filt_missing, :err] .= 0
    df_result[filt_missing, :xmed] .= NaN
    df_result[filt_missing, :ymed] .= NaN

    # Ensure there are no missing values
    disallowmissing!(df_result)

    return df_result
end

In [None]:
df = bin_means_2d(model_alpha, x=:MG_H, y=:MG_FE)

In [None]:
df = bin_means_2d(subgiants, x=:MG_H, y=:MG_FE)

In [None]:
function bin_all_2d(models, observations=subgiants;
        mg_h_bins=mg_h_bins, mg_fe_bins=mg_fe_bins, n_min=3
    )

    labels = keys(models) |> collect

    ahfe_dfs = Dict(label => bin_means_2d(model, x_bins=mg_h_bins, y_bins=mg_fe_bins) 
        for (label, model) in models)
    

    ahfe_obs = bin_means_2d(observations,  x_bins=mg_h_bins, y_bins=mg_fe_bins, x=:MG_H, y=:MG_FE, val=:z_c)
    

    ahfe_binned = DataFrame([label => df.med for (label, df) in ahfe_dfs]...)

    for label in labels
        ahfe_binned[:, ("$(label)_err")] = ahfe_dfs[label].err
        ahfe_binned[:, ("$(label)_counts")] = ahfe_dfs[label].counts

        @assert ahfe_obs.x_bin == ahfe_dfs[label].x_bin
    end

    ahfe_binned[!, :obs] = ahfe_obs.med
    ahfe_binned[!, :obs_err] = ahfe_obs.err
    ahfe_binned[!, :obs_counts] = ahfe_obs.counts
    ahfe_binned[!, :x] = ahfe_obs.x
    ahfe_binned[!, :y] = ahfe_obs.y


    filt = ahfe_binned.obs_counts .>= n_min
    filt = ahfe_binned[:, "$(labels[1])_counts"] .>= n_min
    
    return ahfe_binned[filt, :], labels
end

In [None]:
fig = Figure()
ax = Axis(fig[1, 1], 
    xlabel="[Mg/H]",
    ylabel="[Mg/Fe]",
    limits=(-0.6, 0.6, -0.1, 0.5)
    )


p = scatter!(subgiants.MG_H, subgiants.MG_FE, color=subgiants.C_MG, markersize = 2., colorrange=(-0.5, 0.2))

Colorbar(fig[1, 2], p, label="[C/Mg]")

fig

In [None]:
ahfe, labels = bin_all_2d(models)

In [None]:
scatter(df.x, df.y, color=df.med, markersize=df.counts ./ 10)

In [None]:
minimum(ahfe.obs_counts), minimum(ahfe.alpha_counts)

In [None]:
scatter(ahfe.x, ahfe.y, color=ahfe.alpha, markersize=ahfe.alpha_counts ./ 20 .+ 3)

In [None]:
fig = Figure()
ax = Axis(fig[1, 1], 
    xlabel="[Mg/H]",
    ylabel="[Mg/Fe]",
    limits=(-0.4, 0.25, -0.0, 0.3),
    )


p = scatter!(model_alpha.MG_H, model_alpha.MG_FE, color=model_alpha.C_MG, markersize = 2., colorrange=(-0.5, 0.2))

Colorbar(fig[1, 2], p, label="[C/Mg]")

fig

In [None]:
@model function model_bin2d(models_2d, labels, priors)
    # Create parameters based on the specified priors
    params ~ arraydist(priors)
    
    # Compute model contributions for each dataset
    mu = sum(p * models_2d[:, key] for (p, key) in zip(params, labels))
    


    sigma_model = sum(p * models_2d[:, "$(key)_err"] ./ sqrt.(models_2d[:, "$(key)_counts"]) for (p, key) in zip(params, labels))
    sigma2 = @. models_2d.obs_err^2 ./ models_2d.obs_counts .+ sigma_model .^ 2

    models_2d.obs ~ MvNormal(mu, diagm(sigma2))
end

In [None]:
function plot_mean_ahfe(afe, samples, labels; mode=:mean, kwargs...)
    fig = Figure()
    ax = Axis(fig[1,1],
        xlabel="[Mg/H]",
        ylabel="[C/Mg]",
        )


    p = plot_mean_ahfe!(afe, samples, labels; mode=mode, kwargs...)

    if mode == :residual
        label  = "Δ [C/Mg]"
    elseif mode == :zscore
        label = "z score"
    else
        label = "[C/Mg]"
    end
    
    Colorbar(fig[1,2], p, label=label)

    fig
end

In [None]:
function plot_mean_ahfe!(afeh, samples, labels;
        thin=10, color=:black, alpha=nothing, colorrange=nothing, mode = :mean,
        kwargs...)
    
    alpha_mean = [mean(samples[:, "params[$i]"]) for (i, label) in enumerate(labels)]

    y = sum(alpha_mean[i] * afeh[:, label]  for (i, label) in enumerate(labels))

 

    if colorrange === nothing
        cmin = min(minimum(y), minimum(afeh.obs))
        cmax = max(maximum(y), maximum(afeh.obs))
        colorrange = (cmin, cmax)
    end

    if mode == :residual
        res = afeh.obs .- y
        resmax = maximum(abs.(res))
        
        p = scatter!(afeh.x, afeh.y, markersize=20, color=res, colorrange=(-resmax, resmax), colormap=:RdBu)

    elseif mode == :zscore
        sy = afeh.obs_err ./ sqrt.(afeh.obs_counts)
        sy_model = sum(alpha_mean[i] * afeh[:, "$(key)_err"] ./ sqrt.(afeh[:, "$(key)_counts"]) for (i, key) in enumerate(labels))

        sy = @. sqrt(sy^2 + sy_model^2)
        res = (afeh.obs .- y) ./ sy
        resmax = maximum(abs.(res))
        
        p = scatter!(afeh.x, afeh.y, markersize=20, color=res, colorrange=(-resmax, resmax), colormap=:RdBu)

    elseif mode == :mean
    
        scatter!(afeh.x, afeh.y, color=afeh.obs, markersize=20, colorrange=colorrange)
        scatter!(afeh.x, afeh.y, color=:white, markersize=12)
    
        p = scatter!(afeh.x, afeh.y, color=y, markersize=10, colorrange=colorrange)

    else
        @error "hi"
    end
    
    return p
end

In [None]:
function plot_samples_ahfe!(afe, samples, labels;
        thin=10, color=:black, alpha=nothing, kwargs...)
    
    if alpha === nothing
        alpha = 1/size(samples, 1)^(1/3)
    end

    xs = afe.x |> unique |> sort
    dx = xs[2] - xs[1]
    
    ys = afe.y |> unique |> sort
    dy = ys[2] - ys[1]
    
    for sample in eachrow(samples)[1:thin:end]
        z = sum(sample["params[$i]"] * afe[:, label]  for (i, label) in enumerate(labels))

        Np = length(afe.x)
        x = afe.x .+ (0.5 .- 1rand(Np)) * dx
        y = afe.y .+ (0.5 .- 1rand(Np)) * dy
        scatter!(x, y, color=z)
    end

end

In [None]:
priors = Dict(
    :alpha => Normal(1, 1),
    :zeta0 => Normal(2, 1),
    :zeta1 => Normal(0, 1),
    :zeta2 => Normal(0, 2)
    )

In [None]:
priors = [priors[label] for label in labels]

In [None]:
model = model_bin2d(ahfe, labels, priors
)

In [None]:
chain = sample(model, NUTS(0.65), 10_000)

In [None]:
samples = DataFrame(chain)

In [None]:
fig = pairplot(chain, labels=Dict(
        Symbol("params[$i]") => string(label) for (i, label) in enumerate(labels)
    )
)

In [None]:
fig = Figure()
ax = Axis(fig[1, 1],
    xlabel="[Mg/H]",
    ylabel="mean Z_C / Zmg",
)


plot_samples_ah!(ah, samples, labels, thin=100)

fig

In [None]:
fig = Figure()
ax = Axis(fig[1, 1],
    xlabel="[Mg/H]",
    ylabel="mean Z_C / Zmg",
)


plot_samples_caah!(ah, samples, labels, thin=100)

fig

In [None]:
fig = Figure()
ax = Axis(fig[1, 1],
    xlabel="[Mg/Fe]",
    ylabel="mean Z_C",
)


plot_samples_afe!(afe, samples, labels, thin=100)

fig

In [None]:
plot_mean_ahfe(ahfe, samples, labels)

In [None]:
plot_mean_ahfe(ahfe, samples, labels, mode=:residual)

In [None]:
plot_mean_ahfe(ahfe, samples, labels, mode=:zscore)