diff --git a/docs/src/api.md b/docs/src/api.md index 2dfda9119..0e4012e02 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -206,7 +206,8 @@ DynamicPPL.link!! DynamicPPL.invlink!! DynamicPPL.default_transformation DynamicPPL.maybe_invlink_before_eval!! -``` +DynamicPPL.reconstruct +``` #### Utils diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 594084d66..8904cfe81 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -43,7 +43,6 @@ export AbstractVarInfo, push!!, empty!!, getlogp, - resetlogp!, setlogp!!, acclogp!!, resetlogp!!, diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index acd51e288..116890d7b 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -553,6 +553,99 @@ variables `x` would return """ function tonamedtuple end +# TODO: Clean up all this linking stuff once and for all! +""" + with_logabsdet_jacobian_and_reconstruct([f, ]dist, x) + +Like `Bijectors.with_logabsdet_jacobian(f, x)`, but also ensures the resulting +value is reconstructed to the correct type and shape according to `dist`. +""" +function with_logabsdet_jacobian_and_reconstruct(f, dist, x) + x_recon = reconstruct(f, dist, x) + return with_logabsdet_jacobian(f, x_recon) +end + +# TODO: Once `(inv)link` isn't used heavily in `getindex(vi, vn)`, we can +# just use `first ∘ with_logabsdet_jacobian` to reduce the maintenance burden. +# NOTE: `reconstruct` is no-op if `val` is already of correct shape. +""" + reconstruct_and_link(dist, val) + reconstruct_and_link(vi::AbstractVarInfo, vi::VarName, dist, val) + +Return linked `val` but reconstruct before linking, if necessary. + +Note that unlike [`invlink_and_reconstruct`](@ref), this does not necessarily +return a reconstructed value, i.e. a value of the same type and shape as expected +by `dist`. + +See also: [`invlink_and_reconstruct`](@ref), [`reconstruct`](@ref). +""" +reconstruct_and_link(f, dist, val) = f(reconstruct(f, dist, val)) +reconstruct_and_link(dist, val) = reconstruct_and_link(link_transform(dist), dist, val) +function reconstruct_and_link(::AbstractVarInfo, ::VarName, dist, val) + return reconstruct_and_link(dist, val) +end + +""" + invlink_and_reconstruct(dist, val) + invlink_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val) + +Return invlinked and reconstructed `val`. + +See also: [`reconstruct_and_link`](@ref), [`reconstruct`](@ref). +""" +invlink_and_reconstruct(f, dist, val) = f(reconstruct(f, dist, val)) +function invlink_and_reconstruct(dist, val) + return invlink_and_reconstruct(invlink_transform(dist), dist, val) +end +function invlink_and_reconstruct(::AbstractVarInfo, ::VarName, dist, val) + return invlink_and_reconstruct(dist, val) +end + +""" + maybe_link_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val) + +Return reconstructed `val`, possibly linked if `istrans(vi, vn)` is `true`. +""" +function maybe_reconstruct_and_link(vi::AbstractVarInfo, vn::VarName, dist, val) + return if istrans(vi, vn) + reconstruct_and_link(vi, vn, dist, val) + else + reconstruct(dist, val) + end +end + +""" + maybe_invlink_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val) + +Return reconstructed `val`, possibly invlinked if `istrans(vi, vn)` is `true`. +""" +function maybe_invlink_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val) + return if istrans(vi, vn) + invlink_and_reconstruct(vi, vn, dist, val) + else + reconstruct(dist, val) + end +end + +""" + invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist[, x]) + +Invlink `x` and compute the logpdf under `dist` including correction from +the invlink-transformation. + +If `x` is not provided, `getval(vi, vn)` will be used. +""" +function invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist) + return invlink_with_logpdf(vi, vn, dist, getval(vi, vn)) +end +function invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist, y) + # NOTE: Will this cause type-instabilities or will union-splitting save us? + f = istrans(vi, vn) ? invlink_transform(dist) : identity + x, logjac = with_logabsdet_jacobian_and_reconstruct(f, dist, y) + return x, logpdf(dist, x) + logjac +end + # Legacy code that is currently overloaded for the sake of simplicity. # TODO: Remove when possible. increment_num_produce!(::AbstractVarInfo) = nothing diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 1078a0e18..1f0641007 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -194,8 +194,8 @@ end # fallback without sampler function assume(dist::Distribution, vn::VarName, vi) - r = vi[vn, dist] - return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi + r, logp = invlink_with_logpdf(vi, vn, dist) + return r, logp, vi end # SampleFromPrior and SampleFromUniform @@ -211,7 +211,9 @@ function assume( if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") unset_flag!(vi, vn, "del") r = init(rng, dist, sampler) - BangBang.setindex!!(vi, vectorize(dist, maybe_link(vi, vn, dist, r)), vn) + BangBang.setindex!!( + vi, vectorize(dist, maybe_reconstruct_and_link(vi, vn, dist, r)), vn + ) setorder!(vi, vn, get_num_produce(vi)) else # Otherwise we just extract it. @@ -220,7 +222,7 @@ function assume( else r = init(rng, dist, sampler) if istrans(vi) - push!!(vi, vn, link(dist, r), dist, sampler) + push!!(vi, vn, reconstruct_and_link(dist, r), dist, sampler) # By default `push!!` sets the transformed flag to `false`. settrans!!(vi, true, vn) else @@ -228,7 +230,9 @@ function assume( end end - return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi + # HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct. + logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r) + return r, logpdf(dist, r) - logjac, vi end # default fallback (used e.g. by `SampleFromPrior` and `SampleUniform`) @@ -470,7 +474,11 @@ function get_and_set_val!( r = init(rng, dist, spl, n) for i in 1:n vn = vns[i] - setindex!!(vi, vectorize(dist, maybe_link(vi, vn, dist, r[:, i])), vn) + setindex!!( + vi, + vectorize(dist, maybe_reconstruct_and_link(vi, vn, dist, r[:, i])), + vn, + ) setorder!(vi, vn, get_num_produce(vi)) end else @@ -508,13 +516,17 @@ function get_and_set_val!( for i in eachindex(vns) vn = vns[i] dist = dists isa AbstractArray ? dists[i] : dists - setindex!!(vi, vectorize(dist, maybe_link(vi, vn, dist, r[i])), vn) + setindex!!( + vi, vectorize(dist, maybe_reconstruct_and_link(vi, vn, dist, r[i])), vn + ) setorder!(vi, vn, get_num_produce(vi)) end else # r = reshape(vi[vec(vns)], size(vns)) + # FIXME: Remove `reconstruct` in `getindex_raw(::VarInfo, ...)` + # and fix the lines below. r_raw = getindex_raw(vi, vec(vns)) - r = maybe_invlink.((vi,), vns, dists, reshape(r_raw, size(vns))) + r = maybe_invlink_and_reconstruct.((vi,), vns, dists, reshape(r_raw, size(vns))) end else f = (vn, dist) -> init(rng, dist, spl) @@ -525,7 +537,7 @@ function get_and_set_val!( # 2. Define an anonymous function which returns `nothing`, which # we then broadcast. This will allocate a vector of `nothing` though. if istrans(vi) - push!!.((vi,), vns, link.((vi,), vns, dists, r), dists, (spl,)) + push!!.((vi,), vns, reconstruct_and_link.((vi,), vns, dists, r), dists, (spl,)) # NOTE: Need to add the correction. acclogp!!(vi, sum(logabsdetjac.(bijector.(dists), r))) # `push!!` sets the trans-flag to `false` by default. diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index a445bf87a..68b3d0ae2 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -290,7 +290,7 @@ end # `NamedTuple` function Base.getindex(vi::SimpleVarInfo, vn::VarName, dist::Distribution) - return maybe_invlink(vi, vn, dist, getindex(vi, vn)) + return maybe_invlink_and_reconstruct(vi, vn, dist, getindex(vi, vn)) end function Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribution) vals_linked = mapreduce(vcat, vns) do vn @@ -329,6 +329,9 @@ function getindex_raw(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribut return reconstruct(dist, vals, length(vns)) end +# HACK: because `VarInfo` isn't ready to implement a proper `getindex_raw`. +getval(vi::SimpleVarInfo, vn::VarName) = getindex_raw(vi, vn) + Base.haskey(vi::SimpleVarInfo, vn::VarName) = hasvalue(vi.values, vn) function BangBang.setindex!!(vi::SimpleVarInfo, val, vn::VarName) @@ -426,7 +429,7 @@ function assume( ) value = init(rng, dist, sampler) # Transform if we're working in unconstrained space. - value_raw = maybe_link(vi, vn, dist, value) + value_raw = maybe_reconstruct_and_link(vi, vn, dist, value) vi = BangBang.push!!(vi, vn, value_raw, dist, sampler) return value, Bijectors.logpdf_with_trans(dist, value, istrans(vi, vn)), vi end @@ -444,9 +447,9 @@ function dot_assume( # Transform if we're working in transformed space. value_raw = if dists isa Distribution - maybe_link.((vi,), vns, (dists,), value) + maybe_reconstruct_and_link.((vi,), vns, (dists,), value) else - maybe_link.((vi,), vns, dists, value) + maybe_reconstruct_and_link.((vi,), vns, dists, value) end # Update `vi` @@ -473,7 +476,7 @@ function dot_assume( # Update `vi`. for (vn, val) in zip(vns, eachcol(value)) - val_linked = maybe_link(vi, vn, dist, val) + val_linked = maybe_reconstruct_and_link(vi, vn, dist, val) vi = BangBang.setindex!!(vi, val_linked, vn) end @@ -488,7 +491,7 @@ function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:NamedTuple{names}}) where { nt_vals = map(keys(vi)) do vn val = vi[vn] vns = collect(DynamicPPL.TestUtils.varname_leaves(vn, val)) - vals = map(Base.Fix1(getindex, vi), vns) + vals = map(copy ∘ Base.Fix1(getindex, vi), vns) (vals, map(string, vns)) end @@ -501,7 +504,7 @@ function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:Dict}) # Extract the leaf varnames and values. val = vi[vn] vns = collect(DynamicPPL.TestUtils.varname_leaves(vn, val)) - vals = map(Base.Fix1(getindex, vi), vns) + vals = map(copy ∘ Base.Fix1(getindex, vi), vns) # Determine the corresponding symbol. sym = only(unique(map(getsym, vns))) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 85ad0e23e..dc8720e0a 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -178,3 +178,5 @@ end istrans(vi::ThreadSafeVarInfo, vn::VarName) = istrans(vi.varinfo, vn) istrans(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = istrans(vi.varinfo, vns) + +getval(vi::ThreadSafeVarInfo, vn::VarName) = getval(vi.varinfo, vn) diff --git a/src/transforming.jl b/src/transforming.jl index f4b50b057..bb8abddd6 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -15,7 +15,7 @@ function tilde_assume( # Only transform if `!isinverse` since `vi[vn, right]` # already performs the inverse transformation if it's transformed. - r_transformed = isinverse ? r : bijector(right)(r) + r_transformed = isinverse ? r : link_transform(right)(r) return r, lp, setindex!!(vi, r_transformed, vn) end @@ -27,7 +27,7 @@ function dot_tilde_assume( vi, ) where {isinverse} r = getindex.((vi,), vns, (dist,)) - b = bijector(dist) + b = link_transform(dist) is_trans_uniques = unique(istrans.((vi,), vns)) @assert length(is_trans_uniques) == 1 "DynamicTransformationContext only supports transforming all variables" @@ -70,7 +70,7 @@ function dot_tilde_assume( @assert !isinverse "Trying to invlink non-transformed variables" end - b = bijector(dist) + b = link_transform(dist) for (vn, ri) in zip(vns, eachcol(r)) # Only transform if `!isinverse` since `vi[vn, right]` # already performs the inverse transformation if it's transformed. diff --git a/src/utils.jl b/src/utils.jl index f0bba2071..8f3a0d101 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -177,10 +177,39 @@ function to_namedtuple_expr(syms, vals) return :(NamedTuple{$names_expr}($vals_expr)) end +""" + link_transform(dist) + +Return the constrained-to-unconstrained bijector for distribution `dist`. + +By default, this is just `Bijectors.bijector(dist)`. + +!!! warning + Note that currently this is not used by `Bijectors.logpdf_with_trans`, + hence that needs to be overloaded separately if the intention is + to change behavior of an existing distribution. +""" +link_transform(dist) = bijector(dist) + +""" + invlink_transform(dist) + +Return the unconstrained-to-constrained bijector for distribution `dist`. + +By default, this is just `inverse(link_transform(dist))`. + +!!! warning + Note that currently this is not used by `Bijectors.logpdf_with_trans`, + hence that needs to be overloaded separately if the intention is + to change behavior of an existing distribution. +""" +invlink_transform(dist) = inverse(link_transform(dist)) + ##################################################### # Helper functions for vectorize/reconstruct values # ##################################################### +vectorize(d, r) = vec(r) vectorize(d::UnivariateDistribution, r::Real) = [r] vectorize(d::MultivariateDistribution, r::AbstractVector{<:Real}) = copy(r) vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r)) @@ -191,7 +220,23 @@ vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r)) # otherwise we will have error for MatrixDistribution. # Note this is not the case for MultivariateDistribution so I guess this might be lack of # support for some types related to matrices (like PDMat). -reconstruct(d::UnivariateDistribution, val::Real) = val + +""" + reconstruct([f, ]dist, val) + +Reconstruct `val` so that it's compatible with `dist`. + +If `f` is also provided, the reconstruct value will be +such that `f(reconstruct_val)` is compatible with `dist`. +""" +reconstruct(f, dist, val) = reconstruct(dist, val) + +# No-op versions. +reconstruct(::UnivariateDistribution, val::Real) = val +reconstruct(::MultivariateDistribution, val::AbstractVector{<:Real}) = copy(val) +reconstruct(::MatrixDistribution, val::AbstractMatrix{<:Real}) = copy(val) +# TODO: Implement no-op `reconstruct` for general array variates. + reconstruct(d::Distribution, val::AbstractVector) = reconstruct(size(d), val) reconstruct(::Tuple{}, val::AbstractVector) = val[1] reconstruct(s::NTuple{1}, val::AbstractVector) = copy(val) diff --git a/src/varinfo.jl b/src/varinfo.jl index 17df5a97e..a30c9ea24 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -155,7 +155,7 @@ end for f in names mdf = :(metadata.$f) if inspace(f, space) || length(space) == 0 - len = :(length($mdf.vals)) + len = :(sum(length, $mdf.ranges)) push!( exprs, :( @@ -271,14 +271,24 @@ getmetadata(vi::TypedVarInfo, vn::VarName) = getfield(vi.metadata, getsym(vn)) Return the index of `vn` in the metadata of `vi` corresponding to `vn`. """ -getidx(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).idcs[vn] +getidx(vi::VarInfo, vn::VarName) = getidx(getmetadata(vi, vn), vn) +getidx(md::Metadata, vn::VarName) = md.idcs[vn] """ getrange(vi::VarInfo, vn::VarName) Return the index range of `vn` in the metadata of `vi`. """ -getrange(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).ranges[getidx(vi, vn)] +getrange(vi::VarInfo, vn::VarName) = getrange(getmetadata(vi, vn), vn) +getrange(md::Metadata, vn::VarName) = md.ranges[getidx(md, vn)] + +""" + setrange!(vi::VarInfo, vn::VarName, range) + +Set the index range of `vn` in the metadata of `vi` to `range`. +""" +setrange!(vi::VarInfo, vn::VarName, range) = setrange!(getmetadata(vi, vn), vn, range) +setrange!(md::Metadata, vn::VarName, range) = md.ranges[getidx(md, vn)] = range """ getranges(vi::VarInfo, vns::Vector{<:VarName}) @@ -294,7 +304,8 @@ end Return the distribution from which `vn` was sampled in `vi`. """ -getdist(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).dists[getidx(vi, vn)] +getdist(vi::VarInfo, vn::VarName) = getdist(getmetadata(vi, vn), vn) +getdist(md::Metadata, vn::VarName) = md.dists[getidx(md, vn)] """ getval(vi::VarInfo, vn::VarName) @@ -303,7 +314,8 @@ Return the value(s) of `vn`. The values may or may not be transformed to Euclidean space. """ -getval(vi::VarInfo, vn::VarName) = view(getmetadata(vi, vn).vals, getrange(vi, vn)) +getval(vi::VarInfo, vn::VarName) = getval(getmetadata(vi, vn), vn) +getval(md::Metadata, vn::VarName) = view(md.vals, getrange(md, vn)) """ setval!(vi::VarInfo, val, vn::VarName) @@ -312,7 +324,8 @@ Set the value(s) of `vn` in the metadata of `vi` to `val`. The values may or may not be transformed to Euclidean space. """ -setval!(vi::VarInfo, val, vn::VarName) = getmetadata(vi, vn).vals[getrange(vi, vn)] = [val;] +setval!(vi::VarInfo, val, vn::VarName) = setval!(getmetadata(vi, vn), val, vn) +setval!(md::Metadata, val, vn::VarName) = md.vals[getrange(md, vn)] = [val;] """ getval(vi::VarInfo, vns::Vector{<:VarName}) @@ -321,9 +334,7 @@ Return the value(s) of `vns`. The values may or may not be transformed to Euclidean space. """ -function getval(vi::VarInfo, vns::Vector{<:VarName}) - return mapreduce(vn -> getval(vi, vn), vcat, vns) -end +getval(vi::VarInfo, vns::Vector{<:VarName}) = mapreduce(Base.Fix1(getval, vi), vcat, vns) """ getall(vi::VarInfo) @@ -332,14 +343,12 @@ Return the values of all the variables in `vi`. The values may or may not be transformed to Euclidean space. """ -getall(vi::UntypedVarInfo) = vi.metadata.vals -getall(vi::TypedVarInfo) = vcat(_getall(vi.metadata)...) -@generated function _getall(metadata::NamedTuple{names}) where {names} - exprs = [] - for f in names - push!(exprs, :(metadata.$f.vals)) - end - return :($(exprs...),) +getall(vi::UntypedVarInfo) = getall(vi.metadata) +# NOTE: `mapreduce` over `NamedTuple` results in worse type-inference. +# See for example https://github.com/JuliaLang/julia/pull/46381. +getall(vi::TypedVarInfo) = reduce(vcat, map(getall, vi.metadata)) +function getall(md::Metadata) + return mapreduce(Base.Fix1(getval, md), vcat, md.vns; init=similar(md.vals, 0)) end """ @@ -739,19 +748,13 @@ function link!(vi::VarInfo, spl::AbstractSampler, spaceval::Val) ) return _link!(vi, spl, spaceval) end -function _link!(vi::UntypedVarInfo, spl::Sampler) +function _link!(vi::UntypedVarInfo, spl::AbstractSampler) # TODO: Change to a lazy iterator over `vns` vns = _getvns(vi, spl) if ~istrans(vi, vns[1]) for vn in vns - @debug "X -> ℝ for $(vn)..." dist = getdist(vi, vn) - # TODO: Use inplace versions to avoid allocations - b = bijector(dist) - x = reconstruct(dist, getval(vi, vn)) - y, logjac = with_logabsdet_jacobian(b, x) - setval!(vi, vectorize(dist, y), vn) - acclogp!!(vi, -logjac) + _inner_transform!(vi, vn, dist, link_transform(dist)) settrans!!(vi, true, vn) end else @@ -778,13 +781,8 @@ end if ~istrans(vi, f_vns[1]) # Iterate over all `f_vns` and transform for vn in f_vns - @debug "X -> R for $(vn)..." dist = getdist(vi, vn) - x = reconstruct(dist, getval(vi, vn)) - b = bijector(dist) - y, logjac = with_logabsdet_jacobian(b, x) - setval!(vi, vectorize(dist, y), vn) - acclogp!!(vi, -logjac) + _inner_transform!(vi, vn, dist, link_transform(dist)) settrans!!(vi, true, vn) end else @@ -839,13 +837,8 @@ function _invlink!(vi::UntypedVarInfo, spl::AbstractSampler) vns = _getvns(vi, spl) if istrans(vi, vns[1]) for vn in vns - @debug "ℝ -> X for $(vn)..." dist = getdist(vi, vn) - y = reconstruct(dist, getval(vi, vn)) - b = inverse(bijector(dist)) - x, logjac = with_logabsdet_jacobian(b, y) - setval!(vi, vectorize(dist, x), vn) - acclogp!!(vi, -logjac) + _inner_transform!(vi, vn, dist, invlink_transform(dist)) settrans!!(vi, false, vn) end else @@ -872,13 +865,8 @@ end if istrans(vi, f_vns[1]) # Iterate over all `f_vns` and transform for vn in f_vns - @debug "ℝ -> X for $(vn)..." dist = getdist(vi, vn) - y = reconstruct(dist, getval(vi, vn)) - b = inverse(bijector(dist)) - x, logjac = with_logabsdet_jacobian(b, y) - setval!(vi, vectorize(dist, x), vn) - acclogp!!(vi, -logjac) + _inner_transform!(vi, vn, dist, invlink_transform(dist)) settrans!!(vi, false, vn) end else @@ -891,10 +879,20 @@ end return expr end -link(vi, vn, dist, val) = Bijectors.link(dist, val) -invlink(vi, vn, dist, val) = Bijectors.invlink(dist, val) -maybe_link(vi, vn, dist, val) = istrans(vi, vn) ? link(vi, vn, dist, val) : val -maybe_invlink(vi, vn, dist, val) = istrans(vi, vn) ? invlink(vi, vn, dist, val) : val +function _inner_transform!(vi::VarInfo, vn::VarName, dist, f) + @debug "X -> ℝ for $(vn)..." + # TODO: Use inplace versions to avoid allocations + y, logjac = with_logabsdet_jacobian_and_reconstruct(f, dist, getval(vi, vn)) + yvec = vectorize(dist, y) + # Determine the new range. + start = first(getrange(vi, vn)) + # NOTE: `length(yvec)` should never be longer than `getrange(vi, vn)`. + setrange!(vi, vn, start:(start + length(yvec) - 1)) + # Set the new value. + setval!(vi, yvec, vn) + acclogp!!(vi, -logjac) + return vi +end """ islinked(vi::VarInfo, spl::Union{Sampler, SampleFromPrior}) @@ -927,8 +925,8 @@ end getindex(vi::VarInfo, vn::VarName) = getindex(vi, vn, getdist(vi, vn)) function getindex(vi::VarInfo, vn::VarName, dist::Distribution) @assert haskey(vi, vn) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" - val = getindex_raw(vi, vn, dist) - return maybe_invlink(vi, vn, dist, val) + val = getval(vi, vn) + return maybe_invlink_and_reconstruct(vi, vn, dist, val) end function getindex(vi::VarInfo, vns::Vector{<:VarName}) # FIXME(torfjelde): Using `getdist(vi, first(vns))` won't be correct in cases @@ -1029,19 +1027,20 @@ end return expr end -function tonamedtuple(vi::VarInfo) - return tonamedtuple(vi.metadata, vi) -end -@generated function tonamedtuple(metadata::NamedTuple{names}, vi::VarInfo) where {names} - length(names) === 0 && return :(NamedTuple()) - expr = Expr(:tuple) - map(names) do f - push!( - expr.args, - Expr(:(=), f, :(getindex.(Ref(vi), metadata.$f.vns), string.(metadata.$f.vns))), - ) +# TODO: Remove this completely. +tonamedtuple(varinfo::VarInfo) = tonamedtuple(varinfo.metadata, varinfo) +function tonamedtuple(metadata::NamedTuple{names}, varinfo::VarInfo) where {names} + length(names) === 0 && return NamedTuple() + + vals_tuple = map(values(metadata)) do x + # NOTE: `tonamedtuple` is really only used in Turing.jl to convert to + # a "transition". This means that we really don't mutations of the values + # in `varinfo` to propoagate the previous samples. Hence we `copy.` + vals = map(copy ∘ Base.Fix1(getindex, varinfo), x.vns) + return vals, map(string, x.vns) end - return expr + + return NamedTuple{names}(vals_tuple) end @inline function findvns(vi, f_vns) diff --git a/test/linking.jl b/test/linking.jl new file mode 100644 index 000000000..f81895788 --- /dev/null +++ b/test/linking.jl @@ -0,0 +1,86 @@ +using Bijectors + +# Simple transformations which alters the "dimension" of the variable. +struct TrilToVec{S} + size::S +end + +struct TrilFromVec{S} + size::S +end + +Bijectors.inverse(f::TrilToVec) = TrilFromVec(f.size) +Bijectors.inverse(f::TrilFromVec) = TrilToVec(f.size) + +function (v::TrilToVec)(x) + mask = tril(trues(v.size)) + return vec(x[mask]) +end +function (v::TrilFromVec)(y) + mask = tril(trues(v.size)) + x = similar(y, v.size) + x[mask] .= y + return LowerTriangular(x) +end + +# Just some dummy values so we can make sure that the log-prob computation +# has been altered correctly. +Bijectors.with_logabsdet_jacobian(f::TrilToVec, x) = (f(x), log(eltype(x)(2))) +Bijectors.with_logabsdet_jacobian(f::TrilFromVec, x) = (f(x), -eltype(x)(log(2))) + +# Dummy example. +struct MyMatrixDistribution <: ContinuousMatrixDistribution + dim::Int +end + +Base.size(d::MyMatrixDistribution) = (d.dim, d.dim) +function Distributions._rand!( + rng::AbstractRNG, d::MyMatrixDistribution, x::AbstractMatrix{<:Real} +) + return randn!(rng, x) +end +function Distributions._logpdf(::MyMatrixDistribution, x::AbstractMatrix{<:Real}) + return -sum(abs2, LowerTriangular(x)) / 2 +end + +# Skip reconstruction in the inverse-map since it's no longer needed. +DynamicPPL.reconstruct(::TrilFromVec, ::MyMatrixDistribution, x::AbstractVector{<:Real}) = x + +# Specify the link-transform to use. +Bijectors.bijector(dist::MyMatrixDistribution) = TrilToVec((dist.dim, dist.dim)) +function Bijectors.logpdf_with_trans(dist::MyMatrixDistribution, x, istrans::Bool) + lp = logpdf(dist, x) + if istrans + lp = lp - logabsdetjac(bijector(dist), x) + end + + return lp +end + +@testset "Linking" begin + # Just making sure the transformations are okay. + x = randn(3, 3) + f = TrilToVec((3, 3)) + f_inv = inverse(f) + y = f(x) + @test y isa AbstractVector + @test f_inv(f(x)) == LowerTriangular(x) + + # Within a model. + dist = MyMatrixDistribution(3) + @model demo() = m ~ dist + model = demo() + + vis = DynamicPPL.TestUtils.setup_varinfos(model, rand(model), (@varname(m),)) + @testset "$(short_varinfo_name(vi))" for vi in vis + # Evaluate once to ensure we have `logp` value. + vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext())) + vi_linked = DynamicPPL.link!!(deepcopy(vi), model) + # Difference should just be the log-absdet-jacobian "correction". + @test DynamicPPL.getlogp(vi) - DynamicPPL.getlogp(vi_linked) ≈ log(2) + @test vi_linked[@varname(m), dist] == LowerTriangular(vi[@varname(m), dist]) + # Linked one should be working with a lower-dimensional representation. + @test length(vi_linked[:]) < length(vi[:]) + @test length(vi_linked[:]) == 3 + end +end diff --git a/test/model.jl b/test/model.jl index 7fb8bcf0b..d78133f5e 100644 --- a/test/model.jl +++ b/test/model.jl @@ -116,7 +116,7 @@ end model = DynamicPPL.TestUtils.demo_dynamic_constraint() vi = VarInfo(model) spl = SampleFromPrior() - link!(vi, spl) + link!!(vi, spl, model) for i in 1:10 # Sample with large variations. diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index a5b57f5f6..9a18c439d 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -64,7 +64,6 @@ @testset "$(typeof(vi))" for vi in ( SimpleVarInfo(Dict()), SimpleVarInfo(values_constrained), VarInfo(model) ) - vi = SimpleVarInfo(values_constrained) for vn in DynamicPPL.TestUtils.varnames(model) vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn) end diff --git a/test/turing/Project.toml b/test/turing/Project.toml index a58d9e23d..a822008e3 100644 --- a/test/turing/Project.toml +++ b/test/turing/Project.toml @@ -5,6 +5,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] -DynamicPPL = "0.20, 0.21, 0.22" +DynamicPPL = "0.20, 0.21, 0.22, 0.23" Turing = "0.21, 0.22, 0.23, 0.24, 0.25" julia = "1.6" diff --git a/test/turing/model.jl b/test/turing/model.jl index e27b177eb..fcbdd88a3 100644 --- a/test/turing/model.jl +++ b/test/turing/model.jl @@ -1,95 +1,12 @@ @testset "model.jl" begin @testset "setval! & generated_quantities" begin - @model function demo1(xs, ::Type{TV}=Vector{Float64}) where {TV} - m = TV(undef, 2) - for i in 1:2 - m[i] ~ Normal(0, 1) - end - - for i in eachindex(xs) - xs[i] ~ Normal(m[1], 1.0) - end - - return (m,) + @testset "$model" for model in DynamicPPL.TestUtils.DEMO_MODELS + chain = sample(model, Prior(), 10) + # A simple way of checking that the computation is determinstic: run twice and compare. + res1 = generated_quantities(model, MCMCChains.get_sections(chain, :parameters)) + res2 = generated_quantities(model, MCMCChains.get_sections(chain, :parameters)) + @test all(res1 .== res2) + test_setval!(model, MCMCChains.get_sections(chain, :parameters)) end - - @model function demo2(xs) - m ~ MvNormal(zeros(2), I) - - for i in eachindex(xs) - xs[i] ~ Normal(m[1], 1.0) - end - - return (m,) - end - - xs = randn(3) - model1 = demo1(xs) - model2 = demo2(xs) - - chain1 = sample(model1, MH(), 100) - chain2 = sample(model2, MH(), 100) - - res11 = generated_quantities(model1, MCMCChains.get_sections(chain1, :parameters)) - res21 = generated_quantities(model2, MCMCChains.get_sections(chain1, :parameters)) - - res12 = generated_quantities(model1, MCMCChains.get_sections(chain2, :parameters)) - res22 = generated_quantities(model2, MCMCChains.get_sections(chain2, :parameters)) - - # Check that the two different models produce the same values for - # the same chains. - @test all(res11 .== res21) - @test all(res12 .== res22) - # Ensure that they're not all the same (some can be, because rejected samples) - @test any(res12[1:(end - 1)] .!= res12[2:end]) - - test_setval!(model1, MCMCChains.get_sections(chain1, :parameters)) - test_setval!(model2, MCMCChains.get_sections(chain2, :parameters)) - - # Next level - @model function demo3(xs, ::Type{TV}=Vector{Float64}) where {TV} - m = Vector{TV}(undef, 2) - for i in 1:length(m) - m[i] ~ MvNormal(zeros(2), I) - end - - for i in eachindex(xs) - xs[i] ~ Normal(m[1][1], 1.0) - end - - return (m,) - end - - @model function demo4(xs, ::Type{TV}=Vector{Vector{Float64}}) where {TV} - m = TV(undef, 2) - for i in 1:length(m) - m[i] ~ MvNormal(zeros(2), I) - end - - for i in eachindex(xs) - xs[i] ~ Normal(m[1][1], 1.0) - end - - return (m,) - end - - model3 = demo3(xs) - model4 = demo4(xs) - - chain3 = sample(model3, MH(), 100) - chain4 = sample(model4, MH(), 100) - - res33 = generated_quantities(model3, MCMCChains.get_sections(chain3, :parameters)) - res43 = generated_quantities(model4, MCMCChains.get_sections(chain3, :parameters)) - - res34 = generated_quantities(model3, MCMCChains.get_sections(chain4, :parameters)) - res44 = generated_quantities(model4, MCMCChains.get_sections(chain4, :parameters)) - - # Check that the two different models produce the same values for - # the same chains. - @test all(res33 .== res43) - @test all(res34 .== res44) - # Ensure that they're not all the same (some can be, because rejected samples) - @test any(res34[1:(end - 1)] .!= res34[2:end]) end end