Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proper support for distributions with embedded support #462

Merged
merged 68 commits into from
Jun 15, 2023
Merged
Show file tree
Hide file tree
Changes from 67 commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
fc16d3e
compat with new Bijectors.jl
torfjelde Jan 31, 2023
a0de0ac
bump compat bounds for Bijectors and make it a breaking change
torfjelde Jan 31, 2023
4fe5eea
remove mentioning of Exp and Identity in test_utils.jl
torfjelde Feb 3, 2023
dfaf7be
added mistakenly commented out tests
torfjelde Feb 3, 2023
8d77d78
fixed test_utils
torfjelde Feb 3, 2023
16ac9e1
bump bijectors version
torfjelde Feb 3, 2023
2e3b006
Merge branch 'master' into torfjelde/Bijectors-compat
torfjelde Feb 3, 2023
425ca5d
added no-op impls for reconstruct
torfjelde Feb 6, 2023
c1f0b3b
added a bunch of convenience methods for working with Metadata instead
torfjelde Feb 6, 2023
73f4bd2
added usage of _inner_transform! in link, in addition to additional
torfjelde Feb 6, 2023
1b3c581
updated getall to not assume we want all the values in metadata
torfjelde Feb 6, 2023
c2fbded
added FIXME comment
torfjelde Feb 6, 2023
3b156db
fixed typo in comment
torfjelde Feb 6, 2023
6070e3f
Apply suggestions from code review
torfjelde Feb 6, 2023
613eb1b
sligh simplification of the linking stuff
torfjelde Feb 6, 2023
0af6e29
Merge remote-tracking branch 'origin/torfjelde/support-embedded-suppo…
torfjelde Feb 6, 2023
0fcd481
formatting
torfjelde Feb 6, 2023
29faba0
lower bound test compat entry for Tracker
torfjelde Feb 9, 2023
cc1bb7b
Merge branch 'torfjelde/Bijectors-compat' into torfjelde/support-embe…
torfjelde Feb 9, 2023
2501510
Merge branch 'master' into torfjelde/support-embedded-support
yebai Feb 9, 2023
6957e2e
move link-related functions to abstract_varinfo.jl and renamed methods
torfjelde Feb 11, 2023
90a3edb
fixed invlink!! for VarInfo
torfjelde Feb 12, 2023
7817920
fixed link and invlink tests
torfjelde Feb 12, 2023
ade35c8
added specialized mapreduce for (named)tuples to improve type-inference
torfjelde Feb 12, 2023
d7841e5
Merge remote-tracking branch 'origin/torfjelde/support-embedded-suppo…
torfjelde Feb 12, 2023
7a6ef1b
Apply suggestions from code review
torfjelde Feb 12, 2023
360283f
added missing docstring
torfjelde Feb 12, 2023
d958d84
added minor TODO comment for the future
torfjelde Feb 12, 2023
3cf6e07
added `link_transform` and `invlink_transform`, basically equivalent
torfjelde Feb 12, 2023
a25891d
Apply suggestions from code review
torfjelde Feb 13, 2023
02dd8bf
Update src/utils.jl
torfjelde Feb 13, 2023
4765ea9
added some docstrings
torfjelde Feb 13, 2023
f02fdd9
renamed link_and_reconstruct to the more accurate reconstruct_and_link
torfjelde Feb 13, 2023
603e027
removed unnecessary definition of inlink_transform
torfjelde Feb 13, 2023
de47598
fixed bug in newmetadata
torfjelde Feb 13, 2023
8cd2610
removed mapreduce_tuple in favor of reduce and map
torfjelde Feb 13, 2023
752e40b
Update src/utils.jl
torfjelde Feb 22, 2023
b0a67a9
Merge branch 'master' into torfjelde/support-embedded-support
torfjelde Mar 24, 2023
2765b08
introduce _logpdf_with_trans as a placeholder while we migrate away
torfjelde Mar 27, 2023
ed03864
reconstruct now takes into account the transformation to be used
torfjelde Mar 27, 2023
96c0690
replaced more references to bijector with link_transform
torfjelde Mar 27, 2023
7b5521d
added docstring for invlink_with_logpdf
torfjelde Mar 27, 2023
595d9ee
fixed bug in assume introduced hacky getval for SimpleVarInfo
torfjelde Mar 27, 2023
d81217e
added tests for linking
torfjelde Mar 27, 2023
9b516b8
Apply suggestions from code review
torfjelde Mar 27, 2023
61e832b
rename maybe_link_and_reconstruct to maybe_reconstruct_and_link
torfjelde Mar 27, 2023
d06cc8a
added reconstruct to the API docs
torfjelde Mar 27, 2023
dbecece
Update docs/src/api.md
torfjelde Mar 27, 2023
aa76c08
removed unnecessary comment
torfjelde Mar 27, 2023
f756e44
removed _logpdf_with_trans in favour of just using Bijectors.jl's for…
torfjelde Apr 17, 2023
9785594
Apply suggestions from code review
torfjelde Apr 17, 2023
78e6332
added warning regarding overloading to link_transform and invlink_tra…
torfjelde Apr 17, 2023
cec5bd3
added missing getval for ThreadSafeVarInfo
torfjelde Apr 18, 2023
b6dc3ec
added a minor additional test to linking
torfjelde Apr 23, 2023
f126448
Apply suggestions from code review
torfjelde Apr 27, 2023
1e4d688
Merge remote-tracking branch 'origin/torfjelde/support-embedded-suppo…
torfjelde Apr 27, 2023
b9a8c16
reverted chagnes from previous commit
torfjelde Apr 27, 2023
e3ce20d
fixed usage of deprecated link
torfjelde May 2, 2023
e5d19a8
Merge branch 'master' into torfjelde/support-embedded-support
yebai May 2, 2023
ba7c24c
Merge branch 'torfjelde/support-embedded-support' of github.com:Turin…
torfjelde Jun 5, 2023
bf73961
Update test/linking.jl
torfjelde Jun 5, 2023
eec87ee
Update src/DynamicPPL.jl
torfjelde Jun 5, 2023
42ed9df
Merge branch 'master' into torfjelde/support-embedded-support
torfjelde Jun 6, 2023
6fa0f72
fixed tests
torfjelde Jun 6, 2023
3d5ead3
added copy to tonamedtuple to avoid mutating chain samples
torfjelde Jun 6, 2023
64c3a07
improved testing for setval! and generated_quantities
torfjelde Jun 6, 2023
3fd55d4
bumped the version in turing tests
torfjelde Jun 6, 2023
e683cce
Apply suggestions from code review
torfjelde Jun 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ DynamicPPL.link!!
DynamicPPL.invlink!!
DynamicPPL.default_transformation
DynamicPPL.maybe_invlink_before_eval!!
```
DynamicPPL.reconstruct
```

