Skip to content

Commit

Permalink
Fix _all boosting.
Browse files Browse the repository at this point in the history
_all boosting used to rely on the fact that the TokenStream doesn't eagerly
consume the input java.io.Reader. This fixes the issue by using binary search
in order to find the right boost given a token's start offset.

Close #4315
  • Loading branch information
jpountz committed Dec 5, 2013
1 parent 869c80d commit 0ef6ed9
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 13 deletions.
38 changes: 33 additions & 5 deletions src/main/java/org/elasticsearch/common/lucene/all/AllEntries.java
Expand Up @@ -40,14 +40,20 @@ public class AllEntries extends Reader {
public static class Entry {
private final String name;
private final FastStringReader reader;
private final int startOffset;
private final float boost;

public Entry(String name, FastStringReader reader, float boost) {
public Entry(String name, FastStringReader reader, int startOffset, float boost) {
this.name = name;
this.reader = reader;
this.startOffset = startOffset;
this.boost = boost;
}

public int startOffset() {
return startOffset;
}

public String name() {
return this.name;
}
Expand Down Expand Up @@ -75,7 +81,15 @@ public void addText(String name, String text, float boost) {
if (boost != 1.0f) {
customBoost = true;
}
Entry entry = new Entry(name, new FastStringReader(text), boost);
final int lastStartOffset;
if (entries.isEmpty()) {
lastStartOffset = -1;
} else {
final Entry last = entries.get(entries.size() - 1);
lastStartOffset = last.startOffset() + last.reader().length();
}
final int startOffset = lastStartOffset + 1; // +1 because we insert a space between tokens
Entry entry = new Entry(name, new FastStringReader(text), startOffset, boost);
entries.add(entry);
}

Expand Down Expand Up @@ -129,8 +143,22 @@ public Set<String> fields() {
return fields;
}

public Entry current() {
return this.current;
// compute the boost for a token with the given startOffset
public float boost(int startOffset) {
int lo = 0, hi = entries.size() - 1;
while (lo <= hi) {
final int mid = (lo + hi) >>> 1;
final int midOffset = entries.get(mid).startOffset();
if (startOffset < midOffset) {
hi = mid - 1;
} else {
lo = mid + 1;
}
}
final int index = Math.max(0, hi); // protection against broken token streams
assert entries.get(index).startOffset() <= startOffset;
assert index == entries.size() - 1 || entries.get(index + 1).startOffset() > startOffset;
return entries.get(index).boost();
}

@Override
Expand Down Expand Up @@ -186,7 +214,7 @@ public int read(char[] cbuf, int off, int len) throws IOException {
@Override
public void close() {
if (current != null) {
current.reader().close();
// no need to close, these are readers on strings
current = null;
}
}
Expand Down
Expand Up @@ -22,6 +22,7 @@
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenFilter;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.OffsetAttribute;
import org.apache.lucene.analysis.tokenattributes.PayloadAttribute;
import org.apache.lucene.util.BytesRef;

Expand All @@ -42,11 +43,13 @@ public static TokenStream allTokenStream(String allFieldName, AllEntries allEntr

private final AllEntries allEntries;

private final OffsetAttribute offsetAttribute;
private final PayloadAttribute payloadAttribute;

AllTokenStream(TokenStream input, AllEntries allEntries) {
super(input);
this.allEntries = allEntries;
offsetAttribute = addAttribute(OffsetAttribute.class);
payloadAttribute = addAttribute(PayloadAttribute.class);
}

Expand All @@ -59,14 +62,12 @@ public final boolean incrementToken() throws IOException {
if (!input.incrementToken()) {
return false;
}
if (allEntries.current() != null) {
float boost = allEntries.current().boost();
if (boost != 1.0f) {
encodeFloat(boost, payloadSpare.bytes, payloadSpare.offset);
payloadAttribute.setPayload(payloadSpare);
} else {
payloadAttribute.setPayload(null);
}
final float boost = allEntries.boost(offsetAttribute.startOffset());
if (boost != 1.0f) {
encodeFloat(boost, payloadSpare.bytes, payloadSpare.offset);
payloadAttribute.setPayload(payloadSpare);
} else {
payloadAttribute.setPayload(null);
}
return true;
}
Expand Down
Expand Up @@ -19,6 +19,11 @@

package org.elasticsearch.common.lucene.all;

import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.core.WhitespaceAnalyzer;
import org.apache.lucene.analysis.payloads.PayloadHelper;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.analysis.tokenattributes.PayloadAttribute;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.StoredField;
Expand All @@ -27,6 +32,7 @@
import org.apache.lucene.search.*;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.RAMDirectory;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.test.ElasticsearchTestCase;
import org.junit.Test;
Expand All @@ -40,6 +46,51 @@
*/
public class SimpleAllTests extends ElasticsearchTestCase {

@Test
public void testBoostOnEagerTokenizer() throws Exception {
AllEntries allEntries = new AllEntries();
allEntries.addText("field1", "all", 2.0f);
allEntries.addText("field2", "your", 1.0f);
allEntries.addText("field1", "boosts", 0.5f);
allEntries.reset();
// whitespace analyzer's tokenizer reads characters eagerly on the contrary to the standard tokenizer
final TokenStream ts = AllTokenStream.allTokenStream("any", allEntries, new WhitespaceAnalyzer(Lucene.VERSION));
final CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class);
final PayloadAttribute payloadAtt = ts.addAttribute(PayloadAttribute.class);
ts.reset();
for (int i = 0; i < 3; ++i) {
assertTrue(ts.incrementToken());
final String term;
final float boost;
switch (i) {
case 0:
term = "all";
boost = 2;
break;
case 1:
term = "your";
boost = 1;
break;
case 2:
term = "boosts";
boost = 0.5f;
break;
default:
throw new AssertionError();
}
assertEquals(term, termAtt.toString());
final BytesRef payload = payloadAtt.getPayload();
if (payload == null || payload.length == 0) {
assertEquals(boost, 1f, 0.001f);
} else {
assertEquals(4, payload.length);
final float b = PayloadHelper.decodeFloat(payload.bytes, payload.offset);
assertEquals(boost, b, 0.001f);
}
}
assertFalse(ts.incrementToken());
}

@Test
public void testAllEntriesRead() throws Exception {
AllEntries allEntries = new AllEntries();
Expand Down

0 comments on commit 0ef6ed9

Please sign in to comment.