# Turing Target Practice

## Model

In [1]:
] activate "."

[32m[1m  Activating[22m[39m project at `~/PhD/MicroCanonicalHMC.jl`


In [2]:
# The statistical inference frame-work we will use
using Turing
using Interpolations
using ForwardDiff
using LinearAlgebra
#using StatsPlots
using PyPlot
using Distributed

using Revise
using MicroCanonicalHMC

In [3]:
fs8_zs = [0.38, 0.51, 0.61, 1.48, 0.44, 0.6, 0.73, 0.6, 0.86, 0.067, 1.4]
fs8_data = [0.49749, 0.457523, 0.436148, 0.462, 0.413, 0.39, 0.437, 0.55, 0.4, 0.423, 0.482]
fs8_cov = [0.00203355 0.000811829 0.000264615 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0;
           0.000811829 0.00142289 0.000662824 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 
           0.000264615 0.000662824 0.00118576 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0;
           0.0 0.0 0.0 0.002025 0.0 0.0 0.0 0.0 0.0 0.0 0.0;
           0.0 0.0 0.0 0.0 0.0064 0.00257 0.0 0.0 0.0 0.0 0.0;
           0.0 0.0 0.0 0.0 0.00257 0.003969 0.00254 0.0 0.0 0.0 0.0;
           0.0 0.0 0.0 0.0 0.0 0.00254 0.005184 0.0 0.0 0.0 0.0;
           0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0144 0.0 0.0 0.0;
           0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0121 0.0 0.0; 
           0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.003025 0.0;
           0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.013456000000000001];

In [4]:
function make_fs8(Ωm, σ8; Ωr=8.24*10^-5)
    # ODE solution for growth factor
    x_Dz = LinRange(0, log(1+1100), 300)
    dx_Dz = x_Dz[2]-x_Dz[1]
    z_Dz = @.(exp(x_Dz) - 1)
    a_Dz = @.(1/(1+z_Dz))
    aa = reverse(a_Dz)
    e =  @.(sqrt.(abs(Ωm)*(1+z_Dz)^3+Ωr*(1+z_Dz)^4+(1-Ωm-Ωr)))
    ee = reverse(e)

    dd = zeros(typeof(Ωm), 300)
    yy = zeros(typeof(Ωm), 300)
    dd[1] = aa[1]
    yy[1] = aa[1]^3*ee[end]

    for i in 1:(300-1)
        A0 = -1.5 * Ωm / (aa[i]*ee[i])
        B0 = -1. / (aa[i]^2*ee[i])
        A1 = -1.5 * Ωm / (aa[i+1]*ee[i+1])
        B1 = -1. / (aa[i+1]^2*ee[i+1])
        yy[i+1] = (1+0.5*dx_Dz^2*A0*B0)*yy[i] + 0.5*(A0+A1)*dx_Dz*dd[i]
        dd[i+1] = 0.5*(B0+B1)*dx_Dz*yy[i] + (1+0.5*dx_Dz^2*A0*B0)*dd[i]
    end

    y = reverse(yy)
    d = reverse(dd)

    Dzi = LinearInterpolation(z_Dz, d./d[1], extrapolation_bc=Line())
    fs8zi = LinearInterpolation(z_Dz, -σ8 .* y./ (a_Dz.^2 .*e.*d[1]),
                                 extrapolation_bc=Line())
    return fs8zi
end

make_fs8 (generic function with 1 method)

In [5]:
@model function model(data; cov = fs8_cov) 
    # Define priors
    #KiDS priors
    Ωm ~ Uniform(0.2, 0.25)#~ Normal(0.3, 0.1)
    σ8 ~ Normal(0.8, 0.3)
    fs8_itp = make_fs8(Ωm, σ8)
    theory = fs8_itp(fs8_zs)
    data ~ MvNormal(theory, cov)
end;

In [6]:
stat_model = model(fs8_data)

DynamicPPL.Model{typeof(model), (:data, :cov), (:cov,), (), Tuple{Vector{Float64}, Matrix{Float64}}, Tuple{Matrix{Float64}}, DynamicPPL.DefaultContext}(model, (data = [0.49749, 0.457523, 0.436148, 0.462, 0.413, 0.39, 0.437, 0.55, 0.4, 0.423, 0.482], cov = [0.00203355 0.000811829 … 0.0 0.0; 0.000811829 0.00142289 … 0.0 0.0; … ; 0.0 0.0 … 0.003025 0.0; 0.0 0.0 … 0.0 0.013456000000000001]), (cov = [0.00203355 0.000811829 … 0.0 0.0; 0.000811829 0.00142289 … 0.0 0.0; … ; 0.0 0.0 … 0.003025 0.0; 0.0 0.0 … 0.0 0.013456000000000001],), DynamicPPL.DefaultContext())

## Sampling

In [None]:
target = TuringTarget(stat_model);

In [None]:
spl = MCHMC(0.0, 0.0, varE_wanted=2.0) #sqrt(target.d)*0.01

In [None]:
spl.hyperparameters

In [None]:
samples_mchmc = Sample(spl, target, 10000;
                       monitor_energy=true, 
                       dialog=true)

In [None]:
#plt.plot(samples_mchmc.E[8000:end])

In [None]:
#mean(samples_mchmc.E[8000:end])

In [None]:
#std(samples_mchmc.E[8000:end])^2/target.d

In [None]:
Wms_mchmc = [sample[1] for sample in samples_mchmc]
s8s_mchmc = [sample[2] for sample in samples_mchmc];

In [None]:
plt.hist2d(Wms_mchmc, s8s_mchmc, bins=100, range=[[0.1, 0.4],[0.6, 1.2]]);
plt.xlabel("Wm")
plt.ylabel("s8")
plt.title("MCHMC - RSD model - eps = 0.001 --> var[E]~0.00003");

## AbstractMCMC

In [None]:
samples = sample(
    stat_model, MCHMC(varE_wanted=2.0), 10000;
    monitor_energy=true, 
    dialog=true)

In [None]:
new_samples = sample(stat_model, MCHMC(varE_wanted=2.0), 10000; 
                     monitor_energy=true, 
                     progress=true, resume_from=samples)

## Parallelization

In [153]:
using Distributed

In [154]:
spl = MCHMC(; nchains=100)
target = TuringTarget(stat_model);

In [155]:
loss, xs, us, ls, gs = MicroCanonicalHMC.Init_burnin(spl, target)

(1132.6716096004593, [-1.491351473618569 1.5372672176687452; 0.17197523066762627 0.6958364506682089; … ; -0.646921856499523 0.6108276211896924; -2.2243512416742592 1.1342337124013249], [-0.002014568281471904 -0.20380080702327352; 0.00035317249577067306 0.05357924754428543; … ; 0.0006641722974846554 0.08176696210844785; -2.126911297950547e-5 -0.07737882187981746], [133.32015010185867 133.32015010185867; -6.403448910335451 -6.403448910335451; … ; 7.352765530683344 7.352765530683344; 6.349495796475535 6.349495796475535], [4.474420595815421 452.6481116452825; -0.7844074105034862 -119.0012227062789; … ; -1.4751479184731893 -181.6070384310777; 0.047239380290794886 171.86083861403077])

In [72]:
samples = sample(
    stat_model, MCHMC(varE_wanted=2.0), MCMCThreads(), 30000, 4;
    monitor_energy=true, 
    dialog=true)

eps: 0.5 --> VarE: 1046.941837907961
eps: 0.25 --> VarE: 5.62000092062071
eps: 0.125 --> VarE: 23.662652912868822
eps: 0.0625 --> VarE: 4.362158101099759
eps: 0.03125 --> VarE: 1.5964085748779815
samples: 100--> ESS: 0.024351540287992274
samples: 243--> ESS: 0.020186428268599264
samples: 447--> ESS: 0.02418849942982405


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mTuning eps ⏳
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFound eps: 0.03125 ✅
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mTuning L ⏳
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mFound L: 1.4142135623730951 ✅
[32mchain 1:   4%|█▌                                        |  ETA: 0:00:03[39m
[36mchain 4:   4%|█▌                                        |  ETA: 0:00:03[39m
[34mchain 2:   4%|█▌                                        |  ETA: 0:00:03[39m
[35mchain 3:   4%|█▋                                        |  ETA: 0:00:03[39m

[36mchain 4:   7%|███                                       |  ETA: 0:00:03[39m
[32mchain 1:   7%|███                                       |  ETA: 0:00:03[39m
[34mchain 2:   7%|███                                       |  ETA: 0:00:03[39m
[35mchain 3:   8%|███▏                                      |  ETA: 0:00:03[39m
[36mchain 4:  10%|████▎                                     |  ETA: 0:00:03[39m
[

[36mchain 4:  77%|████████████████████████████████▍         |  ETA: 0:00:01[39m
[32mchain 1:  78%|████████████████████████████████▊         |  ETA: 0:00:01[39m
[34mchain 2:  78%|████████████████████████████████▊         |  ETA: 0:00:01[39m
[35mchain 3:  81%|██████████████████████████████████        |  ETA: 0:00:01[39m
[36mchain 4:  80%|█████████████████████████████████▋        |  ETA: 0:00:01[39m
[32mchain 1:  81%|██████████████████████████████████        |  ETA: 0:00:01[39m
[34mchain 2:  81%|██████████████████████████████████        |  ETA: 0:00:01[39m
[35mchain 3:  84%|███████████████████████████████████▎      |  ETA: 0:00:01[39m
[36mchain 4:  83%|██████████████████████████████████▉       |  ETA: 0:00:01[39m

[32mchain 1:  84%|███████████████████████████████████▌      |  ETA: 0:00:01[39m
[34mchain 2:  84%|███████████████████████████████████▎      |  ETA: 0:00:01[39m
[35mchain 3:  87%|████████████████████████████████████▌     |  ETA: 0:00:00[39m
[36mchain 4:  

4-element Vector{Chains}:
 MCMC chain (30000×4×1 Array{Float64, 3})
 MCMC chain (30000×4×1 Array{Float64, 3})
 MCMC chain (30000×4×1 Array{Float64, 3})
 MCMC chain (30000×4×1 Array{Float64, 3})

In [None]:
new_samples = sample(
    stat_model, MCHMC(varE_wanted=2.0), MCMCThreads(), 30000, 4;
    monitor_energy=true, 
    dialog=true, 
    resume_from=samples)

## NUTS

In [None]:
samples_hmc = sample(stat_model, NUTS(500, 0.65),
                     10000, progress=true; save_state=true)

In [None]:
samples_hmc.value

In [None]:
Wms_hmc = vec(samples_hmc["Ωm"])
s8s_hmc = vec(samples_hmc["σ8"]);

In [None]:
plt.hist2d(Wms_hmc, s8s_hmc, bins=100, range=[[0.1, 0.4],[0.6, 1.2]]);
plt.xlabel("Wm")
plt.ylabel("s8")
plt.title("HMC - RSD model");