#### Utils

Expand Down
1 change: 0 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ export AbstractVarInfo,
push!!,
empty!!,
getlogp,
resetlogp!,
setlogp!!,
acclogp!!,
resetlogp!!,
Expand Down
93 changes: 93 additions & 0 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,99 @@ variables `x` would return
"""
function tonamedtuple end

# TODO: Clean up all this linking stuff once and for all!
"""
with_logabsdet_jacobian_and_reconstruct([f, ]dist, x)

Like `Bijectors.with_logabsdet_jacobian(f, x)`, but also ensures the resulting
value is reconstructed to the correct type and shape according to `dist`.
"""
function with_logabsdet_jacobian_and_reconstruct(f, dist, x)
x_recon = reconstruct(f, dist, x)
return with_logabsdet_jacobian(f, x_recon)
end

# TODO: Once `(inv)link` isn't used heavily in `getindex(vi, vn)`, we can
# just use `first ∘ with_logabsdet_jacobian` to reduce the maintenance burden.
# NOTE: `reconstruct` is no-op if `val` is already of correct shape.
"""
reconstruct_and_link(dist, val)
reconstruct_and_link(vi::AbstractVarInfo, vi::VarName, dist, val)

Return linked `val` but reconstruct before linking, if necessary.

Note that unlike [`invlink_and_reconstruct`](@ref), this does not necessarily
return a reconstructed value, i.e. a value of the same type and shape as expected
by `dist`.

See also: [`invlink_and_reconstruct`](@ref), [`reconstruct`](@ref).
"""
reconstruct_and_link(f, dist, val) = f(reconstruct(f, dist, val))
reconstruct_and_link(dist, val) = reconstruct_and_link(link_transform(dist), dist, val)
function reconstruct_and_link(::AbstractVarInfo, ::VarName, dist, val)
return reconstruct_and_link(dist, val)
end

"""
invlink_and_reconstruct(dist, val)
invlink_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val)

