# Partitioned Sampling Tests: 

In [None]:
# using Distributed 
# addprocs(3)

In [None]:
using IntervalSets
using Distributions
using Random, LinearAlgebra, Statistics, Distributions, StatsBase, ArraysOfArrays
using JLD2
using ValueShapes
using TypedTables
using Measurements
using TypedTables
using HypothesisTests

In [None]:
# using PyPlot
using Plots
pyplot()

In [None]:
using Revise
using BAT

## 1. BAT model definition: Gaussian mixture

Let us use a mixture of 4 normal distribution with random  covariance matrices as a model: 

In [None]:
JLD2.@load "../data/mixture-5D.jld" means cov_m n_clusters

mixture_model = MixtureModel(MvNormal[MvNormal(means[i,:], Matrix(Hermitian(cov_m[i,:,:])) ) for i in 1:1]);

In [None]:
prior = NamedTupleDist(a = [Normal(0, 20) for i in 1:size(means)[2]])

# prior = NamedTupleDist(a = [Normal(0, 20) for i in 1:size(means)[2]])

likelihood = let model = mixture_model
    params -> begin
        return LogDVal(logpdf(model, params.a))
#         return LogDVal(sum([logpdf(Normal(), params.a[i]) for i in 1:length(params.a)]))
#         return LogDVal(0.0)
    end
end

posterior = PosteriorDensity(likelihood, prior);
log_volume = BAT.log_volume(BAT.spatialvolume(posterior.parbounds))

## 2. Sampling: 

In [None]:
sampler = MetropolisHastings() # AHMC()

burnin_1 = MCMCBurninStrategy(
        max_nsamples_per_cycle = 5000,
        max_nsteps_per_cycle = 5000,
        max_time_per_cycle = 25,
        max_ncycles = 30
    )

tuning = AdaptiveMetropolisTuning(
    λ = 0.5,
    α = 0.05..0.15,
    β = 1.5,
    c = 1e-4..1e2,
    r = 0.5
)

sampling_kwargs = (burnin = burnin_1, tuning=tuning);

The same with exploration sampler

In [None]:
# MetropolisHastings: 
exploration_sampler = MetropolisHastings()

burnin_2 = MCMCBurninStrategy(
        max_nsamples_per_cycle = 6000,
        max_nsteps_per_cycle = 6000,
        max_time_per_cycle = 25,
        max_ncycles = 20
    )

exploration_kwargs = (burnin = burnin_2,)
n_exploration = (10^2, 40);

Space partitioning can be done using `BAT.KDTreePartitioning` algorithm

In [None]:
partitioner = KDTreePartitioning(
        partition_dims = [1,], # dimension indices that are considered for partition
        extend_bounds=true # "false" is appropriate for debugging, very fast tuning /convergence 
    );

To integrate subspaces, any `BAT.IntegrationAlgorithm` can be used: 

In [None]:
integrator = AHMIntegration(  
        whitening= CholeskyPartialWhitening(), #CholeskyPartialWhitening(),
        autocorlen= GeyerAutocorLen(),
        volumetype = :HyperRectangle,
        max_startingIDs = 10000,
        max_startingIDs_fraction = 2.5,
        rect_increase = 0.1,
        warning_minstartingids = 16,
        dotrimming = true,
        uncertainty= [:cov]
    );

In [None]:
algorithm = PartitionedSampling(
        sampler = sampler,
        exploration_sampler = exploration_sampler,
        partitioner = partitioner,
        integrator = integrator,
        exploration_kwargs = exploration_kwargs,
        sampling_kwargs = sampling_kwargs,
        n_exploration = n_exploration
    );

In [None]:
n_chains = 10 # chains per subspace 
n_samples = 10^4 # samples per subspace 
n_subspaces = 3

output_sp_ms = bat_sample(posterior, (n_samples, n_chains, n_subspaces), algorithm);

samples_3 = output_sp_ms.result;

In [None]:
posterior_integral = -log(sum(output_sp_ms.info.density_integral))

In [None]:
flat_bounds = BAT.get_tree_par_bounds(output_sp_ms.part_tree)

In [None]:
@show log(sum([prod(rec_bound[:,2] .- rec_bound[:,1]) for rec_bound in flat_bounds]))
@show log_volume;

In [None]:
plot(samples_3, vsel=[1,2,3,], size=(700,500), globalmode=true, localmode=true, 
    upper=Dict("partition_tree"=>output_sp_ms.part_tree, 
        "mean"=>false, "globalmode"=>false, "localmode"=>false))

## Test Single Subspaces: 

In [None]:
subs_ind = 2
smpl_ind = output_sp_ms.info.samples_ind[subs_ind]
smpl_tot_weight = output_sp_ms.info.sum_weights[subs_ind]
smpl_trunc = samples_3[smpl_ind]
smpl_int = output_sp_ms.info.density_integral[subs_ind].val

samples_tmp = DensitySampleVector((smpl_trunc.v,
            smpl_trunc.logd,
            round.(Integer, smpl_tot_weight .* smpl_trunc.weight ./ smpl_int),
            smpl_trunc.info,
            smpl_trunc.aux));


integral_val, hmi_data = bat_integrate(samples_tmp, integrator)

@show smpl_int, integral_val

In [None]:
exp(log(integral_val) + log_volume)

In [None]:
plot(samples_tmp, upper=Dict("partition_tree"=>output_sp_ms.part_tree, 
        "mean"=>false, "globalmode"=>false, "localmode"=>false))

In [None]:
# plot(hmi_data, dim1 = 1, dim2 = 3, size=(900,450), plot_seedcubes=false, plot_rejectedrects = false, plot_acceptedrects = true, legend=false)

## Resample Subspace: 

In [None]:
smpl_int = output_sp_ms.info.density_integral[subs_ind]

flat_bounds = BAT.get_tree_par_bounds(output_sp_ms.part_tree)[subs_ind]

iid_samples = bat_sample(NamedTupleDist(a=mixture_model), 10^6).result;

In [None]:
mask_iid = [prod(flat_bounds[:,1] .< s.v[1][1] .<  flat_bounds[:,2]) for s in iid_samples]

@show sum(mask_iid)

In [None]:
integral_val_iid, hmi_data_iid = bat_integrate(iid_samples[mask_iid], integrator);

In [None]:
integral_val_iid

In [None]:
exp(log(smpl_int) + log_volume)

In [None]:
plot(iid_samples[mask_iid], upper=Dict("partition_tree"=>output_sp_ms.part_tree, 
        "mean"=>false, "globalmode"=>false, "localmode"=>false))

## MCMC samples: 

In [None]:
output_mcmc = bat_sample(posterior, (10^5, 5), MetropolisHastings()).result;

In [None]:
plot(output_mcmc, vsel=[1,2,3,4,5], size=(700,700), globalmode=true, localmode=true, 
    upper=Dict("mean"=>false, "globalmode"=>false, "localmode"=>false))

In [None]:
log(bat_integrate(output_mcmc).result)

In [None]:
bat_sample(posterior, 10^4, RandSampling())