Skip to content

Commit

Permalink
Merge e62fba1 into 9deef5e
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Apr 18, 2024
2 parents 9deef5e + e62fba1 commit 13f34c7
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 0 deletions.
6 changes: 6 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ Sometimes it can be useful to extract the priors of a model. This is the possibl
extract_priors
```

Safe extraction of realizations from a given [`AbstractVarInfo`](@ref) can be done using [`extract_realizations`](@ref).

```@docs
extract_realizations
```

```@docs
NamedDist
```
Expand Down
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ export AbstractVarInfo,
getargnames,
generated_quantities,
extract_priors,
extract_realizations,
# Samplers
Sampler,
SampleFromPrior,
Expand Down
173 changes: 173 additions & 0 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -664,3 +664,176 @@ function fixed(context::FixedContext)
# precedence over decendants of `context`.
return merge(context.values, fixed(childcontext(context)))
end

"""
RealizationExtractorContext
A context that is used to extract realizations from a model.
This is particularly useful when working in unconstrained space, but one
wants to extract the realization of a model in a constrained space.
# Fields
$(TYPEDFIELDS)
"""
struct RealizationExtractorContext{T,C<:AbstractContext} <: AbstractContext
"values that are extracted from the model"
values::T
"child context"
context::C
end

RealizationExtractorContext(values) = RealizationExtractorContext(values, DefaultContext())
function RealizationExtractorContext(context::AbstractContext)
return RealizationExtractorContext(OrderedDict(), context)
end

NodeTrait(::RealizationExtractorContext) = IsParent()
childcontext(context::RealizationExtractorContext) = context.context
function setchildcontext(context::RealizationExtractorContext, child)
return RealizationExtractorContext(context.values, child)
end

function Base.push!(context::RealizationExtractorContext, vn::VarName, value)
return setindex!(context.values, value, vn)
end

function broadcast_push!(context::RealizationExtractorContext, vns, dists, values)
return push!.((context,), vns, values)
end

# This will be hit if we're broadcasting an `AbstractMatrix` over a `MultivariateDistribution`.
function broadcast_push!(
context::RealizationExtractorContext, vns::AbstractVector, values::AbstractMatrix
)
for (vn, col) in zip(vns, eachcol(values))
push!(context, vn, col)
end
end

# `tilde_asssume`
function tilde_assume(context::RealizationExtractorContext, right, vn, vi)
value, logp, vi = tilde_assume(childcontext(context), right, vn, vi)
# Save the value.
push!(context, vn, value)
# Pass on.
return value, logp, vi
end
function tilde_assume(
rng::Random.AbstractRNG, context::RealizationExtractorContext, sampler, right, vn, vi
)
value, logp, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi)
# Save the value.
push!(context, vn, value)
# Pass on.
return value, logp, vi
end

# `dot_tilde_assume`
function dot_tilde_assume(context::RealizationExtractorContext, right, left, vn, vi)
value, logp, vi = dot_tilde_assume(childcontext(context), right, left, vn, vi)

# Save the value.
# FIXME: This is not going to work for arbitrary broadcasting.
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
broadcast_push!(context, _vns, value)

return value, logp, vi
end
function dot_tilde_assume(
rng::Random.AbstractRNG,
context::RealizationExtractorContext,
sampler,
right,
left,
vn,
vi,
)
value, logp, vi = dot_tilde_assume(
rng, childcontext(context), sampler, right, left, vn, vi
)
# Save the value.
# FIXME: This is not going to work for arbitrary broadcasting.
_right, _left, _vns = unwrap_right_left_vns(right, left, vn)
broadcast_push!(context, _vns, value)

return value, logp, vi
end

"""
extract_realizations([rng::Random.AbstractRNG, ]model::Model[, varinfo::AbstractVarInfo])
Extract realizations from the `model` for a given `varinfo` through a evaluation of the model.
If no `varinfo` is provided, then this is effectively the same as
[`Base.rand(rng::Random.AbstractRNG, model::Model)`].
More specifically, this method attempts to extract the realization _as seen in the model_.
For example, `x[1] ~ truncated(Normal(); lower=0)` will result in a realization compatible
with `truncated(Normal(); lower=0)` regardless of whether `varinfo` is working in unconstrained
space.
Hence this method is a "safe" way of obtaining realizations in constrained space at the cost
of additional model evaluations.
# Examples
## When `VarInfo` fails
The following demonstrates a common pitfall when working with [`VarInfo`](@ref) and constrained variables.
```jldoctest
julia> using Distributions, StableRNGs
julia> rng = StableRNG(42);
julia> @model function model_changing_support()
x ~ Bernoulli(0.5)
y ~ x == 1 ? Uniform(0, 1) : Uniform(11, 12)
end;
julia> model = model_changing_support();
julia> # Construct initial type-stable `VarInfo`.
varinfo = VarInfo(rng, model);
julia> # Link it so it works in unconstrained space.
varinfo_linked = DynamicPPL.link(varinfo, model);
julia> # Perform computations in unconstrained space, e.g. changing the values of `θ`.
# Flip `x` so we hit the other support of `y`.
θ = [!varinfo[@varname(x)], rand(rng)];
julia> # Update the `VarInfo` with the new values.
varinfo_linked = DynamicPPL.unflatten(varinfo_linked, θ);
julia> # Determine the expected support of `y`.
lb, ub = θ[1] == 1 ? (0, 1) : (11, 12)
(0, 1)
julia> # Approach 1: Convert back to constrained space using `invlink` and extract.
varinfo_invlinked = DynamicPPL.invlink(varinfo_linked, model);
julia> # (×) Fails! Because `VarInfo` _saves_ the original distributions
# used in the very first model evaluation, hence the support of `y`
# is not updated even though `x` has changed.
lb ≤ varinfo_invlinked[@varname(y)] ≤ ub
false
julia> # Approach 2: Extract realizations using `extract_realizations`.
# (✓) `extract_realizations` will re-run the model and extract
# the correct realization of `y` given the new values of `x`.
lb ≤ extract_realizations(model, varinfo_linked)[@varname(y)] ≤ ub
true
```
"""
function extract_realizations(model::Model, varinfo::AbstractVarInfo=VarInfo())
return extract_realizations(Random.default_rng(), model, varinfo)
end
function extract_realizations(
rng::Random.AbstractRNG, model::Model, varinfo::AbstractVarInfo=VarInfo()
)
context = RealizationExtractorContext(DefaultContext())
evaluate!!(model, varinfo, context)
return context.values
end

0 comments on commit 13f34c7

Please sign in to comment.