Skip to content

Commit

Permalink
Merge 0db2083 into 6a2454f
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 committed Apr 12, 2024
2 parents 6a2454f + 0db2083 commit a807d5e
Show file tree
Hide file tree
Showing 19 changed files with 272 additions and 238 deletions.
42 changes: 21 additions & 21 deletions Project.toml
@@ -1,11 +1,12 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.24.9"
version = "0.25.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -20,15 +21,31 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[extensions]
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
DynamicPPLForwardDiffExt = ["ForwardDiff"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLReverseDiffExt = ["ReverseDiff"]
DynamicPPLZygoteRulesExt = ["ZygoteRules"]

[compat]
ADTypes = "0.2"
AbstractMCMC = "5"
AbstractPPL = "0.7"
BangBang = "0.3"
AbstractPPL = "0.8.2"
Accessors = "0.1"
BangBang = "0.4"
Bijectors = "0.13"
ChainRulesCore = "1"
Compat = "4"
Expand All @@ -44,29 +61,12 @@ MacroTools = "0.5.6"
OrderedCollections = "1"
Random = "1.6"
Requires = "1"
Setfield = "1"
Test = "1.6"
ZygoteRules = "0.2"
julia = "1.6"

[extensions]
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
DynamicPPLForwardDiffExt = ["ForwardDiff"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLReverseDiffExt = ["ReverseDiff"]
DynamicPPLZygoteRulesExt = ["ZygoteRules"]

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
6 changes: 2 additions & 4 deletions docs/Project.toml
@@ -1,21 +1,19 @@
[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"

[compat]
Accessors = "0.1"
DataStructures = "0.18"
Distributions = "0.25"
Documenter = "1"
FillArrays = "0.13, 1"
LogDensityProblems = "2"
MCMCChains = "5, 6"
MLUtils = "0.3, 0.4"
Setfield = "0.7.1, 0.8, 1"
StableRNGs = "1"
20 changes: 18 additions & 2 deletions docs/src/tutorials/prob-interface.md
Expand Up @@ -107,12 +107,28 @@ To give an example of the probability interface in use, we can use it to estimat
In cross-validation, we split the dataset into several equal parts.
Then, we choose one of these sets to serve as the validation set.
Here, we measure fit using the cross entropy (Bayes loss).[^1]
(For the sake of simplicity, in the following code, we enforce that `nfolds` must divide the number of data points. For a more competent implementation, see [MLUtils.jl](https://juliaml.github.io/MLUtils.jl/dev/api/#MLUtils.kfolds).)

```@example probinterface
using MLUtils
# Calculate the train/validation splits across `nfolds` partitions, assume `length(dataset)` divides `nfolds`
function kfolds(dataset::Array{<:Real}, nfolds::Int)
fold_size, remaining = divrem(length(dataset), nfolds)
if remaining != 0
error("The number of folds must divide the number of data points.")
end
first_idx = firstindex(dataset)
last_idx = lastindex(dataset)
splits = map(0:(nfolds - 1)) do i
start_idx = first_idx + i * fold_size
end_idx = start_idx + fold_size
train_set_indices = [first_idx:(start_idx - 1); end_idx:last_idx]
return (view(dataset, train_set_indices), view(dataset, start_idx:(end_idx - 1)))
end
return splits
end
function cross_val(
dataset::AbstractVector{<:Real};
dataset::Vector{<:Real};
nfolds::Int=5,
nsamples::Int=1_000,
rng::Random.AbstractRNG=Random.default_rng(),
Expand Down
2 changes: 1 addition & 1 deletion src/DynamicPPL.jl
Expand Up @@ -12,7 +12,7 @@ using ADTypes: ADTypes
using BangBang: BangBang, push!!, empty!!, setindex!!
using MacroTools: MacroTools
using ConstructionBase: ConstructionBase
using Setfield: Setfield
using Accessors: Accessors
using LogDensityProblems: LogDensityProblems
using LogDensityProblemsAD: LogDensityProblemsAD

Expand Down
8 changes: 4 additions & 4 deletions src/abstract_varinfo.jl
Expand Up @@ -262,7 +262,7 @@ julia> values_as(SimpleVarInfo(data), NamedTuple)
(x = 1.0, m = [2.0])
julia> values_as(SimpleVarInfo(data), OrderedDict)
OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Any} with 2 entries:
OrderedDict{VarName{sym, typeof(identity)} where sym, Any} with 2 entries:
x => 1.0
m => [2.0]
Expand Down Expand Up @@ -312,7 +312,7 @@ julia> values_as(vi, NamedTuple)
(s = 1.0, m = 2.0)
julia> values_as(vi, OrderedDict)
OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Float64} with 2 entries:
OrderedDict{VarName{sym, typeof(identity)} where sym, Float64} with 2 entries:
s => 1.0
m => 2.0
Expand All @@ -338,7 +338,7 @@ julia> values_as(vi, NamedTuple)
(s = 1.0, m = 2.0)
julia> values_as(vi, OrderedDict)
OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Float64} with 2 entries:
OrderedDict{VarName{sym, typeof(identity)} where sym, Float64} with 2 entries:
s => 1.0
m => 2.0
Expand Down Expand Up @@ -426,7 +426,7 @@ julia> # Extract one with only `m`.
julia> keys(varinfo_subset1)
1-element Vector{VarName{:m, Setfield.IdentityLens}}:
1-element Vector{VarName{:m, typeof(identity)}}:
m
julia> varinfo_subset1[@varname(m)]
Expand Down
18 changes: 9 additions & 9 deletions src/compiler.jl
Expand Up @@ -4,11 +4,11 @@ const INTERNALNAMES = (:__model__, :__context__, :__varinfo__)
need_concretize(expr)
Return `true` if `expr` needs to be concretized, i.e., if it contains a colon `:` or
requires a dynamic lens.
requires a dynamic optic.
# Examples
```jldoctest; setup=:(using Setfield)
```jldoctest; setup=:(using Accessors)
julia> DynamicPPL.need_concretize(:(x[1, :]))
true
Expand All @@ -19,7 +19,7 @@ julia> DynamicPPL.need_concretize(:(x[1, 1]))
false
"""
function need_concretize(expr)
return Setfield.need_dynamic_lens(expr) || begin
return Accessors.need_dynamic_optic(expr) || begin
flag = false
MacroTools.postwalk(expr) do ex
# Concretise colon by default
Expand Down Expand Up @@ -202,13 +202,13 @@ variables.
# Example
```jldoctest; setup=:(using Distributions, LinearAlgebra)
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(ones(2), I), randn(2, 2), @varname(x)); vns[end]
x[:,2]
x[:, 2]
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x)); vns[end]
x[1,2]
x[1, 2]
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x[:])); vns[end]
x[:][1,2]
x[:][1, 2]
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(3), @varname(x[1])); vns[end]
x[1][3]
Expand All @@ -226,7 +226,7 @@ function unwrap_right_left_vns(
# for `i = size(left, 2)`. Hence the symbol should be `x[:, i]`,
# and we therefore add the `Colon()` below.
vns = map(axes(left, 2)) do i
return AbstractPPL.concretize(vn Setfield.IndexLens((Colon(), i)), left)
return AbstractPPL.concretize(Accessors.IndexLens((Colon(), i)) vn, left)
end
return unwrap_right_left_vns(right, left, vns)
end
Expand All @@ -236,7 +236,7 @@ function unwrap_right_left_vns(
vn::VarName,
)
vns = map(CartesianIndices(left)) do i
return vn Setfield.IndexLens(Tuple(i))
return Accessors.IndexLens(Tuple(i)) vn
end
return unwrap_right_left_vns(right, left, vns)
end
Expand Down Expand Up @@ -437,7 +437,7 @@ function generate_tilde_assume(left, right, vn)
expr = :($left = $value)
if left isa Expr
expr = AbstractPPL.drop_escape(
Setfield.setmacro(BangBang.prefermutation, expr; overwrite=true)
Accessors.setmacro(BangBang.prefermutation, expr; overwrite=true)
)
end

Expand Down
4 changes: 2 additions & 2 deletions src/contexts.jl
Expand Up @@ -288,9 +288,9 @@ end

function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym}
if @generated
return :(VarName{$(QuoteNode(Symbol(Prefix, PREFIX_SEPARATOR, Sym)))}(getlens(vn)))
return :(VarName{$(QuoteNode(Symbol(Prefix, PREFIX_SEPARATOR, Sym)))}(getoptic(vn)))
else
VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(getlens(vn))
VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(getoptic(vn))
end
end

