In [None]:
using ValueShapes
using ArraysOfArrays
using StatsBase 
using LinearAlgebra
using Statistics

using Distributions 
using IntervalSets

using Plots
using Colors
using ColorSchemes
using LaTeXStrings

using BAT

pyplot(size=(750,500))
line_colors = ColorSchemes.tab20b;

In [None]:
# Gaussian Shell Den

N = 2
min_v = -25
max_v = 25

lgV = N*log(max_v-min_v); 

true_param =(λ=zeros(N), r=5, σ=2)

function fun(x; true_param=true_param)
    abs_dist = sqrt(sum((true_param.λ .- x).^2))
    return exp(-(abs_dist - true_param.r)^2/(2*true_param.σ^2)) / sqrt(2 * pi * true_param.σ^2) 
end

function LogTrueIntegral(N; true_param=true_param) 
    g(x; true_param=true_param, N=N) = x[1]^(N-1) * exp( -(x[1]-true_param.r)^2/(2*true_param.σ^2) )
    return (sqrt(2)*pi^((N-1)/2)) / (gamma(N/2)*true_param.σ)*hcubature(g, [0], [30])[1]
end

In [None]:
algorithm = MetropolisHastings()

tuning = AdaptiveMetropolisTuning(
    λ = 0.5,
    α = 0.15..0.35,
    β = 1.5,
    c = 1e-4..1e2
)

convergence = BrooksGelmanConvergence(
    threshold = 1.1,
    corrected = false
)

init = MCMCInitStrategy(
    init_tries_per_chain = 8..128,
    max_nsamples_init = 250,
    max_nsteps_init = 250,
    max_time_init = 180
)

burnin = MCMCBurninStrategy(
    max_nsamples_per_cycle = 1000,
    max_nsteps_per_cycle = 10000,
    max_time_per_cycle = 250,
    max_ncycles = 200
)

HMI_Manual_Settings = BAT.HMISettings(BAT.cholesky_partial_whitening!, 
        10000, 
        2.5, 
        0.1, 
        true, 
        16, 
        true, 
        Dict("cov. weighted result" => BAT.hm_combineresults_covweighted!)
    )

log_likelihood = params -> LogDVal((log(fun(params.a))))

prior = NamedTupleDist(a = [[min_v .. max_v for i in 1:N]...],)

posterior = PosteriorDensity(log_likelihood, prior);

In [None]:
nchains_ = 10
nsamples_ = 1*10^4
max_time = 150

max_nsteps = 10 * nsamples_


@time samples, chains = bat_sample(
    posterior, (nsamples_, nchains_), algorithm,
    max_nsteps = max_nsteps,
    max_time = max_time,
    tuning = tuning,
    init = init,
    burnin = burnin,
    convergence = convergence,
    strict = false,
    filter = true
);

# Test:

In [None]:
result = BAT.HMIData(unshaped.(samples))
settings = HMI_Manual_Settings

In [None]:
BAT.hm_init(result, settings)

BAT.hm_whiteningtransformation!(result, settings)

BAT.hm_createpartitioningtree!(result)

notsinglemode = BAT.hm_findseeds!(result, settings)
	
BAT.hm_determinetolerance!(result, settings) # tolerance is never Inf.
	
BAT.hm_create_integrationvolumes!(result, settings)

BAT.hm_integrate_integrationvolumes!(result, settings)

for pair in settings.uncertainty_estimators
    @info "Estimating Uncertainty ($(pair[1]))"
    result.integralestimates[pair[1]] = pair[2](result)
end

In [None]:
p_4 = plot(result, 1, 2,
        rscale = 0.7,
        plot_seedsamples = false,
        plot_seedcubes = false,
        plot_samples = true,
        plot_acceptedrects = false,
        plot_rejectedrects = false,
        plot_datasets = 1,
        font_scale = 1, 
        markercolor=:gray,
        markersize=1, 
)


p_4 = plot!(xlim=(-3, 3), ylim=(-3, 3), frame=true, size=(600,600), grid=false, legend=false, xaxis=nothing, yaxis=nothing)

# savefig(p_4, "../../AHMI_publication/GaussShellDistributionData/ahmi_example-a.png")

In [None]:
N = result.dataset1.partitioningtree.cuts
data_tree = deepcopy(result.dataset1.partitioningtree.cutlist)#[2:end]

hcuts = data_tree[1:N+1:end][2:end]

vcuts = deleteat!(data_tree, 1:N+1:length(data_tree))
vcuts = reshape(vcuts, N,N);


p_1 = plot(result, 1, 2,
        rscale = 0.7,
        plot_seedsamples = false,
        plot_seedcubes = false,
        plot_samples = true,
        plot_acceptedrects = false,
        plot_rejectedrects = false,
        plot_datasets = 1,
        font_scale = 1, 
        markercolor=:gray,
        markersize=1, 
)

p_1 = vline!([hcuts], legend=false, linecolor=:red, lw=1.0)

for i in 1:N
    left =  i == 1 ? -10 : hcuts[i-1]
    right = i <= length(hcuts) ? hcuts[i] : 10
    
    for j in 2:N
        p = plot!([left, right], [vcuts[j,i], vcuts[j,i]], seriestype=:path, linecolor=:red, lw=1.0)
    end
end

p_1 = plot!(xlim=(-3, 3), ylim=(-3, 3), frame=true, size=(600,600), grid=false, xaxis=nothing, yaxis=nothing, legend=false)

# savefig(p_1, "../../AHMI_publication/GaussShellDistributionData/ahmi_example-b.png")

In [None]:
modes = []
cubes = []

for i in result.dataset1.startingIDs
    mode = result.dataset1.data[:, i]
    initialcube, vol = BAT.create_initialhypercube(mode, result.dataset1, result.whiteningresult.targetprobfactor)
    push!(modes, mode)
    push!(cubes, initialcube)
end

cubes_x = zeros(Float64, length(cubes) * 6)
cubes_y = zeros(Float64, length(cubes) * 6)

cntr = 1
for i in eachindex(cubes)
    cubes_x[cntr:cntr+5], cubes_y[cntr:cntr+5] = BAT.create_rectangle(cubes[i], 1, 2)
    cntr += 6
end

p_4 = plot(result, 1, 2,
        rscale = 0.7,
        plot_seedsamples = true,
        plot_seedcubes = false,
        plot_samples = true,
        plot_acceptedrects = false,
        plot_rejectedrects = false,
        plot_datasets = 1,
        font_scale = 1, 
        markercolor=:gray,
        markersize=1, 
)


p_4 = plot!(cubes_x[1:end-1], cubes_y[1:end-1], seriestype=:path, color=:red)

p_4 = plot!(xlim=(-3, 3), ylim=(-3, 3), frame=true, size=(600,600), grid=false, legend=false, xaxis=nothing, yaxis=nothing)

# savefig(p_4, "../../AHMI_publication/GaussShellDistributionData/ahmi_example-c.png")

In [None]:
p_2 = plot(result, 1, 2,
        rscale = 0.7,
        plot_seedsamples = false,
        plot_seedcubes = false,
        plot_samples = true,
        plot_acceptedrects = true,
        plot_rejectedrects = true,
        plot_datasets = 1,
        font_scale = 1, 
        markercolor=:gray,
        markersize=1, 
)

p_2 = plot!(xlim=(-3, 3), ylim=(-3, 3), frame=true, size=(600,600), grid=false, legend=false, xaxis=nothing, yaxis=nothing)

# savefig(p_2, "../../AHMI_publication/GaussShellDistributionData/ahmi_example-d.png")