From 7d312cd593f8b0b6daab3eeed75e39f6ee8adc00 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 13 Jul 2023 20:14:49 +0100 Subject: [PATCH] Add `fix` and `unfix` (#488) * aded FixedContext and everything that goes with it * initial work on making fix compat with compiler * added support for dot tilde * exported and added testing for fix and unfix * added equivalent support to condition * fixed some docstrings * added lots of documentation plus some doctests for fixing * added docs on fix * bump patch version * renamd getvalue and hasvalue for contexts to more descriptive get_conditioned_value and has_conditioned_value * formatting * Update src/model.jl * fixeed typo in docstring * fixed docs * Apply suggestions from code review Co-authored-by: David Widmann * Update src/model.jl * Apply suggestions from code review * Update Project.toml * Update docs/src/api.md Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> Co-authored-by: David Widmann Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- Project.toml | 2 +- docs/src/api.md | 29 +++- src/DynamicPPL.jl | 2 + src/compiler.jl | 52 ++++++- src/contexts.jl | 230 +++++++++++++++++++++++---- src/model.jl | 387 +++++++++++++++++++++++++++++++++++++++++++++- test/contexts.jl | 51 ++++-- 7 files changed, 705 insertions(+), 48 deletions(-) diff --git a/Project.toml b/Project.toml index eb03e7c7c..f82d47148 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.23.1" +version = "0.23.2" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/docs/src/api.md b/docs/src/api.md index b30c52688..7b516c3e3 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -82,6 +82,34 @@ Similarly, one can specify with [`AbstractPPL.decondition`](@ref) that certain, decondition ``` +## Fixing and unfixing + +We can also _fix_ a collection of variables in a [`Model`](@ref) to certain using [`fix`](@ref). + +This might seem quite similar to the aforementioned [`condition`](@ref) and its siblings, +but they are indeed different operations: + + - `condition`ed variables are considered to be _observations_, and are thus + included in the computation [`logjoint`](@ref) and [`loglikelihood`](@ref), + but not in [`logprior`](@ref). + - `fix`ed variables are considered to be _constant_, and are thus not included + in any log-probability computations. + +The differences are more clearly spelled out in the docstring of [`fix`](@ref) below. + +```@docs +fix +DynamicPPL.fixed +``` + +The difference between [`fix`](@ref) and [`condition`](@ref) is described in the docstring of [`fix`](@ref) above. + +Similarly, we can [`unfix`](@ref) variables, i.e. return them to their original meaning: + +```@docs +unfix +``` + ## Utilities It is possible to manually increase (or decrease) the accumulated log density from within a model function. @@ -327,4 +355,3 @@ dot_tilde_assume tilde_observe dot_tilde_observe ``` - diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index be65c37b7..9ab1c51f5 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -119,6 +119,8 @@ export AbstractVarInfo, pointwise_loglikelihoods, condition, decondition, + fix, + unfix, # Convenience macros @addlogprob!, @submodel diff --git a/src/compiler.jl b/src/compiler.jl index bdd413630..ffdcd4755 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -66,8 +66,8 @@ function contextual_isassumption(context::AbstractContext, vn) return contextual_isassumption(NodeTrait(context), context, vn) end function contextual_isassumption(context::ConditionContext, vn) - if hasvalue(context, vn) - val = getvalue(context, vn) + if hasconditioned(context, vn) + val = getconditioned(context, vn) # TODO: Do we even need the `>: Missing`, i.e. does it even help the compiler? if eltype(val) >: Missing && val === missing return true @@ -76,7 +76,7 @@ function contextual_isassumption(context::ConditionContext, vn) end end - # We might have nested contexts, e.g. `ContextionContext{.., <:PrefixContext{..., <:ConditionContext}}` + # We might have nested contexts, e.g. `ConditionContext{.., <:PrefixContext{..., <:ConditionContext}}` # so we defer to `childcontext` if we haven't concluded that anything yet. return contextual_isassumption(childcontext(context), vn) end @@ -84,6 +84,40 @@ function contextual_isassumption(context::PrefixContext, vn) return contextual_isassumption(childcontext(context), prefix(context, vn)) end +isfixed(expr, vn) = false +isfixed(::Union{Symbol,Expr}, vn) = :($(DynamicPPL.contextual_isfixed)(__context__, $vn)) + +""" + contextual_isfixed(context, vn) + +Return `true` if `vn` is considered fixed by `context`. +""" +contextual_isfixed(::IsLeaf, context, vn) = false +function contextual_isfixed(::IsParent, context, vn) + return contextual_isfixed(childcontext(context), vn) +end +function contextual_isfixed(context::AbstractContext, vn) + return contextual_isfixed(NodeTrait(context), context, vn) +end +function contextual_isfixed(context::PrefixContext, vn) + return contextual_isfixed(childcontext(context), prefix(context, vn)) +end +function contextual_isfixed(context::FixedContext, vn) + if hasfixed(context, vn) + val = getfixed(context, vn) + # TODO: Do we even need the `>: Missing`, i.e. does it even help the compiler? + if eltype(val) >: Missing && val === missing + return false + else + return true + end + end + + # We might have nested contexts, e.g. `FixedContext{.., <:PrefixContext{..., <:FixedContext}}` + # so we defer to `childcontext` if we haven't concluded that anything yet. + return contextual_isfixed(childcontext(context), vn) +end + # If we're working with, say, a `Symbol`, then we're not going to `view`. maybe_view(x) = x maybe_view(x::Expr) = :(@views($x)) @@ -341,12 +375,14 @@ function generate_tilde(left, right) $(AbstractPPL.drop_escape(varname(left))), $dist ) $isassumption = $(DynamicPPL.isassumption(left, vn)) - if $isassumption + if $(DynamicPPL.isfixed(left, vn)) + $left = $(DynamicPPL.getfixed_nested)(__context__, $vn) + elseif $isassumption $(generate_tilde_assume(left, dist, vn)) else # If `vn` is not in `argnames`, we need to make sure that the variable is defined. if !$(DynamicPPL.inargnames)($vn, __model__) - $left = $(DynamicPPL.getvalue_nested)(__context__, $vn) + $left = $(DynamicPPL.getconditioned_nested)(__context__, $vn) end $value, __varinfo__ = $(DynamicPPL.tilde_observe!!)( @@ -400,12 +436,14 @@ function generate_dot_tilde(left, right) $(AbstractPPL.drop_escape(varname(left))), $right ) $isassumption = $(DynamicPPL.isassumption(left, vn)) - if $isassumption + if $(DynamicPPL.isfixed(left, vn)) + $left .= $(DynamicPPL.getfixed_nested)(__context__, $vn) + elseif $isassumption $(generate_dot_tilde_assume(left, right, vn)) else # If `vn` is not in `argnames`, we need to make sure that the variable is defined. if !$(DynamicPPL.inargnames)($vn, __model__) - $left .= $(DynamicPPL.getvalue_nested)(__context__, $vn) + $left .= $(DynamicPPL.getconditioned_nested)(__context__, $vn) end $value, __varinfo__ = $(DynamicPPL.dot_tilde_observe!!)( diff --git a/src/contexts.jl b/src/contexts.jl index 36b86add0..83da5d929 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -320,67 +320,67 @@ childcontext(context::ConditionContext) = context.context setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.values, child) """ - hasvalue(context::AbstractContext, vn::VarName) + hasconditioned(context::AbstractContext, vn::VarName) Return `true` if `vn` is found in `context`. """ -hasvalue(context::AbstractContext, vn::VarName) = false -hasvalue(context::ConditionContext, vn::VarName) = hasvalue(context.values, vn) -function hasvalue(context::ConditionContext, vns::AbstractArray{<:VarName}) +hasconditioned(context::AbstractContext, vn::VarName) = false +hasconditioned(context::ConditionContext, vn::VarName) = hasvalue(context.values, vn) +function hasconditioned(context::ConditionContext, vns::AbstractArray{<:VarName}) return all(Base.Fix1(hasvalue, context.values), vns) end """ - getvalue(context::AbstractContext, vn::VarName) + getconditioned(context::AbstractContext, vn::VarName) Return value of `vn` in `context`. """ -function getvalue(context::AbstractContext, vn::VarName) +function getconditioned(context::AbstractContext, vn::VarName) return error("context $(context) does not contain value for $vn") end -getvalue(context::ConditionContext, vn::VarName) = getvalue(context.values, vn) +getconditioned(context::ConditionContext, vn::VarName) = getvalue(context.values, vn) """ - hasvalue_nested(context, vn) + hasconditioned_nested(context, vn) Return `true` if `vn` is found in `context` or any of its descendants. -This is contrast to [`hasvalue(::AbstractContext, ::VarName)`](@ref) which only checks +This is contrast to [`hasconditioned(::AbstractContext, ::VarName)`](@ref) which only checks for `vn` in `context`, not recursively checking if `vn` is in any of its descendants. """ -function hasvalue_nested(context::AbstractContext, vn) - return hasvalue_nested(NodeTrait(hasvalue_nested, context), context, vn) +function hasconditioned_nested(context::AbstractContext, vn) + return hasconditioned_nested(NodeTrait(hasconditioned_nested, context), context, vn) end -hasvalue_nested(::IsLeaf, context, vn) = hasvalue(context, vn) -function hasvalue_nested(::IsParent, context, vn) - return hasvalue(context, vn) || hasvalue_nested(childcontext(context), vn) +hasconditioned_nested(::IsLeaf, context, vn) = hasconditioned(context, vn) +function hasconditioned_nested(::IsParent, context, vn) + return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn) end -function hasvalue_nested(context::PrefixContext, vn) - return hasvalue_nested(childcontext(context), prefix(context, vn)) +function hasconditioned_nested(context::PrefixContext, vn) + return hasconditioned_nested(childcontext(context), prefix(context, vn)) end """ - getvalue_nested(context, vn) + getconditioned_nested(context, vn) Return the value of the parameter corresponding to `vn` from `context` or its descendants. -This is contrast to [`getvalue`](@ref) which only returns the value `vn` in `context`, +This is contrast to [`getconditioned`](@ref) which only returns the value `vn` in `context`, not recursively looking into its descendants. """ -function getvalue_nested(context::AbstractContext, vn) - return getvalue_nested(NodeTrait(getvalue_nested, context), context, vn) +function getconditioned_nested(context::AbstractContext, vn) + return getconditioned_nested(NodeTrait(getconditioned_nested, context), context, vn) end -function getvalue_nested(::IsLeaf, context, vn) +function getconditioned_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end -function getvalue_nested(context::PrefixContext, vn) - return getvalue_nested(childcontext(context), prefix(context, vn)) +function getconditioned_nested(context::PrefixContext, vn) + return getconditioned_nested(childcontext(context), prefix(context, vn)) end -function getvalue_nested(::IsParent, context, vn) - return if hasvalue(context, vn) - getvalue(context, vn) +function getconditioned_nested(::IsParent, context, vn) + return if hasconditioned(context, vn) + getconditioned(context, vn) else - getvalue_nested(childcontext(context), vn) + getconditioned_nested(childcontext(context), vn) end end @@ -488,3 +488,179 @@ function conditioned(context::ConditionContext) # precedence over decendants of `context`. return merge(context.values, conditioned(childcontext(context))) end + +struct FixedContext{Values,Ctx<:AbstractContext} <: AbstractContext + values::Values + context::Ctx +end + +const NamedFixedContext{Names} = FixedContext{<:NamedTuple{Names}} +const DictFixedContext = FixedContext{<:AbstractDict} + +FixedContext(values) = FixedContext(values, DefaultContext()) + +# Try to avoid nested `FixedContext`. +function FixedContext(values::NamedTuple, context::NamedFixedContext) + # Note that this potentially overrides values from `context`, thus giving + # precedence to the outmost `FixedContext`. + return FixedContext(merge(context.values, values), childcontext(context)) +end + +function Base.show(io::IO, context::FixedContext) + return print(io, "FixedContext($(context.values), $(childcontext(context)))") +end + +NodeTrait(::FixedContext) = IsParent() +childcontext(context::FixedContext) = context.context +setchildcontext(parent::FixedContext, child) = FixedContext(parent.values, child) + +""" + hasfixed(context::AbstractContext, vn::VarName) + +Return `true` if a fixed value for `vn` is found in `context`. +""" +hasfixed(context::AbstractContext, vn::VarName) = false +hasfixed(context::FixedContext, vn::VarName) = hasvalue(context.values, vn) +function hasfixed(context::FixedContext, vns::AbstractArray{<:VarName}) + return all(Base.Fix1(hasvalue, context.values), vns) +end + +""" + getfixed(context::AbstractContext, vn::VarName) + +Return the fixed value of `vn` in `context`. +""" +function getfixed(context::AbstractContext, vn::VarName) + return error("context $(context) does not contain value for $vn") +end +getfixed(context::FixedContext, vn::VarName) = getvalue(context.values, vn) + +""" + hasfixed_nested(context, vn) + +Return `true` if a fixed value for `vn` is found in `context` or any of its descendants. + +This is contrast to [`hasfixed(::AbstractContext, ::VarName)`](@ref) which only checks +for `vn` in `context`, not recursively checking if `vn` is in any of its descendants. +""" +function hasfixed_nested(context::AbstractContext, vn) + return hasfixed_nested(NodeTrait(hasfixed_nested, context), context, vn) +end +hasfixed_nested(::IsLeaf, context, vn) = hasfixed(context, vn) +function hasfixed_nested(::IsParent, context, vn) + return hasfixed(context, vn) || hasfixed_nested(childcontext(context), vn) +end +function hasfixed_nested(context::PrefixContext, vn) + return hasfixed_nested(childcontext(context), prefix(context, vn)) +end + +""" + getfixed_nested(context, vn) + +Return the fixed value of the parameter corresponding to `vn` from `context` or its descendants. + +This is contrast to [`getfixed`](@ref) which only returns the value `vn` in `context`, +not recursively looking into its descendants. +""" +function getfixed_nested(context::AbstractContext, vn) + return getfixed_nested(NodeTrait(getfixed_nested, context), context, vn) +end +function getfixed_nested(::IsLeaf, context, vn) + return error("context $(context) does not contain value for $vn") +end +function getfixed_nested(context::PrefixContext, vn) + return getfixed_nested(childcontext(context), prefix(context, vn)) +end +function getfixed_nested(::IsParent, context, vn) + return if hasfixed(context, vn) + getfixed(context, vn) + else + getfixed_nested(childcontext(context), vn) + end +end + +""" + fix([context::AbstractContext,] values::NamedTuple) + fix([context::AbstractContext]; values...) + +Return `FixedContext` with `values` and `context` if `values` is non-empty, +otherwise return `context` which is [`DefaultContext`](@ref) by default. + +See also: [`unfix`](@ref) +""" +fix(; values...) = fix(NamedTuple(values)) +fix(values::NamedTuple) = fix(DefaultContext(), values) +function fix(value::Pair{<:VarName}, values::Pair{<:VarName}...) + return fix((value, values...)) +end +function fix(values::NTuple{<:Any,<:Pair{<:VarName}}) + return fix(DefaultContext(), values) +end +fix(context::AbstractContext, values::NamedTuple{()}) = context +function fix(context::AbstractContext, values::Union{AbstractDict,NamedTuple}) + return FixedContext(values, context) +end +function fix(context::AbstractContext; values...) + return fix(context, NamedTuple(values)) +end +function fix(context::AbstractContext, value::Pair{<:VarName}, values::Pair{<:VarName}...) + return fix(context, (value, values...)) +end +function fix(context::AbstractContext, values::NTuple{<:Any,Pair{<:VarName}}) + return fix(context, Dict(values)) +end + +""" + unfix(context::AbstractContext, syms...) + +Return `context` but with `syms` no longer fixed. + +Note that this recursively traverses contexts, unfixing all along the way. + +See also: [`fix`](@ref) +""" +unfix(::IsLeaf, context, args...) = context +function unfix(::IsParent, context, args...) + return setchildcontext(context, unfix(childcontext(context), args...)) +end +function unfix(context, args...) + return unfix(NodeTrait(context), context, args...) +end +function unfix(context::FixedContext) + return unfix(childcontext(context)) +end +function unfix(context::FixedContext, sym) + return fix(unfix(childcontext(context), sym), BangBang.delete!!(context.values, sym)) +end +function unfix(context::FixedContext, sym, syms...) + return unfix( + fix(unfix(childcontext(context), syms...), BangBang.delete!!(context.values, sym)), + syms..., + ) +end + +function unfix(context::NamedFixedContext, vn::VarName{sym}) where {sym} + return fix(unfix(childcontext(context), vn), BangBang.delete!!(context.values, sym)) +end +function unfix(context::FixedContext, vn::VarName) + return fix(unfix(childcontext(context), vn), BangBang.delete!!(context.values, vn)) +end + +""" + fixed(context::AbstractContext) + +Return the values that are fixed under `context`. + +Note that this will recursively traverse the context stack and return +a merged version of the fix values. +""" +fixed(context::AbstractContext) = fixed(NodeTrait(fixed, context), context) +fixed(::IsLeaf, context) = () +fixed(::IsParent, context) = fixed(childcontext(context)) +function fixed(context::FixedContext) + # Note the order of arguments to `merge`. The behavior of the rest of DPPL + # is that the outermost `context` takes precendence, hence when resolving + # the `fixed` variables we need to ensure that `context.values` takes + # precedence over decendants of `context`. + return merge(context.values, fixed(childcontext(context))) +end diff --git a/src/model.jl b/src/model.jl index f6f807652..1613efab6 100644 --- a/src/model.jl +++ b/src/model.jl @@ -308,7 +308,7 @@ This is essentially the inverse of [`condition`](@ref). This also means that it suffers from the same limitiations. Note that currently we only support `variables` to take on explicit values -provided to `condition. +provided to `condition`. # Examples ```jldoctest decondition @@ -404,7 +404,6 @@ julia> # (✓) this works though julia> m = deconditioned_model_2(); (m[1] ≠ 1.0 && m[2] == 2.0) true ``` - """ function AbstractPPL.decondition(model::Model, syms...) return contextualize(model, decondition(model.context, syms...)) @@ -420,7 +419,7 @@ observations(model::Model) = conditioned(model) """ conditioned(model::Model) -Return `NamedTuple` of values that are conditioned on under `model`. +Return the conditioned values in `model`. # Examples ```jldoctest @@ -467,6 +466,388 @@ VarName[] """ conditioned(model::Model) = conditioned(model.context) +""" + fix(model::Model; values...) + fix(model::Model, values::NamedTuple) + +Return a `Model` which now treats the variables in `values` as fixed. + +See also: [`unfix`](@ref), [`fixed`](@ref) + +# Examples +## Simple univariate model +```jldoctest fix +julia> using Distributions + +julia> @model function demo() + m ~ Normal() + x ~ Normal(m, 1) + return (; m=m, x=x) + end +demo (generic function with 2 methods) + +julia> model = demo(); + +julia> m, x = model(); (m ≠ 1.0 && x ≠ 100.0) +true + +julia> # Create a new instance which treats `x` as observed + # with value `100.0`, and similarly for `m=1.0`. + fixed_model = fix(model, x=100.0, m=1.0); + +julia> m, x = fixed_model(); (m == 1.0 && x == 100.0) +true + +julia> # Let's only fix on `x = 100.0`. + fixed_model = fix(model, x = 100.0); + +julia> m, x = fixed_model(); (m ≠ 1.0 && x == 100.0) +true +``` + +The above uses a `NamedTuple` to hold the fixed variables, which allows us to perform some +additional optimizations; in many cases, the above has zero runtime-overhead. + +But we can also use a `Dict`, which offers more flexibility in the fixing +(see examples further below) but generally has worse performance than the `NamedTuple` +approach: + +```jldoctest fix +julia> fixed_model_dict = fix(model, Dict(@varname(x) => 100.0)); + +julia> m, x = fixed_model_dict(); (m ≠ 1.0 && x == 100.0) +true + +julia> # Alternative: pass `Pair{<:VarName}` as positional argument. + fixed_model_dict = fix(model, @varname(x) => 100.0, ); + +julia> m, x = fixed_model_dict(); (m ≠ 1.0 && x == 100.0) +true +``` + +## Fix only a part of a multivariate variable + +We can not only fix multivariate random variables, but +we can also use the standard mechanism of setting something to `missing` +in the call to `fix` to only fix a part of the variable. + +```jldoctest fix +julia> @model function demo_mv(::Type{TV}=Float64) where {TV} + m = Vector{TV}(undef, 2) + m[1] ~ Normal() + m[2] ~ Normal() + return m + end +demo_mv (generic function with 4 methods) + +julia> model = demo_mv(); + +julia> fixed_model = fix(model, m = [missing, 1.0]); + +julia> # (✓) `m[1]` sampled while `m[2]` is fixed + m = fixed_model(); (m[1] ≠ 1.0 && m[2] == 1.0) +true +``` + +Intuitively one might also expect to be able to write something like `fix(model, var\"m[1]\" = 1.0, )`. +Unfortunately this is not supported as it has the potential of increasing compilation +times but without offering any benefit with respect to runtime: + +```jldoctest fix +julia> # (×) `m[2]` is not set to 1.0. + m = fix(model, var"m[2]" = 1.0)(); m[2] == 1.0 +false +``` + +But you _can_ do this if you use a `Dict` as the underlying storage instead: + +```jldoctest fix +julia> # Alternative: `fix(model, Dict(@varname(m[2] => 1.0)))` + # (✓) `m[2]` is set to 1.0. + m = fix(model, @varname(m[2]) => 1.0)(); (m[1] ≠ 1.0 && m[2] == 1.0) +true +``` + +## Nested models + +`fix` of course also supports the use of nested models through +the use of [`@submodel`](@ref). + +```jldoctest fix +julia> @model demo_inner() = m ~ Normal() +demo_inner (generic function with 2 methods) + +julia> @model function demo_outer() + @submodel m = demo_inner() + return m + end +demo_outer (generic function with 2 methods) + +julia> model = demo_outer(); + +julia> model() ≠ 1.0 +true + +julia> fixed_model = model | (m = 1.0, ); + +julia> fixed_model() +1.0 +``` + +But one needs to be careful when prefixing variables in the nested models: + +```jldoctest fix +julia> @model function demo_outer_prefix() + @submodel prefix="inner" m = demo_inner() + return m + end +demo_outer_prefix (generic function with 2 methods) + +julia> # (×) This doesn't work now! + fixed_model = demo_outer_prefix() | (m = 1.0, ); + +julia> fixed_model() == 1.0 +false + +julia> # (✓) `m` in `demo_inner` is referred to as `inner.m` internally, so we do: + fixed_model = demo_outer_prefix() | (var"inner.m" = 1.0, ); + +julia> fixed_model() +1.0 + +julia> # Note that the above `var"..."` is just standard Julia syntax: + keys((var"inner.m" = 1.0, )) +(Symbol("inner.m"),) +``` + +And similarly when using `Dict`: + +```jldoctest fix +julia> fixed_model_dict = demo_outer_prefix() | (@varname(var"inner.m") => 1.0); + +julia> fixed_model_dict() +1.0 +``` + +The difference is maybe more obvious once we look at how these different +in their trace/`VarInfo`: + +```jldoctest fix +julia> keys(VarInfo(demo_outer())) +1-element Vector{VarName{:m, Setfield.IdentityLens}}: + m + +julia> keys(VarInfo(demo_outer_prefix())) +1-element Vector{VarName{Symbol("inner.m"), Setfield.IdentityLens}}: + inner.m +``` + +From this we can tell what the correct way to fix `m` within `demo_inner` +is in the two different models. + +## Difference from `condition` + +A very similar functionality is also provided by [`condition`](@ref) which, +not surprisingly, _conditions_ variables instead of fixing them. The only +difference between fixing and conditioning is as follows: +- `condition`ed variables are considered to be observations, and are thus + included in the computation [`logjoint`](@ref) and [`loglikelihood`](@ref), + but not in [`logprior`](@ref). +- `fix`ed variables are considered to be constant, and are thus not included + in any log-probability computations. + +```juliadoctest fix +julia> @model function demo() + m ~ Normal() + x ~ Normal(m, 1) + return (; m=m, x=x) + end +demo (generic function with 2 methods) + +julia> model = demo(); + +julia> model_fixed = fix(model, m = 1.0); + +julia> model_conditioned = condition(model, m = 1.0); + +julia> logjoint(model_fixed, (x=1.0,)) +-0.9189385332046728 + +julia> # Different! + logjoint(model_conditioned, (x=1.0,)) +-2.3378770664093453 + +julia> # And the difference is the missing log-probability of `m`: + logjoint(model_fixed, (x=1.0,)) + logpdf(Normal(), 1.0) == logjoint(model_conditioned, (x=1.0,)) +true +``` +""" +fix(model::Model; values...) = contextualize(model, fix(model.context; values...)) +function fix(model::Model, value, values...) + return contextualize(model, fix(model.context, value, values...)) +end + +""" + unfix(model::Model) + unfix(model::Model, variables...) + +Return a `Model` for which `variables...` are _not_ considered fixed. +If no `variables` are provided, then all variables currently considered fixed +will no longer be. + +This is essentially the inverse of [`fix`](@ref). This also means that +it suffers from the same limitiations. + +Note that currently we only support `variables` to take on explicit values +provided to `fix`. + +# Examples +```jldoctest unfix +julia> using Distributions + +julia> @model function demo() + m ~ Normal() + x ~ Normal(m, 1) + return (; m=m, x=x) + end +demo (generic function with 2 methods) + +julia> fixed_model = fix(demo(), m = 1.0, x = 10.0); + +julia> fixed_model() +(m = 1.0, x = 10.0) + +julia> # By specifying the `VarName` to `unfix`. + model = unfix(fixed_model, @varname(m)); + +julia> (m, x) = model(); (m ≠ 1.0 && x == 10.0) +true + +julia> # When `NamedTuple` is used as the underlying, you can also provide + # the symbol directly (though the `@varname` approach is preferable if + # if the variable is known at compile-time). + model = unfix(fixed_model, :m); + +julia> (m, x) = model(); (m ≠ 1.0 && x == 10.0) +true + +julia> # `unfix` multiple at once: + (m, x) = unfix(model, :m, :x)(); (m ≠ 1.0 && x ≠ 10.0) +true + +julia> # `unfix` without any symbols will `unfix` all variables. + (m, x) = unfix(model)(); (m ≠ 1.0 && x ≠ 10.0) +true + +julia> # Usage of `Val` to perform `unfix` at compile-time if possible + # is also supported. + model = unfix(fixed_model, Val{:m}()); + +julia> (m, x) = model(); (m ≠ 1.0 && x == 10.0) +true +``` + +Similarly when using a `Dict`: + +```jldoctest unfix +julia> fixed_model_dict = fix(demo(), @varname(m) => 1.0, @varname(x) => 10.0); + +julia> fixed_model_dict() +(m = 1.0, x = 10.0) + +julia> unfixed_model_dict = unfix(fixed_model_dict, @varname(m)); + +julia> (m, x) = unfixed_model_dict(); m ≠ 1.0 && x == 10.0 +true +``` + +But, as mentioned, `unfix` is only supported for variables explicitly +provided to `fix` earlier: + +```jldoctest unfix +julia> @model function demo_mv(::Type{TV}=Float64) where {TV} + m = Vector{TV}(undef, 2) + m[1] ~ Normal() + m[2] ~ Normal() + return m + end +demo_mv (generic function with 4 methods) + +julia> model = demo_mv(); + +julia> fixed_model = fix(model, @varname(m) => [1.0, 2.0]); + +julia> fixed_model() +2-element Vector{Float64}: + 1.0 + 2.0 + +julia> unfixed_model = unfix(fixed_model, @varname(m[1])); + +julia> unfixed_model() # (×) `m[1]` is still fixed +2-element Vector{Float64}: + 1.0 + 2.0 + +julia> # (✓) this works though + unfixed_model_2 = fix(unfixed_model, @varname(m[1]) => missing); + +julia> m = unfixed_model_2(); (m[1] ≠ 1.0 && m[2] == 2.0) +true +``` +""" +unfix(model::Model, syms...) = contextualize(model, unfix(model.context, syms...)) + +""" + fixed(model::Model) + +Return the fixed values in `model`. + +# Examples +```jldoctest +julia> using Distributions + +julia> using DynamicPPL: fixed, contextualize + +julia> @model function demo() + m ~ Normal() + x ~ Normal(m, 1) + end +demo (generic function with 2 methods) + +julia> m = demo(); + +julia> # Returns all the variables we have fixed on + their values. + fixed(fix(m, x=100.0, m=1.0)) +(x = 100.0, m = 1.0) + +julia> # Nested ones also work (note that `PrefixContext` does nothing to the result). + cm = fix(contextualize(m, PrefixContext{:a}(fix(m=1.0))), x=100.0); + +julia> fixed(cm) +(x = 100.0, m = 1.0) + +julia> # Since we fixed on `m`, not `a.m` as it will appear after prefixed, + # `a.m` is treated as a random variable. + keys(VarInfo(cm)) +1-element Vector{VarName{Symbol("a.m"), Setfield.IdentityLens}}: + a.m + +julia> # If we instead fix on `a.m`, `m` in the model will be considered an observation. + cm = fix(contextualize(m, PrefixContext{:a}(fix(var"a.m"=1.0))), x=100.0); + +julia> fixed(cm).x +100.0 + +julia> fixed(cm).var"a.m" +1.0 + +julia> keys(VarInfo(cm)) # <= no variables are sampled +VarName[] +``` +""" +fixed(model::Model) = fixed(model.context) + """ (model::Model)([rng, varinfo, sampler, context]) diff --git a/test/contexts.jl b/test/contexts.jl index 5162dc61c..9b0427cd0 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -11,10 +11,10 @@ using DynamicPPL: PointwiseLikelihoodContext, contextual_isassumption, ConditionContext, - hasvalue, - getvalue, - hasvalue_nested, - getvalue_nested + hasconditioned, + getconditioned, + hasconditioned_nested, + getconditioned_nested # Dummy context to test nested behaviors. struct ParentContext{C<:AbstractContext} <: AbstractContext @@ -178,11 +178,11 @@ end end end - @testset "getvalue_nested & hasvalue_nested" begin + @testset "getconditioned_nested & hasconditioned_nested" begin @testset "$context" for context in contexts fake_vn = VarName{gensym(:x)}() - @test !hasvalue_nested(context, fake_vn) - @test_throws ErrorException getvalue_nested(context, fake_vn) + @test !hasconditioned_nested(context, fake_vn) + @test_throws ErrorException getconditioned_nested(context, fake_vn) if any(Base.Fix2(isa, ConditionContext), context) # `ConditionContext` specific. @@ -201,9 +201,9 @@ end for vn_child in DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) # `vn_child` should be in `context`. - @test hasvalue_nested(context, vn_child) + @test hasconditioned_nested(context, vn_child) # Value should be the same as extracted above. - @test getvalue_nested(context, vn_child) === + @test getconditioned_nested(context, vn_child) === get(val, getlens(vn_child)) end end @@ -253,4 +253,37 @@ end @test SamplingContext(SampleFromPrior(), DefaultContext()) == context @test SamplingContext(SampleFromPrior(), DefaultContext()) == context end + + @testset "FixedContext" begin + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + retval = model() + s, m = retval.s, retval.m + + # Keword approach. + model_fixed = fix(model; s=s) + @test model_fixed().s == s + @test model_fixed().m != m + # A fixed variable should not contribute at all to the logjoint. + # Assuming `condition` is correctly implemented, the following should hold. + @test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m)) + + # Positional approach. + model_fixed = fix(model, (; s)) + @test model_fixed().s == s + @test model_fixed().m != m + @test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m)) + + # Pairs approach. + model_fixed = fix(model, @varname(s) => s) + @test model_fixed().s == s + @test model_fixed().m != m + @test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m)) + + # Dictionary approach. + model_fixed = fix(model, Dict(@varname(s) => s)) + @test model_fixed().s == s + @test model_fixed().m != m + @test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m)) + end + end end