diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassification/BinaryClassificationEvaluator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassification/BinaryClassificationEvaluator.java index d74e40b24..051d75137 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassification/BinaryClassificationEvaluator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassification/BinaryClassificationEvaluator.java @@ -20,8 +20,6 @@ import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.MapPartitionFunction; -import org.apache.flink.api.common.functions.ReduceFunction; -import org.apache.flink.api.common.functions.RichFlatMapFunction; import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.common.functions.RichMapPartitionFunction; import org.apache.flink.api.common.state.ListState; @@ -64,7 +62,6 @@ import java.util.Comparator; import java.util.HashMap; import java.util.Iterator; -import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Random; @@ -124,7 +121,7 @@ public void mapPartition( Iterable> values, Collector> out) { List> bufferedData = - new LinkedList<>(); + new ArrayList<>(); for (Tuple4 t4 : values) { bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2)); } @@ -142,48 +139,8 @@ public void mapPartition( TypeInformation.of(BinarySummary.class), new PartitionSummaryOperator()); - /* Sorts global data. Output Tuple4 : . */ - DataStream> dataWithOrders = - BroadcastUtils.withBroadcastStream( - Collections.singletonList(sortEvalData), - Collections.singletonMap(partitionSummariesKey, partitionSummaries), - inputList -> { - DataStream input = inputList.get(0); - return input.flatMap(new CalcSampleOrders(partitionSummariesKey)); - }); - - DataStream localAreaUnderROCVariable = - dataWithOrders.transform( - "AccumulateMultiScore", - TypeInformation.of(double[].class), - new AccumulateMultiScoreOperator()); - - DataStream middleAreaUnderROC = - DataStreamUtils.reduce( - localAreaUnderROCVariable, - (ReduceFunction) - (t1, t2) -> { - t2[0] += t1[0]; - t2[1] += t1[1]; - t2[2] += t1[2]; - return t2; - }); - - DataStream areaUnderROC = - middleAreaUnderROC.map( - (MapFunction) - value -> { - if (value[1] > 0 && value[2] > 0) { - return (value[0] - 1. * value[1] * (value[1] + 1) / 2) - / (value[1] * value[2]); - } else { - return Double.NaN; - } - }); - Map> broadcastMap = new HashMap<>(); broadcastMap.put(partitionSummariesKey, partitionSummaries); - broadcastMap.put(AREA_UNDER_ROC, areaUnderROC); DataStream localMetrics = BroadcastUtils.withBroadcastStream( Collections.singletonList(sortEvalData), @@ -218,89 +175,6 @@ public void mapPartition( return new Table[] {tEnv.fromDataStream(evalResult)}; } - /** Updates variables for calculating AreaUnderROC. */ - private static class AccumulateMultiScoreOperator extends AbstractStreamOperator - implements OneInputStreamOperator, double[]>, - BoundedOneInput { - private ListState accValueState; - private ListState scoreState; - - double[] accValue; - double score; - - @Override - public void endInput() { - if (accValue != null) { - output.collect( - new StreamRecord<>( - new double[] { - accValue[0] / accValue[1] * accValue[2], - accValue[2], - accValue[3] - })); - } - } - - @Override - public void processElement( - StreamRecord> streamRecord) { - Tuple4 t = streamRecord.getValue(); - if (accValue == null) { - accValue = new double[4]; - score = t.f0; - } else if (score != t.f0) { - output.collect( - new StreamRecord<>( - new double[] { - accValue[0] / accValue[1] * accValue[2], - accValue[2], - accValue[3] - })); - Arrays.fill(accValue, 0.0); - } - accValue[0] += t.f1; - accValue[1] += 1.0; - if (t.f2) { - accValue[2] += t.f3; - } else { - accValue[3] += t.f3; - } - } - - @Override - @SuppressWarnings("unchecked") - public void initializeState(StateInitializationContext context) throws Exception { - super.initializeState(context); - accValueState = - context.getOperatorStateStore() - .getListState( - new ListStateDescriptor<>( - "accValueState", TypeInformation.of(double[].class))); - accValue = - OperatorStateUtils.getUniqueElement(accValueState, "accValueState") - .orElse(null); - - scoreState = - context.getOperatorStateStore() - .getListState( - new ListStateDescriptor<>( - "scoreState", TypeInformation.of(Double.class))); - score = OperatorStateUtils.getUniqueElement(scoreState, "scoreState").orElse(0.0); - } - - @Override - @SuppressWarnings("unchecked") - public void snapshotState(StateSnapshotContext context) throws Exception { - super.snapshotState(context); - accValueState.clear(); - scoreState.clear(); - if (accValue != null) { - accValueState.add(accValue); - scoreState.add(score); - } - } - } - private static class PartitionSummaryOperator extends AbstractStreamOperator implements OneInputStreamOperator, BinarySummary>, BoundedOneInput { @@ -320,7 +194,6 @@ public void processElement(StreamRecord> streamR } @Override - @SuppressWarnings("unchecked") public void initializeState(StateInitializationContext context) throws Exception { super.initializeState(context); summaryState = @@ -340,7 +213,6 @@ public void initializeState(StateInitializationContext context) throws Exception } @Override - @SuppressWarnings("unchecked") public void snapshotState(StateSnapshotContext context) throws Exception { super.snapshotState(context); summaryState.clear(); @@ -385,24 +257,24 @@ public void mapPartition( List statistics = getRuntimeContext().getBroadcastVariable(partitionSummariesKey); - long[] countValues = + double[] accWeights = reduceBinarySummary(statistics, getRuntimeContext().getIndexOfThisSubtask()); - double areaUnderROC = - getRuntimeContext().getBroadcastVariable(AREA_UNDER_ROC).get(0); - long totalTrue = countValues[2]; - long totalFalse = countValues[3]; - if (totalTrue == 0) { - LOG.warn("There is no positive sample in data!"); + double totalSumWeightsPos = accWeights[2]; + double totalSumWeightsNeg = accWeights[3]; + if (totalSumWeightsPos == 0) { + LOG.warn("There is no positive samples in data!"); } - if (totalFalse == 0) { - LOG.warn("There is no negative sample in data!"); + if (totalSumWeightsNeg == 0) { + LOG.warn("There is no negative samples in data!"); } - BinaryMetrics metrics = new BinaryMetrics(0L, areaUnderROC); + BinaryMetrics metrics = new BinaryMetrics(0); + // Stores values of TPR, FPR, Precision, and PR calculated from samples with scores + // ranging from the maximum to the current one. double[] tprFprPrecision = new double[4]; for (Tuple3 t3 : iterable) { - updateBinaryMetrics(t3, metrics, countValues, tprFprPrecision); + updateBinaryMetrics(t3, metrics, accWeights, tprFprPrecision); } collector.collect(metrics); } @@ -411,35 +283,36 @@ public void mapPartition( private static void updateBinaryMetrics( Tuple3 cur, BinaryMetrics binaryMetrics, - long[] countValues, + double[] accWeights, double[] recordValues) { - if (binaryMetrics.count == 0) { - recordValues[0] = countValues[2] == 0 ? 1.0 : 1.0 * countValues[0] / countValues[2]; - recordValues[1] = countValues[3] == 0 ? 1.0 : 1.0 * countValues[1] / countValues[3]; + if (binaryMetrics.sumWeights == 0) { + recordValues[0] = accWeights[2] == 0 ? 1.0 : accWeights[0] / accWeights[2]; + recordValues[1] = accWeights[3] == 0 ? 1.0 : accWeights[1] / accWeights[3]; recordValues[2] = - countValues[0] + countValues[1] == 0 + accWeights[0] + accWeights[1] == 0 ? 1.0 - : 1.0 * countValues[0] / (countValues[0] + countValues[1]); - recordValues[3] = - 1.0 * (countValues[0] + countValues[1]) / (countValues[2] + countValues[3]); + : accWeights[0] / (accWeights[0] + accWeights[1]); + recordValues[3] = (accWeights[0] + accWeights[1]) / (accWeights[2] + accWeights[3]); } - binaryMetrics.count++; - if (cur.f1) { - countValues[0]++; + boolean isPos = cur.f1; + double weight = cur.f2; + binaryMetrics.sumWeights += weight; + if (isPos) { + accWeights[0] += weight; } else { - countValues[1]++; + accWeights[1] += weight; } - double tpr = countValues[2] == 0 ? 1.0 : 1.0 * countValues[0] / countValues[2]; - double fpr = countValues[3] == 0 ? 1.0 : 1.0 * countValues[1] / countValues[3]; + double tpr = accWeights[2] == 0 ? 1.0 : accWeights[0] / accWeights[2]; + double fpr = accWeights[3] == 0 ? 1.0 : accWeights[1] / accWeights[3]; double precision = - countValues[0] + countValues[1] == 0 + accWeights[0] + accWeights[1] == 0 ? 1.0 - : 1.0 * countValues[0] / (countValues[0] + countValues[1]); - double positiveRate = - 1.0 * (countValues[0] + countValues[1]) / (countValues[2] + countValues[3]); + : accWeights[0] / (accWeights[0] + accWeights[1]); + double positiveRate = (accWeights[0] + accWeights[1]) / (accWeights[2] + accWeights[3]); + binaryMetrics.areaUnderROC += (fpr - recordValues[1]) * (tpr + recordValues[0]) / 2; binaryMetrics.areaUnderLorenz += ((positiveRate - recordValues[3]) * (tpr + recordValues[0]) / 2); binaryMetrics.areaUnderPR += ((tpr - recordValues[0]) * (precision + recordValues[2]) / 2); @@ -451,65 +324,31 @@ private static void updateBinaryMetrics( recordValues[3] = positiveRate; } - /** - * For each sample, calculates its score order among all samples. The sample with minimum score - * has order 1, while the sample with maximum score has order samples. - * - *

Input is a dataset of tuple (score, is real positive, weight), output is a dataset of - * tuple (score, order, is real positive, weight). - */ - private static class CalcSampleOrders - extends RichFlatMapFunction< - Tuple3, Tuple4> { - private long startIndex; - private long total = -1; - private final String partitionSummariesKey; - - public CalcSampleOrders(String partitionSummariesKey) { - this.partitionSummariesKey = partitionSummariesKey; - } - - @Override - public void flatMap( - Tuple3 value, - Collector> out) - throws Exception { - if (total == -1) { - List statistics = - getRuntimeContext().getBroadcastVariable(partitionSummariesKey); - long[] countValues = - reduceBinarySummary( - statistics, getRuntimeContext().getIndexOfThisSubtask()); - startIndex = countValues[1] + countValues[0] + 1; - total = countValues[2] + countValues[3]; - } - out.collect(Tuple4.of(value.f0, total - startIndex + 1, value.f1, value.f2)); - startIndex++; - } - } - /** * @param values Reduce Summary of all workers. * @param taskId current taskId. - * @return [curTrue, curFalse, TotalTrue, TotalFalse] + * @return An array storing sum of weights of positives/negatives of tasks before the current + * one, and sum of weights of positives/negatives of all tasks. */ - private static long[] reduceBinarySummary(List values, int taskId) { + private static double[] reduceBinarySummary(List values, int taskId) { List list = new ArrayList<>(values); list.sort(Comparator.comparingDouble(t -> -t.maxScore)); - long curTrue = 0; - long curFalse = 0; - long totalTrue = 0; - long totalFalse = 0; + double prefixSumWeightsPos = 0; + double prefixSumWeightsNeg = 0; + double totalSumWeightsPos = 0; + double totalSumWeightsNeg = 0; for (BinarySummary statistics : list) { if (statistics.taskId == taskId) { - curFalse = totalFalse; - curTrue = totalTrue; + prefixSumWeightsNeg = totalSumWeightsNeg; + prefixSumWeightsPos = totalSumWeightsPos; } - totalTrue += statistics.curPositive; - totalFalse += statistics.curNegative; + totalSumWeightsPos += statistics.sumWeightsPos; + totalSumWeightsNeg += statistics.sumWeightsNeg; } - return new long[] {curTrue, curFalse, totalTrue, totalFalse}; + return new double[] { + prefixSumWeightsPos, prefixSumWeightsNeg, totalSumWeightsPos, totalSumWeightsNeg + }; } /** @@ -520,13 +359,16 @@ private static long[] reduceBinarySummary(List values, int taskId */ private static void updateBinarySummary( BinarySummary statistics, Tuple3 evalElement) { - if (evalElement.f1) { - statistics.curPositive++; + boolean isPos = evalElement.f1; + double weight = evalElement.f2; + double score = evalElement.f0; + if (isPos) { + statistics.sumWeightsPos += weight; } else { - statistics.curNegative++; + statistics.sumWeightsNeg += weight; } - if (Double.compare(statistics.maxScore, evalElement.f0) < 0) { - statistics.maxScore = evalElement.f0; + if (Double.compare(statistics.maxScore, score) < 0) { + statistics.maxScore = score; } } @@ -673,25 +515,26 @@ public static class BinarySummary implements Serializable { public Integer taskId; // maximum score in this partition public double maxScore; - // real positives in this partition - public long curPositive; - // real negatives in this partition - public long curNegative; + // sum of weights of positives in this partition + public double sumWeightsPos; + // sum of weights of negatives in this partition + public double sumWeightsNeg; public BinarySummary() {} - public BinarySummary(Integer taskId, double maxScore, long curPositive, long curNegative) { + public BinarySummary( + Integer taskId, double maxScore, double sumWeightsPos, double sumWeightsNeg) { this.taskId = taskId; this.maxScore = maxScore; - this.curPositive = curPositive; - this.curNegative = curNegative; + this.sumWeightsPos = sumWeightsPos; + this.sumWeightsNeg = sumWeightsNeg; } } /** The evaluation metrics for binary classification. */ public static class BinaryMetrics { - /* The count of samples. */ - public long count; + /* The sum of weights of samples. */ + public double sumWeights; /* Area under ROC */ public double areaUnderROC; @@ -707,22 +550,19 @@ public static class BinaryMetrics { public BinaryMetrics() {} - public BinaryMetrics(long count, double areaUnderROC) { - this.count = count; - this.areaUnderROC = areaUnderROC; + public BinaryMetrics(long sumWeights) { + this.sumWeights = sumWeights; } public BinaryMetrics merge(BinaryMetrics binaryClassMetrics) { if (null == binaryClassMetrics) { return this; } - Preconditions.checkState( - Double.compare(areaUnderROC, binaryClassMetrics.areaUnderROC) == 0, - "AreaUnderROC not equal!"); - count += binaryClassMetrics.count; - ks = Math.max(ks, binaryClassMetrics.ks); - areaUnderPR += binaryClassMetrics.areaUnderPR; + sumWeights += binaryClassMetrics.sumWeights; + areaUnderROC += binaryClassMetrics.areaUnderROC; areaUnderLorenz += binaryClassMetrics.areaUnderLorenz; + areaUnderPR += binaryClassMetrics.areaUnderPR; + ks = Math.max(ks, binaryClassMetrics.ks); return this; } } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java index 0c146a3d2..f6f38b298 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java @@ -118,7 +118,8 @@ public class BinaryClassificationEvaluatorTest extends AbstractTestBase { new double[] { 0.8571428571428571, 0.9377705627705628, 0.8571428571428571, 0.6488095238095237 }; - private static final double EXPECTED_DATA_W = 0.8911680911680911; + private static final double[] EXPECTED_DATA_W = + new double[] {0.8717948717948718, 0.9510202726261435}; private static final double EPS = 1.0e-5; @Before @@ -297,14 +298,20 @@ public void testEvaluateWithMultiScore() { public void testEvaluateWithWeight() { BinaryClassificationEvaluator eval = new BinaryClassificationEvaluator() - .setMetricsNames(BinaryClassificationEvaluatorParams.AREA_UNDER_ROC) + .setMetricsNames( + BinaryClassificationEvaluatorParams.AREA_UNDER_ROC, + BinaryClassificationEvaluatorParams.AREA_UNDER_PR) .setWeightCol("weight"); Table evalResult = eval.transform(inputDataTableWithWeight)[0]; - List results = IteratorUtils.toList(evalResult.execute().collect()); + Row result = (Row) IteratorUtils.toList(evalResult.execute().collect()).get(0); assertArrayEquals( - new String[] {BinaryClassificationEvaluatorParams.AREA_UNDER_ROC}, + new String[] { + BinaryClassificationEvaluatorParams.AREA_UNDER_ROC, + BinaryClassificationEvaluatorParams.AREA_UNDER_PR + }, evalResult.getResolvedSchema().getColumnNames().toArray()); - assertEquals(EXPECTED_DATA_W, results.get(0).getFieldAs(0), EPS); + assertArrayEquals( + EXPECTED_DATA_W, new double[] {result.getFieldAs(0), result.getFieldAs(1)}, EPS); } @Test diff --git a/flink-ml-python/pyflink/ml/evaluation/tests/tests_binaryclassification.py b/flink-ml-python/pyflink/ml/evaluation/tests/tests_binaryclassification.py index 7f34eb4c9..55a8f2b91 100644 --- a/flink-ml-python/pyflink/ml/evaluation/tests/tests_binaryclassification.py +++ b/flink-ml-python/pyflink/ml/evaluation/tests/tests_binaryclassification.py @@ -16,10 +16,10 @@ # limitations under the License. ################################################################################ import os - from pyflink.common import Types -from pyflink.ml.linalg import Vectors, DenseVectorTypeInfo + from pyflink.ml.evaluation.binaryclassification import BinaryClassificationEvaluator +from pyflink.ml.linalg import Vectors, DenseVectorTypeInfo from pyflink.ml.tests.test_utils import PyFlinkMLTestCase @@ -111,7 +111,7 @@ def setUp(self): self.expected_data_m = [0.8571428571428571, 0.9377705627705628, 0.8571428571428571, 0.6488095238095237] - self.expected_data_w = 0.8911680911680911 + self.expected_data_w = [0.8717948717948718, 0.9510202726261435] self.eps = 1e-5 @@ -185,11 +185,11 @@ def test_evaluate_with_multi_score(self): def test_evaluate_with_weight(self): evaluator = BinaryClassificationEvaluator() \ - .set_metrics_names("areaUnderROC") \ + .set_metrics_names("areaUnderROC", "areaUnderPR") \ .set_weight_col("weight") output = evaluator.transform(self.input_data_table_with_weight)[0] self.assertEqual( - ["areaUnderROC"], + ["areaUnderROC", "areaUnderPR"], output.get_schema().get_field_names()) results = [result for result in output.execute().collect()] result = results[0]