Skip to content

Commit

Permalink
Hashing query performance improvements (1.5-2x faster on benchmarks) (#…
Browse files Browse the repository at this point in the history
…114)

Rewrote custom hashing query (`MatchHashesAndScoreQuery`) in Java. 
Doesn't necessarily make it faster, rather less likely that you introduce an expensive scala abstraction. 
Also easier to get help from Lucene users.

Made match counting faster by using an array instead of a map. 
This works because each counter only deals with the consecutive doc ids in a single segment. 
So instead of a Map from doc id to count, you have an array where the index is the doc id and value is the count.

Made candidate identification faster using a similar construct. 
Since you know the highest possible count is the number of terms, you can us an array to build a histogram of the counts,
then traverse from the end of the array to find the kth largest count.

Specific timing improvements (p90 benchmark times):
- Angular LSH: 121ms -> 50ms
- L2 LSH: 18ms -> 11ms
- Jaccard LSH: 58ms -> 36ms

Still need to understand how the `PrefixCodedTerms` work and if there's any possible optimization.
  • Loading branch information
alexklibisz authored Jul 24, 2020
1 parent 30c9e87 commit c75b23f
Show file tree
Hide file tree
Showing 18 changed files with 501 additions and 271 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,11 @@ jobs:
run: make test/gradle
- name: Test Python
run: make test/python
- name: Docker logs
if: ${{ failure() }}
run: cd testing && docker-compose logs

# Site and Docs
# Site and Docs
- name: Compile Site and Docs
run: |
gem install bundler
Expand Down
2 changes: 2 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
- Performance improvements for LSH queries. 1.5-2x faster on regular benchmarks with randomized data. See PR #114.
---
- Fixed error with KNN queries against vectors that are stored in nested fields, e.g. `outer.inner.vec`.
---
- Switched LSH parameter names to more canonical equivalents: `bands -> L`, `rows -> k`,
Expand Down
43 changes: 43 additions & 0 deletions core/src/main/java/com/klibisz/elastiknn/storage/BitBuffer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package com.klibisz.elastiknn.storage;

public interface BitBuffer {
void putOne();
void putZero();
byte[] toByteArray();

class IntBuffer implements BitBuffer {

private final byte[] prefix;
private int i = 0;
private int b = 0;

public IntBuffer(byte[] prefix) {
this.prefix = prefix;
}

public IntBuffer() {
this.prefix = new byte[0];
}

@Override
public void putOne() {
this.b += (1 << this.i);
this.i += 1;
}

@Override
public void putZero() {
this.i += 1;
}

@Override
public byte[] toByteArray() {
byte[] barr = UnsafeSerialization.writeInt(b);
byte[] res = new byte[prefix.length + barr.length];
System.arraycopy(prefix, 0, res, 0, prefix.length);
System.arraycopy(barr, 0, res, prefix.length, barr.length);
return res;
}
}

}
90 changes: 37 additions & 53 deletions core/src/main/java/com/klibisz/elastiknn/utils/ArrayUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,64 +36,48 @@ public static int sortedIntersectionCount(final int [] xs, final int [] ys) {
}