Return invlinked and reconstructed `val`.

See also: [`reconstruct_and_link`](@ref), [`reconstruct`](@ref).
"""
invlink_and_reconstruct(f, dist, val) = f(reconstruct(f, dist, val))
function invlink_and_reconstruct(dist, val)
return invlink_and_reconstruct(invlink_transform(dist), dist, val)
end
function invlink_and_reconstruct(::AbstractVarInfo, ::VarName, dist, val)
return invlink_and_reconstruct(dist, val)
end

"""
maybe_link_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val)

Return reconstructed `val`, possibly linked if `istrans(vi, vn)` is `true`.
"""
function maybe_reconstruct_and_link(vi::AbstractVarInfo, vn::VarName, dist, val)
return if istrans(vi, vn)
reconstruct_and_link(vi, vn, dist, val)
else
reconstruct(dist, val)
end
end

"""
maybe_invlink_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val)

Return reconstructed `val`, possibly invlinked if `istrans(vi, vn)` is `true`.
"""
function maybe_invlink_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val)
return if istrans(vi, vn)
invlink_and_reconstruct(vi, vn, dist, val)
else
reconstruct(dist, val)
end
end

"""
invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist[, x])

Invlink `x` and compute the logpdf under `dist` including correction from
the invlink-transformation.

