From 58d79b40ed837542e727c5f4f3a410ccb5f9ff24 Mon Sep 17 00:00:00 2001 From: Greg Hogan Date: Fri, 14 Oct 2016 16:18:52 -0400 Subject: [PATCH] [FLINK-4586] [core] Broken AverageAccumulator --- .../accumulators/AverageAccumulator.java | 27 ++++++++++--------- .../accumulators/AverageAccumulatorTest.java | 18 ++++++++----- 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/flink-core/src/main/java/org/apache/flink/api/common/accumulators/AverageAccumulator.java b/flink-core/src/main/java/org/apache/flink/api/common/accumulators/AverageAccumulator.java index 9c0f62f12a901..67cf572cd7e87 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/accumulators/AverageAccumulator.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/accumulators/AverageAccumulator.java @@ -28,29 +28,30 @@ public class AverageAccumulator implements SimpleAccumulator { private static final long serialVersionUID = 3672555084179165255L; - - private double localValue; + private long count; + private double sum; + @Override public void add(Double value) { this.count++; - this.localValue += value; + this.sum += value; } public void add(double value) { this.count++; - this.localValue += value; + this.sum += value; } public void add(long value) { this.count++; - this.localValue += value; + this.sum += value; } public void add(int value) { this.count++; - this.localValue += value; + this.sum += value; } @Override @@ -58,21 +59,21 @@ public Double getLocalValue() { if (this.count == 0) { return 0.0; } - return this.localValue / (double)this.count; + return this.sum / this.count; } @Override public void resetLocal() { this.count = 0; - this.localValue = 0; + this.sum = 0; } @Override public void merge(Accumulator other) { if (other instanceof AverageAccumulator) { - AverageAccumulator temp = (AverageAccumulator)other; - this.count += temp.count; - this.localValue += other.getLocalValue(); + AverageAccumulator avg = (AverageAccumulator)other; + this.count += avg.count; + this.sum += avg.sum; } else { throw new IllegalArgumentException("The merged accumulator must be AverageAccumulator."); } @@ -81,13 +82,13 @@ public void merge(Accumulator other) { @Override public AverageAccumulator clone() { AverageAccumulator average = new AverageAccumulator(); - average.localValue = this.localValue; average.count = this.count; + average.sum = this.sum; return average; } @Override public String toString() { - return "AverageAccumulator " + this.localValue + " count " + this.count; + return "AverageAccumulator " + this.getLocalValue() + " for " + this.count + " elements"; } } diff --git a/flink-core/src/test/java/org/apache/flink/api/common/accumulators/AverageAccumulatorTest.java b/flink-core/src/test/java/org/apache/flink/api/common/accumulators/AverageAccumulatorTest.java index 9ebd27c7165b9..585511f5308ac 100644 --- a/flink-core/src/test/java/org/apache/flink/api/common/accumulators/AverageAccumulatorTest.java +++ b/flink-core/src/test/java/org/apache/flink/api/common/accumulators/AverageAccumulatorTest.java @@ -83,12 +83,18 @@ public void testAdd() { @Test public void testMergeSuccess() { - AverageAccumulator average = new AverageAccumulator(); - AverageAccumulator averageNew = new AverageAccumulator(); - average.add(1); - averageNew.add(2); - average.merge(averageNew); - assertEquals(1.5, average.getLocalValue(), 0.0); + AverageAccumulator avg1 = new AverageAccumulator(); + for (int i = 0; i < 5; i++) { + avg1.add(i); + } + + AverageAccumulator avg2 = new AverageAccumulator(); + for (int i = 5; i < 10; i++) { + avg2.add(i); + } + + avg1.merge(avg2); + assertEquals(4.5, avg1.getLocalValue(), 0.0); } @Test