Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sortperm with dims arg for AbstractArray, fixes #16273 #45211

Merged
merged 19 commits into from Jun 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
91 changes: 62 additions & 29 deletions base/sort.jl
Expand Up @@ -11,7 +11,7 @@ using .Base: copymutable, LinearIndices, length, (:), iterate, OneTo,
AbstractMatrix, AbstractUnitRange, isless, identity, eltype, >, <, <=, >=, |, +, -, *, !,
extrema, sub_with_overflow, add_with_overflow, oneunit, div, getindex, setindex!,
length, resize!, fill, Missing, require_one_based_indexing, keytype, UnitRange,
min, max, reinterpret, signed, unsigned, Signed, Unsigned, typemin, xor, Type, BitSigned
min, max, reinterpret, signed, unsigned, Signed, Unsigned, typemin, xor, Type, BitSigned, Val

using .Base: >>>, !==

Expand Down Expand Up @@ -1091,14 +1091,16 @@ end
## sortperm: the permutation to sort an array ##

"""
sortperm(v; alg::Algorithm=DEFAULT_UNSTABLE, lt=isless, by=identity, rev::Bool=false, order::Ordering=Forward)
sortperm(A; alg::Algorithm=DEFAULT_UNSTABLE, lt=isless, by=identity, rev::Bool=false, order::Ordering=Forward, [dims::Integer])

Return a permutation vector `I` that puts `v[I]` in sorted order. The order is specified
Return a permutation vector or array `I` that puts `A[I]` in sorted order along the given dimension.
If `A` has more than one dimension, then the `dims` keyword argument must be specified. The order is specified
using the same keywords as [`sort!`](@ref). The permutation is guaranteed to be stable even
if the sorting algorithm is unstable, meaning that indices of equal elements appear in
ascending order.

See also [`sortperm!`](@ref), [`partialsortperm`](@ref), [`invperm`](@ref), [`indexin`](@ref).
To sort slices of an array, refer to [`sortslices`](@ref).

# Examples
```jldoctest
Expand All @@ -1115,37 +1117,53 @@ julia> v[p]
1
2
3

julia> A = [8 7; 5 6]
2×2 Matrix{Int64}:
8 7
5 6

julia> sortperm(A, dims = 1)
2×2 Matrix{Int64}:
2 4
1 3

julia> sortperm(A, dims = 2)
2×2 Matrix{Int64}:
3 1
2 4
```
"""
function sortperm(v::AbstractVector;
function sortperm(A::AbstractArray;
alg::Algorithm=DEFAULT_UNSTABLE,
lt=isless,
by=identity,
rev::Union{Bool,Nothing}=nothing,
order::Ordering=Forward,
workspace::Union{AbstractVector{<:Integer}, Nothing}=nothing)
workspace::Union{AbstractVector{<:Integer}, Nothing}=nothing,
dims...) #to optionally specify dims argument
ordr = ord(lt,by,rev,order)
if ordr === Forward && isa(v,Vector) && eltype(v)<:Integer
n = length(v)
if ordr === Forward && isa(A,Vector) && eltype(A)<:Integer
n = length(A)
if n > 1
min, max = extrema(v)
min, max = extrema(A)
(diff, o1) = sub_with_overflow(max, min)
(rangelen, o2) = add_with_overflow(diff, oneunit(diff))
if !o1 && !o2 && rangelen < div(n,2)
return sortperm_int_range(v, rangelen, min)
return sortperm_int_range(A, rangelen, min)
end
end
end
p = copymutable(eachindex(v))
sort!(p, alg, Perm(ordr,v), workspace)
ix = copymutable(LinearIndices(A))
sort!(ix; alg, order = Perm(ordr, vec(A)), workspace, dims...)
end


