In [None]:
using IntervalSets
using Distributions
using Random, LinearAlgebra, Statistics, Distributions, StatsBase, ArraysOfArrays
using JLD2
using ValueShapes
using TypedTables
using Measurements
using TypedTables
using HypothesisTests
using BenchmarkTools

In [None]:
# using PyPlot
using Plots
pyplot()

In [None]:
using Revise
using BAT

In [None]:
JLD2.@load "../data/mixture-9D-nc.jld" means cov_m n_clusters

mixture_model = MixtureModel(MvNormal[MvNormal(means[i,:], Matrix(Hermitian(cov_m[i,:,:])) ) for i in 1:n_clusters]);

likelihood = let model = mixture_model
    params -> begin
         return LogDVal(logpdf(model, params.a))
    end
end

prior = NamedTupleDist(a = [Uniform(-100,100) for i in 1:size(means)[2]])
posterior = PosteriorDensity(likelihood, prior);
log_volume = BAT.log_volume(BAT.spatialvolume(posterior.parbounds))

In [None]:
burnin = MCMCBurninStrategy(
        max_nsamples_per_cycle = 1000,
        max_nsteps_per_cycle = 10000,
        max_time_per_cycle = Inf,
        max_ncycles = 50
    )

tuning = AdaptiveMetropolisTuning(
    λ = 0.5,
    α = 0.05..0.15,
    β = 1.5,
    c = 1e-4..1e2,
    r = 0.5
)

init = MCMCInitStrategy(
    init_tries_per_chain = 8..128,
    max_nsamples_init = 25,
    max_nsteps_init = 250,
    max_time_init = Inf
)

n_chains = 10
n_samples = 3*10^5 # number of samples per chain
max_nsteps = 10^10 # total number of samples from all chains 
max_time = 200 # [seconds] spent on generating samples

In [None]:
s_result = @time bat_sample(posterior, (n_samples, n_chains), MetropolisHastings(), init=init, burnin=burnin, tuning=tuning, max_nsteps=max_nsteps, max_time=max_time);

In [None]:
# s_result = @btime bat_sample($posterior, ($n_samples, $n_chains), MetropolisHastings(), init=$init, burnin=$burnin, tuning=$tuning, max_nsteps=$max_nsteps, max_time=$max_time);

In [None]:
length(s_result.result)

In [None]:
s_result.chains[4].nsamples

In [None]:
[i.nsamples for i in s_result.chains]

In [None]:
n_samples_array = []
t_array = [2, 5, 10, 15, 40, 50,]

for t in t_array
    s_result = @time bat_sample(posterior, (n_samples, n_chains), MetropolisHastings(), init=init, burnin=burnin, tuning=tuning, max_nsteps=max_nsteps, max_time=t);
    append!(n_samples_array, length(s_result.result))
end

In [None]:
scatter(t_array, n_samples_array)