Skip to content

Commit

Permalink
Raise an error if maxlag too low (#64)
Browse files Browse the repository at this point in the history
* Throw error if maxlag too small

* Return if Missing

* Update tests

* Increment version number

* Remove type union

* Document also constraints on number of iterations

* Rewrite docs

* Make errors clearer

* Update tests
  • Loading branch information
sethaxen committed Jan 18, 2023
1 parent 2e07d21 commit 189dba5
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 32 deletions.
2 changes: 1 addition & 1 deletion Project.toml
@@ -1,7 +1,7 @@
name = "MCMCDiagnosticTools"
uuid = "be115224-59cd-429b-ad48-344e309966f0"
authors = ["David Widmann"]
version = "0.2.6"
version = "0.3.0-DEV"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Expand Up @@ -10,7 +10,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[compat]
Documenter = "0.27"
EvoTrees = "0.14.7"
MCMCDiagnosticTools = "0.2"
MCMCDiagnosticTools = "0.3"
MLJBase = "0.19, 0.20, 0.21"
MLJIteration = "0.5"
julia = "1.6"
25 changes: 15 additions & 10 deletions src/ess.jl
Expand Up @@ -206,15 +206,17 @@ end
Estimate the effective sample size and ``\\widehat{R}`` of the `samples` of shape
`(draws, chains, parameters)` with the `method`.
`maxlag` indicates the maximum lag for which autocovariance is computed.
By default, the computed ESS and ``\\widehat{R}`` values correspond to the estimator `mean`.
Other estimators can be specified by passing a function `estimator` (see below).
`split_chains` indicates the number of chains each chain is split into.
When `split_chains > 1`, then the diagnostics check for within-chain convergence. When
`d = mod(draws, split_chains) > 0`, i.e. the chains cannot be evenly split, then 1 draw
is discarded after each of the first `d` splits within each chain.
is discarded after each of the first `d` splits within each chain. There must be at least
3 draws in each chain after splitting.
`maxlag` indicates the maximum lag for which autocovariance is computed and must be greater
than 0.
For a given estimand, it is recommended that the ESS is at least `100 * chains` and that
``\\widehat{R} < 1.01``.[^VehtariGelman2021]
Expand Down Expand Up @@ -266,10 +268,17 @@ function ess_rhat(
# when chains have mixed poorly anyways.
# leave the last even autocorrelation as a bias term that reduces variance for
# case of antithetical chains, see below
maxlag = min(maxlag, niter - 4)
if !(maxlag > 0) || T === Missing
return similar(chains, Missing, axes_out), similar(chains, Missing, axes_out)
if !(niter > 4)
throw(ArgumentError("number of draws after splitting must >4 but is $niter."))
end
maxlag > 0 || throw(DomainError(maxlag, "maxlag must be >0."))
maxlag = min(maxlag, niter - 4)

# define output arrays
ess = similar(chains, T, axes_out)
rhat = similar(chains, T, axes_out)

T === Missing && return ess, rhat

# define caches for mean and variance
chain_mean = Array{T}(undef, 1, nchains)
Expand All @@ -282,10 +291,6 @@ function ess_rhat(
# define cache for the computation of the autocorrelation
esscache = build_cache(method, samples, chain_var)

# define output arrays
ess = similar(chains, T, axes_out)
rhat = similar(chains, T, axes_out)

# set maximum ess for antithetic chains, see below
ess_max = ntotal * log10(oftype(one(T), ntotal))

Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Expand Up @@ -25,7 +25,7 @@ EvoTrees = "0.14.7"
FFTW = "1.1"
LogDensityProblems = "0.12, 1, 2"
LogExpFunctions = "0.3"
MCMCDiagnosticTools = "0.2"
MCMCDiagnosticTools = "0.3"
MLJBase = "0.19, 0.20, 0.21"
MLJIteration = "0.5"
MLJLIBSVMInterface = "0.2"
Expand Down
29 changes: 10 additions & 19 deletions test/ess.jl
Expand Up @@ -104,16 +104,15 @@ end
end

@testset "ESS and R̂ only promote eltype when necessary" begin
TM = Vector{Missing}
@testset for T in (Float32, Float64)
x = rand(T, 100, 4, 2)
TV = Vector{T}
@inferred Union{Tuple{TV,TV},Tuple{TM,TM}} ess_rhat(x)
@inferred Tuple{TV,TV} ess_rhat(x)
end
@testset "Int" begin
x = rand(1:10, 100, 4, 2)
TV = Vector{Float64}
@inferred Union{Tuple{TV,TV},Tuple{TM,TM}} ess_rhat(x)
@inferred Tuple{TV,TV} ess_rhat(x)
end
end

Expand All @@ -135,11 +134,6 @@ end
@test axes(S3, 1) == axes(y, 3)
@test R3 isa OffsetVector{Missing}
@test axes(R3, 1) == axes(y, 3)
S4, R4 = ess_rhat(y; maxlag=0) # return eltype should be Missing
@test S4 isa OffsetVector{Missing}
@test axes(S4, 1) == axes(y, 3)
@test R4 isa OffsetVector{Missing}
@test axes(R4, 1) == axes(y, 3)
end

@testset "ESS and R̂ (identical samples)" begin
Expand All @@ -159,18 +153,15 @@ end
end
end

@testset "ESS and R̂ (single sample)" begin # check that issue #137 is fixed
@testset "ESS and R̂ errors" begin # check that issue #137 is fixed
x = rand(4, 3, 5)

for method in (ESSMethod(), FFTESSMethod(), BDAESSMethod())
# analyze array
ess_array, rhat_array = ess_rhat(x; method=method, split_chains=1)

@test length(ess_array) == size(x, 3)
@test all(ismissing, ess_array) # since min(maxlag, niter - 4) = 0
@test length(rhat_array) == size(x, 3)
@test all(ismissing, rhat_array)
end
x2 = rand(5, 3, 5)
@test_throws ArgumentError ess_rhat(x; split_chains=1)
ess_rhat(x2; split_chains=1)
@test_throws ArgumentError ess_rhat(x2; split_chains=2)
x3 = rand(100, 3, 5)
ess_rhat(x3; maxlag=1)
@test_throws DomainError ess_rhat(x3; maxlag=0)
end

@testset "ESS and R̂ with Union{Missing,Float64} eltype" begin
Expand Down

0 comments on commit 189dba5

Please sign in to comment.