Skip to content

Commit

Permalink
use APPL pending fix for testing; fix more errors
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 committed Apr 11, 2024
1 parent 16b6324 commit c2bad4a
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 13 deletions.
8 changes: 4 additions & 4 deletions src/model_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,12 @@ function varname_in_chain!(
x::AbstractArray, vn_parent::VarName{sym}, chain, chain_idx, iteration_idx, out
) where {sym}
# We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the optic.
# This way we can use `getoptic(vn)` to extract the value from `x` and use `vn_parent ⨟ getoptic(vn)`
# This way we can use `getoptic(vn)` to extract the value from `x` and use `getoptic(vn) ∘ vn_parent`
# to extract the value from the `chain`.
for vn in varname_leaves(VarName{sym}(), x)
# Update `out`, possibly in place, and return.
l = AbstractPPL.getoptic(vn)
varname_in_chain!(x, vn_parent l, chain, chain_idx, iteration_idx, out)
varname_in_chain!(x, l vn_parent, chain, chain_idx, iteration_idx, out)
end
return out
end
Expand All @@ -104,7 +104,7 @@ function values_from_chain(
x::AbstractArray, vn_parent::VarName{sym}, chain, chain_idx, iteration_idx
) where {sym}
# We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the optic.
# This way we can use `getoptic(vn)` to extract the value from `x` and use `vn_parent ⨟ getoptic(vn)`
# This way we can use `getoptic(vn)` to extract the value from `x` and use `getoptic(vn) ∘ vn_parent`
# to extract the value from the `chain`.
out = similar(x)
for vn in varname_leaves(VarName{sym}(), x)
Expand All @@ -113,7 +113,7 @@ function values_from_chain(
out = Accessors.set(
out,
BangBang.prefermutation(l),
chain[iteration_idx, Symbol(vn_parent l), chain_idx],
chain[iteration_idx, Symbol(l vn_parent), chain_idx],
)
end

Expand Down
18 changes: 9 additions & 9 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ function set!!(obj, optic::AbstractPPL.ALLOWED_OPTICS, value)
end
function set!!(obj, vn::VarName{sym}, value) where {sym}
optic = BangBang.prefermutation(
Accessors.PropertyLens{sym}() AbstractPPL.getoptic(vn)
AbstractPPL.getoptic(vn) Accessors.PropertyLens{sym}()
)
return Accessors.set(obj, optic, value)
end
Expand Down Expand Up @@ -811,7 +811,7 @@ function nested_getindex(values::AbstractDict, vn::VarName)
# TODO: Should we also check that we `canview` the extracted `value`
# rather than just let it fail upon `get` call?
value = values[VarName(vn, keyoptic)]
return get(value, child)
return child(value)
end

"""
Expand Down Expand Up @@ -1037,7 +1037,7 @@ function varname_and_value_leaves_inner(
)
return (
Leaf(
VarName(vn, DynamicPPL.getoptic(vn) DynamicPPL.Accessors.IndexLens(Tuple(I))),
VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) DynamicPPL.getoptic(vn)),
val[I],
) for I in CartesianIndices(val)
)
Expand All @@ -1046,7 +1046,7 @@ end
function varname_and_value_leaves_inner(vn::VarName, val::AbstractArray)
return Iterators.flatten(
varname_and_value_leaves_inner(
VarName(vn, DynamicPPL.getoptic(vn) DynamicPPL.Accessors.IndexLens(Tuple(I))),
VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) DynamicPPL.getoptic(vn)),
val[I],
) for I in CartesianIndices(val)
)
Expand All @@ -1055,7 +1055,7 @@ function varname_and_value_leaves_inner(vn::DynamicPPL.VarName, val::NamedTuple)
iter = Iterators.map(keys(val)) do sym
optic = DynamicPPL.Accessors.PropertyLens{sym}()
varname_and_value_leaves_inner(
VarName{getsym(vn)}(getoptic(vn) optic), optic(val)
VarName{getsym(vn)}(optic getoptic(vn)), optic(val)
)
end

Expand All @@ -1065,15 +1065,15 @@ end
function varname_and_value_leaves_inner(vn::VarName, x::Cholesky)
# TODO: Or do we use `PDMat` here?
return if x.uplo == 'L'
varname_and_value_leaves_inner(vn Accessors.PropertyLens{:L}(), x.L)
varname_and_value_leaves_inner(Accessors.PropertyLens{:L}() vn, x.L)
else
varname_and_value_leaves_inner(vn Accessors.PropertyLens{:U}(), x.U)
varname_and_value_leaves_inner(Accessors.PropertyLens{:U}() vn, x.U)
end
end
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.LowerTriangular)
return (
Leaf(
VarName(vn, DynamicPPL.getoptic(vn) DynamicPPL.Accessors.IndexLens(Tuple(I))),
VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) DynamicPPL.getoptic(vn)),
x[I],
)
# Iteration over the lower-triangular indices.
Expand All @@ -1083,7 +1083,7 @@ end
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.UpperTriangular)
return (
Leaf(
VarName(vn, DynamicPPL.getoptic(vn) DynamicPPL.Accessors.IndexLens(Tuple(I))),
VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) DynamicPPL.getoptic(vn)),
x[I],
)
# Iteration over the upper-triangular indices.
Expand Down
14 changes: 14 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,20 @@ using Test

using DynamicPPL: getargs_dottilde, getargs_tilde, Selector

# TODO: temporarily overwrite for testing
using AbstractPPL: ALLOWED_OPTICS, VarName
# Allow compositions with optic.
function Base.:(optic::ALLOWED_OPTICS, vn::VarName{sym,<:ALLOWED_OPTICS}) where {sym}
vn_optic = getoptic(vn)
if vn_optic == identity
return VarName{sym}(optic)
elseif optic == identity
return vn
else
return VarName{sym}(optic vn_optic)
end
end

const DIRECTORY_DynamicPPL = dirname(dirname(pathof(DynamicPPL)))
const DIRECTORY_Turing_tests = joinpath(DIRECTORY_DynamicPPL, "test", "turing")
const GROUP = get(ENV, "GROUP", "All")
Expand Down

0 comments on commit c2bad4a

Please sign in to comment.