Skip to content

Commit

Permalink
Merge efe0502 into 8d74357
Browse files Browse the repository at this point in the history
  • Loading branch information
sethaxen committed Dec 13, 2022
2 parents 8d74357 + efe0502 commit 1ff24d7
Show file tree
Hide file tree
Showing 7 changed files with 198 additions and 16 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
name = "MCMCDiagnosticTools"
uuid = "be115224-59cd-429b-ad48-344e309966f0"
authors = ["David Widmann"]
version = "0.2.0"
version = "0.2.1"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
Expand All @@ -18,6 +19,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
[compat]
AbstractFFTs = "0.5, 1"
DataAPI = "1.6"
DataStructures = "0.18.3"
Distributions = "0.25"
MLJModelInterface = "1.6"
SpecialFunctions = "0.8, 0.9, 0.10, 1, 2"
Expand Down
2 changes: 2 additions & 0 deletions src/MCMCDiagnosticTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module MCMCDiagnosticTools

using AbstractFFTs: AbstractFFTs
using DataAPI: DataAPI
using DataStructures: DataStructures
using Distributions: Distributions
using MLJModelInterface: MLJModelInterface
using SpecialFunctions: SpecialFunctions
Expand All @@ -22,6 +23,7 @@ export mcse
export rafterydiag
export rstar

include("utils.jl")
include("bfmi.jl")
include("discretediag.jl")
include("ess.jl")
Expand Down
27 changes: 15 additions & 12 deletions src/rstar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
classifier::MLJModelInterface.Supervised,
samples,
chain_indices::AbstractVector{Int};
subset::Real=0.8,
subset::Real=0.7,
split_chains::Int=2,
verbosity::Int=0,
)
Expand All @@ -23,26 +24,25 @@ function rstar(
classifier::MLJModelInterface.Supervised,
x,
y::AbstractVector{Int};
subset::Real=0.8,
subset::Real=0.7,
split_chains::Int=2,
verbosity::Int=0,
)
# checks
MLJModelInterface.nrows(x) != length(y) && throw(DimensionMismatch())
0 < subset < 1 || throw(ArgumentError("`subset` must be a number in (0, 1)"))

ysplit = split_chain_indices(y, split_chains)

# randomly sub-select training and testing set
N = length(y)
Ntrain = round(Int, N * subset)
0 < Ntrain < N ||
train_ids, test_ids = shuffle_split_stratified(rng, ysplit, subset)
0 < length(train_ids) < length(y) ||
throw(ArgumentError("training and test data subsets must not be empty"))
ids = Random.randperm(rng, N)
train_ids = view(ids, 1:Ntrain)
test_ids = view(ids, (Ntrain + 1):N)

xtable = _astable(x)

