In [None]:
using CairoMakie
using Turing
using CSV, DataFrames

In [None]:
using OrderedCollections

In [None]:
using StatsBase

In [None]:
using PairPlots

In [None]:
import CategoricalArrays: cut

In [None]:
using FillArrays

In [None]:
import NaNMath as nm

In [None]:
using Arya

In [None]:
subgiants = CSV.read("../data/subgiants.csv", DataFrame)

In [None]:
solar_z = (;
    c = 3.39e-03
    )

In [None]:
function C_H_to_Zc(ag_h)
    return solar_z.c * 10 .^ ag_h
end

In [None]:
function Zc_to_C_H(z_c)
    return @. nm.log10(z_c / solar_z.c)
end

In [None]:
"""
Finds the pickled model with either the given name or the parameters 
and returns the csv summary
"""
function find_model(name)

    
    file_name = "../models/perturbations/$name/stars.csv"
    model =  CSV.read(file_name, DataFrame)
    model[!, "z_c"] = C_H_to_Zc(model.AG_H)
    return model
end

In [None]:
rand([1,2,3], 10)

In [None]:
function median_se(values::AbstractVector; N=1000)
    samples = [median(rand(values, length(values))) for _ in 1:N]
    
    return std(samples)
end

In [None]:
function bin_medians(df::DataFrame; x::Symbol=:MG_H_true, bins = 0:0.05:0.35, val::Symbol=:z_c, n_min::Int=3)
    Nb = length(bins) - 1

    x_bin = cut(df[!, x], bins, extend=missing, labels=1:Nb)

    filt = .!ismissing.(x_bin)
    df_filtered = copy(df[filt, :])
    df_filtered[!, :x_bin] = x_bin[filt]
    grouped = groupby(df_filtered, :x_bin)

    results = combine(grouped,
        val => median => :med,
        x => median => :xmed,
        val => median_se => :err,
        val => length => :counts
    )

    full_grid = DataFrame(x_bin = 1:Nb)

    df_result = leftjoin(full_grid, results, on=:x_bin)

    x_bin_mids = midpoints(bins)
    df_result.x = getindex.(Ref(x_bin_mids), df_result.x_bin)

    filt_missing = ismissing.(df_result.counts)
    df_result[filt_missing, :counts] .= 0
    df_result[filt_missing, :med] .= 0
    df_result[filt_missing, :err] .= 0
    #DataFrames.transform!(df_result, :counts => ByRow(count -> count < n_min ? NaN : count) => :counts)

    return df_result
end

In [None]:
o_h_0 = -0.10

In [None]:
function bin_caah(df::DataFrame; bins=-0.5:0.1:0.35, kwargs...)
    filt = .!df.high_alpha
    return bin_medians(df[filt, :]; bins=bins, kwargs...)
end

In [None]:
function bin_caafe(df::DataFrame; m_h=:MG_H_true, m_h_0=o_h_0, d_m_h=0.05, x=:MG_FE_true, val=:z_c, bins=0.0:0.05:0.30, kwargs...)
    filt = df[:, m_h] .> m_h_0 - d_m_h
    filt .&= df[:, m_h] .< m_h_0 + d_m_h
    return bin_medians(df[filt, :]; bins=bins, x=x, val=val, kwargs...)
end

In [None]:
struct mz_result
    zcaah::DataFrame
    zcaafe::DataFrame
end

In [None]:
struct binned_trends
    caah::DataFrame
    caafe::DataFrame
end

In [None]:
struct BinnedZc
    x::Vector{Float64}
    y::Vector{Float64}
end

In [None]:
function bin_model(name)
    model = find_model(name)

    return mz_result(bin_caah(model), bin_caafe(model))
end

In [None]:
import Base: +, *

In [None]:
function (+)(a::mz_result, b::mz_result)
    c = deepcopy(a)
    c.zcaah.med .+= b.zcaah.med
    c.zcaah.counts .+= b.zcaah.counts
    c.zcaah.err .+= b.zcaah.err
    return c
end

In [None]:
function (*)(a::Real, b::mz_result)
    c = deepcopy(b)
    c.zcaah.med .*= a
    c.zcaah.err .*= a
    return c
end

In [None]:
function add_models(models, coeffs)
    total = 0 * models[1].second 
    total.zcaah.med[ismissing.(total.zcaah.med)] .= 0
    
    for (name, model) in models
        coef = coeffs[name]
        total = total + coef * model
    end

    return total
end

# Data loading

In [None]:
subgiants_binned_afe = bin_caafe(subgiants, x=:MG_FE, m_h=:MG_H, val=:C_MG)

In [None]:
subgiants_binned_ah = bin_caah(subgiants, x=:MG_H, val=:C_MG)

In [None]:
models = Dict(
    :α =>  bin_model("analytic"),
    :y0_cc => bin_model("const_cc"),
    :ζ => bin_model("piecelin_m0.2")
    )

