In [None]:
using Revise 
using EuclidianNormalizingFlows

using BAT
using Distributions
using Optimisers
using FunctionChains
using ArraysOfArrays
using LinearAlgebra
using ValueShapes
using StatsBase
using FileIO
using JLD2
using CUDA
using CUDAKernels
using KernelAbstractions
using Flux
using PyPlot

n_smpls = 10^5
n_dims = 200
n_modes = 10

mvns = [MvNormal(10 .*rand(n_dims), 0.5 * abs(randn(1)[1]) .* I(n_dims)) for i in 1:n_modes]
d = MixtureModel(mvns)
importance_density = MvNormal(zeros(n_dims), I)
wanna_use_GPU = true
_device = wanna_use_GPU ? KernelAbstractions.get_device(CUDA.rand(10)) : KernelAbstractions.get_device(rand(10))
samples = bat_sample(d, BAT.IIDSampling(nsamples=n_smpls)).result;
smpls_flat = flatview(unshaped.(samples.v));
samples_nested = wanna_use_GPU ? nestedview(gpu(smpls_flat)) : nestedview(smpls_flat);

In [None]:
nbatches = 10
nepochs = 100
K = 40

blocks = get_flow_musketeer(n_dims,_device,K)
#lr=range(5f-3, 5f-4, length=length(blocks.fs))
lr = fill(3f-3, length(blocks.fs))
trained_blocks = Function[ScaleShiftNorm(_device)]
hists = Vector[]
#smpls_transformed = samples_nested;
smpls_train = samples_nested;

@time begin
    #smpls_train = nestedview(smpls_flat)
    smpls_train = nestedview(trained_blocks[1](smpls_flat))
    
    
    for i in 1:n_dims#Int(n_dims/2)#length(blocks.fs)
        #println("+++ Starting round $i")
        #nbatches= round(Int,nbatches * 1.16)
            
        if i%50==0 
            println("+++ Starting round $i")
        end
        
        r = optimize_whitening(smpls_train, 
            blocks.fs[1+i],
            Optimisers.Adam(lr[i]),
            nbatches=nbatches,
            nepochs=nepochs, 
            shuffle_samples =false)
            
        trained_trafo = r.result
        push!(trained_blocks, trained_trafo)
        push!(hists, r.negll_history)
        smpls_train = nestedview(trained_trafo(flatview(smpls_train)))
    end
end

push!(trained_blocks, ScaleShiftNorm(_device))
trained_flow = fchain(trained_blocks)
smpls_transformed, ladj_trafo = EuclidianNormalizingFlows.with_logabsdet_jacobian(trained_flow, smpls_flat)

smpls_transformed = cpu(smpls_transformed)
smpls_flat_cpu = cpu(smpls_flat)

integral, variance = ghm_integration(smpls_transformed, samples.logd, vec(cpu(ladj_trafo)), importance_density)

@show integral
@show variance


fig, ax = plt.subplots(1,2, figsize=(8,4))
ax[1].hist2d(smpls_flat_cpu[1,:], smpls_flat_cpu[2,:], [100,100], cmap="inferno")
ax[1].set_xlim([minimum(smpls_flat_cpu[1,:]), maximum(smpls_flat_cpu[1,:])])
ax[1].set_ylim([minimum(smpls_flat_cpu[2,:]), maximum(smpls_flat_cpu[2,:])])
ax[2].hist2d(smpls_transformed[1,:], smpls_transformed[2,:], [100,100], cmap="inferno")
ax[2].set_xlim([-3, 3])
ax[2].set_ylim([-3, 3])



for i in round.(Integer, range(1,size(cpu(smpls_flat),1), 4)) #size(cpu(smpls_flat),1)
    fig, ax = plt.subplots(1,2, figsize=(16,4))
    bins = range(minimum(smpls_flat_cpu[i,:])-1, maximum(smpls_flat_cpu[i,:])+1, 110)
    ax[1].hist(cpu(smpls_flat)[i,:], weights=samples.weight, bins=bins, alpha=0.3, label="Target Marginal")
    ax[1].legend()
    ax[1].set_xlabel("$i")
    bins = range(minimum(smpls_transformed[i,:])-1, maximum(smpls_transformed[i,:])+1, 110)
    ax[2].hist(cpu(smpls_transformed)[i,:], weights=samples.weight, bins=bins, alpha=0.3, label="Transformed Marginal")
    ax[2].hist(rand(Normal(), n_smpls),  bins=bins, alpha=0.3, label="Gaussian")
    ax[2].legend()
    ax[2].set_xlabel("$i")
end

In [None]:
save("Nice_MvN_200D_musketeer_23mar21st.jld2", 
    Dict(
    "int" => integral, 
    "int_var" => variance, 
    "flow" => cpu(trained_flow),
    "samples" => cpu(samples),
    "target_dist" => d,
    "nbatches" => nbatches,
    "nepochs" => nepochs,
    "K" => K,
    "lr" => lr,
    "neg_ll_hists" => hists,
    #"trained_blocks" => trained_blocks
    )
)