# Sampling with Space Partitioning

## Tutorial Notebook 

    Structure: 
    
    1. BAT model definition: Gaussian mixture
    2. Single-click run: Default settings  
    3. Detailed sampling summary
       -- @elapsed/@CPUelapsed time 
       -- Posterior integrals
       -- Plotting
    4. Fine-grained control: Samplers, integrators, partitioner 
       -- HMC/Sobol


In [None]:
using IntervalSets
using Distributions
using Random, LinearAlgebra, Statistics, Distributions, StatsBase, ArraysOfArrays
using JLD2
using ValueShapes
using TypedTables
using Measurements
using TypedTables
using HypothesisTests

In [None]:
using Plots
pyplot()

In [None]:
using BAT

## 1. BAT model definition: Gaussian mixture

Let us use a mixture of 4 normal distribution with random  covariance matrices as a model: 

In [None]:
σ_1 = [3.426818298733095 12.378238116671048 -9.632531611142454; 3.426818298733095 4.916266580684483 -5.029942800849483; 3.426818298733095 -6.737309268887753 5.4343957706004415; 3.426818298733095 -3.9729587574454333 3.379361860370276]
σ_2 = [12.378238116671048 69.83909693165143 -43.478993858310886; 4.916266580684482 69.83909693165143 -64.18784570966332; -6.737309268887753 69.83909693165143 32.734776615550174; -3.972958757445433 69.83909693165143 53.74871853095418]
σ_3 = [-9.632531611142456 -43.478993858310886 60.0626256206892; -5.029942800849484 -64.18784570966332 60.0626256206892; 5.4343957706004415 32.734776615550174 60.0626256206892; 3.3793618603702757 53.74871853095418 60.0626256206892]
σ = cat(σ_1,σ_2,σ_3, dims=3)

μ = [8.959570984309234 -9.021529871694005 -5.007789383392622; -1.446445514344754 -7.8327010768703875 -3.2653263028963986; -6.40954093270941 0.25815094665222027 5.830143596540282; 5.076504517881521 -0.8952973253675331 9.16356325348496]

mixture_model = MixtureModel(MvNormal[MvNormal(μ[i,:], Matrix(Hermitian(σ[i,:,:])) ) for i in 1:4]);


In [None]:
prior = NamedTupleDist(a = [Uniform(-50,50), Uniform(-50,50), Uniform(-50,50)])

likelihood = let model = mixture_model
    params -> LogDVal(logpdf(model, params.a))
end

posterior = PosteriorDensity(likelihood, prior);

log_volume = BAT.log_volume(BAT.spatialvolume(posterior.parbounds))

## 2. Single-click run: Default settings 

Default parameters of the `PartitionedSampling()` algorithm: 

1) Exploration samples: `MetropolisHastings()` sampler (20 chains * 10^2 samples)

2) Sampling: `MetropolisHastings()` sampler

3) Space partitioning: `KDTreePartitioning()`

4) Reweighting: `AHMIntegration()`


In [None]:
n_chains = 5 # chains per subspace 
n_samples = 10^4 # samples per subspace 
n_subspaces = 5;

To generate samples with default settings

In [None]:
algorithm_1 = PartitionedSampling()

output_sp = bat_sample(posterior, (n_samples, n_chains, n_subspaces), algorithm_1)

samples_1 = output_sp.result;

Just to compare, let us generate samples using MetropolisHastings algorithm, too:

In [None]:
algorithm_2 = MetropolisHastings()

output_mcmc = bat_sample(posterior, (n_samples*n_subspaces, n_chains), algorithm_2)

samples_2 = output_mcmc.result;

Standard BAT statistics on the posterior DensitySampleVector:

In [None]:
println("Mode: $(mode(samples_1))")
println("Mode: $(mode(samples_2))")

println("Mean: $(mean(samples_1))")
println("Mean: $(mean(samples_2))")

Standard BAT plotting recipes: 

In [None]:
plot(samples_1, size=(700,700), upper=Dict("mean"=>false, "globalmode"=>false, "localmode"=>false))

In [None]:
plot(samples_2, size=(700,700), upper=Dict("mean"=>false, "globalmode"=>false, "localmode"=>false))

