In [None]:
using Distributions 
using IntervalSets
using ValueShapes
using Plots
using ArraysOfArrays
using StatsBase 
using LinearAlgebra
using Random123
using HCubature

using KDTree
using BAT
using BATPar

pyplot(size=(750,500))

# Density Function: 

In [None]:
sigma_1 = [0.32716446841097613 0.17276467616026275; 0.17276467616026275 0.33056237691918017]
sigma_2 = [0.15152028283087893 -0.11598742336072133; -0.11598742336072133 0.1616797732933265] #[0.1572026439007445 -0.1044956058704804; -0.1044956058704804 0.13445945463874312] 
sigma_3 = [0.01942201849281335 -0.003187584896683795; -0.003187584896683795 0.017175237584791444]

N = 2 
min_v = -10
max_v = 10
lgV = N*log(max_v-min_v); 

f(x;μ=[0, 0], sigma=sigma) = pdf(MvNormal(μ, sigma), x)

g(x) = f(x, μ=[1, 1], sigma=sigma_1) + f(x, μ=[-1, -1], sigma=sigma_1) + 0.1*f(x, μ=[1.5, -1.5], sigma=sigma_3) + 0.9*f(x, μ=[-1.3, 1.3], sigma=sigma_2)

# Seed Samples: 

In [None]:
nnsamples = 10^2
nnchains = 10

likelihood = params -> LogDVal((log(g(params.a))))
prior = NamedTupleDist(a = [[min_v .. max_v for i in 1:N]...],)

posterior = PosteriorDensity(likelihood, prior)
samples, stats = bat_sample(posterior, (nnsamples, nnchains), MetropolisHastings());

In [None]:
smpl = flatview(unshaped.(samples.v))
weights_LogLik = samples.logd
weights_Histogram = samples.weight;

data_kdtree = Data(smpl[:,1:end], weights_Histogram[1:end], weights_LogLik[1:end]);

In [None]:
x_range = range(-3, stop=3, length=100)
y_range = range(-3, stop=3, length=100)

z = [g([i,j]) for i in x_range, j in y_range]';

levels_quantiles = [0.2, 0.3, 0.4, 0.5, 0.7,  0.8,  0.85, 0.9, 0.95, 0.99, 0.999, 1,]
levels=quantile([z...], levels_quantiles)

contour(x_range, y_range, z; fill=true,levels=levels, fillalpha=0.2, color=:blues)
scatter!(smpl[1,:], smpl[2,:], alpha=1, markerstrokewidth=0, markersize=0.2.*weights_Histogram, color=:black, label="")
plot!(fillalpha=0.2, size=(600, 600), colorbar=false, frame=false, grid=false, xaxis=false, yaxis=false)

# Space Partitioning: 

In [None]:
KDTree.evaluate_total_cost(data::Data) = KDTree.cost_f_1(data)

output, cost_array = DefineKDTree(data_kdtree, [1,2,], 30);

In [None]:
contour(x_range, y_range, z; fill=true,levels=levels, fillalpha=0.2, color=:blues)
scatter!(smpl[1,:], smpl[2,:], alpha=1, markerstrokewidth=0, markersize=0.1.*weights_Histogram, color=:black, label="")
plot!(fillalpha=0.2, size=(600, 600), colorbar=false, frame=true, grid=false)

plot_tree(output, [1,2], linealpha=1, lw=0.4, linecolor=:red)

plot!(xaxis=false, yaxis=false)

# Sampling: 

In [None]:
bounds_part = extract_par_bounds(output)

BATPar.make_named_prior(i) = BAT.NamedTupleDist( a = [i[1,1]..i[1,2], i[2,1]..i[2,2]])

nnsamples = 10^3
nnchains = 5

tuning = AdaptiveMetropolisTuning(
    λ = 0.5,
    α = 0.15..0.45,
    β = 1.5,
    c = 1e-4..1e2
)

algorithm = MetropolisHastings();

samples_parallel = bat_sample_parallel(likelihood, bounds_part, (nnsamples, nnchains), algorithm, tuning=tuning);

In [None]:
smpl_par = hcat(samples_parallel.samples...)
x = smpl_par[1,:]
y = smpl_par[2,:]
w_o = samples_parallel.weights_o
w_r =  samples_parallel.weights_r;

In [None]:
histogram2d(x, y, weights = w_o, bins=300, color = :balance)
plot!(title="Unweighted MCMC Samples", xlabel="x", ylabel="y")
# plot_tree(output, [1,2], linecolor=:red, linealpha=0.3)