Skip to content
This repository has been archived by the owner on May 4, 2019. It is now read-only.

Commit

Permalink
Sort NAs to last position for PooledDataArrays as well
Browse files Browse the repository at this point in the history
  • Loading branch information
simonster committed Jul 10, 2014
1 parent 1359a53 commit 334229c
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 11 deletions.
17 changes: 12 additions & 5 deletions src/grouping.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function groupsort_indexer(x::AbstractVector, ngroups::Integer)
function groupsort_indexer(x::AbstractVector, ngroups::Integer, nalast::Bool=false)
# translated from Wes McKinney's groupsort_indexer in pandas (file: src/groupby.pyx).

# count group sizes, location 0 for NA
Expand All @@ -11,10 +11,17 @@ function groupsort_indexer(x::AbstractVector, ngroups::Integer)

# mark the start of each contiguous group of like-indexed data
where = fill(1, ngroups + 1)
for i = 2:ngroups+1
where[i] = where[i - 1] + counts[i - 1]
if nalast
for i = 3:ngroups+1
where[i] = where[i - 1] + counts[i - 1]
end
where[1] = where[end] + counts[end]
else
for i = 2:ngroups+1
where[i] = where[i - 1] + counts[i - 1]
end
end

# this is our indexer
result = fill(0, n)
for i = 1:n
Expand All @@ -25,4 +32,4 @@ function groupsort_indexer(x::AbstractVector, ngroups::Integer)
result, where, counts
end

groupsort_indexer(pv::PooledDataVector) = groupsort_indexer(pv.refs, length(pv.pool))
groupsort_indexer(pv::PooledDataVector, nalast::Bool=false) = groupsort_indexer(pv.refs, length(pv.pool), nalast)
2 changes: 1 addition & 1 deletion src/pooleddataarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,7 @@ end
Base.sortperm(pda::PooledDataArray) = groupsort_indexer(pda)[1]
function Base.sortperm(pda::PooledDataArray)
if issorted(pda.pool)
return groupsort_indexer(pda)[1]
return groupsort_indexer(pda, true)[1]
else
return sortperm(reorder!(copy(pda)))
end
Expand Down
23 changes: 18 additions & 5 deletions test/sort.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,30 @@
module TestSort
using DataArrays, Base.Test

dv1 = @data([9, 1, 8, NA, 3, 3, 7, NA])
dv2 = 1.0 * dv1
dv3 = DataArray([1:8])
pdv1 = convert(PooledDataArray, dv1)

@test sortperm(dv1) == sortperm(dv2)
@test sortperm(dv1) == sortperm(pdv1)
@test isequal(sort(dv1), convert(DataArray, sort(dv1)))
@test isequal(sort(dv1), convert(DataArray, sort(pdv1)))

for T in (Float64, BigFloat)
n = 1000
na = randbool(n)
nna = sum(na)
a = Array(T, n)
ra = randn(n-nna)
a[!na] = ra
da = DataArray(a, na)
@test isequal(sort(da), [DataArray(sort(dropna(da))), DataArray(T, nna)])
@test isequal(da[sortperm(da)], [DataArray(sort(dropna(da))), DataArray(T, nna)])
@test isequal(sort(da, rev=true), [DataArray(T, nna), DataArray(sort(dropna(da), rev=true))])
@test isequal(da[sortperm(da, rev=true)], [DataArray(T, nna), DataArray(sort(dropna(da), rev=true))])
for da in (DataArray(a, na), PooledDataArray(a, na))
@test isequal(sort(da), [DataArray(sort(dropna(da))), DataArray(T, nna)])
@test isequal(da[sortperm(da)], [DataArray(sort(dropna(da))), DataArray(T, nna)])
if isa(da, DataArray)
@test isequal(sort(da, rev=true), [DataArray(T, nna), DataArray(sort(dropna(da), rev=true))])
@test isequal(da[sortperm(da, rev=true)], [DataArray(T, nna), DataArray(sort(dropna(da), rev=true))])
end
end
end
end

0 comments on commit 334229c

Please sign in to comment.