Expand Down
12 changes: 6 additions & 6 deletions src/model.jl
Expand Up @@ -279,11 +279,11 @@ in their trace/`VarInfo`:
```jldoctest condition
julia> keys(VarInfo(demo_outer()))
1-element Vector{VarName{:m, Setfield.IdentityLens}}:
1-element Vector{VarName{:m, typeof(identity)}}:
m
julia> keys(VarInfo(demo_outer_prefix()))
1-element Vector{VarName{Symbol("inner.m"), Setfield.IdentityLens}}:
1-element Vector{VarName{Symbol("inner.m"), typeof(identity)}}:
inner.m
```
Expand Down Expand Up @@ -448,7 +448,7 @@ julia> conditioned(cm)
julia> # Since we conditioned on `m`, not `a.m` as it will appear after prefixed,
# `a.m` is treated as a random variable.
keys(VarInfo(cm))
1-element Vector{VarName{Symbol("a.m"), Setfield.IdentityLens}}:
1-element Vector{VarName{Symbol("a.m"), typeof(identity)}}:
a.m
julia> # If we instead condition on `a.m`, `m` in the model will be considered an observation.
Expand Down Expand Up @@ -634,11 +634,11 @@ in their trace/`VarInfo`:
```jldoctest fix
julia> keys(VarInfo(demo_outer()))
1-element Vector{VarName{:m, Setfield.IdentityLens}}:
1-element Vector{VarName{:m, typeof(identity)}}:
m
julia> keys(VarInfo(demo_outer_prefix()))
1-element Vector{VarName{Symbol("inner.m"), Setfield.IdentityLens}}:
1-element Vector{VarName{Symbol("inner.m"), typeof(identity)}}:
inner.m
```
Expand Down Expand Up @@ -830,7 +830,7 @@ julia> fixed(cm)
julia> # Since we fixed on `m`, not `a.m` as it will appear after prefixed,
# `a.m` is treated as a random variable.
keys(VarInfo(cm))
1-element Vector{VarName{Symbol("a.m"), Setfield.IdentityLens}}:
1-element Vector{VarName{Symbol("a.m"), typeof(identity)}}:
a.m
julia> # If we instead fix on `a.m`, `m` in the model will be considered an observation.
Expand Down
18 changes: 9 additions & 9 deletions src/model_utils.jl
Expand Up @@ -78,13 +78,13 @@ end
function varname_in_chain!(
x::AbstractArray, vn_parent::VarName{sym}, chain, chain_idx, iteration_idx, out
) where {sym}
# We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the lens.
# This way we can use `getlens(vn)` to extract the value from `x` and use `vn_parent ∘ getlens(vn)`
# We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the optic.
# This way we can use `getoptic(vn)` to extract the value from `x` and use `getoptic(vn) ∘ vn_parent`
# to extract the value from the `chain`.
for vn in varname_leaves(VarName{sym}(), x)
# Update `out`, possibly in place, and return.
l = AbstractPPL.getlens(vn)
varname_in_chain!(x, vn_parent l, chain, chain_idx, iteration_idx, out)
l = AbstractPPL.getoptic(vn)
varname_in_chain!(x, l vn_parent, chain, chain_idx, iteration_idx, out)
end
return out
end
Expand All @@ -103,17 +103,17 @@ end
function values_from_chain(
x::AbstractArray, vn_parent::VarName{sym}, chain, chain_idx, iteration_idx
) where {sym}
# We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the lens.
# This way we can use `getlens(vn)` to extract the value from `x` and use `vn_parent ∘ getlens(vn)`
# We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the optic.
# This way we can use `getoptic(vn)` to extract the value from `x` and use `getoptic(vn) ∘ vn_parent`
# to extract the value from the `chain`.
out = similar(x)
for vn in varname_leaves(VarName{sym}(), x)
# Update `out`, possibly in place, and return.
l = AbstractPPL.getlens(vn)
out = Setfield.set(
l = AbstractPPL.getoptic(vn)
out = Accessors.set(
out,
BangBang.prefermutation(l),
chain[iteration_idx, Symbol(vn_parent l), chain_idx],
chain[iteration_idx, Symbol(l vn_parent), chain_idx],
)
end

Expand Down

0 comments on commit a807d5e

Please sign in to comment.