From eb3e00cbc2023b6ab915a278532f13fb919356f2 Mon Sep 17 00:00:00 2001
From: phuthientran <58049725+phuthientran@users.noreply.github.com>
Date: Fri, 8 Jan 2021 16:59:04 -1000
Subject: [PATCH] Revert "NIFI-7516: Catch and log SingularMatrixExceptions in
OrdinaryLeastSquares model (#4323)"
This reverts commit d21e9560934e2477cef5f57389efbe5349c922b1.
---
.../models/OrdinaryLeastSquares.java | 31 +++++++------------
...res.java => TestOrdinaryLeastSqaures.java} | 11 ++++---
2 files changed, 19 insertions(+), 23 deletions(-)
rename nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/status/analytics/models/{TestOrdinaryLeastSquares.java => TestOrdinaryLeastSqaures.java} (96%)
diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/status/analytics/models/OrdinaryLeastSquares.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/status/analytics/models/OrdinaryLeastSquares.java
index ecd830b87d11..b4dc98cff5a0 100644
--- a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/status/analytics/models/OrdinaryLeastSquares.java
+++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/status/analytics/models/OrdinaryLeastSquares.java
@@ -16,17 +16,17 @@
*/
package org.apache.nifi.controller.status.analytics.models;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
import org.apache.commons.lang3.ArrayUtils;
-import org.apache.commons.math3.linear.SingularMatrixException;
import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression;
import org.apache.nifi.controller.status.analytics.StatusAnalyticsModel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.stream.Stream;
-
/**
*
* An implementation of the {@link StatusAnalyticsModel} that uses Ordinary Least Squares computation for regression.
@@ -46,14 +46,9 @@ public OrdinaryLeastSquares() {
@Override
public void learn(Stream features, Stream labels) {
double[] labelArray = ArrayUtils.toPrimitive(labels.toArray(Double[]::new));
- double[][] featuresMatrix = features.map(ArrayUtils::toPrimitive).toArray(double[][]::new);
+ double[][] featuresMatrix = features.map(feature -> ArrayUtils.toPrimitive(feature)).toArray(double[][]::new);
this.olsModel.newSampleData(labelArray, featuresMatrix);
- try {
- this.coefficients = olsModel.estimateRegressionParameters();
- } catch (SingularMatrixException sme) {
- LOG.debug("The OLSMultipleLinearRegression model's matrix has no inverse (i.e. it is singular) so regression parameters can not be estimated at this time.");
-
- }
+ this.coefficients = olsModel.estimateRegressionParameters();
}
@Override
@@ -81,7 +76,8 @@ public Double predictVariable(Integer predictVariableIndex, Map
double sumX = 0;
if (knownVariablesWithIndex.size() > 0) {
sumX = knownVariablesWithIndex.entrySet().stream().map(featureTuple -> coefficients[olsModel.isNoIntercept()
- ? featureTuple.getKey() : featureTuple.getKey() + 1] * featureTuple.getValue()).mapToDouble(Double::doubleValue).sum();
+ ? featureTuple.getKey() : featureTuple.getKey() + 1] * featureTuple.getValue())
+ .collect(Collectors.summingDouble(Double::doubleValue));
}
return (label - intercept - sumX) / predictorCoeff;
}
@@ -93,13 +89,10 @@ public Map getScores() {
return null;
} else {
Map scores = new HashMap<>();
- try {
- scores.put("rSquared", olsModel.calculateRSquared());
- scores.put("totalSumOfSquares", olsModel.calculateTotalSumOfSquares());
- } catch (SingularMatrixException sme) {
- LOG.debug("The OLSMultipleLinearRegression model's matrix has no inverse (i.e. it is singular) so no scores can be calculated at this time.");
- }
+ scores.put("rSquared", olsModel.calculateRSquared());
+ scores.put("totalSumOfSquares", olsModel.calculateTotalSumOfSquares());
return scores;
+
}
}
diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/status/analytics/models/TestOrdinaryLeastSquares.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/status/analytics/models/TestOrdinaryLeastSqaures.java
similarity index 96%
rename from nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/status/analytics/models/TestOrdinaryLeastSquares.java
rename to nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/status/analytics/models/TestOrdinaryLeastSqaures.java
index de466ea07e7e..b51b2fd95081 100644
--- a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/status/analytics/models/TestOrdinaryLeastSquares.java
+++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/status/analytics/models/TestOrdinaryLeastSqaures.java
@@ -16,7 +16,7 @@
*/
package org.apache.nifi.controller.status.analytics.models;
-import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
import java.util.Date;
import java.util.HashMap;
@@ -27,7 +27,7 @@
import org.apache.commons.math3.linear.SingularMatrixException;
import org.junit.Test;
-public class TestOrdinaryLeastSquares {
+public class TestOrdinaryLeastSqaures {
@Test
@@ -53,8 +53,7 @@ public void testConstantPrediction(){
} catch (SingularMatrixException sme){
exOccurred = true;
}
- // SingularMatrixException should not be thrown, it will instead be logged
- assertFalse(exOccurred);
+ assertTrue(exOccurred);
}
@@ -150,4 +149,8 @@ public void comparePredictions(){
assert(olsR2 > srR2);
}
+
+
+
+
}