From 1cba34f68670a3c03746acb4a629c3bae2aeabe2 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Thu, 17 Aug 2023 19:23:38 +0800 Subject: [PATCH 1/5] Correct the weighted case in BinaryClassificationEvaluatorTest. --- .../BinaryClassificationEvaluatorTest.java | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java index 0c146a3d2..b55e0fcde 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java @@ -118,7 +118,8 @@ public class BinaryClassificationEvaluatorTest extends AbstractTestBase { new double[] { 0.8571428571428571, 0.9377705627705628, 0.8571428571428571, 0.6488095238095237 }; - private static final double EXPECTED_DATA_W = 0.8911680911680911; + private static final double[] EXPECTED_DATA_W = + new double[] {0.8717948717948718, 0.9510202726261435}; private static final double EPS = 1.0e-5; @Before @@ -297,14 +298,20 @@ public void testEvaluateWithMultiScore() { public void testEvaluateWithWeight() { BinaryClassificationEvaluator eval = new BinaryClassificationEvaluator() - .setMetricsNames(BinaryClassificationEvaluatorParams.AREA_UNDER_ROC) + .setMetricsNames( + BinaryClassificationEvaluatorParams.AREA_UNDER_ROC, + BinaryClassificationEvaluatorParams.AREA_UNDER_PR) .setWeightCol("weight"); Table evalResult = eval.transform(inputDataTableWithWeight)[0]; - List results = IteratorUtils.toList(evalResult.execute().collect()); + Row result = (Row) IteratorUtils.toList(evalResult.execute().collect()).get(0); assertArrayEquals( - new String[] {BinaryClassificationEvaluatorParams.AREA_UNDER_ROC}, + new String[] { + BinaryClassificationEvaluatorParams.AREA_UNDER_ROC, + BinaryClassificationEvaluatorParams.AREA_UNDER_PR + }, evalResult.getResolvedSchema().getColumnNames().toArray()); - assertEquals(EXPECTED_DATA_W, results.get(0).getFieldAs(0), EPS); + assertArrayEquals( + EXPECTED_DATA_W, new double[] {result.getFieldAs(0), result.getFieldAs(1)}, 1e-9); } @Test From 9a7b8516077716961d43a6c02d01b99eabd138de Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Fri, 11 Aug 2023 16:29:42 +0800 Subject: [PATCH 2/5] Make weighted auROc and auPRC correct. --- .../BinaryClassificationEvaluator.java | 267 ++++-------------- 1 file changed, 51 insertions(+), 216 deletions(-) diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassification/BinaryClassificationEvaluator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassification/BinaryClassificationEvaluator.java index d74e40b24..33a8d853f 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassification/BinaryClassificationEvaluator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassification/BinaryClassificationEvaluator.java @@ -20,8 +20,6 @@ import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.MapPartitionFunction; -import org.apache.flink.api.common.functions.ReduceFunction; -import org.apache.flink.api.common.functions.RichFlatMapFunction; import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.common.functions.RichMapPartitionFunction; import org.apache.flink.api.common.state.ListState; @@ -64,7 +62,6 @@ import java.util.Comparator; import java.util.HashMap; import java.util.Iterator; -import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Random; @@ -124,7 +121,7 @@ public void mapPartition( Iterable> values, Collector> out) { List> bufferedData = - new LinkedList<>(); + new ArrayList<>(); for (Tuple4 t4 : values) { bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2)); } @@ -142,48 +139,8 @@ public void mapPartition( TypeInformation.of(BinarySummary.class), new PartitionSummaryOperator()); - /* Sorts global data. Output Tuple4 : . */ - DataStream> dataWithOrders = - BroadcastUtils.withBroadcastStream( - Collections.singletonList(sortEvalData), - Collections.singletonMap(partitionSummariesKey, partitionSummaries), - inputList -> { - DataStream input = inputList.get(0); - return input.flatMap(new CalcSampleOrders(partitionSummariesKey)); - }); - - DataStream localAreaUnderROCVariable = - dataWithOrders.transform( - "AccumulateMultiScore", - TypeInformation.of(double[].class), - new AccumulateMultiScoreOperator()); - - DataStream middleAreaUnderROC = - DataStreamUtils.reduce( - localAreaUnderROCVariable, - (ReduceFunction) - (t1, t2) -> { - t2[0] += t1[0]; - t2[1] += t1[1]; - t2[2] += t1[2]; - return t2; - }); - - DataStream areaUnderROC = - middleAreaUnderROC.map( - (MapFunction) - value -> { - if (value[1] > 0 && value[2] > 0) { - return (value[0] - 1. * value[1] * (value[1] + 1) / 2) - / (value[1] * value[2]); - } else { - return Double.NaN; - } - }); - Map> broadcastMap = new HashMap<>(); broadcastMap.put(partitionSummariesKey, partitionSummaries); - broadcastMap.put(AREA_UNDER_ROC, areaUnderROC); DataStream localMetrics = BroadcastUtils.withBroadcastStream( Collections.singletonList(sortEvalData), @@ -218,89 +175,6 @@ public void mapPartition( return new Table[] {tEnv.fromDataStream(evalResult)}; } - /** Updates variables for calculating AreaUnderROC. */ - private static class AccumulateMultiScoreOperator extends AbstractStreamOperator - implements OneInputStreamOperator, double[]>, - BoundedOneInput { - private ListState accValueState; - private ListState scoreState; - - double[] accValue; - double score; - - @Override - public void endInput() { - if (accValue != null) { - output.collect( - new StreamRecord<>( - new double[] { - accValue[0] / accValue[1] * accValue[2], - accValue[2], - accValue[3] - })); - } - } - - @Override - public void processElement( - StreamRecord> streamRecord) { - Tuple4 t = streamRecord.getValue(); - if (accValue == null) { - accValue = new double[4]; - score = t.f0; - } else if (score != t.f0) { - output.collect( - new StreamRecord<>( - new double[] { - accValue[0] / accValue[1] * accValue[2], - accValue[2], - accValue[3] - })); - Arrays.fill(accValue, 0.0); - } - accValue[0] += t.f1; - accValue[1] += 1.0; - if (t.f2) { - accValue[2] += t.f3; - } else { - accValue[3] += t.f3; - } - } - - @Override - @SuppressWarnings("unchecked") - public void initializeState(StateInitializationContext context) throws Exception { - super.initializeState(context); - accValueState = - context.getOperatorStateStore() - .getListState( - new ListStateDescriptor<>( - "accValueState", TypeInformation.of(double[].class))); - accValue = - OperatorStateUtils.getUniqueElement(accValueState, "accValueState") - .orElse(null); - - scoreState = - context.getOperatorStateStore() - .getListState( - new ListStateDescriptor<>( - "scoreState", TypeInformation.of(Double.class))); - score = OperatorStateUtils.getUniqueElement(scoreState, "scoreState").orElse(0.0); - } - - @Override - @SuppressWarnings("unchecked") - public void snapshotState(StateSnapshotContext context) throws Exception { - super.snapshotState(context); - accValueState.clear(); - scoreState.clear(); - if (accValue != null) { - accValueState.add(accValue); - scoreState.add(score); - } - } - } - private static class PartitionSummaryOperator extends AbstractStreamOperator implements OneInputStreamOperator, BinarySummary>, BoundedOneInput { @@ -320,7 +194,6 @@ public void processElement(StreamRecord> streamR } @Override - @SuppressWarnings("unchecked") public void initializeState(StateInitializationContext context) throws Exception { super.initializeState(context); summaryState = @@ -340,7 +213,6 @@ public void initializeState(StateInitializationContext context) throws Exception } @Override - @SuppressWarnings("unchecked") public void snapshotState(StateSnapshotContext context) throws Exception { super.snapshotState(context); summaryState.clear(); @@ -362,7 +234,7 @@ public void mapPartition( reduceMetrics = reduceMetrics.merge(iter.next()); } Map map = new HashMap<>(); - map.put(AREA_UNDER_ROC, reduceMetrics.areaUnderROC); + map.put(AREA_UNDER_ROC, 1. - reduceMetrics.areaUnderROC); map.put(AREA_UNDER_PR, reduceMetrics.areaUnderPR); map.put(AREA_UNDER_LORENZ, reduceMetrics.areaUnderLorenz); map.put(KS, reduceMetrics.ks); @@ -385,13 +257,11 @@ public void mapPartition( List statistics = getRuntimeContext().getBroadcastVariable(partitionSummariesKey); - long[] countValues = + double[] countValues = reduceBinarySummary(statistics, getRuntimeContext().getIndexOfThisSubtask()); - double areaUnderROC = - getRuntimeContext().getBroadcastVariable(AREA_UNDER_ROC).get(0); - long totalTrue = countValues[2]; - long totalFalse = countValues[3]; + double totalTrue = countValues[2]; + double totalFalse = countValues[3]; if (totalTrue == 0) { LOG.warn("There is no positive sample in data!"); } @@ -399,7 +269,7 @@ public void mapPartition( LOG.warn("There is no negative sample in data!"); } - BinaryMetrics metrics = new BinaryMetrics(0L, areaUnderROC); + BinaryMetrics metrics = new BinaryMetrics(0L); double[] tprFprPrecision = new double[4]; for (Tuple3 t3 : iterable) { updateBinaryMetrics(t3, metrics, countValues, tprFprPrecision); @@ -411,35 +281,36 @@ public void mapPartition( private static void updateBinaryMetrics( Tuple3 cur, BinaryMetrics binaryMetrics, - long[] countValues, + double[] countValues, double[] recordValues) { if (binaryMetrics.count == 0) { - recordValues[0] = countValues[2] == 0 ? 1.0 : 1.0 * countValues[0] / countValues[2]; - recordValues[1] = countValues[3] == 0 ? 1.0 : 1.0 * countValues[1] / countValues[3]; + recordValues[0] = countValues[2] == 0 ? 1.0 : countValues[0] / countValues[2]; + recordValues[1] = countValues[3] == 0 ? 1.0 : countValues[1] / countValues[3]; recordValues[2] = countValues[0] + countValues[1] == 0 ? 1.0 - : 1.0 * countValues[0] / (countValues[0] + countValues[1]); - recordValues[3] = - 1.0 * (countValues[0] + countValues[1]) / (countValues[2] + countValues[3]); + : countValues[0] / (countValues[0] + countValues[1]); + recordValues[3] = (countValues[0] + countValues[1]) / (countValues[2] + countValues[3]); } - binaryMetrics.count++; - if (cur.f1) { - countValues[0]++; + boolean isPos = cur.f1; + double weight = cur.f2; + binaryMetrics.count += weight; + if (isPos) { + countValues[0] += weight; } else { - countValues[1]++; + countValues[1] += weight; } - double tpr = countValues[2] == 0 ? 1.0 : 1.0 * countValues[0] / countValues[2]; - double fpr = countValues[3] == 0 ? 1.0 : 1.0 * countValues[1] / countValues[3]; + double tpr = countValues[2] == 0 ? 1.0 : countValues[0] / countValues[2]; + double fpr = countValues[3] == 0 ? 1.0 : countValues[1] / countValues[3]; double precision = countValues[0] + countValues[1] == 0 ? 1.0 - : 1.0 * countValues[0] / (countValues[0] + countValues[1]); - double positiveRate = - 1.0 * (countValues[0] + countValues[1]) / (countValues[2] + countValues[3]); + : countValues[0] / (countValues[0] + countValues[1]); + double positiveRate = (countValues[0] + countValues[1]) / (countValues[2] + countValues[3]); + binaryMetrics.areaUnderROC += (fpr + recordValues[1]) * (tpr - recordValues[0]) / 2; binaryMetrics.areaUnderLorenz += ((positiveRate - recordValues[3]) * (tpr + recordValues[0]) / 2); binaryMetrics.areaUnderPR += ((tpr - recordValues[0]) * (precision + recordValues[2]) / 2); @@ -451,65 +322,28 @@ private static void updateBinaryMetrics( recordValues[3] = positiveRate; } - /** - * For each sample, calculates its score order among all samples. The sample with minimum score - * has order 1, while the sample with maximum score has order samples. - * - *

Input is a dataset of tuple (score, is real positive, weight), output is a dataset of - * tuple (score, order, is real positive, weight). - */ - private static class CalcSampleOrders - extends RichFlatMapFunction< - Tuple3, Tuple4> { - private long startIndex; - private long total = -1; - private final String partitionSummariesKey; - - public CalcSampleOrders(String partitionSummariesKey) { - this.partitionSummariesKey = partitionSummariesKey; - } - - @Override - public void flatMap( - Tuple3 value, - Collector> out) - throws Exception { - if (total == -1) { - List statistics = - getRuntimeContext().getBroadcastVariable(partitionSummariesKey); - long[] countValues = - reduceBinarySummary( - statistics, getRuntimeContext().getIndexOfThisSubtask()); - startIndex = countValues[1] + countValues[0] + 1; - total = countValues[2] + countValues[3]; - } - out.collect(Tuple4.of(value.f0, total - startIndex + 1, value.f1, value.f2)); - startIndex++; - } - } - /** * @param values Reduce Summary of all workers. * @param taskId current taskId. * @return [curTrue, curFalse, TotalTrue, TotalFalse] */ - private static long[] reduceBinarySummary(List values, int taskId) { + private static double[] reduceBinarySummary(List values, int taskId) { List list = new ArrayList<>(values); list.sort(Comparator.comparingDouble(t -> -t.maxScore)); - long curTrue = 0; - long curFalse = 0; - long totalTrue = 0; - long totalFalse = 0; + double curTrue = 0; + double curFalse = 0; + double totalTrue = 0; + double totalFalse = 0; for (BinarySummary statistics : list) { if (statistics.taskId == taskId) { curFalse = totalFalse; curTrue = totalTrue; } - totalTrue += statistics.curPositive; - totalFalse += statistics.curNegative; + totalTrue += statistics.sumWeightPos; + totalFalse += statistics.sumWeightNeg; } - return new long[] {curTrue, curFalse, totalTrue, totalFalse}; + return new double[] {curTrue, curFalse, totalTrue, totalFalse}; } /** @@ -520,13 +354,16 @@ private static long[] reduceBinarySummary(List values, int taskId */ private static void updateBinarySummary( BinarySummary statistics, Tuple3 evalElement) { - if (evalElement.f1) { - statistics.curPositive++; + boolean isPos = evalElement.f1; + double weight = evalElement.f2; + double score = evalElement.f0; + if (isPos) { + statistics.sumWeightPos += weight; } else { - statistics.curNegative++; + statistics.sumWeightNeg += weight; } - if (Double.compare(statistics.maxScore, evalElement.f0) < 0) { - statistics.maxScore = evalElement.f0; + if (Double.compare(statistics.maxScore, score) < 0) { + statistics.maxScore = score; } } @@ -673,25 +510,26 @@ public static class BinarySummary implements Serializable { public Integer taskId; // maximum score in this partition public double maxScore; - // real positives in this partition - public long curPositive; - // real negatives in this partition - public long curNegative; + // sum of weights of positives in this partition + public double sumWeightPos; + // sum of weights of negatives in this partition + public double sumWeightNeg; public BinarySummary() {} - public BinarySummary(Integer taskId, double maxScore, long curPositive, long curNegative) { + public BinarySummary( + Integer taskId, double maxScore, double sumWeightPos, double sumWeightNeg) { this.taskId = taskId; this.maxScore = maxScore; - this.curPositive = curPositive; - this.curNegative = curNegative; + this.sumWeightPos = sumWeightPos; + this.sumWeightNeg = sumWeightNeg; } } /** The evaluation metrics for binary classification. */ public static class BinaryMetrics { /* The count of samples. */ - public long count; + public double count; /* Area under ROC */ public double areaUnderROC; @@ -707,22 +545,19 @@ public static class BinaryMetrics { public BinaryMetrics() {} - public BinaryMetrics(long count, double areaUnderROC) { + public BinaryMetrics(long count) { this.count = count; - this.areaUnderROC = areaUnderROC; } public BinaryMetrics merge(BinaryMetrics binaryClassMetrics) { if (null == binaryClassMetrics) { return this; } - Preconditions.checkState( - Double.compare(areaUnderROC, binaryClassMetrics.areaUnderROC) == 0, - "AreaUnderROC not equal!"); count += binaryClassMetrics.count; - ks = Math.max(ks, binaryClassMetrics.ks); - areaUnderPR += binaryClassMetrics.areaUnderPR; + areaUnderROC += binaryClassMetrics.areaUnderROC; areaUnderLorenz += binaryClassMetrics.areaUnderLorenz; + areaUnderPR += binaryClassMetrics.areaUnderPR; + ks = Math.max(ks, binaryClassMetrics.ks); return this; } } From 1d3b9c4c3521658cfd2d68f021f59b2eeb82557e Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Tue, 22 Aug 2023 11:49:08 +0800 Subject: [PATCH 3/5] Improve readability of variables names, documents according to comments. --- .../BinaryClassificationEvaluator.java | 101 +++++++++--------- 1 file changed, 53 insertions(+), 48 deletions(-) diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassification/BinaryClassificationEvaluator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassification/BinaryClassificationEvaluator.java index 33a8d853f..051d75137 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassification/BinaryClassificationEvaluator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassification/BinaryClassificationEvaluator.java @@ -234,7 +234,7 @@ public void mapPartition( reduceMetrics = reduceMetrics.merge(iter.next()); } Map map = new HashMap<>(); - map.put(AREA_UNDER_ROC, 1. - reduceMetrics.areaUnderROC); + map.put(AREA_UNDER_ROC, reduceMetrics.areaUnderROC); map.put(AREA_UNDER_PR, reduceMetrics.areaUnderPR); map.put(AREA_UNDER_LORENZ, reduceMetrics.areaUnderLorenz); map.put(KS, reduceMetrics.ks); @@ -257,22 +257,24 @@ public void mapPartition( List statistics = getRuntimeContext().getBroadcastVariable(partitionSummariesKey); - double[] countValues = + double[] accWeights = reduceBinarySummary(statistics, getRuntimeContext().getIndexOfThisSubtask()); - double totalTrue = countValues[2]; - double totalFalse = countValues[3]; - if (totalTrue == 0) { - LOG.warn("There is no positive sample in data!"); + double totalSumWeightsPos = accWeights[2]; + double totalSumWeightsNeg = accWeights[3]; + if (totalSumWeightsPos == 0) { + LOG.warn("There is no positive samples in data!"); } - if (totalFalse == 0) { - LOG.warn("There is no negative sample in data!"); + if (totalSumWeightsNeg == 0) { + LOG.warn("There is no negative samples in data!"); } - BinaryMetrics metrics = new BinaryMetrics(0L); + BinaryMetrics metrics = new BinaryMetrics(0); + // Stores values of TPR, FPR, Precision, and PR calculated from samples with scores + // ranging from the maximum to the current one. double[] tprFprPrecision = new double[4]; for (Tuple3 t3 : iterable) { - updateBinaryMetrics(t3, metrics, countValues, tprFprPrecision); + updateBinaryMetrics(t3, metrics, accWeights, tprFprPrecision); } collector.collect(metrics); } @@ -281,36 +283,36 @@ public void mapPartition( private static void updateBinaryMetrics( Tuple3 cur, BinaryMetrics binaryMetrics, - double[] countValues, + double[] accWeights, double[] recordValues) { - if (binaryMetrics.count == 0) { - recordValues[0] = countValues[2] == 0 ? 1.0 : countValues[0] / countValues[2]; - recordValues[1] = countValues[3] == 0 ? 1.0 : countValues[1] / countValues[3]; + if (binaryMetrics.sumWeights == 0) { + recordValues[0] = accWeights[2] == 0 ? 1.0 : accWeights[0] / accWeights[2]; + recordValues[1] = accWeights[3] == 0 ? 1.0 : accWeights[1] / accWeights[3]; recordValues[2] = - countValues[0] + countValues[1] == 0 + accWeights[0] + accWeights[1] == 0 ? 1.0 - : countValues[0] / (countValues[0] + countValues[1]); - recordValues[3] = (countValues[0] + countValues[1]) / (countValues[2] + countValues[3]); + : accWeights[0] / (accWeights[0] + accWeights[1]); + recordValues[3] = (accWeights[0] + accWeights[1]) / (accWeights[2] + accWeights[3]); } boolean isPos = cur.f1; double weight = cur.f2; - binaryMetrics.count += weight; + binaryMetrics.sumWeights += weight; if (isPos) { - countValues[0] += weight; + accWeights[0] += weight; } else { - countValues[1] += weight; + accWeights[1] += weight; } - double tpr = countValues[2] == 0 ? 1.0 : countValues[0] / countValues[2]; - double fpr = countValues[3] == 0 ? 1.0 : countValues[1] / countValues[3]; + double tpr = accWeights[2] == 0 ? 1.0 : accWeights[0] / accWeights[2]; + double fpr = accWeights[3] == 0 ? 1.0 : accWeights[1] / accWeights[3]; double precision = - countValues[0] + countValues[1] == 0 + accWeights[0] + accWeights[1] == 0 ? 1.0 - : countValues[0] / (countValues[0] + countValues[1]); - double positiveRate = (countValues[0] + countValues[1]) / (countValues[2] + countValues[3]); + : accWeights[0] / (accWeights[0] + accWeights[1]); + double positiveRate = (accWeights[0] + accWeights[1]) / (accWeights[2] + accWeights[3]); - binaryMetrics.areaUnderROC += (fpr + recordValues[1]) * (tpr - recordValues[0]) / 2; + binaryMetrics.areaUnderROC += (fpr - recordValues[1]) * (tpr + recordValues[0]) / 2; binaryMetrics.areaUnderLorenz += ((positiveRate - recordValues[3]) * (tpr + recordValues[0]) / 2); binaryMetrics.areaUnderPR += ((tpr - recordValues[0]) * (precision + recordValues[2]) / 2); @@ -325,25 +327,28 @@ private static void updateBinaryMetrics( /** * @param values Reduce Summary of all workers. * @param taskId current taskId. - * @return [curTrue, curFalse, TotalTrue, TotalFalse] + * @return An array storing sum of weights of positives/negatives of tasks before the current + * one, and sum of weights of positives/negatives of all tasks. */ private static double[] reduceBinarySummary(List values, int taskId) { List list = new ArrayList<>(values); list.sort(Comparator.comparingDouble(t -> -t.maxScore)); - double curTrue = 0; - double curFalse = 0; - double totalTrue = 0; - double totalFalse = 0; + double prefixSumWeightsPos = 0; + double prefixSumWeightsNeg = 0; + double totalSumWeightsPos = 0; + double totalSumWeightsNeg = 0; for (BinarySummary statistics : list) { if (statistics.taskId == taskId) { - curFalse = totalFalse; - curTrue = totalTrue; + prefixSumWeightsNeg = totalSumWeightsNeg; + prefixSumWeightsPos = totalSumWeightsPos; } - totalTrue += statistics.sumWeightPos; - totalFalse += statistics.sumWeightNeg; + totalSumWeightsPos += statistics.sumWeightsPos; + totalSumWeightsNeg += statistics.sumWeightsNeg; } - return new double[] {curTrue, curFalse, totalTrue, totalFalse}; + return new double[] { + prefixSumWeightsPos, prefixSumWeightsNeg, totalSumWeightsPos, totalSumWeightsNeg + }; } /** @@ -358,9 +363,9 @@ private static void updateBinarySummary( double weight = evalElement.f2; double score = evalElement.f0; if (isPos) { - statistics.sumWeightPos += weight; + statistics.sumWeightsPos += weight; } else { - statistics.sumWeightNeg += weight; + statistics.sumWeightsNeg += weight; } if (Double.compare(statistics.maxScore, score) < 0) { statistics.maxScore = score; @@ -511,25 +516,25 @@ public static class BinarySummary implements Serializable { // maximum score in this partition public double maxScore; // sum of weights of positives in this partition - public double sumWeightPos; + public double sumWeightsPos; // sum of weights of negatives in this partition - public double sumWeightNeg; + public double sumWeightsNeg; public BinarySummary() {} public BinarySummary( - Integer taskId, double maxScore, double sumWeightPos, double sumWeightNeg) { + Integer taskId, double maxScore, double sumWeightsPos, double sumWeightsNeg) { this.taskId = taskId; this.maxScore = maxScore; - this.sumWeightPos = sumWeightPos; - this.sumWeightNeg = sumWeightNeg; + this.sumWeightsPos = sumWeightsPos; + this.sumWeightsNeg = sumWeightsNeg; } } /** The evaluation metrics for binary classification. */ public static class BinaryMetrics { - /* The count of samples. */ - public double count; + /* The sum of weights of samples. */ + public double sumWeights; /* Area under ROC */ public double areaUnderROC; @@ -545,15 +550,15 @@ public static class BinaryMetrics { public BinaryMetrics() {} - public BinaryMetrics(long count) { - this.count = count; + public BinaryMetrics(long sumWeights) { + this.sumWeights = sumWeights; } public BinaryMetrics merge(BinaryMetrics binaryClassMetrics) { if (null == binaryClassMetrics) { return this; } - count += binaryClassMetrics.count; + sumWeights += binaryClassMetrics.sumWeights; areaUnderROC += binaryClassMetrics.areaUnderROC; areaUnderLorenz += binaryClassMetrics.areaUnderLorenz; areaUnderPR += binaryClassMetrics.areaUnderPR; From 5b7a0738ff8c1af4e0c4d5129b75e86f83905748 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Tue, 22 Aug 2023 11:50:31 +0800 Subject: [PATCH 4/5] Correct Python UT in BinaryClassificationEvaluatorTest. --- .../ml/evaluation/tests/tests_binaryclassification.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flink-ml-python/pyflink/ml/evaluation/tests/tests_binaryclassification.py b/flink-ml-python/pyflink/ml/evaluation/tests/tests_binaryclassification.py index 7f34eb4c9..55a8f2b91 100644 --- a/flink-ml-python/pyflink/ml/evaluation/tests/tests_binaryclassification.py +++ b/flink-ml-python/pyflink/ml/evaluation/tests/tests_binaryclassification.py @@ -16,10 +16,10 @@ # limitations under the License. ################################################################################ import os - from pyflink.common import Types -from pyflink.ml.linalg import Vectors, DenseVectorTypeInfo + from pyflink.ml.evaluation.binaryclassification import BinaryClassificationEvaluator +from pyflink.ml.linalg import Vectors, DenseVectorTypeInfo from pyflink.ml.tests.test_utils import PyFlinkMLTestCase @@ -111,7 +111,7 @@ def setUp(self): self.expected_data_m = [0.8571428571428571, 0.9377705627705628, 0.8571428571428571, 0.6488095238095237] - self.expected_data_w = 0.8911680911680911 + self.expected_data_w = [0.8717948717948718, 0.9510202726261435] self.eps = 1e-5 @@ -185,11 +185,11 @@ def test_evaluate_with_multi_score(self): def test_evaluate_with_weight(self): evaluator = BinaryClassificationEvaluator() \ - .set_metrics_names("areaUnderROC") \ + .set_metrics_names("areaUnderROC", "areaUnderPR") \ .set_weight_col("weight") output = evaluator.transform(self.input_data_table_with_weight)[0] self.assertEqual( - ["areaUnderROC"], + ["areaUnderROC", "areaUnderPR"], output.get_schema().get_field_names()) results = [result for result in output.execute().collect()] result = results[0] From 406378a392677b65f93b1c11ada3e0749adeb799 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Thu, 24 Aug 2023 14:30:19 +0800 Subject: [PATCH 5/5] Change 1e-9 to EPS in BinaryClassificationEvaluatorTest. --- .../flink/ml/evaluation/BinaryClassificationEvaluatorTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java index b55e0fcde..f6f38b298 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java @@ -311,7 +311,7 @@ public void testEvaluateWithWeight() { }, evalResult.getResolvedSchema().getColumnNames().toArray()); assertArrayEquals( - EXPECTED_DATA_W, new double[] {result.getFieldAs(0), result.getFieldAs(1)}, 1e-9); + EXPECTED_DATA_W, new double[] {result.getFieldAs(0), result.getFieldAs(1)}, EPS); } @Test