Skip to content

Commit

Permalink
PriorExtractorContext (#496)
Browse files Browse the repository at this point in the history
* first commt

* export context

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* missing comma

* Update src/contexts.jl

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

* fixed compilation

* extract_priors

* tests

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* tests for dot

* Update test/model.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* bug

* Apply suggestions from code review

* added docstring to extract_priors

* fixed and added more tests for extract_priors

* moved prior extraction to a separate file

* Update test/model.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* forgot to move a small piece of code

* added extract_priors to docs

* Update docs/src/api.md

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* added qualifiers to docstring of extract_priors

* Revert "added qualifiers to docstring of extract_priors"

This reverts commit cab9f9c.

* "fixed" the doctests as ran in docs making

* make calls to doctest consistent

* Update test/runtests.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
Co-authored-by: Tor Erlend Fjelde <tor.erlend95@gmail.com>
  • Loading branch information
4 people committed Jul 13, 2023
1 parent e6dd4ef commit e8172f0
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 1 deletion.
7 changes: 7 additions & 0 deletions docs/make.jl
@@ -1,6 +1,13 @@
using Documenter
using DynamicPPL
using DynamicPPL: AbstractPPL
# NOTE: This is necessary to ensure that if we print something from
# Distributions.jl in a doctest, then the shown value will not include
# a qualifier; that is, we don't want `Distributions.Normal{Float64}`
# but rather `Normal{Float64}`. The latter is what will then be printed
# in the doctest as run in `test/runtests.jl`, and so we need to stay
# consistent with that.
using Distributions

# Doctest setup
DocMeta.setdocmeta!(
Expand Down
6 changes: 6 additions & 0 deletions docs/src/api.md
Expand Up @@ -102,6 +102,12 @@ For a chain of samples, one can compute the pointwise log-likelihoods of each ob
pointwise_loglikelihoods
```

Sometimes it can be useful to extract the priors of a model. This is the possible using [`extract_priors`](@ref).

```@docs
extract_priors
```

```@docs
NamedDist
```
Expand Down
2 changes: 2 additions & 0 deletions src/DynamicPPL.jl
Expand Up @@ -86,6 +86,7 @@ export AbstractVarInfo,
getmissings,
getargnames,
generated_quantities,
extract_priors,
# Samplers
Sampler,
SampleFromPrior,
Expand Down Expand Up @@ -166,5 +167,6 @@ include("submodel_macro.jl")
include("test_utils.jl")
include("transforming.jl")
include("logdensityfunction.jl")
include("extract_priors.jl")

end # module
118 changes: 118 additions & 0 deletions src/extract_priors.jl
@@ -0,0 +1,118 @@
struct PriorExtractorContext{D<:OrderedDict{VarName,Any},Ctx<:AbstractContext} <:
AbstractContext
priors::D
context::Ctx
end

PriorExtractorContext(context) = PriorExtractorContext(OrderedDict{VarName,Any}(), context)

NodeTrait(::PriorExtractorContext) = IsParent()
childcontext(context::PriorExtractorContext) = context.context
function setchildcontext(parent::PriorExtractorContext, child)
return PriorExtractorContext(parent.priors, child)
end

function setprior!(context::PriorExtractorContext, vn::VarName, dist::Distribution)
return context.priors[vn] = dist
end

function setprior!(
context::PriorExtractorContext, vns::AbstractArray{<:VarName}, dist::Distribution
)
for vn in vns
context.priors[vn] = dist
end
end

function setprior!(
context::PriorExtractorContext,
vns::AbstractArray{<:VarName},
dists::AbstractArray{<:Distribution},
)
for (vn, dist) in zip(vns, dists)
context.priors[vn] = dist
end
end

function DynamicPPL.tilde_assume(context::PriorExtractorContext, right, vn, vi)
setprior!(context, vn, right)
return DynamicPPL.tilde_assume(childcontext(context), right, vn, vi)
end

function DynamicPPL.dot_tilde_assume(context::PriorExtractorContext, right, left, vn, vi)
setprior!(context, vn, right)
return DynamicPPL.dot_tilde_assume(childcontext(context), right, left, vn, vi)
end

"""
extract_priors([rng::Random.AbstractRNG, ]model::Model)
Extract the priors from a model.
This is done by sampling from the model and
recording the distributions that are used to generate the samples.
!!! warning
Because the extraction is done by execution of the model, there
are several caveats:
1. If one variable, say, `y ~ Normal(0, x)`, where `x ~ Normal()`
is also a random variable, then the extracted prior will have
different parameters in every extraction!
2. If the model does _not_ have static support, say,
`n ~ Categorical(1:10); x ~ MvNormmal(zeros(n), I)`, then the
extracted priors themselves will be different between extractions,
not just their parameters.
Both of these caveats are demonstrated below.
# Examples
## Changing parameters
```jldoctest
julia> using Distributions, StableRNGs
julia> rng = StableRNG(42);
julia> @model function model_dynamic_parameters()
x ~ Normal(0, 1)
y ~ Normal(x, 1)
end;
julia> model = model_dynamic_parameters();
julia> extract_priors(rng, model)[@varname(y)]
Normal{Float64}(μ=-0.6702516921145671, σ=1.0)
julia> extract_priors(rng, model)[@varname(y)]
Normal{Float64}(μ=1.3736306979834252, σ=1.0)
```
## Changing support
```jldoctest
julia> using LinearAlgebra, Distributions, StableRNGs
julia> rng = StableRNG(42);
julia> @model function model_dynamic_support()
n ~ Categorical(ones(10) ./ 10)
x ~ MvNormal(zeros(n), I)
end;
julia> model = model_dynamic_support();
julia> length(extract_priors(rng, model)[@varname(x)])
6
julia> length(extract_priors(rng, model)[@varname(x)])
9
```
"""
extract_priors(model::Model) = extract_priors(Random.default_rng(), model)
function extract_priors(rng::Random.AbstractRNG, model::Model)
context = PriorExtractorContext(SamplingContext(rng))
evaluate!!(model, VarInfo(), context)
return context.priors
end
31 changes: 31 additions & 0 deletions test/model.jl
Expand Up @@ -12,6 +12,19 @@ struct MyZeroModel end
return x ~ Normal(m, 1)
end

innermost_distribution_type(d::Distribution) = typeof(d)
function innermost_distribution_type(d::Distributions.ReshapedDistribution)
return innermost_distribution_type(d.dist)
end
function innermost_distribution_type(d::Distributions.Product)
dists = map(innermost_distribution_type, d.v)
if any(!=(dists[1]), dists)
error("Cannot extract innermost distribution type from $d")
end

return dists[1]
end

@testset "model.jl" begin
@testset "convenience functions" begin
model = gdemo_default
Expand Down Expand Up @@ -154,4 +167,22 @@ end
@model test_defaults(x, n=length(x)) = x ~ MvNormal(zeros(n), I)
@test length(test_defaults(missing, 2)()) == 2
end

@testset "extract priors" begin
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
priors = extract_priors(model)

# We know that any variable starting with `s` should have `InverseGamma`
# and any variable starting with `m` should have `Normal`.
for (vn, prior) in priors
if DynamicPPL.getsym(vn) == :s
@test innermost_distribution_type(prior) <: InverseGamma
elseif DynamicPPL.getsym(vn) == :m
@test innermost_distribution_type(prior) <: Union{Normal,MvNormal}
else
error("Unexpected variable name: $vn")
end
end
end
end
end
5 changes: 4 additions & 1 deletion test/runtests.jl
Expand Up @@ -60,7 +60,10 @@ include("test_util.jl")

@testset "doctests" begin
DocMeta.setdocmeta!(
DynamicPPL, :DocTestSetup, :(using DynamicPPL); recursive=true
DynamicPPL,
:DocTestSetup,
:(using DynamicPPL, Distributions);
recursive=true,
)
doctestfilters = [
# Older versions will show "0 element Array" instead of "Type[]".
Expand Down

0 comments on commit e8172f0

Please sign in to comment.