"""
sortperm!(ix, v; alg::Algorithm=DEFAULT_UNSTABLE, lt=isless, by=identity, rev::Bool=false, order::Ordering=Forward, initialized::Bool=false)
sortperm!(ix, A; alg::Algorithm=DEFAULT_UNSTABLE, lt=isless, by=identity, rev::Bool=false, order::Ordering=Forward, initialized::Bool=false, [dims::Integer])

Like [`sortperm`](@ref), but accepts a preallocated index vector `ix`. If `initialized` is `false`
(the default), `ix` is initialized to contain the values `1:length(v)`.
Like [`sortperm`](@ref), but accepts a preallocated index vector or array `ix` with the same `axes` as `A`. If `initialized` is `false`
(the default), `ix` is initialized to contain the values `LinearIndices(A)`.

# Examples
```jldoctest
Expand All @@ -1162,25 +1180,36 @@ julia> v[p]
1
2
3

julia> A = [8 7; 5 6]; p = zeros(Int,2, 2);

julia> sortperm!(p, A; dims=1); p
2×2 Matrix{Int64}:
2 4
1 3

julia> sortperm!(p, A; dims=2); p
2×2 Matrix{Int64}:
3 1
2 4
```
"""
function sortperm!(x::AbstractVector{T}, v::AbstractVector;
function sortperm!(ix::AbstractArray{T}, A::AbstractArray;
alg::Algorithm=DEFAULT_UNSTABLE,
lt=isless,
by=identity,
rev::Union{Bool,Nothing}=nothing,
order::Ordering=Forward,
initialized::Bool=false,
workspace::Union{AbstractVector{T}, Nothing}=nothing) where T <: Integer
if axes(x,1) != axes(v,1)
throw(ArgumentError("index vector must have the same length/indices as the source vector, $(axes(x,1)) != $(axes(v,1))"))
end
workspace::Union{AbstractVector{T}, Nothing}=nothing,
dims...) where T <: Integer #to optionally specify dims argument
(typeof(A) <: AbstractVector) == (:dims in keys(dims)) && throw(ArgumentError("Dims argument incorrect for type $(typeof(A))"))
axes(ix) == axes(A) || throw(ArgumentError("index array must have the same size/axes as the source array, $(axes(ix)) != $(axes(A))"))

if !initialized
@inbounds for i in eachindex(v)
x[i] = i
end
ix .= LinearIndices(A)
end
sort!(x, alg, Perm(ord(lt,by,rev,order),v), workspace)
sort!(ix; alg, order = Perm(ord(lt, by, rev, order), vec(A)), workspace, dims...)
end

# sortperm for vectors of few unique integers
Expand Down Expand Up @@ -1307,16 +1336,20 @@ function sort!(A::AbstractArray{T};
rev::Union{Bool,Nothing}=nothing,
order::Ordering=Forward,
workspace::Union{AbstractVector{T}, Nothing}=similar(A, size(A, dims))) where T
ordr = ord(lt, by, rev, order)
_sort!(A, Val(dims), alg, ord(lt, by, rev, order), workspace)
end
function _sort!(A::AbstractArray{T}, ::Val{K},
alg::Algorithm,
order::Ordering,
workspace::Union{AbstractVector{T}, Nothing}) where {K,T}
nd = ndims(A)
k = dims

1 <= k <= nd || throw(ArgumentError("dimension out of range"))
1 <= K <= nd || throw(ArgumentError("dimension out of range"))

remdims = ntuple(i -> i == k ? 1 : axes(A, i), nd)
remdims = ntuple(i -> i == K ? 1 : axes(A, i), nd)
for idx in CartesianIndices(remdims)
Av = view(A, ntuple(i -> i == k ? Colon() : idx[i], nd)...)
sort!(Av, alg, ordr, workspace)
Av = view(A, ntuple(i -> i == K ? Colon() : idx[i], nd)...)
sort!(Av, alg, order, workspace)
pcjentsch marked this conversation as resolved.
Show resolved Hide resolved
end
A
end
Expand Down
22 changes: 19 additions & 3 deletions test/sorting.jl
Expand Up @@ -47,9 +47,25 @@ end
@test r == [3,1,2]
@test r === s
end
@test_throws ArgumentError sortperm!(view([1,2,3,4], 1:4), [2,3,1])
@test sortperm(OffsetVector([8.0,-2.0,0.5], -4)) == OffsetVector([-2, -1, -3], -4)
@test sortperm!(Int32[1,2], [2.0, 1.0]) == Int32[2, 1]
@test_throws ArgumentError sortperm!(view([1, 2, 3, 4], 1:4), [2, 3, 1])
@test sortperm(OffsetVector([8.0, -2.0, 0.5], -4)) == OffsetVector([-2, -1, -3], -4)
@test sortperm!(Int32[1, 2], [2.0, 1.0]) == Int32[2, 1]
@test_throws ArgumentError sortperm!(Int32[1, 2], [2.0, 1.0]; dims=1)
let A = rand(4, 4, 4)
for dims = 1:3
perm = sortperm(A; dims)
sorted = sort(A; dims)
@test A[perm] == sorted

perm_idx = similar(Array{Int}, axes(A))
sortperm!(perm_idx, A; dims)
@test perm_idx == perm
end
end
@test_throws ArgumentError sortperm!(zeros(Int, 3, 3), rand(3, 3);)
@test_throws ArgumentError sortperm!(zeros(Int, 3, 3), rand(3, 3); dims=3)
@test_throws ArgumentError sortperm!(zeros(Int, 3, 4), rand(4, 4); dims=1)
@test_throws ArgumentError sortperm!(OffsetArray(zeros(Int, 4, 4), -4:-1, 1:4), rand(4, 4); dims=1)
end

@testset "misc sorting" begin
Expand Down