Skip to content

Commit

Permalink
Mark Sampling context as not needing derivatives (#556)
Browse files Browse the repository at this point in the history
* Mark Sampling context as not needing derivatives

* Mark Sampling context as not needing derivatives

* Fix format

* Fix Project.toml

* Qualify SamplingContext

---------

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
  • Loading branch information
wsmoses and devmotion authored Nov 21, 2023
1 parent 34be85c commit 03e4ba2
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 0 deletions.
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"

[extensions]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]

[compat]
AbstractMCMC = "5"
Expand All @@ -38,6 +40,7 @@ Compat = "4"
ConstructionBase = "1.5.4"
Distributions = "0.25"
DocStringExtensions = "0.9"
EnzymeCore = "0.6"
LogDensityProblems = "2"
MCMCChains = "6"
MacroTools = "0.5.6"
Expand All @@ -52,3 +55,4 @@ julia = "1.6"

[extras]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
13 changes: 13 additions & 0 deletions ext/DynamicPPLEnzymeCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module DynamicPPLEnzymeCoreExt

if isdefined(Base, :get_extension)
using DynamicPPL: DynamicPPL
using EnzymeCore
else
using ..DynamicPPL: DynamicPPL
using ..EnzymeCore
end

@inline EnzymeCore.EnzymeRules.inactive_type(::Type{<:DynamicPPL.SamplingContext}) = true

end
3 changes: 3 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ end
@require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include(
"../ext/DynamicPPLMCMCChainsExt.jl"
)
@require EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" include(
"../ext/DynamicPPLEnzymeCoreExt.jl"
)
end
end

Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Expand Down
3 changes: 3 additions & 0 deletions test/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ using DynamicPPL:
hasconditioned_nested,
getconditioned_nested

using EnzymeCore

# Dummy context to test nested behaviors.
struct ParentContext{C<:AbstractContext} <: AbstractContext
context::C
Expand Down Expand Up @@ -252,6 +254,7 @@ end
@test SamplingContext(Random.default_rng(), DefaultContext()) == context
@test SamplingContext(SampleFromPrior(), DefaultContext()) == context
@test SamplingContext(SampleFromPrior(), DefaultContext()) == context
@test EnzymeCore.EnzymeRules.inactive_type(typeof(context))
end

@testset "FixedContext" begin
Expand Down

0 comments on commit 03e4ba2

Please sign in to comment.