diff --git a/Project.toml b/Project.toml index cbad0d688..0dad5cd48 100644 --- a/Project.toml +++ b/Project.toml @@ -1,11 +1,12 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.24.9" +version = "0.25.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -20,15 +21,31 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" -Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" +[weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" + +[extensions] +DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] +DynamicPPLEnzymeCoreExt = ["EnzymeCore"] +DynamicPPLForwardDiffExt = ["ForwardDiff"] +DynamicPPLMCMCChainsExt = ["MCMCChains"] +DynamicPPLReverseDiffExt = ["ReverseDiff"] +DynamicPPLZygoteRulesExt = ["ZygoteRules"] + [compat] ADTypes = "0.2" AbstractMCMC = "5" -AbstractPPL = "0.7" -BangBang = "0.3" +AbstractPPL = "0.8.2" +Accessors = "0.1" +BangBang = "0.4" Bijectors = "0.13" ChainRulesCore = "1" Compat = "4" @@ -44,29 +61,12 @@ MacroTools = "0.5.6" OrderedCollections = "1" Random = "1.6" Requires = "1" -Setfield = "1" Test = "1.6" ZygoteRules = "0.2" julia = "1.6" -[extensions] -DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] -DynamicPPLEnzymeCoreExt = ["EnzymeCore"] -DynamicPPLForwardDiffExt = ["ForwardDiff"] -DynamicPPLMCMCChainsExt = ["MCMCChains"] -DynamicPPLReverseDiffExt = ["ReverseDiff"] -DynamicPPLZygoteRulesExt = ["ZygoteRules"] - [extras] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" - -[weakdeps] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" diff --git a/docs/Project.toml b/docs/Project.toml index 48ebe173c..0746a3b5d 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,21 +1,19 @@ [deps] +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" -MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" -Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" [compat] +Accessors = "0.1" DataStructures = "0.18" Distributions = "0.25" Documenter = "1" FillArrays = "0.13, 1" LogDensityProblems = "2" MCMCChains = "5, 6" -MLUtils = "0.3, 0.4" -Setfield = "0.7.1, 0.8, 1" StableRNGs = "1" diff --git a/docs/src/tutorials/prob-interface.md b/docs/src/tutorials/prob-interface.md index dc9f50204..330fa931a 100644 --- a/docs/src/tutorials/prob-interface.md +++ b/docs/src/tutorials/prob-interface.md @@ -107,12 +107,28 @@ To give an example of the probability interface in use, we can use it to estimat In cross-validation, we split the dataset into several equal parts. Then, we choose one of these sets to serve as the validation set. Here, we measure fit using the cross entropy (Bayes loss).[^1] +(For the sake of simplicity, in the following code, we enforce that `nfolds` must divide the number of data points. For a more competent implementation, see [MLUtils.jl](https://juliaml.github.io/MLUtils.jl/dev/api/#MLUtils.kfolds).) ```@example probinterface -using MLUtils +# Calculate the train/validation splits across `nfolds` partitions, assume `length(dataset)` divides `nfolds` +function kfolds(dataset::Array{<:Real}, nfolds::Int) + fold_size, remaining = divrem(length(dataset), nfolds) + if remaining != 0 + error("The number of folds must divide the number of data points.") + end + first_idx = firstindex(dataset) + last_idx = lastindex(dataset) + splits = map(0:(nfolds - 1)) do i + start_idx = first_idx + i * fold_size + end_idx = start_idx + fold_size + train_set_indices = [first_idx:(start_idx - 1); end_idx:last_idx] + return (view(dataset, train_set_indices), view(dataset, start_idx:(end_idx - 1))) + end + return splits +end function cross_val( - dataset::AbstractVector{<:Real}; + dataset::Vector{<:Real}; nfolds::Int=5, nsamples::Int=1_000, rng::Random.AbstractRNG=Random.default_rng(), diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index ce6605250..0ccfbb103 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -12,7 +12,7 @@ using ADTypes: ADTypes using BangBang: BangBang, push!!, empty!!, setindex!! using MacroTools: MacroTools using ConstructionBase: ConstructionBase -using Setfield: Setfield +using Accessors: Accessors using LogDensityProblems: LogDensityProblems using LogDensityProblemsAD: LogDensityProblemsAD diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index bd7e8d8fb..8aedeb09c 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -262,7 +262,7 @@ julia> values_as(SimpleVarInfo(data), NamedTuple) (x = 1.0, m = [2.0]) julia> values_as(SimpleVarInfo(data), OrderedDict) -OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Any} with 2 entries: +OrderedDict{VarName{sym, typeof(identity)} where sym, Any} with 2 entries: x => 1.0 m => [2.0] @@ -312,7 +312,7 @@ julia> values_as(vi, NamedTuple) (s = 1.0, m = 2.0) julia> values_as(vi, OrderedDict) -OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Float64} with 2 entries: +OrderedDict{VarName{sym, typeof(identity)} where sym, Float64} with 2 entries: s => 1.0 m => 2.0 @@ -338,7 +338,7 @@ julia> values_as(vi, NamedTuple) (s = 1.0, m = 2.0) julia> values_as(vi, OrderedDict) -OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Float64} with 2 entries: +OrderedDict{VarName{sym, typeof(identity)} where sym, Float64} with 2 entries: s => 1.0 m => 2.0 @@ -426,7 +426,7 @@ julia> # Extract one with only `m`. julia> keys(varinfo_subset1) -1-element Vector{VarName{:m, Setfield.IdentityLens}}: +1-element Vector{VarName{:m, typeof(identity)}}: m julia> varinfo_subset1[@varname(m)] diff --git a/src/compiler.jl b/src/compiler.jl index e7c44d16b..f8a04a557 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -4,11 +4,11 @@ const INTERNALNAMES = (:__model__, :__context__, :__varinfo__) need_concretize(expr) Return `true` if `expr` needs to be concretized, i.e., if it contains a colon `:` or -requires a dynamic lens. +requires a dynamic optic. # Examples -```jldoctest; setup=:(using Setfield) +```jldoctest; setup=:(using Accessors) julia> DynamicPPL.need_concretize(:(x[1, :])) true @@ -19,7 +19,7 @@ julia> DynamicPPL.need_concretize(:(x[1, 1])) false """ function need_concretize(expr) - return Setfield.need_dynamic_lens(expr) || begin + return Accessors.need_dynamic_optic(expr) || begin flag = false MacroTools.postwalk(expr) do ex # Concretise colon by default @@ -202,13 +202,13 @@ variables. # Example ```jldoctest; setup=:(using Distributions, LinearAlgebra) julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(ones(2), I), randn(2, 2), @varname(x)); vns[end] -x[:,2] +x[:, 2] julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x)); vns[end] -x[1,2] +x[1, 2] julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x[:])); vns[end] -x[:][1,2] +x[:][1, 2] julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(3), @varname(x[1])); vns[end] x[1][3] @@ -226,7 +226,7 @@ function unwrap_right_left_vns( # for `i = size(left, 2)`. Hence the symbol should be `x[:, i]`, # and we therefore add the `Colon()` below. vns = map(axes(left, 2)) do i - return AbstractPPL.concretize(vn ∘ Setfield.IndexLens((Colon(), i)), left) + return AbstractPPL.concretize(Accessors.IndexLens((Colon(), i)) ∘ vn, left) end return unwrap_right_left_vns(right, left, vns) end @@ -236,7 +236,7 @@ function unwrap_right_left_vns( vn::VarName, ) vns = map(CartesianIndices(left)) do i - return vn ∘ Setfield.IndexLens(Tuple(i)) + return Accessors.IndexLens(Tuple(i)) ∘ vn end return unwrap_right_left_vns(right, left, vns) end @@ -437,7 +437,7 @@ function generate_tilde_assume(left, right, vn) expr = :($left = $value) if left isa Expr expr = AbstractPPL.drop_escape( - Setfield.setmacro(BangBang.prefermutation, expr; overwrite=true) + Accessors.setmacro(BangBang.prefermutation, expr; overwrite=true) ) end diff --git a/src/contexts.jl b/src/contexts.jl index 83da5d929..2018b9155 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -288,9 +288,9 @@ end function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} if @generated - return :(VarName{$(QuoteNode(Symbol(Prefix, PREFIX_SEPARATOR, Sym)))}(getlens(vn))) + return :(VarName{$(QuoteNode(Symbol(Prefix, PREFIX_SEPARATOR, Sym)))}(getoptic(vn))) else - VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(getlens(vn)) + VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(getoptic(vn)) end end diff --git a/src/model.jl b/src/model.jl index c0cc2f26f..8c10ed36e 100644 --- a/src/model.jl +++ b/src/model.jl @@ -279,11 +279,11 @@ in their trace/`VarInfo`: ```jldoctest condition julia> keys(VarInfo(demo_outer())) -1-element Vector{VarName{:m, Setfield.IdentityLens}}: +1-element Vector{VarName{:m, typeof(identity)}}: m julia> keys(VarInfo(demo_outer_prefix())) -1-element Vector{VarName{Symbol("inner.m"), Setfield.IdentityLens}}: +1-element Vector{VarName{Symbol("inner.m"), typeof(identity)}}: inner.m ``` @@ -448,7 +448,7 @@ julia> conditioned(cm) julia> # Since we conditioned on `m`, not `a.m` as it will appear after prefixed, # `a.m` is treated as a random variable. keys(VarInfo(cm)) -1-element Vector{VarName{Symbol("a.m"), Setfield.IdentityLens}}: +1-element Vector{VarName{Symbol("a.m"), typeof(identity)}}: a.m julia> # If we instead condition on `a.m`, `m` in the model will be considered an observation. @@ -634,11 +634,11 @@ in their trace/`VarInfo`: ```jldoctest fix julia> keys(VarInfo(demo_outer())) -1-element Vector{VarName{:m, Setfield.IdentityLens}}: +1-element Vector{VarName{:m, typeof(identity)}}: m julia> keys(VarInfo(demo_outer_prefix())) -1-element Vector{VarName{Symbol("inner.m"), Setfield.IdentityLens}}: +1-element Vector{VarName{Symbol("inner.m"), typeof(identity)}}: inner.m ``` @@ -830,7 +830,7 @@ julia> fixed(cm) julia> # Since we fixed on `m`, not `a.m` as it will appear after prefixed, # `a.m` is treated as a random variable. keys(VarInfo(cm)) -1-element Vector{VarName{Symbol("a.m"), Setfield.IdentityLens}}: +1-element Vector{VarName{Symbol("a.m"), typeof(identity)}}: a.m julia> # If we instead fix on `a.m`, `m` in the model will be considered an observation. diff --git a/src/model_utils.jl b/src/model_utils.jl index ab1acfa05..ac4ec7022 100644 --- a/src/model_utils.jl +++ b/src/model_utils.jl @@ -78,13 +78,13 @@ end function varname_in_chain!( x::AbstractArray, vn_parent::VarName{sym}, chain, chain_idx, iteration_idx, out ) where {sym} - # We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the lens. - # This way we can use `getlens(vn)` to extract the value from `x` and use `vn_parent ∘ getlens(vn)` + # We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the optic. + # This way we can use `getoptic(vn)` to extract the value from `x` and use `getoptic(vn) ∘ vn_parent` # to extract the value from the `chain`. for vn in varname_leaves(VarName{sym}(), x) # Update `out`, possibly in place, and return. - l = AbstractPPL.getlens(vn) - varname_in_chain!(x, vn_parent ∘ l, chain, chain_idx, iteration_idx, out) + l = AbstractPPL.getoptic(vn) + varname_in_chain!(x, l ∘ vn_parent, chain, chain_idx, iteration_idx, out) end return out end @@ -103,17 +103,17 @@ end function values_from_chain( x::AbstractArray, vn_parent::VarName{sym}, chain, chain_idx, iteration_idx ) where {sym} - # We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the lens. - # This way we can use `getlens(vn)` to extract the value from `x` and use `vn_parent ∘ getlens(vn)` + # We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the optic. + # This way we can use `getoptic(vn)` to extract the value from `x` and use `getoptic(vn) ∘ vn_parent` # to extract the value from the `chain`. out = similar(x) for vn in varname_leaves(VarName{sym}(), x) # Update `out`, possibly in place, and return. - l = AbstractPPL.getlens(vn) - out = Setfield.set( + l = AbstractPPL.getoptic(vn) + out = Accessors.set( out, BangBang.prefermutation(l), - chain[iteration_idx, Symbol(vn_parent ∘ l), chain_idx], + chain[iteration_idx, Symbol(l ∘ vn_parent), chain_idx], ) end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index ad37130d6..6704e03fa 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -259,15 +259,15 @@ function unflatten(svi::SimpleVarInfo, x::AbstractVector) end function BangBang.empty!!(vi::SimpleVarInfo) - return resetlogp!!(Setfield.@set vi.values = empty!!(vi.values)) + return resetlogp!!(Accessors.@set vi.values = empty!!(vi.values)) end Base.isempty(vi::SimpleVarInfo) = isempty(vi.values) getlogp(vi::SimpleVarInfo) = vi.logp getlogp(vi::SimpleVarInfo{<:Any,<:Ref}) = vi.logp[] -setlogp!!(vi::SimpleVarInfo, logp) = Setfield.@set vi.logp = logp -acclogp!!(vi::SimpleVarInfo, logp) = Setfield.@set vi.logp = getlogp(vi) + logp +setlogp!!(vi::SimpleVarInfo, logp) = Accessors.@set vi.logp = logp +acclogp!!(vi::SimpleVarInfo, logp) = Accessors.@set vi.logp = getlogp(vi) + logp function setlogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) vi.logp[] = logp @@ -343,7 +343,7 @@ Base.haskey(vi::SimpleVarInfo, vn::VarName) = hasvalue(vi.values, vn) function BangBang.setindex!!(vi::SimpleVarInfo, val, vn::VarName) # For `NamedTuple` we treat the symbol in `vn` as the _property_ to set. - return Setfield.@set vi.values = set!!(vi.values, vn, val) + return Accessors.@set vi.values = set!!(vi.values, vn, val) end function BangBang.setindex!!(vi::SimpleVarInfo, val, spl::AbstractSampler) @@ -362,34 +362,34 @@ end function BangBang.setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName) # For dictlike objects, we treat the entire `vn` as a _key_ to set. dict = values_as(vi) - # Attempt to split into `parent` and `child` lenses. - parent, child, issuccess = splitlens(getlens(vn)) do lens - l = lens === nothing ? Setfield.IdentityLens() : lens - haskey(dict, VarName(vn, l)) + # Attempt to split into `parent` and `child` optic. + parent, child, issuccess = splitoptic(getoptic(vn)) do optic + o = optic === nothing ? identity : optic + haskey(dict, VarName(vn, o)) end - # When combined with `VarInfo`, `nothing` is equivalent to `IdentityLens`. - keylens = parent === nothing ? Setfield.IdentityLens() : parent + # When combined with `VarInfo`, `nothing` is equivalent to `identity`. + keyoptic = parent === nothing ? identity : parent dict_new = if !issuccess # Split doesn't exist ⟹ we're working with a new key. BangBang.setindex!!(dict, val, vn) else # Split exists ⟹ trying to set an existing key. - vn_key = VarName(vn, keylens) + vn_key = VarName(vn, keyoptic) BangBang.setindex!!(dict, set!!(dict[vn_key], child, val), vn_key) end - return Setfield.@set vi.values = dict_new + return Accessors.@set vi.values = dict_new end # `NamedTuple` function BangBang.push!!( vi::SimpleVarInfo{<:NamedTuple}, - vn::VarName{sym,Setfield.IdentityLens}, + vn::VarName{sym,typeof(identity)}, value, dist::Distribution, gidset::Set{Selector}, ) where {sym} - return Setfield.@set vi.values = merge(vi.values, NamedTuple{(sym,)}((value,))) + return Accessors.@set vi.values = merge(vi.values, NamedTuple{(sym,)}((value,))) end function BangBang.push!!( vi::SimpleVarInfo{<:NamedTuple}, @@ -398,7 +398,7 @@ function BangBang.push!!( dist::Distribution, gidset::Set{Selector}, ) where {sym} - return Setfield.@set vi.values = set!!(vi.values, vn, value) + return Accessors.@set vi.values = set!!(vi.values, vn, value) end # `AbstractDict` @@ -426,7 +426,7 @@ end # `subset` function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName}) - return Setfield.@set varinfo.values = _subset(varinfo.values, vns) + return Accessors.@set varinfo.values = _subset(varinfo.values, vns) end function _subset(x::AbstractDict, vns) @@ -445,11 +445,11 @@ function _subset(x::AbstractDict, vns) end function _subset(x::NamedTuple, vns) - # NOTE: Here we can only handle `vns` that contain the `IdentityLens`. - if any(Base.Fix1(!==, Setfield.IdentityLens()) ∘ getlens, vns) + # NOTE: Here we can only handle `vns` that contain `identity` as optic. + if any(Base.Fix1(!==, identity) ∘ getoptic, vns) throw( ArgumentError( - "Cannot subset `NamedTuple` with non-`IdentityLens` `VarName`. " * + "Cannot subset `NamedTuple` with non-`identity` `VarName`. " * "For example, `@varname(x)` is allowed, but `@varname(x[1])` is not.", ), ) @@ -542,10 +542,10 @@ function settrans!!(vi::SimpleVarInfo, trans) return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation()) end function settrans!!(vi::SimpleVarInfo, transformation::AbstractTransformation) - return Setfield.@set vi.transformation = transformation + return Accessors.@set vi.transformation = transformation end function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) - return Setfield.@set vi.varinfo = settrans!!(vi.varinfo, trans) + return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans) end istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) @@ -675,7 +675,7 @@ function link!!( x = vi.values y, logjac = with_logabsdet_jacobian(b, x) lp_new = getlogp(vi) - logjac - vi_new = setlogp!!(Setfield.@set(vi.values = y), lp_new) + vi_new = setlogp!!(Accessors.@set(vi.values = y), lp_new) return settrans!!(vi_new, t) end @@ -690,7 +690,7 @@ function invlink!!( y = vi.values x, logjac = with_logabsdet_jacobian(b, y) lp_new = getlogp(vi) + logjac - vi_new = setlogp!!(Setfield.@set(vi.values = x), lp_new) + vi_new = setlogp!!(Accessors.@set(vi.values = x), lp_new) return settrans!!(vi_new, NoTransformation()) end diff --git a/src/test_utils.jl b/src/test_utils.jl index 6323f4dab..a315f7729 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -8,7 +8,7 @@ using Test using Random: Random using Bijectors: Bijectors -using Setfield: Setfield +using Accessors: Accessors # For backwards compat. using DynamicPPL: varname_leaves diff --git a/src/threadsafe.jl b/src/threadsafe.jl index fb1cc1c0c..c40d38466 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -58,7 +58,7 @@ end function BangBang.push!!( vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} ) - return Setfield.@set vi.varinfo = push!!(vi.varinfo, vn, r, dist, gidset) + return Accessors.@set vi.varinfo = push!!(vi.varinfo, vn, r, dist, gidset) end get_num_produce(vi::ThreadSafeVarInfo) = get_num_produce(vi.varinfo) @@ -84,25 +84,25 @@ islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl function link!!( t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model ) - return Setfield.@set vi.varinfo = link!!(t, vi.varinfo, spl, model) + return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, spl, model) end function invlink!!( t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model ) - return Setfield.@set vi.varinfo = invlink!!(t, vi.varinfo, spl, model) + return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, spl, model) end function link( t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model ) - return Setfield.@set vi.varinfo = link(t, vi.varinfo, spl, model) + return Accessors.@set vi.varinfo = link(t, vi.varinfo, spl, model) end function invlink( t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model ) - return Setfield.@set vi.varinfo = invlink(t, vi.varinfo, spl, model) + return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, spl, model) end # Need to define explicitly for `DynamicTransformation` to avoid method ambiguity. @@ -142,7 +142,7 @@ function maybe_invlink_before_eval!!( # Defer to the wrapped `AbstractVarInfo` object. # NOTE: When computing `getlogp` for `ThreadSafeVarInfo` we do include the `getlogp(vi.varinfo)` # hence the log-absdet-jacobian term will correctly be included in the `getlogp(vi)`. - return Setfield.@set vi.varinfo = maybe_invlink_before_eval!!( + return Accessors.@set vi.varinfo = maybe_invlink_before_eval!!( vi.varinfo, context, model ) end @@ -175,20 +175,20 @@ end getindex_raw(vi::ThreadSafeVarInfo, spl::AbstractSampler) = getindex_raw(vi.varinfo, spl) function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::AbstractSampler) - return Setfield.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) + return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) end function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::SampleFromPrior) - return Setfield.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) + return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) end function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::SampleFromUniform) - return Setfield.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) + return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) end function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vn::VarName) - return Setfield.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vn) + return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vn) end function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vns::AbstractVector{<:VarName}) - return Setfield.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vns) + return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vns) end function set_retained_vns_del_by_spl!(vi::ThreadSafeVarInfo, spl::Sampler) @@ -197,7 +197,7 @@ end isempty(vi::ThreadSafeVarInfo) = isempty(vi.varinfo) function BangBang.empty!!(vi::ThreadSafeVarInfo) - return resetlogp!!(Setfield.@set!(vi.varinfo = empty!!(vi.varinfo))) + return resetlogp!!(Accessors.@set(vi.varinfo = empty!!(vi.varinfo))) end values_as(vi::ThreadSafeVarInfo, ::Type{T}) where {T} = values_as(vi.varinfo, T) @@ -211,10 +211,10 @@ end # Transformations. function settrans!!(vi::ThreadSafeVarInfo, trans::Bool, vn::VarName) - return Setfield.@set vi.varinfo = settrans!!(vi.varinfo, trans, vn) + return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans, vn) end function settrans!!(vi::ThreadSafeVarInfo, spl::AbstractSampler, dist::Distribution) - return Setfield.@set vi.varinfo = settrans!!(vi.varinfo, spl, dist) + return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, spl, dist) end istrans(vi::ThreadSafeVarInfo, vn::VarName) = istrans(vi.varinfo, vn) @@ -223,18 +223,18 @@ istrans(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = istrans(vi.vari getval(vi::ThreadSafeVarInfo, vn::VarName) = getval(vi.varinfo, vn) function unflatten(vi::ThreadSafeVarInfo, x::AbstractVector) - return Setfield.@set vi.varinfo = unflatten(vi.varinfo, x) + return Accessors.@set vi.varinfo = unflatten(vi.varinfo, x) end function unflatten(vi::ThreadSafeVarInfo, spl::AbstractSampler, x::AbstractVector) - return Setfield.@set vi.varinfo = unflatten(vi.varinfo, spl, x) + return Accessors.@set vi.varinfo = unflatten(vi.varinfo, spl, x) end function subset(varinfo::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) - return Setfield.@set varinfo.varinfo = subset(varinfo.varinfo, vns) + return Accessors.@set varinfo.varinfo = subset(varinfo.varinfo, vns) end function Base.merge(varinfo_left::ThreadSafeVarInfo, varinfo_right::ThreadSafeVarInfo) - return Setfield.@set varinfo_left.varinfo = merge( + return Accessors.@set varinfo_left.varinfo = merge( varinfo_left.varinfo, varinfo_right.varinfo ) end diff --git a/src/utils.jl b/src/utils.jl index b447fed53..06528a72d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -347,13 +347,15 @@ collectmaybe(x::Base.AbstractSet) = collect(x) ####################### # BangBang.jl related # ####################### -function set!!(obj, lens::Setfield.Lens, value) - lensmut = BangBang.prefermutation(lens) - return Setfield.set(obj, lensmut, value) +function set!!(obj, optic::AbstractPPL.ALLOWED_OPTICS, value) + opticmut = BangBang.prefermutation(optic) + return Accessors.set(obj, opticmut, value) end function set!!(obj, vn::VarName{sym}, value) where {sym} - lens = BangBang.prefermutation(Setfield.PropertyLens{sym}() ∘ AbstractPPL.getlens(vn)) - return Setfield.set(obj, lens, value) + optic = BangBang.prefermutation( + AbstractPPL.getoptic(vn) ∘ Accessors.PropertyLens{sym}() + ) + return Accessors.set(obj, optic, value) end ############################# @@ -363,39 +365,41 @@ end # we're more likely to specialize on the key in these settings rather than the container. # TODO: I'm not sure about this name. """ - canview(lens, container) + canview(optic, container) -Return `true` if `lens` can be used to view `container`, and `false` otherwise. +Return `true` if `optic` can be used to view `container`, and `false` otherwise. # Examples -```jldoctest; setup=:(using Setfield; using DynamicPPL: canview) -julia> canview(@lens(_.a), (a = 1.0, )) +```jldoctest; setup=:(using Accessors; using DynamicPPL: canview) +julia> canview(@o(_.a), (a = 1.0, )) true -julia> canview(@lens(_.a), (b = 1.0, )) # property `a` does not exist +julia> canview(@o(_.a), (b = 1.0, )) # property `a` does not exist false -julia> canview(@lens(_.a[1]), (a = [1.0, 2.0], )) +julia> canview(@o(_.a[1]), (a = [1.0, 2.0], )) true -julia> canview(@lens(_.a[3]), (a = [1.0, 2.0], )) # out of bounds +julia> canview(@o(_.a[3]), (a = [1.0, 2.0], )) # out of bounds false ``` """ -canview(lens, container) = false -canview(::Setfield.IdentityLens, _) = true -function canview(lens::Setfield.PropertyLens{field}, x) where {field} +canview(optic, container) = false +canview(::typeof(identity), _) = true +function canview(optic::Accessors.PropertyLens{field}, x) where {field} return hasproperty(x, field) end # `IndexLens`: only relevant if `x` supports indexing. -canview(lens::Setfield.IndexLens, x) = false -canview(lens::Setfield.IndexLens, x::AbstractArray) = checkbounds(Bool, x, lens.indices...) +canview(optic::Accessors.IndexLens, x) = false +function canview(optic::Accessors.IndexLens, x::AbstractArray) + return checkbounds(Bool, x, optic.indices...) +end -# `ComposedLens`: check that we can view `.outer` and `.inner`, but using -# value extracted using `.outer`. -function canview(lens::Setfield.ComposedLens, x) - return canview(lens.outer, x) && canview(lens.inner, get(x, lens.outer)) +# `ComposedOptic`: check that we can view `.inner` and `.outer`, but using +# value extracted using `.inner`. +function canview(optic::Accessors.ComposedOptic, x) + return canview(optic.inner, x) && canview(optic.outer, optic.inner(x)) end """ @@ -416,123 +420,123 @@ x ``` """ function parent(vn::VarName) - p = parent(getlens(vn)) - return p === nothing ? VarName(vn, Setfield.IdentityLens()) : VarName(vn, p) + p = parent(getoptic(vn)) + return p === nothing ? VarName(vn, identity) : VarName(vn, p) end """ - parent(lens::Setfield.Lens) + parent(optic) -Return the parent lens. If `lens` doesn't have a parent, +Return the parent optic. If `optic` doesn't have a parent, `nothing` is returned. See also: [`parent_and_child`]. # Examples -```jldoctest; setup=:(using Setfield; using DynamicPPL: parent) -julia> parent(@lens(_.a[1])) -(@lens _.a) +```jldoctest; setup=:(using Accessors; using DynamicPPL: parent) +julia> parent(@o(_.a[1])) +(@o _.a) -julia> # Parent of lens without parents results in `nothing`. - (parent ∘ parent)(@lens(_.a[1])) === nothing +julia> # Parent of optic without parents results in `nothing`. + (parent ∘ parent)(@o(_.a[1])) === nothing true ``` """ -parent(lens::Setfield.Lens) = first(parent_and_child(lens)) +parent(optic::AbstractPPL.ALLOWED_OPTICS) = first(parent_and_child(optic)) """ - parent_and_child(lens::Setfield.Lens) + parent_and_child(optic) -Return a 2-tuple of lenses `(parent, child)` where `parent` is the -parent lens of `lens` and `child` is the child lens of `lens`. +Return a 2-tuple of optics `(parent, child)` where `parent` is the +parent optic of `optic` and `child` is the child optic of `optic`. -If `lens` does not have a parent, we return `(nothing, lens)`. +If `optic` does not have a parent, we return `(nothing, optic)`. See also: [`parent`]. # Examples -```jldoctest; setup=:(using Setfield; using DynamicPPL: parent_and_child) -julia> parent_and_child(@lens(_.a[1])) -((@lens _.a), (@lens _[1])) +```jldoctest; setup=:(using Accessors; using DynamicPPL: parent_and_child) +julia> parent_and_child(@o(_.a[1])) +((@o _.a), (@o _[1])) -julia> parent_and_child(@lens(_.a)) -(nothing, (@lens _.a)) +julia> parent_and_child(@o(_.a)) +(nothing, (@o _.a)) ``` """ -parent_and_child(lens::Setfield.Lens) = (nothing, lens) -function parent_and_child(lens::Setfield.ComposedLens) - p, child = parent_and_child(lens.inner) - parent = p === nothing ? lens.outer : lens.outer ∘ p +parent_and_child(optic::AbstractPPL.ALLOWED_OPTICS) = (nothing, optic) +function parent_and_child(optic::Accessors.ComposedOptic) + p, child = parent_and_child(optic.outer) + parent = p === nothing ? optic.inner : p ∘ optic.inner return parent, child end """ - splitlens(condition, lens) + splitoptic(condition, optic) Return a 3-tuple `(parent, child, issuccess)` where, if `issuccess` is `true`, -`parent` is a lens such that `condition(parent)` is `true` and `parent ∘ child == lens`. +`parent` is a optic such that `condition(parent)` is `true` and `child ∘ parent == optic`. If `issuccess` is `false`, then no such split could be found. # Examples -```jldoctest; setup=:(using Setfield; using DynamicPPL: splitlens) -julia> p, c, issucesss = splitlens(@lens(_.a[1])) do parent +```jldoctest; setup=:(using Accessors; using DynamicPPL: splitoptic) +julia> p, c, issucesss = splitoptic(@o(_.a[1])) do parent # Succeeds! - parent == @lens(_.a) + parent == @o(_.a) end -((@lens _.a), (@lens _[1]), true) +((@o _.a), (@o _[1]), true) -julia> p ∘ c -(@lens _.a[1]) +julia> c ∘ p +(@o _.a[1]) -julia> splitlens(@lens(_.a[1])) do parent +julia> splitoptic(@o(_.a[1])) do parent # Fails! - parent == @lens(_.b) + parent == @o(_.b) end -(nothing, (@lens _.a[1]), false) +(nothing, (@o _.a[1]), false) ``` """ -function splitlens(condition, lens) - current_parent, current_child = parent_and_child(lens) +function splitoptic(condition, optic) + current_parent, current_child = parent_and_child(optic) # We stop if either a) `condition` is satisfied, or b) we reached the root. while !condition(current_parent) && current_parent !== nothing current_parent, c = parent_and_child(current_parent) - current_child = c ∘ current_child + current_child = current_child ∘ c end return current_parent, current_child, condition(current_parent) end """ - remove_parent_lens(vn_parent::VarName, vn_child::VarName) + remove_parent_optic(vn_parent::VarName, vn_child::VarName) -Remove the parent lens `vn_parent` from `vn_child`. +Remove the parent optic `vn_parent` from `vn_child`. # Examples -```jldoctest -julia> DynamicPPL.remove_parent_lens(@varname(x), @varname(x.a)) -(@lens _.a) +```jldoctest; setup = :(using Accessors; using DynamicPPL: remove_parent_optic) +julia> remove_parent_optic(@varname(x), @varname(x.a)) +(@o _.a) -julia> DynamicPPL.remove_parent_lens(@varname(x), @varname(x.a[1])) -(@lens _.a[1]) +julia> remove_parent_optic(@varname(x), @varname(x.a[1])) +(@o _.a[1]) -julia> DynamicPPL.remove_parent_lens(@varname(x.a), @varname(x.a[1])) -(@lens _[1]) +julia> remove_parent_optic(@varname(x.a), @varname(x.a[1])) +(@o _[1]) -julia> DynamicPPL.remove_parent_lens(@varname(x.a), @varname(x.a[1].b)) -(@lens _[1].b) +julia> remove_parent_optic(@varname(x.a), @varname(x.a[1].b)) +(@o _[1].b) -julia> DynamicPPL.remove_parent_lens(@varname(x.a), @varname(x.a)) +julia> remove_parent_optic(@varname(x.a), @varname(x.a)) ERROR: Could not find x.a in x.a -julia> DynamicPPL.remove_parent_lens(@varname(x.a[2]), @varname(x.a[1])) +julia> remove_parent_optic(@varname(x.a[2]), @varname(x.a[1])) ERROR: Could not find x.a[2] in x.a[1] ``` """ -function remove_parent_lens(vn_parent::VarName{sym}, vn_child::VarName{sym}) where {sym} - _, child, issuccess = splitlens(getlens(vn_child)) do lens - l = lens === nothing ? Setfield.IdentityLens() : lens - VarName(vn_child, l) == vn_parent +function remove_parent_optic(vn_parent::VarName{sym}, vn_child::VarName{sym}) where {sym} + _, child, issuccess = splitoptic(getoptic(vn_child)) do optic + o = optic === nothing ? identity : optic + VarName(vn_child, o) == vn_parent end issuccess || error("Could not find $vn_parent in $vn_child") @@ -749,8 +753,8 @@ false """ function hasvalue(vals::NamedTuple, vn::VarName{sym}) where {sym} # LHS: Ensure that `nt` indeed has the property we want. - # RHS: Ensure that the lens can view into `nt`. - return haskey(vals, sym) && canview(getlens(vn), getproperty(vals, sym)) + # RHS: Ensure that the optic can view into `nt`. + return haskey(vals, sym) && canview(getoptic(vn), getproperty(vals, sym)) end # For `dictlike` we need to check wether `vn` is "immediately" present, or @@ -760,20 +764,20 @@ function hasvalue(vals::AbstractDict, vn::VarName) haskey(vals, vn) && return true # If `vn` is not present, we check any parent-varnames by attempting - # to split the lens into the key / `parent` and the extraction lens / `child`. + # to split the optic into the key / `parent` and the extraction optic / `child`. # If `issuccess` is `true`, we found such a split, and hence `vn` is present. - parent, child, issuccess = splitlens(getlens(vn)) do lens - l = lens === nothing ? Setfield.IdentityLens() : lens - haskey(vals, VarName(vn, l)) + parent, child, issuccess = splitoptic(getoptic(vn)) do optic + o = optic === nothing ? identity : optic + haskey(vals, VarName(vn, o)) end - # When combined with `VarInfo`, `nothing` is equivalent to `IdentityLens`. - keylens = parent === nothing ? Setfield.IdentityLens() : parent + # When combined with `VarInfo`, `nothing` is equivalent to `identity`. + keyoptic = parent === nothing ? identity : parent # Return early if no such split could be found. issuccess || return false # At this point we just need to check that we `canview` the value. - value = vals[VarName(vn, keylens)] + value = vals[VarName(vn, keyoptic)] return canview(child, value) end @@ -790,13 +794,13 @@ function nested_getindex(values::AbstractDict, vn::VarName) return maybeval end - # Split the lens into the key / `parent` and the extraction lens / `child`. - parent, child, issuccess = splitlens(getlens(vn)) do lens - l = lens === nothing ? Setfield.IdentityLens() : lens - haskey(values, VarName(vn, l)) + # Split the optic into the key / `parent` and the extraction optic / `child`. + parent, child, issuccess = splitoptic(getoptic(vn)) do optic + o = optic === nothing ? identity : optic + haskey(values, VarName(vn, o)) end - # When combined with `VarInfo`, `nothing` is equivalent to `IdentityLens`. - keylens = parent === nothing ? Setfield.IdentityLens() : parent + # When combined with `VarInfo`, `nothing` is equivalent to `identity`. + keyoptic = parent === nothing ? identity : parent # If we found a valid split, then we can extract the value. if !issuccess @@ -806,8 +810,8 @@ function nested_getindex(values::AbstractDict, vn::VarName) # TODO: Should we also check that we `canview` the extracted `value` # rather than just let it fail upon `get` call? - value = values[VarName(vn, keylens)] - return get(value, child) + value = values[VarName(vn, keyoptic)] + return child(value) end """ @@ -911,20 +915,20 @@ x.z[2][1] varname_leaves(vn::VarName, ::Real) = [vn] function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) return ( - VarName(vn, getlens(vn) ∘ Setfield.IndexLens(Tuple(I))) for + VarName(vn, Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)) for I in CartesianIndices(val) ) end function varname_leaves(vn::VarName, val::AbstractArray) return Iterators.flatten( - varname_leaves(VarName(vn, getlens(vn) ∘ Setfield.IndexLens(Tuple(I))), val[I]) for - I in CartesianIndices(val) + varname_leaves(VarName(vn, Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), val[I]) + for I in CartesianIndices(val) ) end function varname_leaves(vn::VarName, val::NamedTuple) iter = Iterators.map(keys(val)) do sym - lens = Setfield.PropertyLens{sym}() - varname_leaves(vn ∘ lens, get(val, lens)) + optic = Accessors.PropertyLens{sym}() + varname_leaves(VarName(vn, optic ∘ getoptic(vn)), optic(val)) end return Iterators.flatten(iter) end @@ -963,27 +967,27 @@ julia> x = reshape(1:4, 2, 2); julia> # `LowerTriangular` foreach(println, varname_and_value_leaves(@varname(x), LowerTriangular(x))) -(x[1,1], 1) -(x[2,1], 2) -(x[2,2], 4) +(x[1, 1], 1) +(x[2, 1], 2) +(x[2, 2], 4) julia> # `UpperTriangular` foreach(println, varname_and_value_leaves(@varname(x), UpperTriangular(x))) -(x[1,1], 1) -(x[1,2], 3) -(x[2,2], 4) +(x[1, 1], 1) +(x[1, 2], 3) +(x[2, 2], 4) julia> # `Cholesky` with lower-triangular foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'L', 0))) -(x.L[1,1], 1.0) -(x.L[2,1], 0.0) -(x.L[2,2], 1.0) +(x.L[1, 1], 1.0) +(x.L[2, 1], 0.0) +(x.L[2, 2], 1.0) julia> # `Cholesky` with upper-triangular foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'U', 0))) -(x.U[1,1], 1.0) -(x.U[1,2], 0.0) -(x.U[2,2], 1.0) +(x.U[1, 1], 1.0) +(x.U[1, 2], 0.0) +(x.U[2, 2], 1.0) ``` """ function varname_and_value_leaves(vn::VarName, x) @@ -1033,7 +1037,7 @@ function varname_and_value_leaves_inner( ) return ( Leaf( - VarName(vn, DynamicPPL.getlens(vn) ∘ DynamicPPL.Setfield.IndexLens(Tuple(I))), + VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) ∘ DynamicPPL.getoptic(vn)), val[I], ) for I in CartesianIndices(val) ) @@ -1042,15 +1046,17 @@ end function varname_and_value_leaves_inner(vn::VarName, val::AbstractArray) return Iterators.flatten( varname_and_value_leaves_inner( - VarName(vn, DynamicPPL.getlens(vn) ∘ DynamicPPL.Setfield.IndexLens(Tuple(I))), + VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) ∘ DynamicPPL.getoptic(vn)), val[I], ) for I in CartesianIndices(val) ) end function varname_and_value_leaves_inner(vn::DynamicPPL.VarName, val::NamedTuple) iter = Iterators.map(keys(val)) do sym - lens = DynamicPPL.Setfield.PropertyLens{sym}() - varname_and_value_leaves_inner(vn ∘ lens, get(val, lens)) + optic = DynamicPPL.Accessors.PropertyLens{sym}() + varname_and_value_leaves_inner( + VarName{getsym(vn)}(optic ∘ getoptic(vn)), optic(val) + ) end return Iterators.flatten(iter) @@ -1059,15 +1065,15 @@ end function varname_and_value_leaves_inner(vn::VarName, x::Cholesky) # TODO: Or do we use `PDMat` here? return if x.uplo == 'L' - varname_and_value_leaves_inner(vn ∘ Setfield.PropertyLens{:L}(), x.L) + varname_and_value_leaves_inner(Accessors.PropertyLens{:L}() ∘ vn, x.L) else - varname_and_value_leaves_inner(vn ∘ Setfield.PropertyLens{:U}(), x.U) + varname_and_value_leaves_inner(Accessors.PropertyLens{:U}() ∘ vn, x.U) end end function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.LowerTriangular) return ( Leaf( - VarName(vn, DynamicPPL.getlens(vn) ∘ DynamicPPL.Setfield.IndexLens(Tuple(I))), + VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) ∘ DynamicPPL.getoptic(vn)), x[I], ) # Iteration over the lower-triangular indices. @@ -1077,7 +1083,7 @@ end function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.UpperTriangular) return ( Leaf( - VarName(vn, DynamicPPL.getlens(vn) ∘ DynamicPPL.Setfield.IndexLens(Tuple(I))), + VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) ∘ DynamicPPL.getoptic(vn)), x[I], ) # Iteration over the upper-triangular indices. diff --git a/src/varinfo.jl b/src/varinfo.jl index c8c46ee27..de4e3196f 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -965,7 +965,7 @@ function link!!( ) # By default this will simply evaluate the model with `DynamicTransformationContext`, and so # we need to specialize to avoid this. - return Setfield.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, spl, model) + return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, spl, model) end """ @@ -1051,7 +1051,7 @@ function invlink!!( ) # By default this will simply evaluate the model with `DynamicTransformationContext`, and so # we need to specialize to avoid this. - return Setfield.@set vi.varinfo = DynamicPPL.invlink!!(vi.varinfo, spl, model) + return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(vi.varinfo, spl, model) end function maybe_invlink_before_eval!!(vi::VarInfo, context::AbstractContext, model::Model) @@ -1166,7 +1166,7 @@ function link( ) # By default this will simply evaluate the model with `DynamicTransformationContext`, and so # we need to specialize to avoid this. - return Setfield.@set varinfo.varinfo = link(varinfo.varinfo, spl, model) + return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, spl, model) end function _link(varinfo::UntypedVarInfo, spl::AbstractSampler) @@ -1261,7 +1261,7 @@ function invlink( ) # By default this will simply evaluate the model with `DynamicTransformationContext`, and so # we need to specialize to avoid this. - return Setfield.@set varinfo.varinfo = invlink(varinfo.varinfo, spl, model) + return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, spl, model) end function _invlink(varinfo::UntypedVarInfo, spl::AbstractSampler) @@ -1397,10 +1397,10 @@ function _nested_setindex_maybe!(vi::VarInfo, md::Metadata, val, vn::VarName) vn_parent = vns[i] dist = getdist(md, vn_parent) val_parent = getindex(vi, vn_parent, dist) # TODO: Ensure that we're working with a view here. - # Split the varname into its tail lens. - lens = remove_parent_lens(vn_parent, vn) + # Split the varname into its tail optic. + optic = remove_parent_optic(vn_parent, vn) # Update the value for the parent. - val_parent_updated = set!!(val_parent, lens, val) + val_parent_updated = set!!(val_parent, optic, val) setindex!(vi, val_parent_updated, vn_parent) return vn_parent end diff --git a/test/Project.toml b/test/Project.toml index 93cd7ecd1..d345d30d5 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,6 +2,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" @@ -19,16 +20,16 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" -Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] +Accessors = "0.1" ADTypes = "0.2" AbstractMCMC = "5" -AbstractPPL = "0.7" +AbstractPPL = "0.8.2" Bijectors = "0.13" Compat = "4.3.0" Distributions = "0.25" @@ -41,7 +42,6 @@ LogDensityProblemsAD = "1" MCMCChains = "6.0.4" MacroTools = "0.5.5" ReverseDiff = "1" -Setfield = "1" StableRNGs = "1" Tracker = "0.2.23" Zygote = "0.6" diff --git a/test/contexts.jl b/test/contexts.jl index d04aecb52..11e2c99b7 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -1,4 +1,4 @@ -using Test, DynamicPPL, Setfield +using Test, DynamicPPL, Accessors using DynamicPPL: leafcontext, setleafcontext, @@ -55,7 +55,7 @@ Return `vn` but now with the prefix removed. """ function remove_prefix(vn::VarName) return VarName{Symbol(split(string(vn), string(DynamicPPL.PREFIX_SEPARATOR))[end])}( - getlens(vn) + getoptic(vn) ) end @@ -169,7 +169,7 @@ end # Let's check elementwise. for vn_child in DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) - if get(val, getlens(vn_child)) === missing + if getoptic(vn_child)(val) === missing @test contextual_isassumption(context, vn_child) else @test !contextual_isassumption(context, vn_child) @@ -206,7 +206,7 @@ end @test hasconditioned_nested(context, vn_child) # Value should be the same as extracted above. @test getconditioned_nested(context, vn_child) === - get(val, getlens(vn_child)) + getoptic(vn_child)(val) end end end @@ -233,12 +233,12 @@ end vn = VarName{:x}() vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) @test DynamicPPL.getsym(vn_prefixed) == Symbol("a.b.c.d.e.f.x") - @test getlens(vn_prefixed) === getlens(vn) + @test getoptic(vn_prefixed) === getoptic(vn) vn = VarName{:x}(((1,),)) vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) @test DynamicPPL.getsym(vn_prefixed) == Symbol("a.b.c.d.e.f.x") - @test getlens(vn_prefixed) === getlens(vn) + @test getoptic(vn_prefixed) === getoptic(vn) end @testset "SamplingContext" begin diff --git a/test/runtests.jl b/test/runtests.jl index 9e11e2ef4..efa595516 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,3 +1,4 @@ +using Accessors using ADTypes using DynamicPPL using AbstractMCMC @@ -13,7 +14,6 @@ using MCMCChains using Tracker using ReverseDiff using Zygote -using Setfield using Compat using Distributed @@ -25,6 +25,20 @@ using Test using DynamicPPL: getargs_dottilde, getargs_tilde, Selector +# TODO: temporarily overwrite for testing +using AbstractPPL: ALLOWED_OPTICS, VarName +# Allow compositions with optic. +function Base.:∘(optic::ALLOWED_OPTICS, vn::VarName{sym,<:ALLOWED_OPTICS}) where {sym} + vn_optic = getoptic(vn) + if vn_optic == identity + return VarName{sym}(optic) + elseif optic == identity + return vn + else + return VarName{sym}(optic ∘ vn_optic) + end +end + const DIRECTORY_DynamicPPL = dirname(dirname(pathof(DynamicPPL))) const DIRECTORY_Turing_tests = joinpath(DIRECTORY_DynamicPPL, "test", "turing") const GROUP = get(ENV, "GROUP", "All") @@ -76,7 +90,7 @@ include("test_util.jl") DocMeta.setdocmeta!( DynamicPPL, :DocTestSetup, - :(using DynamicPPL, Distributions); + :(using DynamicPPL, Distributions, Accessors); recursive=true, ) doctestfilters = [ diff --git a/test/turing/varinfo.jl b/test/turing/varinfo.jl index c4e3fa87b..30408e598 100644 --- a/test/turing/varinfo.jl +++ b/test/turing/varinfo.jl @@ -192,7 +192,7 @@ return p end chain = sample(mat_name_test(), HMC(0.2, 4), 1000) - check_numerical(chain, ["p[1,1]"], [0]; atol=0.25) + check_numerical(chain, ["p[1, 1]"], [0]; atol=0.25) @model function marr_name_test() p = Array{Array{Any}}(undef, 2) diff --git a/test/varinfo.jl b/test/varinfo.jl index 71e341767..a9e734575 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -1,7 +1,7 @@ function check_varinfo_keys(varinfo, vns) if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple{<:NamedTuple} # NOTE: We can't compare the `keys(varinfo_merged)` directly with `vns`, - # since `keys(varinfo_merged)` only contains `VarName` with `IdentityLens`. + # since `keys(varinfo_merged)` only contains `VarName` with `identity`. # So we just check that the original keys are present. for vn in vns # Should have all the original keys. @@ -519,7 +519,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) end # For certain varinfos we should have errors. - # `SimpleVarInfo{<:NamedTuple}` can only handle varnames with `IdentityLens`. + # `SimpleVarInfo{<:NamedTuple}` can only handle varnames with `identity`. varinfo = varinfos[findfirst(Base.Fix2(isa, SimpleVarInfo{<:NamedTuple}), varinfos)] @testset "$(short_varinfo_name(varinfo)): failure cases" begin @test_throws ArgumentError subset(