In [1]:
using MuseInference
using Zygote

In [32]:
# 512-dimensional noisy funnel
prob = SimpleMuseProblem(
    rand(10),
    function sample_x_z(rng, θ)
        z = rand(rng, MvNormal(zeros(512), exp(θ)*I))
        x = rand(rng, MvNormal(z, I))
        (;x, z)
    end,
    function logLike(x, z, θ)
        -(1//2) * (sum((x .- z).^2) + sum(z.^2) / exp(θ) + 512*θ)
    end, 
    function logPrior(θ)
        -θ^2/(2*3^2)
    end;
    autodiff = AD.ForwardDiffBackend()
)

# get solution
muse(prob, (θ=1,))

LoadError: MethodError: no method matching exp(::NamedTuple{(:θ,), Tuple{Int64}})
[0mClosest candidates are:
[0m  exp([91m::Union{Float16, Float32, Float64}[39m) at special/exp.jl:326
[0m  exp([91m::StridedMatrix{var"#s886"} where var"#s886"<:Union{Float32, Float64, ComplexF32, ComplexF64}[39m) at C:\APPs\Julia-1.8.4\share\julia\stdlib\v1.8\LinearAlgebra\src\dense.jl:569
[0m  exp([91m::StridedMatrix{var"#s886"} where var"#s886"<:Union{Integer, Complex{<:Integer}}[39m) at C:\APPs\Julia-1.8.4\share\julia\stdlib\v1.8\LinearAlgebra\src\dense.jl:570
[0m  ...

In [9]:
MuseInference.ZygoteBackend()

LoadError: UndefVarError: prob not defined

In [6]:
using AdvancedHMC, ForwardDiff
using LogDensityProblems
using LinearAlgebra

# Define the target distribution using the `LogDensityProblem` interface
struct LogTargetDensity
    dim::Int
end
LogDensityProblems.logdensity(p::LogTargetDensity, θ) = -sum(abs2, θ) / 2  # standard multivariate normal
LogDensityProblems.dimension(p::LogTargetDensity) = p.dim
LogDensityProblems.capabilities(::Type{LogTargetDensity}) = LogDensityProblems.LogDensityOrder{0}()

# Choose parameter dimensionality and initial parameter value
D = 100; initial_θ = rand(D)
ℓπ = LogTargetDensity(D)

# Set the number of samples to draw and warmup iterations
n_samples, n_adapts = 2_000, 1_000

# Define a Hamiltonian system
metric = DiagEuclideanMetric(D)
hamiltonian = Hamiltonian(metric, ℓπ, ForwardDiff)

# Define a leapfrog solver, with initial step size chosen heuristically
initial_ϵ = find_good_stepsize(hamiltonian, initial_θ)
integrator = Leapfrog(initial_ϵ)

# Define an HMC sampler, with the following components
#   - multinomial sampling scheme,
#   - generalised No-U-Turn criteria, and
#   - windowed adaption for step-size and diagonal mass matrix
proposal = NUTS{MultinomialTS, GeneralisedNoUTurn}(integrator)
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(0.8, integrator))

# Run the sampler to draw samples from the specified Gaussian, where
#   - `samples` will store the samples
#   - `stats` will store diagnostic statistics for each sample
samples, stats = sample(hamiltonian, proposal, initial_θ, n_samples, adaptor, n_adapts; progress=true)

[32mSampling   0%|█                              |  ETA: 0:02:05[39m[K
[34m  iterations:                                   2[39m[K
[34m  ratio_divergent_transitions:                  0.0[39m[K
[34m  ratio_divergent_transitions_during_adaption:  0.0[39m[K
[34m  n_steps:                                      1[39m[K
[34m  is_accept:                                    true[39m[K
[34m  acceptance_rate:                              0.0[39m[K
[34m  log_density:                                  -17.820112182773414[39m[K
[34m  hamiltonian_energy:                           65.91424991068507[39m[K
[34m  hamiltonian_energy_error:                     0.0[39m[K
[34m  max_hamiltonian_energy_error:                 793.8955719679075[39m[K
[34m  tree_depth:                                   1[39m[K
[34m  numerical_error:                              false[39m[K
[34m  step_size:                                    2.5626941406301733[39m[K
[34m  nom


















[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[32mSampling   5%|██                             |  ETA: 0:00:05[39m[K
[34m  iterations:                                   102[39m[K
[34m  ratio_divergent_transitions:                  0.0[39m[K
[34m  ratio_divergent_transitions_during_adaption:  0.01[39m[K
[34m  n_steps:                                      1[39m[K
[34m  is_accept:                                    true[39m[K
[34m  acceptance_rate:                              0.0[39m[K
[34m  log_density:                                  -49.990856995580245[39m[K
[34m  hamiltonian_energy:                           87.27101269785999[39m[K
[34m  hamiltonian_energy_error:                     0.0[39m[K
[34m  max_hamiltonian_energy_error:                 1.1913249420806959e6[39m[K
[34m  tree_depth:                                   0[39m[K
[34m  numerical_


















[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[32mSampling   9%|███                            |  ETA: 0:00:04[39m[K
[34m  iterations:                                   178[39m[K
[34m  ratio_divergent_transitions:                  0.0[39m[K
[34m  ratio_divergent_transitions_during_adaption:  0.01[39m[K
[34m  n_steps:                                      15[39m[K
[34m  is_accept:                                    true[39m[K
[34m  acceptance_rate:                              0.9313776807119758[39m[K
[34m  log_density:                                  -62.70741191286972[39m[K
[34m  hamiltonian_energy:                           116.04703045259852[39m[K
[34m  hamiltonian_energy_error:                     0.10286619060742908[39m[K
[34m  max_hamiltonian_energy_error:                 0.1369784434261021[39m[K
[34m  tree_depth:                                 


















[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[32mSampling  13%|████                           |  ETA: 0:00:03[39m[K
[34m  iterations:                                   254[39m[K
[34m  ratio_divergent_transitions:                  0.0[39m[K
[34m  ratio_divergent_transitions_during_adaption:  0.01[39m[K
[34m  n_steps:                                      15[39m[K
[34m  is_accept:                                    true[39m[K
[34m  acceptance_rate:                              0.9853378564216276[39m[K
[34m  log_density:                                  -67.90110922606368[39m[K
[34m  hamiltonian_energy:                           117.79112321097331[39m[K
[34m  hamiltonian_energy_error:                     0.009992909863143495[39m[K
[34m  max_hamiltonian_energy_error:                 -0.3245838165601498[39m[K
[34m  tree_depth:                               


















[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[32mSampling  19%|██████                         |  ETA: 0:00:03[39m[K
[34m  iterations:                                   373[39m[K
[34m  ratio_divergent_transitions:                  0.0[39m[K
[34m  ratio_divergent_transitions_during_adaption:  0.01[39m[K
[34m  n_steps:                                      7[39m[K
[34m  is_accept:                                    true[39m[K
[34m  acceptance_rate:                              0.7517768898765915[39m[K
[34m  log_density:                                  -45.92706236344347[39m[K
[34m  hamiltonian_energy:                           88.86525068207399[39m[K
[34m  hamiltonian_energy_error:                     0.2891422096945746[39m[K
[34m  max_hamiltonian_energy_error:                 0.3475206413342562[39m[K
[34m  tree_depth:                                   3


















[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[32mSampling  28%|█████████                      |  ETA: 0:00:02[39m[K
[34m  iterations:                                   556[39m[K
[34m  ratio_divergent_transitions:                  0.0[39m[K
[34m  ratio_divergent_transitions_during_adaption:  0.01[39m[K
[34m  n_steps:                                      7[39m[K
[34m  is_accept:                                    true[39m[K
[34m  acceptance_rate:                              0.7225062074144152[39m[K
[34m  log_density:                                  -58.61848542710099[39m[K
[34m  hamiltonian_energy:                           107.75215538527351[39m[K
[34m  hamiltonian_energy_error:                     0.37655846666373805[39m[K
[34m  max_hamiltonian_energy_error:                 0.5537373272815387[39m[K
[34m  tree_depth:                                  


















[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[32mSampling  35%|███████████                    |  ETA: 0:00:02[39m[K
[34m  iterations:                                   692[39m[K
[34m  ratio_divergent_transitions:                  0.0[39m[K
[34m  ratio_divergent_transitions_during_adaption:  0.01[39m[K
[34m  n_steps:                                      7[39m[K
[34m  is_accept:                                    true[39m[K
[34m  acceptance_rate:                              0.7429871690136876[39m[K
[34m  log_density:                                  -43.05296425283181[39m[K
[34m  hamiltonian_energy:                           88.72942315608961[39m[K
[34m  hamiltonian_energy_error:                     0.13445165063083664[39m[K
[34m  max_hamiltonian_energy_error:                 0.6548325744882106[39m[K
[34m  tree_depth:                                   


















[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[32mSampling  43%|██████████████                 |  ETA: 0:00:01[39m[K
[34m  iterations:                                   854[39m[K
[34m  ratio_divergent_transitions:                  0.0[39m[K
[34m  ratio_divergent_transitions_during_adaption:  0.0[39m[K
[34m  n_steps:                                      111[39m[K
[34m  is_accept:                                    true[39m[K
[34m  acceptance_rate:                              0.9646583345813848[39m[K
[34m  log_density:                                  -39.18868741500303[39m[K
[34m  hamiltonian_energy:                           91.44289646560135[39m[K
[34m  hamiltonian_energy_error:                     -0.3584119013061269[39m[K
[34m  max_hamiltonian_energy_error:                 -0.3763913488099604[39m[K
[34m  tree_depth:                                 


















[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[32mSampling  51%|████████████████               |  ETA: 0:00:01[39m[K
[34m  iterations:                                   1016[39m[K
[34m  ratio_divergent_transitions:                  0.0[39m[K
[34m  ratio_divergent_transitions_during_adaption:  0.0[39m[K
[34m  n_steps:                                      7[39m[K
[34m  is_accept:                                    true[39m[K
[34m  acceptance_rate:                              0.4733474001919423[39m[K
[34m  log_density:                                  -50.4558922733567[39m[K
[34m  hamiltonian_energy:                           112.40188046511668[39m[K
[34m  hamiltonian_energy_error:                     0.33946243301730306[39m[K
[34m  max_hamiltonian_energy_error:                 1.3732398707313678[39m[K
[34m  tree_depth:                                   


















[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[32mSampling  83%|██████████████████████████     |  ETA: 0:00:00[39m[K
[34m  iterations:                                   1654[39m[K
[34m  ratio_divergent_transitions:                  0.0[39m[K
[34m  ratio_divergent_transitions_during_adaption:  0.0[39m[K
[34m  n_steps:                                      7[39m[K
[34m  is_accept:                                    true[39m[K
[34m  acceptance_rate:                              0.6252576611439778[39m[K
[34m  log_density:                                  -54.05340005100106[39m[K
[34m  hamiltonian_energy:                           112.53336852286688[39m[K
[34m  hamiltonian_energy_error:                     0.20035730303752075[39m[K
[34m  max_hamiltonian_energy_error:                 0.8564049478486595[39m[K
[34m  tree_depth:                                  


















[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[K[A[32mSampling 100%|███████████████████████████████| Time: 0:00:01[39m[K
[34m  iterations:                                   2000[39m[K
[34m  ratio_divergent_transitions:                  0.0[39m[K
[34m  ratio_divergent_transitions_during_adaption:  0.0[39m[K
[34m  n_steps:                                      7[39m[K
[34m  is_accept:                                    true[39m[K
[34m  acceptance_rate:                              0.9862851278539819[39m[K
[34m  log_density:                                  -41.223408567734396[39m[K
[34m  hamiltonian_energy:                           85.83642012978709[39m[K
[34m  hamiltonian_energy_error:                     -0.3030877499782605[39m[K
[34m  max_hamiltonian_energy_error:                 -0.35867705500186275[39m[K
[34m  tree_depth:                                

┌ Info: Finished 2000 sampling steps for 1 chains in 1.4481759 (s)
│   h = Hamiltonian(metric=DiagEuclideanMetric([0.9495360978812399, 0.9225 ...]), kinetic=GaussianKinetic())
│   κ = HMCKernel{AdvancedHMC.FullMomentumRefreshment, Trajectory{MultinomialTS, Leapfrog{Float64}, GeneralisedNoUTurn{Float64}}}(AdvancedHMC.FullMomentumRefreshment(), Trajectory{MultinomialTS}(integrator=Leapfrog(ϵ=0.47), tc=GeneralisedNoUTurn{Float64}(10, 1000.0)))
│   EBFMI_est = 1.0570949117442243
│   average_acceptance_rate = 0.8206590633062693
└ @ AdvancedHMC C:\Users\clock\.julia\packages\AdvancedHMC\9L3Qc\src\sampler.jl:246


([[-0.18201294085766995, -0.9818942168152818, -0.49301630431904075, -0.07524372355807718, -0.9224662895496503, -0.028535454686621442, -0.8654930012891681, -0.8628378612968225, -0.09812532505180932, -0.8511256962490754  …  -0.6975989957790674, -0.1267373900824793, -0.7680611239846544, 0.05434697196972921, -0.37228696540252937, -0.27322167175375744, -0.21843247887394646, 0.13146643630592514, -0.6346162034644767, -0.4115063354806173], [-0.18201294085766995, -0.9818942168152818, -0.49301630431904075, -0.07524372355807718, -0.9224662895496503, -0.028535454686621442, -0.8654930012891681, -0.8628378612968225, -0.09812532505180932, -0.8511256962490754  …  -0.6975989957790674, -0.1267373900824793, -0.7680611239846544, 0.05434697196972921, -0.37228696540252937, -0.27322167175375744, -0.21843247887394646, 0.13146643630592514, -0.6346162034644767, -0.4115063354806173], [0.09061716549869571, 0.12513590134681052, -0.1720413122925467, -0.12924852337286047, 1.0985205507068339, 0.03453669117810146, 0.5

In [9]:
stats

2000-element Vector{NamedTuple}:
 (n_steps = 27, is_accept = true, acceptance_rate = 0.1738896591616117, log_density = -17.820112182773414, hamiltonian_energy = 74.53054895881289, hamiltonian_energy_error = 0.3107484262755946, max_hamiltonian_energy_error = 8.583434389930119, tree_depth = 4, numerical_error = false, step_size = 0.8, nom_step_size = 0.8, is_adapt = true)
 (n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -17.820112182773414, hamiltonian_energy = 65.91424991068507, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 793.8955719679075, tree_depth = 1, numerical_error = false, step_size = 2.5626941406301733, nom_step_size = 2.5626941406301733, is_adapt = true)
 (n_steps = 15, is_accept = true, acceptance_rate = 0.7576319610651432, log_density = -24.759148437505758, hamiltonian_energy = 69.36350357941645, hamiltonian_energy_error = 0.1335877118527975, max_hamiltonian_energy_error = 0.6751823219289292, tree_depth = 4, numerical_error = false, st