Skip to content

Commit

Permalink
only reduce_percentiles
Browse files Browse the repository at this point in the history
tests pass 500 iters
  • Loading branch information
andyb-elastic committed Sep 28, 2018
1 parent 2a7395e commit 88d5189
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 114 deletions.
Expand Up @@ -38,57 +38,42 @@
public class InternalMAD extends InternalNumericMetricsAggregation.SingleValue implements MedianAbsoluteDeviation {

private final TDigestState valueSketch;
private final TDigestState deviationSketch;
private final String method;

public InternalMAD(String name,
List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metaData,
DocValueFormat format,
TDigestState valueSketch,
TDigestState deviationSketch,
String method) {
TDigestState valueSketch) {

super(name, pipelineAggregators, metaData);
this.format = format;
this.valueSketch = valueSketch;
this.deviationSketch = deviationSketch;
this.method = method;
}

public InternalMAD(StreamInput in) throws IOException {
super(in);
format = in.readNamedWriteable(DocValueFormat.class);
valueSketch = TDigestState.read(in);
deviationSketch = TDigestState.read(in);
method = in.readString();
}

@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeNamedWriteable(format);
TDigestState.write(valueSketch, out);
TDigestState.write(deviationSketch, out);
out.writeString(method);
}

@Override
public InternalAggregation doReduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
TDigestState valueMerged = null;
TDigestState deviationMerged = null;
for (InternalAggregation aggregation : aggregations) {
final InternalMAD magAgg = (InternalMAD) aggregation;
if (valueMerged == null) {
valueMerged = new TDigestState(magAgg.valueSketch.compression());
}
if (deviationMerged == null) {
deviationMerged = new TDigestState(magAgg.deviationSketch.compression());
}
valueMerged.add(magAgg.valueSketch);
deviationMerged.add(magAgg.deviationSketch);
}

return new InternalMAD(name, pipelineAggregators(), metaData, format, valueMerged, deviationSketch, method);
return new InternalMAD(name, pipelineAggregators(), metaData, format, valueMerged);
}

