In [None]:
using Turing
using StatsBase
using LinearAlgebra: diagm
using CairoMakie
using Arya

In [None]:
using DataFrames, PairPlots

In [None]:
using FillArrays


In [None]:
import FillArrays: I

In [None]:
function make_observations(N=1000, z1=0.25, z2=0.4, z3=-0.3, sx=0.05, sy=0.5, s_int_x=1, x0=0)

    x = x0 .+ randn(N) * s_int_x
    y = @. z1 + z2*x + z3*x^2

    xo = x .+ sx .* randn(N)
    yo = y .+ sy .* randn(N)

    return xo, yo
    
end

In [None]:
function plot_mcmc_results(chain::Chains; bins::Int = 30)
    # Extract parameter names
    params = chain.name_map.parameters
    nparams = length(params)
    
    # Determine grid layout
    nrows = ceil(Int, nparams)
    
    # Initialize the figure
    fig = Figure(size = (600, 200 * nrows), 
                backgroundcolor = :white)
    
    # Determine number of chains
    nchains = size(chain, 3)
        
    # Iterate over each parameter
    for (i, param) in enumerate(params)

        # Trace Plot
        ax_trace = Axis(fig[i, 1],
            xlabel = "Iteration",
            ylabel = "$param",
            xgridvisible=false,
            ygridvisible=false
            )
        
        # Plot each chain's trace
        for c in 1:nchains
            samples = chain[:, i, c]
            println(typeof(samples))
            lines!(ax_trace, collect(1:length(samples)), samples, color = c, colorrange=(0, nchains), label = "Chain $c")
        end
        
        # Add legend only once
        if i == 1 && nchains > 1
            axislegend(ax_trace, position = :rt)
        end
        
        # Histogram Plot
        ax_hist = Axis(fig[i, 2],
            limits=(0, nothing, nothing, nothing),
            xgridvisible=false,
            ygridvisible=false
        )
        hidedecorations!(ax_hist)
        linkyaxes!(ax_trace, ax_hist)

        
        # Combine samples from all chains for histogram
        combined_samples = vec(chain[:, i, :])
        hist!(ax_hist, direction=:x, combined_samples, bins = bins)
        
        if i < nparams
            hidexdecorations!(ax_trace, ticks=false)
        end
    end

    colgap!(fig.layout, 1, 0)
    rowgap!(fig.layout, 0)

    colsize!(fig.layout, 2, Relative(0.25))

    return fig
end

In [None]:
obs_x, obs_y = make_observations()

In [None]:
scatter(obs_x, obs_y)

In [None]:
@model function analytic_model(obs_x, obs_y)
    z1 ~ Normal(0, 1)
    z2 ~ Normal(0, 1)
    z3 ~ Normal(0, 1)

    sigma2 ~ truncated(Normal(0, 1); lower=0)
    
    y_pred =  z1 .+ z2 * obs_x  .+ z3 * obs_x .^2

    obs_y ~ MvNormal(y_pred, sigma2 * I )

    return
end

## MH

In [None]:
model = analytic_model(obs_x, obs_y)

In [None]:
chain = sample(model, MH(), 100_000)

In [None]:
samples = DataFrame(chain)

In [None]:
pairplot(chain)

## IS()

In [None]:
model = analytic_model(obs_x, obs_y)

In [None]:
chain = sample(model, IS(), 10_000)

In [None]:
samples = DataFrame(chain)

In [None]:
pairplot(chain)

## PG

In [None]:
model = analytic_model(obs_x, obs_y)

In [None]:
chain = sample(model, PG(30), MCMCThreads(), 100, 8)

In [None]:
DataFrame(chain)

In [None]:
pairplot(chain)

## SMC

In [None]:
model = analytic_model(obs_x, obs_y)

In [None]:
chain = sample(model, SMC(), MCMCThreads(), 100, 8)

In [None]:
samples = DataFrame(chain)

In [None]:
pairplot(chain)

## HMC

In [None]:
model = analytic_model(obs_x, obs_y)

In [None]:
chain = sample(model, HMC(0.001, 3), 10_000)

In [None]:
samples = DataFrame(chain)

In [None]:
pairplot(chain)

## NUTS

In [None]:
model = analytic_model(obs_x, obs_y)

In [None]:
chain = sample(model, NUTS(0.65), 10_000)

In [None]:
pairplot(chain)

## RMH

In [None]:
using AdvancedMH: RWMH
using AdvancedMH

In [None]:
s_walk = 1e-3
rw_prop = RandomWalkProposal(MvNormal([0,0,0,0], s_walk * I))


In [None]:
chain = sample(model, MH(rw_prop), 30_000, num_warmup=1000)

In [None]:
pairplot(chain)

In [None]:
plot_mcmc_results(chain)