Skip to content

Commit

Permalink
Merge 8cfee33 into 3e71a76
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Apr 19, 2023
2 parents 3e71a76 + 8cfee33 commit c294f2d
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Project.toml
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.24.3"
version = "0.24.4"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
22 changes: 19 additions & 3 deletions src/inference/hmc.jl
Expand Up @@ -159,7 +159,14 @@ function DynamicPPL.initialstep(
metricT = getmetricT(spl.alg)
metric = metricT(length(theta))
= LogDensityProblemsAD.ADgradient(
Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext())
Turing.LogDensityFunction(
vi,
model,
# Use the leaf-context from the `model` in case the user has
# contextualized the model with something like `PriorContext`
# to sample from the prior.
DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context))
)
)
logπ = Base.Fix1(LogDensityProblems.logdensity, ℓ)
∂logπ∂θ(x) = LogDensityProblems.logdensity_and_gradient(ℓ, x)
Expand Down Expand Up @@ -265,7 +272,11 @@ end
function get_hamiltonian(model, spl, vi, state, n)
metric = gen_metric(n, spl, state)
= LogDensityProblemsAD.ADgradient(
Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext())
Turing.LogDensityFunction(
vi,
model,
DynamicPPL.SamplingContext(spl, DynamicPPL.leafcontext(model.context))
)
)
ℓπ = Base.Fix1(LogDensityProblems.logdensity, ℓ)
∂ℓπ∂θ = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ℓ)
Expand Down Expand Up @@ -538,7 +549,12 @@ function HMCState(

# Get the initial log pdf and gradient functions.
∂logπ∂θ = gen_∂logπ∂θ(vi, spl, model)
logπ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext())
logπ = Turing.LogDensityFunction(
vi,
model,
DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context))
)


# Get the metric type.
metricT = getmetricT(spl.alg)
Expand Down
18 changes: 16 additions & 2 deletions src/inference/mh.jl
Expand Up @@ -375,7 +375,14 @@ function propose!!(

# Make a new transition.
densitymodel = AMH.DensityModel(
Base.Fix1(LogDensityProblems.logdensity, Turing.LogDensityFunction(vi, model, DynamicPPL.SamplingContext(rng, spl)))
Base.Fix1(
LogDensityProblems.logdensity,
Turing.LogDensityFunction(
vi,
model,
DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context))
)
)
)
trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans)

Expand Down Expand Up @@ -403,7 +410,14 @@ function propose!!(

# Make a new transition.
densitymodel = AMH.DensityModel(
Base.Fix1(LogDensityProblems.logdensity, Turing.LogDensityFunction(vi, model, DynamicPPL.SamplingContext(rng, spl)))
Base.Fix1(
LogDensityProblems.logdensity,
Turing.LogDensityFunction(
vi,
model,
DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context))
)
)
)
trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans)

Expand Down
7 changes: 7 additions & 0 deletions test/inference/hmc.jl
Expand Up @@ -216,4 +216,11 @@
res3 = sample(StableRNG(123), gdemo_default, alg, 1000)
@test Array(res1) == Array(res2) == Array(res3)
end

@turing_testset "prior" begin
alg = NUTS(1000, 0.8)
gdemo_default_prior = DynamicPPL.contextualize(gdemo_default, DynamicPPL.PriorContext())
chain = sample(gdemo_default_prior, alg, 10_000)
check_numerical(chain, [:s, :m], [mean(InverseGamma(2, 3)), 0], atol=0.2)
end
end
12 changes: 12 additions & 0 deletions test/inference/mh.jl
Expand Up @@ -216,4 +216,16 @@
vi = Turing.Inference.maybe_link!!(vi, spl, alg.proposals, gdemo_default)
@test !DynamicPPL.islinked(vi, spl)
end

@turing_testset "prior" begin
# HACK: MH can be so bad for this prior model for some reason that it's difficult to
# find a non-trivial `atol` where the tests will pass for all seeds. Hence we fix it :/
rng = StableRNG(10)
alg = MH()
gdemo_default_prior = DynamicPPL.contextualize(gdemo_default, DynamicPPL.PriorContext())
burnin = 10_000
n = 10_000
chain = sample(rng, gdemo_default_prior, alg, n; discard_initial = burnin, thinning=10)
check_numerical(chain, [:s, :m], [mean(InverseGamma(2, 3)), 0], atol=0.3)
end
end

0 comments on commit c294f2d

Please sign in to comment.