Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
[SPARK-31681][ML][PYSPARK] Python multiclass logistic regression eval…
…uate should return LogisticRegressionSummary

### What changes were proposed in this pull request?
Return LogisticRegressionSummary for multiclass logistic regression evaluate in PySpark

### Why are the changes needed?
Currently we have
```
    since("2.0.0")
    def evaluate(self, dataset):
        if not isinstance(dataset, DataFrame):
            raise ValueError("dataset must be a DataFrame but got %s." % type(dataset))
        java_blr_summary = self._call_java("evaluate", dataset)
        return BinaryLogisticRegressionSummary(java_blr_summary)
```
we should return LogisticRegressionSummary for multiclass logistic regression

### Does this PR introduce _any_ user-facing change?
Yes
return LogisticRegressionSummary instead of BinaryLogisticRegressionSummary for multiclass logistic regression in Python

### How was this patch tested?
unit test

Closes #28503 from huaxingao/lr_summary.

Authored-by: Huaxin Gao <huaxing@us.ibm.com>
Signed-off-by: Sean Owen <srowen@gmail.com>
  • Loading branch information
huaxingao authored and srowen committed May 14, 2020
1 parent b2300fc commit e10516a
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
5 changes: 4 additions & 1 deletion python/pyspark/ml/classification.py
Expand Up @@ -932,7 +932,10 @@ def evaluate(self, dataset):
if not isinstance(dataset, DataFrame):
raise ValueError("dataset must be a DataFrame but got %s." % type(dataset))
java_blr_summary = self._call_java("evaluate", dataset)
return BinaryLogisticRegressionSummary(java_blr_summary)
if self.numClasses <= 2:
return BinaryLogisticRegressionSummary(java_blr_summary)
else:
return LogisticRegressionSummary(java_blr_summary)


class LogisticRegressionSummary(JavaWrapper):
Expand Down
6 changes: 5 additions & 1 deletion python/pyspark/ml/tests/test_training_summary.py
Expand Up @@ -21,7 +21,8 @@
if sys.version > '3':
basestring = str

from pyspark.ml.classification import LogisticRegression
from pyspark.ml.classification import BinaryLogisticRegressionSummary, LogisticRegression, \
LogisticRegressionSummary
from pyspark.ml.clustering import BisectingKMeans, GaussianMixture, KMeans
from pyspark.ml.linalg import Vectors
from pyspark.ml.regression import GeneralizedLinearRegression, LinearRegression
Expand Down Expand Up @@ -149,6 +150,7 @@ def test_binary_logistic_regression_summary(self):
# test evaluation (with training dataset) produces a summary with same values
# one check is enough to verify a summary is returned, Scala version runs full test
sameSummary = model.evaluate(df)
self.assertTrue(isinstance(sameSummary, BinaryLogisticRegressionSummary))
self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC)

def test_multiclass_logistic_regression_summary(self):
Expand Down Expand Up @@ -187,6 +189,8 @@ def test_multiclass_logistic_regression_summary(self):
# test evaluation (with training dataset) produces a summary with same values
# one check is enough to verify a summary is returned, Scala version runs full test
sameSummary = model.evaluate(df)
self.assertTrue(isinstance(sameSummary, LogisticRegressionSummary))
self.assertFalse(isinstance(sameSummary, BinaryLogisticRegressionSummary))
self.assertAlmostEqual(sameSummary.accuracy, s.accuracy)

def test_gaussian_mixture_summary(self):
Expand Down

0 comments on commit e10516a

Please sign in to comment.