If `x` is not provided, `getval(vi, vn)` will be used.
"""
function invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist)
return invlink_with_logpdf(vi, vn, dist, getval(vi, vn))
end
function invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist, y)
# NOTE: Will this cause type-instabilities or will union-splitting save us?
f = istrans(vi, vn) ? invlink_transform(dist) : identity
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
x, logjac = with_logabsdet_jacobian_and_reconstruct(f, dist, y)
return x, logpdf(dist, x) + logjac
end

# Legacy code that is currently overloaded for the sake of simplicity.
# TODO: Remove when possible.
increment_num_produce!(::AbstractVarInfo) = nothing
Expand Down
30 changes: 21 additions & 9 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ end

# fallback without sampler
function assume(dist::Distribution, vn::VarName, vi)
r = vi[vn, dist]
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi
r, logp = invlink_with_logpdf(vi, vn, dist)
return r, logp, vi
end

# SampleFromPrior and SampleFromUniform
Expand All @@ -211,7 +211,9 @@ function assume(
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del")
unset_flag!(vi, vn, "del")
r = init(rng, dist, sampler)
BangBang.setindex!!(vi, vectorize(dist, maybe_link(vi, vn, dist, r)), vn)
BangBang.setindex!!(
vi, vectorize(dist, maybe_reconstruct_and_link(vi, vn, dist, r)), vn
)
setorder!(vi, vn, get_num_produce(vi))
else
# Otherwise we just extract it.
Expand All @@ -220,15 +222,17 @@ function assume(
else
r = init(rng, dist, sampler)
if istrans(vi)
push!!(vi, vn, link(dist, r), dist, sampler)
push!!(vi, vn, reconstruct_and_link(dist, r), dist, sampler)
# By default `push!!` sets the transformed flag to `false`.
settrans!!(vi, true, vn)
else
push!!(vi, vn, r, dist, sampler)
end
end

return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi
# HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, that's confusing.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really want to remove this entire function.

logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r)
return r, logpdf(dist, r) - logjac, vi
end

# default fallback (used e.g. by `SampleFromPrior` and `SampleUniform`)
Expand Down Expand Up @@ -470,7 +474,11 @@ function get_and_set_val!(
r = init(rng, dist, spl, n)
for i in 1:n
vn = vns[i]
setindex!!(vi, vectorize(dist, maybe_link(vi, vn, dist, r[:, i])), vn)
setindex!!(
vi,
vectorize(dist, maybe_reconstruct_and_link(vi, vn, dist, r[:, i])),
vn,
)
setorder!(vi, vn, get_num_produce(vi))
end
else
Expand Down Expand Up @@ -508,13 +516,17 @@ function get_and_set_val!(
for i in eachindex(vns)
vn = vns[i]
dist = dists isa AbstractArray ? dists[i] : dists
setindex!!(vi, vectorize(dist, maybe_link(vi, vn, dist, r[i])), vn)
setindex!!(
vi, vectorize(dist, maybe_reconstruct_and_link(vi, vn, dist, r[i])), vn
)
setorder!(vi, vn, get_num_produce(vi))
end
else
# r = reshape(vi[vec(vns)], size(vns))
# FIXME: Remove `reconstruct` in `getindex_raw(::VarInfo, ...)`
# and fix the lines below.
r_raw = getindex_raw(vi, vec(vns))
r = maybe_invlink.((vi,), vns, dists, reshape(r_raw, size(vns)))
r = maybe_invlink_and_reconstruct.((vi,), vns, dists, reshape(r_raw, size(vns)))
end
else
f = (vn, dist) -> init(rng, dist, spl)
Expand All @@ -525,7 +537,7 @@ function get_and_set_val!(
# 2. Define an anonymous function which returns `nothing`, which
# we then broadcast. This will allocate a vector of `nothing` though.
if istrans(vi)
push!!.((vi,), vns, link.((vi,), vns, dists, r), dists, (spl,))
push!!.((vi,), vns, reconstruct_and_link.((vi,), vns, dists, r), dists, (spl,))
# NOTE: Need to add the correction.
acclogp!!(vi, sum(logabsdetjac.(bijector.(dists), r)))
# `push!!` sets the trans-flag to `false` by default.
Expand Down
17 changes: 10 additions & 7 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ end

# `NamedTuple`
function Base.getindex(vi::SimpleVarInfo, vn::VarName, dist::Distribution)
return maybe_invlink(vi, vn, dist, getindex(vi, vn))
return maybe_invlink_and_reconstruct(vi, vn, dist, getindex(vi, vn))
end
function Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribution)
vals_linked = mapreduce(vcat, vns) do vn
Expand Down Expand Up @@ -329,6 +329,9 @@ function getindex_raw(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribut
return reconstruct(dist, vals, length(vns))
end

# HACK: because `VarInfo` isn't ready to implement a proper `getindex_raw`.
getval(vi::SimpleVarInfo, vn::VarName) = getindex_raw(vi, vn)

Base.haskey(vi::SimpleVarInfo, vn::VarName) = hasvalue(vi.values, vn)

function BangBang.setindex!!(vi::SimpleVarInfo, val, vn::VarName)
Expand Down Expand Up @@ -426,7 +429,7 @@ function assume(
)
value = init(rng, dist, sampler)
# Transform if we're working in unconstrained space.
value_raw = maybe_link(vi, vn, dist, value)
value_raw = maybe_reconstruct_and_link(vi, vn, dist, value)
vi = BangBang.push!!(vi, vn, value_raw, dist, sampler)
return value, Bijectors.logpdf_with_trans(dist, value, istrans(vi, vn)), vi
end
Expand All @@ -444,9 +447,9 @@ function dot_assume(

# Transform if we're working in transformed space.
value_raw = if dists isa Distribution
maybe_link.((vi,), vns, (dists,), value)
maybe_reconstruct_and_link.((vi,), vns, (dists,), value)
else
maybe_link.((vi,), vns, dists, value)
maybe_reconstruct_and_link.((vi,), vns, dists, value)
end

# Update `vi`
Expand All @@ -473,7 +476,7 @@ function dot_assume(

# Update `vi`.
for (vn, val) in zip(vns, eachcol(value))
val_linked = maybe_link(vi, vn, dist, val)
val_linked = maybe_reconstruct_and_link(vi, vn, dist, val)
vi = BangBang.setindex!!(vi, val_linked, vn)
end

Expand All @@ -488,7 +491,7 @@ function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:NamedTuple{names}}) where {
nt_vals = map(keys(vi)) do vn
val = vi[vn]
vns = collect(DynamicPPL.TestUtils.varname_leaves(vn, val))
vals = map(Base.Fix1(getindex, vi), vns)
vals = map(copy ∘ Base.Fix1(getindex, vi), vns)
(vals, map(string, vns))
end

Expand All @@ -501,7 +504,7 @@ function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:Dict})
# Extract the leaf varnames and values.
val = vi[vn]
vns = collect(DynamicPPL.TestUtils.varname_leaves(vn, val))
vals = map(Base.Fix1(getindex, vi), vns)
vals = map(copy ∘ Base.Fix1(getindex, vi), vns)

# Determine the corresponding symbol.
sym = only(unique(map(getsym, vns)))
Expand Down
2 changes: 2 additions & 0 deletions src/threadsafe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,5 @@ end

istrans(vi::ThreadSafeVarInfo, vn::VarName) = istrans(vi.varinfo, vn)
istrans(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = istrans(vi.varinfo, vns)

getval(vi::ThreadSafeVarInfo, vn::VarName) = getval(vi.varinfo, vn)
6 changes: 3 additions & 3 deletions src/transforming.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ function tilde_assume(

# Only transform if `!isinverse` since `vi[vn, right]`
# already performs the inverse transformation if it's transformed.
r_transformed = isinverse ? r : bijector(right)(r)
r_transformed = isinverse ? r : link_transform(right)(r)
return r, lp, setindex!!(vi, r_transformed, vn)
end

Expand All @@ -27,7 +27,7 @@ function dot_tilde_assume(
vi,
) where {isinverse}
r = getindex.((vi,), vns, (dist,))
b = bijector(dist)
b = link_transform(dist)

is_trans_uniques = unique(istrans.((vi,), vns))
@assert length(is_trans_uniques) == 1 "DynamicTransformationContext only supports transforming all variables"
Expand Down Expand Up @@ -70,7 +70,7 @@ function dot_tilde_assume(
@assert !isinverse "Trying to invlink non-transformed variables"
end

b = bijector(dist)
b = link_transform(dist)
for (vn, ri) in zip(vns, eachcol(r))
# Only transform if `!isinverse` since `vi[vn, right]`
# already performs the inverse transformation if it's transformed.
Expand Down
47 changes: 46 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,39 @@ function to_namedtuple_expr(syms, vals)
return :(NamedTuple{$names_expr}($vals_expr))
end

"""
link_transform(dist)

