/
from_namedtuple.jl
184 lines (161 loc) · 6.58 KB
/
from_namedtuple.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
"""
from_namedtuple(posterior::NamedTuple; kwargs...) -> InferenceData
from_namedtuple(posterior::Vector{Vector{<:NamedTuple}}; kwargs...) -> InferenceData
from_namedtuple(
posterior::NamedTuple,
sample_stats::Any,
posterior_predictive::Any,
predictions::Any,
log_likelihood::Any;
kwargs...
) -> InferenceData
Convert a `NamedTuple` or container of `NamedTuple`s to an `InferenceData`.
If containers are passed, they are flattened into a single `NamedTuple` with array elements
whose first dimensions correspond to the dimensions of the containers.
# Arguments
- `posterior`: The data to be converted. It may be of the following types:
+ `::NamedTuple`: The keys are the variable names and the values are arrays with
dimensions `(ndraws, nchains[, sizes...])`.
+ `::Vector{Vector{<:NamedTuple}}`: A vector of length `nchains` whose elements have
length `ndraws`.
# Keywords
- `posterior_predictive::Any=nothing`: Draws from the posterior predictive distribution
- `sample_stats::Any=nothing`: Statistics of the posterior sampling process
- `predictions::Any=nothing`: Out-of-sample predictions for the posterior.
- `prior=nothing`: Draws from the prior. Accepts the same types as `posterior`.
- `prior_predictive::Any=nothing`: Draws from the prior predictive distribution
- `sample_stats_prior::Any=nothing`: Statistics of the prior sampling process
- `observed_data::NamedTuple`: Observed data on which the `posterior` is
conditional. It should only contain data which is modeled as a random variable. Keys
are parameter names and values.
- `constant_data::NamedTuple`: Model constants, data included in the model
which is not modeled as a random variable. Keys are parameter names and values.
- `predictions_constant_data::NamedTuple`: Constants relevant to the model
predictions (i.e. new `x` values in a linear regression).
- `log_likelihood`: Pointwise log-likelihood for the data. It is recommended
to use this argument as a `NamedTuple` whose keys are observed variable names and whose
values are log likelihood arrays.
- `library`: Name of library that generated the draws
- `coords`: Map from named dimension to named indices
- `dims`: Map from variable name to names of its dimensions
# Returns
- `InferenceData`: The data with groups corresponding to the provided data
!!! note
If a `NamedTuple` is provided for `observed_data`, `constant_data`, or
predictions_constant_data`, any non-array values (e.g. integers) are converted to
0-dimensional arrays.
# Examples
```@example
using InferenceObjects
nchains = 2
ndraws = 100
data1 = (
x=rand(ndraws, nchains), y=randn(ndraws, nchains, 2), z=randn(ndraws, nchains, 3, 2)
)
idata1 = from_namedtuple(data1)
data2 = [[(x=rand(), y=randn(2), z=randn(3, 2)) for _ in 1:ndraws] for _ in 1:nchains];
idata2 = from_namedtuple(data2)
```
"""
from_namedtuple
function from_namedtuple(
posterior, posterior_predictive, sample_stats, predictions, log_likelihood; kwargs...
)
all_idata = InferenceData()
post_data = posterior === nothing ? posterior : namedtuple_of_arrays(posterior)
for (group, group_data) in [
:posterior_predictive => posterior_predictive,
:sample_stats => sample_stats,
:predictions => predictions,
:log_likelihood => log_likelihood,
]
group_data === nothing && continue
if post_data !== nothing
if group_data isa Symbol
group_data = (Symbol(group_data),)
end
if all(Base.Fix2(isa, Symbol), group_data)
group_data = NamedTuple{Tuple(group_data)}(post_data)
post_data = NamedTuple{Tuple(setdiff(keys(post_data), keys(group_data)))}(
post_data
)
end
isempty(group_data) && continue
end
group_dataset = convert_to_dataset(group_data; kwargs...)
all_idata = merge(all_idata, InferenceData(; group => group_dataset))
end
(post_data === nothing || isempty(post_data)) && return all_idata
post_dataset = convert_to_dataset(post_data; kwargs...)
all_idata = merge(all_idata, InferenceData(; posterior=post_dataset))
return all_idata
end
function from_namedtuple(
posterior::Union{NamedTuple,Nothing}=nothing;
posterior_predictive=nothing,
sample_stats=nothing,
predictions=nothing,
prior=nothing,
prior_predictive=nothing,
sample_stats_prior=nothing,
observed_data=nothing,
constant_data=nothing,
predictions_constant_data=nothing,
log_likelihood=nothing,
library=nothing,
kwargs...,
)
all_idata = from_namedtuple(
posterior,
posterior_predictive,
sample_stats,
predictions,
log_likelihood;
library=library,
kwargs...,
)
if any(x -> x !== nothing, [prior, prior_predictive, sample_stats_prior])
pre_prior_idata = from_namedtuple(
prior;
posterior_predictive=prior_predictive,
sample_stats=sample_stats_prior,
library,
kwargs...,
)
prior_idata = rekey(
pre_prior_idata,
(
posterior=:prior,
posterior_predictive=:prior_predictive,
sample_stats=:sample_stats_prior,
),
)
all_idata = merge(all_idata, prior_idata)
end
for (group, group_data) in [
:observed_data => observed_data,
:constant_data => constant_data,
:predictions_constant_data => predictions_constant_data,
]
group_data === nothing && continue
group_dataset = convert_to_dataset(group_data; library, default_dims=(), kwargs...)
all_idata = merge(all_idata, InferenceData(; group => group_dataset))
end
return all_idata
end
function from_namedtuple(data::AbstractVector{<:AbstractVector{<:NamedTuple}}; kwargs...)
return from_namedtuple(stack_chains(map(stack_draws, data)); kwargs...)
end
"""
convert_to_inference_data(obj::NamedTuple; kwargs...) -> InferenceData
convert_to_inference_data(obj::Vector{Vector{<:NamedTuple}}; kwargs...) -> InferenceData
Convert `obj` to an [`InferenceData`](@ref). See [`from_namedtuple`](@ref) for a description
of `obj` possibilities and `kwargs`.
"""
function convert_to_inference_data(
data::T; group=:posterior, kwargs...
) where {T<:Union{NamedTuple,AbstractVector{<:AbstractVector{<:NamedTuple}}}}
group = Symbol(group)
group === :posterior && return from_namedtuple(data; kwargs...)
return from_namedtuple(; group => data, kwargs...)
end