Skip to content

Commit

Permalink
Added support for sort_mode avg for sorting by geo_distance.
Browse files Browse the repository at this point in the history
Closes #2962
  • Loading branch information
martijnvg committed May 1, 2013
1 parent c21ab1a commit 0d3b787
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 1 deletion.
Expand Up @@ -176,6 +176,7 @@ public double computeDistance(int doc) {

GeoPoint point = iter.next();
double distance = fixedSourceDistance.calculate(point.lat(), point.lon());
int counter = 1;
while (iter.hasNext()) {
point = iter.next();
double newDistance = fixedSourceDistance.calculate(point.lat(), point.lon());
Expand All @@ -190,9 +191,18 @@ public double computeDistance(int doc) {
distance = newDistance;
}
break;
case AVG:
distance += newDistance;
counter++;
break;
}
}
return distance;

if (sortMode == SortMode.AVG && counter > 1) {
return distance / counter;
} else {
return distance;
}
}

}
Expand Down
Expand Up @@ -115,6 +115,10 @@ public SortField parse(XContentParser parser, SearchContext context) throws Exce
sortMode = reverse ? SortMode.MAX : SortMode.MIN;
}

if (sortMode == SortMode.SUM) {
throw new ElasticSearchIllegalArgumentException("sort_mode [sum] isn't supported for sorting by geo distance");
}

FieldMapper mapper = context.smartNameFieldMapper(fieldName);
if (mapper == null) {
throw new ElasticSearchIllegalArgumentException("failed to find mapper for [" + fieldName + "] for geo distance based sort");
Expand Down
Expand Up @@ -19,12 +19,14 @@

package org.elasticsearch.test.integration.search.geo;

import org.elasticsearch.action.search.SearchPhaseExecutionException;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.Priority;
import org.elasticsearch.common.geo.GeoDistance;
import org.elasticsearch.common.unit.DistanceUnit;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.sort.SortBuilders;
import org.elasticsearch.search.sort.SortOrder;
Expand All @@ -41,6 +43,7 @@
import static org.elasticsearch.index.query.QueryBuilders.matchAllQuery;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.*;
import static org.testng.AssertJUnit.fail;

/**
*/
Expand Down Expand Up @@ -358,6 +361,45 @@ public void testDistanceSortingMVFields() throws Exception {
assertThat(((Number) searchResponse.getHits().getAt(2).sortValues()[0]).doubleValue(), closeTo(0.4621d, 0.01d));
assertThat(searchResponse.getHits().getAt(3).id(), equalTo("1"));
assertThat(((Number) searchResponse.getHits().getAt(3).sortValues()[0]).doubleValue(), equalTo(0d));

searchResponse = client.prepareSearch("test").setQuery(matchAllQuery())
.addSort(SortBuilders.geoDistanceSort("locations").point(40.7143528, -74.0059731).sortMode("avg").order(SortOrder.ASC))
.execute().actionGet();

assertThat(searchResponse.getHits().getTotalHits(), equalTo(4l));
assertThat(searchResponse.getHits().hits().length, equalTo(4));
assertThat(searchResponse.getHits().getAt(0).id(), equalTo("1"));
assertThat(((Number) searchResponse.getHits().getAt(0).sortValues()[0]).doubleValue(), equalTo(0d));
assertThat(searchResponse.getHits().getAt(1).id(), equalTo("3"));
assertThat(((Number) searchResponse.getHits().getAt(1).sortValues()[0]).doubleValue(), closeTo(1.157d, 0.01d));
assertThat(searchResponse.getHits().getAt(2).id(), equalTo("2"));
assertThat(((Number) searchResponse.getHits().getAt(2).sortValues()[0]).doubleValue(), closeTo(2.874d, 0.01d));
assertThat(searchResponse.getHits().getAt(3).id(), equalTo("4"));
assertThat(((Number) searchResponse.getHits().getAt(3).sortValues()[0]).doubleValue(), closeTo(5.301d, 0.01d));

searchResponse = client.prepareSearch("test").setQuery(matchAllQuery())
.addSort(SortBuilders.geoDistanceSort("locations").point(40.7143528, -74.0059731).sortMode("avg").order(SortOrder.DESC))
.execute().actionGet();

assertThat(searchResponse.getHits().getTotalHits(), equalTo(4l));
assertThat(searchResponse.getHits().hits().length, equalTo(4));
assertThat(searchResponse.getHits().getAt(0).id(), equalTo("4"));
assertThat(((Number) searchResponse.getHits().getAt(0).sortValues()[0]).doubleValue(), closeTo(5.301d, 0.01d));
assertThat(searchResponse.getHits().getAt(1).id(), equalTo("2"));
assertThat(((Number) searchResponse.getHits().getAt(1).sortValues()[0]).doubleValue(), closeTo(2.874d, 0.01d));
assertThat(searchResponse.getHits().getAt(2).id(), equalTo("3"));
assertThat(((Number) searchResponse.getHits().getAt(2).sortValues()[0]).doubleValue(), closeTo(1.157d, 0.01d));
assertThat(searchResponse.getHits().getAt(3).id(), equalTo("1"));
assertThat(((Number) searchResponse.getHits().getAt(3).sortValues()[0]).doubleValue(), equalTo(0d));

try {
client.prepareSearch("test").setQuery(matchAllQuery())
.addSort(SortBuilders.geoDistanceSort("locations").point(40.7143528, -74.0059731).sortMode("sum"))
.execute().actionGet();
fail("Expected error");
} catch (SearchPhaseExecutionException e) {
assertThat(e.shardFailures()[0].status(), equalTo(RestStatus.BAD_REQUEST));
}
}

@Test
Expand Down

0 comments on commit 0d3b787

Please sign in to comment.