Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# DynamicPPL Changelog

## 0.38.3

Added a new exported struct, `DynamicPPL.ParamsWithStats`, and a corresponding function `DynamicPPL.to_chains`, which automatically converts a collection of `ParamsWithStats` to a given Chains type.

## 0.38.2

Added a compatibility entry for JET@0.11.
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.38.2"
version = "0.38.3"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
18 changes: 18 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -505,3 +505,21 @@ There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_va
DynamicPPL.Experimental.determine_suitable_varinfo
DynamicPPL.Experimental.is_suitable_varinfo
```

### Converting VarInfos to chains

It is a fairly common operation to want to convert a collection of `VarInfo` objects into a chains object for downstream analysis.
This can be accomplished with the following:

```@docs
DynamicPPL.ParamsWithStats
DynamicPPL.to_chains
```

Furthermore, one can convert chains back into a collection of parameter dictionaries and/or stats with:

```@docs
DynamicPPL.from_chains
```

This is useful if you want to use the result of a chain in further model evaluations.
165 changes: 109 additions & 56 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,113 @@ function chain_sample_to_varname_dict(
return d
end

"""
DynamicPPL.to_chains(
::Type{MCMCChains.Chains},
params_and_stats::AbstractArray{<:ParamsWithStats}
)

Convert an array of `DynamicPPL.ParamsWithStats` to an `MCMCChains.Chains` object.
"""
function DynamicPPL.to_chains(
::Type{MCMCChains.Chains},
params_and_stats::AbstractMatrix{<:DynamicPPL.ParamsWithStats},
)
# Handle parameters
all_vn_leaves = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()
split_dicts = map(params_and_stats) do ps
# Separate into individual VarNames.
vn_leaves_and_vals = if isempty(ps.params)
Tuple{DynamicPPL.VarName,Any}[]
else
iters = map(
AbstractPPL.varname_and_value_leaves,
keys(ps.params),
values(ps.params),
)
mapreduce(collect, vcat, iters)
end
vn_leaves = map(first, vn_leaves_and_vals)
vals = map(last, vn_leaves_and_vals)
for vn_leaf in vn_leaves
push!(all_vn_leaves, vn_leaf)
end
DynamicPPL.OrderedCollections.OrderedDict(zip(vn_leaves, vals))
end
vn_leaves = collect(all_vn_leaves)
param_vals = [
get(split_dicts[i, j], key, missing) for i in eachindex(axes(split_dicts, 1)),
key in vn_leaves, j in eachindex(axes(split_dicts, 2))
]
param_symbols = map(Symbol, vn_leaves)
# Handle statistics
stat_keys = DynamicPPL.OrderedCollections.OrderedSet{Symbol}()
for ps in params_and_stats
for k in keys(ps.stats)
push!(stat_keys, k)
end
end
stat_keys = collect(stat_keys)
stat_vals = [
get(params_and_stats[i, j].stats, key, missing) for
i in eachindex(axes(params_and_stats, 1)), key in stat_keys,
j in eachindex(axes(params_and_stats, 2))
]
# Construct name map and info
name_map = (internals=stat_keys,)
info = (
varname_to_symbol=DynamicPPL.OrderedCollections.OrderedDict(
zip(all_vn_leaves, param_symbols)
),
)
# Concatenate parameter and statistic values
vals = cat(param_vals, stat_vals; dims=2)
symbols = vcat(param_symbols, stat_keys)
return MCMCChains.Chains(MCMCChains.concretize(vals), symbols, name_map; info=info)
end
function DynamicPPL.to_chains(
::Type{MCMCChains.Chains}, ps::AbstractVector{<:DynamicPPL.ParamsWithStats}
)
return DynamicPPL.to_chains(MCMCChains.Chains, hcat(ps))
end

function DynamicPPL.from_chains(
::Type{T}, chain::MCMCChains.Chains
) where {T<:AbstractDict{<:DynamicPPL.VarName}}
idxs = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
matrix = map(idxs) do (sample_idx, chain_idx)
d = T()
for vn in DynamicPPL.varnames(chain)
d[vn] = DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx)
end
d
end
return matrix
end
function DynamicPPL.from_chains(::Type{NamedTuple}, chain::MCMCChains.Chains)
idxs = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
matrix = map(idxs) do (sample_idx, chain_idx)
get(chain[sample_idx, :, chain_idx], keys(chain); flatten=true)
end
return matrix
end
function DynamicPPL.from_chains(
::Type{DynamicPPL.ParamsWithStats}, chain::MCMCChains.Chains
)
idxs = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
internals_chain = MCMCChains.get_sections(chain, :internals)
params = DynamicPPL.from_chains(
DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,eltype(chain.value)},
chain,
)
stats = DynamicPPL.from_chains(NamedTuple, internals_chain)
return map(idxs) do (sample_idx, chain_idx)
DynamicPPL.ParamsWithStats(
params[sample_idx, chain_idx], stats[sample_idx, chain_idx]
)
end
end

"""
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)

Expand Down Expand Up @@ -110,7 +217,6 @@ function DynamicPPL.predict(
DynamicPPL.VarInfo(),
(
DynamicPPL.LogPriorAccumulator(),
DynamicPPL.LogJacobianAccumulator(),
DynamicPPL.LogLikelihoodAccumulator(),
DynamicPPL.ValuesAsInModelAccumulator(false),
),
Expand All @@ -129,23 +235,9 @@ function DynamicPPL.predict(
varinfo,
DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()),
)
vals = DynamicPPL.getacc(varinfo, Val(:ValuesAsInModel)).values
varname_vals = mapreduce(
collect,
vcat,
map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)),
)

return (varname_and_values=varname_vals, logp=DynamicPPL.getlogjoint(varinfo))
DynamicPPL.ParamsWithStats(varinfo, nothing)
end

chain_result = reduce(
MCMCChains.chainscat,
[
_predictive_samples_to_chains(predictive_samples[:, chain_idx]) for
chain_idx in 1:size(predictive_samples, 2)
],
)
chain_result = DynamicPPL.to_chains(MCMCChains.Chains, predictive_samples)
parameter_names = if include_all
MCMCChains.names(chain_result, :parameters)
else
Expand All @@ -164,45 +256,6 @@ function DynamicPPL.predict(
)
end

function _predictive_samples_to_arrays(predictive_samples)
variable_names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()

sample_dicts = map(predictive_samples) do sample
varname_value_pairs = sample.varname_and_values
varnames = map(first, varname_value_pairs)
values = map(last, varname_value_pairs)
for varname in varnames
push!(variable_names_set, varname)
end

return DynamicPPL.OrderedCollections.OrderedDict(zip(varnames, values))
end

variable_names = collect(variable_names_set)
variable_values = [
get(sample_dicts[i], key, missing) for i in eachindex(sample_dicts),
key in variable_names
]

return variable_names, variable_values
end

function _predictive_samples_to_chains(predictive_samples)
variable_names, variable_values = _predictive_samples_to_arrays(predictive_samples)
variable_names_symbols = map(Symbol, variable_names)

internal_parameters = [:lp]
log_probabilities = reshape([sample.logp for sample in predictive_samples], :, 1)

parameter_names = [variable_names_symbols; internal_parameters]
parameter_values = hcat(variable_values, log_probabilities)
parameter_values = MCMCChains.concretize(parameter_values)

return MCMCChains.Chains(
parameter_values, parameter_names, (internals=internal_parameters,)
)
end

"""
returned(model::Model, chain::MCMCChains.Chains)

Expand Down
4 changes: 4 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ export AbstractVarInfo,
prefix,
returned,
to_submodel,
# Chain construction
ParamsWithStats,
to_chains,
# Convenience macros
@addlogprob!,
value_iterator_from_chain,
Expand Down Expand Up @@ -194,6 +197,7 @@ include("model_utils.jl")
include("extract_priors.jl")
include("values_as_in_model.jl")
include("bijector.jl")
include("to_chains.jl")

include("debug_utils.jl")
using .DebugUtils
Expand Down
Loading
Loading