In [None]:
using CairoMakie
using Arya
using DataFrames, CSV
import Base: @kwdef

using DataFrames

In [None]:
using PythonCall
surp = pyimport("surp")
vice = pyimport("vice")
gcem = surp.gce_math

py2f(x) = pyconvert(Float64, x)
py2vec(x) = pyconvert(Vector{Float64}, x)

In [None]:
using Printf

In [None]:
surp.yields.set_yields()

In [None]:
subgiants = surp.subgiants

In [None]:
χ(t, i) = t^i * exp(-t)
ϵ(t) = χ(t, 0)
ξ(t) = χ(t, 1)

In [None]:
calc_ϖ(params) =  (1 + params.η - params.r) / params.τ_star

In [None]:
using Revise # so we can change Params
R = Float64

In [None]:
@kwdef struct Params
    sfh::String = "const"
    η::R = 0.5
    r::R = 0.4
    τ_star::R = 2
        
    t_d_agb::R = 0.1
    τ_agb::R = 0.3
        
    y_c_agb::R = 0.000456
    γ::R = 1

    y_c_cc::R = 2.28e-03
    ζ_c_agb::R = 0
    ζ_c_cc::R = 0
        
    y_o::R = 7.13e-03
    y_fe_cc::R = 4.73e-04
    y_fe_ia::R = 7.70e-04
        
    τ_ia::R = 1.5
    κ_ia::R = -1
    t_d_ia::R = 0.15
        
    t_end::R = 13.2
end

In [None]:
function Base.show(io::IO, params::Params)
    @printf io "%-12s %s\n" "property" "value"
    @printf io "%-12s %s\n" "-"^12 "-"^12
    for key in propertynames(params)
        @printf io "%-12s %s\n" key getproperty(params, key)
    end
end

In [None]:
params = Params()

In [None]:
function cc_model_const(t, y, params = Params())
    ϖ = calc_ϖ(params)
    s = @. 1 - ϵ(ϖ * t)
    return @. y / params.τ_star * s / ϖ
end

In [None]:
function fe_ia_model_const(t, params = Params())
    ϖ = calc_ϖ(params)
    ι = 1 / params.τ_ia
    
    δt = @. t - params.t_d_ia
    
    s = @. 1 - ϖ/(ϖ - ι) * ϵ(ι * δt)
    s .+= @. ι/(ϖ - ι) * ϵ(ϖ * δt)
    s = ifelse.(δt .<= 0, 0, s)
    
    Zeq = params.y_fe_ia / params.τ_star  / ϖ
    return Zeq .* s
end



function fe_model_const(t, params=Params())
    Zcc = cc_model_const(t, params.y_fe_cc, params)
    Zia = fe_ia_model_const(t, params)
    return Zcc + Zia
end


In [None]:
function c_agb_model_gamma1_const(t, params=Params())
    α = 1/params.τ_agb
    ϖ = calc_ϖ(params)
    δt = t .- params.t_d_agb
    
    s = 1 / ϖ 
    s = @. s +  1/(α - ϖ) * ξ(α * δt)
    s .+= @. (2*α- ϖ) / (α - ϖ)^2 * ϵ(α * δt)
    s .+=  @. -(1/ϖ + (2*α- ϖ) / (α - ϖ)^2) * ϵ(ϖ * δt)
    
    s = ifelse.(δt .< 0, 0, s)

    return @. params.y_c_agb * s / params.τ_star
end

function c_agb_model_gamma2_const(t, params=Params())
    α = 1/params.τ_agb
    ϖ = calc_ϖ(params)
    δt = t .- params.t_d_agb
    
    s = ones(length(t))
    s .+= @. ( 2ϖ^3/(α-ϖ)^3 + 2ϖ^2/(α*(α-ϖ)) + 2ϖ/(α-ϖ) ) * χ(α*δt, 0)
    s .+= @. -( 1 + 2ϖ^3/(α-ϖ)^3 + 2ϖ^2/(α-ϖ)^2 + 2ϖ/(α-ϖ) ) * χ(ϖ*δt, 0)
    s .+= @. ( 2ϖ^3 / (α*(α-ϖ)^2) + 2ϖ^2 / (α*(α-ϖ)) ) * χ(α*δt, 1)
    s .+= @. ( ϖ^3/(α^2*(α-ϖ)) ) * χ(α*δt, 2)
    s = ifelse.(δt .< 0, 0, s)

    return @. params.y_c_agb * s / params.τ_star / ϖ
end


In [None]:

