From e10516ae63cfc58f2d493e4d3f19940d45c8f033 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 14 May 2020 10:54:35 -0500 Subject: [PATCH] [SPARK-31681][ML][PYSPARK] Python multiclass logistic regression evaluate should return LogisticRegressionSummary ### What changes were proposed in this pull request? Return LogisticRegressionSummary for multiclass logistic regression evaluate in PySpark ### Why are the changes needed? Currently we have ``` since("2.0.0") def evaluate(self, dataset): if not isinstance(dataset, DataFrame): raise ValueError("dataset must be a DataFrame but got %s." % type(dataset)) java_blr_summary = self._call_java("evaluate", dataset) return BinaryLogisticRegressionSummary(java_blr_summary) ``` we should return LogisticRegressionSummary for multiclass logistic regression ### Does this PR introduce _any_ user-facing change? Yes return LogisticRegressionSummary instead of BinaryLogisticRegressionSummary for multiclass logistic regression in Python ### How was this patch tested? unit test Closes #28503 from huaxingao/lr_summary. Authored-by: Huaxin Gao Signed-off-by: Sean Owen --- python/pyspark/ml/classification.py | 5 ++++- python/pyspark/ml/tests/test_training_summary.py | 6 +++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index d635be1d8db80..3bc862cc42af9 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -932,7 +932,10 @@ def evaluate(self, dataset): if not isinstance(dataset, DataFrame): raise ValueError("dataset must be a DataFrame but got %s." % type(dataset)) java_blr_summary = self._call_java("evaluate", dataset) - return BinaryLogisticRegressionSummary(java_blr_summary) + if self.numClasses <= 2: + return BinaryLogisticRegressionSummary(java_blr_summary) + else: + return LogisticRegressionSummary(java_blr_summary) class LogisticRegressionSummary(JavaWrapper): diff --git a/python/pyspark/ml/tests/test_training_summary.py b/python/pyspark/ml/tests/test_training_summary.py index 1d19ebf9a34a0..b5054095d190b 100644 --- a/python/pyspark/ml/tests/test_training_summary.py +++ b/python/pyspark/ml/tests/test_training_summary.py @@ -21,7 +21,8 @@ if sys.version > '3': basestring = str -from pyspark.ml.classification import LogisticRegression +from pyspark.ml.classification import BinaryLogisticRegressionSummary, LogisticRegression, \ + LogisticRegressionSummary from pyspark.ml.clustering import BisectingKMeans, GaussianMixture, KMeans from pyspark.ml.linalg import Vectors from pyspark.ml.regression import GeneralizedLinearRegression, LinearRegression @@ -149,6 +150,7 @@ def test_binary_logistic_regression_summary(self): # test evaluation (with training dataset) produces a summary with same values # one check is enough to verify a summary is returned, Scala version runs full test sameSummary = model.evaluate(df) + self.assertTrue(isinstance(sameSummary, BinaryLogisticRegressionSummary)) self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC) def test_multiclass_logistic_regression_summary(self): @@ -187,6 +189,8 @@ def test_multiclass_logistic_regression_summary(self): # test evaluation (with training dataset) produces a summary with same values # one check is enough to verify a summary is returned, Scala version runs full test sameSummary = model.evaluate(df) + self.assertTrue(isinstance(sameSummary, LogisticRegressionSummary)) + self.assertFalse(isinstance(sameSummary, BinaryLogisticRegressionSummary)) self.assertAlmostEqual(sameSummary.accuracy, s.accuracy) def test_gaussian_mixture_summary(self):