In [None]:
fig, ax = FigAxis()

for (label, model) in models
    scatter!(model.zcaah.x, model.zcaah.med, label=string(label))
end

axislegend()
fig

In [None]:
fig, ax = FigAxis()

for (label, model) in models
    scatter!(model.zcaafe.x, model.zcaafe.med, label=string(label))
end

axislegend()
fig

In [None]:
models_zc_ah = Dict(name => disallowmissing(model.zcaah.med) for (name, model) in models)

In [None]:
models_zc_afe = Dict(name => disallowmissing(model.zcaafe.med) for (name, model) in models)

In [None]:
models_zc_afe[:α]

In [None]:
errscatter(subgiants_binned_ah.x, subgiants_binned_ah.med, yerr=subgiants_binned_ah.err,
    axis=(;xlabel="[Mg/H]", ylabel="[C/Mg]", title="binned subgiants")
)

In [None]:
errscatter(subgiants_binned_afe.x, subgiants_binned_afe.med, yerr=subgiants_binned_afe.err,
    axis=(;xlabel="[Mg/Fe]", ylabel="[C/Mg]", title="binned subgiants (with [M/H] in ...)")
)

# Linear Model

In [None]:
using LinearAlgebra: diagm

In [None]:
@model function fit_3_comp_model(x, y, y_e, x2, y2, y_e2, models, models2; m_h_0=o_h_0)
    y0_cc ~ Normal(2, 1)
    α ~ Normal(2, 1)
    ζ ~ Normal(0, 1)

    Zc = y0_cc * models[:y0_cc] .+ α * models[:α] .+ ζ * models[:ζ]

    mu = Zc_to_C_H.(Zc) .- x
    y1_pred = mu

    Zc2 = y0_cc * models2[:y0_cc] .+ α * models2[:α] .+ ζ * models2[:ζ]
    mu2 = Zc_to_C_H.(Zc2) .- m_h_0

    
    y ~ MvNormal(mu, diagm(y_e .^ 2))
    y2 ~ MvNormal(mu2,  diagm(y_e2 .^ 2))
end

In [None]:
function plot_samples!(samples, x, models;
        thin=10, color=:black, alpha=nothing, kwargs...)

    if alpha === nothing
        alpha = 1/size(samples, 1)^(1/3)
    end
    
    for sample in eachrow(samples)[1:thin:end]
        y = @. sample.α * models[:α] + sample.y0_cc * models[:y0_cc] + sample.ζ * models[:ζ]
        lines!(x, Zc_to_C_H(y) .- x, color=color, alpha=alpha)
    end
end

In [None]:
function plot_samples_afe!(samples, x, models;
        thin=10, color=:black, alpha=nothing, kwargs...)

    if alpha === nothing
        alpha = 1/size(samples, 1)^(1/3)
    end
    
    for sample in eachrow(samples)[1:thin:end]
        y = @. sample.α * models[:α] + sample.y0_cc * models[:y0_cc] + sample.ζ * models[:ζ]
        lines!(x, Zc_to_C_H(y) .- o_h_0, color=color, alpha=alpha)
    end
end

In [None]:
y = disallowmissing(subgiants_binned_ah.med)
yerr = disallowmissing(subgiants_binned_ah.err)# ./ sqrt.(subgiants_binned.counts)
x = disallowmissing(subgiants_binned_ah.x)

y2 = disallowmissing(subgiants_binned_afe.med)
yerr2 = disallowmissing(subgiants_binned_afe.err)# ./ sqrt.(subgiants_binned2.counts)
x2 = disallowmissing(subgiants_binned_afe.x)

In [None]:
yerr

In [None]:
models_zc_afe[:α]

In [None]:
model = fit_3_comp_model(x, y, yerr, x2, y2, yerr2, models_zc_ah, models_zc_afe)

In [None]:
chain = sample(model, NUTS(0.65), 5_000)

In [None]:
samples = DataFrame(chain)

In [None]:
pairplot(chain)

In [None]:
fig = Figure()
ax = Axis(fig[1, 1],
    #limits=(-0.5, 0.3, -0.2, 0.0),
    xgridvisible=false,
    ygridvisible=false,
    xlabel="[Mg/H]",
    ylabel="[C/Mg]",
)


plot_samples!(samples, x, models_zc_ah, alpha=0.008)
errscatter!(x, y, yerr=yerr)

fig

In [None]:
fig = Figure()
ax = Axis(fig[1, 1],
    xlabel="[Mg/Fe]",
    ylabel=L"\textrm{mean([C/Mg]) | [Mg/H] }\in [-0.15, -0.05]",
    xgridvisible=false,
    ygridvisible=false,
)


plot_samples_afe!(samples, x2, models_zc_afe, alpha=0.008)
errscatter!(x2, y2, yerr=yerr2)

fig