function c_model(t, params=Params())
    if params.sfh == "const"
        return c_model_const(t, params)
    end
end


function c_model_const(t, params=Params())
    Zcc = cc_model_const(t, params.y_c_cc, params)
    
    if params.γ == 1
        Zagb = c_agb_model_gamma1_const(t, params)
    elseif params.γ == 2
        Zagb = c_agb_model_gamma2_const(t, params)
    else
        error("gamma must be 1 or 2")
    end
    
    return Zcc .+ Zagb
end


function o_model(t, params=Params())
    if params.sfh == "const"
        Zcc = cc_model_const(t, params.y_o, params)
    end
    return Zcc 
end

function fe_model(t, params=Params())
    if params.sfh == "const"
        Zfe = fe_model_const(t, params)
    end
    return Zfe 
end

# Comparison with VICE models

In [None]:
sfh_const() = pyfunc(t -> 1 + 0*t)
sfh_exp(tau) = pyfunc(t -> exp(-t / tau))

function get_sfh(params::Params)
    if params.sfh == "const"
        return sfh_const()
    elseif params.sfh == "exp"
        return sfh_exp(params.τ_sfh)
    else
        throw("sfh not known")
    end

end

In [None]:
get_sfh(params)

In [None]:
function set_yields(params; kwargs...)
    surp.yields.set_yields(;
        y_c_cc="Lin",
        zeta_c_cc=params.ζ_c_cc, 
        y0_c_cc=params.y_c_cc,
        Y_c_agb="A", 
        y_fe_ia=params.y_fe_ia,
        y_fe_cc=params.y_fe_cc,
        y_o_cc= params.y_o,

        kwargs_c_agb=pydict(
            tau_agb = params.τ_agb,
            t_D = params.t_d_agb,
            zeta = 0,
            gamma = params.γ,
        ),
        kwargs...
    )
end
        

In [None]:
set_yields(params, verbose=true)

In [None]:
function run_singlezone(params; dt=0.01, t_end=13.2, mode="sfr", RIa="exp", kwargs...)
    set_yields(params; kwargs...)

    sfh = get_sfh(params)
    sz = vice.singlezone(elements=pylist(["o", "mg", "c", "fe"]), func=sfh, mode=mode, dt=dt, 
                         eta=params.η,  tau_star=params.τ_star, Mg0=1)
    
    sz.RIa = RIa
    out = sz.run(pylist(0:dt:t_end), capture=true, overwrite=true)
    
    h = DataFrame(pyconvert(Dict{String, Vector{Float64}}, out.history.todict()))

    add_abund_columns!(h)
    return h
end

In [None]:
function add_abund_columns!(h)
    zo = h[:, "z(o)"]
    zfe = h[:, "z(fe)"]
    zc = h[:, "z(c)"]
    t = h[:, "time"]

    
    
    o_h = gcem.abund_to_brak(pylist(zo), "o") |> py2vec
    c_o = gcem.abund_ratio_to_brak(pylist(zc ./ zo), "c", "o") |> py2vec
    o_fe = gcem.abund_ratio_to_brak(pylist(zo ./ zfe), "o", "fe") |> py2vec

    h[:, "MG_H"] = o_h
    h[:, "C_MG"] = c_o
    h[:, "MG_FE"] = o_fe

    return h
end

In [None]:
function run_analytic(params, t=LinRange(0, 13.2, 10000))
    df = DataFrame()
    df[:, "time"] = t
    df[:, "z(c)"] = c_model(t, params)
    df[:, "z(o)"] = o_model(t, params)
    df[:, "z(fe)"] = fe_model(t, params)

    add_abund_columns!(df)

    return df
end

In [None]:
function compare_z_t(sz, ana)

    fig = Figure()

    for (i, ele) in enumerate(["fe", "o", "c"])
       
        ax = abund_axis!(fig[i, 1], "Z_$ele")
        p = plot_abund_time!(ax, sz, ana, "z($(lowercase(ele)))")
        if i < 3
            hidexdecorations!(ax, grid=false)
        end
        
        if i == 1
            axislegend(ax, position=:lt)
        end
    end
    local ax

    for (i, ele) in enumerate(["MG_H", "MG_FE", "C_MG"])
        ax = abund_axis!(fig[i, 2], ele)
        p = plot_abund_time!(ax, sz, ana, ele)
        if i < 3
            hidexdecorations!(ax, grid=false)
        end
    end


    fig
end
    

