diff --git a/src/array.jl b/src/array.jl index b3c845ed..9116a0b8 100644 --- a/src/array.jl +++ b/src/array.jl @@ -1,7 +1,7 @@ ## Common code for CategoricalArray and NullableCategoricalArray import Base: convert, copy, copy!, getindex, setindex!, similar, size, - linearindexing, unique, vcat + linearindexing, unique, vcat, in # Used for keyword argument default value _isordered(x::AbstractCategoricalArray) = isordered(x) @@ -711,3 +711,17 @@ levels!(A::CategoricalArray, newlevels::Vector) = _levels!(A, newlevels) droplevels!(A::CategoricalArray) = levels!(A, unique(A)) unique(A::CategoricalArray) = _unique(Array, A.refs, A.pool) + +function in{T, N, R}(x::Any, y::CategoricalArray{T, N, R}) + ref = get(y.pool, x, zero(R)) + ref != 0 ? ref in y.refs : false +end + +function in{T, N, R}(x::CategoricalValue, y::CategoricalArray{T, N, R}) + if x.pool === y.pool + return x.level in y.refs + else + ref = get(y.pool, index(x.pool)[x.level], zero(R)) + return ref != 0 ? ref in y.refs : false + end +end diff --git a/src/nullablearray.jl b/src/nullablearray.jl index 42f3eda2..4def28a1 100644 --- a/src/nullablearray.jl +++ b/src/nullablearray.jl @@ -1,4 +1,4 @@ -import Base: convert, getindex, setindex!, similar +import Base: convert, getindex, setindex!, similar, in using NullableArrays: NullableArray ## Constructors and converters @@ -137,3 +137,21 @@ levels!(A::NullableCategoricalArray, newlevels::Vector; nullok=false) = _levels! droplevels!(A::NullableCategoricalArray) = levels!(A, _unique(Array, A.refs, A.pool)) unique{T}(A::NullableCategoricalArray{T}) = _unique(NullableArray{T}, A.refs, A.pool) + +function in{T, N, R}(x::Nullable, y::NullableCategoricalArray{T, N, R}) + ref = get(y.pool, get(x), zero(R)) + ref != 0 ? ref in y.refs : false +end + +function in{S<:CategoricalValue, T, N, R}(x::Nullable{S}, y::NullableCategoricalArray{T, N, R}) + v = get(x) + if v.pool === y.pool + return v.level in y.refs + else + ref = get(y.pool, index(v.pool)[v.level], zero(R)) + return ref != 0 ? ref in y.refs : false + end +end + +in{T, N, R}(x::Any, y::NullableCategoricalArray{T, N, R}) = false +in{T, N, R}(x::CategoricalValue, y::NullableCategoricalArray{T, N, R}) = false diff --git a/src/pool.jl b/src/pool.jl index 52d03d9a..6aa246d5 100644 --- a/src/pool.jl +++ b/src/pool.jl @@ -86,6 +86,7 @@ Base.length(pool::CategoricalPool) = length(pool.index) Base.getindex(pool::CategoricalPool, i::Integer) = pool.valindex[i] Base.get(pool::CategoricalPool, level::Any) = pool.invindex[level] +Base.get(pool::CategoricalPool, level::Any, default::Any) = get(pool.invindex, level, default) function Base.get!{T, R, V}(pool::CategoricalPool{T, R, V}, level) get!(pool.invindex, level) do diff --git a/src/value.jl b/src/value.jl index 1981f300..b80e5024 100644 --- a/src/value.jl +++ b/src/value.jl @@ -62,6 +62,9 @@ end Base.isequal(x::CategoricalValue, y::Any) = isequal(index(x.pool)[x.level], y) Base.isequal(x::Any, y::CategoricalValue) = isequal(y, x) +Base.in(x::CategoricalValue, y::Any) = index(x.pool)[x.level] in y +Base.in{T<:Integer}(x::CategoricalValue, y::Range{T}) = index(x.pool)[x.level] in y + Base.hash(x::CategoricalValue, h::UInt) = hash(index(x.pool)[x.level], h) function Base.isless{S, T}(x::CategoricalValue{S}, y::CategoricalValue{T}) diff --git a/test/08_equality.jl b/test/08_equality.jl index 0438a64a..3b7d7423 100644 --- a/test/08_equality.jl +++ b/test/08_equality.jl @@ -142,4 +142,13 @@ module TestEquality @test (ov2b == nv2a) === false @test (ov2b == nv1b) === false @test (ov2b == nv2b) === true + + # Check in() + pool = CategoricalPool([5, 1, 3]) + nv = CategoricalValue(2, pool) + + @test (nv in 1:3) === true + @test (nv in [1, 2, 3]) === true + @test (nv in 2:3) === false + @test (nv in [2, 3]) === false end diff --git a/test/13_arraycommon.jl b/test/13_arraycommon.jl index 6c438698..fe116152 100644 --- a/test/13_arraycommon.jl +++ b/test/13_arraycommon.jl @@ -443,4 +443,29 @@ x = NullableCategoricalArray(1) @test_throws NullException CategoricalArray(x) @test_throws NullException convert(CategoricalArray, x) + +# Test in() +ca1 = CategoricalArray([1, 2, 3]) +ca2 = CategoricalArray([4, 3, 2]) + +@test (ca1[1] in ca1) === true +@test (ca2[2] in ca1) === true +@test (ca2[1] in ca1) === false + +@test (1 in ca1) === true +@test (5 in ca1) === false + +nca1 = NullableCategoricalArray([1, 2, 3]) +nca2 = NullableCategoricalArray([4, 3, 2]) + +@test (ca1[1] in nca1) === false +@test (1 in nca1) === false + +@test (nca1[1] in nca1) === true +@test (nca2[2] in nca1) === true +@test (nca2[1] in nca1) === false + +@test (Nullable(1) in nca1) === true +@test (Nullable(5) in nca1) === false + end