Skip to content

Commit

Permalink
Work around Julia's Base.Sort.MissingOptimization bugs (#78)
Browse files Browse the repository at this point in the history
  • Loading branch information
LilithHafner committed Jun 15, 2023
1 parent 4f1b96e commit ff693e2
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 8 deletions.
29 changes: 25 additions & 4 deletions src/SortingAlgorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ struct TimSortAlg <: Algorithm end
struct RadixSortAlg <: Algorithm end
struct CombSortAlg <: Algorithm end

function maybe_optimize(x::Algorithm)
function maybe_optimize(x::Algorithm)
isdefined(Base.Sort, :InitialOptimizations) ? Base.Sort.InitialOptimizations(x) : x
end
end
const HeapSort = maybe_optimize(HeapSortAlg())
const TimSort = maybe_optimize(TimSortAlg())
# Whenever InitialOptimizations is defined, RadixSort falls
# Whenever InitialOptimizations is defined, RadixSort falls
# back to Base.DEFAULT_STABLE which already includes them.
const RadixSort = RadixSortAlg()

Expand Down Expand Up @@ -79,6 +79,27 @@ end
#
# Original author: @kmsquire

@static if v"1.9.0-alpha" <= VERSION <= v"1.9.1"
function Base.getindex(v::Base.Sort.WithoutMissingVector, i::UnitRange)
out = Vector{eltype(v)}(undef, length(i))
out .= v.data[i]
out
end

# skip MissingOptimization due to JuliaLang/julia#50171
const _FIVE_ARG_SAFE_DEFAULT_STABLE = Base.DEFAULT_STABLE.next

# Explicitly define conversion from _sort!(v, alg, order, kw) to sort!(v, lo, hi, alg, order)
# To avoid excessively strict dispatch loop detection
function Base.Sort._sort!(v::AbstractVector, a::Union{HeapSortAlg, TimSortAlg, RadixSortAlg, CombSortAlg}, o::Base.Order.Ordering, kw)
Base.Sort.@getkw lo hi scratch
sort!(v, lo, hi, a, o)
scratch
end
else
const _FIVE_ARG_SAFE_DEFAULT_STABLE = Base.DEFAULT_STABLE
end

const Run = UnitRange{Int}

const MIN_GALLOP = 7
Expand Down Expand Up @@ -490,7 +511,7 @@ function sort!(v::AbstractVector, lo::Int, hi::Int, ::TimSortAlg, o::Ordering)
# Make a run of length minrun
count = min(minrun, hi-i+1)
run_range = i:i+count-1
sort!(v, i, i+count-1, DEFAULT_STABLE, o)
sort!(v, i, i+count-1, _FIVE_ARG_SAFE_DEFAULT_STABLE, o)
else
if !issorted(run_range)
run_range = last(run_range):first(run_range)
Expand Down
29 changes: 25 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@ using StatsBase
using Random

a = rand(1:10000, 1000)
am = [rand() < .9 ? i : missing for i in a]

for alg in [TimSort, HeapSort, RadixSort, CombSort]
for alg in [TimSort, HeapSort, RadixSort, CombSort, SortingAlgorithms.TimSortAlg()]
b = sort(a, alg=alg)
@test issorted(b)
ix = sortperm(a, alg=alg)
b = a[ix]
@test issorted(b)
@test a[ix] == b

# legacy 3-argument calling convention
@test b == sort!(copy(a), alg, Base.Order.Forward)

b = sort(a, alg=alg, rev=true)
@test issorted(b, rev=true)
ix = sortperm(a, alg=alg, rev=true)
Expand All @@ -34,9 +38,26 @@ for alg in [TimSort, HeapSort, RadixSort, CombSort]
invpermute!(c, ix)
@test c == a

if alg != RadixSort # RadixSort does not work with Lt orderings
if alg != RadixSort # RadixSort does not work with Lt orderings or missing
c = sort(a, alg=alg, lt=(>))
@test b == c

# Issue https://github.com/JuliaData/DataFrames.jl/issues/3340
bm1 = sort(am, alg=alg)
@test issorted(bm1)
@test count(ismissing, bm1) == count(ismissing, am)

bm2 = am[sortperm(am, alg=alg)]
@test issorted(bm2)
@test count(ismissing, bm2) == count(ismissing, am)

bm3 = am[sortperm!(collect(eachindex(am)), am, alg=alg)]
@test issorted(bm3)
@test count(ismissing, bm3) == count(ismissing, am)

if alg == TimSort # Stable
@test all(bm1 .=== bm2 .=== bm3)
end
end

c = sort(a, alg=alg, by=x->1/x)
Expand Down Expand Up @@ -103,8 +124,8 @@ for n in [0:10..., 100, 101, 1000, 1001]
# test float sorting with NaNs
s = sort(v, alg=alg, order=ord)
@test issorted(s, order=ord)
# This tests that NaNs (which compare equivalent) are treated stably

# This tests that NaNs (which compare equivalent) are treated stably
# even when the underlying algorithm is unstable. That it happens to
# pass is not a part of the public API:
@test reinterpret(UInt64, v[map(isnan, v)]) == reinterpret(UInt64, s[map(isnan, s)])
Expand Down

0 comments on commit ff693e2

Please sign in to comment.