From 0b564150448e8a0b4ff32282bec85553b39b0e72 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 4 Mar 2024 13:20:21 +0000 Subject: [PATCH] Partial fix for #2095 (#2096) * use immutable link in the initialstep for HMC * bump patch version * added test * Update hmc.jl --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- src/mcmc/hmc.jl | 4 ++-- test/mcmc/hmc.jl | 11 +++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index f4f7446a1..a2d70e34b 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -131,13 +131,13 @@ function DynamicPPL.initialstep( rng::AbstractRNG, model::AbstractModel, spl::Sampler{<:Hamiltonian}, - vi::AbstractVarInfo; + vi_original::AbstractVarInfo; initial_params=nothing, nadapts=0, kwargs... ) # Transform the samples to unconstrained space and compute the joint log probability. - vi = link!!(vi, spl, model) + vi = DynamicPPL.link(vi_original, spl, model) # Extract parameters. theta = vi[spl] diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 3ffd8de06..6c61a4249 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -246,4 +246,15 @@ sample(demo_warn_initial_params(), NUTS(; adtype=adbackend), 5) end end + + @turing_testset "(partially) issue: #2095" begin + @model function vector_of_dirichlet(::Type{TV}=Vector{Float64}) where {TV} + xs = Vector{TV}(undef, 2) + xs[1] ~ Dirichlet(ones(5)) + xs[2] ~ Dirichlet(ones(5)) + end + model = vector_of_dirichlet() + chain = sample(model, NUTS(), 1000) + @test mean(Array(chain)) ≈ 0.2 + end end