In [None]:
using Turing
using StatsBase
using LinearAlgebra: diagm
using CairoMakie
using Arya

In [None]:
using DataFrames, PairPlots

In [None]:
using FillArrays


In [None]:
import FillArrays: I

In [None]:
true_params = (;z1=0.25, z2=0.4, z3=-0.3, sx=0.05, sy=0.5)

In [None]:
function make_observations(N=1000, z1=0.25, z2=0.4, z3=-0.3, sx=0.05, sy=0.5, s_int_x=1, x0=0)

    x = x0 .+ randn(N) * s_int_x
    y = @. z1 + z2*x + z3*x^2

    xo = x .+ sx .* randn(N)
    yo = y .+ sy .* randn(N)

    return xo, yo
    
end

In [None]:
function plot_mcmc_results(chain::Chains; bins::Int = 30)
    # Extract parameter names
    params = chain.name_map.parameters
    nparams = length(params)
    
    # Determine grid layout
    nrows = ceil(Int, nparams)
    
    # Initialize the figure
    fig = Figure(size = (600, 200 * nrows), 
                backgroundcolor = :white)
    
    # Determine number of chains
    nchains = size(chain, 3)
        
    # Iterate over each parameter
    for (i, param) in enumerate(params)

        # Trace Plot
        ax_trace = Axis(fig[i, 1],
            xlabel = "Iteration",
            ylabel = "$param",
            xgridvisible=false,
            ygridvisible=false
            )
        
        # Plot each chain's trace
        for c in 1:nchains
            samples = chain[:, i, c]
            lines!(ax_trace, collect(1:length(samples)), samples, color = c, colorrange=(0, nchains), label = "Chain $c")
        end
        
        # Add legend only once
        if i == 1 && nchains > 1
            axislegend(ax_trace, position = :rt)
        end
        
        # Histogram Plot
        ax_hist = Axis(fig[i, 2],
            limits=(0, nothing, nothing, nothing),
            xgridvisible=false,
            ygridvisible=false
        )
        hidedecorations!(ax_hist)
        linkyaxes!(ax_trace, ax_hist)

        
        # Combine samples from all chains for histogram
        combined_samples = vec(chain[:, i, :])
        hist!(ax_hist, direction=:x, combined_samples, bins = bins)
        
        if i < nparams
            hidexdecorations!(ax_trace, ticks=false)
        end
    end

    colgap!(fig.layout, 1, 0)
    rowgap!(fig.layout, 0)

    colsize!(fig.layout, 2, Relative(0.25))

    return fig
end

In [None]:
obs_x, obs_y = make_observations()

In [None]:
scatter(obs_x, obs_y)

In [None]:
@model function analytic_model(obs_x, obs_y)
    z1 ~ Normal(0, 1)
    z2 ~ Normal(0, 1)
    z3 ~ Normal(0, 1)

    sigma2 ~ truncated(Normal(0, 1); lower=0)
    
    y_pred =  z1 .+ z2 * obs_x  .+ z3 * obs_x .^2

    obs_y ~ MvNormal(y_pred, sigma2 * I )

    return
end

## MH
The metropolis-hastings algorithm is fairly straightforward. In each step, a random sample is drawn from the prior. If the sample has a higher likelihood, then it is accepted, otherwise not.

In [None]:
model = analytic_model(obs_x, obs_y)

In [None]:
chain = sample(model, MH(), 100_000)

In [None]:
plot_mcmc_results(chain)

In [None]:
pairplot(chain)

## IS()
Importance sampling samples from the prior and calculates the log likelihood. Much harder to make nice plots to visualize :://.

In [None]:
model = analytic_model(obs_x, obs_y)

In [None]:
samples[argmax(lp), :]

In [None]:
chain = sample(model, IS(), 100_000)

In [None]:
samples = DataFrame(chain)

In [None]:
pairplot(chain)

In [None]:
@model function gdemo(x)
    s² ~ InverseGamma(2,3)
    m ~ Normal(0,sqrt.(s²))
    x[1] ~ Normal(m, sqrt.(s²))
    x[2] ~ Normal(m, sqrt.(s²))
    return s², m
end

chain = sample(gdemo([-10, -2]), IS(), 1000)

In [None]:
plot_mcmc_results(chain)

