diff --git a/base/sort.jl b/base/sort.jl index b57c00edb5454..b9497d8136114 100644 --- a/base/sort.jl +++ b/base/sort.jl @@ -90,7 +90,15 @@ issorted(itr; issorted(itr, ord(lt,by,rev,order)) function partialsort!(v::AbstractVector, k::Union{Integer,OrdinalRange}, o::Ordering) - _sort!(v, InitialOptimizations(ScratchQuickSort(k)), o, (;)) + # TODO move k from `alg` to `kw` + # Don't perform InitialOptimizations before Bracketing. The optimizations take O(n) + # time and so does the whole sort. But do perform them before recursive calls because + # that can cause significant speedups when the target range is large so the runtime is + # dominated by k log k and the optimizations runs in O(k) time. + _sort!(v, BoolOptimization( + Small{12}( # Very small inputs should go straight to insertion sort + BracketedSort(k))), + o, (;)) maybeview(v, k) end @@ -1138,6 +1146,195 @@ function _sort!(v::AbstractVector, a::ScratchQuickSort, o::Ordering, kw; end +""" + BracketedSort(target[, next::Algorithm]) <: Algorithm + +Perform a partialsort for the elements that fall into the indices specified by the `target` +using BracketedSort with the `next` algorithm for subproblems. + +BracketedSort takes a random* sample of the input, estimates the quantiles of the input +using the quantiles of the sample to find signposts that almost certainly bracket the target +values, filters the value in the input that fall between the signpost values to the front of +the input, and then, if that "almost certainly" turned out to be true, finds the target +within the small chunk that are, by value, between the signposts and now by position, at the +front of the vector. On small inputs or when target is close to the size of the input, +BracketedSort falls back to the `next` algorithm directly. Otherwise, BracketedSort uses the +`next` algorithm only to compute quantiles of the sample and to find the target within the +small chunk. + +## Performance + +If the `next` algorithm has `O(n * log(n))` runtime and the input is not pathological then +the runtime of this algorithm is `O(n + k * log(k))` where `n` is the length of the input +and `k` is `length(target)`. On pathological inputs the asymptotic runtime is the same as +the runtime of the `next` algorithm. + +BracketedSort itself does not allocate. If `next` is in-place then BracketedSort is also +in-place. If `next` is not in place, and it's space usage increases monotonically with input +length then BracketedSort's maximum space usage will never be more than the space usage +of `next` on the input BracketedSort receives. For large nonpathological inputs and targets +substantially smaller than the size of the input, BracketedSort's maximum memory usage will +be much less than `next`'s. If the maximum additional space usage of `next` scales linearly +then for small k the average* maximum additional space usage of BracketedSort will be +`O(n^(2.3/3))`. + +By default, BracketedSort uses the `O(n)` space and `O(n + k log k)` runtime +`ScratchQuickSort` algorithm recursively. + +*Sorting is unable to depend on Random.jl because Random.jl depends on sorting. + Consequently, we use `hash` as a source of randomness. The average runtime guarantees + assume that `hash(x::Int)` produces a random result. However, as this randomization is + deterministic, if you try hard enough you can find inputs that consistently reach the + worst case bounds. Actually constructing such inputs is an exercise left to the reader. + Have fun :). + +Characteristics: + * *unstable*: does not preserve the ordering of elements that compare equal + (e.g. "a" and "A" in a sort of letters that ignores case). + * *in-place* in memory if the `next` algorithm is in-place. + * *estimate-and-filter*: strategy + * *linear runtime* if `length(target)` is constant and `next` is reasonable + * *n + k log k* worst case runtime if `next` has that runtime. + * *pathological inputs* can significantly increase constant factors. +""" +struct BracketedSort{T, F} <: Algorithm + target::T + get_next::F +end + +# TODO: this composition between BracketedSort and ScratchQuickSort does not bring me joy +BracketedSort(k) = BracketedSort(k, k -> InitialOptimizations(ScratchQuickSort(k))) + +function bracket_kernel!(v::AbstractVector, lo, hi, lo_signpost, hi_signpost, o) + i = 0 + count_below = 0 + checkbounds(v, lo:hi) + for j in lo:hi + x = @inbounds v[j] + a = lo_signpost !== nothing && lt(o, x, lo_signpost) + b = hi_signpost === nothing || !lt(o, hi_signpost, x) + count_below += a + # if a != b # This branch is almost never taken, so making it branchless is bad. + # @inbounds v[i], v[j] = v[j], v[i] + # i += 1 + # end + c = a != b # JK, this is faster. + k = i * c + j + # Invariant: @assert firstindex(v) ≤ lo ≤ i + j ≤ k ≤ j ≤ hi ≤ lastindex(v) + @inbounds v[j], v[k] = v[k], v[j] + i += c - 1 + end + count_below, i+hi +end + +function move!(v, target, source) + # This function never dominates runtime—only add `@inbounds` if you can demonstrate a + # performance improvement. And if you do, also double check behavior when `target` + # is out of bounds. + @assert length(target) == length(source) + if length(target) == 1 || isdisjoint(target, source) + for (i, j) in zip(target, source) + v[i], v[j] = v[j], v[i] + end + else + @assert minimum(source) <= minimum(target) + reverse!(v, minimum(source), maximum(target)) + reverse!(v, minimum(target), maximum(target)) + end +end + +function _sort!(v::AbstractVector, a::BracketedSort, o::Ordering, kw) + @getkw lo hi scratch + # TODO for further optimization: reuse scratch between trials better, from signpost + # selection to recursive calls, and from the fallback (but be aware of type stability, + # especially when sorting IEEE floats. + + # We don't need to bounds check target because that is done higher up in the stack + # However, we cannot assume the target is inbounds. + lo < hi || return scratch + ln = hi - lo + 1 + + # This is simply a precomputed short-circuit to avoid doing scalar math for small inputs. + # It does not change dispatch at all. + ln < 260 && return _sort!(v, a.get_next(a.target), o, kw) + + target = a.target + k = cbrt(ln) + k2 = round(Int, k^2) + k2ln = k2/ln + offset = .15k2*top_set_bit(k2) # TODO for further optimization: tune this + lo_signpost_i, hi_signpost_i = + (floor(Int, (tar - lo) * k2ln + lo + off) for (tar, off) in + ((minimum(target), -offset), (maximum(target), offset))) + lastindex_sample = lo+k2-1 + expected_middle_ln = (min(lastindex_sample, hi_signpost_i) - max(lo, lo_signpost_i) + 1) / k2ln + # This heuristic is complicated because it fairly accurately reflects the runtime of + # this algorithm which is necessary to get good dispatch when both the target is large + # and the input are large. + # expected_middle_ln is a float and k2 is significantly below typemax(Int), so this will + # not overflow: + # TODO move target from alg to kw to avoid this ickyness: + ln <= 130 + 2k2 + 2expected_middle_ln && return _sort!(v, a.get_next(a.target), o, kw) + + # We store the random sample in + # sample = view(v, lo:lo+k2) + # but views are not quite as fast as using the input array directly, + # so we don't actually construct this view at runtime. + + # TODO for further optimization: handle lots of duplicates better. + # Right now lots of duplicates rounds up when it could use some super fast optimizations + # in some cases. + # e.g. + # + # Target: |----| + # Sorted input: 000000000000000000011111112222223333333333 + # + # Will filter all zeros and ones to the front when it could just take the first few + # it encounters. This optimization would be especially potent when `allequal(ans)` and + # equal elements are egal. + + # 3 random trials should typically give us 0.99999 reliability; we can assume + # the input is pathological and abort to fallback if we fail three trials. + seed = hash(ln, Int === Int64 ? 0x85eb830e0216012d : 0xae6c4e15) + for attempt in 1:3 + seed = hash(attempt, seed) + for i in lo:lo+k2-1 + j = mod(hash(i, seed), i:hi) # TODO for further optimization: be sneaky and remove this division + v[i], v[j] = v[j], v[i] + end + count_below, lastindex_middle = if lo_signpost_i <= lo && lastindex_sample <= hi_signpost_i + # The heuristics higher up in this function that dispatch to the `next` + # algorithm should prevent this from happening. + # Specifically, this means that expected_middle_ln == ln, so + # ln <= ... + 2.0expected_middle_ln && return ... + # will trigger. + @assert false + # But if it does happen, the kernel reduces to + 0, hi + elseif lo_signpost_i <= lo + _sort!(v, a.get_next(hi_signpost_i), o, (;kw..., hi=lastindex_sample)) + bracket_kernel!(v, lo, hi, nothing, v[hi_signpost_i], o) + elseif lastindex_sample <= hi_signpost_i + _sort!(v, a.get_next(lo_signpost_i), o, (;kw..., hi=lastindex_sample)) + bracket_kernel!(v, lo, hi, v[lo_signpost_i], nothing, o) + else + # TODO for further optimization: don't sort the middle elements + _sort!(v, a.get_next(lo_signpost_i:hi_signpost_i), o, (;kw..., hi=lastindex_sample)) + bracket_kernel!(v, lo, hi, v[lo_signpost_i], v[hi_signpost_i], o) + end + target_in_middle = target .- count_below + if lo <= minimum(target_in_middle) && maximum(target_in_middle) <= lastindex_middle + scratch = _sort!(v, a.get_next(target_in_middle), o, (;kw..., hi=lastindex_middle)) + move!(v, target, target_in_middle) + return scratch + end + # This line almost never runs. + end + # This line only runs on pathological inputs. Make sure it's covered by tests :) + _sort!(v, a.get_next(target), o, kw) +end + + """ StableCheckSorted(next) <: Algorithm diff --git a/test/sorting.jl b/test/sorting.jl index edc6ca17cfde0..d88cc43a8e1f4 100644 --- a/test/sorting.jl +++ b/test/sorting.jl @@ -721,6 +721,8 @@ end for alg in safe_algs @test sort(1:n, alg=alg, lt = (i,j) -> v[i]<=v[j]) == perm end + # This could easily break with minor heuristic adjustments + # because partialsort is not even guaranteed to be stable: @test partialsort(1:n, 172, lt = (i,j) -> v[i]<=v[j]) == perm[172] @test partialsort(1:n, 315:415, lt = (i,j) -> v[i]<=v[j]) == perm[315:415] @@ -1034,6 +1036,41 @@ end @test issorted(sort!(rand(100), Base.Sort.InitialOptimizations(DispatchLoopTestAlg()), Base.Order.Forward)) end +@testset "partialsort tests added for BracketedSort #52006" begin + x = rand(Int, 1000) + @test partialsort(x, 1) == minimum(x) + @test partialsort(x, 1000) == maximum(x) + sx = sort(x) + for i in [1, 2, 4, 10, 11, 425, 500, 845, 991, 997, 999, 1000] + @test partialsort(x, i) == sx[i] + end + for i in [1:1, 1:2, 1:5, 1:8, 1:9, 1:11, 1:108, 135:812, 220:586, 363:368, 450:574, 458:597, 469:638, 487:488, 500:501, 584:594, 1000:1000] + @test partialsort(x, i) == sx[i] + end + + # Semi-pathological input + seed = hash(1000, Int === Int64 ? 0x85eb830e0216012d : 0xae6c4e15) + seed = hash(1, seed) + for i in 1:100 + j = mod(hash(i, seed), i:1000) + x[j] = typemax(Int) + end + @test partialsort(x, 500) == sort(x)[500] + + # Fully pathological input + # it would be too much trouble to actually construct a valid pathological input, so we + # construct an invalid pathological input. + # This test is kind of sketchy because it passes invalid inputs to the function + for i in [1:6, 1:483, 1:957, 77:86, 118:478, 223:227, 231:970, 317:958, 500:501, 500:501, 500:501, 614:620, 632:635, 658:665, 933:940, 937:942, 997:1000, 999:1000] + x = rand(1:5, 1000) + @test partialsort(x, i, lt=(<=)) == sort(x)[i] + end + for i in [1, 7, 8, 490, 495, 852, 993, 996, 1000] + x = rand(1:5, 1000) + @test partialsort(x, i, lt=(<=)) == sort(x)[i] + end +end + # This testset is at the end of the file because it is slow. @testset "searchsorted" begin numTypes = [ Int8, Int16, Int32, Int64, Int128,