/**
* Find the kth largest value in the given array.
* Swaps elements in the given array.
* Based on: https://github.com/bephrem1/backtobackswe, https://www.youtube.com/watch?v=hGK_5n81drs.
* Lucene also has an implementation: https://lucene.apache.org/core/8_0_0/core/org/apache/lucene/util/IntroSelector.html,
* but it's more abstract and was slower when I benchmarked it.
* @param arr The array.
* @param k The position.
* @return The index of the kth largest value.
* Find the kth greatest value in the given array of shorts in O(N) time and space.
* Works by creating a histogram of the array values and traversing the histogram in reverse order.
* Assumes the max value in the array is small enough that you can keep an array of that length in memory.
* This is generally true for term counts.
*
* @param arr array of non-negative shorts, presumably some type of count.
* @param k the desired largest value.
* @return the kth largest value.
*/
public static int quickSelect(int[] arr, int k) {
int n = arr.length;
int left = 0;
int right = n - 1;
int finalIndexOfChoosenPivot = 0;
while (left <= right) {
int choosenPivotIndex = (right - left + 1) / 2 + left;
finalIndexOfChoosenPivot = qsPartition(arr, left, right, choosenPivotIndex);
if (finalIndexOfChoosenPivot == n - k) {
break;
} else if (finalIndexOfChoosenPivot > n - k) {
right = finalIndexOfChoosenPivot - 1;
} else {
left = finalIndexOfChoosenPivot + 1;
public static short kthGreatest(short[] arr, int k) {
if (arr.length == 0) {
throw new IllegalArgumentException("Array must be non-empty");
} else if (k < 0 || k >= arr.length) {
throw new IllegalArgumentException(String.format(
"k [%d] must be >= 0 and less than length of array [%d]",
k, arr.length
));
} else {
// Find the min and max values.
short max = arr[0];
short min = arr[0];
for (short a: arr) {
if (a > max) max = a;
else if (a < min) min = a;
}
}
return arr[finalIndexOfChoosenPivot];
}

/**
* Same as quickSelect, except makes a copy of the array so the original is unmodified.
* @param arr
* @param k
* @return
*/
public static int quickSelectCopy(int[] arr, int k) {
return quickSelect(Arrays.copyOf(arr, arr.length), k);
}

// Build and populate a histogram for non-zero values.
int[] hist = new int[max - min + 1];
for (short a: arr) {
hist[a - min] += 1;
}

private static int qsPartition(int[] arr, int left, int right, int pivotIndex) {
int pivotValue = arr[pivotIndex];
int lesserItemsTailIndex = left;
qsSwap(arr, pivotIndex, right);
for (int i = left; i < right; i++) {
if (arr[i] < pivotValue) {
qsSwap(arr, i, lesserItemsTailIndex);
lesserItemsTailIndex++;
// Find the kth largest value by iterating from the end of the histogram.
int geqk = 0;
short kthLargest = max;
while (kthLargest >= min) {
geqk += hist[kthLargest - min];
if (geqk > k) break;
else kthLargest--;
}
}
qsSwap(arr, right, lesserItemsTailIndex);
return lesserItemsTailIndex;
}

private static void qsSwap(int[] arr, int first, int second) {
int temp = arr[first];
arr[first] = arr[second];
arr[second] = temp;
return kthLargest;
}
}


}
29 changes: 0 additions & 29 deletions core/src/main/scala/com/klibisz/elastiknn/storage/BitBuffer.scala

This file was deleted.

2 changes: 2 additions & 0 deletions gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@ pluginName=elastiknn
group=com.klibisz.elastiknn

scalaVersion=2.12
esVersion=7.6.2
luceneVersion=8.4.0
circeVersion=0.13.0
1 change: 0 additions & 1 deletion plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ dependencies {
implementation "org.scala-lang:scala-library:${scalaVersion}"
implementation "com.google.guava:guava:28.1-jre"
runtime "com.google.guava:guava:28.1-jre"
implementation 'com.carrotsearch:hppc:0.8.2'
}

esplugin {
Expand Down
4 changes: 0 additions & 4 deletions plugin/src/main/java/com/klibisz/elastiknn/Empty.java

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
package org.apache.lucene.search;

import com.klibisz.elastiknn.utils.ArrayUtils;
import org.apache.lucene.index.*;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;

import java.io.IOException;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;

public class MatchHashesAndScoreQuery extends Query {

public interface ScoreFunction {
double score(int docId, int numMatchingHashes);
}

private final String field;
private final BytesRef[] hashes;
private final int candidates;
private final IndexReader indexReader;
private final Function<LeafReaderContext, ScoreFunction> scoreFunctionBuilder;
private final PrefixCodedTerms prefixCodedTerms;
private final int numDocsInSegment;

private static PrefixCodedTerms makePrefixCodedTerms(String field, BytesRef[] hashes) {
// PrefixCodedTerms.Builder expects the hashes in sorted order.
ArrayUtil.timSort(hashes);
PrefixCodedTerms.Builder builder = new PrefixCodedTerms.Builder();
for (BytesRef br : hashes) builder.add(field, br);
return builder.finish();
}

public MatchHashesAndScoreQuery(final String field,
final BytesRef[] hashes,
final int candidates,
final IndexReader indexReader,
final Function<LeafReaderContext, ScoreFunction> scoreFunctionBuilder) {
this.field = field;
this.hashes = hashes;
this.candidates = candidates;
this.indexReader = indexReader;
this.scoreFunctionBuilder = scoreFunctionBuilder;
this.prefixCodedTerms = makePrefixCodedTerms(field, hashes);
this.numDocsInSegment = indexReader.numDocs();
}

@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) {

return new Weight(this) {

private short[] countMatches(LeafReaderContext context) throws IOException {
LeafReader reader = context.reader();
Terms terms = reader.terms(field);
TermsEnum termsEnum = terms.iterator();
PrefixCodedTerms.TermIterator iterator = prefixCodedTerms.iterator();
short[] counts = new short[numDocsInSegment];
PostingsEnum docs = null;
BytesRef term = iterator.next();
while (term != null) {
if (termsEnum.seekExact(term)) {
docs = termsEnum.postings(docs, PostingsEnum.NONE);
for (int i = 0; i < docs.cost(); i++) {
int docId = docs.nextDoc();
counts[docId] += 1;
}
}
term = iterator.next();
}
return counts;
}

private DocIdSetIterator buildDocIdSetIterator(short[] counts) {
if (candidates >= numDocsInSegment) return DocIdSetIterator.all(indexReader.maxDoc());
else {
int minCandidateCount = ArrayUtils.kthGreatest(counts, candidates);
// DocIdSetIterator that iterates over the doc ids but only emits the ids >= the min candidate count.
return new DocIdSetIterator() {

private int doc = 0;

@Override
public int docID() {
return doc;
}

@Override
public int nextDoc() {
// Increment doc until it exceeds the min candidate count.
do doc++;
while (doc < counts.length && counts[doc]< minCandidateCount);
if (doc == counts.length) return DocIdSetIterator.NO_MORE_DOCS;
else return docID();
}

@Override
public int advance(int target) {
while (doc < target) nextDoc();
return docID();
}

@Override
public long cost() {
return counts.length;
}
};
}
}

@Override
public void extractTerms(Set<Term> terms) { }

@Override
public Explanation explain(LeafReaderContext context, int doc) {
return Explanation.match( 0, "If someone know what this should return, please submit a PR. :)");
}

@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
ScoreFunction scoreFunction = scoreFunctionBuilder.apply(context);
short[] counts = countMatches(context);
DocIdSetIterator disi = buildDocIdSetIterator(counts);

return new Scorer(this) {
@Override
public DocIdSetIterator iterator() {
return disi;
}

@Override
public float getMaxScore(int upTo) {
return Float.MAX_VALUE;
}

@Override
public float score() {
return (float) scoreFunction.score(docID(), counts[docID()]);
}

@Override
public int docID() {
return disi.docID();
}
};
}

@Override
public boolean isCacheable(LeafReaderContext ctx) {
return false;
}
};
}

@Override
public String toString(String field) {
return String.format(
"%s for field [%s] with [%d] hashes and [%d] candidates",
this.getClass().getSimpleName(),
this.field,
this.hashes.length,
this.candidates);
}

@Override
public boolean equals(Object obj) {
if (obj instanceof MatchHashesAndScoreQuery) {
MatchHashesAndScoreQuery q = (MatchHashesAndScoreQuery) obj;
return q.hashCode() == this.hashCode();
} else {
return false;
}
}

@Override
public int hashCode() {
return Objects.hash(field, hashes, candidates, indexReader, scoreFunctionBuilder);
}
}
Loading

0 comments on commit c75b23f

Please sign in to comment.