diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/status/analytics/models/OrdinaryLeastSquares.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/status/analytics/models/OrdinaryLeastSquares.java index b4dc98cff5a0..ecd830b87d11 100644 --- a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/status/analytics/models/OrdinaryLeastSquares.java +++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/status/analytics/models/OrdinaryLeastSquares.java @@ -16,17 +16,17 @@ */ package org.apache.nifi.controller.status.analytics.models; -import java.util.HashMap; -import java.util.Map; -import java.util.stream.Collectors; -import java.util.stream.Stream; - import org.apache.commons.lang3.ArrayUtils; +import org.apache.commons.math3.linear.SingularMatrixException; import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression; import org.apache.nifi.controller.status.analytics.StatusAnalyticsModel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Stream; + /** *

* An implementation of the {@link StatusAnalyticsModel} that uses Ordinary Least Squares computation for regression. @@ -46,9 +46,14 @@ public OrdinaryLeastSquares() { @Override public void learn(Stream features, Stream labels) { double[] labelArray = ArrayUtils.toPrimitive(labels.toArray(Double[]::new)); - double[][] featuresMatrix = features.map(feature -> ArrayUtils.toPrimitive(feature)).toArray(double[][]::new); + double[][] featuresMatrix = features.map(ArrayUtils::toPrimitive).toArray(double[][]::new); this.olsModel.newSampleData(labelArray, featuresMatrix); - this.coefficients = olsModel.estimateRegressionParameters(); + try { + this.coefficients = olsModel.estimateRegressionParameters(); + } catch (SingularMatrixException sme) { + LOG.debug("The OLSMultipleLinearRegression model's matrix has no inverse (i.e. it is singular) so regression parameters can not be estimated at this time."); + + } } @Override @@ -76,8 +81,7 @@ public Double predictVariable(Integer predictVariableIndex, Map double sumX = 0; if (knownVariablesWithIndex.size() > 0) { sumX = knownVariablesWithIndex.entrySet().stream().map(featureTuple -> coefficients[olsModel.isNoIntercept() - ? featureTuple.getKey() : featureTuple.getKey() + 1] * featureTuple.getValue()) - .collect(Collectors.summingDouble(Double::doubleValue)); + ? featureTuple.getKey() : featureTuple.getKey() + 1] * featureTuple.getValue()).mapToDouble(Double::doubleValue).sum(); } return (label - intercept - sumX) / predictorCoeff; } @@ -89,10 +93,13 @@ public Map getScores() { return null; } else { Map scores = new HashMap<>(); - scores.put("rSquared", olsModel.calculateRSquared()); - scores.put("totalSumOfSquares", olsModel.calculateTotalSumOfSquares()); + try { + scores.put("rSquared", olsModel.calculateRSquared()); + scores.put("totalSumOfSquares", olsModel.calculateTotalSumOfSquares()); + } catch (SingularMatrixException sme) { + LOG.debug("The OLSMultipleLinearRegression model's matrix has no inverse (i.e. it is singular) so no scores can be calculated at this time."); + } return scores; - } } diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/status/analytics/models/TestOrdinaryLeastSqaures.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/status/analytics/models/TestOrdinaryLeastSquares.java similarity index 96% rename from nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/status/analytics/models/TestOrdinaryLeastSqaures.java rename to nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/status/analytics/models/TestOrdinaryLeastSquares.java index b51b2fd95081..de466ea07e7e 100644 --- a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/status/analytics/models/TestOrdinaryLeastSqaures.java +++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/status/analytics/models/TestOrdinaryLeastSquares.java @@ -16,7 +16,7 @@ */ package org.apache.nifi.controller.status.analytics.models; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertFalse; import java.util.Date; import java.util.HashMap; @@ -27,7 +27,7 @@ import org.apache.commons.math3.linear.SingularMatrixException; import org.junit.Test; -public class TestOrdinaryLeastSqaures { +public class TestOrdinaryLeastSquares { @Test @@ -53,7 +53,8 @@ public void testConstantPrediction(){ } catch (SingularMatrixException sme){ exOccurred = true; } - assertTrue(exOccurred); + // SingularMatrixException should not be thrown, it will instead be logged + assertFalse(exOccurred); } @@ -149,8 +150,4 @@ public void comparePredictions(){ assert(olsR2 > srR2); } - - - - }