# train classifier on training data
ycategorical = MLJModelInterface.categorical(y)
ycategorical = MLJModelInterface.categorical(ysplit)
xtrain = MLJModelInterface.selectrows(xtable, train_ids)
fitresult, _ = MLJModelInterface.fit(
classifier, verbosity, xtrain, ycategorical[train_ids]
Expand Down Expand Up @@ -79,7 +79,8 @@ end
rng::Random.AbstractRNG=Random.default_rng(),
classifier::MLJModelInterface.Supervised,
samples::AbstractArray{<:Real,3};
subset::Real=0.8,
subset::Real=0.7,
split_chains::Int=2,
verbosity::Int=0,
)
Expand All @@ -91,8 +92,10 @@ This implementation is an adaption of algorithms 1 and 2 described by Lambert an
The `classifier` has to be a supervised classifier of the MLJ framework (see the
[MLJ documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/list_of_supported_models/#model_list)
for a list of supported models). It is trained with a `subset` of the samples. The training
of the classifier can be inspected by adjusting the `verbosity` level.
for a list of supported models). It is trained with a `subset` of the samples from each
chain. Each chain is split into `split_chains` separate chains to additionally check for
within-chain convergence. The training of the classifier can be inspected by adjusting the
`verbosity` level.
If the classifier is deterministic, i.e., if it predicts a class, the value of the ``R^*``
statistic is returned (algorithm 1). If the classifier is probabilistic, i.e., if it outputs
Expand Down
100 changes: 100 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""
unique_indices(x) -> (unique, indices)
Return the results of `unique(collect(x))` along with the a vector of the same length whose
elements are the indices in `x` at which the corresponding unique element in `unique` is
found.
"""
function unique_indices(x)
inds = eachindex(x)
T = eltype(inds)
ind_map = DataStructures.SortedDict{eltype(x),Vector{T}}()
for i in inds
xi = x[i]
inds_xi = get!(ind_map, xi) do
return T[]
end
push!(inds_xi, i)
end
unique = collect(keys(ind_map))
indices = collect(values(ind_map))
return unique, indices
end

"""
split_chain_indices(
chain_inds::AbstractVector{Int},
split::Int=2,
) -> AbstractVector{Int}
Split each chain in `chain_inds` into `split` chains.
For each chain in `chain_inds`, all entries are assumed to correspond to draws that have
been ordered by iteration number. The result is a vector of the same length as `chain_inds`
where each entry is the new index of the chain that the corresponding draw belongs to.
"""
function split_chain_indices(c::AbstractVector{Int}, split::Int=2)
cnew = similar(c)
if split == 1
copyto!(cnew, c)
return cnew
end
_, indices = unique_indices(c)
chain_ind = 0
for inds in indices
ndraws_per_split, rem = divrem(length(inds), split)
# here we can't use Iterators.partition because it's greedy. e.g. we can't partition
# 4 items across 3 partitions because Iterators.partition(1:4, 1) == [[1], [2], [3]]
# and Iterators.partition(1:4, 2) == [[1, 2], [3, 4]]. But we would want
# [[1, 2], [3], [4]].
i = j = 0
ndraws_this_split = ndraws_per_split + (j < rem)
chain_ind += 1
for ind in inds
cnew[ind] = chain_ind
if (i += 1) == ndraws_this_split
i = 0
j += 1
ndraws_this_split = ndraws_per_split + (j < rem)
chain_ind += 1
end
end
end
return cnew
end

"""
shuffle_split_stratified(
rng::Random.AbstractRNG,
group_ids::AbstractVector,
frac::Real,
) -> (inds1, inds2)
Randomly split the indices of `group_ids` into two groups, where `frac` indices from each
group are in `inds1` and the remainder are in `inds2`.
This is used, for example, to split data into training and test data while preserving the
class balances.
"""
function shuffle_split_stratified(
rng::Random.AbstractRNG, group_ids::AbstractVector, frac::Real
)
_, indices = unique_indices(group_ids)
T = eltype(eltype(indices))
N1_tot = sum(x -> round(Int, length(x) * frac), indices)
N2_tot = length(group_ids) - N1_tot
inds1 = Vector{T}(undef, N1_tot)
inds2 = Vector{T}(undef, N2_tot)
items_in_1 = items_in_2 = 0
for inds in indices
N = length(inds)
N1 = round(Int, N * frac)
N2 = N - N1
Random.shuffle!(rng, inds)
copyto!(inds1, items_in_1 + 1, inds, 1, N1)
copyto!(inds2, items_in_2 + 1, inds, N1 + 1, N2)
items_in_1 += N1
items_in_2 += N2
end
return inds1, inds2
end
17 changes: 14 additions & 3 deletions test/rstar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mo
@test dist isa LocationScale
@test dist.ρ isa PoissonBinomial
@test minimum(dist) == 0
@test maximum(dist) == 3
@test maximum(dist) == 6
end
@test mean(dist) 1 rtol = 0.2
wrapper === Vector && break
Expand All @@ -48,7 +48,7 @@ const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mo
@test dist isa LocationScale
@test dist.ρ isa PoissonBinomial
@test minimum(dist) == 0
@test maximum(dist) == 4
@test maximum(dist) == 8
end
@test mean(dist) 1 rtol = 0.15

Expand All @@ -58,7 +58,7 @@ const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mo
100 .* cos.(1:N) 100 .* sin.(1:N)
])
chain_indices = repeat(1:2; inner=N)
dist = rstar(classifier, samples, chain_indices)
dist = rstar(classifier, samples, chain_indices; split_chains=1)

# Mean of the statistic should be close to 2, i.e., the classifier should be able to
# learn an almost perfect decision boundary between chains.
Expand All @@ -71,6 +71,17 @@ const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mo
@test maximum(dist) == 2
end
@test mean(dist) 2 rtol = 0.15

# Compute the R⋆ statistic for identical chains that individually have not mixed.
samples = ones(sz)
samples[div(N, 2):end, :] .= 2
chain_indices = repeat(1:4; outer=div(N, 4))
dist = rstar(classifier, samples, chain_indices; split_chains=1)
# without split chains cannot distinguish between chains
@test mean(dist) 1 rtol = 0.15
dist = rstar(classifier, samples, chain_indices)
# with split chains can learn almost perfect decision boundary
@test mean(dist) 2 rtol = 0.15
end
wrapper === Vector && continue

Expand Down
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ using Test
Random.seed!(1)

@testset "MCMCDiagnosticTools.jl" begin
@testset "utils" begin
include("utils.jl")
end

@testset "Bayesian fraction of missing information" begin
include("bfmi.jl")
end
Expand Down
60 changes: 60 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
using MCMCDiagnosticTools
using Test
using Random

@testset "unique_indices" begin
@testset "indices=$(eachindex(inds))" for inds in [
rand(11:14, 100), transpose(rand(11:14, 10, 10))
]
unique, indices = @inferred MCMCDiagnosticTools.unique_indices(inds)
@test unique isa Vector{Int}
if eachindex(inds) isa CartesianIndices{2}
@test indices isa Vector{Vector{CartesianIndex{2}}}
else
@test indices isa Vector{Vector{Int}}
end
@test issorted(unique)
@test issetequal(union(indices...), eachindex(inds))
for i in eachindex(unique, indices)
@test all(inds[indices[i]] .== unique[i])
end
end
end

@testset "split_chain_indices" begin
c = [2, 2, 1, 3, 4, 3, 4, 1, 2, 1, 4, 3, 3, 2, 4, 3, 4, 1, 4, 1]
@test @inferred(MCMCDiagnosticTools.split_chain_indices(c, 1)) == c

cnew = @inferred MCMCDiagnosticTools.split_chain_indices(c, 2)
unique, indices = MCMCDiagnosticTools.unique_indices(c)
uniquenew, indicesnew = MCMCDiagnosticTools.unique_indices(cnew)
for (i, inew) in enumerate(1:2:7)
@test length(indicesnew[inew]) length(indicesnew[inew + 1])
@test indices[i] == vcat(indicesnew[inew], indicesnew[inew + 1])
end

cnew = MCMCDiagnosticTools.split_chain_indices(c, 3)
unique, indices = MCMCDiagnosticTools.unique_indices(c)
uniquenew, indicesnew = MCMCDiagnosticTools.unique_indices(cnew)
for (i, inew) in enumerate(1:3:11)
@test length(indicesnew[inew])
length(indicesnew[inew + 1])
length(indicesnew[inew + 2])
@test indices[i] ==
vcat(indicesnew[inew], indicesnew[inew + 1], indicesnew[inew + 2])
end
end

@testset "shuffle_split_stratified" begin
rng = Random.default_rng()
c = rand(1:4, 100)
unique, indices = MCMCDiagnosticTools.unique_indices(c)
@testset "frac=$frac" for frac in [0.3, 0.5, 0.7]
inds1, inds2 = @inferred(MCMCDiagnosticTools.shuffle_split_stratified(rng, c, frac))
@test issetequal(vcat(inds1, inds2), eachindex(c))
for inds in indices
common_inds = intersect(inds1, inds)
@test length(common_inds) == round(frac * length(inds))
end
end
end

0 comments on commit 1ff24d7

Please sign in to comment.