In [None]:
function abund_axis!(gs, abund)
    ax = Axis(gs,
        xlabel="time",
        ylabel=abund,
        xscale=log10,
        limits=(1e-3, 15, nothing, nothing),
    )
end

In [None]:
function plot_abund_time!(ax, sz, ana, abund)

    lines!(ax, sz.time[2:end], sz[2:end, abund], label="VICE")
    lines!(ax, ana.time[2:end], ana[2:end, abund], label="analytic")
end

In [None]:
function plot_abund_time(sz, ana, abund)
    fig = Figure()
    ax = Axis(fig[1,1],
        xlabel="time",
        ylabel=abund,
        xscale=log10,
        limits=(1e-3, 15, nothing, nothing),
    )
    lines!(sz.time[2:end], sz[2:end, abund], label="VICE")
    lines!(ana.time[2:end], ana[2:end, abund], label="analytic")

    Legend(fig[1,2], ax)
    fig
end

In [None]:
function compare_abund(sz, ana)

    fig = Figure()
    ax = Axis(fig[1,1],
        xlabel="time",
        ylabel="MG_FE",
        xscale=log10,
        limits=(1e-3, 13.2, nothing, nothing),
    )
    lines!(sz.time[2:end], sz.MG_FE[2:end], label="VICE")
    lines!(ana.time[2:end], ana.MG_FE[2:end], label="analytic")

    display(fig)


    fig = Figure()
    ax = Axis(fig[1,1],
        xlabel="time",
        ylabel="C_MG",
        xscale=log10,
        limits=(1e-3, 13.2, nothing, nothing),
    )
    lines!(sz.time[2:end], sz.MG_FE[2:end], label="VICE")
    lines!(ana.time[2:end], ana.MG_FE[2:end], label="analytic")

    display(fig)
end

    

In [None]:
params = Params()

In [None]:
sz = run_singlezone(params, mode="sfr");

In [None]:
ana = run_analytic(params);

In [None]:
compare_z_t(sz, ana)

In [None]:
params = Params(
    γ = 2,
    )


In [None]:
sz = run_singlezone(params, mode="sfr");

In [None]:
ana = run_analytic(params);

In [None]:
compare_z_t(sz, ana)

# mcmc fit to caah

In [None]:
using Optim
using Turing
using Distributions


In [None]:
@model function linear_regression(x, y, σ_y)
    # Priors
    α ~ Normal(0, 1)
    β ~ Normal(0, 0.5)
    
    # Likelihood
    for i in 1:length(y)
        y[i] ~ Normal(α + β * x[i], σ_y[i])
    end
end

In [None]:
vice.yields.ccsne.settings("mg")

In [None]:
x = surp.gce_math.MH_to_Z(subgiants.MG_H) |> py2vec
y = surp.gce_math.brak_to_abund_ratio(subgiants.C_MG, "C", "mg") * vice.yields.ccsne.settings("mg") |> py2vec
σ_y = @. py2vec(subgiants.C_MG_ERR.values)  * y / log(10) |> py2vec# TODO: double check this is correct error prop!

filt = (!).(isnan.(x))
filt .&= (!).(isnan.(y))
filt .&= (!).(isnan.(σ_y))
filt .&= (!).(py2vec(subgiants.high_alpha.values))

x = x[filt]
y = y[filt]
σ_y = σ_y[filt];

In [None]:
model = linear_regression(x, y, σ_y)
chain = sample(model, NUTS(0.65), 2000)

In [None]:
plot(chain)

In [None]:
p = plot()
xs = LinRange(0, 2, 1000)

for i in 1:10:length(chain)
    α = chain[:α].data[i]
    β = chain[:β].data[i]
    
    ys = @. α + xs * β
    plot!(xs, ys, color="black", alpha=0.05, lw=1, legend=false)
end

scatter!(x, y, ms=1, msw=0, alpha=0.5)

xlims!(0, 0.05)
ylims!(0, 1e-2)
p

# CAAFE Regression

In [None]:
const solar_z = Dict(key => vice.solar_z(key) for key in vice.solar_z.keys())

In [None]:
function abund_to_brak(abundances, ele, ele2="h")
    if ele2 == "h"
        return @. log10(abundances / solar_z[ele])
    else
        return @. log10(abundances) - log10(solar_z[ele] / solar_z[ele2])
    end
end

In [None]:
function calc_model(params, t=vec(LinRange(0.01, 13, 1000)))
    zo_a = o_model(t, params)
    zfe_a = fe_model(t, params)
    zc_a = c_model(t, params)

    o_h_a = abund_to_brak(zo_a, "o")
    c_o_a = abund_to_brak(zc_a ./ zo_a, "c", "o")
    o_fe_a = abund_to_brak(zo_a ./ zfe_a, "o", "fe")
    
    return o_fe_a, c_o_a, ones(length(t))
        
