Skip to content

Commit

Permalink
Merge pull request #221 from Workiva/add-bitarray-getsetbits
Browse files Browse the repository at this point in the history
Add `GetSetBits` and `Count` to `BitArray`
  • Loading branch information
dustinhiatt-wf committed May 18, 2023
2 parents c466da2 + 8f1c722 commit 68e77ee
Show file tree
Hide file tree
Showing 5 changed files with 254 additions and 2 deletions.
80 changes: 79 additions & 1 deletion bitarray/bitarray.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ efficient way. This is *NOT* a threadsafe package.
*/
package bitarray

import "math/bits"

// bitArray is a struct that maintains state of a bit array.
type bitArray struct {
blocks []block
Expand Down Expand Up @@ -116,7 +118,74 @@ func (ba *bitArray) GetBit(k uint64) (bool, error) {
return result, nil
}

//ClearBit will unset a bit at the given index if it is set.
// GetSetBits gets the position of bits set in the array.
func (ba *bitArray) GetSetBits(from uint64, buffer []uint64) []uint64 {
fromBlockIndex, fromOffset := getIndexAndRemainder(from)
return getSetBitsInBlocks(
fromBlockIndex,
fromOffset,
ba.blocks[fromBlockIndex:],
nil,
buffer,
)
}

// getSetBitsInBlocks fills a buffer with positions of set bits in the provided blocks. Optionally, indices may be
// provided for sparse/non-consecutive blocks.
func getSetBitsInBlocks(
fromBlockIndex, fromOffset uint64,
blocks []block,
indices []uint64,
buffer []uint64,
) []uint64 {
bufferCapacity := cap(buffer)
if bufferCapacity == 0 {
return buffer[:0]
}

results := buffer[:bufferCapacity]
resultSize := 0

for i, block := range blocks {
blockIndex := fromBlockIndex + uint64(i)
if indices != nil {
blockIndex = indices[i]
}

isFirstBlock := blockIndex == fromBlockIndex
if isFirstBlock {
block >>= fromOffset
}

for block != 0 {
trailing := bits.TrailingZeros64(uint64(block))

if isFirstBlock {
results[resultSize] = uint64(trailing) + (blockIndex << 6) + fromOffset
} else {
results[resultSize] = uint64(trailing) + (blockIndex << 6)
}
resultSize++

if resultSize == cap(results) {
return results[:resultSize]
}

// Clear the bit we just added to the result, which is the last bit set in the block. Ex.:
// block 01001100
// ^block 10110011
// (^block) + 1 10110100
// block & (^block) + 1 00000100
// block ^ mask 01001000
mask := block & ((^block) + 1)
block = block ^ mask
}
}

return results[:resultSize]
}

// ClearBit will unset a bit at the given index if it is set.
func (ba *bitArray) ClearBit(k uint64) error {
if k >= ba.Capacity() {
return OutOfRangeError(k)
Expand All @@ -137,6 +206,15 @@ func (ba *bitArray) ClearBit(k uint64) error {
return nil
}

// Count returns the number of set bits in this array.
func (ba *bitArray) Count() int {
count := 0
for _, block := range ba.blocks {
count += bits.OnesCount64(uint64(block))
}
return count
}

// Or will bitwise or two bit arrays and return a new bit array
// representing the result.
func (ba *bitArray) Or(other BitArray) BitArray {
Expand Down
70 changes: 70 additions & 0 deletions bitarray/bitarray_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestBitOperations(t *testing.T) {
Expand Down Expand Up @@ -142,6 +143,28 @@ func TestIsEmpty(t *testing.T) {
assert.False(t, ba.IsEmpty())
}

func TestCount(t *testing.T) {
ba := newBitArray(500)
assert.Equal(t, 0, ba.Count())

require.NoError(t, ba.SetBit(0))
assert.Equal(t, 1, ba.Count())

require.NoError(t, ba.SetBit(40))
require.NoError(t, ba.SetBit(64))
require.NoError(t, ba.SetBit(100))
require.NoError(t, ba.SetBit(200))
require.NoError(t, ba.SetBit(469))
require.NoError(t, ba.SetBit(500))
assert.Equal(t, 7, ba.Count())

require.NoError(t, ba.ClearBit(200))
assert.Equal(t, 6, ba.Count())

ba.Reset()
assert.Equal(t, 0, ba.Count())
}

func TestClear(t *testing.T) {
ba := newBitArray(10)

Expand Down Expand Up @@ -195,6 +218,53 @@ func BenchmarkGetBit(b *testing.B) {
}
}

func TestGetSetBits(t *testing.T) {
ba := newBitArray(1000)
buf := make([]uint64, 0, 5)

require.NoError(t, ba.SetBit(1))
require.NoError(t, ba.SetBit(4))
require.NoError(t, ba.SetBit(8))
require.NoError(t, ba.SetBit(63))
require.NoError(t, ba.SetBit(64))
require.NoError(t, ba.SetBit(200))
require.NoError(t, ba.SetBit(1000))

assert.Equal(t, []uint64(nil), ba.GetSetBits(0, nil))
assert.Equal(t, []uint64{}, ba.GetSetBits(0, []uint64{}))

assert.Equal(t, []uint64{1, 4, 8, 63, 64}, ba.GetSetBits(0, buf))
assert.Equal(t, []uint64{63, 64, 200, 1000}, ba.GetSetBits(10, buf))
assert.Equal(t, []uint64{63, 64, 200, 1000}, ba.GetSetBits(63, buf))
assert.Equal(t, []uint64{200, 1000}, ba.GetSetBits(128, buf))

require.NoError(t, ba.ClearBit(4))
require.NoError(t, ba.ClearBit(64))
assert.Equal(t, []uint64{1, 8, 63, 200, 1000}, ba.GetSetBits(0, buf))
assert.Empty(t, ba.GetSetBits(1001, buf))

ba.Reset()
assert.Empty(t, ba.GetSetBits(0, buf))
}

func BenchmarkGetSetBits(b *testing.B) {
numItems := uint64(168000)

ba := newBitArray(numItems)
for i := uint64(0); i < numItems; i++ {
if i%13 == 0 || i%5 == 0 {
require.NoError(b, ba.SetBit(i))
}
}

buf := make([]uint64, 0, ba.Capacity())

b.ResetTimer()
for i := 0; i < b.N; i++ {
ba.GetSetBits(0, buf)
}
}

func TestEquality(t *testing.T) {
ba := newBitArray(s + 1)
other := newBitArray(s + 1)
Expand Down
6 changes: 6 additions & 0 deletions bitarray/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ type BitArray interface {
// function returns an error if the position is out
// of range. A sparse bit array never returns an error.
GetBit(k uint64) (bool, error)
// GetSetBits gets the position of bits set in the array. Will
// return as many set bits as can fit in the provided buffer
// starting from the specified position in the array.
GetSetBits(from uint64, buffer []uint64) []uint64
// ClearBit clears the bit at the given position. This
// function returns an error if the position is out
// of range. A sparse bit array never returns an error.
Expand All @@ -55,6 +59,8 @@ type BitArray interface {
// in the case of a dense bit array or the highest possible
// seen capacity of the sparse array.
Capacity() uint64
// Count returns the number of set bits in this array.
Count() int
// Or will bitwise or the two bitarrays and return a new bitarray
// representing the result.
Or(other BitArray) BitArray
Expand Down
32 changes: 31 additions & 1 deletion bitarray/sparse_bitarray.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ limitations under the License.

package bitarray

import "sort"
import (
"math/bits"
"sort"
)

// uintSlice is an alias for a slice of ints. Len, Swap, and Less
// are exported to fulfill an interface needed for the search
Expand Down Expand Up @@ -127,6 +130,24 @@ func (sba *sparseBitArray) GetBit(k uint64) (bool, error) {
return sba.blocks[i].get(position), nil
}

// GetSetBits gets the position of bits set in the array.
func (sba *sparseBitArray) GetSetBits(from uint64, buffer []uint64) []uint64 {
fromBlockIndex, fromOffset := getIndexAndRemainder(from)

fromBlockLocation := sba.indices.search(fromBlockIndex)
if int(fromBlockLocation) == len(sba.indices) {
return buffer[:0]
}

return getSetBitsInBlocks(
fromBlockIndex,
fromOffset,
sba.blocks[fromBlockLocation:],
sba.indices[fromBlockLocation:],
buffer,
)
}

// ToNums converts this sparse bitarray to a list of numbers contained
// within it.
func (sba *sparseBitArray) ToNums() []uint64 {
Expand Down Expand Up @@ -225,6 +246,15 @@ func (sba *sparseBitArray) Equals(other BitArray) bool {
return true
}

// Count returns the number of set bits in this array.
func (sba *sparseBitArray) Count() int {
count := 0
for _, block := range sba.blocks {
count += bits.OnesCount64(uint64(block))
}
return count
}

// Or will perform a bitwise or operation with the provided bitarray and
// return a new result bitarray.
func (sba *sparseBitArray) Or(other BitArray) BitArray {
Expand Down
68 changes: 68 additions & 0 deletions bitarray/sparse_bitarray_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestGetCompressedBit(t *testing.T) {
Expand Down Expand Up @@ -76,6 +77,73 @@ func BenchmarkSetCompressedBit(b *testing.B) {
}
}

func TestGetSetCompressedBits(t *testing.T) {
ba := newSparseBitArray()
buf := make([]uint64, 0, 5)

require.NoError(t, ba.SetBit(1))
require.NoError(t, ba.SetBit(4))
require.NoError(t, ba.SetBit(8))
require.NoError(t, ba.SetBit(63))
require.NoError(t, ba.SetBit(64))
require.NoError(t, ba.SetBit(200))
require.NoError(t, ba.SetBit(1000))

assert.Equal(t, []uint64(nil), ba.GetSetBits(0, nil))
assert.Equal(t, []uint64{}, ba.GetSetBits(0, []uint64{}))

assert.Equal(t, []uint64{1, 4, 8, 63, 64}, ba.GetSetBits(0, buf))
assert.Equal(t, []uint64{63, 64, 200, 1000}, ba.GetSetBits(10, buf))
assert.Equal(t, []uint64{63, 64, 200, 1000}, ba.GetSetBits(63, buf))
assert.Equal(t, []uint64{200, 1000}, ba.GetSetBits(128, buf))

require.NoError(t, ba.ClearBit(4))
require.NoError(t, ba.ClearBit(64))
assert.Equal(t, []uint64{1, 8, 63, 200, 1000}, ba.GetSetBits(0, buf))
assert.Empty(t, ba.GetSetBits(1001, buf))

ba.Reset()
assert.Empty(t, ba.GetSetBits(0, buf))
}

func BenchmarkGetSetCompressedBits(b *testing.B) {
ba := newSparseBitArray()
for i := uint64(0); i < 168000; i++ {
if i%13 == 0 || i%5 == 0 {
require.NoError(b, ba.SetBit(i))
}
}

buf := make([]uint64, 0, ba.Capacity())

b.ResetTimer()
for i := 0; i < b.N; i++ {
ba.GetSetBits(0, buf)
}
}

func TestCompressedCount(t *testing.T) {
ba := newSparseBitArray()
assert.Equal(t, 0, ba.Count())

require.NoError(t, ba.SetBit(0))
assert.Equal(t, 1, ba.Count())

require.NoError(t, ba.SetBit(40))
require.NoError(t, ba.SetBit(64))
require.NoError(t, ba.SetBit(100))
require.NoError(t, ba.SetBit(200))
require.NoError(t, ba.SetBit(469))
require.NoError(t, ba.SetBit(500))
assert.Equal(t, 7, ba.Count())

require.NoError(t, ba.ClearBit(200))
assert.Equal(t, 6, ba.Count())

ba.Reset()
assert.Equal(t, 0, ba.Count())
}

func TestClearCompressedBit(t *testing.T) {
ba := newSparseBitArray()
ba.SetBit(5)
Expand Down

0 comments on commit 68e77ee

Please sign in to comment.