Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hide the internal data structure of HeapPointWriter #12762

Merged
merged 6 commits into from Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Expand Up @@ -363,6 +363,8 @@ Other
overflows and slices that are too large. Some bits of code are simplified. Documentation is updated and expanded.
(Stefan Vodita)

* GITHUB#12762: Refactor BKD HeapPointWriter to hide the internal data structure. (Ignacio Vera)

======================== Lucene 9.8.0 =======================

API Changes
Expand Down
120 changes: 16 additions & 104 deletions lucene/core/src/java/org/apache/lucene/util/bkd/BKDRadixSelector.java
Expand Up @@ -19,8 +19,6 @@
import java.io.IOException;
import java.util.Arrays;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.ArrayUtil.ByteArrayComparator;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IntroSelector;
import org.apache.lucene.util.IntroSorter;
Expand Down Expand Up @@ -181,12 +179,8 @@ private int findCommonPrefixAndHistogram(
break;
} else {
// Check common prefix and adjust histogram
final int startIndex =
(dimCommonPrefix > config.bytesPerDim) ? config.bytesPerDim : dimCommonPrefix;
final int endIndex =
(commonPrefixPosition > config.bytesPerDim)
? config.bytesPerDim
: commonPrefixPosition;
final int startIndex = Math.min(dimCommonPrefix, config.bytesPerDim);
final int endIndex = Math.min(commonPrefixPosition, config.bytesPerDim);
packedValueDocID = pointValue.packedValueDocIDBytes();
int j =
Arrays.mismatch(
Expand Down Expand Up @@ -427,24 +421,13 @@ protected void swap(int i, int j) {
@Override
protected int byteAt(int i, int k) {
assert k >= 0 : "negative prefix " + k;
if (k < dimCmpBytes) {
// dim bytes
return points.block[i * config.bytesPerDoc + dimOffset + k] & 0xff;
} else {
// data bytes
return points.block[i * config.bytesPerDoc + dataOffset + k] & 0xff;
}
return points.byteAt(i, k < dimCmpBytes ? dimOffset + k : dataOffset + k);
}

@Override
protected Selector getFallbackSelector(int d) {
final int skypedBytes = d + commonPrefixLength;
final int dimStart = dim * config.bytesPerDim;
// data length is composed by the data dimensions plus the docID
final int dataLength =
(config.numDims - config.numIndexDims) * config.bytesPerDim + Integer.BYTES;
final ByteArrayComparator dimComparator =
ArrayUtil.getUnsignedComparator(config.bytesPerDim);
return new IntroSelector() {

@Override
Expand All @@ -455,61 +438,31 @@ protected void swap(int i, int j) {
@Override
protected void setPivot(int i) {
if (skypedBytes < config.bytesPerDim) {
System.arraycopy(
points.block,
i * config.bytesPerDoc + dim * config.bytesPerDim,
scratch,
0,
config.bytesPerDim);
points.copyDim(i, dimStart, scratch, 0);
}
System.arraycopy(
points.block,
i * config.bytesPerDoc + config.packedIndexBytesLength,
scratch,
config.bytesPerDim,
dataLength);
points.copyDataDimsAndDoc(i, scratch, config.bytesPerDim);
}

@Override
protected int compare(int i, int j) {
if (skypedBytes < config.bytesPerDim) {
int iOffset = i * config.bytesPerDoc;
int jOffset = j * config.bytesPerDoc;
int cmp =
dimComparator.compare(
points.block, iOffset + dimStart, points.block, jOffset + dimStart);
int cmp = points.compareDim(i, j, dimStart);
if (cmp != 0) {
return cmp;
}
}
int iOffset = i * config.bytesPerDoc + config.packedIndexBytesLength;
int jOffset = j * config.bytesPerDoc + config.packedIndexBytesLength;
return Arrays.compareUnsigned(
points.block,
iOffset,
iOffset + dataLength,
points.block,
jOffset,
jOffset + dataLength);
return points.compareDataDimsAndDoc(i, j);
}

@Override
protected int comparePivot(int j) {
if (skypedBytes < config.bytesPerDim) {
int jOffset = j * config.bytesPerDoc;
int cmp = dimComparator.compare(scratch, 0, points.block, jOffset + dimStart);
int cmp = points.compareDim(j, scratch, 0, dimStart);
if (cmp != 0) {
return cmp;
}
}
int jOffset = j * config.bytesPerDoc + config.packedIndexBytesLength;
return Arrays.compareUnsigned(
scratch,
config.bytesPerDim,
config.bytesPerDim + dataLength,
points.block,
jOffset,
jOffset + dataLength);
return points.compareDataDimsAndDoc(j, scratch, config.bytesPerDim);
}
};
}
Expand Down Expand Up @@ -538,13 +491,7 @@ public void heapRadixSort(
@Override
protected int byteAt(int i, int k) {
assert k >= 0 : "negative prefix " + k;
if (k < dimCmpBytes) {
// dim bytes
return points.block[i * config.bytesPerDoc + dimOffset + k] & 0xff;
} else {
// data bytes
return points.block[i * config.bytesPerDoc + dataOffset + k] & 0xff;
}
return points.byteAt(i, k < dimCmpBytes ? dimOffset + k : dataOffset + k);
}

@Override
Expand All @@ -556,11 +503,6 @@ protected void swap(int i, int j) {
protected Sorter getFallbackSorter(int k) {
final int skypedBytes = k + commonPrefixLength;
final int dimStart = dim * config.bytesPerDim;
// data length is composed by the data dimensions plus the docID
final int dataLength =
(config.numDims - config.numIndexDims) * config.bytesPerDim + Integer.BYTES;
final ByteArrayComparator dimComparator =
ArrayUtil.getUnsignedComparator(config.bytesPerDim);
return new IntroSorter() {

@Override
Expand All @@ -571,61 +513,31 @@ protected void swap(int i, int j) {
@Override
protected void setPivot(int i) {
if (skypedBytes < config.bytesPerDim) {
System.arraycopy(
points.block,
i * config.bytesPerDoc + dim * config.bytesPerDim,
scratch,
0,
config.bytesPerDim);
points.copyDim(i, dimStart, scratch, 0);
}
System.arraycopy(
points.block,
i * config.bytesPerDoc + config.packedIndexBytesLength,
scratch,
config.bytesPerDim,
dataLength);
points.copyDataDimsAndDoc(i, scratch, config.bytesPerDim);
}

@Override
protected int compare(int i, int j) {
if (skypedBytes < config.bytesPerDim) {
int iOffset = i * config.bytesPerDoc;
int jOffset = j * config.bytesPerDoc;
int cmp =
dimComparator.compare(
points.block, iOffset + dimStart, points.block, jOffset + dimStart);
final int cmp = points.compareDim(i, j, dimStart);
if (cmp != 0) {
return cmp;
}
}
int iOffset = i * config.bytesPerDoc + config.packedIndexBytesLength;
int jOffset = j * config.bytesPerDoc + config.packedIndexBytesLength;
return Arrays.compareUnsigned(
points.block,
iOffset,
iOffset + dataLength,
points.block,
jOffset,
jOffset + dataLength);
return points.compareDataDimsAndDoc(i, j);
}

@Override
protected int comparePivot(int j) {
if (skypedBytes < config.bytesPerDim) {
int jOffset = j * config.bytesPerDoc;
int cmp = dimComparator.compare(scratch, 0, points.block, jOffset + dimStart);
int cmp = points.compareDim(j, scratch, 0, dimStart);
if (cmp != 0) {
return cmp;
}
}
int jOffset = j * config.bytesPerDoc + config.packedIndexBytesLength;
return Arrays.compareUnsigned(
scratch,
config.bytesPerDim,
config.bytesPerDim + dataLength,
points.block,
jOffset,
jOffset + dataLength);
return points.compareDataDimsAndDoc(j, scratch, config.bytesPerDim);
}
};
}
Expand Down
Expand Up @@ -16,8 +16,7 @@
*/
package org.apache.lucene.util.bkd;

import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.BytesRef;
import java.util.function.IntFunction;

/**
* Utility class to read buffered points from in-heap arrays.
Expand All @@ -26,22 +25,13 @@
*/
public final class HeapPointReader implements PointReader {
private int curRead;
final byte[] block;
final BKDConfig config;
final int end;
private final HeapPointValue pointValue;
private final int end;
private final IntFunction<PointValue> points;

public HeapPointReader(BKDConfig config, byte[] block, int start, int end) {
this.block = block;
HeapPointReader(IntFunction<PointValue> points, int start, int end) {
curRead = start - 1;
this.end = end;
this.config = config;
if (start < end) {
this.pointValue = new HeapPointValue(config, block);
} else {
// no values
this.pointValue = null;
}
this.points = points;
}

@Override
Expand All @@ -52,46 +42,9 @@ public boolean next() {

@Override
public PointValue pointValue() {
pointValue.setOffset(curRead * config.bytesPerDoc);
return pointValue;
return points.apply(curRead);
}

@Override
public void close() {}

/** Reusable implementation for a point value on-heap */
static class HeapPointValue implements PointValue {

final BytesRef packedValue;
final BytesRef packedValueDocID;
final int packedValueLength;

HeapPointValue(BKDConfig config, byte[] value) {
this.packedValueLength = config.packedBytesLength;
this.packedValue = new BytesRef(value, 0, packedValueLength);
this.packedValueDocID = new BytesRef(value, 0, config.bytesPerDoc);
}

/** Sets a new value by changing the offset. */
public void setOffset(int offset) {
packedValue.offset = offset;
packedValueDocID.offset = offset;
}

@Override
public BytesRef packedValue() {
return packedValue;
}

@Override
public int docID() {
int position = packedValueDocID.offset + packedValueLength;
return (int) BitUtil.VH_BE_INT.get(packedValueDocID.bytes, position);
}

@Override
public BytesRef packedValueDocIDBytes() {
return packedValueDocID;
}
}
}