In [None]:
using Distributions 
using IntervalSets
using ValueShapes
using ArraysOfArrays
using StatsBase 
using LinearAlgebra
using Random123
using HCubature
using LaTeXStrings

In [None]:
import Plots
import PyPlot
Plots.pyplot()

plt = PyPlot

SMALL_SIZE = 10
MEDIUM_SIZE = 11
BIGGER_SIZE = 12

plt.rc("font", size=SMALL_SIZE)          # controls default text sizes
plt.rc("axes", titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc("axes", labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc("xtick", labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc("ytick", labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc("legend", fontsize=SMALL_SIZE)    # legend fontsize
plt.rc("figure", titlesize=BIGGER_SIZE)  # fontsize of the figure title

colors = vcat([0 0 0 0], plt.cm.YlOrRd(range(0, stop=1, length=10))[2:end,:]);

In [None]:
using Revise
using BATPar
using KDTree
using BAT

# Density Function

In [None]:
# # simple Normal Distribution: 

# N = 5
# min_v = -6.
# max_v = 6.

# lgV = N*log(max_v-min_v); 

# f(x::AbstractArray) = prod(pdf.(Normal(0, 1), x))

# LogTrueIntegral(N)=0.0

In [None]:
# simple Normal Distribution: 

N = 5
min_v = -10.
max_v = 10.

dist_dim = truncated(Normal(0,1), min_v, max_v)
dist = product_distribution([dist_dim for i in 1:N])

f(x::AbstractArray) = pdf(dist, x)

lgV = N*log(max_v-min_v); 

LogTrueIntegral(N)=0.0

# Serial Sampling

In [None]:
likelihood = params -> LogDVal((log(f(params.a))))
prior = NamedTupleDist(a = [[min_v .. max_v for i in 1:N]...],)
posterior = PosteriorDensity(likelihood, prior);

In [None]:
nnsamples = 10^2
nnchains = 5

samples, stats = bat_sample(posterior, (nnsamples, nnchains), MetropolisHastings());

In [None]:
smpl = flatview(unshaped.(samples.v))
weights_LogLik = samples.logd
weights_Histogram = samples.weight;

data_kdtree = Data(smpl[:,1:end], weights_Histogram[1:end], weights_LogLik[1:end]);

# Space Partitioning: 

In [None]:
KDTree.evaluate_total_cost(data::Data) = KDTree.cost_f_1(data)

output, cost_array = DefineKDTree(data_kdtree, collect(1:N), 10);

extend_tree_bounds!(output, repeat([min_v], N), repeat([max_v], N))

In [None]:
prior_bounds = [min_v, max_v] 

extend_tree_bounds!(output, repeat([prior_bounds[1]], N), repeat([prior_bounds[2]], N))

In [None]:
fig, ax = plt.subplots(1,1, figsize=(7, 5))
ax.scatter(smpl[1,:], smpl[2,:], color="k", s=0.4)
ax.set_xlabel(L"\lambda_1")
ax.set_ylabel(L"\lambda_2")

plot_tree(output, [1,2], ax, color="red")

# ax.set_xlim(-11., 11.)
# ax.set_ylim(-11., 11.)

# ax.set_xlim(-4., 4.)
# ax.set_ylim(-4., 4.)

# Sampling of subspaces 

In [None]:
burnin = MCMCBurninStrategy(
    max_nsamples_per_cycle = 5000,
    max_nsteps_per_cycle = 50000,
    max_time_per_cycle = Inf,
    max_ncycles = 150
)

tuning = AdaptiveMetropolisTuning(
    λ = 0.5,
    α = 0.05..0.15,
    β = 1.5,
    c = 1e-4..1e2
)


AHMI_settingss = BAT.HMISettings(BAT.cholesky_partial_whitening!,
    10000, 2.5, 0.1, true, 16, true, Dict("cov. weighted result" => BAT.hm_combineresults_covweighted!)
)

algorithm = MetropolisHastings(ARPWeighting())

nnchains = 3
nnsamples = 10^4;

In [None]:
bounds_part = extract_par_bounds(output)
BATPar.make_named_prior(i) = BAT.NamedTupleDist( a =  [[i[j,1]..i[j,2] for j in 1:size(i)[1]]...])
algorithm = MetropolisHastings();

samples_parallel = bat_sample_parallel(likelihood, bounds_part, (nnsamples, nnchains), algorithm, tuning=tuning, burnin=burnin, settings=AHMI_settingss); 

In [None]:
smpl_par = hcat(samples_parallel.samples...)
x = smpl_par[1,:]
y = smpl_par[2,:]
w_o = samples_parallel.weights_o
w_r =  samples_parallel.weights_r;

In [None]:
@show "Truth", exp(LogTrueIntegral(N))

@show "Int", sum(samples_parallel.integrals), sqrt(sum((samples_parallel.uncertainty).^2));

In [None]:
# bin_range = range(min_v, stop=max_v, length=50)

histogram_wr = fit(Histogram, (x, y), weights(w_r), nbins=100)
histogram_wo = fit(Histogram, (x, y), weights(w_o), nbins=100)
histogram_wr = normalize(histogram_wr, mode=:pdf);
histogram_wo = normalize(histogram_wo, mode=:pdf);

In [None]:
fig, ax = plt.subplots(1,3, figsize=(15, 5))
fig.subplots_adjust(wspace=0.05)

ax[1].pcolormesh(midpoints(histogram_wr.edges[1]), midpoints(histogram_wr.edges[2]), replace(histogram_wr.weights', 0=>NaN), cmap="RdYlBu_r") 
ax[3].pcolormesh(midpoints(histogram_wr.edges[1]), midpoints(histogram_wr.edges[2]), replace(histogram_wr.weights', 0=>NaN), cmap="RdYlBu_r") 
ax[2].pcolormesh(midpoints(histogram_wo.edges[1]), midpoints(histogram_wo.edges[2]), replace(histogram_wo.weights', 0=>NaN), cmap="RdYlBu_r") 

plot_tree(output, [1,2], ax[2], linewidth=0.8, color="black", alpha=0.4)

ax[1].set_xlabel(L"\lambda_1")
ax[1].set_ylabel(L"\lambda_2")

ax[1].set_xlim(min_v, max_v)
ax[1].set_ylim(min_v, max_v)

ax[3].set_xlabel(L"\lambda_1")
ax[3].get_yaxis().set_visible(false)
ax[2].get_yaxis().set_visible(false)

# Underestimation/bias 

In [None]:
n_cuts = 3
n_smpl_cut = 10^4
n_chains = 10

AHMI_settings = BAT.HMISettings(BAT.cholesky_partial_whitening!,
    1000, 1.0, 0.1, true, 16, true, Dict("cov. weighted result" => BAT.hm_combineresults_covweighted!)
)

algorithm = MetropolisHastings()

burnin = BAT.MCMCBurninStrategy()

tuning = AdaptiveMetropolisTuning(
    λ = 0.5,
    α = 0.15..0.2,
    β = 1.5,
    c = 1e-4..1e2
)


In [None]:
function generate_data(n_runs, n_cuts, n_smpl_cut, n_chains; AHMI_settings=AHMI_settings)
    
    integrals_true = Vector{Float64}()
    log_v_true = Vector{Float64}()
    
    integrals_partsmpl = Vector{Float64}()
    uns_partsmpl = Vector{Float64}()
    log_v_part = Vector{Float64}()
    
    integrals_ahmi = Vector{Float64}()
    uns_ahmi = Vector{Float64}()
    
    for i in 1:n_runs
        
        @show i 
        
        seeds, stats = bat_sample(posterior, (100, 10), MetropolisHastings())
        kd_data = Data(flatview(unshaped.(seeds.v))[:,1:end], seeds.weight[1:end], seeds.logd[1:end])
        kd_output, _ = DefineKDTree(kd_data, collect(1:N), n_cuts);
        extend_tree_bounds!(kd_output, repeat([min_v], N), repeat([max_v], N)) # try changing this 
        kd_bounds = extract_par_bounds(kd_output)
        samples_par = bat_sample_parallel(likelihood, kd_bounds, (n_smpl_cut, n_chains), algorithm, tuning=tuning, burnin=burnin, settings=AHMI_settings);
        par_integral_run = [sum(samples_par.integrals), sqrt(sum((samples_par.uncertainty).^2))] ./ 1.0
        
        tot_volum = sum([prod(diff(j, dims=2)) for j in kd_bounds])
        
        # ***
        
        samples_serial, stats_serial = bat_sample(posterior, ((1+n_cuts)*n_smpl_cut, n_chains), MetropolisHastings(), tuning=tuning, burnin=burnin,)
        hmi_data = BAT.HMIData(unshaped.(samples_serial))
        BAT.hm_integrate!(hmi_data, settings=AHMI_settings)
        
        ahmi_integral_run =[hmi_data.integralestimates["cov. weighted result"].final.estimate, hmi_data.integralestimates["cov. weighted result"].final.uncertainty]
		
        log_smpl_int = ahmi_integral_run .* exp(lgV)
        
        # ***
        
        push!(integrals_true, 0.0)
        push!(integrals_partsmpl, par_integral_run[1])
        push!(uns_partsmpl, par_integral_run[2])
        push!(integrals_ahmi, log_smpl_int[1])
        push!(uns_ahmi, log_smpl_int[2])
        push!(log_v_true, lgV)
        push!(log_v_part, tot_volum)
    end
    
    return (integrals_true, log_v_true, integrals_partsmpl, uns_partsmpl, log_v_part, integrals_ahmi, uns_ahmi)
end

In [None]:
(int_true, lgV_true, int_part, uns_part, lgV_part, int_ahmi, uns_ahmi) = generate_data(30, n_cuts, n_smpl_cut, n_chains)

In [None]:
integrals_partition = int_part
integral_ahmi = int_ahmi
integrals_true = ones(length(integral_ahmi))

unsert_partition = uns_part
unsert_ahmi = uns_ahmi;

In [None]:
fig, ax = plt.subplots(1,1, figsize=(7, 5))

ax.axvline(1, c="red", label="Truth", alpha=0.5)
ax.axvline(mean(integrals_partition), c="C0", alpha=0.5)
ax.axvline(mean(integral_ahmi), c="C1", alpha=0.5)

ax.hist(integrals_partition, bins=10, density=true, color="C0", alpha=0.5, label="w/ partition")
ax.hist(integral_ahmi, bins=10, density=true, color="C1", alpha=0.5, label="w/o partition")

ax.legend(loc="upper left", frameon=true, framealpha=0.8, ncol=1)

ax.set_xlim(0.95, 1.05)

ax.set_xlabel("I")
ax.set_ylabel("counts")

In [None]:
fig, ax = plt.subplots(1,1, figsize=(7, 5))

ax.axvline(1, c="red", label="Truth", alpha=0.5)
ax.axvline(mean(integrals_partition), c="C0", alpha=0.5)

ax.hist(integrals_partition, bins=10, density=true, color="C0", alpha=0.5, label="w/ partition")
ax.legend(loc="upper left", frameon=true, framealpha=0.8, ncol=1)

ax.set_xlim(0.95, 1.05)

ax.set_xlabel("I")
ax.set_ylabel("counts")