Skip to content

Commit

Permalink
Add highlighting support for suggester.
Browse files Browse the repository at this point in the history
This commit adds general highlighting support to the suggest feature.
The only implementation that implements this functionality at this
point is the phrase suggester.
The API supports a 'pre_tag' and a 'post_tag' that are used
to wrap suggested parts of the given user input changed by the
suggester.

Closes #3442
  • Loading branch information
nik9000 authored and s1monw committed Aug 6, 2013
1 parent 7511676 commit 2b76ac8
Show file tree
Hide file tree
Showing 12 changed files with 212 additions and 45 deletions.
21 changes: 20 additions & 1 deletion src/main/java/org/elasticsearch/search/suggest/Suggest.java
Expand Up @@ -496,18 +496,25 @@ public static class Option implements Streamable, ToXContent {
static class Fields {

static final XContentBuilderString TEXT = new XContentBuilderString("text");
static final XContentBuilderString HIGHLIGHTED = new XContentBuilderString("highlighted");
static final XContentBuilderString SCORE = new XContentBuilderString("score");

}

private Text text;
private Text highlighted;
private float score;

public Option(Text text, float score) {
public Option(Text text, Text highlighted, float score) {
this.text = text;
this.highlighted = highlighted;
this.score = score;
}

public Option(Text text, float score) {
this(text, null, score);
}

public Option() {
}

Expand All @@ -518,6 +525,13 @@ public Text getText() {
return text;
}

/**
* @return Copy of suggested text with changes from user supplied text highlighted.
*/
public Text getHighlighted() {
return highlighted;
}

/**
* @return The score based on the edit distance difference between the suggested term and the
* term in the suggest text.
Expand All @@ -533,12 +547,14 @@ protected void setScore(float score) {
@Override
public void readFrom(StreamInput in) throws IOException {
text = in.readText();
highlighted = in.readOptionalText();
score = in.readFloat();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeText(text);
out.writeOptionalText(highlighted);
out.writeFloat(score);
}

Expand All @@ -552,6 +568,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws

protected XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(Fields.TEXT, text);
if (highlighted != null) {
builder.field(Fields.HIGHLIGHTED, highlighted);
}
builder.field(Fields.SCORE, score);
return builder;
}
Expand Down
Expand Up @@ -32,14 +32,18 @@ public abstract class CandidateGenerator {
public abstract long frequency(BytesRef term) throws IOException;

public CandidateSet drawCandidates(BytesRef term) throws IOException {
CandidateSet set = new CandidateSet(Candidate.EMPTY, createCandidate(term));
CandidateSet set = new CandidateSet(Candidate.EMPTY, createCandidate(term, true));
return drawCandidates(set);
}

public Candidate createCandidate(BytesRef term) throws IOException {
return createCandidate(term, frequency(term), 1.0);
public Candidate createCandidate(BytesRef term, boolean userInput) throws IOException {
return createCandidate(term, frequency(term), 1.0, userInput);
}
public abstract Candidate createCandidate(BytesRef term, long frequency, double channelScore) throws IOException;
public Candidate createCandidate(BytesRef term, long frequency, double channelScore) throws IOException {
return createCandidate(term, frequency, channelScore, false);
}

public abstract Candidate createCandidate(BytesRef term, long frequency, double channelScore, boolean userInput) throws IOException;

public abstract CandidateSet drawCandidates(CandidateSet set) throws IOException;

Expand Down
Expand Up @@ -18,11 +18,11 @@
*/
package org.elasticsearch.search.suggest.phrase;

import java.util.Arrays;

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.search.suggest.SuggestUtils;
import org.elasticsearch.search.suggest.phrase.DirectCandidateGenerator.Candidate;

import java.util.Arrays;
//TODO public for tests
public final class Correction {

Expand All @@ -41,15 +41,33 @@ public String toString() {
}

public BytesRef join(BytesRef separator) {
return join(separator, new BytesRef());
return join(separator, null, null);
}

public BytesRef join(BytesRef separator, BytesRef preTag, BytesRef postTag) {
return join(separator, new BytesRef(), preTag, postTag);
}

public BytesRef join(BytesRef separator, BytesRef result) {
public BytesRef join(BytesRef separator, BytesRef result, BytesRef preTag, BytesRef postTag) {
BytesRef[] toJoin = new BytesRef[this.candidates.length];
int len = separator.length * this.candidates.length - 1;
for (int i = 0; i < toJoin.length; i++) {
toJoin[i] = candidates[i].term;
len += toJoin[i].length;
Candidate candidate = candidates[i];
if (preTag == null || candidate.userInput) {
toJoin[i] = candidate.term;
} else {
final int maxLen = preTag.length + postTag.length + candidate.term.length;
final BytesRef highlighted = new BytesRef(maxLen);// just allocate once
if (i == 0 || candidates[i-1].userInput) {
highlighted.append(preTag);
}
highlighted.append(candidate.term);
if (toJoin.length == i + 1 || candidates[i+1].userInput) {
highlighted.append(postTag);
}
toJoin[i] = highlighted;
}
len += toJoin[i].length;
}
result.offset = 0;
result.grow(len);
Expand Down
Expand Up @@ -126,7 +126,7 @@ public CandidateSet drawCandidates(CandidateSet set) throws IOException {
for (int i = 0; i < suggestSimilar.length; i++) {
SuggestWord suggestWord = suggestSimilar[i];
BytesRef candidate = new BytesRef(suggestWord.string);
postFilter(new Candidate(candidate, internalFrequency(candidate), suggestWord.score, score(suggestWord.freq, suggestWord.score, dictSize)), spare, byteSpare, candidates);
postFilter(new Candidate(candidate, internalFrequency(candidate), suggestWord.score, score(suggestWord.freq, suggestWord.score, dictSize), false), spare, byteSpare, candidates);
}
set.addCandidates(candidates);
return set;
Expand Down Expand Up @@ -160,9 +160,9 @@ public void nextToken() throws IOException {
if (posIncAttr.getPositionIncrement() > 0 && result.bytesEquals(candidate.term)) {
BytesRef term = BytesRef.deepCopyOf(result);
long freq = frequency(term);
candidates.add(new Candidate(BytesRef.deepCopyOf(term), freq, candidate.stringDistance, score(candidate.frequency, candidate.stringDistance, dictSize)));
candidates.add(new Candidate(BytesRef.deepCopyOf(term), freq, candidate.stringDistance, score(candidate.frequency, candidate.stringDistance, dictSize), false));
} else {
candidates.add(new Candidate(BytesRef.deepCopyOf(result), candidate.frequency, nonErrorLikelihood, score(candidate.frequency, candidate.stringDistance, dictSize)));
candidates.add(new Candidate(BytesRef.deepCopyOf(result), candidate.frequency, nonErrorLikelihood, score(candidate.frequency, candidate.stringDistance, dictSize), false));
}
}
}, spare);
Expand Down Expand Up @@ -213,17 +213,20 @@ public static class Candidate {
public final double stringDistance;
public final long frequency;
public final double score;
public final boolean userInput;

public Candidate(BytesRef term, long frequency, double stringDistance, double score) {
public Candidate(BytesRef term, long frequency, double stringDistance, double score, boolean userInput) {
this.frequency = frequency;
this.term = term;
this.stringDistance = stringDistance;
this.score = score;
this.userInput = userInput;
}

@Override
public String toString() {
return "Candidate [term=" + term.utf8ToString() + ", stringDistance=" + stringDistance + ", frequency=" + frequency + "]";
return "Candidate [term=" + term.utf8ToString() + ", stringDistance=" + stringDistance + ", frequency=" + frequency +
(userInput ? ", userInput" : "" ) + "]";
}

@Override
Expand Down Expand Up @@ -253,8 +256,8 @@ public boolean equals(Object obj) {
}

@Override
public Candidate createCandidate(BytesRef term, long frequency, double channelScore) throws IOException {
return new Candidate(term, frequency, channelScore, score(frequency, channelScore, dictSize));
public Candidate createCandidate(BytesRef term, long frequency, double channelScore, boolean userInput) throws IOException {
return new Candidate(term, frequency, channelScore, score(frequency, channelScore, dictSize), userInput);
}

}
Expand Up @@ -72,8 +72,8 @@ public int compare(Candidate left, Candidate right) {
return set;
}
@Override
public Candidate createCandidate(BytesRef term, long frequency, double channelScore) throws IOException {
return candidateGenerator[0].createCandidate(term, frequency, channelScore);
public Candidate createCandidate(BytesRef term, long frequency, double channelScore, boolean userInput) throws IOException {
return candidateGenerator[0].createCandidate(term, frequency, channelScore, userInput);
}

}
Expand Up @@ -93,7 +93,7 @@ public void nextToken() throws IOException {
if (currentSet != null) {
candidateSetsList.add(currentSet);
}
currentSet = new CandidateSet(Candidate.EMPTY, generator.createCandidate(BytesRef.deepCopyOf(term)));
currentSet = new CandidateSet(Candidate.EMPTY, generator.createCandidate(BytesRef.deepCopyOf(term), true));
}
}

Expand Down
Expand Up @@ -101,8 +101,27 @@ public SuggestionSearchContext.SuggestionContext parse(XContentParser parser, Ma
} else {
throw new ElasticSearchIllegalArgumentException("suggester[phrase] doesn't support array field [" + fieldName + "]");
}
} else if (token == Token.START_OBJECT && "smoothing".equals(fieldName)) {
parseSmoothingModel(parser, suggestion, fieldName);
} else if (token == Token.START_OBJECT) {
if ("smoothing".equals(fieldName)) {
parseSmoothingModel(parser, suggestion, fieldName);
} else if ("highlight".equals(fieldName)) {
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
fieldName = parser.currentName();
} else if (token.isValue()) {
if ("pre_tag".equals(fieldName) || "preTag".equals(fieldName)) {
suggestion.setPreTag(parser.bytes());
} else if ("post_tag".equals(fieldName) || "postTag".equals(fieldName)) {
suggestion.setPostTag(parser.bytes());
} else {
throw new ElasticSearchIllegalArgumentException(
"suggester[phrase][highlight] doesn't support field [" + fieldName + "]");
}
}
}
} else {
throw new ElasticSearchIllegalArgumentException("suggester[phrase] doesn't support array field [" + fieldName + "]");
}
} else {
throw new ElasticSearchIllegalArgumentException("suggester[phrase] doesn't support field [" + fieldName + "]");
}
Expand Down
Expand Up @@ -73,9 +73,14 @@ public Suggestion<? extends Entry<? extends Option>> execute(String name, Phrase
Suggestion.Entry<Option> resultEntry = new Suggestion.Entry<Option>(new StringText(spare.toString()), 0, spare.length);
BytesRef byteSpare = new BytesRef();
for (Correction correction : corrections) {
UnicodeUtil.UTF8toUTF16(correction.join(SEPARATOR, byteSpare), spare);
UnicodeUtil.UTF8toUTF16(correction.join(SEPARATOR, byteSpare, null, null), spare);
Text phrase = new StringText(spare.toString());
resultEntry.addOption(new Suggestion.Entry.Option(phrase, (float) (correction.score)));
Text highlighted = null;
if (suggestion.getPreTag() != null) {
UnicodeUtil.UTF8toUTF16(correction.join(SEPARATOR, byteSpare, suggestion.getPreTag(), suggestion.getPostTag()), spare);
highlighted = new StringText(spare.toString());
}
resultEntry.addOption(new Suggestion.Entry.Option(phrase, highlighted, (float) (correction.score)));
}
final Suggestion<Entry<Option>> response = new Suggestion<Entry<Option>>(name, suggestion.getSize());
response.addTerm(resultEntry);
Expand Down
Expand Up @@ -44,6 +44,8 @@ public final class PhraseSuggestionBuilder extends SuggestionBuilder<PhraseSugge
private SmoothingModel model;
private Boolean forceUnigrams;
private Integer tokenLimit;
private String preTag;
private String postTag;

public PhraseSuggestionBuilder(String name) {
super(name, "phrase");
Expand Down Expand Up @@ -147,6 +149,19 @@ public PhraseSuggestionBuilder tokenLimit(int tokenLimit) {
return this;
}

/**
* Setup highlighting for suggestions. If this is called a highlight field
* is returned with suggestions wrapping changed tokens with preTag and postTag.
*/
public PhraseSuggestionBuilder highlight(String preTag, String postTag) {
if (preTag == null || postTag == null) {
throw new ElasticSearchIllegalArgumentException("Pre and post tag must not be null.");
}
this.preTag = preTag;
this.postTag = postTag;
return this;
}

@Override
public XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException {
if (realWordErrorLikelihood != null) {
Expand Down Expand Up @@ -185,6 +200,12 @@ public XContentBuilder innerToXContent(XContentBuilder builder, Params params) t
model.toXContent(builder, params);
builder.endObject();
}
if (preTag != null) {
builder.startObject("highlight");
builder.field("pre_tag", preTag);
builder.field("post_tag", postTag);
builder.endObject();
}
return builder;
}

Expand Down
Expand Up @@ -38,6 +38,8 @@ class PhraseSuggestionContext extends SuggestionContext {
private int gramSize = 1;
private float confidence = 1.0f;
private int tokenLimit = NoisyChannelSpellChecker.DEFAULT_TOKEN_LIMIT;
private BytesRef preTag;
private BytesRef postTag;

private WordScorer.WordScorerFactory scorer;

Expand Down Expand Up @@ -162,4 +164,20 @@ public void setTokenLimit(int tokenLimit) {
public int getTokenLimit() {
return tokenLimit;
}

public void setPreTag(BytesRef preTag) {
this.preTag = preTag;
}

public BytesRef getPreTag() {
return preTag;
}

public void setPostTag(BytesRef postTag) {
this.postTag = postTag;
}

public BytesRef getPostTag() {
return postTag;
}
}
Expand Up @@ -50,7 +50,6 @@
import static org.elasticsearch.search.suggest.SuggestBuilder.phraseSuggestion;
import static org.elasticsearch.search.suggest.SuggestBuilder.termSuggestion;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSuggestionSize;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.*;

/**
Expand Down Expand Up @@ -592,6 +591,24 @@ public void testMarvelHerosPhraseSuggest() throws ElasticSearchException, IOExce
assertThat(searchSuggest.getSuggestion("simple_phrase").getEntries().get(0).getText().string(), equalTo("Xor the Got-Jewel"));
assertThat(searchSuggest.getSuggestion("simple_phrase").getEntries().get(0).getOptions().get(0).getText().string(), equalTo("xorr the god jewel"));

// Ask for highlighting
searchSuggest = searchSuggest(client(), "Xor the Got-Jewel",
phraseSuggestion("simple_phrase").
realWordErrorLikelihood(0.95f).field("bigram").gramSize(2).analyzer("body")
.addCandidateGenerator(PhraseSuggestionBuilder.candidateGenerator("body").minWordLength(1).suggestMode("always"))
.maxErrors(0.5f)
.size(1)
.highlight("<em>", "</em>"));

assertThat(searchSuggest, notNullValue());
assertThat(searchSuggest.size(), equalTo(1));
assertThat(searchSuggest.getSuggestion("simple_phrase").getName(), equalTo("simple_phrase"));
assertThat(searchSuggest.getSuggestion("simple_phrase").getEntries().size(), equalTo(1));
assertThat(searchSuggest.getSuggestion("simple_phrase").getEntries().get(0).getOptions().size(), equalTo(1));
assertThat(searchSuggest.getSuggestion("simple_phrase").getEntries().get(0).getText().string(), equalTo("Xor the Got-Jewel"));
assertThat(searchSuggest.getSuggestion("simple_phrase").getEntries().get(0).getOptions().get(0).getText().string(), equalTo("xorr the god jewel"));
assertThat(searchSuggest.getSuggestion("simple_phrase").getEntries().get(0).getOptions().get(0).getHighlighted().string(), equalTo("<em>xorr</em> the <em>god</em> jewel"));


// pass in a correct phrase
searchSuggest = searchSuggest(client(), "Xorr the God-Jewel",
Expand Down

0 comments on commit 2b76ac8

Please sign in to comment.