Skip to content

Commit

Permalink
Use biased autocov for all lags (#61)
Browse files Browse the repository at this point in the history
* Use biased autocov for ESSMethod and FFTESSMethod

* Increment patch version

* Run formatter

* Add test for autocov methods being equivalent to StatsBase

* Add StatsBase as a test dependency

* Increment patch number
  • Loading branch information
sethaxen committed Jan 12, 2023
1 parent 4070993 commit b6e8de9
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Project.toml
@@ -1,7 +1,7 @@
name = "MCMCDiagnosticTools"
uuid = "be115224-59cd-429b-ad48-344e309966f0"
authors = ["David Widmann"]
version = "0.2.2"
version = "0.2.3"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down
15 changes: 6 additions & 9 deletions src/ess.jl
Expand Up @@ -9,8 +9,6 @@ effective sample size of MCMC chains.
It is is based on the discussion by [^VehtariGelman2021] and uses the
biased estimator of the autocovariance, as discussed by [^Geyer1992].
In contrast to Geyer, the divisor `n - 1` is used in the estimation of
the autocovariance to obtain the unbiased estimator of the variance for lag 0.
[^VehtariGelman2021]: Vehtari, A., Gelman, A., Simpson, D., Carpenter, B., & Bürkner, P. C. (2021).
Rank-normalization, folding, and localization: An improved ``\\widehat {R}`` for
Expand Down Expand Up @@ -157,12 +155,9 @@ function mean_autocov(k::Int, cache::ESSCache)
)
end

# normalize autocovariance estimators by `niter - 1` instead
# of `niter - k` to obtain
# - unbiased estimators of the variance for lag 0
# - biased but more stable estimators for all other lags as discussed by
# Geyer (1992)
return s / (niter - 1)
# normalize autocovariance estimators by `niter` instead of `niter - k` to obtain biased
# but more stable estimators for all lags as discussed by Geyer (1992)
return s / niter
end

function mean_autocov(k::Int, cache::FFTESSCache)
Expand All @@ -174,9 +169,11 @@ function mean_autocov(k::Int, cache::FFTESSCache)
# we use biased but more stable estimators as discussed by Geyer (1992)
samples_cache = cache.samples_cache
chain_var = cache.chain_var
return Statistics.mean(1:nchains) do i
uncorrection_factor = (niter - 1)//niter # undo corrected=true for chain_var
result = Statistics.mean(1:nchains) do i
@inbounds(real(samples_cache[k + 1, i]) / real(samples_cache[1, i])) * chain_var[i]
end
return result * uncorrection_factor
end

function mean_autocov(k::Int, cache::BDAESSCache)
Expand Down
20 changes: 20 additions & 0 deletions test/ess.jl
Expand Up @@ -9,6 +9,18 @@ using Statistics
using StatsBase
using Test

struct ExplicitESSMethod <: MCMCDiagnosticTools.AbstractESSMethod end
struct ExplicitESSCache{S}
samples::S
end
function MCMCDiagnosticTools.build_cache(::ExplicitESSMethod, samples::Matrix, var::Vector)
return ExplicitESSCache(samples)
end
MCMCDiagnosticTools.update!(::ExplicitESSCache) = nothing
function MCMCDiagnosticTools.mean_autocov(k::Int, cache::ExplicitESSCache)
return mean(autocov(cache.samples, k:k; demean=true))
end

struct CauchyProblem end
LogDensityProblems.logdensity(p::CauchyProblem, θ) = -sum(log1psq, θ)
function LogDensityProblems.logdensity_and_gradient(p::CauchyProblem, θ)
Expand Down Expand Up @@ -121,6 +133,14 @@ end
end
end

@testset "Autocov of ESSMethod and FFTESSMethod equivalent to StatsBase" begin
x = randn(1_000, 10, 40)
ess_exp = ess_rhat(x; method=ExplicitESSMethod())[1]
@testset "$method" for method in [FFTESSMethod(), ESSMethod()]
@test ess_rhat(x; method=method)[1] ess_exp
end
end

@testset "ESS and R̂ for chains with 2 epochs that have not mixed" begin
# checks that splitting yields lower ESS estimates and higher Rhat estimates
x = randn(1000, 4, 10) .+ repeat([0, 10]; inner=(500, 1, 1))
Expand Down

2 comments on commit b6e8de9

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/75603

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.3 -m "<description of version>" b6e8de9a6ac330c4f7f6cc45185d1e509127fa1f
git push origin v0.2.3

Please sign in to comment.