end

In [None]:
function mvn_prob(x, y, xerr, yerr, xm, ym)
    return exp(-1/2 * ((y - ym)^2/yerr^2 + (x-xm)^2/xerr^2))
end

In [None]:
function log_L(x, y, δx::AbstractVector, δy::AbstractVector, x_pred, y_pred, w)    
    Nd = length(x) 
    Σs = Vector(undef, Nd)
    
    for i in 1:Nd
        s = @. w * mvn_prob(x[i], y[i], δx[i], δy[i], x_pred, y_pred)
        Σs[i] = sum(s)
    end
        
    ll = sum( log.(Σs) ) # product for each data point
    return ll
end

function log_L(x, y, δx::Real, δy::Real, x_pred, y_pred, w)    
    Nd = length(x) 
    Σs = Vector(undef, Nd)
    
    for i in 1:Nd
        s = @. w * mvn_prob(x[i], y[i], δx, δy, x_pred, y_pred)
        Σs[i] = sum(s)
    end
        
    ll = sum( log.(Σs) ) # product for each data point
    return ll
end

In [None]:
@model function singlezone_regression(x, y)
    # Priors
    ζ_c = 0
    
    t_d_agb ~ LogNormal(log(0.1), 0.4)
    f_agb ~ Beta(1, 1.5)
    τ_agb ~ LogNormal(log(0.3), 0.1)
    η ~ Exponential(1)
    σ_x ~ LogNormal(log(0.2), 0.5)
    σ_y ~ LogNormal(log(0.5), 0.5)

    if DynamicPPL.leafcontext(__context__) !== Turing.PriorContext()
        params = to_params(f_agb=f_agb, τ_agb=τ_agb, t_d_agb=t_d_agb, η=η)
        x_pred, y_pred, w = calc_model(params)
        Turing.@addlogprob! log_L(x, y, σ_x, σ_y, x_pred, y_pred, w)
    end
end


model_params = [:f_agb, :τ_agb, :t_d_agb, :η]
function to_params(; f_agb, kwargs...)
    y_c_tot = 2.3e-3

    y_c_cc = (1-f_agb) * y_c_tot
    y_c_agb = f_agb * y_c_tot
    return Params(y_c_cc=y_c_cc, y_c_agb=y_c_agb; kwargs...)
end

In [None]:
function to_params(row::DataFrameRow)
    model_kwargs = Dict()
    for name in names(row)
        name = Symbol(name)
        if name ∈ model_params
            model_kwargs[name] = row[name]
        end

    end

    params = to_params(;model_kwargs...)
end

In [None]:
function plot_chain(chain; color="black", legend=false, alpha=0.1, kwargs...)
    df = DataFrame(chain)
    p = plot()
    for i in 1:size(df, 1)
        params = to_params(df[i, :])
        x_pred, y_pred, _ = calc_model(params)
        plot!(x_pred, y_pred; 
            color=color, label="model", legend=false,  alpha=alpha,
            kwargs...)
    end
    return p
end

## Sample the prior

In [None]:
chain = sample(singlezone_regression(missing, missing), Prior(), 1000)

In [None]:
plot(chain)

In [None]:
plot_chain(chain)

In [None]:
x = subgiants.MG_FE.values
y = subgiants.C_MG.values
σ_y = subgiants.C_MG_ERR.values 
σ_x = subgiants.MG_FE_ERR.values 

filt = @. !isnan(x)
@. filt &= !isnan(y)
@. filt &= !isnan(σ_y)
@. filt &= !isnan(σ_x)

filt .&= σ_x .< 0.05
filt .&= σ_y .< 0.1

x = x[filt]
y = y[filt]
σ_y = σ_y[filt];
σ_x = σ_x[filt];

In [None]:
model = singlezone_regression(x, y)

In [None]:
map_estimate = optimize(model, MAP())

In [None]:
@profilehtml sample(model, NUTS(0.65), 20, initial_params=map_estimate.values)

In [None]:
chain = sample(model, HMC(0.1, 5), 100, initial_params=map_estimate.values, ϵ=0.2)

In [None]:
plot(chain)

In [None]:
plot_chain(chain)
scatter!(x, y, ms=2, msw=0, alpha=0.1)
xlims!(-0.1, 0.6)
ylims!(-0.5, 0.2)