-
Notifications
You must be signed in to change notification settings - Fork 25
/
extract_priors.jl
118 lines (90 loc) · 3.27 KB
/
extract_priors.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
struct PriorExtractorContext{D<:OrderedDict{VarName,Any},Ctx<:AbstractContext} <:
AbstractContext
priors::D
context::Ctx
end
PriorExtractorContext(context) = PriorExtractorContext(OrderedDict{VarName,Any}(), context)
NodeTrait(::PriorExtractorContext) = IsParent()
childcontext(context::PriorExtractorContext) = context.context
function setchildcontext(parent::PriorExtractorContext, child)
return PriorExtractorContext(parent.priors, child)
end
function setprior!(context::PriorExtractorContext, vn::VarName, dist::Distribution)
return context.priors[vn] = dist
end
function setprior!(
context::PriorExtractorContext, vns::AbstractArray{<:VarName}, dist::Distribution
)
for vn in vns
context.priors[vn] = dist
end
end
function setprior!(
context::PriorExtractorContext,
vns::AbstractArray{<:VarName},
dists::AbstractArray{<:Distribution},
)
for (vn, dist) in zip(vns, dists)
context.priors[vn] = dist
end
end
function DynamicPPL.tilde_assume(context::PriorExtractorContext, right, vn, vi)
setprior!(context, vn, right)
return DynamicPPL.tilde_assume(childcontext(context), right, vn, vi)
end
function DynamicPPL.dot_tilde_assume(context::PriorExtractorContext, right, left, vn, vi)
setprior!(context, vn, right)
return DynamicPPL.dot_tilde_assume(childcontext(context), right, left, vn, vi)
end
"""
extract_priors([rng::Random.AbstractRNG, ]model::Model)
Extract the priors from a model.
This is done by sampling from the model and
recording the distributions that are used to generate the samples.
!!! warning
Because the extraction is done by execution of the model, there
are several caveats:
1. If one variable, say, `y ~ Normal(0, x)`, where `x ~ Normal()`
is also a random variable, then the extracted prior will have
different parameters in every extraction!
2. If the model does _not_ have static support, say,
`n ~ Categorical(1:10); x ~ MvNormmal(zeros(n), I)`, then the
extracted priors themselves will be different between extractions,
not just their parameters.
Both of these caveats are demonstrated below.
# Examples
## Changing parameters
```jldoctest
julia> using Distributions, StableRNGs
julia> rng = StableRNG(42);
julia> @model function model_dynamic_parameters()
x ~ Normal(0, 1)
y ~ Normal(x, 1)
end;
julia> model = model_dynamic_parameters();
julia> extract_priors(rng, model)[@varname(y)]
Normal{Float64}(μ=-0.6702516921145671, σ=1.0)
julia> extract_priors(rng, model)[@varname(y)]
Normal{Float64}(μ=1.3736306979834252, σ=1.0)
```
## Changing support
```jldoctest
julia> using LinearAlgebra, Distributions, StableRNGs
julia> rng = StableRNG(42);
julia> @model function model_dynamic_support()
n ~ Categorical(ones(10) ./ 10)
x ~ MvNormal(zeros(n), I)
end;
julia> model = model_dynamic_support();
julia> length(extract_priors(rng, model)[@varname(x)])
6
julia> length(extract_priors(rng, model)[@varname(x)])
9
```
"""
extract_priors(model::Model) = extract_priors(Random.default_rng(), model)
function extract_priors(rng::Random.AbstractRNG, model::Model)
context = PriorExtractorContext(SamplingContext(rng))
evaluate!!(model, VarInfo(), context)
return context.priors
end