Skip to content

Commit

Permalink
Add in() methods
Browse files Browse the repository at this point in the history
Fixes comparisons between categorical values/arrays and plain values/arrays,
and for optimization between categorical values/arrays.
  • Loading branch information
nalimilan committed May 7, 2017
1 parent 8434ae6 commit e380d86
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 2 deletions.
16 changes: 15 additions & 1 deletion src/array.jl
@@ -1,7 +1,7 @@
## Common code for CategoricalArray and NullableCategoricalArray

import Base: convert, copy, copy!, getindex, setindex!, similar, size,
linearindexing, unique, vcat
linearindexing, unique, vcat, in

# Used for keyword argument default value
_isordered(x::AbstractCategoricalArray) = isordered(x)
Expand Down Expand Up @@ -711,3 +711,17 @@ levels!(A::CategoricalArray, newlevels::Vector) = _levels!(A, newlevels)
droplevels!(A::CategoricalArray) = levels!(A, unique(A))

unique(A::CategoricalArray) = _unique(Array, A.refs, A.pool)

function in{T, N, R}(x::Any, y::CategoricalArray{T, N, R})
ref = get(y.pool, x, zero(R))
ref != 0 ? ref in y.refs : false
end

function in{T, N, R}(x::CategoricalValue, y::CategoricalArray{T, N, R})
if x.pool === y.pool
return x.level in y.refs
else
ref = get(y.pool, index(x.pool)[x.level], zero(R))
return ref != 0 ? ref in y.refs : false
end
end
20 changes: 19 additions & 1 deletion src/nullablearray.jl
@@ -1,4 +1,4 @@
import Base: convert, getindex, setindex!, similar
import Base: convert, getindex, setindex!, similar, in
using NullableArrays: NullableArray

## Constructors and converters
Expand Down Expand Up @@ -137,3 +137,21 @@ levels!(A::NullableCategoricalArray, newlevels::Vector; nullok=false) = _levels!
droplevels!(A::NullableCategoricalArray) = levels!(A, _unique(Array, A.refs, A.pool))

unique{T}(A::NullableCategoricalArray{T}) = _unique(NullableArray{T}, A.refs, A.pool)

function in{T, N, R}(x::Nullable, y::NullableCategoricalArray{T, N, R})
ref = get(y.pool, get(x), zero(R))
ref != 0 ? ref in y.refs : false
end

function in{S<:CategoricalValue, T, N, R}(x::Nullable{S}, y::NullableCategoricalArray{T, N, R})
v = get(x)
if v.pool === y.pool
return v.level in y.refs
else
ref = get(y.pool, index(v.pool)[v.level], zero(R))
return ref != 0 ? ref in y.refs : false
end
end

in{T, N, R}(x::Any, y::NullableCategoricalArray{T, N, R}) = false
in{T, N, R}(x::CategoricalValue, y::NullableCategoricalArray{T, N, R}) = false
1 change: 1 addition & 0 deletions src/pool.jl
Expand Up @@ -86,6 +86,7 @@ Base.length(pool::CategoricalPool) = length(pool.index)

Base.getindex(pool::CategoricalPool, i::Integer) = pool.valindex[i]
Base.get(pool::CategoricalPool, level::Any) = pool.invindex[level]
Base.get(pool::CategoricalPool, level::Any, default::Any) = get(pool.invindex, level, default)

function Base.get!{T, R, V}(pool::CategoricalPool{T, R, V}, level)
get!(pool.invindex, level) do
Expand Down
3 changes: 3 additions & 0 deletions src/value.jl
Expand Up @@ -62,6 +62,9 @@ end
Base.isequal(x::CategoricalValue, y::Any) = isequal(index(x.pool)[x.level], y)
Base.isequal(x::Any, y::CategoricalValue) = isequal(y, x)

Base.in(x::CategoricalValue, y::Any) = index(x.pool)[x.level] in y
Base.in{T<:Integer}(x::CategoricalValue, y::Range{T}) = index(x.pool)[x.level] in y

Base.hash(x::CategoricalValue, h::UInt) = hash(index(x.pool)[x.level], h)

function Base.isless{S, T}(x::CategoricalValue{S}, y::CategoricalValue{T})
Expand Down
9 changes: 9 additions & 0 deletions test/08_equality.jl
Expand Up @@ -142,4 +142,13 @@ module TestEquality
@test (ov2b == nv2a) === false
@test (ov2b == nv1b) === false
@test (ov2b == nv2b) === true

# Check in()
pool = CategoricalPool([5, 1, 3])
nv = CategoricalValue(2, pool)

@test (nv in 1:3) === true
@test (nv in [1, 2, 3]) === true
@test (nv in 2:3) === false
@test (nv in [2, 3]) === false
end
25 changes: 25 additions & 0 deletions test/13_arraycommon.jl
Expand Up @@ -443,4 +443,29 @@ x = NullableCategoricalArray(1)
@test_throws NullException CategoricalArray(x)
@test_throws NullException convert(CategoricalArray, x)


# Test in()
ca1 = CategoricalArray([1, 2, 3])
ca2 = CategoricalArray([4, 3, 2])

@test (ca1[1] in ca1) === true
@test (ca2[2] in ca1) === true
@test (ca2[1] in ca1) === false

@test (1 in ca1) === true
@test (5 in ca1) === false

nca1 = NullableCategoricalArray([1, 2, 3])
nca2 = NullableCategoricalArray([4, 3, 2])

@test (ca1[1] in nca1) === false
@test (1 in nca1) === false

@test (nca1[1] in nca1) === true
@test (nca2[2] in nca1) === true
@test (nca2[1] in nca1) === false

@test (Nullable(1) in nca1) === true
@test (Nullable(5) in nca1) === false

end

0 comments on commit e380d86

Please sign in to comment.