Skip to content

Commit

Permalink
Merge ebb9caf into c361069
Browse files Browse the repository at this point in the history
  • Loading branch information
nalimilan committed Sep 18, 2020
2 parents c361069 + ebb9caf commit 49e822c
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 64 deletions.
114 changes: 68 additions & 46 deletions src/extras.jl
@@ -1,26 +1,7 @@
using Statistics

function fill_refs!(refs::AbstractArray, X::AbstractArray,
breaks::AbstractVector, extend::Bool, allowmissing::Bool)
n = length(breaks)
lower = first(breaks)
upper = last(breaks)

@inbounds for i in eachindex(X)
x = X[i]

if extend && x == upper
refs[i] = n-1
elseif !extend && !(lower <= x < upper)
throw(ArgumentError("value $x (at index $i) does not fall inside the breaks: adapt them manually, or pass extend=true"))
else
refs[i] = searchsortedlast(breaks, x)
end
end
end

function fill_refs!(refs::AbstractArray, X::AbstractArray{>: Missing},
breaks::AbstractVector, extend::Bool, allowmissing::Bool)
breaks::AbstractVector, extend::Union{Bool, Missing})
n = length(breaks)
lower = first(breaks)
upper = last(breaks)
Expand All @@ -30,10 +11,12 @@ function fill_refs!(refs::AbstractArray, X::AbstractArray{>: Missing},

if ismissing(x)
refs[i] = 0
elseif extend && x == upper
elseif extend === true && x == upper
refs[i] = n-1
elseif !extend && !(lower <= x < upper)
allowmissing || throw(ArgumentError("value $x (at index $i) does not fall inside the breaks: adapt them manually, or pass extend=true or allowmissing=true"))
elseif extend !== true && !(lower <= x < upper)
extend === missing ||
throw(ArgumentError("value $x (at index $i) does not fall inside the breaks: " *
"adapt them manually, or pass extend=true or extend=missing"))
refs[i] = 0
else
refs[i] = searchsortedlast(breaks, x)
Expand All @@ -52,25 +35,26 @@ default_formatter(from, to, i; leftclosed, rightclosed) =
@doc raw"""
cut(x::AbstractArray, breaks::AbstractVector;
labels::Union{AbstractVector{<:AbstractString},Function},
extend::Bool=false, allowmissing::Bool=false, allowempty::Bool=false)
extend::Union{Bool,Missing}=false, allowempty::Bool=false)
Cut a numeric array into intervals and return an ordered `CategoricalArray` indicating
Cut a numeric array into intervals at values `breaks`
and return an ordered `CategoricalArray` indicating
the interval into which each entry falls. Intervals are of the form `[lower, upper)`,
i.e. the lower bound is included and the upper bound is excluded.
i.e. the lower bound is included and the upper bound is excluded, except for the last
interval which is closed on both ends, i.e. `[lower, upper]`.
If `x` accepts missing values (i.e. `eltype(x) >: Missing`) the returned array will
also accept them.
# Keyword arguments
* `extend::Bool=false`: when `false`, an error is raised if some values in `x` fall
outside of the breaks; when `true`, breaks are automatically added to include all
values in `x`, and the upper bound is included in the last interval.
* `extend::Union{Bool, Missing}=false`: when `false`, an error is raised if some values
in `x` fall outside of the breaks; when `true`, breaks are automatically added to include
all values in `x`, and the upper bound is included in the last interval; when `missing`,
values outside of the breaks generate `missing` entries.
* `labels::Union{AbstractVector,Function}`: a vector of strings giving the names to use for
the intervals; or a function `f(from, to, i; leftclosed, rightclosed)` that generates
the labels from the left and right interval boundaries and the group index. Defaults to
`"[from, to)"` (or `"[from, to]"` for the rightmost interval if `extend == true`).
* `allowmissing::Bool=true`: when `true`, values outside of breaks result in missing values.
only supported when `x` accepts missing values.
* `allowempty::Bool=false`: when `false`, an error is raised if some breaks appear
multiple times, generating empty intervals; when `true`, duplicate breaks are allowed
and the intervals they generate are kept as unused levels
Expand Down Expand Up @@ -116,17 +100,30 @@ julia> cut(-1:0.5:1, 3, labels=fmt)
"grp 3 (0.333333//1.0)"
```
"""
function cut(x::AbstractArray{T, N}, breaks::AbstractVector;
extend::Bool=false,
labels::Union{AbstractVector{<:AbstractString},Function}=default_formatter,
allowmissing::Bool=false,
allow_missing::Union{Bool, Nothing}=nothing,
allowempty::Bool=false) where {T, N}
@inline function cut(x::AbstractArray, breaks::AbstractVector;
extend::Union{Bool, Missing}=false,
labels::Union{AbstractVector{<:AbstractString},Function}=default_formatter,
allowmissing::Union{Bool, Nothing}=nothing,
allow_missing::Union{Bool, Nothing}=nothing,
allowempty::Bool=false)
if allow_missing !== nothing
Base.depwarn("allow_missing argument is deprecated, use allowmissing instead",
:cut!)
allowmissing = allow_missing
Base.depwarn("allow_missing argument is deprecated, use extend=missing instead",
:cut)
extend = missing
end
if allowmissing !== nothing
Base.depwarn("allowmissing argument is deprecated, use extend=missing instead",
:cut)
extend = missing
end
return _cut(x, breaks, extend, labels, allowempty)
end

# Separate function for inferability (thanks to inlining of cut)
function _cut(x::AbstractArray{T, N}, breaks::AbstractVector,
extend::Union{Bool, Missing},
labels::Union{AbstractVector{<:AbstractString},Function},
allowempty::Bool=false) where {T, N}
if !allowempty && !allunique(breaks)
throw(ArgumentError("all breaks must be unique unless `allowempty=true`"))
end
Expand All @@ -135,19 +132,38 @@ function cut(x::AbstractArray{T, N}, breaks::AbstractVector;
breaks = sort(breaks)
end

if extend
min_x, max_x = extrema(x)
if extend === true
xnm = T >: Missing ? skipmissing(x) : x
length(breaks) >= 1 || throw(ArgumentError("at least one break must be provided"))
local min_x, max_x
try
min_x, max_x = extrema(xnm)
catch err
if T >: Missing && all(ismissing, xnm)
if length(breaks) < 2
throw(ArgumentError("could not extend breaks as all values are missing: " *
"please specify at least two breaks manually"))
else
min_x, max_x = missing, missing
end
else
rethrow(err)
end
end
if !ismissing(min_x) && breaks[1] > min_x
breaks = [min_x; breaks]
end
if !ismissing(max_x) && breaks[end] < max_x
breaks = [breaks; max_x]
end
length(breaks) > 1 ||
throw(ArgumentError("could not extend breaks as all values are equal: " *
"please specify at least two breaks manually"))
end

refs = Array{DefaultRefType, N}(undef, size(x))
try
fill_refs!(refs, x, breaks, extend, allowmissing)
fill_refs!(refs, x, breaks, extend)
catch err
# So that the error appears to come from cut() itself,
# since it refers to its keyword arguments
Expand All @@ -159,6 +175,7 @@ function cut(x::AbstractArray{T, N}, breaks::AbstractVector;
end

n = length(breaks)
n >= 2 || throw(ArgumentError("at least two breaks must be provided when extend is not true"))
if labels isa Function
from = map(x -> sprint(show, x, context=:compact=>true), breaks[1:n-1])
to = map(x -> sprint(show, x, context=:compact=>true), breaks[2:n])
Expand All @@ -168,7 +185,8 @@ function cut(x::AbstractArray{T, N}, breaks::AbstractVector;
leftclosed=breaks[i] != breaks[i+1], rightclosed=false)
end
levs[end] = labels(from[end], to[end], n-1,
leftclosed=breaks[end-1] != breaks[end], rightclosed=extend)
leftclosed=breaks[end-1] != breaks[end],
rightclosed=coalesce(extend, false))
else
length(labels) == n-1 || throw(ArgumentError("labels must be of length $(n-1), but got length $(length(labels))"))
# Levels must have element type String for type stability of the result
Expand All @@ -184,7 +202,7 @@ function cut(x::AbstractArray{T, N}, breaks::AbstractVector;
end

pool = CategoricalPool(levs, true)
S = T >: Missing ? Union{String, Missing} : String
S = T >: Missing || extend isa Missing ? Union{String, Missing} : String
CategoricalArray{S, N}(refs, pool)
end

Expand All @@ -203,6 +221,9 @@ quantile_formatter(from, to, i; leftclosed, rightclosed) =
Cut a numeric array into `ngroups` quantiles, determined using `quantile`.
If `x` contains `missing` values, they are automatically skipped when computing
quantiles.
# Keyword arguments
* `labels::Union{AbstractVector,Function}`: a vector of strings giving the names to use for
the intervals; or a function `f(from, to, i; leftclosed, rightclosed)` that generates
Expand All @@ -216,7 +237,8 @@ Cut a numeric array into `ngroups` quantiles, determined using `quantile`.
function cut(x::AbstractArray, ngroups::Integer;
labels::Union{AbstractVector{<:AbstractString},Function}=quantile_formatter,
allowempty::Bool=false)
breaks = Statistics.quantile(x, (1:ngroups-1)/ngroups)
xnm = eltype(x) >: Missing ? skipmissing(x) : x
breaks = Statistics.quantile(xnm, (1:ngroups-1)/ngroups)
if !allowempty && !allunique(breaks)
n = length(unique(breaks)) - 1
throw(ArgumentError("cannot compute $ngroups quantiles: `quantile` " *
Expand Down
50 changes: 32 additions & 18 deletions test/15_extras.jl
Expand Up @@ -12,30 +12,21 @@ const ≅ = isequal
@test levels(x) == ["[1, 3)", "[3, 6)"]

err = @test_throws ArgumentError cut(Vector{Union{T, Int}}([2, 3, 5]), [3, 6])
if T === Missing
@test err.value.msg == "value 2 (at index 1) does not fall inside the breaks: adapt them manually, or pass extend=true or allowmissing=true"
else
@test err.value.msg == "value 2 (at index 1) does not fall inside the breaks: adapt them manually, or pass extend=true"
end
@test err.value.msg == "value 2 (at index 1) does not fall inside the breaks: adapt them manually, or pass extend=true or extend=missing"


err = @test_throws ArgumentError cut(Vector{Union{T, Int}}([2, 3, 5]), [2, 5])
if T === Missing
@test err.value.msg == "value 5 (at index 3) does not fall inside the breaks: adapt them manually, or pass extend=true or allowmissing=true"
else
@test err.value.msg == "value 5 (at index 3) does not fall inside the breaks: adapt them manually, or pass extend=true"
end
@test err.value.msg == "value 5 (at index 3) does not fall inside the breaks: adapt them manually, or pass extend=true or extend=missing"

if T === Missing
x = @inferred cut(Vector{Union{T, Int}}([2, 3, 5]), [2, 5], allowmissing=true)
@test x ["[2, 5)", "[2, 5)", missing]
@test isa(x, CategoricalVector{Union{String, T}})
@test isordered(x)
@test levels(x) == ["[2, 5)"]
x = @inferred cut(Vector{Union{T, Int}}([2, 3, 5]), [2, 5], extend=missing)
else
err = @test_throws ArgumentError cut(Vector{Int}([2, 3, 5]), [2, 5], allowmissing=true)
@test err.value.msg == "value 5 (at index 3) does not fall inside the breaks: adapt them manually, or pass extend=true"
x = cut(Vector{Union{T, Int}}([2, 3, 5]), [2, 5], extend=missing)
end
@test x ["[2, 5)", "[2, 5)", missing]
@test isa(x, CategoricalVector{Union{String, Missing}})
@test isordered(x)
@test levels(x) == ["[2, 5)"]

x = @inferred cut(Vector{Union{T, Int}}([2, 3, 5]), [3, 6], extend=true)
@test x == ["[2, 3)", "[3, 6]", "[3, 6]"]
Expand Down Expand Up @@ -106,7 +97,6 @@ end
@test all(ismissing, y)
end

# TODO: test on arrays supporting missing values once a quantile() method is provided for them
@testset "cut([5, 4, 3, 2], 2)" begin
x = @inferred cut([5, 4, 3, 2], 2)
@test x == ["Q2: [3.5, 5.0]", "Q2: [3.5, 5.0]", "Q1: [2.0, 3.5)", "Q1: [2.0, 3.5)"]
Expand All @@ -115,6 +105,14 @@ end
@test levels(x) == ["Q1: [2.0, 3.5)", "Q2: [3.5, 5.0]"]
end

@testset "cut(x, n) with missing values" begin
x = @inferred cut([5, 4, 3, missing, 2], 2)
@test x ["Q2: [3.5, 5.0]", "Q2: [3.5, 5.0]", "Q1: [2.0, 3.5)", missing, "Q1: [2.0, 3.5)"]
@test isa(x, CategoricalArray)
@test isordered(x)
@test levels(x) == ["Q1: [2.0, 3.5)", "Q2: [3.5, 5.0]"]
end

@testset "cut with formatter function" begin
my_formatter(from, to, i; leftclosed, rightclosed) = "$i: $from -- $to"

Expand Down Expand Up @@ -182,4 +180,20 @@ end
labels=["1", "2", "3", "2", "4"])
end

@testset "cut with extend=true" begin
err = @test_throws ArgumentError cut([1, 1], [], extend=true)
@test err.value.msg == "at least one break must be provided"

err = @test_throws ArgumentError cut([1, 1], [1], extend=true)
@test err.value.msg == "could not extend breaks as all values are equal: please specify at least two breaks manually"

err = @test_throws ArgumentError cut([1, 1, missing], [1], extend=true)
@test err.value.msg == "could not extend breaks as all values are equal: please specify at least two breaks manually"

err = @test_throws ArgumentError cut([missing], [1], extend=true)
@test err.value.msg == "could not extend breaks as all values are missing: please specify at least two breaks manually"

@test cut([missing], [1, 2], extend=true) [missing]
end

end

0 comments on commit 49e822c

Please sign in to comment.