## 3. Detailed sampling summary

Run information:

In [None]:
columnnames(output_sp.info)

In [None]:
output_sp.info

In [None]:
posterior_integral = sum(output_sp.info.density_integral)

In [None]:
log(posterior_integral)

In [None]:
total_cpu_time = sum(output_sp.info.sampling_cpu_time) + sum(output_sp.info.integration_cpu_time)

In [None]:
total_wc_time = (output_sp.info.sampling_wc[end][end] - output_sp.info.sampling_wc[1][1])*1e-9

Partition tree: 

In [None]:
plot(samples_1, size=(700,700), upper=Dict("partition_tree"=>output_sp.part_tree, "mean"=>false, "globalmode"=>false, "localmode"=>false))

Exploration samples: 

In [None]:
n_explorstion = length(output_sp.exp_samples)

In [None]:
plot(output_sp.exp_samples, size=(700,700), 
    upper = Dict("seriestype"=>:scatter, "colors"=>:red, "partition_tree"=>output_sp.part_tree,), 
    lower = Dict("seriestype"=>:hist, "bins"=>30),
    diagonal = Dict("bins"=>30),)

## 4. Fine-grained control: Samplers, integrators, partitioner 

Subspace sampler can be any `BAT.AbstractSamplingAlgorithm` algorithm: 

In [None]:
sampler = MetropolisHastings() # AHMC()

burnin_1 = MCMCBurninStrategy(
        max_nsamples_per_cycle = 400,
        max_nsteps_per_cycle = 400,
        max_time_per_cycle = 25,
        max_ncycles = 20
    )

sampling_kwargs = (burnin = burnin_1,);

The same with exploration sampler

In [None]:
# Sobol Sampler: 
# exploration_sampler = BAT.SobolSampler()
# exploration_kwargs = NamedTuple()
# n_exploration = (10^2, 40);

# MetropolisHastings: 
exploration_sampler = MetropolisHastings()
burnin_2 = MCMCBurninStrategy(
        max_nsamples_per_cycle = 400,
        max_nsteps_per_cycle = 400,
        max_time_per_cycle = 25,
        max_ncycles = 5
    )
exploration_kwargs = (burnin = burnin_1,)
n_exploration = (10^2, 40);

Space partitioning can be done using `BAT.KDTreePartitioning` algorithm

In [None]:
partitioner = KDTreePartitioning(
        partition_dims = [1,2,], # dimension indices that are considered for partition
        extend_bounds=false # "false" is appropriate for debugging, very fast tuning /convergence 
    );

To integrate subspaces, any `BAT.IntegrationAlgorithm` can be used: 

In [None]:
integrator = AHMIntegration(  
        whitening= CholeskyPartialWhitening(),
        autocorlen= GeyerAutocorLen(),
        volumetype = :HyperRectangle,
        max_startingIDs = 10000,
        max_startingIDs_fraction = 2.5,
        rect_increase = 0.1,
        warning_minstartingids = 16,
        dotrimming = true,
        uncertainty= [:cov]
    );

Finally, `PartitionedSampling` can be defined using settings: 

In [None]:
algorithm = PartitionedSampling(
        sampler = sampler,
        exploration_sampler = exploration_sampler,
        partitioner = partitioner,
        integrator = integrator,
        exploration_kwargs = exploration_kwargs,
        sampling_kwargs = sampling_kwargs,
        n_exploration = n_exploration
    );

In [None]:
output_sp_ms = bat_sample(posterior, (n_samples, n_chains, 40), algorithm);

samples_3 = output_sp_ms.result;

In [None]:
plot(samples_3, vsel=[1,2,3], size=(700,700), globalmode=true, localmode=true, 
    upper=Dict("partition_tree"=>output_sp_ms.part_tree, 
        "mean"=>false, "globalmode"=>false, "localmode"=>false))

In [None]:
posterior_integral = sum(output_sp_ms.info.density_integral)

In [None]:
plot(output_sp_ms.exp_samples, size=(700,700), 
    upper = Dict("seriestype"=>:scatter, "partition_tree"=>output_sp_ms.part_tree,), 
    lower = Dict("seriestype"=>:hist, "bins"=>30),
    diagonal = Dict("bins"=>30),)