Skip to content

Commit

Permalink
Added iterators EqualSumSubsets and TwoStepRange
Browse files Browse the repository at this point in the history
  • Loading branch information
PGS62 committed Apr 4, 2024
1 parent baf193b commit a1d9ae9
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 41 deletions.
139 changes: 103 additions & 36 deletions src/pairwise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ function _pairwise!(::Val{:none}, f, dest::AbstractMatrix{V}, x, y,
#cov(x) is faster than cov(x, x)
(f == cov) && (f = ((x, y) -> x === y ? cov(x) : cov(x, y)))

Threads.@threads for subset in equal_sum_subsets(nc, Threads.nthreads())
Threads.@threads for subset in EqualSumSubsets(nc, Threads.nthreads())
for j in subset
for i = (symmetric ? j : 1):nr
# For performance, diagonal is special-cased
Expand Down Expand Up @@ -96,7 +96,7 @@ function _pairwise!(::Val{:pairwise}, f, dest::AbstractMatrix{V}, x, y, symmetri
nmtx = promoted_nmtype(x)[]
nmty = promoted_nmtype(y)[]

Threads.@threads for subset in equal_sum_subsets(nc, Threads.nthreads())
Threads.@threads for subset in EqualSumSubsets(nc, Threads.nthreads())
scratch_fx = task_local_vector(:scratch_fx, nmtx, m)
scratch_fy = task_local_vector(:scratch_fy, nmty, m)
for j in subset
Expand Down Expand Up @@ -393,51 +393,118 @@ function handle_pairwise(x::AbstractVector, y::AbstractVector;
return view(scratch_fx, lb:(j-1)), view(scratch_fy, lb:(j-1))
end

#=Condition a) makes equal_sum_subsets useful for load-balancing the multi-threaded section
of _pairwise! in the non-symmetric case, and condition b) for the symmetric case.=#
"""
equal_sum_subsets(n::Int, num_subsets::Int)::Vector{Vector{Int}}
task_local_vector(key::Symbol, similarto::AbstractArray{V},
length::Int)::Vector{V} where {V}
Retrieve from task local storage a vector of length `length` and matching the element
type of `similarto`, with initialisation on first call during a task.
"""
function task_local_vector(key::Symbol, similarto::AbstractArray{V},
length::Int)::Vector{V} where {V}
haskey(task_local_storage(), key) || task_local_storage(key, similar(similarto, length))
return task_local_storage(key)
end

Divide the integers 1:n into a number of subsets such that a) each subset has
(approximately) the same number of elements; and b) the sum of the elements in each subset
is nearly equal. If `n` is a multiple of `2 * num_subsets` both conditions hold exactly.
"""
EqualSumSubsets
An iterator that partitions the integers 1 to n into `num_subsets` "subsets" (of type
TwoStepRange) such that a) each subset is of nearly equal length; and b) the sum of the
elements in each subset is nearly equal. If `n` is a multiple of `2 * num_subsets` both
conditions hold exactly.
## Example
```julia-repl
julia> StatsBase.equal_sum_subsets(30,5)
5-element Vector{Vector{Int64}}:
[30, 21, 20, 11, 10, 1]
[29, 22, 19, 12, 9, 2]
[28, 23, 18, 13, 8, 3]
[27, 24, 17, 14, 7, 4]
[26, 25, 16, 15, 6, 5]
julia> for s in StatsBase.EqualSumSubsets(30,5)
println((collect(s), sum(s)))
end
([30, 21, 20, 11, 10, 1], 93)
([29, 22, 19, 12, 9, 2], 93)
([28, 23, 18, 13, 8, 3], 93)
([27, 24, 17, 14, 7, 4], 93)
([26, 25, 16, 15, 6, 5], 93)
#Check for correct partitioning, in this case of integers 1:1000 into 17 subsets.
julia> sort(vcat([collect(s) for s in StatsBase.EqualSumSubsets(1000,17)]...))==1:1000
true
```
"""
function equal_sum_subsets(n::Int, num_subsets::Int)::Vector{Vector{Int}}
subsets = [Int[] for _ in 1:min(n, num_subsets)]
writeto, scanup = 1, true
for i = n:-1:1
push!(subsets[writeto], i)
if scanup && writeto == num_subsets
scanup = false
elseif (!scanup) && writeto == 1
scanup = true
else
writeto += scanup ? 1 : -1
end
struct EqualSumSubsets
n::Int64
num_subsets::Int64

function EqualSumSubsets(n, num_subsets)
n >= 0 || throw("n must not be negative, but got $n")
num_subsets > 0 || throw("num_subsets must be positive, but got $num_subsets")
return new(n, num_subsets)
end
return subsets

end

Base.eltype(::EqualSumSubsets) = TwoStepRange

Check warning on line 446 in src/pairwise.jl

View check run for this annotation

Codecov / codecov/patch

src/pairwise.jl#L446

Added line #L446 was not covered by tests
Base.length(x::EqualSumSubsets) = min(x.n, x.num_subsets)
Base.firstindex(::EqualSumSubsets) = 1

function Base.iterate(ess::EqualSumSubsets, state::Int64=1)
state > length(ess) && return nothing
return getindex(ess, state), state + 1
end

