-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LUCENE-9027: Use SIMD instructions to decode postings. #973
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change looks great -- surprisingly small amount of impacted code, just a lot of code copying from the previous postings format to this new one, and then tweaking how we encode/decode longs, and matching byte oder to native order of the server doing the decoding.
} | ||
|
||
/** | ||
* Encode 128 8-bits integers from {@code data} into {@code out}. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, should be {@code longs}
? And, are they really 8-bit integers arriving in longs
? If so, why is it long[]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I've had many iterations, and this comment is outdated, I'll fix.
final int patchedBitsRequired = Math.max(PackedInts.bitsRequired(top4[0]), maxBitsRequired - 8); | ||
int numExceptions = 0; | ||
for (int i = 1; i < 4; ++i) { | ||
if (top4[i] > (1L << patchedBitsRequired) - 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add a new local variable, long patchLimit = (1L << patchedBitsRequired) - 1;
, used here and below?
assert exceptionCount == numExceptions : exceptionCount + " " + numExceptions; | ||
} | ||
|
||
if (allEqual(longs) && maxBitsRequired <= 8) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you safely insert numExceptions == 0
into this if conjunction? I.e. when all longs are equal, there would be no exceptions, I think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All values in longs
have been masked at this point, so there might be exceptions. My intuition is that this case is important for rare terms that might have most of their term freqs equal to 1, but maybe a couple exceptions. Does it answer your question?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, thanks for the explanation.
} | ||
|
||
if (allEqual(longs) && maxBitsRequired <= 8) { | ||
for (int i = 0; i < numExceptions; ++i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, maybe I'm wrong and confused about this, since you are iterating the exceptions here :)
@@ -63,7 +69,16 @@ public static ByteBufferIndexInput newInstance(String resourceDescription, ByteB | |||
assert chunkSizePower >= 0 && chunkSizePower <= 30; | |||
assert (length >>> chunkSizePower) < Integer.MAX_VALUE; | |||
} | |||
|
|||
|
|||
protected void setCurBuf(ByteBuffer curBuf) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code only runs when we cross into another mapped region, right (each chunkSize
in MMapDirectory
)? And when we initialize the first mapped region too. Hmm, but only from readByte()
. This is likely a non-trivial fixed per-query / per-segment / per-field cost. And I suppose it's only needed for postings files, so e.g. terms dict / doc values really don't need to do this, but are doing so now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could also build the LongBuffer view on demand if that works better for you. Are you more worried about CPU, memory or both?
public void readLongs(ByteOrder byteOrder, long[] dst, int offset, int length) throws IOException { | ||
try { | ||
final int position = curBuf.position(); | ||
guard.getLongs(curLongBufferViews[position & 0x07].position(position >>> 3), dst, offset, length); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we ensured that our long postings lists (long enough to use the 128 integer block encoding) were all written "long aligned", which I guess would mean inserting pad bytes for these high frequency terms, then we would not need these 63 clones? I.e. we'd know at read time that all reads are already long aligned?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the record, there are 8, not 63 clones. Currently blocks all take bitsPerValue * 16 + 1
bytes, so adding padding would waste 7 bytes all the time. This would be a 42% increase for 1 bpv blocks and 21% for 2 bpv, which are not that uncommon bits per value for term freqs, especially when exceptions help lower the number of bits per value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK I must be a little confused here -- the loop above under the lazy init if
(for (int i = 0; i < Math.min(Long.BYTES, curBuf.limit()); ++i) {
) seems like it is initializing up to 64
clones, yet the array access here (position & 0x07
) seems to only use the first 8
.
And I thought the padding would only be needed at the start of postings lists, and only for terms whose frequency is >= 128
, not at the start of each block, but you're right, since we have the one extra one leading byte (to tell us numBits
) for each block, pad bytes would indeed need to be for each block, hrmph. Nevermind ;)
I'm also surprised this trick actually helps -- under the hood the CPU must still do unaligned long decodes (which I think modern X86-64 are good at)? Maybe add a comment about why this trick is worthwhile?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, duh! Long.BYTES
is 8 :)
try { | ||
CodecUtil.writeIndexHeader(docOut, DOC_CODEC, VERSION_CURRENT, | ||
state.segmentInfo.getId(), state.segmentSuffix); | ||
ByteOrder byteOrder = ByteOrder.nativeOrder(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK it looks like this postings format will write in the native byte order of the server doing indexing, which is typically (x86) little-endian.
But then at read (search) time, if that machine's native order is big-endian, it will reverse bytes on each decode. But if it's little-endian (common case), no reversal is needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is correct.
|
||
for (int i = 0; i < numLongsPerShift; ++i) { | ||
long l = tmp[i]; | ||
if (byteOrder != ByteOrder.BIG_ENDIAN) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is commonly the case -- x86 is little-endian. I suppose we could make things a bit faster by reversing this, i.e. fixing our Python code generator to write bytes in little-endian, and then reverse only if the current server is big-endian (unusual). But that's likely a minor gain...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this one comes from the fact that I wanted to avoid adding new methods to DataOutput. So even though we read with DataInput#readLongs(long[])
, we still write with DataOutput#writeLong(long)
. Then if I want to write longs with the native byte order we need to check whether it differs from BIG_ENDIAN which is the endianness of Java longs. I'll add a comment to clarify.
} | ||
curBuf.position(position + (length << 3)); | ||
} catch (BufferUnderflowException e) { | ||
super.readLongs(byteOrder, dst, offset, length); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, how come we don't also do the setCurBuf
here? It seems like we are falling back to the slow (super.readLongs
) implementation always on the first readLongs
call, but since that slow version will use readByte
it will then initialize the curLongBufferViews
for subsequent readLongs
calls?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is correct.
@mikemccand I changed ByteBufferIndexInput to create the LongBuffer views in order to hopefully address your concern. I also removed write support from the Lucene50 postings format and moved it to backward-codecs. I couldn't do it for Completion50 because it would require adding a dependency to suggest to backward-codecs so I left it in the suggest module. |
assert exceptionCount == numExceptions : exceptionCount + " " + numExceptions; | ||
} | ||
|
||
if (allEqual(longs) && maxBitsRequired <= 8) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, thanks for the explanation.
for (int i = 0; i < ForUtil.BLOCK_SIZE; ++i) { | ||
if (longs[i] > top4[0]) { | ||
top4[0] = longs[i]; | ||
Arrays.sort(top4); // For only 4 entries we just sort on every iteration instead of maintaining a PQ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could probably inlined a small PQ ... it'd be just a few nested if's I think ... not likely not worth it, i.e. wouldn't move the needle on indexing throughput much, and maybe Arrays.sort
is already specializing small arrays anyways.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I would be surprised if that made a different for 4 slots.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well I got too curious about this and played around with some silly micro-benchmarks. I tested three approaches.
First approach is to inline the PQ as a long[4]
:
private static long[] top4_a(long[] input) {
long[] top4 = new long[4];
Arrays.fill(top4, Long.MIN_VALUE);
for (long elem : input) {
if (elem > top4[3]) {
if (elem > top4[1]) {
if (elem > top4[0]) {
top4[3] = top4[2];
top4[2] = top4[1];
top4[1] = top4[0];
top4[0] = elem;
} else {
top4[3] = top4[2];
top4[2] = top4[1];
top4[1] = elem;
}
} else if (elem > top4[2]) {
top4[3] = top4[2];
top4[2] = elem;
} else {
top4[3] = elem;
}
}
}
return top4;
}
Second approach is the same thing, use local variables for the four slots instead of long[]
:
private static long[] top4_b(long[] input) {
long first = Long.MIN_VALUE;
long second = Long.MIN_VALUE;
long third = Long.MIN_VALUE;
long forth = Long.MIN_VALUE;
for (long elem : input) {
if (elem > forth) {
if (elem > second) {
if (elem > first) {
forth = third;
third = second;
second = first;
first = elem;
} else {
forth = third;
third = second;
second = elem;
}
} else if (elem > third) {
forth = third;
third = elem;
} else {
forth = elem;
}
}
}
return new long[] {first, second, third, forth};
}
Last approach just uses Arrays.sort
(like here):
private static long[] top4_c(long[] input) {
long[] top4 = new long[4];
Arrays.fill(top4, Long.MIN_VALUE);
for (long elem : input) {
if (elem > top4[0]) {
top4[0] = elem;
Arrays.sort(top4);
}
}
for (int i = 0; i < 2; i++) {
// swap
long x = top4[i];
top4[i] = top4[4 - i - 1];
top4[4 - i - 1] = x;
}
return top4;
}
And then I wrote a silly micro-benchmark on a randomly generated long[]
:
public static void main(String[] args) throws Exception {
// long seed = System.currentTimeMillis();
long seed = 1574030230334L;
System.out.println("SEED: " + seed);
Random r = new Random(seed);
/*
for (int i = 0; i < 10000; i++) {
int len = 1000 + r.nextInt(9000);
int valueRange = 1000 + r.nextInt(9000);
long[] values = new long[len];
for (int j = 0; j < len; j++) {
values[j] = r.nextInt(valueRange);
}
long[] top4 = top4_b(values);
Arrays.sort(values);
long[] correctTop4 = new long[4];
for (int j = 0; j < 4; j++) {
correctTop4[j] = values[len - j - 1];
}
if (Arrays.equals(top4, correctTop4) == false) {
throw new RuntimeException("FAILED:\n seed=" + seed + "\n input=" + Arrays.toString(values) +
"\n top4=" + Arrays.toString(top4) + "\n answer=" + Arrays.toString(correctTop4));
}
}
*/
int len = 1000000;
long[] values = new long[len];
int valueRange = 10000 + r.nextInt(90000);
for (int j = 0; j < len; j++) {
values[j] = r.nextInt(valueRange);
}
// warmup
for (int i = 0; i < 5000; i++) {
long[] top4 = top4_c(values);
}
System.out.println("Done warmup");
// test
long bestNS = Long.MAX_VALUE;
for (int iter = 0; iter < 10; iter++) {
long t0 = System.nanoTime();
for (int i = 0; i < 2000; i++) {
long[] top4 = top4_c(values);
}
long t1 = System.nanoTime();
long ns = t1 - t0;
String extra;
if (ns < bestNS) {
bestNS = ns;
extra = " ***";
} else {
extra = "";
}
System.out.println(String.format(Locale.ROOT, "iter %d: %.3f sec%s", iter, (t1 - t0) / 1000000000., extra));
}
}
And the results are interesting: it is a bit faster to "inline" the PQ approach:
top4_a:
SEED: 1574030230334
Done warmup
iter 0: 0.875 sec ***
iter 1: 0.876 sec
iter 2: 0.876 sec
iter 3: 0.875 sec ***
iter 4: 0.876 sec
iter 5: 0.874 sec ***
iter 6: 0.874 sec
iter 7: 0.872 sec ***
iter 8: 0.874 sec
iter 9: 0.872 sec
top4_b:
SEED: 1574030230334
Done warmup
iter 0: 0.788 sec ***
iter 1: 0.790 sec
iter 2: 0.786 sec ***
iter 3: 0.784 sec ***
iter 4: 0.784 sec
iter 5: 0.784 sec
iter 6: 0.785 sec
iter 7: 0.785 sec
iter 8: 0.786 sec
iter 9: 0.785 sec
top4_c:
SEED: 1574030230334
Done warmup
iter 0: 1.317 sec ***
iter 1: 1.331 sec
iter 2: 1.330 sec
iter 3: 1.323 sec
iter 4: 1.323 sec
iter 5: 1.316 sec ***
iter 6: 1.319 sec
iter 7: 1.318 sec
iter 8: 1.317 sec
iter 9: 1.313 sec ***
I'm using Corretto JDK11: openjdk full version "11.0.4+11-LTS", with all defaults for the JVM.
Net/net I don't think this warrants inlining the PQ, this is likely a small part of the overall indexing time, and the Arrays.sort
approach is nice and compact and understandable. I was just curious ;)
void decodeAndPrefixSum(DataInput in, long base, long[] longs) throws IOException { | ||
final int bitsPerValue = Byte.toUnsignedInt(in.readByte()); | ||
if (bitsPerValue == 0) { | ||
// Note: can this special-case be optimized? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe hotspot is able to SIMD this loop?
public void readLongs(ByteOrder byteOrder, long[] dst, int offset, int length) throws IOException { | ||
try { | ||
final int position = curBuf.position(); | ||
guard.getLongs(curLongBufferViews[position & 0x07].position(position >>> 3), dst, offset, length); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK I must be a little confused here -- the loop above under the lazy init if
(for (int i = 0; i < Math.min(Long.BYTES, curBuf.limit()); ++i) {
) seems like it is initializing up to 64
clones, yet the array access here (position & 0x07
) seems to only use the first 8
.
And I thought the padding would only be needed at the start of postings lists, and only for terms whose frequency is >= 128
, not at the start of each block, but you're right, since we have the one extra one leading byte (to tell us numBits
) for each block, pad bytes would indeed need to be for each block, hrmph. Nevermind ;)
I'm also surprised this trick actually helps -- under the hood the CPU must still do unaligned long decodes (which I think modern X86-64 are good at)? Maybe add a comment about why this trick is worthwhile?
@Override | ||
public void readLELongs(long[] dst, int offset, int length) throws IOException { | ||
if (curLongBufferViews == null) { | ||
// Lazy init to not make pay for memory and initialization cost if you don't need to read arrays of longs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe also explain that only certain index files (just postings today, but maybe doc values soon too) use the read/writeLongs
APIs, justifying the laziness?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I was thinking of using it to read doc IDs in BKDReader next. I agree doc values should probably look into it as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
under the hood the CPU must still do unaligned long decodes (which I think modern X86-64 are good at)? Maybe add a comment about why this trick is worthwhile?
This is mostly a workaround for the lack of ByteBuffer#getLongs(long[]). Avoiding the LongBuffer view would require to call ByteBuffer#getLong in a loop, which seems to have some per-long overhead according to the microbenchmarks I ran. I'll leave a comment.
// readLELongs is only used for postings today, so we compute the long | ||
// views lazily so that other data-structures don't have to pay for the | ||
// associated initialization/memory overhead. | ||
curLongBufferViews = new LongBuffer[Long.BYTES]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
for (int i = 0; i < ForUtil.BLOCK_SIZE; ++i) { | ||
if (longs[i] > top4[0]) { | ||
top4[0] = longs[i]; | ||
Arrays.sort(top4); // For only 4 entries we just sort on every iteration instead of maintaining a PQ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well I got too curious about this and played around with some silly micro-benchmarks. I tested three approaches.
First approach is to inline the PQ as a long[4]
:
private static long[] top4_a(long[] input) {
long[] top4 = new long[4];
Arrays.fill(top4, Long.MIN_VALUE);
for (long elem : input) {
if (elem > top4[3]) {
if (elem > top4[1]) {
if (elem > top4[0]) {
top4[3] = top4[2];
top4[2] = top4[1];
top4[1] = top4[0];
top4[0] = elem;
} else {
top4[3] = top4[2];
top4[2] = top4[1];
top4[1] = elem;
}
} else if (elem > top4[2]) {
top4[3] = top4[2];
top4[2] = elem;
} else {
top4[3] = elem;
}
}
}
return top4;
}
Second approach is the same thing, use local variables for the four slots instead of long[]
:
private static long[] top4_b(long[] input) {
long first = Long.MIN_VALUE;
long second = Long.MIN_VALUE;
long third = Long.MIN_VALUE;
long forth = Long.MIN_VALUE;
for (long elem : input) {
if (elem > forth) {
if (elem > second) {
if (elem > first) {
forth = third;
third = second;
second = first;
first = elem;
} else {
forth = third;
third = second;
second = elem;
}
} else if (elem > third) {
forth = third;
third = elem;
} else {
forth = elem;
}
}
}
return new long[] {first, second, third, forth};
}
Last approach just uses Arrays.sort
(like here):
private static long[] top4_c(long[] input) {
long[] top4 = new long[4];
Arrays.fill(top4, Long.MIN_VALUE);
for (long elem : input) {
if (elem > top4[0]) {
top4[0] = elem;
Arrays.sort(top4);
}
}
for (int i = 0; i < 2; i++) {
// swap
long x = top4[i];
top4[i] = top4[4 - i - 1];
top4[4 - i - 1] = x;
}
return top4;
}
And then I wrote a silly micro-benchmark on a randomly generated long[]
:
public static void main(String[] args) throws Exception {
// long seed = System.currentTimeMillis();
long seed = 1574030230334L;
System.out.println("SEED: " + seed);
Random r = new Random(seed);
/*
for (int i = 0; i < 10000; i++) {
int len = 1000 + r.nextInt(9000);
int valueRange = 1000 + r.nextInt(9000);
long[] values = new long[len];
for (int j = 0; j < len; j++) {
values[j] = r.nextInt(valueRange);
}
long[] top4 = top4_b(values);
Arrays.sort(values);
long[] correctTop4 = new long[4];
for (int j = 0; j < 4; j++) {
correctTop4[j] = values[len - j - 1];
}
if (Arrays.equals(top4, correctTop4) == false) {
throw new RuntimeException("FAILED:\n seed=" + seed + "\n input=" + Arrays.toString(values) +
"\n top4=" + Arrays.toString(top4) + "\n answer=" + Arrays.toString(correctTop4));
}
}
*/
int len = 1000000;
long[] values = new long[len];
int valueRange = 10000 + r.nextInt(90000);
for (int j = 0; j < len; j++) {
values[j] = r.nextInt(valueRange);
}
// warmup
for (int i = 0; i < 5000; i++) {
long[] top4 = top4_c(values);
}
System.out.println("Done warmup");
// test
long bestNS = Long.MAX_VALUE;
for (int iter = 0; iter < 10; iter++) {
long t0 = System.nanoTime();
for (int i = 0; i < 2000; i++) {
long[] top4 = top4_c(values);
}
long t1 = System.nanoTime();
long ns = t1 - t0;
String extra;
if (ns < bestNS) {
bestNS = ns;
extra = " ***";
} else {
extra = "";
}
System.out.println(String.format(Locale.ROOT, "iter %d: %.3f sec%s", iter, (t1 - t0) / 1000000000., extra));
}
}
And the results are interesting: it is a bit faster to "inline" the PQ approach:
top4_a:
SEED: 1574030230334
Done warmup
iter 0: 0.875 sec ***
iter 1: 0.876 sec
iter 2: 0.876 sec
iter 3: 0.875 sec ***
iter 4: 0.876 sec
iter 5: 0.874 sec ***
iter 6: 0.874 sec
iter 7: 0.872 sec ***
iter 8: 0.874 sec
iter 9: 0.872 sec
top4_b:
SEED: 1574030230334
Done warmup
iter 0: 0.788 sec ***
iter 1: 0.790 sec
iter 2: 0.786 sec ***
iter 3: 0.784 sec ***
iter 4: 0.784 sec
iter 5: 0.784 sec
iter 6: 0.785 sec
iter 7: 0.785 sec
iter 8: 0.786 sec
iter 9: 0.785 sec
top4_c:
SEED: 1574030230334
Done warmup
iter 0: 1.317 sec ***
iter 1: 1.331 sec
iter 2: 1.330 sec
iter 3: 1.323 sec
iter 4: 1.323 sec
iter 5: 1.316 sec ***
iter 6: 1.319 sec
iter 7: 1.318 sec
iter 8: 1.317 sec
iter 9: 1.313 sec ***
I'm using Corretto JDK11: openjdk full version "11.0.4+11-LTS", with all defaults for the JVM.
Net/net I don't think this warrants inlining the PQ, this is likely a small part of the overall indexing time, and the Arrays.sort
approach is nice and compact and understandable. I was just curious ;)
Thanks @mikemccand for taking the time to look at this large PR! |
Here is a draft PR in case somebody would like to play with this. This is not ready to merge as the changes to
DataInput
might be controversial and we'd also need to