diff --git a/base/bitarray.jl b/base/bitarray.jl index aea54aa3385bf..afae953078d4d 100644 --- a/base/bitarray.jl +++ b/base/bitarray.jl @@ -1510,37 +1510,70 @@ function findprev(testf::Function, B::BitArray, start::Integer) end #findlast(testf::Function, B::BitArray) = findprev(testf, B, 1) ## defined in array.jl +# findall helper functions +# Generic case (>2 dimensions) +function allindices!(I, B::BitArray) + ind = first(keys(B)) + for k = 1:length(B) + I[k] = ind + ind = nextind(B, ind) + end +end + +# Optimized case for vector +function allindices!(I, B::BitVector) + I[:] .= 1:length(B) +end + +# Optimized case for matrix +function allindices!(I, B::BitMatrix) + k = 1 + for c = 1:size(B,2), r = 1:size(B,1) + I[k] = CartesianIndex(r, c) + k += 1 + end +end + +@inline _overflowind(i1, irest::Tuple{}, size) = (i1, irest) +@inline function _overflowind(i1, irest, size) + i2 = irest[1] + while i1 > size[1] + i1 -= size[1] + i2 += 1 + end + i2, irest = _overflowind(i2, tail(irest), tail(size)) + return (i1, (i2, irest...)) +end + +@inline _toind(i1, irest::Tuple{}) = i1 +@inline _toind(i1, irest) = CartesianIndex(i1, irest...) + function findall(B::BitArray) - l = length(B) nnzB = count(B) - ind = first(keys(B)) - I = Vector{typeof(ind)}(undef, nnzB) + I = Vector{eltype(keys(B))}(undef, nnzB) nnzB == 0 && return I + nnzB == length(B) && (allindices!(I, B); return I) Bc = B.chunks - Icount = 1 - for i = 1:length(Bc)-1 - u = UInt64(1) - c = Bc[i] - for j = 1:64 - if c & u != 0 - I[Icount] = ind - Icount += 1 - end - ind = nextind(B, ind) - u <<= 1 - end - end - u = UInt64(1) - c = Bc[end] - for j = 0:_mod64(l-1) - if c & u != 0 - I[Icount] = ind - Icount += 1 + Bs = size(B) + Bi = i1 = i = 1 + irest = ntuple(one, ndims(B) - 1) + c = Bc[1] + @inbounds while true + while c == 0 + Bi == length(Bc) && return I + i1 += 64 + Bi += 1 + c = Bc[Bi] end - ind = nextind(B, ind) - u <<= 1 + + tz = trailing_zeros(c) + c = _blsr(c) + + i1, irest = _overflowind(i1 + tz, irest, Bs) + I[i] = _toind(i1, irest) + i += 1 + i1 -= tz end - return I end # For performance diff --git a/test/bitarray.jl b/test/bitarray.jl index 939991b39858b..d1ac4bb5ed061 100644 --- a/test/bitarray.jl +++ b/test/bitarray.jl @@ -1167,9 +1167,30 @@ timesofar("datamove") @test findnextnot((.~(b1 >> i)) .⊻ submask, j) == i+1 end + # Do a few more thorough tests for findall b1 = bitrand(n1, n2) @check_bit_operation findall(b1) Vector{CartesianIndex{2}} @check_bit_operation findall(!iszero, b1) Vector{CartesianIndex{2}} + + # tall-and-skinny (test index overflow logic in findall) + @check_bit_operation findall(bitrand(1, 1, 1, 250)) Vector{CartesianIndex{4}} + + # empty dimensions + @check_bit_operation findall(bitrand(0, 0, 10)) Vector{CartesianIndex{3}} + + # sparse (test empty 64-bit chunks in findall) + b1 = falses(8, 8, 8) + b1[3,3,3] = b1[6,6,6] = true + @check_bit_operation findall(b1) Vector{CartesianIndex{3}} + + # BitArrays of various dimensions + for dims = 0:8 + t = Tuple(fill(2, dims)) + ret_type = Vector{dims == 1 ? Int : CartesianIndex{dims}} + @check_bit_operation findall(trues(t)) ret_type + @check_bit_operation findall(falses(t)) ret_type + @check_bit_operation findall(bitrand(t)) ret_type + end end timesofar("find")