In [None]:
pairplot(chain, weights=chain.logevidence)

## SMC

In [None]:
import AdvancedPS

In [None]:
?AdvancedPS.ResampleWithESSThreshold

In [None]:
model = analytic_model(obs_x, obs_y)

In [None]:
sam = SMC(AdvancedPS.ResampleWithESSThreshold(
        AdvancedPS.resample_multinomial, 
        1
        ))


In [None]:

chain = sample(model, sam, 1000)

In [None]:
unique(chain[:, 2, :]), length(chain[:, 2, :])

In [None]:
plot_mcmc_results(chain)

In [None]:
pairplot(chain)

## PG

In [None]:
import AdvancedPS

In [None]:
model = analytic_model(obs_x, obs_y)

In [None]:
sam  = AdvancedPS.PG(14, 1.0) |> externalsampler

In [None]:
model = analytic_model(obs_x, obs_y)

In [None]:
chain = sample(model, PG(30), 1000)

In [None]:
DataFrame(chain)

In [None]:
pairplot(chain)

In [None]:
plot_mcmc_results(chain)

## SG models

In [None]:
model = analytic_model(obs_x, obs_y)

In [None]:
?Variation

In [None]:
chain = sample(model, vi, 1000)

In [None]:
DataFrame(chain)

In [None]:
pairplot(chain)

In [None]:
plot_mcmc_results(chain)

## HMC

In [None]:
model = analytic_model(obs_x, obs_y)

In [None]:
chain = sample(model, HMC(0.01, 3), 10_000)

In [None]:
plot_mcmc_results(chain[100:end])

In [None]:
pairplot(chain[100:end])

## NUTS

In [None]:
model = analytic_model(obs_x, obs_y)

In [None]:
chain = sample(model, NUTS(0.65), 10_000)

In [None]:
pairplot(chain)

## RMH

In [None]:
using AdvancedMH: RWMH
using AdvancedMH

In [None]:
s_walk = 0.003
rw_prop = RandomWalkProposal(MvNormal([0,0,0,0], s_walk * I))


In [None]:
chain = sample(model, MH(rw_prop), MCMCThreads(), 10_000, 8)

In [None]:
pairplot(chain[1000:end])

In [None]:
plot_mcmc_results(chain[1000:1200])

# VI

In [None]:
model = analytic_model(obs_x, obs_y)

In [None]:
using Turing: Variational


In [None]:
?Variational.ADVI

In [None]:
?Variational.DecayedADAGrad

In [None]:
?ADVI

In [None]:
vi_type =  ADVI(100, 1000)

vi_result = vi(model, vi_type) 

In [None]:
z = rand(vi_result, 100_000)

In [None]:
df = DataFrame()
model_labels = bijector(model, Val(true))[2]

for label in keys(model_labels)
    idx = getproperty(model_labels, label)

    if length(idx) == 1
        df[!, label] = z[idx[1][1], :]
    end
end
df

In [None]:
true_params

In [None]:
pairplot(df)

# Emcee

In [None]:
using PythonCall
emcee = pyimport("emcee")

In [None]:
?bijector

In [None]:
model_labels = bijector(model, Val(true))[2]


In [None]:
[3:3][1]

In [None]:
variables = (:z1, :z2, :z3, :sigma2)

In [None]:
model = analytic_model(obs_x, obs_y)

In [None]:
using OrderedCollections

In [None]:
function py_log_prob(theta)
    vars = (z1=theta[1], z2=theta[2], z3=theta[3], sigma2=theta[4])
    if vars.sigma2 < 0
        return -Inf
    end
        
    return logjoint(model, vars)
end

In [None]:
ndim = length(variables)
nwalkers = 16

priors = sample(model, Prior(),  nwalkers)
p0 = priors.value[:, 1:ndim, 1].data

In [None]:
sammy = emcee.EnsembleSampler(nwalkers, ndim, py_log_prob)

In [None]:
sammy.run_mcmc(p0, 1000, progress=true)

In [None]:
chain = pyconvert(Array{Float64}, sammy.get_chain())
chain = permutedims(chain, (1, 3,2))

In [None]:
chain = Chains(chain, [variables...])

In [None]:
burn = 200

In [None]:
plot_mcmc_results(chain[burn:end])

In [None]:
pairplot(chain[burn:end])