Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 176 additions & 2 deletions arrow/compute/internal/kernels/vector_selection.go
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,153 @@ func (c *chunkedPrimitiveGetter[T]) GetValue(i int64) T {
func (c *chunkedPrimitiveGetter[T]) NullCount() int64 { return c.nulls }
func (c *chunkedPrimitiveGetter[T]) Len() int64 { return c.len }

// isSorted checks if indices are monotonically increasing (sorted)
// Returns true if sorted, false otherwise
// Uses sampling for large arrays to avoid full scan
func isSorted[IdxT arrow.UintType](indices []IdxT) bool {
if len(indices) <= 1 {
return true
}

// For small arrays, check all elements
if len(indices) < 256 {
for i := 1; i < len(indices); i++ {
if indices[i] < indices[i-1] {
return false
}
}
return true
}

// For larger arrays, sample at regular intervals
// Check first, last, and ~32 samples in between
step := len(indices) / 32
if step < 1 {
step = 1
}

prev := indices[0]
for i := step; i < len(indices); i += step {
if indices[i] < prev {
return false
}
prev = indices[i]
}

// Check last element
if indices[len(indices)-1] < prev {
return false
}

return true
}

// isReverseSorted checks if indices are monotonically decreasing (reverse sorted)
// Uses sampling for large arrays to avoid full scan
func isReverseSorted[IdxT arrow.UintType](indices []IdxT) bool {
if len(indices) <= 1 {
return true
}

// For small arrays, check all elements
if len(indices) < 256 {
for i := 1; i < len(indices); i++ {
if indices[i] > indices[i-1] {
return false
}
}
return true
}

// For larger arrays, sample at regular intervals
step := len(indices) / 32
if step < 1 {
step = 1
}

prev := indices[0]
for i := step; i < len(indices); i += step {
if indices[i] > prev {
return false
}
prev = indices[i]
}

// Check last element
if indices[len(indices)-1] > prev {
return false
}

return true
}

// primitiveTakeImplSorted is optimized for sorted (monotonically increasing) indices
// This enables better CPU cache utilization and branch prediction
func primitiveTakeImplSorted[IdxT arrow.UintType, ValT arrow.IntType](values primitiveGetter[ValT], indices *exec.ArraySpan, out *exec.ExecResult) {
var (
indicesData = exec.GetSpanValues[IdxT](indices, 1)
outData = exec.GetSpanValues[ValT](out, 1)
)

// Fast path: no nulls at all
if values.NullCount() == 0 && indices.Nulls == 0 {
// Try to access underlying values directly for better performance
if valImpl, ok := values.(*primitiveGetterImpl[ValT]); ok {
// Direct memory access for primitiveGetterImpl
valData := valImpl.values
// Unroll loop for better performance
i := 0
for ; i+4 <= len(indicesData); i += 4 {
outData[i] = valData[indicesData[i]]
outData[i+1] = valData[indicesData[i+1]]
outData[i+2] = valData[indicesData[i+2]]
outData[i+3] = valData[indicesData[i+3]]
}
for ; i < len(indicesData); i++ {
outData[i] = valData[indicesData[i]]
}
} else {
// Fallback to GetValue interface
for i, idx := range indicesData {
outData[i] = values.GetValue(int64(idx))
}
}
out.Nulls = 0
return
}

// Handle nulls in sorted case
var (
indicesIsValid = indices.Buffers[0].Buf
indicesOffset = indices.Offset
outIsValid = out.Buffers[0].Buf
outOffset = out.Offset
validCount = int64(0)
)

if values.NullCount() == 0 {
// Only indices can be null
for i, idx := range indicesData {
if bitutil.BitIsSet(indicesIsValid, int(indicesOffset)+i) {
outData[i] = values.GetValue(int64(idx))
bitutil.SetBit(outIsValid, int(outOffset)+i)
validCount++
}
}
} else {
// Both can be null
for i, idx := range indicesData {
if bitutil.BitIsSet(indicesIsValid, int(indicesOffset)+i) && values.IsValid(int64(idx)) {
outData[i] = values.GetValue(int64(idx))
bitutil.SetBit(outIsValid, int(outOffset)+i)
validCount++
}
}
}

out.Nulls = out.Len - validCount
}

func primitiveTakeImpl[IdxT arrow.UintType, ValT arrow.IntType](values primitiveGetter[ValT], indices *exec.ArraySpan, out *exec.ExecResult) {
var (
indicesData = exec.GetSpanValues[IdxT](indices, 1)
Expand All @@ -678,8 +825,35 @@ func primitiveTakeImpl[IdxT arrow.UintType, ValT arrow.IntType](values primitive
// values and indices are both never null
// this means we didn't allocate the validity bitmap
// and can simplify everything
for i, idx := range indicesData {
outData[i] = values.GetValue(int64(idx))

// Check if indices are sorted for optimized path
// Use sorted path for arrays >= 32 elements where sorting check is cheap
if len(indicesData) >= 32 {
if isSorted(indicesData) {
primitiveTakeImplSorted[IdxT, ValT](values, indices, out)
return
}
// Check for reverse sorted - use sequential loop to avoid cache penalties
// Loop unrolling amplifies cache miss penalties in reverse access patterns
if isReverseSorted(indicesData) {
for i := 0; i < len(indicesData); i++ {
outData[i] = values.GetValue(int64(indicesData[i]))
}
out.Nulls = 0
return
}
}

// Unroll loop for better performance (random access patterns)
i := 0
for ; i+4 <= len(indicesData); i += 4 {
outData[i] = values.GetValue(int64(indicesData[i]))
outData[i+1] = values.GetValue(int64(indicesData[i+1]))
outData[i+2] = values.GetValue(int64(indicesData[i+2]))
outData[i+3] = values.GetValue(int64(indicesData[i+3]))
}
for ; i < len(indicesData); i++ {
outData[i] = values.GetValue(int64(indicesData[i]))
}
out.Nulls = 0
return
Expand Down
Loading