From fc6cae9e7a3bebc573a5cf0166d1200358cddac1 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 21 Nov 2023 19:26:17 +0100 Subject: [PATCH] Make ZygoteRules and ChainRulesCore weak dependencies (#564) * Make ZygoteRules and ChainRulesCore weak dependencies * Fix format * Add another non-differentiable to CRC extension * Perform coverage analysis on all Julia versions --- .github/workflows/CI.yml | 5 ----- Project.toml | 14 ++++++++++---- ext/DynamicPPLChainRulesCoreExt.jl | 27 +++++++++++++++++++++++++++ ext/DynamicPPLZygoteRulesExt.jl | 25 +++++++++++++++++++++++++ src/DynamicPPL.jl | 13 ++++++++----- src/compat/ad.jl | 22 ---------------------- src/utils.jl | 3 --- 7 files changed, 70 insertions(+), 39 deletions(-) create mode 100644 ext/DynamicPPLChainRulesCoreExt.jl create mode 100644 ext/DynamicPPLZygoteRulesExt.jl delete mode 100644 src/compat/ad.jl diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 4641e4ab8..a54b850a1 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -56,19 +56,14 @@ jobs: ${{ runner.os }}- - uses: julia-actions/julia-buildpkg@latest - uses: julia-actions/julia-runtest@latest - with: - coverage: ${{ matrix.version == '1' && matrix.os == 'ubuntu-latest' && matrix.num_threads == 1 }} env: GROUP: All JULIA_NUM_THREADS: ${{ matrix.num_threads }} - uses: julia-actions/julia-processcoverage@v1 - if: matrix.version == '1' && matrix.os == 'ubuntu-latest' && matrix.num_threads == 1 - uses: codecov/codecov-action@v1 - if: matrix.version == '1' && matrix.os == 'ubuntu-latest' && matrix.num_threads == 1 with: file: lcov.info - uses: coverallsapp/github-action@master - if: matrix.version == '1' && matrix.os == 'ubuntu-latest' && matrix.num_threads == 1 with: github-token: ${{ secrets.GITHUB_TOKEN }} path-to-lcov: lcov.info diff --git a/Project.toml b/Project.toml index 2b5da13d9..46b2bc046 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.24.2" +version = "0.24.3" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -23,12 +23,16 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [weakdeps] -MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [extensions] -DynamicPPLMCMCChainsExt = ["MCMCChains"] +DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] DynamicPPLEnzymeCoreExt = ["EnzymeCore"] +DynamicPPLMCMCChainsExt = ["MCMCChains"] +DynamicPPLZygoteRulesExt = ["ZygoteRules"] [compat] AbstractMCMC = "5" @@ -54,5 +58,7 @@ Test = "1.6" julia = "1.6" [extras] -MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" diff --git a/ext/DynamicPPLChainRulesCoreExt.jl b/ext/DynamicPPLChainRulesCoreExt.jl new file mode 100644 index 000000000..1c6e188fb --- /dev/null +++ b/ext/DynamicPPLChainRulesCoreExt.jl @@ -0,0 +1,27 @@ +module DynamicPPLChainRulesCoreExt + +if isdefined(Base, :get_extension) + using DynamicPPL: DynamicPPL, BangBang, Distributions + using ChainRulesCore: ChainRulesCore +else + using ..DynamicPPL: DynamicPPL, BangBang, Distributions + using ..ChainRulesCore: ChainRulesCore +end + +# See https://github.com/TuringLang/Turing.jl/issues/1199 +ChainRulesCore.@non_differentiable BangBang.push!!( + vi::DynamicPPL.VarInfo, + vn::DynamicPPL.VarName, + r, + dist::Distributions.Distribution, + gidset::Set{DynamicPPL.Selector}, +) + +ChainRulesCore.@non_differentiable DynamicPPL.updategid!( + vi::DynamicPPL.AbstractVarInfo, vn::DynamicPPL.VarName, spl::DynamicPPL.Sampler +) + +# No need + causes issues for some AD backends, e.g. Zygote. +ChainRulesCore.@non_differentiable DynamicPPL.infer_nested_eltype(x) + +end # module diff --git a/ext/DynamicPPLZygoteRulesExt.jl b/ext/DynamicPPLZygoteRulesExt.jl new file mode 100644 index 000000000..93b87e2a0 --- /dev/null +++ b/ext/DynamicPPLZygoteRulesExt.jl @@ -0,0 +1,25 @@ +module DynamicPPLZygoteRulesExt + +if isdefined(Base, :get_extension) + using DynamicPPL: DynamicPPL, Distributions + using ZygoteRules: ZygoteRules +else + using ..DynamicPPL: DynamicPPL, Distributions + using ..ZygoteRules: ZygoteRules +end + +# https://github.com/TuringLang/Turing.jl/issues/1595 +ZygoteRules.@adjoint function DynamicPPL.dot_observe( + spl::Union{DynamicPPL.SampleFromPrior,DynamicPPL.SampleFromUniform}, + dists::AbstractArray{<:Distributions.Distribution}, + value::AbstractArray, + vi, +) + function dot_observe_fallback(spl, dists, value, vi) + DynamicPPL.increment_num_produce!(vi) + return sum(map(Distributions.loglikelihood, dists, value)), vi + end + return ZygoteRules.pullback(__context__, dot_observe_fallback, spl, dists, value, vi) +end + +end # module diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index b5c29a34e..9d7eb6b7d 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -9,11 +9,9 @@ using OrderedCollections: OrderedDict using AbstractMCMC: AbstractMCMC using BangBang: BangBang, push!!, empty!!, setindex!! -using ChainRulesCore: ChainRulesCore using MacroTools: MacroTools using ConstructionBase: ConstructionBase using Setfield: Setfield -using ZygoteRules: ZygoteRules using LogDensityProblems: LogDensityProblems using LinearAlgebra: LinearAlgebra, Cholesky @@ -171,7 +169,6 @@ include("simple_varinfo.jl") include("context_implementations.jl") include("compiler.jl") include("prob_macro.jl") -include("compat/ad.jl") include("loglikelihoods.jl") include("submodel_macro.jl") include("test_utils.jl") @@ -186,12 +183,18 @@ end @static if !isdefined(Base, :get_extension) function __init__() - @require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include( - "../ext/DynamicPPLMCMCChainsExt.jl" + @require ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" include( + "../ext/DynamicPPLChainRulesCoreExt.jl" ) @require EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" include( "../ext/DynamicPPLEnzymeCoreExt.jl" ) + @require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include( + "../ext/DynamicPPLMCMCChainsExt.jl" + ) + @require ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" include( + "../ext/DynamicPPLZygoteRulesExt.jl" + ) end end diff --git a/src/compat/ad.jl b/src/compat/ad.jl deleted file mode 100644 index edcac7874..000000000 --- a/src/compat/ad.jl +++ /dev/null @@ -1,22 +0,0 @@ -# See https://github.com/TuringLang/Turing.jl/issues/1199 -ChainRulesCore.@non_differentiable push!!( - vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} -) - -ChainRulesCore.@non_differentiable updategid!( - vi::AbstractVarInfo, vn::VarName, spl::Sampler -) - -# https://github.com/TuringLang/Turing.jl/issues/1595 -ZygoteRules.@adjoint function dot_observe( - spl::Union{SampleFromPrior,SampleFromUniform}, - dists::AbstractArray{<:Distribution}, - value::AbstractArray, - vi, -) - function dot_observe_fallback(spl, dists, value, vi) - increment_num_produce!(vi) - return sum(map(Distributions.loglikelihood, dists, value)), vi - end - return ZygoteRules.pullback(__context__, dot_observe_fallback, spl, dists, value, vi) -end diff --git a/src/utils.jl b/src/utils.jl index ae79a7792..b447fed53 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -883,9 +883,6 @@ end # Handle `AbstractDict` differently since `eltype` results in a `Pair`. infer_nested_eltype(::Type{<:AbstractDict{<:Any,ET}}) where {ET} = infer_nested_eltype(ET) -# No need + causes issues for some AD backends, e.g. Zygote. -ChainRulesCore.@non_differentiable infer_nested_eltype(x) - """ varname_leaves(vn::VarName, val)