Skip to content

Commit

Permalink
Replacing tonamedtuple (#526)
Browse files Browse the repository at this point in the history
* added impl of varname_and_value_leaves

* added examples with cholesky to varname_and_value_leaves doctests

* added more descriptive docstring of iterate for Leaf

* added concrete example in comment of iterate for Leaf

* added small docstring to Leaf

* Update src/utils.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
torfjelde and github-actions[bot] committed Sep 1, 2023
1 parent e2178c6 commit 866eb6f
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 1 deletion.
2 changes: 2 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ DynamicPPL.reconstruct
```@docs
DynamicPPL.unflatten
DynamicPPL.tonamedtuple
DynamicPPL.varname_leaves
DynamicPPL.varname_and_value_leaves
```

#### `SimpleVarInfo`
Expand Down
2 changes: 1 addition & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ using Setfield: Setfield
using ZygoteRules: ZygoteRules
using LogDensityProblems: LogDensityProblems

using LinearAlgebra: Cholesky
using LinearAlgebra: LinearAlgebra, Cholesky

using DocStringExtensions

Expand Down
152 changes: 152 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -870,3 +870,155 @@ function varname_leaves(vn::VarName, val::NamedTuple)
end
return Iterators.flatten(iter)
end

"""
varname_and_value_leaves(vn::VarName, val)
Return an iterator over all varname-value pairs that are represented by `vn` on `val`.
# Examples
```jldoctest varname-and-value-leaves
julia> using DynamicPPL: varname_and_value_leaves
julia> foreach(println, varname_and_value_leaves(@varname(x), 1:2))
(x[1], 1)
(x[2], 2)
julia> foreach(println, varname_and_value_leaves(@varname(x[1:2]), 1:2))
(x[1:2][1], 1)
(x[1:2][2], 2)
julia> x = (y = 1, z = [[2.0], [3.0]]);
julia> foreach(println, varname_and_value_leaves(@varname(x), x))
(x.y, 1)
(x.z[1][1], 2.0)
(x.z[2][1], 3.0)
```
There are also some special handling for certain types:
```jldoctest varname-and-value-leaves
julia> using LinearAlgebra
julia> x = reshape(1:4, 2, 2);
julia> # `LowerTriangular`
foreach(println, varname_and_value_leaves(@varname(x), LowerTriangular(x)))
(x[1,1], 1)
(x[2,1], 2)
(x[2,2], 4)
julia> # `UpperTriangular`
foreach(println, varname_and_value_leaves(@varname(x), UpperTriangular(x)))
(x[1,1], 1)
(x[1,2], 3)
(x[2,2], 4)
julia> # `Cholesky` with lower-triangular
foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'L', 0)))
(x[1,1], 1.0)
(x[2,1], 0.0)
(x[2,2], 1.0)
julia> # `Cholesky` with upper-triangular
foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'U', 0)))
(x[1,1], 1.0)
(x[1,2], 0.0)
(x[2,2], 1.0)
```
"""
function varname_and_value_leaves(vn::VarName, x)
return Iterators.map(value, Iterators.flatten(varname_and_value_leaves_inner(vn, x)))
end

"""
Leaf{T}
A container that represents the leaf of a nested structure, implementing
`iterate` to return itself.
This is particularly useful in conjunction with `Iterators.flatten` to
prevent flattening of nested structures.
"""
struct Leaf{T}
value::T
end

Leaf(xs...) = Leaf(xs)

# Allow us to treat `Leaf` as an iterator containing a single element.
# Something like an `[x]` would also be an iterator with a single element,
# but when we call `flatten` on this, it would also iterate over `x`,
# unflattening that too. By making `Leaf` a single-element iterator, which
# returns itself, we can call `iterate` on this as many times as we like
# without causing any change. The result is that `Iterators.flatten`
# will _not_ unflatten `Leaf`s.
# Note that this is similar to how `Base.iterate` is implemented for `Real`::
#
# julia> iterate(1)
# (1, nothing)
#
# One immediate example where this becomes in our scenario is that we might
# have `missing` values in our data, which does _not_ have an `iterate`
# implemented. Calling `Iterators.flatten` on this would cause an error.
Base.iterate(leaf::Leaf) = leaf, nothing
Base.iterate(::Leaf, _) = nothing

# Convenience.
value(leaf::Leaf) = leaf.value

# Leaf-types.
varname_and_value_leaves_inner(vn::VarName, x::Real) = [Leaf(vn, x)]
function varname_and_value_leaves_inner(
vn::VarName, val::AbstractArray{<:Union{Real,Missing}}
)
return (
Leaf(
VarName(vn, DynamicPPL.getlens(vn) DynamicPPL.Setfield.IndexLens(Tuple(I))),
val[I],
) for I in CartesianIndices(val)
)
end
# Containers.
function varname_and_value_leaves_inner(vn::VarName, val::AbstractArray)
return Iterators.flatten(
varname_and_value_leaves_inner(
VarName(vn, DynamicPPL.getlens(vn) DynamicPPL.Setfield.IndexLens(Tuple(I))),
val[I],
) for I in CartesianIndices(val)
)
end
function varname_and_value_leaves_inner(vn::DynamicPPL.VarName, val::NamedTuple)
iter = Iterators.map(keys(val)) do sym
lens = DynamicPPL.Setfield.PropertyLens{sym}()
varname_and_value_leaves_inner(vn lens, get(val, lens))
end

return Iterators.flatten(iter)
end
# Special types.
function varname_and_value_leaves_inner(vn::VarName, x::Cholesky)
# TODO: Or do we use `PDMat` here?
return varname_and_value_leaves_inner(vn, x.UL)
end
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.LowerTriangular)
return (
Leaf(
VarName(vn, DynamicPPL.getlens(vn) DynamicPPL.Setfield.IndexLens(Tuple(I))),
x[I],
)
# Iteration over the lower-triangular indices.
for I in CartesianIndices(x) if I[1] >= I[2]
)
end
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.UpperTriangular)
return (
Leaf(
VarName(vn, DynamicPPL.getlens(vn) DynamicPPL.Setfield.IndexLens(Tuple(I))),
x[I],
)
# Iteration over the upper-triangular indices.
for I in CartesianIndices(x) if I[1] <= I[2]
)
end

0 comments on commit 866eb6f

Please sign in to comment.