function Base.getindex(ess::EqualSumSubsets, i::Int64=1)
n, nss = ess.n, ess.num_subsets
step1 = 2i - 2nss - 1
step2 = 1 - 2i
return TwoStepRange(n - i + 1, step1, step2)
end

"""
task_local_vector(key::Symbol, similarto::AbstractArray{V},
length::Int)::Vector{V} where {V}
TwoStepRange
Retrieve from task local storage a vector of length `length` and matching the element
type of `similarto`, with initialisation on first call during a task.
Range with a starting value of `start`, stop value of `1` and a step that alternates
between `step1` and `step2`. `start` must be positive and `step1` and `step2` must be
negative.
# Examples
```jldoctest

Check failure on line 470 in src/pairwise.jl

View workflow job for this annotation

GitHub Actions / Documentation

doctest failure in ~/work/StatsBase.jl/StatsBase.jl/src/pairwise.jl:470-479 ```jldoctest julia> collect(StatsBase.TwoStepRange(30,-7,-3)) 6-element Vector{Int64}: 30 23 20 13 10 3 ``` Subexpression: collect(StatsBase.TwoStepRange(30,-7,-3)) Evaluated output: ERROR: UndefVarError: `StatsBase` not defined in `Main` Suggestion: check for spelling errors or missing imports. Hint: a global variable of this name also exists in StatsBase. Stacktrace: [1] top-level scope @ none:1 Expected output: 6-element Vector{Int64}: 30 23 20 13 10 3 diff = Warning: Diff output requires color. 6-element Vector{Int64}: 30 23 20 13 10 3ERROR: UndefVarError: `StatsBase` not defined in `Main` Suggestion: check for spelling errors or missing imports. Hint: a global variable of this name also exists in StatsBase. Stacktrace: [1] top-level scope @ none:1
julia> collect(StatsBase.TwoStepRange(30,-7,-3))
6-element Vector{Int64}:
30
23
20
13
10
3
```
"""
function task_local_vector(key::Symbol, similarto::AbstractArray{V},
length::Int)::Vector{V} where {V}
haskey(task_local_storage(), key) || task_local_storage(key, similar(similarto, length))
return task_local_storage(key)
struct TwoStepRange
start::Int64
step1::Int64
step2::Int64

function TwoStepRange(start, step1, step2)
start > 0 || throw("start must be positive, but got $start")
step1 < 0 || throw("step1 must be negative, but got $step1")
step2 < 0 || throw("step2 must be negative, but got $step2")
return new(start, step1, step2)
end
end

Base.eltype(::TwoStepRange) = Int64

function Base.length(tsr::TwoStepRange)
return length((tsr.start):(tsr.step1+tsr.step2):1) +
length((tsr.start+tsr.step1):(tsr.step1+tsr.step2):1)
end

function Base.iterate(tsr::TwoStepRange, i::Int64=1)
(i > length(tsr)) && return nothing
return getindex(tsr, i), i + 1
end

function Base.getindex(tsr::TwoStepRange, i::Int64=1)::Int64
a, b = divrem(i - 1, 2)
return tsr.start + (tsr.step1 + tsr.step2) * a + b * tsr.step1
end
4 changes: 2 additions & 2 deletions src/rankcorr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ function _pairwise!(::Val{:pairwise}, f::typeof(corspearman),
fl64 = Float64[]
nmtx = promoted_nmtype(x)[]
nmty = promoted_nmtype(y)[]
Threads.@threads for subset in equal_sum_subsets(nr, Threads.nthreads())
Threads.@threads for subset in EqualSumSubsets(nr, Threads.nthreads())

for i in subset

Expand Down Expand Up @@ -482,7 +482,7 @@ function corkendall_loop!(skipmissing::Symbol, f::typeof(corkendall), dest::Abst

symmetric = x === y

Threads.@threads for subset in equal_sum_subsets(nr, Threads.nthreads())
Threads.@threads for subset in EqualSumSubsets(nr, Threads.nthreads())

for i in subset

Expand Down
13 changes: 10 additions & 3 deletions test/rankcorr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,16 @@ using Test
@test StatsBase.midpoint(1, 10) == 5
@test StatsBase.midpoint(1, widen(10)) == 5

@test StatsBase.equal_sum_subsets(0, 1) == Vector{Int64}[]
@test sum.(StatsBase.equal_sum_subsets(100, 5)) == repeat([1010], 5)
@test sort(vcat(StatsBase.equal_sum_subsets(500, 7)...)) == collect(1:500)
for n in 1:200, nss in 1:7
#check is a partition
@test sort(vcat([collect(s) for s in StatsBase.EqualSumSubsets(n, nss)]...)) == 1:n
#check near-equal lengths
lengths = [length(s) for s in StatsBase.EqualSumSubsets(n, nss)]
@test (maximum(lengths) - minimum(lengths)) <= 1
#check near-equal sums
sums = [sum(collect(s)) for s in StatsBase.EqualSumSubsets(n, nss)]
@test (maximum(sums) - minimum(sums)) < nss
end

end

Expand Down

0 comments on commit a1d9ae9

Please sign in to comment.