Return the constrained-to-unconstrained bijector for distribution `dist`.

By default, this is just `Bijectors.bijector(dist)`.

!!! warning
Note that currently this is not used by `Bijectors.logpdf_with_trans`,
hence that needs to be overloaded separately if the intention is
to change behavior of an existing distribution.
"""
link_transform(dist) = bijector(dist)

"""
invlink_transform(dist)

Return the unconstrained-to-constrained bijector for distribution `dist`.

By default, this is just `inverse(link_transform(dist))`.

!!! warning
Note that currently this is not used by `Bijectors.logpdf_with_trans`,
hence that needs to be overloaded separately if the intention is
to change behavior of an existing distribution.
"""
invlink_transform(dist) = inverse(link_transform(dist))

#####################################################
# Helper functions for vectorize/reconstruct values #
#####################################################

vectorize(d, r) = vec(r)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this is untested and causes method ambiguity issues (when eg tested with test_method_ambiguities).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know of this test_method_ambiguities; can you elaborate or point me somewhere?

vectorize(d::UnivariateDistribution, r::Real) = [r]
vectorize(d::MultivariateDistribution, r::AbstractVector{<:Real}) = copy(r)
vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r))
Expand All @@ -191,7 +220,23 @@ vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r))
# otherwise we will have error for MatrixDistribution.
# Note this is not the case for MultivariateDistribution so I guess this might be lack of
# support for some types related to matrices (like PDMat).
reconstruct(d::UnivariateDistribution, val::Real) = val

"""
reconstruct([f, ]dist, val)

Reconstruct `val` so that it's compatible with `dist`.

If `f` is also provided, the reconstruct value will be
such that `f(reconstruct_val)` is compatible with `dist`.
"""
reconstruct(f, dist, val) = reconstruct(dist, val)

# No-op versions.
reconstruct(::UnivariateDistribution, val::Real) = val
reconstruct(::MultivariateDistribution, val::AbstractVector{<:Real}) = copy(val)
reconstruct(::MatrixDistribution, val::AbstractMatrix{<:Real}) = val
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
# TODO: Implement no-op `reconstruct` for general array variates.

reconstruct(d::Distribution, val::AbstractVector) = reconstruct(size(d), val)
reconstruct(::Tuple{}, val::AbstractVector) = val[1]
reconstruct(s::NTuple{1}, val::AbstractVector) = copy(val)
Expand Down
Loading
Loading