New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixes to return types of ess_rhat #62
Conversation
Pull Request Test Coverage Report for Build 3930744323
💛 - Coveralls |
Codecov ReportBase: 95.21% // Head: 95.63% // Increases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## main #62 +/- ##
==========================================
+ Coverage 95.21% 95.63% +0.42%
==========================================
Files 10 10
Lines 710 710
==========================================
+ Hits 676 679 +3
+ Misses 34 31 -3
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
With DimensionalData v0.24.1 and this PR, this now works: julia> using DimensionalData, MCMCDiagnosticTools
julia> x = DimArray(randn(Float32, 1000, 4, 3), (:draw, :chain, Dim{:var}([:a, :b, :c])));
julia> S, R = ess_rhat_bulk(x);
julia> S
3-element DimArray{Float32,1} with dimensions:
Dim{:var} Categorical{Symbol} Symbol[a, b, c] ForwardOrdered
:a 3899.19
:b 3973.59
:c 3797.71
julia> R
3-element DimArray{Float32,1} with dimensions:
Dim{:var} Categorical{Symbol} Symbol[a, b, c] ForwardOrdered
:a 0.999755
:b 1.00018
:c 1.0003 No such luck for any other array types I've tried with named axes/indices, but seems they can opt in if they just implement the documented This PR should be ready for final review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, just one minor question. Possibly something for a different PR though.
|
||
# do not compute estimates if there is only one sample or lag | ||
maxlag = min(maxlag, niter - 1) | ||
maxlag > 0 || return fill(missing, nparams), fill(missing, nparams) | ||
if !(maxlag > 0) | ||
return similar(chains, Missing, axes_out), similar(chains, Missing, axes_out) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have any tests for this? And any @test_inferred
checks - it seems they should fail due to this line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, these lines are tested:
MCMCDiagnosticTools.jl/test/ess.jl
Lines 122 to 134 in b6e8de9
@testset "ESS and R̂ (single sample)" begin # check that issue #137 is fixed | |
x = rand(1, 3, 5) | |
for method in (ESSMethod(), FFTESSMethod(), BDAESSMethod()) | |
# analyze array | |
ess_array, rhat_array = ess_rhat(x; method=method) | |
@test length(ess_array) == size(x, 3) | |
@test all(ismissing, ess_array) # since min(maxlag, niter - 1) = 0 | |
@test length(rhat_array) == size(x, 3) | |
@test all(ismissing, rhat_array) | |
end | |
end |
I added more tests that the array type returned here preserves the final axis and some @inferred
tests specifying the type union.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO this type union is a bit annoying. I wonder if in a breaking release we should change this to an error, similar to what StatsBase does e.g. in this case:
julia> using StatsBase
julia> autocov(rand(3), [4])
ERROR: lags must be less than the sample length.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] check_lags
@ ~/.julia/packages/StatsBase/XgjIN/src/signalcorr.jl:15 [inlined]
[3] autocov!(r::Vector{Float64}, x::Vector{Float64}, lags::Vector{Int64}; demean::Bool)
@ StatsBase ~/.julia/packages/StatsBase/XgjIN/src/signalcorr.jl:68
[4] #autocov#113
@ ~/.julia/packages/StatsBase/XgjIN/src/signalcorr.jl:115 [inlined]
[5] autocov(x::Vector{Float64}, lags::Vector{Int64})
@ StatsBase ~/.julia/packages/StatsBase/XgjIN/src/signalcorr.jl:113
[6] top-level scope
@ REPL[3]:1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, this is a good idea.
This PR makes 2 improvements to
ess_rhat
:Float64
.U
ess
andrhat
preserve structure of last dimension of input array. This is especially useful when the input array is one with named dimensions, likeDimensionalData.DimArray
, since the indices in that dimension will be preserved.