-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
198 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
""" | ||
unique_indices(x) -> (unique, indices) | ||
Return the results of `unique(collect(x))` along with the a vector of the same length whose | ||
elements are the indices in `x` at which the corresponding unique element in `unique` is | ||
found. | ||
""" | ||
function unique_indices(x) | ||
inds = eachindex(x) | ||
T = eltype(inds) | ||
ind_map = DataStructures.SortedDict{eltype(x),Vector{T}}() | ||
for i in inds | ||
xi = x[i] | ||
inds_xi = get!(ind_map, xi) do | ||
return T[] | ||
end | ||
push!(inds_xi, i) | ||
end | ||
unique = collect(keys(ind_map)) | ||
indices = collect(values(ind_map)) | ||
return unique, indices | ||
end | ||
|
||
""" | ||
split_chain_indices( | ||
chain_inds::AbstractVector{Int}, | ||
split::Int=2, | ||
) -> AbstractVector{Int} | ||
Split each chain in `chain_inds` into `split` chains. | ||
For each chain in `chain_inds`, all entries are assumed to correspond to draws that have | ||
been ordered by iteration number. The result is a vector of the same length as `chain_inds` | ||
where each entry is the new index of the chain that the corresponding draw belongs to. | ||
""" | ||
function split_chain_indices(c::AbstractVector{Int}, split::Int=2) | ||
cnew = similar(c) | ||
if split == 1 | ||
copyto!(cnew, c) | ||
return cnew | ||
end | ||
_, indices = unique_indices(c) | ||
chain_ind = 0 | ||
for inds in indices | ||
ndraws_per_split, rem = divrem(length(inds), split) | ||
# here we can't use Iterators.partition because it's greedy. e.g. we can't partition | ||
# 4 items across 3 partitions because Iterators.partition(1:4, 1) == [[1], [2], [3]] | ||
# and Iterators.partition(1:4, 2) == [[1, 2], [3, 4]]. But we would want | ||
# [[1, 2], [3], [4]]. | ||
i = j = 0 | ||
ndraws_this_split = ndraws_per_split + (j < rem) | ||
chain_ind += 1 | ||
for ind in inds | ||
cnew[ind] = chain_ind | ||
if (i += 1) == ndraws_this_split | ||
i = 0 | ||
j += 1 | ||
ndraws_this_split = ndraws_per_split + (j < rem) | ||
chain_ind += 1 | ||
end | ||
end | ||
end | ||
return cnew | ||
end | ||
|
||
""" | ||
shuffle_split_stratified( | ||
rng::Random.AbstractRNG, | ||
group_ids::AbstractVector, | ||
frac::Real, | ||
) -> (inds1, inds2) | ||
Randomly split the indices of `group_ids` into two groups, where `frac` indices from each | ||
group are in `inds1` and the remainder are in `inds2`. | ||
This is used, for example, to split data into training and test data while preserving the | ||
class balances. | ||
""" | ||
function shuffle_split_stratified( | ||
rng::Random.AbstractRNG, group_ids::AbstractVector, frac::Real | ||
) | ||
_, indices = unique_indices(group_ids) | ||
T = eltype(eltype(indices)) | ||
N1_tot = sum(x -> round(Int, length(x) * frac), indices) | ||
N2_tot = length(group_ids) - N1_tot | ||
inds1 = Vector{T}(undef, N1_tot) | ||
inds2 = Vector{T}(undef, N2_tot) | ||
items_in_1 = items_in_2 = 0 | ||
for inds in indices | ||
N = length(inds) | ||
N1 = round(Int, N * frac) | ||
N2 = N - N1 | ||
Random.shuffle!(rng, inds) | ||
copyto!(inds1, items_in_1 + 1, inds, 1, N1) | ||
copyto!(inds2, items_in_2 + 1, inds, N1 + 1, N2) | ||
items_in_1 += N1 | ||
items_in_2 += N2 | ||
end | ||
return inds1, inds2 | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
using MCMCDiagnosticTools | ||
using Test | ||
using Random | ||
|
||
@testset "unique_indices" begin | ||
@testset "indices=$(eachindex(inds))" for inds in [ | ||
rand(11:14, 100), transpose(rand(11:14, 10, 10)) | ||
] | ||
unique, indices = @inferred MCMCDiagnosticTools.unique_indices(inds) | ||
@test unique isa Vector{Int} | ||
if eachindex(inds) isa CartesianIndices{2} | ||
@test indices isa Vector{Vector{CartesianIndex{2}}} | ||
else | ||
@test indices isa Vector{Vector{Int}} | ||
end | ||
@test issorted(unique) | ||
@test issetequal(union(indices...), eachindex(inds)) | ||
for i in eachindex(unique, indices) | ||
@test all(inds[indices[i]] .== unique[i]) | ||
end | ||
end | ||
end | ||
|
||
@testset "split_chain_indices" begin | ||
c = [2, 2, 1, 3, 4, 3, 4, 1, 2, 1, 4, 3, 3, 2, 4, 3, 4, 1, 4, 1] | ||
@test @inferred(MCMCDiagnosticTools.split_chain_indices(c, 1)) == c | ||
|
||
cnew = @inferred MCMCDiagnosticTools.split_chain_indices(c, 2) | ||
unique, indices = MCMCDiagnosticTools.unique_indices(c) | ||
uniquenew, indicesnew = MCMCDiagnosticTools.unique_indices(cnew) | ||
for (i, inew) in enumerate(1:2:7) | ||
@test length(indicesnew[inew]) ≥ length(indicesnew[inew + 1]) | ||
@test indices[i] == vcat(indicesnew[inew], indicesnew[inew + 1]) | ||
end | ||
|
||
cnew = MCMCDiagnosticTools.split_chain_indices(c, 3) | ||
unique, indices = MCMCDiagnosticTools.unique_indices(c) | ||
uniquenew, indicesnew = MCMCDiagnosticTools.unique_indices(cnew) | ||
for (i, inew) in enumerate(1:3:11) | ||
@test length(indicesnew[inew]) ≥ | ||
length(indicesnew[inew + 1]) ≥ | ||
length(indicesnew[inew + 2]) | ||
@test indices[i] == | ||
vcat(indicesnew[inew], indicesnew[inew + 1], indicesnew[inew + 2]) | ||
end | ||
end | ||
|
||
@testset "shuffle_split_stratified" begin | ||
rng = Random.default_rng() | ||
c = rand(1:4, 100) | ||
unique, indices = MCMCDiagnosticTools.unique_indices(c) | ||
@testset "frac=$frac" for frac in [0.3, 0.5, 0.7] | ||
inds1, inds2 = @inferred(MCMCDiagnosticTools.shuffle_split_stratified(rng, c, frac)) | ||
@test issetequal(vcat(inds1, inds2), eachindex(c)) | ||
for inds in indices | ||
common_inds = intersect(inds1, inds) | ||
@test length(common_inds) == round(frac * length(inds)) | ||
end | ||
end | ||
end |