Skip to content

Commit

Permalink
Redesign of MCSE (#63)
Browse files Browse the repository at this point in the history
* Add mcse_sbm

* Update description of `estimator`

* Add specialized estimators for mean, std, and quantile

* Remove vector methods, defaulting to sbm

* Update docstring

* Fix bugs

* Update docstrings

* Update docstring

* Move helper functions to own file

* Rearrange tests

* Update mcse tests

* Export mcse_sbm

* Increment minor version number with DEV suffix

* Increment docs and tests version numbers

* Add additional citation

* Update diagnostics to use new mcse

* Increase tolerance of mcse tests

* Increase tolerance more

* Add mcse_sbm to docs

* Skip high autocorrelation tests for mcse_sbm

* Note underestimation for SBM

* Update src/mcse.jl

* Don't enforce type

* Document kwargs passed to mcse

* Cross-link mcse and ess_rhat docstrings

* Document derivation of mcse for std

* Test type-inferrability of ess_rhat

* Make sure ess_rhat for quantiles not promoted

* Make sure ess_rhat for median type-inferrable

* Implement specific method for median

* Return missing if any are missing

* Add mcse tests

* Decrease the number of checks

* Make ESS/MCSE for median with with Union{Missing,Real}

* Make _fold_around_median type-inferrable

* Increase tolerance for exhaustive tests

* Fix _fold_around_median

* Fix count of checks

* Increase the number of draws

improves the quality of the estimates and reduces random failures

* Apply suggestions from code review

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

* Make sure heideldiag and gewekediag preserve input type

* Consistently use first and last for ess_rhat

* Copy comment to _fold_around_median

* Make mcse_sbm an internal function

* Update tests

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
  • Loading branch information
sethaxen and devmotion committed Jan 19, 2023
1 parent 189dba5 commit 9e78f78
Show file tree
Hide file tree
Showing 13 changed files with 299 additions and 140 deletions.
2 changes: 1 addition & 1 deletion src/MCMCDiagnosticTools.jl
Expand Up @@ -7,7 +7,7 @@ using Distributions: Distributions
using MLJModelInterface: MLJModelInterface as MMI
using SpecialFunctions: SpecialFunctions
using StatsBase: StatsBase
using StatsFuns: StatsFuns
using StatsFuns: StatsFuns, sqrt2
using Tables: Tables

using LinearAlgebra: LinearAlgebra
Expand Down
19 changes: 13 additions & 6 deletions src/ess.jl
Expand Up @@ -222,7 +222,7 @@ For a given estimand, it is recommended that the ESS is at least `100 * chains`
``\\widehat{R} < 1.01``.[^VehtariGelman2021]
See also: [`ESSMethod`](@ref), [`FFTESSMethod`](@ref), [`BDAESSMethod`](@ref),
[`ess_rhat_bulk`](@ref), [`ess_tail`](@ref), [`rhat_tail`](@ref)
[`ess_rhat_bulk`](@ref), [`ess_tail`](@ref), [`rhat_tail`](@ref), [`mcse`](@ref)
## Estimators
Expand Down Expand Up @@ -435,8 +435,8 @@ function ess_tail(
# workaround for https://github.com/JuliaStats/Statistics.jl/issues/136
T = Base.promote_eltype(x, tail_prob)
return min.(
ess_rhat(Base.Fix2(Statistics.quantile, T(tail_prob / 2)), x; kwargs...)[1],
ess_rhat(Base.Fix2(Statistics.quantile, T(1 - tail_prob / 2)), x; kwargs...)[1],
first(ess_rhat(Base.Fix2(Statistics.quantile, T(tail_prob / 2)), x; kwargs...)),
first(ess_rhat(Base.Fix2(Statistics.quantile, T(1 - tail_prob / 2)), x; kwargs...)),
)
end

Expand Down Expand Up @@ -464,13 +464,20 @@ See also: [`ess_tail`](@ref), [`ess_rhat_bulk`](@ref)
doi: [10.1214/20-BA1221](https://doi.org/10.1214/20-BA1221)
arXiv: [1903.08008](https://arxiv.org/abs/1903.08008)
"""
rhat_tail(x; kwargs...) = ess_rhat_bulk(_fold_around_median(x); kwargs...)[2]
rhat_tail(x; kwargs...) = last(ess_rhat_bulk(_fold_around_median(x); kwargs...))

# Compute an expectand `z` such that ``\\textrm{mean-ESS}(z) ≈ \\textrm{f-ESS}(x)``.
# If no proxy expectand for `f` is known, `nothing` is returned.
_expectand_proxy(f, x) = nothing
function _expectand_proxy(::typeof(Statistics.median), x)
return x .≤ Statistics.median(x; dims=(1, 2))
y = similar(x)
# avoid using the `dims` keyword for median because it
# - can error for Union{Missing,Real} (https://github.com/JuliaStats/Statistics.jl/issues/8)
# - is type-unstable (https://github.com/JuliaStats/Statistics.jl/issues/39)
for (xi, yi) in zip(eachslice(x; dims=3), eachslice(y; dims=3))
yi .= xi .≤ Statistics.median(vec(xi))
end
return y
end
function _expectand_proxy(::typeof(Statistics.std), x)
return (x .- Statistics.mean(x; dims=(1, 2))) .^ 2
Expand All @@ -480,7 +487,7 @@ function _expectand_proxy(::typeof(StatsBase.mad), x)
return _expectand_proxy(Statistics.median, x_folded)
end
function _expectand_proxy(f::Base.Fix2{typeof(Statistics.quantile),<:Real}, x)
y = similar(x, Bool)
y = similar(x)
# currently quantile does not support a dims keyword argument
for (xi, yi) in zip(eachslice(x; dims=3), eachslice(y; dims=3))
yi .= xi .≤ f(vec(xi))
Expand Down
12 changes: 8 additions & 4 deletions src/gewekediag.jl
Expand Up @@ -12,6 +12,8 @@ samples are independent. A non-significant test p-value indicates convergence.
p-values indicate non-convergence and the possible need to discard initial samples as a
burn-in sequence or to simulate additional samples.
`kwargs` are forwarded to [`mcse`](@ref).
[^Geweke1991]: Geweke, J. F. (1991). Evaluating the accuracy of sampling-based approaches to the calculation of posterior moments (No. 148). Federal Reserve Bank of Minneapolis.
"""
function gewekediag(x::AbstractVector{<:Real}; first::Real=0.1, last::Real=0.5, kwargs...)
Expand All @@ -22,10 +24,12 @@ function gewekediag(x::AbstractVector{<:Real}; first::Real=0.1, last::Real=0.5,
n = length(x)
x1 = x[1:round(Int, first * n)]
x2 = x[round(Int, n - last * n + 1):n]
z =
(Statistics.mean(x1) - Statistics.mean(x2)) /
hypot(mcse(x1; kwargs...), mcse(x2; kwargs...))
p = SpecialFunctions.erfc(abs(z) / sqrt(2))
s = hypot(
Base.first(mcse(Statistics.mean, reshape(x1, :, 1, 1); split_chains=1, kwargs...)),
Base.first(mcse(Statistics.mean, reshape(x2, :, 1, 1); split_chains=1, kwargs...)),
)
z = (Statistics.mean(x1) - Statistics.mean(x2)) / s
p = SpecialFunctions.erfc(abs(z) / sqrt2)

return (zscore=z, pvalue=p)
end
15 changes: 10 additions & 5 deletions src/heideldiag.jl
Expand Up @@ -9,31 +9,36 @@ means are within a target ratio. Stationarity is rejected (0) for significant te
Halfwidth tests are rejected (0) if observed ratios are greater than the target, as is the
case for `s2` and `beta[1]`.
`kwargs` are forwarded to [`mcse`](@ref).
[^Heidelberger1983]: Heidelberger, P., & Welch, P. D. (1983). Simulation run length control in the presence of an initial transient. Operations Research, 31(6), 1109-1144.
"""
function heideldiag(
x::AbstractVector{<:Real}; alpha::Real=0.05, eps::Real=0.1, start::Int=1, kwargs...
x::AbstractVector{<:Real}; alpha::Real=1//20, eps::Real=0.1, start::Int=1, kwargs...
)
n = length(x)
delta = trunc(Int, 0.10 * n)
y = x[trunc(Int, n / 2):end]
S0 = length(y) * mcse(y; kwargs...)^2
i, pvalue, converged, ybar = 1, 1.0, false, NaN
T = typeof(zero(eltype(x)) / 1)
s = first(mcse(Statistics.mean, reshape(y, :, 1, 1); split_chains=1, kwargs...))
S0 = length(y) * s^2
i, pvalue, converged, ybar = 1, one(T), false, T(NaN)
while i < n / 2
y = x[i:end]
m = length(y)
ybar = Statistics.mean(y)
B = cumsum(y) - ybar * collect(1:m)
Bsq = (B .* B) ./ (m * S0)
I = sum(Bsq) / m
pvalue = 1.0 - pcramer(I)
pvalue = 1 - T(pcramer(I))
converged = pvalue > alpha
if converged
break
end
i += delta
end
halfwidth = sqrt(2) * SpecialFunctions.erfcinv(alpha) * mcse(y; kwargs...)
s = first(mcse(Statistics.mean, reshape(y, :, 1, 1); split_chains=1, kwargs...))
halfwidth = sqrt2 * SpecialFunctions.erfcinv(T(alpha)) * s
passed = halfwidth / abs(ybar) <= eps
return (
burnin=i + start - 2,
Expand Down
171 changes: 114 additions & 57 deletions src/mcse.jl
@@ -1,72 +1,129 @@
const normcdf1 = 0.8413447460685429 # StatsFuns.normcdf(1)
const normcdfn1 = 0.15865525393145705 # StatsFuns.normcdf(-1)

"""
mcse(x::AbstractVector{<:Real}; method::Symbol=:imse, kwargs...)
mcse(estimator, samples::AbstractArray{<:Union{Missing,Real}}; kwargs...)
Estimate the Monte Carlo standard errors (MCSE) of the `estimator` applied to `samples` of
shape `(draws, chains, parameters)`.
See also: [`ess_rhat`](@ref)
## Estimators
`estimator` must accept a vector of the same `eltype` as `samples` and return a real estimate.
Compute the Monte Carlo standard error (MCSE) of samples `x`.
The optional argument `method` describes how the errors are estimated. Possible options are:
For the following estimators, the effective sample size [`ess_rhat`](@ref) and an estimate
of the asymptotic variance are used to compute the MCSE, and `kwargs` are forwarded to
`ess_rhat`:
- `Statistics.mean`
- `Statistics.median`
- `Statistics.std`
- `Base.Fix2(Statistics.quantile, p::Real)`
- `:bm` for batch means [^Glynn1991]
- `:imse` initial monotone sequence estimator [^Geyer1992]
- `:ipse` initial positive sequence estimator [^Geyer1992]
For other estimators, the subsampling bootstrap method (SBM)[^FlegalJones2011][^Flegal2012]
is used as a fallback, and the only accepted `kwargs` are `batch_size`, which indicates the
size of the overlapping batches used to estimate the MCSE, defaulting to
`floor(Int, sqrt(draws * chains))`. Note that SBM tends to underestimate the MCSE,
especially for highly autocorrelated chains. One should verify that autocorrelation is low
by checking the bulk- and tail-[`ess_rhat`](@ref) values.
[^Glynn1991]: Glynn, P. W., & Whitt, W. (1991). Estimating the asymptotic variance with batch means. Operations Research Letters, 10(8), 431-435.
[^FlegalJones2011]: Flegal JM, Jones GL. (2011) Implementing MCMC: estimating with confidence.
Handbook of Markov Chain Monte Carlo. pp. 175-97.
[pdf](http://faculty.ucr.edu/~jflegal/EstimatingWithConfidence.pdf)
[^Flegal2012]: Flegal JM. (2012) Applicability of subsampling bootstrap methods in Markov chain Monte Carlo.
Monte Carlo and Quasi-Monte Carlo Methods 2010. pp. 363-72.
doi: [10.1007/978-3-642-27440-4_18](https://doi.org/10.1007/978-3-642-27440-4_18)
[^Geyer1992]: Geyer, C. J. (1992). Practical Markov Chain Monte Carlo. Statistical Science, 473-483.
"""
function mcse(x::AbstractVector{<:Real}; method::Symbol=:imse, kwargs...)
return if method === :bm
mcse_bm(x; kwargs...)
elseif method === :imse
mcse_imse(x)
elseif method === :ipse
mcse_ipse(x)
else
throw(ArgumentError("unsupported MCSE method $method"))
mcse(f, x::AbstractArray{<:Union{Missing,Real},3}; kwargs...) = _mcse_sbm(f, x; kwargs...)
function mcse(
::typeof(Statistics.mean), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...
)
S = first(ess_rhat(Statistics.mean, samples; kwargs...))
return dropdims(Statistics.std(samples; dims=(1, 2)); dims=(1, 2)) ./ sqrt.(S)
end
function mcse(
::typeof(Statistics.std), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...
)
x = (samples .- Statistics.mean(samples; dims=(1, 2))) .^ 2 # expectand proxy
S = first(ess_rhat(Statistics.mean, x; kwargs...))
# asymptotic variance of sample variance estimate is Var[var] = E[μ₄] - E[var]²,
# where μ₄ is the 4th central moment
# by the delta method, Var[std] = Var[var] / 4E[var] = (E[μ₄]/E[var] - E[var])/4,
# See e.g. Chapter 3 of Van der Vaart, AW. (200) Asymptotic statistics. Vol. 3.
mean_var = dropdims(Statistics.mean(x; dims=(1, 2)); dims=(1, 2))
mean_moment4 = dropdims(Statistics.mean(abs2, x; dims=(1, 2)); dims=(1, 2))
return @. sqrt((mean_moment4 / mean_var - mean_var) / S) / 2
end
function mcse(
f::Base.Fix2{typeof(Statistics.quantile),<:Real},
samples::AbstractArray{<:Union{Missing,Real},3};
kwargs...,
)
p = f.x
S = first(ess_rhat(f, samples; kwargs...))
T = eltype(S)
R = promote_type(eltype(samples), typeof(oneunit(eltype(samples)) / sqrt(oneunit(T))))
values = similar(S, R)
for (i, xi, Si) in zip(eachindex(values), eachslice(samples; dims=3), S)
values[i] = _mcse_quantile(vec(xi), p, Si)
end
return values
end
function mcse(
::typeof(Statistics.median), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...
)
S = first(ess_rhat(Statistics.median, samples; kwargs...))
T = eltype(S)
R = promote_type(eltype(samples), typeof(oneunit(eltype(samples)) / sqrt(oneunit(T))))
values = similar(S, R)
for (i, xi, Si) in zip(eachindex(values), eachslice(samples; dims=3), S)
values[i] = _mcse_quantile(vec(xi), 1//2, Si)
end
return values
end

function mcse_bm(x::AbstractVector{<:Real}; size::Int=floor(Int, sqrt(length(x))))
n = length(x)
m = min(div(n, 2), size)
m == size || @warn "batch size was reduced to $m"
mcse = StatsBase.sem(Statistics.mean(@view(x[(i + 1):(i + m)])) for i in 0:m:(n - m))
return mcse
function _mcse_quantile(x, p, Seff)
Seff === missing && return missing
S = length(x)
# quantile error distribution is asymptotically normal; estimate σ (mcse) with 2
# quadrature points: xl and xu, chosen as quantiles so that xu - xl = 2σ
# compute quantiles of error distribution in probability space (i.e. quantiles passed through CDF)
# Beta(α,β) is the approximate error distribution of quantile estimates
α = Seff * p + 1
β = Seff * (1 - p) + 1
prob_x_upper = StatsFuns.betainvcdf(α, β, normcdf1)
prob_x_lower = StatsFuns.betainvcdf(α, β, normcdfn1)
# use inverse ECDF to get quantiles in quantile (x) space
l = max(floor(Int, prob_x_lower * S), 1)
u = min(ceil(Int, prob_x_upper * S), S)
iperm = partialsortperm(x, l:u) # sort as little of x as possible
xl = x[first(iperm)]
xu = x[last(iperm)]
# estimate mcse from quantiles
return (xu - xl) / 2
end

function mcse_imse(x::AbstractVector{<:Real})
n = length(x)
lags = [0, 1]
ghat = StatsBase.autocov(x, lags)
Ghat = sum(ghat)
@inbounds value = Ghat + ghat[2]
@inbounds for i in 2:2:(n - 2)
lags[1] = i
lags[2] = i + 1
StatsBase.autocov!(ghat, x, lags)
Ghat = min(Ghat, sum(ghat))
Ghat > 0 || break
value += 2 * Ghat
function _mcse_sbm(
f,
x::AbstractArray{<:Union{Missing,Real},3};
batch_size::Int=floor(Int, sqrt(size(x, 1) * size(x, 2))),
)
T = promote_type(eltype(x), typeof(zero(eltype(x)) / 1))
values = similar(x, T, (axes(x, 3),))
for (i, xi) in zip(eachindex(values), eachslice(x; dims=3))
values[i] = _mcse_sbm(f, vec(xi), batch_size)
end

mcse = sqrt(value / n)

return mcse
return values
end

function mcse_ipse(x::AbstractVector{<:Real})
function _mcse_sbm(f, x, batch_size)
any(x -> x === missing, x) && return missing
n = length(x)
lags = [0, 1]
ghat = StatsBase.autocov(x, lags)
@inbounds value = ghat[1] + 2 * ghat[2]
@inbounds for i in 2:2:(n - 2)
lags[1] = i
lags[2] = i + 1
StatsBase.autocov!(ghat, x, lags)
Ghat = sum(ghat)
Ghat > 0 || break
value += 2 * Ghat
end

mcse = sqrt(value / n)

return mcse
i1 = firstindex(x)
v = Statistics.var(
f(view(x, i:(i + batch_size - 1))) for i in i1:(i1 + n - batch_size);
corrected=false,
)
return sqrt(v * (batch_size//n))
end
11 changes: 10 additions & 1 deletion src/utils.jl
Expand Up @@ -145,7 +145,16 @@ end
Compute the absolute deviation of `x` from `Statistics.median(x)`.
"""
_fold_around_median(data) = abs.(data .- Statistics.median(data; dims=(1, 2)))
function _fold_around_median(x)
y = similar(x)
# avoid using the `dims` keyword for median because it
# - can error for Union{Missing,Real} (https://github.com/JuliaStats/Statistics.jl/issues/8)
# - is type-unstable (https://github.com/JuliaStats/Statistics.jl/issues/39)
for (xi, yi) in zip(eachslice(x; dims=3), eachslice(y; dims=3))
yi .= abs.(xi .- Statistics.median(vec(xi)))
end
return y
end

"""
_rank_normalize(x::AbstractArray{<:Any,3})
Expand Down

0 comments on commit 9e78f78

Please sign in to comment.