Skip to content

Commit

Permalink
Partial fix for #2095 (#2096)
Browse files Browse the repository at this point in the history
* use immutable link in the initialstep for HMC

* bump patch version

* added test

* Update hmc.jl

---------

Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>
  • Loading branch information
torfjelde and yebai committed Mar 4, 2024
1 parent 4b5e4d7 commit 0b56415
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/mcmc/hmc.jl
Expand Up @@ -131,13 +131,13 @@ function DynamicPPL.initialstep(
rng::AbstractRNG,
model::AbstractModel,
spl::Sampler{<:Hamiltonian},
vi::AbstractVarInfo;
vi_original::AbstractVarInfo;
initial_params=nothing,
nadapts=0,
kwargs...
)
# Transform the samples to unconstrained space and compute the joint log probability.
vi = link!!(vi, spl, model)
vi = DynamicPPL.link(vi_original, spl, model)

# Extract parameters.
theta = vi[spl]
Expand Down
11 changes: 11 additions & 0 deletions test/mcmc/hmc.jl
Expand Up @@ -246,4 +246,15 @@
sample(demo_warn_initial_params(), NUTS(; adtype=adbackend), 5)
end
end

@turing_testset "(partially) issue: #2095" begin
@model function vector_of_dirichlet(::Type{TV}=Vector{Float64}) where {TV}
xs = Vector{TV}(undef, 2)
xs[1] ~ Dirichlet(ones(5))
xs[2] ~ Dirichlet(ones(5))
end
model = vector_of_dirichlet()
chain = sample(model, NUTS(), 1000)
@test mean(Array(chain)) 0.2
end
end

0 comments on commit 0b56415

Please sign in to comment.