diff --git a/src/mcmc/gibbs_new.jl b/src/mcmc/gibbs_new.jl index 498adf60a..6dc0ac022 100644 --- a/src/mcmc/gibbs_new.jl +++ b/src/mcmc/gibbs_new.jl @@ -1,103 +1,3 @@ -# Basically like a `DynamicPPL.FixedContext` but -# 1. Hijacks the tilde pipeline to fix variables. -# 2. Computes the log-probability of the fixed variables. -# -# Purpose: avoid triggering resampling of variables we're conditioning on. -# - Using standard `DynamicPPL.condition` results in conditioned variables being treated -# as observations in the truest sense, i.e. we hit `DynamicPPL.tilde_observe`. -# - But `observe` is overloaded by some samplers, e.g. `CSMC`, which can lead to -# undesirable behavior, e.g. `CSMC` triggering a resampling for every conditioned variable -# rather than only for the "true" observations. -# - `GibbsContext` allows us to perform conditioning while still hit the `assume` pipeline -# rather than the `observe` pipeline for the conditioned variables. -struct GibbsContext{Values,Ctx<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext - values::Values - context::Ctx -end - -Gibbscontext(values) = GibbsContext(values, DynamicPPL.DefaultContext()) - -DynamicPPL.NodeTrait(::GibbsContext) = DynamicPPL.IsParent() -DynamicPPL.childcontext(context::GibbsContext) = context.context -DynamicPPL.setchildcontext(context::GibbsContext, childcontext) = GibbsContext(context.values, childcontext) - -# has and get -has_conditioned_gibbs(context::GibbsContext, vn::VarName) = DynamicPPL.hasvalue(context.values, vn) -function has_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) - return all(Base.Fix1(has_conditioned_gibbs, context), vns) -end - -get_conditioned_gibbs(context::GibbsContext, vn::VarName) = DynamicPPL.getvalue(context.values, vn) -function get_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) - return map(Base.Fix1(get_conditioned_gibbs, context), vns) -end - -# Tilde pipeline -function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) - # Short-circuits the tilde assume if `vn` is present in `context`. - if has_conditioned_gibbs(context, vn) - value = get_conditioned_gibbs(context, vn) - return value, logpdf(right, value), vi - end - - # Otherwise, falls back to the default behavior. - return DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) -end - -function DynamicPPL.tilde_assume(rng::Random.AbstractRNG, context::GibbsContext, sampler, right, vn, vi) - # Short-circuits the tilde assume if `vn` is present in `context`. - if has_conditioned_gibbs(context, vn) - value = get_conditioned_gibbs(context, vn) - return value, logpdf(right, value), vi - end - - # Otherwise, falls back to the default behavior. - return DynamicPPL.tilde_assume(rng, DynamicPPL.childcontext(context), sampler, right, vn, vi) -end - -# Some utility methods for handling the `logpdf` computations in dot-tilde the pipeline. -make_broadcastable(x) = x -make_broadcastable(dist::Distribution) = tuple(dist) - -# Need the following two methods to properly support broadcasting over columns. -broadcast_logpdf(dist, x) = sum(logpdf.(make_broadcastable(dist), x)) -function broadcast_logpdf(dist::MultivariateDistribution, x::AbstractMatrix) - return loglikelihood(dist, x) -end - -reconstruct_getvalue(dist, x) = x -function reconstruct_getvalue( - dist::MultivariateDistribution, - x::AbstractVector{<:AbstractVector{<:Real}} -) - return reduce(hcat, x[2:end]; init=x[1]) -end - -function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi) - # Short-circuits the tilde assume if `vn` is present in `context`. - if has_conditioned_gibbs(context, vns) - value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) - return value, broadcast_logpdf(right, values), vi - end - - # Otherwise, falls back to the default behavior. - return DynamicPPL.dot_tilde_assume(DynamicPPL.childcontext(context), right, left, vns, vi) -end - -function DynamicPPL.dot_tilde_assume( - rng::Random.AbstractRNG, context::GibbsContext, sampler, right, left, vns, vi -) - # Short-circuits the tilde assume if `vn` is present in `context`. - if has_conditioned_gibbs(context, vns) - values = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) - return values, broadcast_logpdf(right, values), vi - end - - # Otherwise, falls back to the default behavior. - return DynamicPPL.dot_tilde_assume(rng, DynamicPPL.childcontext(context), sampler, right, left, vns, vi) -end - - preferred_value_type(::AbstractVarInfo) = OrderedDict preferred_value_type(::SimpleVarInfo{<:NamedTuple}) = NamedTuple function preferred_value_type(varinfo::DynamicPPL.TypedVarInfo) @@ -108,28 +8,10 @@ function preferred_value_type(varinfo::DynamicPPL.TypedVarInfo) return namedtuple_compatible ? NamedTuple : OrderedDict end -# No-op if no values are provided. -condition_gibbs(context::DynamicPPL.AbstractContext) = context -# For `NamedTuple` and `AbstractDict` we just construct the context. -function condition_gibbs(context::DynamicPPL.AbstractContext, values::Union{NamedTuple,AbstractDict}) - return GibbsContext(values, context) -end -# If we get more than one argument, we just recurse. -function condition_gibbs(context::DynamicPPL.AbstractContext, value, values...) - return condition_gibbs( - condition_gibbs(context, value), - values... - ) -end -# For `AbstractVarInfo` we just extract the values. -function condition_gibbs(context::DynamicPPL.AbstractContext, varinfo::AbstractVarInfo) +function DynamicPPL.condition(context::DynamicPPL.AbstractContext, varinfo::AbstractVarInfo) # TODO: Determine when it's okay to use `NamedTuple` and use that instead. return condition_gibbs(context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo))) end -# Allow calling this on a `Model` directly. -function condition_gibbs(model::Model, values...) - return DynamicPPL.contextualize(model, condition_gibbs(model.context, values...)) -end """ @@ -162,6 +44,7 @@ true function make_conditional(model::Model, target_varinfo::AbstractVarInfo, varinfos) # TODO: Check if this is known at compile-time if `varinfos isa Tuple`. return condition_gibbs( + return condition( model, filter(Base.Fix1(!==, target_varinfo), varinfos)... )