diff --git a/Project.toml b/Project.toml index 06cba77ab..7ee7d2f97 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.28.1" +version = "0.28.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 8935edc12..9e86590fa 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -76,6 +76,46 @@ function getcontext(f::LogDensityFunction) return f.context === nothing ? leafcontext(f.model.context) : f.context end +""" + getmodel(f) + +Return the `DynamicPPL.Model` wrapped in the given log-density function `f`. +""" +getmodel(f::LogDensityProblemsAD.ADGradientWrapper) = + getmodel(LogDensityProblemsAD.parent(f)) +getmodel(f::DynamicPPL.LogDensityFunction) = f.model + +""" + setmodel(f, model[, adtype]) + +Set the `DynamicPPL.Model` in the given log-density function `f` to `model`. + +!!! warning + Note that if `f` is a `LogDensityProblemsAD.ADGradientWrapper` wrapping a + `DynamicPPL.LogDensityFunction`, performing an update of the `model` in `f` + might require recompilation of the gradient tape, depending on the AD backend. +""" +function setmodel( + f::LogDensityProblemsAD.ADGradientWrapper, + model::DynamicPPL.Model, + adtype::ADTypes.AbstractADType, +) + # TODO: Should we handle `SciMLBase.NoAD`? + # For an `ADGradientWrapper` we do the following: + # 1. Update the `Model` in the underlying `LogDensityFunction`. + # 2. Re-construct the `ADGradientWrapper` using `ADgradient` using the provided `adtype` + # to ensure that the recompilation of gradient tapes, etc. also occur. For example, + # ReverseDiff.jl in compiled mode will cache the compiled tape, which means that just + # replacing the corresponding field with the new model won't be sufficient to obtain + # the correct gradients. + return LogDensityProblemsAD.ADgradient( + adtype, setmodel(LogDensityProblemsAD.parent(f), model) + ) +end +function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model) + return Accessors.@set f.model = model +end + # HACK: heavy usage of `AbstractSampler` for, well, _everything_, is being phased out. In the mean time # we need to define these annoying methods to ensure that we stay compatible with everything. getsampler(f::LogDensityFunction) = getsampler(getcontext(f)) diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index ea70ace29..beda767e6 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -1,4 +1,24 @@ -using Test, DynamicPPL, LogDensityProblems +using Test, DynamicPPL, ADTypes, LogDensityProblems, LogDensityProblemsAD, ReverseDiff + +@testset "`getmodel` and `setmodel`" begin + @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS + model = DynamicPPL.TestUtils.DEMO_MODELS[1] + ℓ = DynamicPPL.LogDensityFunction(model) + @test DynamicPPL.getmodel(ℓ) == model + @test DynamicPPL.setmodel(ℓ, model).model == model + + # ReverseDiff related + ∇ℓ = LogDensityProblemsAD.ADgradient(:ReverseDiff, ℓ; compile=Val(false)) + @test DynamicPPL.getmodel(∇ℓ) == model + @test DynamicPPL.getmodel(DynamicPPL.setmodel(∇ℓ, model, AutoReverseDiff())) == + model + ∇ℓ = LogDensityProblemsAD.ADgradient(:ReverseDiff, ℓ; compile=Val(true)) + new_∇ℓ = DynamicPPL.setmodel(∇ℓ, model, AutoReverseDiff()) + @test DynamicPPL.getmodel(new_∇ℓ) == model + # HACK(sunxd): rely on internal implementation detail, i.e., naming of `compiledtape` + @test new_∇ℓ.compiledtape != ∇ℓ.compiledtape + end +end @testset "LogDensityFunction" begin @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS diff --git a/test/turing/Project.toml b/test/turing/Project.toml index 501359253..ed2b08ce5 100644 --- a/test/turing/Project.toml +++ b/test/turing/Project.toml @@ -1,6 +1,7 @@ [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" +HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"