diff --git a/.gitignore b/.gitignore index 7c61f44a5..198907c73 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ *.jl.mem .DS_Store Manifest.toml +**.~undo-tree~ diff --git a/Project.toml b/Project.toml index f5ddb0886..e85f3b59f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.23.11" +version = "0.23.12" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -20,6 +20,15 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" +[weakdeps] +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" + +[extensions] +DynamicPPLMCMCChainsExt = ["MCMCChains"] + +[extras] +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" + [compat] AbstractMCMC = "2, 3.0, 4" AbstractPPL = "0.6" @@ -31,6 +40,7 @@ Distributions = "0.23.8, 0.24, 0.25" DocStringExtensions = "0.8, 0.9" LogDensityProblems = "2" MacroTools = "0.5.6" +MCMCChains = "6" OrderedCollections = "1" Setfield = "0.7.1, 0.8, 1" ZygoteRules = "0.2" diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl new file mode 100644 index 000000000..de77f58f2 --- /dev/null +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -0,0 +1,16 @@ +module DynamicPPLMCMCChainsExt + +using DynamicPPL: DynamicPPL +using MCMCChains: MCMCChains + +function DynamicPPL.generated_quantities(model::DynamicPPL.Model, chain::MCMCChains.Chains) + chain_parameters = MCMCChains.get_sections(chain, :parameters) + varinfo = DynamicPPL.VarInfo(model) + iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) + return map(iters) do (sample_idx, chain_idx) + DynamicPPL.setval_and_resample!(varinfo, chain_parameters, sample_idx, chain_idx) + model(varinfo) + end +end + +end diff --git a/test/ext/DynamicPPLMCMCChainsExt.jl b/test/ext/DynamicPPLMCMCChainsExt.jl new file mode 100644 index 000000000..c19bf6f2d --- /dev/null +++ b/test/ext/DynamicPPLMCMCChainsExt.jl @@ -0,0 +1,9 @@ +@testset "DynamicPPLMCMCChainsExt" begin + @model demo() = x ~ Normal() + model = demo() + + chain = MCMCChains.Chains(randn(1000, 2, 1), [:x, :y], Dict(:internals => [:y])) + chain_generated = @test_nowarn generated_quantities(model, chain) + @test size(chain_generated) == (1000, 1) + @test mean(chain_generated) ≈ 0 atol = 0.1 +end diff --git a/test/runtests.jl b/test/runtests.jl index 2d1b521ce..74cabc272 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -59,6 +59,10 @@ include("test_util.jl") include(joinpath("compat", "ad.jl")) end + @testset "extensions" begin + include("ext/DynamicPPLMCMCChainsExt.jl") + end + @testset "doctests" begin DocMeta.setdocmeta!( DynamicPPL,