@Override
Expand All @@ -114,9 +99,7 @@ protected int doHashCode() {
@Override
protected boolean doEquals(Object obj) {
InternalMAD other = (InternalMAD) obj;
return Objects.equals(valueSketch, other.valueSketch)
&& Objects.equals(deviationSketch, other.deviationSketch)
&& Objects.equals(method, other.method);
return Objects.equals(valueSketch, other.valueSketch);
}

@Override
Expand All @@ -132,6 +115,6 @@ public double value() {
// todo maybe - compute this when the object is constructed so we don't have to build a new tdigest for the deviations every time
@Override
public double getMAD() {
return computeMAD(valueSketch, deviationSketch, method);
return computeMAD(valueSketch);
}
}
Expand Up @@ -48,25 +48,20 @@ public class MADAggregationBuilder extends LeafOnly<ValuesSource.Numeric, MADAgg
public static final String NAME = "mad";

private static final ParseField COMPRESSION_FIELD = new ParseField("compression");
private static final ParseField METHOD_FIELD = new ParseField("method"); // todo remove

public static final List<String> METHODS = Arrays.asList("collection_median", "reduce_percentiles", "reduce_centroids"); // todo remove

private static final ObjectParser<MADAggregationBuilder, Void> PARSER;

static {
PARSER = new ObjectParser<>(NAME);
ValuesSourceParserHelper.declareNumericFields(PARSER, true, true, false); // todo verify these arguments
PARSER.declareDouble(MADAggregationBuilder::setCompression, COMPRESSION_FIELD);
PARSER.declareString(MADAggregationBuilder::setMethod, METHOD_FIELD); // todo remove
}

public static MADAggregationBuilder parse(String aggregationName, XContentParser parser) throws IOException {
return PARSER.parse(parser, new MADAggregationBuilder(aggregationName), null);
}

private double compression = 100.0d;
private String method = "collection_median"; // todo remove

public MADAggregationBuilder(String name) {
super(name, ValuesSourceType.NUMERIC, ValueType.NUMERIC);
Expand All @@ -75,15 +70,13 @@ public MADAggregationBuilder(String name) {
public MADAggregationBuilder(StreamInput in) throws IOException {
super(in, ValuesSourceType.NUMERIC, ValueType.NUMERIC);
compression = in.readDouble();
method = in.readString(); //todo remove
}

protected MADAggregationBuilder(MADAggregationBuilder clone,
AggregatorFactories.Builder factoriesBuilder,
Map<String, Object> metaData) {
super(clone, factoriesBuilder, metaData);
this.compression = clone.compression;
this.method = clone.method; // todo remove
}

/**
Expand All @@ -105,19 +98,6 @@ public MADAggregationBuilder setCompression(double compression) {
return this;
}

public String getMethod() { // todo remove
return method;
}

public MADAggregationBuilder setMethod(String method) { // todo remove
if (METHODS.contains(method) == false) {
throw new IllegalArgumentException("Invalid MAD method [" + method + "]");
}

this.method = method;
return this;
}

@Override
protected AggregationBuilder shallowCopy(AggregatorFactories.Builder factoriesBuilder, Map<String, Object> metaData) {
return new MADAggregationBuilder(this, factoriesBuilder, metaData);
Expand All @@ -126,7 +106,6 @@ protected AggregationBuilder shallowCopy(AggregatorFactories.Builder factoriesBu
@Override
protected void innerWriteTo(StreamOutput out) throws IOException {
out.writeDouble(compression);
out.writeString(method); // todo remove
}

@Override
Expand All @@ -136,27 +115,24 @@ protected void innerWriteTo(StreamOutput out) throws IOException {
AggregatorFactories.Builder subFactoriesBuilder)
throws IOException {

return new MADAggregatorFactory(name, config, context, parent, subFactoriesBuilder, metaData, compression, method);
// todo remove method
return new MADAggregatorFactory(name, config, context, parent, subFactoriesBuilder, metaData, compression);
}

@Override
protected XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
builder.field(COMPRESSION_FIELD.getPreferredName(), compression);
builder.field(METHOD_FIELD.getPreferredName(), method); // todo remove
return builder;
}

@Override
protected int innerHashCode() {
return Objects.hash(compression, method);
} // todo remove method
return Objects.hash(compression);
}

@Override
protected boolean innerEquals(Object obj) {
MADAggregationBuilder other = (MADAggregationBuilder) obj;
return Objects.equals(compression, other.compression)
&& Objects.equals(method, other.method); // todo remove
return Objects.equals(compression, other.compression);
}

@Override
Expand Down
Expand Up @@ -23,7 +23,6 @@
import org.apache.lucene.search.ScoreMode;
import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.DoubleArray;
import org.elasticsearch.common.util.ObjectArray;
import org.elasticsearch.index.fielddata.SortedNumericDoubleValues;
import org.elasticsearch.search.DocValueFormat;
Expand All @@ -47,11 +46,8 @@ public class MADAggregator extends NumericMetricsAggregator.SingleValue {
private final DocValueFormat format;

private final double compression;
private final String method; // todo remove

private ObjectArray<TDigestState> valueSketches;
private ObjectArray<TDigestState> deviationSketches;
private DoubleArray approximateMedians;

public MADAggregator(String name,
SearchContext context,
Expand All @@ -60,19 +56,15 @@ public MADAggregator(String name,
Map<String, Object> metaData,
ValuesSource.Numeric valuesSource,
DocValueFormat format,
double compression,
String method) throws IOException { // todo remov methhod
double compression) throws IOException {

super(name, context, parent, pipelineAggregators, metaData);

this.valuesSource = valuesSource;
this.format = format;

this.valueSketches = context.bigArrays().newObjectArray(1);
this.deviationSketches = context.bigArrays().newObjectArray(1);
this.approximateMedians = context.bigArrays().newDoubleArray(1);
this.compression = compression;
this.method = method; // todo remove
}

/*
Expand All @@ -85,7 +77,7 @@ public double metric(long owningBucketOrd) {
if (owningBucketOrd >= valueSketches.size() || valueSketches.get(owningBucketOrd) == null) {
return Double.NaN;
} else {
return computeMAD(valueSketches.get(owningBucketOrd), deviationSketches.get(owningBucketOrd), method);
return computeMAD(valueSketches.get(owningBucketOrd));
}
}

Expand All @@ -112,41 +104,18 @@ protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucket
public void collect(int doc, long bucket) throws IOException {

valueSketches = bigArrays.grow(valueSketches, bucket + 1);
deviationSketches = bigArrays.grow(deviationSketches, bucket + 1);

if (bucket >= approximateMedians.size()) {
final long from = approximateMedians.size();
approximateMedians = bigArrays.grow(approximateMedians, bucket + 1);
approximateMedians.fill(from, approximateMedians.size(), Double.NEGATIVE_INFINITY); // todo use NaN? is neginf right
}

TDigestState valueSketch = valueSketches.get(bucket);
if (valueSketch == null) {
valueSketch = new TDigestState(compression);
valueSketches.set(bucket, valueSketch);
}

TDigestState deviationSketch = deviationSketches.get(bucket);
if (deviationSketch == null) {
deviationSketch = new TDigestState(compression);
deviationSketches.set(bucket, deviationSketch);
}

double approximateMedian = approximateMedians.get(bucket);

if (values.advanceExact(doc)) {
final int valueCount = values.docValueCount();
for (int i = 0; i < valueCount; i++) {
final double value = values.nextValue();
if (approximateMedian == Double.NEGATIVE_INFINITY) {
approximateMedian = value; // when starting out, we set approx median to the first value
}
valueSketch.add(value);
final double deviation = Math.abs(approximateMedian - value);
deviationSketch.add(deviation);

approximateMedian = valueSketch.quantile(0.5);
approximateMedians.set(bucket, approximateMedian);
}
}
}
Expand All @@ -159,29 +128,24 @@ public InternalAggregation buildAggregation(long bucket) throws IOException {
return buildEmptyAggregation();
} else {
final TDigestState valueSketch = valueSketches.get(bucket);
final TDigestState deviationSketch = deviationSketches.get(bucket);
return new InternalMAD(name, pipelineAggregators(), metaData(), format, valueSketch, deviationSketch, method);
return new InternalMAD(name, pipelineAggregators(), metaData(), format, valueSketch);
}
}

@Override
public InternalAggregation buildEmptyAggregation() {
return new InternalMAD(name, pipelineAggregators(), metaData(), format,
new TDigestState(compression), new TDigestState(compression), method);
return new InternalMAD(name, pipelineAggregators(), metaData(), format, new TDigestState(compression));
}

@Override
public void doClose() {
Releasables.close(valueSketches, deviationSketches, approximateMedians);
Releasables.close(valueSketches);
}

// todo maybe this should live elsewhere
public static double computeMAD(TDigestState valuesSketch, TDigestState deviationsSketch, String method) {
if (method.equals("collection_median")) {
public static double computeMAD(TDigestState valuesSketch) {

return deviationsSketch.quantile(0.5);

} else if (method.equals("reduce_percentiles")) {
if (valuesSketch.size() > 0) {

final double approximateMedian = valuesSketch.quantile(0.5);
final TDigestState approximateDeviationsSketch = new TDigestState(valuesSketch.compression());
Expand All @@ -194,19 +158,8 @@ public static double computeMAD(TDigestState valuesSketch, TDigestState deviatio
}

return approximateDeviationsSketch.quantile(0.5);

} else if (method.equals("reduce_centroids")) {
final double approximateMedian = valuesSketch.quantile(0.5);
final TDigestState approximatedDeviationsSketch = new TDigestState(valuesSketch.compression());

valuesSketch.centroids().forEach(centroid -> {
final double deviation = Math.abs(approximateMedian - centroid.mean());
approximatedDeviationsSketch.add(deviation, centroid.count());
});

return approximatedDeviationsSketch.quantile(0.5);
} else {
throw new IllegalStateException("Invalid MAD method [" + method + "]");
return Double.NaN;
}
}
}
Expand Up @@ -35,27 +35,24 @@
public class MADAggregatorFactory extends ValuesSourceAggregatorFactory<ValuesSource.Numeric, MADAggregatorFactory> {

private final double compression;
private final String method; // todo remove

public MADAggregatorFactory(String name,
ValuesSourceConfig<ValuesSource.Numeric> config,
SearchContext context,
AggregatorFactory<?> parent,
AggregatorFactories.Builder subFactoriesBuilder,
Map<String, Object> metaData,
double compression,
String method) throws IOException { // todo remove method
double compression) throws IOException {
super(name, config, context, parent, subFactoriesBuilder, metaData);
this.compression = compression;
this.method = method;
}

@Override
protected Aggregator createUnmapped(Aggregator parent,
List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metaData) throws IOException {

return new MADAggregator(name, context, parent, pipelineAggregators, metaData, null, config.format(), compression, method);
return new MADAggregator(name, context, parent, pipelineAggregators, metaData, null, config.format(), compression);
}

@Override
Expand All @@ -65,6 +62,6 @@ protected Aggregator doCreateInternal(ValuesSource.Numeric valuesSource,
List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metaData) throws IOException {

return new MADAggregator(name, context, parent, pipelineAggregators, metaData, valuesSource, config.format(), compression, method);
return new MADAggregator(name, context, parent, pipelineAggregators, metaData, valuesSource, config.format(), compression);
}
}
Expand Up @@ -149,8 +149,8 @@ private void testCase(Query query,
IndexSearcher indexSearcher = newSearcher(indexReader, true, true);

MADAggregationBuilder builder = new MADAggregationBuilder("mad")
.setMethod("reduce_centroids")
.field("number");
.field("number")
.setCompression(randomDoubleBetween(20, 1000, true));

MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.LONG);
fieldType.setName("number");
Expand Down
Expand Up @@ -150,9 +150,8 @@ protected Collection<Class<? extends Plugin>> nodePlugins() {
private static MADAggregationBuilder randomBuilder() {
final MADAggregationBuilder builder = new MADAggregationBuilder("mad");
if (randomBoolean()) {
builder.setCompression(randomDoubleBetween(0, 1000, false));
builder.setCompression(randomDoubleBetween(20, 1000, false));
}
builder.setMethod("reduce_centroids");
return builder;
}

Expand Down

0 comments on commit 88d5189

Please sign in to comment.