Skip to content

Commit

Permalink
Add getmodel and setmodel from/to LogDensityFunction (#626)
Browse files Browse the repository at this point in the history
* initial copy and paste

* add some test

* Update ext/DynamicPPLReverseDiffExt.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update ext/DynamicPPLReverseDiffExt.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* update to the new implementation according to Turing

* Update src/logdensityfunction.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* error fixes

* Update src/logdensityfunction.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/logdensityfunction.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* add `HypothesisTests` to turing test dep

* version bump

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
sunxd3 and github-actions[bot] committed Jul 22, 2024
1 parent dfdc155 commit 36008f9
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.28.1"
version = "0.28.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
40 changes: 40 additions & 0 deletions src/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,46 @@ function getcontext(f::LogDensityFunction)
return f.context === nothing ? leafcontext(f.model.context) : f.context
end

"""
getmodel(f)
Return the `DynamicPPL.Model` wrapped in the given log-density function `f`.
"""
getmodel(f::LogDensityProblemsAD.ADGradientWrapper) =
getmodel(LogDensityProblemsAD.parent(f))
getmodel(f::DynamicPPL.LogDensityFunction) = f.model

"""
setmodel(f, model[, adtype])
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
!!! warning
Note that if `f` is a `LogDensityProblemsAD.ADGradientWrapper` wrapping a
`DynamicPPL.LogDensityFunction`, performing an update of the `model` in `f`
might require recompilation of the gradient tape, depending on the AD backend.
"""
function setmodel(
f::LogDensityProblemsAD.ADGradientWrapper,
model::DynamicPPL.Model,
adtype::ADTypes.AbstractADType,
)
# TODO: Should we handle `SciMLBase.NoAD`?
# For an `ADGradientWrapper` we do the following:
# 1. Update the `Model` in the underlying `LogDensityFunction`.
# 2. Re-construct the `ADGradientWrapper` using `ADgradient` using the provided `adtype`
# to ensure that the recompilation of gradient tapes, etc. also occur. For example,
# ReverseDiff.jl in compiled mode will cache the compiled tape, which means that just
# replacing the corresponding field with the new model won't be sufficient to obtain
# the correct gradients.
return LogDensityProblemsAD.ADgradient(
adtype, setmodel(LogDensityProblemsAD.parent(f), model)
)
end
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
return Accessors.@set f.model = model
end

# HACK: heavy usage of `AbstractSampler` for, well, _everything_, is being phased out. In the mean time
# we need to define these annoying methods to ensure that we stay compatible with everything.
getsampler(f::LogDensityFunction) = getsampler(getcontext(f))
Expand Down
22 changes: 21 additions & 1 deletion test/logdensityfunction.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,24 @@
using Test, DynamicPPL, LogDensityProblems
using Test, DynamicPPL, ADTypes, LogDensityProblems, LogDensityProblemsAD, ReverseDiff

@testset "`getmodel` and `setmodel`" begin
@testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS
model = DynamicPPL.TestUtils.DEMO_MODELS[1]
= DynamicPPL.LogDensityFunction(model)
@test DynamicPPL.getmodel(ℓ) == model
@test DynamicPPL.setmodel(ℓ, model).model == model

# ReverseDiff related
∇ℓ = LogDensityProblemsAD.ADgradient(:ReverseDiff, ℓ; compile=Val(false))
@test DynamicPPL.getmodel(∇ℓ) == model
@test DynamicPPL.getmodel(DynamicPPL.setmodel(∇ℓ, model, AutoReverseDiff())) ==
model
∇ℓ = LogDensityProblemsAD.ADgradient(:ReverseDiff, ℓ; compile=Val(true))
new_∇ℓ = DynamicPPL.setmodel(∇ℓ, model, AutoReverseDiff())
@test DynamicPPL.getmodel(new_∇ℓ) == model
# HACK(sunxd): rely on internal implementation detail, i.e., naming of `compiledtape`
@test new_∇ℓ.compiledtape != ∇ℓ.compiledtape
end
end

@testset "LogDensityFunction" begin
@testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS
Expand Down
1 change: 1 addition & 0 deletions test/turing/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down

2 comments on commit 36008f9

@sunxd3
Copy link
Collaborator Author

@sunxd3 sunxd3 commented on 36008f9 Jul 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

This is a maintenance release, we added some internal utility functions.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/111494

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.28.2 -m "<description of version>" 36008f9ef109de15daecaa93a48abf4b4905e6ae
git push origin v0.28.2

Please sign in to comment.