In [None]:
using CairoMakie
using Turing
using CSV, DataFrames

In [None]:
using PairPlots

In [None]:
subgiants = CSV.read("../data/subgiants.csv", DataFrame)

In [None]:
using FillArrays

In [None]:
FillArrays.I

In [None]:
idx = subgiants.MG_H .> -1
idx .&= .!subgiants.high_alpha

x = subgiants.MG_H[idx]
y = subgiants.C_MG[idx];
x_e = subgiants.MG_H_ERR[idx];
y_e = subgiants.C_MG_ERR[idx];

In [None]:
sum(idx), sum(.! idx)

In [None]:
scatter(x, y)

# Linear Model

In [None]:
@model function linear_regression(x, y, x_e, y_e)
    log_σ ~ Normal(-2, 0.8)
    a ~ Normal(0, 0.4)
    b ~ Normal(0, 0.5)

    mu = a .+ x*b

    s_int = exp(2log_σ)
    s_x = b .* x_e
    s_y = y_e

    σ = @. sqrt(s_int^2 + s_x^2 + s_y^2)
    return y ~ MvNormal(mu, σ)

end

In [None]:
function plot_samples!(samples, x;
        thin=10, color=:black, alpha=nothing, kwargs...)

    alpha = 1/size(samples, 1)^(1/3)

    for sample in eachrow(samples)[1:thin:end]
        y = @. sample.a + sample.b * x
        lines!(x, y, color=color, alpha=alpha)
    end
end

In [None]:
model = linear_regression(x, y, x_e, y_e)

In [None]:
chain = sample(model, NUTS(), 5_000)

In [None]:
samples = DataFrame(chain)

In [None]:
pairplot(chain)

In [None]:
fig = Figure()
ax = Axis(fig[1, 1],
    limits=(-0.5, 0.5, -0.6, 0.2)
)
scatter!(x, y, markersize=3, alpha=0.3)

plot_samples!(samples, LinRange(-0.5, 0.5, 100))

fig

# Log - Lin models

In [None]:

x = 10 .^ subgiants.MG_H[idx]
y = subgiants.C_MG[idx];
x_e = x .* log(10) .* subgiants.MG_H_ERR[idx];
y_e = subgiants.C_MG_ERR[idx];

In [None]:
function plot_data()
    fig = Figure()
    ax = Axis(fig[1, 1], xlabel="[Mg/H]", ylabel="C/Mg")

    scatter!(x, y, markersize=3, alpha=0.3)

    fig
end

In [None]:
plot_data()

## Linear model

In [None]:
@model function linear_regression(x, y, x_e, y_e)
    log_σ ~ Normal(-2, 0.8)
    a ~ Normal(0, 0.4)
    b ~ Normal(0, 0.5)

    mu = a .+ x*b

    s_int = exp(2log_σ)
    s_x = b .* x_e
    s_y = y_e

    σ = @. sqrt(s_int^2 + s_x^2 + s_y^2)
    return y ~ MvNormal(mu, σ)

end

In [None]:
model = linear_regression(x, y, x_e, y_e)

In [None]:
chain = sample(model, NUTS(), 5_000)

In [None]:
samples = DataFrame(chain);

In [None]:
pairplot(chain)

In [None]:
function plot_samples!(samples, x;
        thin=10, color=:black, alpha=nothing, kwargs...)

    alpha = 1/size(samples, 1)^(1/3)

    for sample in eachrow(samples)[1:thin:end]
        y = @. sample.a + sample.b * x
        lines!(x, y, color=color, alpha=alpha)
    end
end

In [None]:
fig = plot_data()

plot_samples!(samples, LinRange(0, 4, 100))

fig

## Exponential

In [None]:
@model function exp_regression(x, y, x_e, y_e)
    log_σ ~ Normal(-2, 0.8)
    a ~ Normal(0, 0.4)
    b ~ Normal(0, 0.5)
    tau ~ Exponential(6)

    mu = @. a + b*exp(x/tau)
    mu_p = @. b/tau * exp(x/tau)
    
    s_int = exp(2log_σ)
    s_x = @. mu_p / mu * x_e
    s_y = y_e

    σ = @. sqrt(s_int^2 + s_x^2 + s_y^2)
    return y ~ MvNormal(mu, σ)

end

In [None]:
model = exp_regression(x[1:100], y[1:100], x_e[1:100], y_e[1:100])

In [None]:
chain = sample(model, NUTS(), 5_000)

In [None]:
samples = DataFrame(chain);

In [None]:
pairplot(chain)

