From 334229caefca7387310719c5910919f314402d0a Mon Sep 17 00:00:00 2001 From: Simon Kornblith Date: Thu, 10 Jul 2014 15:07:15 -0400 Subject: [PATCH] Sort NAs to last position for PooledDataArrays as well --- src/grouping.jl | 17 ++++++++++++----- src/pooleddataarray.jl | 2 +- test/sort.jl | 23 ++++++++++++++++++----- 3 files changed, 31 insertions(+), 11 deletions(-) diff --git a/src/grouping.jl b/src/grouping.jl index c5199cb..3d8875d 100644 --- a/src/grouping.jl +++ b/src/grouping.jl @@ -1,4 +1,4 @@ -function groupsort_indexer(x::AbstractVector, ngroups::Integer) +function groupsort_indexer(x::AbstractVector, ngroups::Integer, nalast::Bool=false) # translated from Wes McKinney's groupsort_indexer in pandas (file: src/groupby.pyx). # count group sizes, location 0 for NA @@ -11,10 +11,17 @@ function groupsort_indexer(x::AbstractVector, ngroups::Integer) # mark the start of each contiguous group of like-indexed data where = fill(1, ngroups + 1) - for i = 2:ngroups+1 - where[i] = where[i - 1] + counts[i - 1] + if nalast + for i = 3:ngroups+1 + where[i] = where[i - 1] + counts[i - 1] + end + where[1] = where[end] + counts[end] + else + for i = 2:ngroups+1 + where[i] = where[i - 1] + counts[i - 1] + end end - + # this is our indexer result = fill(0, n) for i = 1:n @@ -25,4 +32,4 @@ function groupsort_indexer(x::AbstractVector, ngroups::Integer) result, where, counts end -groupsort_indexer(pv::PooledDataVector) = groupsort_indexer(pv.refs, length(pv.pool)) +groupsort_indexer(pv::PooledDataVector, nalast::Bool=false) = groupsort_indexer(pv.refs, length(pv.pool), nalast) diff --git a/src/pooleddataarray.jl b/src/pooleddataarray.jl index c263558..203537e 100644 --- a/src/pooleddataarray.jl +++ b/src/pooleddataarray.jl @@ -777,7 +777,7 @@ end Base.sortperm(pda::PooledDataArray) = groupsort_indexer(pda)[1] function Base.sortperm(pda::PooledDataArray) if issorted(pda.pool) - return groupsort_indexer(pda)[1] + return groupsort_indexer(pda, true)[1] else return sortperm(reorder!(copy(pda))) end diff --git a/test/sort.jl b/test/sort.jl index 4fa01f7..842bd64 100644 --- a/test/sort.jl +++ b/test/sort.jl @@ -1,6 +1,16 @@ module TestSort using DataArrays, Base.Test +dv1 = @data([9, 1, 8, NA, 3, 3, 7, NA]) +dv2 = 1.0 * dv1 +dv3 = DataArray([1:8]) +pdv1 = convert(PooledDataArray, dv1) + +@test sortperm(dv1) == sortperm(dv2) +@test sortperm(dv1) == sortperm(pdv1) +@test isequal(sort(dv1), convert(DataArray, sort(dv1))) +@test isequal(sort(dv1), convert(DataArray, sort(pdv1))) + for T in (Float64, BigFloat) n = 1000 na = randbool(n) @@ -8,10 +18,13 @@ for T in (Float64, BigFloat) a = Array(T, n) ra = randn(n-nna) a[!na] = ra - da = DataArray(a, na) - @test isequal(sort(da), [DataArray(sort(dropna(da))), DataArray(T, nna)]) - @test isequal(da[sortperm(da)], [DataArray(sort(dropna(da))), DataArray(T, nna)]) - @test isequal(sort(da, rev=true), [DataArray(T, nna), DataArray(sort(dropna(da), rev=true))]) - @test isequal(da[sortperm(da, rev=true)], [DataArray(T, nna), DataArray(sort(dropna(da), rev=true))]) + for da in (DataArray(a, na), PooledDataArray(a, na)) + @test isequal(sort(da), [DataArray(sort(dropna(da))), DataArray(T, nna)]) + @test isequal(da[sortperm(da)], [DataArray(sort(dropna(da))), DataArray(T, nna)]) + if isa(da, DataArray) + @test isequal(sort(da, rev=true), [DataArray(T, nna), DataArray(sort(dropna(da), rev=true))]) + @test isequal(da[sortperm(da, rev=true)], [DataArray(T, nna), DataArray(sort(dropna(da), rev=true))]) + end + end end end \ No newline at end of file