Skip to content

Commit

Permalink
Make ess_rhat return a NamedTuple (#74)
Browse files Browse the repository at this point in the history
* Return NamedTuple from ess_rhat

* Test NamedTuple is the inferred type

* Return same type from other methods

* Make line slightly more readable

* Update src/ess_rhat.jl

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

* Add missing NamedTuple

---------

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
  • Loading branch information
sethaxen and devmotion committed Feb 27, 2023
1 parent ee6d5d1 commit fb3f906
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
16 changes: 10 additions & 6 deletions src/ess_rhat.jl
Expand Up @@ -280,7 +280,7 @@ function _ess(estimator, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs
return _ess(Val(:basic), x; kwargs...)
end
function _ess(kind::Val, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...)
return first(_ess_rhat(kind, samples; kwargs...))
return _ess_rhat(kind, samples; kwargs...).ess
end
function _ess(
::Val{:tail},
Expand Down Expand Up @@ -411,7 +411,11 @@ function _rhat(::Val{:rank}, x::AbstractArray{<:Union{Missing,Real},3}; kwargs..
end

"""
ess_rhat(samples::AbstractArray{<:Union{Missing,Real},3}; kind::Symbol=:rank, kwargs...)
ess_rhat(
samples::AbstractArray{<:Union{Missing,Real},3};
kind::Symbol=:rank,
kwargs...,
) -> NamedTuple{(:ess, :rhat)}
Estimate the effective sample size and ``\\widehat{R}`` of the `samples` of shape
`(draws, chains, parameters)`.
Expand Down Expand Up @@ -468,7 +472,7 @@ function _ess_rhat(
ess = similar(chains, T, axes_out)
rhat = similar(chains, T, axes_out)

T === Missing && return ess, rhat
T === Missing && return (; ess, rhat)

# define caches for mean and variance
chain_mean = Array{T}(undef, 1, nchains)
Expand Down Expand Up @@ -565,7 +569,7 @@ function _ess_rhat(
ess[i] = min(ntotal / τ, ess_max)
end

return ess, rhat
return (; ess, rhat)
end
function _ess_rhat(::Val{:bulk}, x::AbstractArray{<:Union{Missing,Real},3}; kwargs...)
return _ess_rhat(Val(:basic), _rank_normalize(x); kwargs...)
Expand All @@ -578,15 +582,15 @@ function _ess_rhat(
)
S = _ess(kind, x; split_chains=split_chains, kwargs...)
R = _rhat(kind, x; split_chains=split_chains)
return S, R
return (ess=S, rhat=R)
end
function _ess_rhat(
::Val{:rank}, x::AbstractArray{<:Union{Missing,Real},3}; split_chains::Int=2, kwargs...
)
Sbulk, Rbulk = _ess_rhat(Val(:bulk), x; split_chains=split_chains, kwargs...)
Rtail = _rhat(Val(:tail), x; split_chains=split_chains)
Rrank = map(max, Rtail, Rbulk)
return Sbulk, Rrank
return (ess=Sbulk, rhat=Rrank)
end

# Compute an expectand `z` such that ``\\textrm{mean-ESS}(z) ≈ \\textrm{f-ESS}(x)``.
Expand Down
6 changes: 4 additions & 2 deletions test/ess_rhat.jl
Expand Up @@ -45,14 +45,16 @@ mymean(x) = mean(x)
TV = Vector{T}
kind === :rank || @test @inferred(ess(x; kind=kind)) isa TV
@test @inferred(rhat(x; kind=kind)) isa TV
@test @inferred(ess_rhat(x; kind=kind)) isa Tuple{TV,TV}
@test @inferred(ess_rhat(x; kind=kind)) isa
NamedTuple{(:ess, :rhat),Tuple{TV,TV}}
end
@testset "Int" begin
x = rand(1:10, 100, 4, 2)
TV = Vector{Float64}
kind === :rank || @test @inferred(ess(x; kind=kind)) isa TV
@test @inferred(rhat(x; kind=kind)) isa TV
@test @inferred(ess_rhat(x; kind=kind)) isa Tuple{TV,TV}
@test @inferred(ess_rhat(x; kind=kind)) isa
NamedTuple{(:ess, :rhat),Tuple{TV,TV}}
end
end
@testset for kind in (mean, median, mad, std, Base.Fix2(quantile, 0.25))
Expand Down

0 comments on commit fb3f906

Please sign in to comment.