From f6dd7ecc9c1abb92cdc3be2ee8d3489232719ef6 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 16 Feb 2024 13:58:54 +0000 Subject: [PATCH 01/23] remove `BangBang.possible` --- src/utils.jl | 37 ---------------- test/utils.jl | 115 -------------------------------------------------- 2 files changed, 152 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index b447fed53..d515a303d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -539,43 +539,6 @@ function remove_parent_lens(vn_parent::VarName{sym}, vn_child::VarName{sym}) whe return child end -# HACK: All of these are related to https://github.com/JuliaFolds/BangBang.jl/issues/233 -# and https://github.com/JuliaFolds/BangBang.jl/pull/238, https://github.com/JuliaFolds2/BangBang.jl/pull/16. -# This avoids type-instability in `dot_assume` for `SimpleVarInfo`. -# The following code a copy from https://github.com/JuliaFolds2/BangBang.jl/pull/16 authored by torfjelde -# Default implementation for `_setindex!` with `AbstractArray`. -# But this will return `false` even in cases such as -# -# setindex!!([1, 2, 3], [4, 5, 6], :) -# -# because `promote_type(eltype(C), T) <: eltype(C)` is `false`. -# To address this, we specialize on the case where `T<:AbstractArray`. -# In addition, we need to support a wide range of indexing behaviors: -# -# We also need to ensure that the dimensionality of the index is -# valid, i.e. that we're not returning `true` in cases such as -# -# setindex!!([1, 2, 3], [4, 5], 1) -# -# which should return `false`. -_index_dimension(::Any) = 0 -_index_dimension(::Colon) = 1 -_index_dimension(::AbstractVector) = 1 -_index_dimension(indices::Tuple) = sum(map(_index_dimension, indices)) - -function BangBang.possible( - ::typeof(BangBang._setindex!), ::C, ::T, indices::Vararg -) where {M,C<:AbstractArray{<:Real},T<:AbstractArray{<:Real,M}} - return BangBang.implements(setindex!, C) && - promote_type(eltype(C), eltype(T)) <: eltype(C) && - # This will still return `false` for scenarios such as - # - # setindex!!([1, 2, 3], [4, 5, 6], :, 1) - # - # which are in fact valid. However, this cases are rare. - (_index_dimension(indices) == M || _index_dimension(indices) == 1) -end - # HACK(torfjelde): This makes it so it works on iterators, etc. by default. # TODO(torfjelde): Do better. """ diff --git a/test/utils.jl b/test/utils.jl index a2d6f46fb..1fcf09ef1 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -48,119 +48,4 @@ x = rand(dist) @test vectorize(dist, x) == vec(x.UL) end - - @testset "BangBang.possible" begin - using DynamicPPL.BangBang: setindex!! - - # Some utility methods for testing `setindex!`. - test_linear_index_only(::Tuple, ::AbstractArray) = false - test_linear_index_only(inds::NTuple{1}, ::AbstractArray) = true - test_linear_index_only(inds::NTuple{1}, ::AbstractVector) = false - - function replace_colon_with_axis(inds::Tuple, x) - ntuple(length(inds)) do i - inds[i] isa Colon ? axes(x, i) : inds[i] - end - end - function replace_colon_with_vector(inds::Tuple, x) - ntuple(length(inds)) do i - inds[i] isa Colon ? collect(axes(x, i)) : inds[i] - end - end - function replace_colon_with_range(inds::Tuple, x) - ntuple(length(inds)) do i - inds[i] isa Colon ? (1:size(x, i)) : inds[i] - end - end - function replace_colon_with_booleans(inds::Tuple, x) - ntuple(length(inds)) do i - inds[i] isa Colon ? trues(size(x, i)) : inds[i] - end - end - - function replace_colon_with_range_linear(inds::NTuple{1}, x::AbstractArray) - return inds[1] isa Colon ? (1:length(x),) : inds - end - - @testset begin - @test setindex!!((1, 2, 3), :two, 2) === (1, :two, 3) - @test setindex!!((a=1, b=2, c=3), :two, :b) === (a=1, b=:two, c=3) - @test setindex!!([1, 2, 3], :two, 2) == [1, :two, 3] - @test setindex!!(Dict{Symbol,Int}(:a => 1, :b => 2), 10, :a) == - Dict(:a => 10, :b => 2) - @test setindex!!(Dict{Symbol,Int}(:a => 1, :b => 2), 3, "c") == - Dict(:a => 1, :b => 2, "c" => 3) - end - - @testset "mutation" begin - @testset "without type expansion" begin - for args in [([1, 2, 3], 20, 2), (Dict(:a => 1, :b => 2), 10, :a)] - @test setindex!!(args...) === args[1] - end - end - - @testset "with type expansion" begin - @test setindex!!([1, 2, 3], [4, 5], 1) == [[4, 5], 2, 3] - @test setindex!!([1, 2, 3], [4, 5, 6], :, 1) == [4, 5, 6] - end - end - - @testset "slices" begin - @testset "$(typeof(x)) with $(src_idx)" for (x, src_idx) in [ - # Vector. - (randn(2), (:,)), - (randn(2), (1:2,)), - # Matrix. - (randn(2, 3), (:,)), - (randn(2, 3), (:, 1)), - (randn(2, 3), (:, 1:3)), - # 3D array. - (randn(2, 3, 4), (:, 1, :)), - (randn(2, 3, 4), (:, 1:3, :)), - (randn(2, 3, 4), (1, 1:3, :)), - ] - # Base case. - @test @inferred(setindex!!(x, x[src_idx...], src_idx...)) === x - - # If we have `Colon` in the index, we replace this with other equivalent indices. - if any(Base.Fix2(isa, Colon), src_idx) - if test_linear_index_only(src_idx, x) - # With range instead of `Colon`. - @test @inferred( - setindex!!( - x, - x[src_idx...], - replace_colon_with_range_linear(src_idx, x)..., - ) - ) === x - else - # With axis instead of `Colon`. - @test @inferred( - setindex!!( - x, x[src_idx...], replace_colon_with_axis(src_idx, x)... - ) - ) === x - # With range instead of `Colon`. - @test @inferred( - setindex!!( - x, x[src_idx...], replace_colon_with_range(src_idx, x)... - ) - ) === x - # With vectors instead of `Colon`. - @test @inferred( - setindex!!( - x, x[src_idx...], replace_colon_with_vector(src_idx, x)... - ) - ) === x - # With boolean index instead of `Colon`. - @test @inferred( - setindex!!( - x, x[src_idx...], replace_colon_with_booleans(src_idx, x)... - ) - ) === x - end - end - end - end - end end From 58dec49225749f1f35ef417b9a36f88f5300e34b Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 16 Feb 2024 15:26:30 +0000 Subject: [PATCH 02/23] version bumps --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 6510e7ea0..9b7ce5cb4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.24.7" +version = "0.24.8" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -28,7 +28,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" ADTypes = "0.2" AbstractMCMC = "5" AbstractPPL = "0.7" -BangBang = "0.3" +BangBang = "0.4.1" Bijectors = "0.13" ChainRulesCore = "1" Compat = "4" From d3cd5e873723d35d69ea4b0b69ec7a0482b647b2 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Thu, 29 Feb 2024 08:05:41 +0000 Subject: [PATCH 03/23] remove dep `MLUtils` --- docs/Project.toml | 2 -- docs/src/tutorials/prob-interface.md | 20 ++++++++++++++++---- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 48ebe173c..271976d0e 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -5,7 +5,6 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" -MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" @@ -16,6 +15,5 @@ Documenter = "1" FillArrays = "0.13, 1" LogDensityProblems = "2" MCMCChains = "5, 6" -MLUtils = "0.3, 0.4" Setfield = "0.7.1, 0.8, 1" StableRNGs = "1" diff --git a/docs/src/tutorials/prob-interface.md b/docs/src/tutorials/prob-interface.md index dc9f50204..cccebdaf3 100644 --- a/docs/src/tutorials/prob-interface.md +++ b/docs/src/tutorials/prob-interface.md @@ -107,12 +107,11 @@ To give an example of the probability interface in use, we can use it to estimat In cross-validation, we split the dataset into several equal parts. Then, we choose one of these sets to serve as the validation set. Here, we measure fit using the cross entropy (Bayes loss).[^1] +(For the sake of simplicity, in the following code, we enforce that `nfolds` ) ```@example probinterface -using MLUtils - function cross_val( - dataset::AbstractVector{<:Real}; + dataset::Vector{<:Real}; nfolds::Int=5, nsamples::Int=1_000, rng::Random.AbstractRNG=Random.default_rng(), @@ -121,7 +120,20 @@ function cross_val( model = gdemo(1) | (x=[first(dataset)],) loss = zero(logjoint(model, rand(rng, model))) - for (train, validation) in kfolds(dataset, nfolds) + # prepare the K-folds + fold_size = div(length(dataset), nfolds) + if length(dataset) % nfolds != 0 + error("The number of folds must divide the number of data points.") + end + splits = Vector{Tuple{SubArray, SubArray}}(undef, nfolds) + + for i in 1:nfolds + start_idx, end_idx = (i-1)*fold_size + 1, i*fold_size + train_set_indices = [1:start_idx-1; end_idx+1:length(dataset)] + splits[i] = (view(dataset, train_set_indices), view(dataset, start_idx:end_idx)) + end + + for (train, validation) in splits # First, we train the model on the training set, i.e., we obtain samples from the posterior. # For normally-distributed data, the posterior can be computed in closed form. # For general models, however, typically samples will be generated using MCMC with Turing. From 4e4a994e9dab5b0877df71ae044c81387de6dd7b Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Thu, 29 Feb 2024 08:08:25 +0000 Subject: [PATCH 04/23] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/tutorials/prob-interface.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/src/tutorials/prob-interface.md b/docs/src/tutorials/prob-interface.md index cccebdaf3..7ff81cf25 100644 --- a/docs/src/tutorials/prob-interface.md +++ b/docs/src/tutorials/prob-interface.md @@ -125,11 +125,11 @@ function cross_val( if length(dataset) % nfolds != 0 error("The number of folds must divide the number of data points.") end - splits = Vector{Tuple{SubArray, SubArray}}(undef, nfolds) - + splits = Vector{Tuple{SubArray,SubArray}}(undef, nfolds) + for i in 1:nfolds - start_idx, end_idx = (i-1)*fold_size + 1, i*fold_size - train_set_indices = [1:start_idx-1; end_idx+1:length(dataset)] + start_idx, end_idx = (i - 1) * fold_size + 1, i * fold_size + train_set_indices = [1:(start_idx - 1); (end_idx + 1):length(dataset)] splits[i] = (view(dataset, train_set_indices), view(dataset, start_idx:end_idx)) end From ef668a33d2aa562d2c51fa1b1eefa2164b1516d6 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Thu, 29 Feb 2024 12:57:46 +0000 Subject: [PATCH 05/23] finish sentence --- docs/src/tutorials/prob-interface.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/tutorials/prob-interface.md b/docs/src/tutorials/prob-interface.md index 7ff81cf25..5a16b0db1 100644 --- a/docs/src/tutorials/prob-interface.md +++ b/docs/src/tutorials/prob-interface.md @@ -107,7 +107,7 @@ To give an example of the probability interface in use, we can use it to estimat In cross-validation, we split the dataset into several equal parts. Then, we choose one of these sets to serve as the validation set. Here, we measure fit using the cross entropy (Bayes loss).[^1] -(For the sake of simplicity, in the following code, we enforce that `nfolds` ) +(For the sake of simplicity, in the following code, we enforce that `nfolds` must divide the number of data points. For a more competent implementation, see [MLUtils.jl](https://juliaml.github.io/MLUtils.jl/dev/api/#MLUtils.kfolds).) ```@example probinterface function cross_val( From c901f820bcf469990df5fc294ab47d793ba8cdda Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Thu, 29 Feb 2024 12:58:26 +0000 Subject: [PATCH 06/23] Update docs/src/tutorials/prob-interface.md Co-authored-by: David Widmann --- docs/src/tutorials/prob-interface.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/tutorials/prob-interface.md b/docs/src/tutorials/prob-interface.md index 5a16b0db1..cc19bf050 100644 --- a/docs/src/tutorials/prob-interface.md +++ b/docs/src/tutorials/prob-interface.md @@ -121,8 +121,8 @@ function cross_val( loss = zero(logjoint(model, rand(rng, model))) # prepare the K-folds - fold_size = div(length(dataset), nfolds) - if length(dataset) % nfolds != 0 + fold_size, remaining = divrem(length(dataset), nfolds) + if remaining != 0 error("The number of folds must divide the number of data points.") end splits = Vector{Tuple{SubArray,SubArray}}(undef, nfolds) From f20dafe95146b5f3dad9589b81fbb01577ddaa97 Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Thu, 29 Feb 2024 12:59:11 +0000 Subject: [PATCH 07/23] Update docs/src/tutorials/prob-interface.md Co-authored-by: David Widmann --- docs/src/tutorials/prob-interface.md | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/docs/src/tutorials/prob-interface.md b/docs/src/tutorials/prob-interface.md index cc19bf050..2da5c3201 100644 --- a/docs/src/tutorials/prob-interface.md +++ b/docs/src/tutorials/prob-interface.md @@ -125,12 +125,13 @@ function cross_val( if remaining != 0 error("The number of folds must divide the number of data points.") end - splits = Vector{Tuple{SubArray,SubArray}}(undef, nfolds) - - for i in 1:nfolds - start_idx, end_idx = (i - 1) * fold_size + 1, i * fold_size - train_set_indices = [1:(start_idx - 1); (end_idx + 1):length(dataset)] - splits[i] = (view(dataset, train_set_indices), view(dataset, start_idx:end_idx)) + first_idx = firstindex(dataset) + last_idx = lastindex(dataset) + splits = map(0:(nfolds - 1)) do i + start_idx = first_idx + i * fold_size + end_idx = start_idx + fold_size + train_set_indices = [first_idx:start_idx; (end_idx + 1):last_idx] + return (view(dataset, train_set_indices), view(dataset, (start_idx + 1):end_idx)) end for (train, validation) in splits From 9934f99d2c7503d5843e71e893dd36de1bf2ffc2 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Thu, 29 Feb 2024 13:12:03 +0000 Subject: [PATCH 08/23] make `kfolds` a function --- docs/src/tutorials/prob-interface.md | 31 +++++++++++++++------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/docs/src/tutorials/prob-interface.md b/docs/src/tutorials/prob-interface.md index 2da5c3201..683396550 100644 --- a/docs/src/tutorials/prob-interface.md +++ b/docs/src/tutorials/prob-interface.md @@ -110,17 +110,8 @@ Here, we measure fit using the cross entropy (Bayes loss).[^1] (For the sake of simplicity, in the following code, we enforce that `nfolds` must divide the number of data points. For a more competent implementation, see [MLUtils.jl](https://juliaml.github.io/MLUtils.jl/dev/api/#MLUtils.kfolds).) ```@example probinterface -function cross_val( - dataset::Vector{<:Real}; - nfolds::Int=5, - nsamples::Int=1_000, - rng::Random.AbstractRNG=Random.default_rng(), -) - # Initialize `loss` in a way such that the loop below does not change its type - model = gdemo(1) | (x=[first(dataset)],) - loss = zero(logjoint(model, rand(rng, model))) - - # prepare the K-folds +# Calculate the train/validation splits across `nfolds` partitions, assume `length(dataset)` divides `nfolds` +function kfolds(dataset::Array{<:Real}, nfolds::Int) fold_size, remaining = divrem(length(dataset), nfolds) if remaining != 0 error("The number of folds must divide the number of data points.") @@ -130,11 +121,23 @@ function cross_val( splits = map(0:(nfolds - 1)) do i start_idx = first_idx + i * fold_size end_idx = start_idx + fold_size - train_set_indices = [first_idx:start_idx; (end_idx + 1):last_idx] - return (view(dataset, train_set_indices), view(dataset, (start_idx + 1):end_idx)) + train_set_indices = [first_idx:start_idx-1; end_idx:last_idx] + return (view(dataset, train_set_indices), view(dataset, start_idx:end_idx-1)) end + return splits +end + +function cross_val( + dataset::Vector{<:Real}; + nfolds::Int=5, + nsamples::Int=1_000, + rng::Random.AbstractRNG=Random.default_rng(), +) + # Initialize `loss` in a way such that the loop below does not change its type + model = gdemo(1) | (x=[first(dataset)],) + loss = zero(logjoint(model, rand(rng, model))) - for (train, validation) in splits + for (train, validation) in kfolds(dataset, nfolds) # First, we train the model on the training set, i.e., we obtain samples from the posterior. # For normally-distributed data, the posterior can be computed in closed form. # For general models, however, typically samples will be generated using MCMC with Turing. From 00a3f0b015ed0fc2391dc9ebd9ae8a1217a908e0 Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Thu, 29 Feb 2024 13:18:47 +0000 Subject: [PATCH 09/23] Update docs/src/tutorials/prob-interface.md Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/tutorials/prob-interface.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/tutorials/prob-interface.md b/docs/src/tutorials/prob-interface.md index 683396550..330fa931a 100644 --- a/docs/src/tutorials/prob-interface.md +++ b/docs/src/tutorials/prob-interface.md @@ -121,8 +121,8 @@ function kfolds(dataset::Array{<:Real}, nfolds::Int) splits = map(0:(nfolds - 1)) do i start_idx = first_idx + i * fold_size end_idx = start_idx + fold_size - train_set_indices = [first_idx:start_idx-1; end_idx:last_idx] - return (view(dataset, train_set_indices), view(dataset, start_idx:end_idx-1)) + train_set_indices = [first_idx:(start_idx - 1); end_idx:last_idx] + return (view(dataset, train_set_indices), view(dataset, start_idx:(end_idx - 1))) end return splits end From da96d6cc4cafbcd24e839ddf2ebedb802ebd0a48 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 5 Apr 2024 12:00:20 +0100 Subject: [PATCH 10/23] transition to `Accessors` --- Project.toml | 42 +++++------ docs/Project.toml | 4 +- src/DynamicPPL.jl | 2 +- src/abstract_varinfo.jl | 8 +-- src/compiler.jl | 10 +-- src/model.jl | 12 ++-- src/model_utils.jl | 2 +- src/simple_varinfo.jl | 32 ++++----- src/test_utils.jl | 2 +- src/threadsafe.jl | 36 +++++----- src/utils.jl | 154 ++++++++++++++++++++-------------------- src/varinfo.jl | 8 +-- test/Project.toml | 4 +- test/contexts.jl | 2 +- test/runtests.jl | 2 +- 15 files changed, 160 insertions(+), 160 deletions(-) diff --git a/Project.toml b/Project.toml index cbad0d688..e8091435e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,11 +1,12 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.24.9" +version = "0.25.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -20,15 +21,31 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" -Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" +[weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" + +[extensions] +DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] +DynamicPPLEnzymeCoreExt = ["EnzymeCore"] +DynamicPPLForwardDiffExt = ["ForwardDiff"] +DynamicPPLMCMCChainsExt = ["MCMCChains"] +DynamicPPLReverseDiffExt = ["ReverseDiff"] +DynamicPPLZygoteRulesExt = ["ZygoteRules"] + [compat] ADTypes = "0.2" AbstractMCMC = "5" -AbstractPPL = "0.7" -BangBang = "0.3" +AbstractPPL = "0.8" +Accessors = "0.1" +BangBang = "0.4" Bijectors = "0.13" ChainRulesCore = "1" Compat = "4" @@ -44,29 +61,12 @@ MacroTools = "0.5.6" OrderedCollections = "1" Random = "1.6" Requires = "1" -Setfield = "1" Test = "1.6" ZygoteRules = "0.2" julia = "1.6" -[extensions] -DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] -DynamicPPLEnzymeCoreExt = ["EnzymeCore"] -DynamicPPLForwardDiffExt = ["ForwardDiff"] -DynamicPPLMCMCChainsExt = ["MCMCChains"] -DynamicPPLReverseDiffExt = ["ReverseDiff"] -DynamicPPLZygoteRulesExt = ["ZygoteRules"] - [extras] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" - -[weakdeps] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" diff --git a/docs/Project.toml b/docs/Project.toml index 48ebe173c..5a605a974 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,4 +1,5 @@ [deps] +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" @@ -6,10 +7,10 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" -Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" [compat] +Accessors = "0.1" DataStructures = "0.18" Distributions = "0.25" Documenter = "1" @@ -17,5 +18,4 @@ FillArrays = "0.13, 1" LogDensityProblems = "2" MCMCChains = "5, 6" MLUtils = "0.3, 0.4" -Setfield = "0.7.1, 0.8, 1" StableRNGs = "1" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index ce6605250..0ccfbb103 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -12,7 +12,7 @@ using ADTypes: ADTypes using BangBang: BangBang, push!!, empty!!, setindex!! using MacroTools: MacroTools using ConstructionBase: ConstructionBase -using Setfield: Setfield +using Accessors: Accessors using LogDensityProblems: LogDensityProblems using LogDensityProblemsAD: LogDensityProblemsAD diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index bd7e8d8fb..8aedeb09c 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -262,7 +262,7 @@ julia> values_as(SimpleVarInfo(data), NamedTuple) (x = 1.0, m = [2.0]) julia> values_as(SimpleVarInfo(data), OrderedDict) -OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Any} with 2 entries: +OrderedDict{VarName{sym, typeof(identity)} where sym, Any} with 2 entries: x => 1.0 m => [2.0] @@ -312,7 +312,7 @@ julia> values_as(vi, NamedTuple) (s = 1.0, m = 2.0) julia> values_as(vi, OrderedDict) -OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Float64} with 2 entries: +OrderedDict{VarName{sym, typeof(identity)} where sym, Float64} with 2 entries: s => 1.0 m => 2.0 @@ -338,7 +338,7 @@ julia> values_as(vi, NamedTuple) (s = 1.0, m = 2.0) julia> values_as(vi, OrderedDict) -OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Float64} with 2 entries: +OrderedDict{VarName{sym, typeof(identity)} where sym, Float64} with 2 entries: s => 1.0 m => 2.0 @@ -426,7 +426,7 @@ julia> # Extract one with only `m`. julia> keys(varinfo_subset1) -1-element Vector{VarName{:m, Setfield.IdentityLens}}: +1-element Vector{VarName{:m, typeof(identity)}}: m julia> varinfo_subset1[@varname(m)] diff --git a/src/compiler.jl b/src/compiler.jl index e7c44d16b..3ebeaefff 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -8,7 +8,7 @@ requires a dynamic lens. # Examples -```jldoctest; setup=:(using Setfield) +```jldoctest; setup=:(using Accessors) julia> DynamicPPL.need_concretize(:(x[1, :])) true @@ -19,7 +19,7 @@ julia> DynamicPPL.need_concretize(:(x[1, 1])) false """ function need_concretize(expr) - return Setfield.need_dynamic_lens(expr) || begin + return Accessors.need_dynamic_optic(expr) || begin flag = false MacroTools.postwalk(expr) do ex # Concretise colon by default @@ -226,7 +226,7 @@ function unwrap_right_left_vns( # for `i = size(left, 2)`. Hence the symbol should be `x[:, i]`, # and we therefore add the `Colon()` below. vns = map(axes(left, 2)) do i - return AbstractPPL.concretize(vn ∘ Setfield.IndexLens((Colon(), i)), left) + return AbstractPPL.concretize(vn ∘ Accessors.IndexLens((Colon(), i)), left) end return unwrap_right_left_vns(right, left, vns) end @@ -236,7 +236,7 @@ function unwrap_right_left_vns( vn::VarName, ) vns = map(CartesianIndices(left)) do i - return vn ∘ Setfield.IndexLens(Tuple(i)) + return vn ∘ Accessors.IndexLens(Tuple(i)) end return unwrap_right_left_vns(right, left, vns) end @@ -437,7 +437,7 @@ function generate_tilde_assume(left, right, vn) expr = :($left = $value) if left isa Expr expr = AbstractPPL.drop_escape( - Setfield.setmacro(BangBang.prefermutation, expr; overwrite=true) + Accessors.setmacro(BangBang.prefermutation, expr; overwrite=true) ) end diff --git a/src/model.jl b/src/model.jl index c0cc2f26f..8c10ed36e 100644 --- a/src/model.jl +++ b/src/model.jl @@ -279,11 +279,11 @@ in their trace/`VarInfo`: ```jldoctest condition julia> keys(VarInfo(demo_outer())) -1-element Vector{VarName{:m, Setfield.IdentityLens}}: +1-element Vector{VarName{:m, typeof(identity)}}: m julia> keys(VarInfo(demo_outer_prefix())) -1-element Vector{VarName{Symbol("inner.m"), Setfield.IdentityLens}}: +1-element Vector{VarName{Symbol("inner.m"), typeof(identity)}}: inner.m ``` @@ -448,7 +448,7 @@ julia> conditioned(cm) julia> # Since we conditioned on `m`, not `a.m` as it will appear after prefixed, # `a.m` is treated as a random variable. keys(VarInfo(cm)) -1-element Vector{VarName{Symbol("a.m"), Setfield.IdentityLens}}: +1-element Vector{VarName{Symbol("a.m"), typeof(identity)}}: a.m julia> # If we instead condition on `a.m`, `m` in the model will be considered an observation. @@ -634,11 +634,11 @@ in their trace/`VarInfo`: ```jldoctest fix julia> keys(VarInfo(demo_outer())) -1-element Vector{VarName{:m, Setfield.IdentityLens}}: +1-element Vector{VarName{:m, typeof(identity)}}: m julia> keys(VarInfo(demo_outer_prefix())) -1-element Vector{VarName{Symbol("inner.m"), Setfield.IdentityLens}}: +1-element Vector{VarName{Symbol("inner.m"), typeof(identity)}}: inner.m ``` @@ -830,7 +830,7 @@ julia> fixed(cm) julia> # Since we fixed on `m`, not `a.m` as it will appear after prefixed, # `a.m` is treated as a random variable. keys(VarInfo(cm)) -1-element Vector{VarName{Symbol("a.m"), Setfield.IdentityLens}}: +1-element Vector{VarName{Symbol("a.m"), typeof(identity)}}: a.m julia> # If we instead fix on `a.m`, `m` in the model will be considered an observation. diff --git a/src/model_utils.jl b/src/model_utils.jl index ab1acfa05..8d4cb34c5 100644 --- a/src/model_utils.jl +++ b/src/model_utils.jl @@ -110,7 +110,7 @@ function values_from_chain( for vn in varname_leaves(VarName{sym}(), x) # Update `out`, possibly in place, and return. l = AbstractPPL.getlens(vn) - out = Setfield.set( + out = Accessors.set( out, BangBang.prefermutation(l), chain[iteration_idx, Symbol(vn_parent ∘ l), chain_idx], diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index ad37130d6..8ff902017 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -259,15 +259,15 @@ function unflatten(svi::SimpleVarInfo, x::AbstractVector) end function BangBang.empty!!(vi::SimpleVarInfo) - return resetlogp!!(Setfield.@set vi.values = empty!!(vi.values)) + return resetlogp!!(Accessors.@set vi.values = empty!!(vi.values)) end Base.isempty(vi::SimpleVarInfo) = isempty(vi.values) getlogp(vi::SimpleVarInfo) = vi.logp getlogp(vi::SimpleVarInfo{<:Any,<:Ref}) = vi.logp[] -setlogp!!(vi::SimpleVarInfo, logp) = Setfield.@set vi.logp = logp -acclogp!!(vi::SimpleVarInfo, logp) = Setfield.@set vi.logp = getlogp(vi) + logp +setlogp!!(vi::SimpleVarInfo, logp) = Accessors.@set vi.logp = logp +acclogp!!(vi::SimpleVarInfo, logp) = Accessors.@set vi.logp = getlogp(vi) + logp function setlogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) vi.logp[] = logp @@ -343,7 +343,7 @@ Base.haskey(vi::SimpleVarInfo, vn::VarName) = hasvalue(vi.values, vn) function BangBang.setindex!!(vi::SimpleVarInfo, val, vn::VarName) # For `NamedTuple` we treat the symbol in `vn` as the _property_ to set. - return Setfield.@set vi.values = set!!(vi.values, vn, val) + return Accessors.@set vi.values = set!!(vi.values, vn, val) end function BangBang.setindex!!(vi::SimpleVarInfo, val, spl::AbstractSampler) @@ -364,11 +364,11 @@ function BangBang.setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName dict = values_as(vi) # Attempt to split into `parent` and `child` lenses. parent, child, issuccess = splitlens(getlens(vn)) do lens - l = lens === nothing ? Setfield.IdentityLens() : lens + l = lens === nothing ? identity : lens haskey(dict, VarName(vn, l)) end # When combined with `VarInfo`, `nothing` is equivalent to `IdentityLens`. - keylens = parent === nothing ? Setfield.IdentityLens() : parent + keylens = parent === nothing ? identity : parent dict_new = if !issuccess # Split doesn't exist ⟹ we're working with a new key. @@ -378,18 +378,18 @@ function BangBang.setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName vn_key = VarName(vn, keylens) BangBang.setindex!!(dict, set!!(dict[vn_key], child, val), vn_key) end - return Setfield.@set vi.values = dict_new + return Accessors.@set vi.values = dict_new end # `NamedTuple` function BangBang.push!!( vi::SimpleVarInfo{<:NamedTuple}, - vn::VarName{sym,Setfield.IdentityLens}, + vn::VarName{sym,typeof(identity)}, value, dist::Distribution, gidset::Set{Selector}, ) where {sym} - return Setfield.@set vi.values = merge(vi.values, NamedTuple{(sym,)}((value,))) + return Accessors.@set vi.values = merge(vi.values, NamedTuple{(sym,)}((value,))) end function BangBang.push!!( vi::SimpleVarInfo{<:NamedTuple}, @@ -398,7 +398,7 @@ function BangBang.push!!( dist::Distribution, gidset::Set{Selector}, ) where {sym} - return Setfield.@set vi.values = set!!(vi.values, vn, value) + return Accessors.@set vi.values = set!!(vi.values, vn, value) end # `AbstractDict` @@ -426,7 +426,7 @@ end # `subset` function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName}) - return Setfield.@set varinfo.values = _subset(varinfo.values, vns) + return Accessors.@set varinfo.values = _subset(varinfo.values, vns) end function _subset(x::AbstractDict, vns) @@ -446,7 +446,7 @@ end function _subset(x::NamedTuple, vns) # NOTE: Here we can only handle `vns` that contain the `IdentityLens`. - if any(Base.Fix1(!==, Setfield.IdentityLens()) ∘ getlens, vns) + if any(Base.Fix1(!==, identity) ∘ getlens, vns) throw( ArgumentError( "Cannot subset `NamedTuple` with non-`IdentityLens` `VarName`. " * @@ -542,10 +542,10 @@ function settrans!!(vi::SimpleVarInfo, trans) return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation()) end function settrans!!(vi::SimpleVarInfo, transformation::AbstractTransformation) - return Setfield.@set vi.transformation = transformation + return Accessors.@set vi.transformation = transformation end function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) - return Setfield.@set vi.varinfo = settrans!!(vi.varinfo, trans) + return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans) end istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) @@ -675,7 +675,7 @@ function link!!( x = vi.values y, logjac = with_logabsdet_jacobian(b, x) lp_new = getlogp(vi) - logjac - vi_new = setlogp!!(Setfield.@set(vi.values = y), lp_new) + vi_new = setlogp!!(Accessors.@set(vi.values = y), lp_new) return settrans!!(vi_new, t) end @@ -690,7 +690,7 @@ function invlink!!( y = vi.values x, logjac = with_logabsdet_jacobian(b, y) lp_new = getlogp(vi) + logjac - vi_new = setlogp!!(Setfield.@set(vi.values = x), lp_new) + vi_new = setlogp!!(Accessors.@set(vi.values = x), lp_new) return settrans!!(vi_new, NoTransformation()) end diff --git a/src/test_utils.jl b/src/test_utils.jl index 6323f4dab..a315f7729 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -8,7 +8,7 @@ using Test using Random: Random using Bijectors: Bijectors -using Setfield: Setfield +using Accessors: Accessors # For backwards compat. using DynamicPPL: varname_leaves diff --git a/src/threadsafe.jl b/src/threadsafe.jl index fb1cc1c0c..c40d38466 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -58,7 +58,7 @@ end function BangBang.push!!( vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} ) - return Setfield.@set vi.varinfo = push!!(vi.varinfo, vn, r, dist, gidset) + return Accessors.@set vi.varinfo = push!!(vi.varinfo, vn, r, dist, gidset) end get_num_produce(vi::ThreadSafeVarInfo) = get_num_produce(vi.varinfo) @@ -84,25 +84,25 @@ islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl function link!!( t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model ) - return Setfield.@set vi.varinfo = link!!(t, vi.varinfo, spl, model) + return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, spl, model) end function invlink!!( t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model ) - return Setfield.@set vi.varinfo = invlink!!(t, vi.varinfo, spl, model) + return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, spl, model) end function link( t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model ) - return Setfield.@set vi.varinfo = link(t, vi.varinfo, spl, model) + return Accessors.@set vi.varinfo = link(t, vi.varinfo, spl, model) end function invlink( t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model ) - return Setfield.@set vi.varinfo = invlink(t, vi.varinfo, spl, model) + return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, spl, model) end # Need to define explicitly for `DynamicTransformation` to avoid method ambiguity. @@ -142,7 +142,7 @@ function maybe_invlink_before_eval!!( # Defer to the wrapped `AbstractVarInfo` object. # NOTE: When computing `getlogp` for `ThreadSafeVarInfo` we do include the `getlogp(vi.varinfo)` # hence the log-absdet-jacobian term will correctly be included in the `getlogp(vi)`. - return Setfield.@set vi.varinfo = maybe_invlink_before_eval!!( + return Accessors.@set vi.varinfo = maybe_invlink_before_eval!!( vi.varinfo, context, model ) end @@ -175,20 +175,20 @@ end getindex_raw(vi::ThreadSafeVarInfo, spl::AbstractSampler) = getindex_raw(vi.varinfo, spl) function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::AbstractSampler) - return Setfield.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) + return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) end function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::SampleFromPrior) - return Setfield.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) + return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) end function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::SampleFromUniform) - return Setfield.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) + return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) end function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vn::VarName) - return Setfield.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vn) + return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vn) end function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vns::AbstractVector{<:VarName}) - return Setfield.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vns) + return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vns) end function set_retained_vns_del_by_spl!(vi::ThreadSafeVarInfo, spl::Sampler) @@ -197,7 +197,7 @@ end isempty(vi::ThreadSafeVarInfo) = isempty(vi.varinfo) function BangBang.empty!!(vi::ThreadSafeVarInfo) - return resetlogp!!(Setfield.@set!(vi.varinfo = empty!!(vi.varinfo))) + return resetlogp!!(Accessors.@set(vi.varinfo = empty!!(vi.varinfo))) end values_as(vi::ThreadSafeVarInfo, ::Type{T}) where {T} = values_as(vi.varinfo, T) @@ -211,10 +211,10 @@ end # Transformations. function settrans!!(vi::ThreadSafeVarInfo, trans::Bool, vn::VarName) - return Setfield.@set vi.varinfo = settrans!!(vi.varinfo, trans, vn) + return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans, vn) end function settrans!!(vi::ThreadSafeVarInfo, spl::AbstractSampler, dist::Distribution) - return Setfield.@set vi.varinfo = settrans!!(vi.varinfo, spl, dist) + return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, spl, dist) end istrans(vi::ThreadSafeVarInfo, vn::VarName) = istrans(vi.varinfo, vn) @@ -223,18 +223,18 @@ istrans(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = istrans(vi.vari getval(vi::ThreadSafeVarInfo, vn::VarName) = getval(vi.varinfo, vn) function unflatten(vi::ThreadSafeVarInfo, x::AbstractVector) - return Setfield.@set vi.varinfo = unflatten(vi.varinfo, x) + return Accessors.@set vi.varinfo = unflatten(vi.varinfo, x) end function unflatten(vi::ThreadSafeVarInfo, spl::AbstractSampler, x::AbstractVector) - return Setfield.@set vi.varinfo = unflatten(vi.varinfo, spl, x) + return Accessors.@set vi.varinfo = unflatten(vi.varinfo, spl, x) end function subset(varinfo::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) - return Setfield.@set varinfo.varinfo = subset(varinfo.varinfo, vns) + return Accessors.@set varinfo.varinfo = subset(varinfo.varinfo, vns) end function Base.merge(varinfo_left::ThreadSafeVarInfo, varinfo_right::ThreadSafeVarInfo) - return Setfield.@set varinfo_left.varinfo = merge( + return Accessors.@set varinfo_left.varinfo = merge( varinfo_left.varinfo, varinfo_right.varinfo ) end diff --git a/src/utils.jl b/src/utils.jl index b447fed53..0cf7c85d0 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -347,13 +347,13 @@ collectmaybe(x::Base.AbstractSet) = collect(x) ####################### # BangBang.jl related # ####################### -function set!!(obj, lens::Setfield.Lens, value) +function set!!(obj, lens::AbstractPPL.ALLOWED_OPTICS, value) lensmut = BangBang.prefermutation(lens) - return Setfield.set(obj, lensmut, value) + return Accessors.set(obj, lensmut, value) end function set!!(obj, vn::VarName{sym}, value) where {sym} - lens = BangBang.prefermutation(Setfield.PropertyLens{sym}() ∘ AbstractPPL.getlens(vn)) - return Setfield.set(obj, lens, value) + lens = BangBang.prefermutation(Accessors.PropertyLens{sym}() ∘ AbstractPPL.getlens(vn)) + return Accessors.set(obj, lens, value) end ############################# @@ -368,7 +368,7 @@ end Return `true` if `lens` can be used to view `container`, and `false` otherwise. # Examples -```jldoctest; setup=:(using Setfield; using DynamicPPL: canview) +```jldoctest; setup=:(using Accessors; using DynamicPPL: canview) julia> canview(@lens(_.a), (a = 1.0, )) true @@ -383,18 +383,18 @@ false ``` """ canview(lens, container) = false -canview(::Setfield.IdentityLens, _) = true -function canview(lens::Setfield.PropertyLens{field}, x) where {field} +canview(::typeof(identity), _) = true +function canview(lens::Accessors.PropertyLens{field}, x) where {field} return hasproperty(x, field) end # `IndexLens`: only relevant if `x` supports indexing. -canview(lens::Setfield.IndexLens, x) = false -canview(lens::Setfield.IndexLens, x::AbstractArray) = checkbounds(Bool, x, lens.indices...) +canview(lens::Accessors.IndexLens, x) = false +canview(lens::Accessors.IndexLens, x::AbstractArray) = checkbounds(Bool, x, lens.indices...) # `ComposedLens`: check that we can view `.outer` and `.inner`, but using # value extracted using `.outer`. -function canview(lens::Setfield.ComposedLens, x) +function canview(lens::Accessors.ComposedOptic, x) return canview(lens.outer, x) && canview(lens.inner, get(x, lens.outer)) end @@ -417,122 +417,122 @@ x """ function parent(vn::VarName) p = parent(getlens(vn)) - return p === nothing ? VarName(vn, Setfield.IdentityLens()) : VarName(vn, p) + return p === nothing ? VarName(vn, identity) : VarName(vn, p) end """ - parent(lens::Setfield.Lens) + parent(optic) -Return the parent lens. If `lens` doesn't have a parent, +Return the parent optic. If `optic` doesn't have a parent, `nothing` is returned. See also: [`parent_and_child`]. # Examples -```jldoctest; setup=:(using Setfield; using DynamicPPL: parent) -julia> parent(@lens(_.a[1])) -(@lens _.a) +```jldoctest; setup=:(using Accessors; using DynamicPPL: parent) +julia> parent(@o(_.a[1])) +(@o _.a) -julia> # Parent of lens without parents results in `nothing`. - (parent ∘ parent)(@lens(_.a[1])) === nothing +julia> # Parent of optic without parents results in `nothing`. + (parent ⨟ parent)(@o(_.a[1])) === nothing true ``` """ -parent(lens::Setfield.Lens) = first(parent_and_child(lens)) +parent(optic::AbstractPPL.ALLOWED_OPTICS) = first(parent_and_child(optic)) """ - parent_and_child(lens::Setfield.Lens) + parent_and_child(optic) -Return a 2-tuple of lenses `(parent, child)` where `parent` is the -parent lens of `lens` and `child` is the child lens of `lens`. +Return a 2-tuple of optics `(parent, child)` where `parent` is the +parent optic of `optic` and `child` is the child optic of `optic`. -If `lens` does not have a parent, we return `(nothing, lens)`. +If `optic` does not have a parent, we return `(nothing, optic)`. See also: [`parent`]. # Examples -```jldoctest; setup=:(using Setfield; using DynamicPPL: parent_and_child) -julia> parent_and_child(@lens(_.a[1])) -((@lens _.a), (@lens _[1])) +```jldoctest; setup=:(using Accessors; using DynamicPPL: parent_and_child) +julia> parent_and_child(@o(_.a[1])) +((@o _.a), (@o _[1])) -julia> parent_and_child(@lens(_.a)) -(nothing, (@lens _.a)) +julia> parent_and_child(@o(_.a)) +(nothing, (@o _.a)) ``` """ -parent_and_child(lens::Setfield.Lens) = (nothing, lens) -function parent_and_child(lens::Setfield.ComposedLens) - p, child = parent_and_child(lens.inner) - parent = p === nothing ? lens.outer : lens.outer ∘ p +parent_and_child(optic::AbstractPPL.ALLOWED_OPTICS) = (nothing, optic) +function parent_and_child(optic::Accessors.ComposedOptic) + p, child = parent_and_child(optic.outer) + parent = p === nothing ? optic.inner : optic.inner ⨟ p return parent, child end """ - splitlens(condition, lens) + splitoptic(condition, optic) Return a 3-tuple `(parent, child, issuccess)` where, if `issuccess` is `true`, -`parent` is a lens such that `condition(parent)` is `true` and `parent ∘ child == lens`. +`parent` is a lens such that `condition(parent)` is `true` and `parent ⨟ child == lens`. If `issuccess` is `false`, then no such split could be found. # Examples -```jldoctest; setup=:(using Setfield; using DynamicPPL: splitlens) -julia> p, c, issucesss = splitlens(@lens(_.a[1])) do parent +```jldoctest; setup=:(using Accessors; using DynamicPPL: splitoptic) +julia> p, c, issucesss = splitoptic(@o(_.a[1])) do parent # Succeeds! - parent == @lens(_.a) + parent == @o(_.a) end -((@lens _.a), (@lens _[1]), true) +((@o _.a), (@o _[1]), true) -julia> p ∘ c -(@lens _.a[1]) +julia> p ⨟ c +(@o _.a[1]) -julia> splitlens(@lens(_.a[1])) do parent +julia> splitoptic(@o(_.a[1])) do parent # Fails! - parent == @lens(_.b) + parent == @o(_.b) end -(nothing, (@lens _.a[1]), false) +(nothing, (@o _.a[1]), false) ``` """ -function splitlens(condition, lens) - current_parent, current_child = parent_and_child(lens) +function splitoptic(condition, optic) + current_parent, current_child = parent_and_child(optic) # We stop if either a) `condition` is satisfied, or b) we reached the root. while !condition(current_parent) && current_parent !== nothing current_parent, c = parent_and_child(current_parent) - current_child = c ∘ current_child + current_child = c ⨟ current_child end return current_parent, current_child, condition(current_parent) end """ - remove_parent_lens(vn_parent::VarName, vn_child::VarName) + remove_parent_optic(vn_parent::VarName, vn_child::VarName) -Remove the parent lens `vn_parent` from `vn_child`. +Remove the parent optic `vn_parent` from `vn_child`. # Examples -```jldoctest -julia> DynamicPPL.remove_parent_lens(@varname(x), @varname(x.a)) -(@lens _.a) +```jldoctest; setup = :(using Accessors; using DynamicPPL: remove_parent_optic) +julia> remove_parent_optic(@varname(x), @varname(x.a)) +(@o _.a) -julia> DynamicPPL.remove_parent_lens(@varname(x), @varname(x.a[1])) -(@lens _.a[1]) +julia> remove_parent_optic(@varname(x), @varname(x.a[1])) +(@o _.a[1]) -julia> DynamicPPL.remove_parent_lens(@varname(x.a), @varname(x.a[1])) -(@lens _[1]) +julia> remove_parent_optic(@varname(x.a), @varname(x.a[1])) +(@o _[1]) -julia> DynamicPPL.remove_parent_lens(@varname(x.a), @varname(x.a[1].b)) -(@lens _[1].b) +julia> remove_parent_optic(@varname(x.a), @varname(x.a[1].b)) +(@o _[1].b) -julia> DynamicPPL.remove_parent_lens(@varname(x.a), @varname(x.a)) +julia> remove_parent_optic(@varname(x.a), @varname(x.a)) ERROR: Could not find x.a in x.a -julia> DynamicPPL.remove_parent_lens(@varname(x.a[2]), @varname(x.a[1])) +julia> remove_parent_optic(@varname(x.a[2]), @varname(x.a[1])) ERROR: Could not find x.a[2] in x.a[1] ``` """ -function remove_parent_lens(vn_parent::VarName{sym}, vn_child::VarName{sym}) where {sym} - _, child, issuccess = splitlens(getlens(vn_child)) do lens - l = lens === nothing ? Setfield.IdentityLens() : lens - VarName(vn_child, l) == vn_parent +function remove_parent_optic(vn_parent::VarName{sym}, vn_child::VarName{sym}) where {sym} + _, child, issuccess = splitoptic(getoptic(vn_child)) do optic + o = optic === nothing ? identity : optic + VarName(vn_child, o) == vn_parent end issuccess || error("Could not find $vn_parent in $vn_child") @@ -763,11 +763,11 @@ function hasvalue(vals::AbstractDict, vn::VarName) # to split the lens into the key / `parent` and the extraction lens / `child`. # If `issuccess` is `true`, we found such a split, and hence `vn` is present. parent, child, issuccess = splitlens(getlens(vn)) do lens - l = lens === nothing ? Setfield.IdentityLens() : lens + l = lens === nothing ? identity : lens haskey(vals, VarName(vn, l)) end # When combined with `VarInfo`, `nothing` is equivalent to `IdentityLens`. - keylens = parent === nothing ? Setfield.IdentityLens() : parent + keylens = parent === nothing ? identity : parent # Return early if no such split could be found. issuccess || return false @@ -792,11 +792,11 @@ function nested_getindex(values::AbstractDict, vn::VarName) # Split the lens into the key / `parent` and the extraction lens / `child`. parent, child, issuccess = splitlens(getlens(vn)) do lens - l = lens === nothing ? Setfield.IdentityLens() : lens + l = lens === nothing ? identity : lens haskey(values, VarName(vn, l)) end # When combined with `VarInfo`, `nothing` is equivalent to `IdentityLens`. - keylens = parent === nothing ? Setfield.IdentityLens() : parent + keylens = parent === nothing ? identity : parent # If we found a valid split, then we can extract the value. if !issuccess @@ -911,19 +911,19 @@ x.z[2][1] varname_leaves(vn::VarName, ::Real) = [vn] function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) return ( - VarName(vn, getlens(vn) ∘ Setfield.IndexLens(Tuple(I))) for + VarName(vn, getlens(vn) ∘ Accessors.IndexLens(Tuple(I))) for I in CartesianIndices(val) ) end function varname_leaves(vn::VarName, val::AbstractArray) return Iterators.flatten( - varname_leaves(VarName(vn, getlens(vn) ∘ Setfield.IndexLens(Tuple(I))), val[I]) for + varname_leaves(VarName(vn, getlens(vn) ∘ Accessors.IndexLens(Tuple(I))), val[I]) for I in CartesianIndices(val) ) end function varname_leaves(vn::VarName, val::NamedTuple) iter = Iterators.map(keys(val)) do sym - lens = Setfield.PropertyLens{sym}() + lens = Accessors.PropertyLens{sym}() varname_leaves(vn ∘ lens, get(val, lens)) end return Iterators.flatten(iter) @@ -1033,7 +1033,7 @@ function varname_and_value_leaves_inner( ) return ( Leaf( - VarName(vn, DynamicPPL.getlens(vn) ∘ DynamicPPL.Setfield.IndexLens(Tuple(I))), + VarName(vn, DynamicPPL.getlens(vn) ∘ DynamicPPL.Accessors.IndexLens(Tuple(I))), val[I], ) for I in CartesianIndices(val) ) @@ -1042,14 +1042,14 @@ end function varname_and_value_leaves_inner(vn::VarName, val::AbstractArray) return Iterators.flatten( varname_and_value_leaves_inner( - VarName(vn, DynamicPPL.getlens(vn) ∘ DynamicPPL.Setfield.IndexLens(Tuple(I))), + VarName(vn, DynamicPPL.getlens(vn) ∘ DynamicPPL.Accessors.IndexLens(Tuple(I))), val[I], ) for I in CartesianIndices(val) ) end function varname_and_value_leaves_inner(vn::DynamicPPL.VarName, val::NamedTuple) iter = Iterators.map(keys(val)) do sym - lens = DynamicPPL.Setfield.PropertyLens{sym}() + lens = DynamicPPL.Accessors.PropertyLens{sym}() varname_and_value_leaves_inner(vn ∘ lens, get(val, lens)) end @@ -1059,15 +1059,15 @@ end function varname_and_value_leaves_inner(vn::VarName, x::Cholesky) # TODO: Or do we use `PDMat` here? return if x.uplo == 'L' - varname_and_value_leaves_inner(vn ∘ Setfield.PropertyLens{:L}(), x.L) + varname_and_value_leaves_inner(vn ∘ Accessors.PropertyLens{:L}(), x.L) else - varname_and_value_leaves_inner(vn ∘ Setfield.PropertyLens{:U}(), x.U) + varname_and_value_leaves_inner(vn ∘ Accessors.PropertyLens{:U}(), x.U) end end function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.LowerTriangular) return ( Leaf( - VarName(vn, DynamicPPL.getlens(vn) ∘ DynamicPPL.Setfield.IndexLens(Tuple(I))), + VarName(vn, DynamicPPL.getlens(vn) ∘ DynamicPPL.Accessors.IndexLens(Tuple(I))), x[I], ) # Iteration over the lower-triangular indices. @@ -1077,7 +1077,7 @@ end function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.UpperTriangular) return ( Leaf( - VarName(vn, DynamicPPL.getlens(vn) ∘ DynamicPPL.Setfield.IndexLens(Tuple(I))), + VarName(vn, DynamicPPL.getlens(vn) ∘ DynamicPPL.Accessors.IndexLens(Tuple(I))), x[I], ) # Iteration over the upper-triangular indices. diff --git a/src/varinfo.jl b/src/varinfo.jl index c8c46ee27..ce0a5ff73 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -965,7 +965,7 @@ function link!!( ) # By default this will simply evaluate the model with `DynamicTransformationContext`, and so # we need to specialize to avoid this. - return Setfield.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, spl, model) + return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, spl, model) end """ @@ -1051,7 +1051,7 @@ function invlink!!( ) # By default this will simply evaluate the model with `DynamicTransformationContext`, and so # we need to specialize to avoid this. - return Setfield.@set vi.varinfo = DynamicPPL.invlink!!(vi.varinfo, spl, model) + return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(vi.varinfo, spl, model) end function maybe_invlink_before_eval!!(vi::VarInfo, context::AbstractContext, model::Model) @@ -1166,7 +1166,7 @@ function link( ) # By default this will simply evaluate the model with `DynamicTransformationContext`, and so # we need to specialize to avoid this. - return Setfield.@set varinfo.varinfo = link(varinfo.varinfo, spl, model) + return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, spl, model) end function _link(varinfo::UntypedVarInfo, spl::AbstractSampler) @@ -1261,7 +1261,7 @@ function invlink( ) # By default this will simply evaluate the model with `DynamicTransformationContext`, and so # we need to specialize to avoid this. - return Setfield.@set varinfo.varinfo = invlink(varinfo.varinfo, spl, model) + return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, spl, model) end function _invlink(varinfo::UntypedVarInfo, spl::AbstractSampler) diff --git a/test/Project.toml b/test/Project.toml index 93cd7ecd1..a95abef1e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,6 +2,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" @@ -19,13 +20,13 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" -Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] +Accessors = "0.1" ADTypes = "0.2" AbstractMCMC = "5" AbstractPPL = "0.7" @@ -41,7 +42,6 @@ LogDensityProblemsAD = "1" MCMCChains = "6.0.4" MacroTools = "0.5.5" ReverseDiff = "1" -Setfield = "1" StableRNGs = "1" Tracker = "0.2.23" Zygote = "0.6" diff --git a/test/contexts.jl b/test/contexts.jl index d04aecb52..994d98194 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -1,4 +1,4 @@ -using Test, DynamicPPL, Setfield +using Test, DynamicPPL, Accessors using DynamicPPL: leafcontext, setleafcontext, diff --git a/test/runtests.jl b/test/runtests.jl index 9e11e2ef4..f18167d08 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,3 +1,4 @@ +using Accessors using ADTypes using DynamicPPL using AbstractMCMC @@ -13,7 +14,6 @@ using MCMCChains using Tracker using ReverseDiff using Zygote -using Setfield using Compat using Distributed From 0f5de5397f4fb37cdcd05f60cb55b5cfc33bd44d Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 5 Apr 2024 12:42:22 +0100 Subject: [PATCH 11/23] more updates --- src/compiler.jl | 4 +- src/contexts.jl | 4 +- src/model_utils.jl | 16 ++++---- src/simple_varinfo.jl | 20 ++++----- src/utils.jl | 94 +++++++++++++++++++++---------------------- src/varinfo.jl | 6 +-- test/Project.toml | 2 +- test/contexts.jl | 10 ++--- test/varinfo.jl | 4 +- 9 files changed, 80 insertions(+), 80 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 3ebeaefff..f6e3e3982 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -4,7 +4,7 @@ const INTERNALNAMES = (:__model__, :__context__, :__varinfo__) need_concretize(expr) Return `true` if `expr` needs to be concretized, i.e., if it contains a colon `:` or -requires a dynamic lens. +requires a dynamic optic. # Examples @@ -226,7 +226,7 @@ function unwrap_right_left_vns( # for `i = size(left, 2)`. Hence the symbol should be `x[:, i]`, # and we therefore add the `Colon()` below. vns = map(axes(left, 2)) do i - return AbstractPPL.concretize(vn ∘ Accessors.IndexLens((Colon(), i)), left) + return AbstractPPL.concretize(vn ⨟ Accessors.IndexLens((Colon(), i)), left) end return unwrap_right_left_vns(right, left, vns) end diff --git a/src/contexts.jl b/src/contexts.jl index 83da5d929..2018b9155 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -288,9 +288,9 @@ end function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} if @generated - return :(VarName{$(QuoteNode(Symbol(Prefix, PREFIX_SEPARATOR, Sym)))}(getlens(vn))) + return :(VarName{$(QuoteNode(Symbol(Prefix, PREFIX_SEPARATOR, Sym)))}(getoptic(vn))) else - VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(getlens(vn)) + VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(getoptic(vn)) end end diff --git a/src/model_utils.jl b/src/model_utils.jl index 8d4cb34c5..9232b0c88 100644 --- a/src/model_utils.jl +++ b/src/model_utils.jl @@ -78,13 +78,13 @@ end function varname_in_chain!( x::AbstractArray, vn_parent::VarName{sym}, chain, chain_idx, iteration_idx, out ) where {sym} - # We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the lens. - # This way we can use `getlens(vn)` to extract the value from `x` and use `vn_parent ∘ getlens(vn)` + # We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the optic. + # This way we can use `getoptic(vn)` to extract the value from `x` and use `vn_parent ⨟ getoptic(vn)` # to extract the value from the `chain`. for vn in varname_leaves(VarName{sym}(), x) # Update `out`, possibly in place, and return. - l = AbstractPPL.getlens(vn) - varname_in_chain!(x, vn_parent ∘ l, chain, chain_idx, iteration_idx, out) + l = AbstractPPL.getoptic(vn) + varname_in_chain!(x, vn_parent ⨟ l, chain, chain_idx, iteration_idx, out) end return out end @@ -103,17 +103,17 @@ end function values_from_chain( x::AbstractArray, vn_parent::VarName{sym}, chain, chain_idx, iteration_idx ) where {sym} - # We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the lens. - # This way we can use `getlens(vn)` to extract the value from `x` and use `vn_parent ∘ getlens(vn)` + # We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the optic. + # This way we can use `getoptic(vn)` to extract the value from `x` and use `vn_parent ⨟ getoptic(vn)` # to extract the value from the `chain`. out = similar(x) for vn in varname_leaves(VarName{sym}(), x) # Update `out`, possibly in place, and return. - l = AbstractPPL.getlens(vn) + l = AbstractPPL.getoptic(vn) out = Accessors.set( out, BangBang.prefermutation(l), - chain[iteration_idx, Symbol(vn_parent ∘ l), chain_idx], + chain[iteration_idx, Symbol(vn_parent ⨟ l), chain_idx], ) end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 8ff902017..6704e03fa 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -362,20 +362,20 @@ end function BangBang.setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName) # For dictlike objects, we treat the entire `vn` as a _key_ to set. dict = values_as(vi) - # Attempt to split into `parent` and `child` lenses. - parent, child, issuccess = splitlens(getlens(vn)) do lens - l = lens === nothing ? identity : lens - haskey(dict, VarName(vn, l)) + # Attempt to split into `parent` and `child` optic. + parent, child, issuccess = splitoptic(getoptic(vn)) do optic + o = optic === nothing ? identity : optic + haskey(dict, VarName(vn, o)) end - # When combined with `VarInfo`, `nothing` is equivalent to `IdentityLens`. - keylens = parent === nothing ? identity : parent + # When combined with `VarInfo`, `nothing` is equivalent to `identity`. + keyoptic = parent === nothing ? identity : parent dict_new = if !issuccess # Split doesn't exist ⟹ we're working with a new key. BangBang.setindex!!(dict, val, vn) else # Split exists ⟹ trying to set an existing key. - vn_key = VarName(vn, keylens) + vn_key = VarName(vn, keyoptic) BangBang.setindex!!(dict, set!!(dict[vn_key], child, val), vn_key) end return Accessors.@set vi.values = dict_new @@ -445,11 +445,11 @@ function _subset(x::AbstractDict, vns) end function _subset(x::NamedTuple, vns) - # NOTE: Here we can only handle `vns` that contain the `IdentityLens`. - if any(Base.Fix1(!==, identity) ∘ getlens, vns) + # NOTE: Here we can only handle `vns` that contain `identity` as optic. + if any(Base.Fix1(!==, identity) ∘ getoptic, vns) throw( ArgumentError( - "Cannot subset `NamedTuple` with non-`IdentityLens` `VarName`. " * + "Cannot subset `NamedTuple` with non-`identity` `VarName`. " * "For example, `@varname(x)` is allowed, but `@varname(x[1])` is not.", ), ) diff --git a/src/utils.jl b/src/utils.jl index 0cf7c85d0..ac6d86dbc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -347,13 +347,13 @@ collectmaybe(x::Base.AbstractSet) = collect(x) ####################### # BangBang.jl related # ####################### -function set!!(obj, lens::AbstractPPL.ALLOWED_OPTICS, value) - lensmut = BangBang.prefermutation(lens) - return Accessors.set(obj, lensmut, value) +function set!!(obj, optic::AbstractPPL.ALLOWED_OPTICS, value) + opticmut = BangBang.prefermutation(optic) + return Accessors.set(obj, opticmut, value) end function set!!(obj, vn::VarName{sym}, value) where {sym} - lens = BangBang.prefermutation(Accessors.PropertyLens{sym}() ∘ AbstractPPL.getlens(vn)) - return Accessors.set(obj, lens, value) + optic = BangBang.prefermutation(Accessors.PropertyLens{sym}() ∘ AbstractPPL.getoptic(vn)) + return Accessors.set(obj, optic, value) end ############################# @@ -363,39 +363,39 @@ end # we're more likely to specialize on the key in these settings rather than the container. # TODO: I'm not sure about this name. """ - canview(lens, container) + canview(optic, container) -Return `true` if `lens` can be used to view `container`, and `false` otherwise. +Return `true` if `optic` can be used to view `container`, and `false` otherwise. # Examples ```jldoctest; setup=:(using Accessors; using DynamicPPL: canview) -julia> canview(@lens(_.a), (a = 1.0, )) +julia> canview(@o(_.a), (a = 1.0, )) true -julia> canview(@lens(_.a), (b = 1.0, )) # property `a` does not exist +julia> canview(@o(_.a), (b = 1.0, )) # property `a` does not exist false -julia> canview(@lens(_.a[1]), (a = [1.0, 2.0], )) +julia> canview(@o(_.a[1]), (a = [1.0, 2.0], )) true -julia> canview(@lens(_.a[3]), (a = [1.0, 2.0], )) # out of bounds +julia> canview(@o(_.a[3]), (a = [1.0, 2.0], )) # out of bounds false ``` """ -canview(lens, container) = false +canview(optic, container) = false canview(::typeof(identity), _) = true -function canview(lens::Accessors.PropertyLens{field}, x) where {field} +function canview(optic::Accessors.PropertyLens{field}, x) where {field} return hasproperty(x, field) end # `IndexLens`: only relevant if `x` supports indexing. -canview(lens::Accessors.IndexLens, x) = false -canview(lens::Accessors.IndexLens, x::AbstractArray) = checkbounds(Bool, x, lens.indices...) +canview(optic::Accessors.IndexLens, x) = false +canview(optic::Accessors.IndexLens, x::AbstractArray) = checkbounds(Bool, x, optic.indices...) -# `ComposedLens`: check that we can view `.outer` and `.inner`, but using -# value extracted using `.outer`. -function canview(lens::Accessors.ComposedOptic, x) - return canview(lens.outer, x) && canview(lens.inner, get(x, lens.outer)) +# `ComposedOptic`: check that we can view `.inner` and `.outer`, but using +# value extracted using `.inner`. +function canview(optic::Accessors.ComposedOptic, x) + return canview(optic.inner, x) && canview(optic.outer, optic.inner(x)) end """ @@ -416,7 +416,7 @@ x ``` """ function parent(vn::VarName) - p = parent(getlens(vn)) + p = parent(getoptic(vn)) return p === nothing ? VarName(vn, identity) : VarName(vn, p) end @@ -470,7 +470,7 @@ end splitoptic(condition, optic) Return a 3-tuple `(parent, child, issuccess)` where, if `issuccess` is `true`, -`parent` is a lens such that `condition(parent)` is `true` and `parent ⨟ child == lens`. +`parent` is a optic such that `condition(parent)` is `true` and `parent ⨟ child == optic`. If `issuccess` is `false`, then no such split could be found. @@ -749,8 +749,8 @@ false """ function hasvalue(vals::NamedTuple, vn::VarName{sym}) where {sym} # LHS: Ensure that `nt` indeed has the property we want. - # RHS: Ensure that the lens can view into `nt`. - return haskey(vals, sym) && canview(getlens(vn), getproperty(vals, sym)) + # RHS: Ensure that the optic can view into `nt`. + return haskey(vals, sym) && canview(getoptic(vn), getproperty(vals, sym)) end # For `dictlike` we need to check wether `vn` is "immediately" present, or @@ -760,20 +760,20 @@ function hasvalue(vals::AbstractDict, vn::VarName) haskey(vals, vn) && return true # If `vn` is not present, we check any parent-varnames by attempting - # to split the lens into the key / `parent` and the extraction lens / `child`. + # to split the optic into the key / `parent` and the extraction optic / `child`. # If `issuccess` is `true`, we found such a split, and hence `vn` is present. - parent, child, issuccess = splitlens(getlens(vn)) do lens - l = lens === nothing ? identity : lens - haskey(vals, VarName(vn, l)) + parent, child, issuccess = splitoptic(getoptic(vn)) do optic + o = optic === nothing ? identity : optic + haskey(vals, VarName(vn, o)) end - # When combined with `VarInfo`, `nothing` is equivalent to `IdentityLens`. - keylens = parent === nothing ? identity : parent + # When combined with `VarInfo`, `nothing` is equivalent to `identity`. + keyoptic = parent === nothing ? identity : parent # Return early if no such split could be found. issuccess || return false # At this point we just need to check that we `canview` the value. - value = vals[VarName(vn, keylens)] + value = vals[VarName(vn, keyoptic)] return canview(child, value) end @@ -790,13 +790,13 @@ function nested_getindex(values::AbstractDict, vn::VarName) return maybeval end - # Split the lens into the key / `parent` and the extraction lens / `child`. - parent, child, issuccess = splitlens(getlens(vn)) do lens - l = lens === nothing ? identity : lens - haskey(values, VarName(vn, l)) + # Split the optic into the key / `parent` and the extraction optic / `child`. + parent, child, issuccess = splitoptic(getoptic(vn)) do optic + o = optic === nothing ? identity : optic + haskey(values, VarName(vn, o)) end - # When combined with `VarInfo`, `nothing` is equivalent to `IdentityLens`. - keylens = parent === nothing ? identity : parent + # When combined with `VarInfo`, `nothing` is equivalent to `identity`. + keyoptic = parent === nothing ? identity : parent # If we found a valid split, then we can extract the value. if !issuccess @@ -806,7 +806,7 @@ function nested_getindex(values::AbstractDict, vn::VarName) # TODO: Should we also check that we `canview` the extracted `value` # rather than just let it fail upon `get` call? - value = values[VarName(vn, keylens)] + value = values[VarName(vn, keyoptic)] return get(value, child) end @@ -911,20 +911,20 @@ x.z[2][1] varname_leaves(vn::VarName, ::Real) = [vn] function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) return ( - VarName(vn, getlens(vn) ∘ Accessors.IndexLens(Tuple(I))) for + VarName(vn, getoptic(vn) ⨟ Accessors.IndexLens(Tuple(I))) for I in CartesianIndices(val) ) end function varname_leaves(vn::VarName, val::AbstractArray) return Iterators.flatten( - varname_leaves(VarName(vn, getlens(vn) ∘ Accessors.IndexLens(Tuple(I))), val[I]) for + varname_leaves(VarName(vn, getoptic(vn) ⨟ Accessors.IndexLens(Tuple(I))), val[I]) for I in CartesianIndices(val) ) end function varname_leaves(vn::VarName, val::NamedTuple) iter = Iterators.map(keys(val)) do sym - lens = Accessors.PropertyLens{sym}() - varname_leaves(vn ∘ lens, get(val, lens)) + optic = Accessors.PropertyLens{sym}() + varname_leaves(vn ∘ optic, optic(val)) end return Iterators.flatten(iter) end @@ -1033,7 +1033,7 @@ function varname_and_value_leaves_inner( ) return ( Leaf( - VarName(vn, DynamicPPL.getlens(vn) ∘ DynamicPPL.Accessors.IndexLens(Tuple(I))), + VarName(vn, DynamicPPL.getoptic(vn) ∘ DynamicPPL.Accessors.IndexLens(Tuple(I))), val[I], ) for I in CartesianIndices(val) ) @@ -1042,15 +1042,15 @@ end function varname_and_value_leaves_inner(vn::VarName, val::AbstractArray) return Iterators.flatten( varname_and_value_leaves_inner( - VarName(vn, DynamicPPL.getlens(vn) ∘ DynamicPPL.Accessors.IndexLens(Tuple(I))), + VarName(vn, DynamicPPL.getoptic(vn) ∘ DynamicPPL.Accessors.IndexLens(Tuple(I))), val[I], ) for I in CartesianIndices(val) ) end function varname_and_value_leaves_inner(vn::DynamicPPL.VarName, val::NamedTuple) iter = Iterators.map(keys(val)) do sym - lens = DynamicPPL.Accessors.PropertyLens{sym}() - varname_and_value_leaves_inner(vn ∘ lens, get(val, lens)) + optic = DynamicPPL.Accessors.PropertyLens{sym}() + varname_and_value_leaves_inner(VarName{getsym(vn)}(getoptic(vn) ⨟ optic), optic(val)) end return Iterators.flatten(iter) @@ -1067,7 +1067,7 @@ end function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.LowerTriangular) return ( Leaf( - VarName(vn, DynamicPPL.getlens(vn) ∘ DynamicPPL.Accessors.IndexLens(Tuple(I))), + VarName(vn, DynamicPPL.getoptic(vn) ⨟ DynamicPPL.Accessors.IndexLens(Tuple(I))), x[I], ) # Iteration over the lower-triangular indices. @@ -1077,7 +1077,7 @@ end function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.UpperTriangular) return ( Leaf( - VarName(vn, DynamicPPL.getlens(vn) ∘ DynamicPPL.Accessors.IndexLens(Tuple(I))), + VarName(vn, DynamicPPL.getoptic(vn) ⨟ DynamicPPL.Accessors.IndexLens(Tuple(I))), x[I], ) # Iteration over the upper-triangular indices. diff --git a/src/varinfo.jl b/src/varinfo.jl index ce0a5ff73..de4e3196f 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1397,10 +1397,10 @@ function _nested_setindex_maybe!(vi::VarInfo, md::Metadata, val, vn::VarName) vn_parent = vns[i] dist = getdist(md, vn_parent) val_parent = getindex(vi, vn_parent, dist) # TODO: Ensure that we're working with a view here. - # Split the varname into its tail lens. - lens = remove_parent_lens(vn_parent, vn) + # Split the varname into its tail optic. + optic = remove_parent_optic(vn_parent, vn) # Update the value for the parent. - val_parent_updated = set!!(val_parent, lens, val) + val_parent_updated = set!!(val_parent, optic, val) setindex!(vi, val_parent_updated, vn_parent) return vn_parent end diff --git a/test/Project.toml b/test/Project.toml index a95abef1e..1b3f53b33 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -29,7 +29,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" Accessors = "0.1" ADTypes = "0.2" AbstractMCMC = "5" -AbstractPPL = "0.7" +AbstractPPL = "0.8" Bijectors = "0.13" Compat = "4.3.0" Distributions = "0.25" diff --git a/test/contexts.jl b/test/contexts.jl index 994d98194..11e2c99b7 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -55,7 +55,7 @@ Return `vn` but now with the prefix removed. """ function remove_prefix(vn::VarName) return VarName{Symbol(split(string(vn), string(DynamicPPL.PREFIX_SEPARATOR))[end])}( - getlens(vn) + getoptic(vn) ) end @@ -169,7 +169,7 @@ end # Let's check elementwise. for vn_child in DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) - if get(val, getlens(vn_child)) === missing + if getoptic(vn_child)(val) === missing @test contextual_isassumption(context, vn_child) else @test !contextual_isassumption(context, vn_child) @@ -206,7 +206,7 @@ end @test hasconditioned_nested(context, vn_child) # Value should be the same as extracted above. @test getconditioned_nested(context, vn_child) === - get(val, getlens(vn_child)) + getoptic(vn_child)(val) end end end @@ -233,12 +233,12 @@ end vn = VarName{:x}() vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) @test DynamicPPL.getsym(vn_prefixed) == Symbol("a.b.c.d.e.f.x") - @test getlens(vn_prefixed) === getlens(vn) + @test getoptic(vn_prefixed) === getoptic(vn) vn = VarName{:x}(((1,),)) vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) @test DynamicPPL.getsym(vn_prefixed) == Symbol("a.b.c.d.e.f.x") - @test getlens(vn_prefixed) === getlens(vn) + @test getoptic(vn_prefixed) === getoptic(vn) end @testset "SamplingContext" begin diff --git a/test/varinfo.jl b/test/varinfo.jl index 71e341767..a9e734575 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -1,7 +1,7 @@ function check_varinfo_keys(varinfo, vns) if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple{<:NamedTuple} # NOTE: We can't compare the `keys(varinfo_merged)` directly with `vns`, - # since `keys(varinfo_merged)` only contains `VarName` with `IdentityLens`. + # since `keys(varinfo_merged)` only contains `VarName` with `identity`. # So we just check that the original keys are present. for vn in vns # Should have all the original keys. @@ -519,7 +519,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) end # For certain varinfos we should have errors. - # `SimpleVarInfo{<:NamedTuple}` can only handle varnames with `IdentityLens`. + # `SimpleVarInfo{<:NamedTuple}` can only handle varnames with `identity`. varinfo = varinfos[findfirst(Base.Fix2(isa, SimpleVarInfo{<:NamedTuple}), varinfos)] @testset "$(short_varinfo_name(varinfo)): failure cases" begin @test_throws ArgumentError subset( From 8594fddb97d63257f8f3872700f21b98ddf9920f Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Fri, 5 Apr 2024 12:43:52 +0100 Subject: [PATCH 12/23] Update src/utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/utils.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index ac6d86dbc..aad36bb3b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1050,7 +1050,9 @@ end function varname_and_value_leaves_inner(vn::DynamicPPL.VarName, val::NamedTuple) iter = Iterators.map(keys(val)) do sym optic = DynamicPPL.Accessors.PropertyLens{sym}() - varname_and_value_leaves_inner(VarName{getsym(vn)}(getoptic(vn) ⨟ optic), optic(val)) + varname_and_value_leaves_inner( + VarName{getsym(vn)}(getoptic(vn) ⨟ optic), optic(val) + ) end return Iterators.flatten(iter) From 5fa0e63a81f7ef1e5178def8531b98d97034a7db Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Fri, 5 Apr 2024 12:43:57 +0100 Subject: [PATCH 13/23] Update src/utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index aad36bb3b..6111c192c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -917,8 +917,8 @@ function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) end function varname_leaves(vn::VarName, val::AbstractArray) return Iterators.flatten( - varname_leaves(VarName(vn, getoptic(vn) ⨟ Accessors.IndexLens(Tuple(I))), val[I]) for - I in CartesianIndices(val) + varname_leaves(VarName(vn, getoptic(vn) ⨟ Accessors.IndexLens(Tuple(I))), val[I]) + for I in CartesianIndices(val) ) end function varname_leaves(vn::VarName, val::NamedTuple) From 9573d03737ed6ca7accf55d1e5c432841cb268ef Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Fri, 5 Apr 2024 12:44:01 +0100 Subject: [PATCH 14/23] Update src/utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/utils.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 6111c192c..9566e59bd 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -390,7 +390,9 @@ end # `IndexLens`: only relevant if `x` supports indexing. canview(optic::Accessors.IndexLens, x) = false -canview(optic::Accessors.IndexLens, x::AbstractArray) = checkbounds(Bool, x, optic.indices...) +function canview(optic::Accessors.IndexLens, x::AbstractArray) + return checkbounds(Bool, x, optic.indices...) +end # `ComposedOptic`: check that we can view `.inner` and `.outer`, but using # value extracted using `.inner`. From e126e8699c0495375354a29237b1f31cd7f1c9c9 Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Fri, 5 Apr 2024 12:44:05 +0100 Subject: [PATCH 15/23] Update src/utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/utils.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 9566e59bd..5d5892f2d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -352,7 +352,9 @@ function set!!(obj, optic::AbstractPPL.ALLOWED_OPTICS, value) return Accessors.set(obj, opticmut, value) end function set!!(obj, vn::VarName{sym}, value) where {sym} - optic = BangBang.prefermutation(Accessors.PropertyLens{sym}() ∘ AbstractPPL.getoptic(vn)) + optic = BangBang.prefermutation( + Accessors.PropertyLens{sym}() ∘ AbstractPPL.getoptic(vn) + ) return Accessors.set(obj, optic, value) end From 0787ed2d5fb565c38186b8ec28ee85d5f2aad4e0 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Sun, 7 Apr 2024 12:41:36 +0100 Subject: [PATCH 16/23] use fixed AbstractPPL for tests --- test/runtests.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index f18167d08..25e839da2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,3 +1,6 @@ +using Pkg +Pkg.add(url="https://github.com/TuringLang/AbstractPPL.jl", rev="sunxd/fix-get-function") + using Accessors using ADTypes using DynamicPPL From ccb011511e581a9ab2a1d5cbea89a0f4ad6b3dfd Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 9 Apr 2024 09:18:52 +0100 Subject: [PATCH 17/23] adjust some util code related to compositing varname and optic --- src/compiler.jl | 4 ++-- src/utils.jl | 22 +++++++++++----------- test/runtests.jl | 2 +- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index f6e3e3982..65efe00a7 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -226,7 +226,7 @@ function unwrap_right_left_vns( # for `i = size(left, 2)`. Hence the symbol should be `x[:, i]`, # and we therefore add the `Colon()` below. vns = map(axes(left, 2)) do i - return AbstractPPL.concretize(vn ⨟ Accessors.IndexLens((Colon(), i)), left) + return AbstractPPL.concretize(Accessors.IndexLens((Colon(), i)) ∘ vn, left) end return unwrap_right_left_vns(right, left, vns) end @@ -236,7 +236,7 @@ function unwrap_right_left_vns( vn::VarName, ) vns = map(CartesianIndices(left)) do i - return vn ∘ Accessors.IndexLens(Tuple(i)) + return Accessors.IndexLens(Tuple(i)) ∘ vn end return unwrap_right_left_vns(right, left, vns) end diff --git a/src/utils.jl b/src/utils.jl index 5d5892f2d..28fa1e814 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -438,7 +438,7 @@ julia> parent(@o(_.a[1])) (@o _.a) julia> # Parent of optic without parents results in `nothing`. - (parent ⨟ parent)(@o(_.a[1])) === nothing + (parent ∘ parent)(@o(_.a[1])) === nothing true ``` """ @@ -466,7 +466,7 @@ julia> parent_and_child(@o(_.a)) parent_and_child(optic::AbstractPPL.ALLOWED_OPTICS) = (nothing, optic) function parent_and_child(optic::Accessors.ComposedOptic) p, child = parent_and_child(optic.outer) - parent = p === nothing ? optic.inner : optic.inner ⨟ p + parent = p === nothing ? optic.inner : p ∘ optic.inner return parent, child end @@ -474,7 +474,7 @@ end splitoptic(condition, optic) Return a 3-tuple `(parent, child, issuccess)` where, if `issuccess` is `true`, -`parent` is a optic such that `condition(parent)` is `true` and `parent ⨟ child == optic`. +`parent` is a optic such that `condition(parent)` is `true` and `child ∘ parent == optic`. If `issuccess` is `false`, then no such split could be found. @@ -486,7 +486,7 @@ julia> p, c, issucesss = splitoptic(@o(_.a[1])) do parent end ((@o _.a), (@o _[1]), true) -julia> p ⨟ c +julia> c ∘ p (@o _.a[1]) julia> splitoptic(@o(_.a[1])) do parent @@ -501,7 +501,7 @@ function splitoptic(condition, optic) # We stop if either a) `condition` is satisfied, or b) we reached the root. while !condition(current_parent) && current_parent !== nothing current_parent, c = parent_and_child(current_parent) - current_child = c ⨟ current_child + current_child = current_child ∘ c end return current_parent, current_child, condition(current_parent) @@ -915,20 +915,20 @@ x.z[2][1] varname_leaves(vn::VarName, ::Real) = [vn] function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) return ( - VarName(vn, getoptic(vn) ⨟ Accessors.IndexLens(Tuple(I))) for + VarName(vn, Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)) for I in CartesianIndices(val) ) end function varname_leaves(vn::VarName, val::AbstractArray) return Iterators.flatten( - varname_leaves(VarName(vn, getoptic(vn) ⨟ Accessors.IndexLens(Tuple(I))), val[I]) + varname_leaves(VarName(vn, Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), val[I]) for I in CartesianIndices(val) ) end function varname_leaves(vn::VarName, val::NamedTuple) iter = Iterators.map(keys(val)) do sym optic = Accessors.PropertyLens{sym}() - varname_leaves(vn ∘ optic, optic(val)) + varname_leaves(VarName(vn, optic ∘ getoptic(vn)), optic(val)) end return Iterators.flatten(iter) end @@ -1037,7 +1037,7 @@ function varname_and_value_leaves_inner( ) return ( Leaf( - VarName(vn, DynamicPPL.getoptic(vn) ∘ DynamicPPL.Accessors.IndexLens(Tuple(I))), + VarName(vn, DynamicPPL.getoptic(vn) ⨟ DynamicPPL.Accessors.IndexLens(Tuple(I))), val[I], ) for I in CartesianIndices(val) ) @@ -1046,7 +1046,7 @@ end function varname_and_value_leaves_inner(vn::VarName, val::AbstractArray) return Iterators.flatten( varname_and_value_leaves_inner( - VarName(vn, DynamicPPL.getoptic(vn) ∘ DynamicPPL.Accessors.IndexLens(Tuple(I))), + VarName(vn, DynamicPPL.getoptic(vn) ⨟ DynamicPPL.Accessors.IndexLens(Tuple(I))), val[I], ) for I in CartesianIndices(val) ) @@ -1064,7 +1064,7 @@ end # Special types. function varname_and_value_leaves_inner(vn::VarName, x::Cholesky) # TODO: Or do we use `PDMat` here? - return if x.uplo == 'L' + return if x.uplo == 'L' varname_and_value_leaves_inner(vn ∘ Accessors.PropertyLens{:L}(), x.L) else varname_and_value_leaves_inner(vn ∘ Accessors.PropertyLens{:U}(), x.U) diff --git a/test/runtests.jl b/test/runtests.jl index 25e839da2..f961b2578 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,5 @@ using Pkg -Pkg.add(url="https://github.com/TuringLang/AbstractPPL.jl", rev="sunxd/fix-get-function") +Pkg.add(url="https://github.com/TuringLang/AbstractPPL.jl", rev="sunxd/add_composite_function") using Accessors using ADTypes From 835b4574d053c0467aef14dd464ef4c393de3f39 Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Tue, 9 Apr 2024 09:30:49 +0100 Subject: [PATCH 18/23] Update src/utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 28fa1e814..759dfde8e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1064,7 +1064,7 @@ end # Special types. function varname_and_value_leaves_inner(vn::VarName, x::Cholesky) # TODO: Or do we use `PDMat` here? - return if x.uplo == 'L' + return if x.uplo == 'L' varname_and_value_leaves_inner(vn ∘ Accessors.PropertyLens{:L}(), x.L) else varname_and_value_leaves_inner(vn ∘ Accessors.PropertyLens{:U}(), x.U) From 62de1f5bda807ba6f5cbc3208ec3641d102eb9be Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 9 Apr 2024 16:15:10 +0100 Subject: [PATCH 19/23] update with recent AbstractPPL merge --- test/Project.toml | 2 +- test/runtests.jl | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index 1b3f53b33..d345d30d5 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -29,7 +29,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" Accessors = "0.1" ADTypes = "0.2" AbstractMCMC = "5" -AbstractPPL = "0.8" +AbstractPPL = "0.8.2" Bijectors = "0.13" Compat = "4.3.0" Distributions = "0.25" diff --git a/test/runtests.jl b/test/runtests.jl index f961b2578..f18167d08 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,3 @@ -using Pkg -Pkg.add(url="https://github.com/TuringLang/AbstractPPL.jl", rev="sunxd/add_composite_function") - using Accessors using ADTypes using DynamicPPL From 8e5ea7c9a5f66d8fd8fca3c8a8ec5e058cc54ded Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Wed, 10 Apr 2024 15:08:12 +0100 Subject: [PATCH 20/23] test new APPL fix --- test/runtests.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index f18167d08..71b016799 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,3 +1,6 @@ +using Pkg +Pkg.add(url="https://github.com/TuringLang/AbstractPPL.jl", rev="sunxd/patch_compose") + using Accessors using ADTypes using DynamicPPL From 16b63242a7204c028347c4d4de71710aba77f02b Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Wed, 10 Apr 2024 15:44:39 +0100 Subject: [PATCH 21/23] remove the `Pkg.add`, causing issue with env resolution --- test/runtests.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 71b016799..f18167d08 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,3 @@ -using Pkg -Pkg.add(url="https://github.com/TuringLang/AbstractPPL.jl", rev="sunxd/patch_compose") - using Accessors using ADTypes using DynamicPPL From c2bad4ab1b705127a2dfa1ced3e4a66ba9fe741a Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Thu, 11 Apr 2024 17:35:27 +0100 Subject: [PATCH 22/23] use APPL pending fix for testing; fix more errors --- src/model_utils.jl | 8 ++++---- src/utils.jl | 18 +++++++++--------- test/runtests.jl | 14 ++++++++++++++ 3 files changed, 27 insertions(+), 13 deletions(-) diff --git a/src/model_utils.jl b/src/model_utils.jl index 9232b0c88..ac4ec7022 100644 --- a/src/model_utils.jl +++ b/src/model_utils.jl @@ -79,12 +79,12 @@ function varname_in_chain!( x::AbstractArray, vn_parent::VarName{sym}, chain, chain_idx, iteration_idx, out ) where {sym} # We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the optic. - # This way we can use `getoptic(vn)` to extract the value from `x` and use `vn_parent ⨟ getoptic(vn)` + # This way we can use `getoptic(vn)` to extract the value from `x` and use `getoptic(vn) ∘ vn_parent` # to extract the value from the `chain`. for vn in varname_leaves(VarName{sym}(), x) # Update `out`, possibly in place, and return. l = AbstractPPL.getoptic(vn) - varname_in_chain!(x, vn_parent ⨟ l, chain, chain_idx, iteration_idx, out) + varname_in_chain!(x, l ∘ vn_parent, chain, chain_idx, iteration_idx, out) end return out end @@ -104,7 +104,7 @@ function values_from_chain( x::AbstractArray, vn_parent::VarName{sym}, chain, chain_idx, iteration_idx ) where {sym} # We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the optic. - # This way we can use `getoptic(vn)` to extract the value from `x` and use `vn_parent ⨟ getoptic(vn)` + # This way we can use `getoptic(vn)` to extract the value from `x` and use `getoptic(vn) ∘ vn_parent` # to extract the value from the `chain`. out = similar(x) for vn in varname_leaves(VarName{sym}(), x) @@ -113,7 +113,7 @@ function values_from_chain( out = Accessors.set( out, BangBang.prefermutation(l), - chain[iteration_idx, Symbol(vn_parent ⨟ l), chain_idx], + chain[iteration_idx, Symbol(l ∘ vn_parent), chain_idx], ) end diff --git a/src/utils.jl b/src/utils.jl index 759dfde8e..4aa5d910b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -353,7 +353,7 @@ function set!!(obj, optic::AbstractPPL.ALLOWED_OPTICS, value) end function set!!(obj, vn::VarName{sym}, value) where {sym} optic = BangBang.prefermutation( - Accessors.PropertyLens{sym}() ∘ AbstractPPL.getoptic(vn) + AbstractPPL.getoptic(vn) ∘ Accessors.PropertyLens{sym}() ) return Accessors.set(obj, optic, value) end @@ -811,7 +811,7 @@ function nested_getindex(values::AbstractDict, vn::VarName) # TODO: Should we also check that we `canview` the extracted `value` # rather than just let it fail upon `get` call? value = values[VarName(vn, keyoptic)] - return get(value, child) + return child(value) end """ @@ -1037,7 +1037,7 @@ function varname_and_value_leaves_inner( ) return ( Leaf( - VarName(vn, DynamicPPL.getoptic(vn) ⨟ DynamicPPL.Accessors.IndexLens(Tuple(I))), + VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) ∘ DynamicPPL.getoptic(vn)), val[I], ) for I in CartesianIndices(val) ) @@ -1046,7 +1046,7 @@ end function varname_and_value_leaves_inner(vn::VarName, val::AbstractArray) return Iterators.flatten( varname_and_value_leaves_inner( - VarName(vn, DynamicPPL.getoptic(vn) ⨟ DynamicPPL.Accessors.IndexLens(Tuple(I))), + VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) ∘ DynamicPPL.getoptic(vn)), val[I], ) for I in CartesianIndices(val) ) @@ -1055,7 +1055,7 @@ function varname_and_value_leaves_inner(vn::DynamicPPL.VarName, val::NamedTuple) iter = Iterators.map(keys(val)) do sym optic = DynamicPPL.Accessors.PropertyLens{sym}() varname_and_value_leaves_inner( - VarName{getsym(vn)}(getoptic(vn) ⨟ optic), optic(val) + VarName{getsym(vn)}(optic ∘ getoptic(vn)), optic(val) ) end @@ -1065,15 +1065,15 @@ end function varname_and_value_leaves_inner(vn::VarName, x::Cholesky) # TODO: Or do we use `PDMat` here? return if x.uplo == 'L' - varname_and_value_leaves_inner(vn ∘ Accessors.PropertyLens{:L}(), x.L) + varname_and_value_leaves_inner(Accessors.PropertyLens{:L}() ∘ vn, x.L) else - varname_and_value_leaves_inner(vn ∘ Accessors.PropertyLens{:U}(), x.U) + varname_and_value_leaves_inner(Accessors.PropertyLens{:U}() ∘ vn, x.U) end end function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.LowerTriangular) return ( Leaf( - VarName(vn, DynamicPPL.getoptic(vn) ⨟ DynamicPPL.Accessors.IndexLens(Tuple(I))), + VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) ∘ DynamicPPL.getoptic(vn)), x[I], ) # Iteration over the lower-triangular indices. @@ -1083,7 +1083,7 @@ end function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.UpperTriangular) return ( Leaf( - VarName(vn, DynamicPPL.getoptic(vn) ⨟ DynamicPPL.Accessors.IndexLens(Tuple(I))), + VarName(vn, DynamicPPL.Accessors.IndexLens(Tuple(I)) ∘ DynamicPPL.getoptic(vn)), x[I], ) # Iteration over the upper-triangular indices. diff --git a/test/runtests.jl b/test/runtests.jl index f18167d08..290d2e10f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -25,6 +25,20 @@ using Test using DynamicPPL: getargs_dottilde, getargs_tilde, Selector +# TODO: temporarily overwrite for testing +using AbstractPPL: ALLOWED_OPTICS, VarName +# Allow compositions with optic. +function Base.:∘(optic::ALLOWED_OPTICS, vn::VarName{sym,<:ALLOWED_OPTICS}) where {sym} + vn_optic = getoptic(vn) + if vn_optic == identity + return VarName{sym}(optic) + elseif optic == identity + return vn + else + return VarName{sym}(optic ∘ vn_optic) + end +end + const DIRECTORY_DynamicPPL = dirname(dirname(pathof(DynamicPPL))) const DIRECTORY_Turing_tests = joinpath(DIRECTORY_DynamicPPL, "test", "turing") const GROUP = get(ENV, "GROUP", "All") From 0db208384a4893ad9ca5dcf49d6924cc8f851cab Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 12 Apr 2024 08:59:12 +0100 Subject: [PATCH 23/23] fix more errors --- src/compiler.jl | 6 +++--- src/utils.jl | 24 ++++++++++++------------ test/runtests.jl | 2 +- test/turing/varinfo.jl | 2 +- 4 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 65efe00a7..f8a04a557 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -202,13 +202,13 @@ variables. # Example ```jldoctest; setup=:(using Distributions, LinearAlgebra) julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(ones(2), I), randn(2, 2), @varname(x)); vns[end] -x[:,2] +x[:, 2] julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x)); vns[end] -x[1,2] +x[1, 2] julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x[:])); vns[end] -x[:][1,2] +x[:][1, 2] julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(3), @varname(x[1])); vns[end] x[1][3] diff --git a/src/utils.jl b/src/utils.jl index 4aa5d910b..06528a72d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -967,27 +967,27 @@ julia> x = reshape(1:4, 2, 2); julia> # `LowerTriangular` foreach(println, varname_and_value_leaves(@varname(x), LowerTriangular(x))) -(x[1,1], 1) -(x[2,1], 2) -(x[2,2], 4) +(x[1, 1], 1) +(x[2, 1], 2) +(x[2, 2], 4) julia> # `UpperTriangular` foreach(println, varname_and_value_leaves(@varname(x), UpperTriangular(x))) -(x[1,1], 1) -(x[1,2], 3) -(x[2,2], 4) +(x[1, 1], 1) +(x[1, 2], 3) +(x[2, 2], 4) julia> # `Cholesky` with lower-triangular foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'L', 0))) -(x.L[1,1], 1.0) -(x.L[2,1], 0.0) -(x.L[2,2], 1.0) +(x.L[1, 1], 1.0) +(x.L[2, 1], 0.0) +(x.L[2, 2], 1.0) julia> # `Cholesky` with upper-triangular foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'U', 0))) -(x.U[1,1], 1.0) -(x.U[1,2], 0.0) -(x.U[2,2], 1.0) +(x.U[1, 1], 1.0) +(x.U[1, 2], 0.0) +(x.U[2, 2], 1.0) ``` """ function varname_and_value_leaves(vn::VarName, x) diff --git a/test/runtests.jl b/test/runtests.jl index 290d2e10f..efa595516 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -90,7 +90,7 @@ include("test_util.jl") DocMeta.setdocmeta!( DynamicPPL, :DocTestSetup, - :(using DynamicPPL, Distributions); + :(using DynamicPPL, Distributions, Accessors); recursive=true, ) doctestfilters = [ diff --git a/test/turing/varinfo.jl b/test/turing/varinfo.jl index c4e3fa87b..30408e598 100644 --- a/test/turing/varinfo.jl +++ b/test/turing/varinfo.jl @@ -192,7 +192,7 @@ return p end chain = sample(mat_name_test(), HMC(0.2, 4), 1000) - check_numerical(chain, ["p[1,1]"], [0]; atol=0.25) + check_numerical(chain, ["p[1, 1]"], [0]; atol=0.25) @model function marr_name_test() p = Array{Array{Any}}(undef, 2)