Skip to content

Commit

Permalink
Fix-ups for sorting workspace/buffer (#45330) (#45570)
Browse files Browse the repository at this point in the history
* Fix and test sort!(OffsetArray(rand(200), -10))

* Convert to 1-based indexing rather than generalize to arbitrary indexing

* avoid overhead of views where reasonable

* style

* handle edge cases better, making the workspace function unhelpful. Also minor style changes and fixups from #45596 and local review.

* move comments in tests for discoverability

Co-authored-by: Lilith Hafner <Lilith.Hafner@gmail.com>
  • Loading branch information
LilithHafner and Lilith Hafner committed Jun 16, 2022
1 parent fa2f304 commit 6e79796
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 28 deletions.
42 changes: 26 additions & 16 deletions base/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module Sort
import ..@__MODULE__, ..parentmodule
const Base = parentmodule(@__MODULE__)
using .Base.Order
using .Base: copymutable, LinearIndices, length, (:), iterate, elsize,
using .Base: copymutable, LinearIndices, length, (:), iterate, OneTo,
eachindex, axes, first, last, similar, zip, OrdinalRange, firstindex, lastindex,
AbstractVector, @inbounds, AbstractRange, @eval, @inline, Vector, @noinline,
AbstractMatrix, AbstractUnitRange, isless, identity, eltype, >, <, <=, >=, |, +, -, *, !,
Expand Down Expand Up @@ -605,7 +605,10 @@ function sort!(v::AbstractVector{T}, lo::Integer, hi::Integer, a::MergeSortAlg,
hi-lo <= SMALL_THRESHOLD && return sort!(v, lo, hi, SMALL_ALGORITHM, o)

m = midpoint(lo, hi)
t = workspace(v, t0, m-lo+1)

t = t0 === nothing ? similar(v, m-lo+1) : t0
length(t) < m-lo+1 && resize!(t, m-lo+1)
Base.require_one_based_indexing(t)

sort!(v, lo, m, a, o, t)
sort!(v, m+1, hi, a, o, t)
Expand Down Expand Up @@ -683,7 +686,7 @@ function radix_sort!(v::AbstractVector{U}, lo::Integer, hi::Integer, bits::Unsig
t::AbstractVector{U}, chunk_size=radix_chunk_size_heuristic(lo, hi, bits)) where U <: Unsigned
# bits is unsigned for performance reasons.
mask = UInt(1) << chunk_size - 1
counts = Vector{UInt}(undef, mask+2)
counts = Vector{Int}(undef, mask+2)

@inbounds for shift in 0:chunk_size:bits-1

Expand Down Expand Up @@ -732,6 +735,7 @@ end

# For AbstractVector{Bool}, counting sort is always best.
# This is an implementation of counting sort specialized for Bools.
# Accepts unused workspace to avoid method ambiguity.
function sort!(v::AbstractVector{B}, lo::Integer, hi::Integer, a::AdaptiveSort, o::Ordering,
t::Union{AbstractVector{B}, Nothing}=nothing) where {B <: Bool}
first = lt(o, false, true) ? false : lt(o, true, false) ? true : return v
Expand All @@ -746,10 +750,6 @@ function sort!(v::AbstractVector{B}, lo::Integer, hi::Integer, a::AdaptiveSort,
v
end

workspace(v::AbstractVector, ::Nothing, len::Integer) = similar(v, len)
function workspace(v::AbstractVector{T}, t::AbstractVector{T}, len::Integer) where T
length(t) < len ? resize!(t, len) : t
end
maybe_unsigned(x::Integer) = x # this is necessary to avoid calling unsigned on BigInt
maybe_unsigned(x::BitSigned) = unsigned(x)
function _extrema(v::AbstractVector, lo::Integer, hi::Integer, o::Ordering)
Expand Down Expand Up @@ -856,8 +856,18 @@ function sort!(v::AbstractVector{T}, lo::Integer, hi::Integer, a::AdaptiveSort,
u[i] -= u_min
end

u2 = radix_sort!(u, lo, hi, bits, reinterpret(U, workspace(v, t, hi)))
uint_unmap!(v, u2, lo, hi, o, u_min)
if t !== nothing && checkbounds(Bool, t, lo:hi) # Fully preallocated and aligned workspace
u2 = radix_sort!(u, lo, hi, bits, reinterpret(U, t))
uint_unmap!(v, u2, lo, hi, o, u_min)
elseif t !== nothing && (applicable(resize!, t) || length(t) >= hi-lo+1) # Viable workspace
length(t) >= hi-lo+1 || resize!(t, hi-lo+1)
t1 = axes(t, 1) isa OneTo ? t : view(t, firstindex(t):lastindex(t))
u2 = radix_sort!(view(u, lo:hi), 1, hi-lo+1, bits, reinterpret(U, t1))
uint_unmap!(view(v, lo:hi), u2, 1, hi-lo+1, o, u_min)
else # No viable workspace
u2 = radix_sort!(u, lo, hi, bits, similar(u))
uint_unmap!(v, u2, lo, hi, o, u_min)
end
end

## generic sorting methods ##
Expand Down Expand Up @@ -1113,7 +1123,7 @@ function sortperm(v::AbstractVector;
by=identity,
rev::Union{Bool,Nothing}=nothing,
order::Ordering=Forward,
workspace::Union{AbstractVector, Nothing}=nothing)
workspace::Union{AbstractVector{<:Integer}, Nothing}=nothing)
ordr = ord(lt,by,rev,order)
if ordr === Forward && isa(v,Vector) && eltype(v)<:Integer
n = length(v)
Expand Down Expand Up @@ -1235,7 +1245,7 @@ function sort(A::AbstractArray{T};
by=identity,
rev::Union{Bool,Nothing}=nothing,
order::Ordering=Forward,
workspace::Union{AbstractVector{T}, Nothing}=similar(A, 0)) where T
workspace::Union{AbstractVector{T}, Nothing}=similar(A, size(A, dims))) where T
dim = dims
order = ord(lt,by,rev,order)
n = length(axes(A, dim))
Expand Down Expand Up @@ -1296,7 +1306,7 @@ function sort!(A::AbstractArray{T};
by=identity,
rev::Union{Bool,Nothing}=nothing,
order::Ordering=Forward,
workspace::Union{AbstractVector{T}, Nothing}=nothing) where T
workspace::Union{AbstractVector{T}, Nothing}=similar(A, size(A, dims))) where T
ordr = ord(lt, by, rev, order)
nd = ndims(A)
k = dims
Expand Down Expand Up @@ -1523,8 +1533,8 @@ issignleft(o::ForwardOrdering, x::Floats) = lt(o, x, zero(x))
issignleft(o::ReverseOrdering, x::Floats) = lt(o, x, -zero(x))
issignleft(o::Perm, i::Integer) = issignleft(o.order, o.data[i])

function fpsort!(v::AbstractVector, a::Algorithm, o::Ordering,
t::Union{AbstractVector, Nothing}=nothing)
function fpsort!(v::AbstractVector{T}, a::Algorithm, o::Ordering,
t::Union{AbstractVector{T}, Nothing}=nothing) where T
# fpsort!'s optimizations speed up comparisons, of which there are O(nlogn).
# The overhead is O(n). For n < 10, it's not worth it.
length(v) < 10 && return sort!(v, firstindex(v), lastindex(v), SMALL_ALGORITHM, o, t)
Expand All @@ -1550,8 +1560,8 @@ function sort!(v::FPSortable, a::Algorithm, o::DirectOrdering,
t::Union{FPSortable, Nothing}=nothing)
fpsort!(v, a, o, t)
end
function sort!(v::AbstractVector{<:Union{Signed, Unsigned}}, a::Algorithm,
o::Perm{<:DirectOrdering,<:FPSortable}, t::Union{AbstractVector, Nothing}=nothing)
function sort!(v::AbstractVector{T}, a::Algorithm, o::Perm{<:DirectOrdering,<:FPSortable},
t::Union{AbstractVector{T}, Nothing}=nothing) where T <: Union{Signed, Unsigned}
fpsort!(v, a, o, t)
end

Expand Down
24 changes: 12 additions & 12 deletions test/sorting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,16 @@ end
@test issorted(a)
end

@testset "sort!(::OffsetVector)" begin
for length in vcat(0:5, [10, 300, 500, 1000])
for offset in [-100000, -10, -1, 0, 1, 17, 1729]
x = OffsetVector(rand(length), offset)
sort!(x)
@test issorted(x)
end
end
end

@testset "sort!(::OffsetMatrix; dims)" begin
x = OffsetMatrix(rand(5,5), 5, -5)
sort!(x; dims=1)
Expand Down Expand Up @@ -654,17 +664,6 @@ end
end
end

@testset "workspace()" begin
for v in [[1, 2, 3], [0.0]]
for t0 in vcat([nothing], [similar(v,i) for i in 1:5]), len in 0:5
t = Base.Sort.workspace(v, t0, len)
@test eltype(t) == eltype(v)
@test length(t) >= len
@test firstindex(t) == 1
end
end
end

@testset "sort(x; workspace=w) " begin
for n in [1,10,100,1000]
v = rand(n)
Expand All @@ -681,7 +680,7 @@ end
end
end


# This testset is at the end of the file because it is slow.
@testset "searchsorted" begin
numTypes = [ Int8, Int16, Int32, Int64, Int128,
UInt8, UInt16, UInt32, UInt64, UInt128,
Expand Down Expand Up @@ -842,5 +841,6 @@ end
end
end
end
# The "searchsorted" testset is at the end of the file because it is slow.

end

0 comments on commit 6e79796

Please sign in to comment.