/
loo_pit.jl
124 lines (106 loc) · 5.6 KB
/
loo_pit.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
@doc """
loo_pit(idata::InferenceData, log_weights; kwargs...) -> DimArray
Compute LOO-PIT values using existing normalized log LOO importance weights.
# Keywords
- `y_name`: Name of observed data variable in `idata.observed_data`. If not provided, then
the only observed data variable is used.
- `y_pred_name`: Name of posterior predictive variable in `idata.posterior_predictive`.
If not provided, then `y_name` is used.
- `kwargs`: Remaining keywords are forwarded to the base method of `loo_pit`.
# Examples
Calculate LOO-PIT values using already computed log weights.
```jldoctest
julia> using ArviZExampleData, PosteriorStats
julia> idata = load_example_data("centered_eight");
julia> loo_result = loo(idata; var_name=:obs);
julia> loo_pit(idata, loo_result.psis_result.log_weights; y_name=:obs)
╭───────────────────────────────────────────╮
│ 8-element DimArray{Float64,1} loo_pit_obs │
├───────────────────────────────────────────┴──────────────────────────── dims ┐
↓ school Categorical{String} [Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
└──────────────────────────────────────────────────────────────────────────────┘
"Choate" 0.943511
"Deerfield" 0.63797
"Phillips Andover" 0.316697
"Phillips Exeter" 0.582252
"Hotchkiss" 0.295321
"Lawrenceville" 0.403318
"St. Paul's" 0.902508
"Mt. Hermon" 0.655275
```
"""
function PosteriorStats.loo_pit(
idata::InferenceObjects.InferenceData,
log_weights::AbstractArray;
y_name::Union{Symbol,Nothing}=nothing,
y_pred_name::Union{Symbol,Nothing}=nothing,
kwargs...,
)
(_y_name, y), (_, _y_pred) = observations_and_predictions(idata, y_name, y_pred_name)
y_pred = _draw_chains_params_array(_y_pred)
pitvals = PosteriorStats.loo_pit(y, y_pred, log_weights; kwargs...)
return DimensionalData.rebuild(pitvals; name=Symbol("loo_pit_$(_y_name)"))
end
@doc """
loo_pit(idata::InferenceData; kwargs...) -> DimArray
Compute LOO-PIT from groups in `idata` using PSIS-LOO.
# Keywords
- `y_name`: Name of observed data variable in `idata.observed_data`. If not provided, then
the only observed data variable is used.
- `y_pred_name`: Name of posterior predictive variable in `idata.posterior_predictive`.
If not provided, then `y_name` is used.
- `log_likelihood_name`: Name of log-likelihood variable in `idata.log_likelihood`.
If not provided, then `y_name` is used if `idata` has a `log_likelihood` group,
otherwise the only variable is used.
- `reff::Union{Real,AbstractArray{<:Real}}`: The relative effective sample size(s) of the
_likelihood_ values. If an array, it must have the same data dimensions as the
corresponding log-likelihood variable. If not provided, then this is estimated using
`ess`.
- `kwargs`: Remaining keywords are forwarded to the base method of `loo_pit`.
# Examples
Calculate LOO-PIT values using as test quantity the observed values themselves.
```jldoctest
julia> using ArviZExampleData, PosteriorStats
julia> idata = load_example_data("centered_eight");
julia> loo_pit(idata; y_name=:obs)
╭───────────────────────────────────────────╮
│ 8-element DimArray{Float64,1} loo_pit_obs │
├───────────────────────────────────────────┴──────────────────────────── dims ┐
↓ school Categorical{String} [Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
└──────────────────────────────────────────────────────────────────────────────┘
"Choate" 0.943511
"Deerfield" 0.63797
"Phillips Andover" 0.316697
"Phillips Exeter" 0.582252
"Hotchkiss" 0.295321
"Lawrenceville" 0.403318
"St. Paul's" 0.902508
"Mt. Hermon" 0.655275
```
"""
function PosteriorStats.loo_pit(
idata::InferenceObjects.InferenceData;
y_name::Union{Symbol,Nothing}=nothing,
y_pred_name::Union{Symbol,Nothing}=nothing,
log_likelihood_name::Union{Symbol,Nothing}=nothing,
reff=nothing,
kwargs...,
)
(_y_name, y), (_, _y_pred) = observations_and_predictions(idata, y_name, y_pred_name)
y_pred = _draw_chains_params_array(_y_pred)
if log_likelihood_name === nothing
if haskey(idata, :log_likelihood)
_log_like = log_likelihood(idata.log_likelihood, _y_name)
elseif haskey(idata, :sample_stats) && haskey(idata.sample_stats, :log_likelihood)
_log_like = idata.sample_stats.log_likelihood
else
throw(ArgumentError("There must be a `log_likelihood` group in `idata`"))
end
else
_log_like = log_likelihood(idata.log_likelihood, log_likelihood_name)
end
log_like = _draw_chains_params_array(_log_like)
psis_result = PosteriorStats.loo(log_like; reff).psis_result
pitvals = PosteriorStats.loo_pit(y, y_pred, psis_result.log_weights; kwargs...)
return DimensionalData.rebuild(pitvals; name=Symbol("loo_pit_$(_y_name)"))
end