diff --git a/src/array.jl b/src/array.jl index b708f897..02d599b5 100644 --- a/src/array.jl +++ b/src/array.jl @@ -1050,3 +1050,26 @@ StructTypes.construct(::Type{<:CategoricalArray{Union{Nothing, T}}}, A::Vector) where {T} = categoricalnothing(T, A) categoricalnothing(T, A::AbstractVector) = CategoricalArray{Union{Nothing, T}}(A) + +# DataAPI refarray/refvalue/refpool support +struct CategoricalRefPool{T, P} <: AbstractVector{T} + pool::P +end + +Base.IndexStyle(::Type{<: CategoricalRefPool}) = Base.IndexLinear() +@inline function Base.getindex(x::CategoricalRefPool, i::Int) + @boundscheck checkbounds(x, i) + i > 0 ? @inbounds(x.pool[i]) : missing +end +Base.size(x::CategoricalRefPool{T}) where {T} = (length(x.pool) + (T >: Missing),) +Base.axes(x::CategoricalRefPool{T}) where {T} = + ((T >: Missing ? 0 : 1):length(x.pool),) + +DataAPI.refarray(A::CatArrOrSub) = refs(A) +@inline function DataAPI.refvalue(A::CatArrOrSub{T}, i::Integer) where T + @boundscheck checkindex(Bool, (T >: Missing ? 0 : 1):length(pool(A)), i) || + throw(BoundsError()) + i > 0 ? @inbounds(pool(A)[i]) : missing +end +DataAPI.refpool(A::CatArrOrSub{T}) where {T} = + CategoricalRefPool{eltype(A), typeof(pool(A))}(pool(A)) diff --git a/test/13_arraycommon.jl b/test/13_arraycommon.jl index fdd6484b..7fd8d0ee 100644 --- a/test/13_arraycommon.jl +++ b/test/13_arraycommon.jl @@ -2027,4 +2027,35 @@ StructTypes.StructType(::Type{<:MyCustomType}) = StructTypes.Struct() @test levels(readx.var) == levels(x.var) end +@testset "refarray, refvalue and refpool" begin + for y in (categorical(["b", "a", "c", "b"]), + view(categorical(["a", "a", "c", "b"]), 1:3), + categorical(["b" missing; "a" "c"; "b" missing]), + view(categorical(["b" missing; "a" "c"; "b" missing]), 2:3, 1)) + @test DataAPI.refarray(y) === CategoricalArrays.refs(y) + @test DataAPI.refvalue.(Ref(y), DataAPI.refarray(y)) ≅ y + @test DataAPI.getindex.(Ref(DataAPI.refpool(y)), DataAPI.refarray(y)) ≅ y + @test_throws BoundsError DataAPI.refvalue(y, -1) + @test_throws BoundsError DataAPI.refvalue(y, length(levels(y))+1) + if !(eltype(y) >: Missing) + @test_throws BoundsError DataAPI.refvalue(y, 0) + end + + rp = DataAPI.refpool(y) + @test rp isa AbstractVector{eltype(y)} + if eltype(y) >: Missing + @test collect(rp) ≅ [missing; levels(y)] + @test size(rp) == (length(levels(y)) + 1,) + @test axes(rp) == (0:length(levels(y)),) + else + @test collect(rp) == levels(y) + @test size(rp) == (length(levels(y)),) + @test axes(rp) == (1:length(levels(y)),) + @test_throws BoundsError rp[0] + end + @test_throws BoundsError rp[-1] + @test_throws BoundsError rp[end + 1] + end +end + end