In [None]:
function plot_samples!(samples, x;
        thin=10, color=:black, alpha=nothing, kwargs...)

    alpha = 1/size(samples, 1)^(1/3)

    for sample in eachrow(samples)[1:thin:end]
        y = @. sample.a + sample.b * exp(x / sample.tau)
        lines!(x, y, color=color, alpha=alpha)
    end
end

In [None]:
fig = plot_data()

plot_samples!(samples, LinRange(0, 4, 100))

fig

# Lin Lin Modelx

In [None]:

x = 10 .^ subgiants.MG_H[idx]
y = 10 .^ subgiants.C_MG[idx];
x_e = x .* log(10) .* subgiants.MG_H_ERR[idx];
y_e = y .* log(10) .* subgiants.C_MG_ERR[idx];

In [None]:
function plot_data()
    fig = Figure()
    ax = Axis(fig[1, 1], xlabel="Mg", ylabel="C/Mg")

    scatter!(x, y, markersize=3, alpha=0.3)

    fig
end

In [None]:
plot_data()

## Linear

In [None]:
@model function linear_regression(x, y, x_e, y_e)
    log_σ ~ Normal(-2, 0.8)
    a ~ Normal(0, 0.4)
    b ~ Normal(0, 0.5)

    mu = a .+ x*b

    s_int = exp(2log_σ)
    s_x = b .* x_e
    s_y = y_e

    σ = @. sqrt(s_int^2 + s_x^2 + s_y^2)
    return y ~ MvNormal(mu, σ)

end

In [None]:
model = linear_regression(x, y, x_e, y_e)

In [None]:
chain = sample(model, NUTS(), 5_000)

In [None]:
samples = DataFrame(chain);

In [None]:
pairplot(chain)

In [None]:
function plot_samples!(samples, x;
        thin=10, color=:black, alpha=nothing, kwargs...)

    alpha = 1/size(samples, 1)^(1/3)

    for sample in eachrow(samples)[1:thin:end]
        y = @. sample.a + sample.b * x
        lines!(x, y, color=color, alpha=alpha)
    end
end

In [None]:
fig = plot_data()

plot_samples!(samples, LinRange(0, 4, 100))

fig

## Exponential

In [None]:
@model function exp_regression(x, y, x_e, y_e)
    log_σ ~ Normal(-2, 0.8)
    a ~ Normal(0, 0.4)
    b ~ Normal(0, 0.5)
    tau ~ Exponential(6)

    mu = @. a + b*exp(x/tau)
    mu_p = @. b/tau * exp(x/tau)
    
    s_int = exp(2log_σ)
    s_x = @. mu_p / mu * x_e
    s_y = y_e

    σ = @. sqrt(s_int^2 + s_x^2 + s_y^2)
    return y ~ MvNormal(mu, σ)

end

In [None]:
model = exp_regression(x[1:100], y[1:100], x_e[1:100], y_e[1:100])

In [None]:
chain = sample(model, NUTS(), 5_000)

In [None]:
samples = DataFrame(chain);

In [None]:
pairplot(chain)

In [None]:
function plot_samples!(samples, x;
        thin=10, color=:black, alpha=nothing, kwargs...)

    alpha = 1/size(samples, 1)^(1/3)

    for sample in eachrow(samples)[1:thin:end]
        y = @. sample.a + sample.b * exp(x / sample.tau)
        lines!(x, y, color=color, alpha=alpha)
    end
end

In [None]:
fig = plot_data()

plot_samples!(samples, LinRange(0, 4, 100))

fig

## Quadratic

In [None]:
@model function exp_regression(x, y, x_e, y_e)
    log_σ ~ Normal(-2, 0.8)
    a ~ Normal(0, 0.4)
    b ~ Normal(0, 0.5)
    c ~ Normal(0, 0.3)

    mu = @. a + b*x + c*x^2
    mu_p = @. b + 2*c*x
    
    s_int = exp(2log_σ)
    s_x = @. mu_p / mu * x_e
    s_y = y_e

    σ = @. sqrt(s_int^2 + s_x^2 + s_y^2)
    return y ~ MvNormal(mu, σ)

end

In [None]:
model = exp_regression(x, y, x_e, y_e)

In [None]:
chain = sample(model, NUTS(), 5_000)

In [None]:
samples = DataFrame(chain);

In [None]:
pairplot(chain)

In [None]:
function plot_samples!(samples, x;
        thin=10, color=:black, alpha=nothing, kwargs...)

    alpha = 1/size(samples, 1)^(1/3)

    for sample in eachrow(samples)[1:thin:end]
        y = @. sample.a + sample.b *x + sample.c * x^2
        lines!(x, y, color=color, alpha=alpha)
    end
end

In [None]:
fig = plot_data()

plot_samples!(samples, LinRange(0, 4, 100))

fig