Skip to content

Commit

Permalink
went back to using DynamicPPL.condition rather than using custom
Browse files Browse the repository at this point in the history
`GibbsContext` while we wait for
TuringLang/DynamicPPL.jl#563 to be merged
  • Loading branch information
torfjelde committed Nov 19, 2023
1 parent ba8c6e1 commit 53bd707
Showing 1 changed file with 2 additions and 119 deletions.
121 changes: 2 additions & 119 deletions src/mcmc/gibbs_new.jl
@@ -1,103 +1,3 @@
# Basically like a `DynamicPPL.FixedContext` but
# 1. Hijacks the tilde pipeline to fix variables.
# 2. Computes the log-probability of the fixed variables.
#
# Purpose: avoid triggering resampling of variables we're conditioning on.
# - Using standard `DynamicPPL.condition` results in conditioned variables being treated
# as observations in the truest sense, i.e. we hit `DynamicPPL.tilde_observe`.
# - But `observe` is overloaded by some samplers, e.g. `CSMC`, which can lead to
# undesirable behavior, e.g. `CSMC` triggering a resampling for every conditioned variable
# rather than only for the "true" observations.
# - `GibbsContext` allows us to perform conditioning while still hit the `assume` pipeline
# rather than the `observe` pipeline for the conditioned variables.
struct GibbsContext{Values,Ctx<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext
values::Values
context::Ctx
end

Gibbscontext(values) = GibbsContext(values, DynamicPPL.DefaultContext())

DynamicPPL.NodeTrait(::GibbsContext) = DynamicPPL.IsParent()
DynamicPPL.childcontext(context::GibbsContext) = context.context
DynamicPPL.setchildcontext(context::GibbsContext, childcontext) = GibbsContext(context.values, childcontext)

# has and get
has_conditioned_gibbs(context::GibbsContext, vn::VarName) = DynamicPPL.hasvalue(context.values, vn)
function has_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName})
return all(Base.Fix1(has_conditioned_gibbs, context), vns)
end

get_conditioned_gibbs(context::GibbsContext, vn::VarName) = DynamicPPL.getvalue(context.values, vn)
function get_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName})
return map(Base.Fix1(get_conditioned_gibbs, context), vns)
end

# Tilde pipeline
function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi)
# Short-circuits the tilde assume if `vn` is present in `context`.
if has_conditioned_gibbs(context, vn)
value = get_conditioned_gibbs(context, vn)
return value, logpdf(right, value), vi
end

# Otherwise, falls back to the default behavior.
return DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi)
end

function DynamicPPL.tilde_assume(rng::Random.AbstractRNG, context::GibbsContext, sampler, right, vn, vi)
# Short-circuits the tilde assume if `vn` is present in `context`.
if has_conditioned_gibbs(context, vn)
value = get_conditioned_gibbs(context, vn)
return value, logpdf(right, value), vi
end

# Otherwise, falls back to the default behavior.
return DynamicPPL.tilde_assume(rng, DynamicPPL.childcontext(context), sampler, right, vn, vi)
end

# Some utility methods for handling the `logpdf` computations in dot-tilde the pipeline.
make_broadcastable(x) = x
make_broadcastable(dist::Distribution) = tuple(dist)

# Need the following two methods to properly support broadcasting over columns.
broadcast_logpdf(dist, x) = sum(logpdf.(make_broadcastable(dist), x))
function broadcast_logpdf(dist::MultivariateDistribution, x::AbstractMatrix)
return loglikelihood(dist, x)
end

reconstruct_getvalue(dist, x) = x
function reconstruct_getvalue(
dist::MultivariateDistribution,
x::AbstractVector{<:AbstractVector{<:Real}}
)
return reduce(hcat, x[2:end]; init=x[1])
end

function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi)
# Short-circuits the tilde assume if `vn` is present in `context`.
if has_conditioned_gibbs(context, vns)
value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns))
return value, broadcast_logpdf(right, values), vi
end

# Otherwise, falls back to the default behavior.
return DynamicPPL.dot_tilde_assume(DynamicPPL.childcontext(context), right, left, vns, vi)
end

function DynamicPPL.dot_tilde_assume(
rng::Random.AbstractRNG, context::GibbsContext, sampler, right, left, vns, vi
)
# Short-circuits the tilde assume if `vn` is present in `context`.
if has_conditioned_gibbs(context, vns)
values = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns))
return values, broadcast_logpdf(right, values), vi
end

# Otherwise, falls back to the default behavior.
return DynamicPPL.dot_tilde_assume(rng, DynamicPPL.childcontext(context), sampler, right, left, vns, vi)
end


preferred_value_type(::AbstractVarInfo) = OrderedDict
preferred_value_type(::SimpleVarInfo{<:NamedTuple}) = NamedTuple
function preferred_value_type(varinfo::DynamicPPL.TypedVarInfo)
Expand All @@ -108,28 +8,10 @@ function preferred_value_type(varinfo::DynamicPPL.TypedVarInfo)
return namedtuple_compatible ? NamedTuple : OrderedDict
end

# No-op if no values are provided.
condition_gibbs(context::DynamicPPL.AbstractContext) = context
# For `NamedTuple` and `AbstractDict` we just construct the context.
function condition_gibbs(context::DynamicPPL.AbstractContext, values::Union{NamedTuple,AbstractDict})
return GibbsContext(values, context)
end
# If we get more than one argument, we just recurse.
function condition_gibbs(context::DynamicPPL.AbstractContext, value, values...)
return condition_gibbs(
condition_gibbs(context, value),
values...
)
end
# For `AbstractVarInfo` we just extract the values.
function condition_gibbs(context::DynamicPPL.AbstractContext, varinfo::AbstractVarInfo)
function DynamicPPL.condition(context::DynamicPPL.AbstractContext, varinfo::AbstractVarInfo)
# TODO: Determine when it's okay to use `NamedTuple` and use that instead.
return condition_gibbs(context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo)))
end
# Allow calling this on a `Model` directly.
function condition_gibbs(model::Model, values...)
return DynamicPPL.contextualize(model, condition_gibbs(model.context, values...))
end


"""
Expand Down Expand Up @@ -162,6 +44,7 @@ true
function make_conditional(model::Model, target_varinfo::AbstractVarInfo, varinfos)
# TODO: Check if this is known at compile-time if `varinfos isa Tuple`.
return condition_gibbs(
return condition(
model,
filter(Base.Fix1(!==, target_varinfo), varinfos)...
)
Expand Down

0 comments on commit 53bd707

Please sign in to comment.