diff --git a/src/counts.jl b/src/counts.jl index 2b017e4b9..3278fce9d 100644 --- a/src/counts.jl +++ b/src/counts.jl @@ -255,7 +255,9 @@ raw counts. - `:dict`: use `Dict`-based method which is generally slower but uses less RAM and is safe for any data type. """ -function addcounts!(cm::Dict{T}, x::AbstractArray{T}; alg = :auto) where T +addcounts!(cm::Dict, x; alg = :auto) = _addcounts!(eltype(x), cm, x, alg = alg) + +function _addcounts!(::Type{T}, cm::Dict, x; alg = :auto) where T # if it's safe to be sorted using radixsort then it should be faster # albeit using more RAM if radixsort_safe(T) && (alg == :auto || alg == :radixsort) @@ -269,7 +271,7 @@ function addcounts!(cm::Dict{T}, x::AbstractArray{T}; alg = :auto) where T end """Dict-based addcounts method""" -function addcounts_dict!(cm::Dict{T}, x::AbstractArray{T}) where T +function addcounts_dict!(cm::Dict{T}, x) where T for v in x index = ht_keyindex2!(cm, v) if index > 0 @@ -286,14 +288,27 @@ end # faster results and less memory usage. However we still wish to enable others # to write generic algorithms, therefore the methods below still accept the # `alg` argument but it is ignored. -function addcounts!(cm::Dict{Bool}, x::AbstractArray{Bool}; alg = :ignored) +function _addcounts!(::Type{Bool}, cm::Dict{Bool}, x::AbstractArray{Bool}; alg = :ignored) sumx = sum(x) cm[true] = get(cm, true, 0) + sumx cm[false] = get(cm, false, 0) + length(x) - sumx cm end -function addcounts!(cm::Dict{T}, x::AbstractArray{T}; alg = :ignored) where T <: Union{UInt8, UInt16, Int8, Int16} +# specialized for `Bool` iterator +function _addcounts!(::Type{Bool}, cm::Dict{Bool}, x; alg = :ignored) + sumx = 0 + len = 0 + for i in x + sumx += i + len += 1 + end + cm[true] = get(cm, true, 0) + sumx + cm[false] = get(cm, false, 0) + len - sumx + cm +end + +function _addcounts!(::Type{T}, cm::Dict{T}, x; alg = :ignored) where T <: Union{UInt8, UInt16, Int8, Int16} counts = zeros(Int, 2^(8sizeof(T))) @inbounds for xi in x @@ -318,8 +333,7 @@ const BaseRadixSortSafeTypes = Union{Int8, Int16, Int32, Int64, Int128, Float32, Float64} "Can the type be safely sorted by radixsort" -radixsort_safe(::Type{T}) where {T<:BaseRadixSortSafeTypes} = true -radixsort_safe(::Type) = false +radixsort_safe(::Type{T}) where T = T<:BaseRadixSortSafeTypes function _addcounts_radix_sort_loop!(cm::Dict{T}, sx::AbstractArray{T}) where T last_sx = sx[1] @@ -353,6 +367,12 @@ function addcounts_radixsort!(cm::Dict{T}, x::AbstractArray{T}) where T return _addcounts_radix_sort_loop!(cm, sx) end +# fall-back for `x` an iterator +function addcounts_radixsort!(cm::Dict{T}, x) where T + sx = sort!(collect(x), alg = RadixSort) + return _addcounts_radix_sort_loop!(cm, sx) +end + function addcounts!(cm::Dict{T}, x::AbstractArray{T}, wv::AbstractVector{W}) where {T,W<:Real} n = length(x) length(wv) == n || throw(DimensionMismatch()) @@ -386,7 +406,7 @@ of occurrences. - `:dict`: use `Dict`-based method which is generally slower but uses less RAM and is safe for any data type. """ -countmap(x::AbstractArray{T}; alg = :auto) where {T} = addcounts!(Dict{T,Int}(), x; alg = alg) +countmap(x; alg = :auto) = addcounts!(Dict{eltype(x),Int}(), x; alg = alg) countmap(x::AbstractArray{T}, wv::AbstractVector{W}) where {T,W<:Real} = addcounts!(Dict{T,W}(), x, wv) diff --git a/test/counts.jl b/test/counts.jl index 2fd508327..9f684df86 100644 --- a/test/counts.jl +++ b/test/counts.jl @@ -80,6 +80,14 @@ cm = countmap(x) @test cm["a"] == 3 @test cm["b"] == 2 @test cm["c"] == 1 + +# iterator, non-radixsort +cm_missing = countmap(skipmissing(x)) +cm_any_itr = countmap((i for i in x)) +@test cm_missing == cm_any_itr == cm +@test cm_missing isa Dict{String, Int} +@test cm_any_itr isa Dict{Any, Int} + pm = proportionmap(x) @test pm["a"] ≈ (1/2) @test pm["b"] ≈ (1/3) @@ -91,6 +99,15 @@ xx = repeat([6, 1, 3, 1], outer=100_000) cm = countmap(xx) @test cm == Dict(1 => 200_000, 3 => 100_000, 6 => 100_000) +# with iterator +cm_missing = countmap(skipmissing(xx)) +@test cm_missing isa Dict{Int, Int} +@test cm_missing == cm + +cm_any_itr = countmap((i for i in xx)) +@test cm_any_itr isa Dict{Any,Int} # no knowledge about type +@test cm_missing == cm + # testing the radixsort-based addcounts xx = repeat([6, 1, 3, 1], outer=100_000) cm = Dict{Int, Int}() @@ -99,11 +116,20 @@ StatsBase.addcounts_radixsort!(cm,xx) xx2 = repeat([7, 1, 3, 1], outer=100_000) StatsBase.addcounts_radixsort!(cm,xx2) @test cm == Dict(1 => 400_000, 3 => 200_000, 6 => 100_000, 7 => 100_000) +# with iterator +cm_missing = Dict{Int, Int}() +StatsBase.addcounts_radixsort!(cm_missing,skipmissing(xx)) +@test cm_missing == Dict(1 => 200_000, 3 => 100_000, 6 => 100_000) +StatsBase.addcounts_radixsort!(cm_missing,skipmissing(xx2)) +@test cm_missing == Dict(1 => 400_000, 3 => 200_000, 6 => 100_000, 7 => 100_000) # testing the Dict-based addcounts cm = Dict{Int, Int}() +cm_itr = Dict{Int, Int}() StatsBase.addcounts_dict!(cm,xx) -@test cm == Dict(1 => 200_000, 3 => 100_000, 6 => 100_000) +StatsBase.addcounts_dict!(cm_itr,skipmissing(xx)) +@test cm_itr == cm == Dict(1 => 200_000, 3 => 100_000, 6 => 100_000) +@test cm_itr isa Dict{Int, Int} cm = countmap(x, weights(w)) @test cm["a"] == 5.5 @@ -119,11 +145,16 @@ pm = proportionmap(x, weights(w)) # testing small bits type bx = [true, false, true, true, false] -@test countmap(bx) == Dict(true => 3, false => 2) +cm_bx_missing = countmap(skipmissing(bx)) +@test cm_bx_missing == countmap(bx) == Dict(true => 3, false => 2) +@test cm_bx_missing isa Dict{Bool, Int} for T in [UInt8, UInt16, Int8, Int16] tx = T[typemin(T), 8, typemax(T), 19, 8] - @test countmap(tx) == Dict(typemin(T) => 1, typemax(T) => 1, 8 => 2, 19 => 1) + tx_missing = skipmissing(T[typemin(T), 8, typemax(T), 19, 8]) + cm_tx_missing = countmap(tx_missing) + @test cm_tx_missing == countmap(tx) == Dict(typemin(T) => 1, typemax(T) => 1, 8 => 2, 19 => 1) + @test cm_tx_missing isa Dict{T, Int} end @testset "views" begin