Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-31681][ML][PySpark] Python multiclass logistic regression evaluate should return LogisticRegressionSummary #28503

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 4 additions & 1 deletion python/pyspark/ml/classification.py
Expand Up @@ -932,7 +932,10 @@ def evaluate(self, dataset):
if not isinstance(dataset, DataFrame):
raise ValueError("dataset must be a DataFrame but got %s." % type(dataset))
java_blr_summary = self._call_java("evaluate", dataset)
return BinaryLogisticRegressionSummary(java_blr_summary)
if self.numClasses <= 2:
return BinaryLogisticRegressionSummary(java_blr_summary)
else:
return LogisticRegressionSummary(java_blr_summary)


class LogisticRegressionSummary(JavaWrapper):
Expand Down
6 changes: 5 additions & 1 deletion python/pyspark/ml/tests/test_training_summary.py
Expand Up @@ -21,7 +21,8 @@
if sys.version > '3':
basestring = str

from pyspark.ml.classification import LogisticRegression
from pyspark.ml.classification import BinaryLogisticRegressionSummary, LogisticRegression, \
LogisticRegressionSummary
from pyspark.ml.clustering import BisectingKMeans, GaussianMixture, KMeans
from pyspark.ml.linalg import Vectors
from pyspark.ml.regression import GeneralizedLinearRegression, LinearRegression
Expand Down Expand Up @@ -149,6 +150,7 @@ def test_binary_logistic_regression_summary(self):
# test evaluation (with training dataset) produces a summary with same values
# one check is enough to verify a summary is returned, Scala version runs full test
sameSummary = model.evaluate(df)
self.assertTrue(isinstance(sameSummary, BinaryLogisticRegressionSummary))
self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC)

def test_multiclass_logistic_regression_summary(self):
Expand Down Expand Up @@ -187,6 +189,8 @@ def test_multiclass_logistic_regression_summary(self):
# test evaluation (with training dataset) produces a summary with same values
# one check is enough to verify a summary is returned, Scala version runs full test
sameSummary = model.evaluate(df)
self.assertTrue(isinstance(sameSummary, LogisticRegressionSummary))
self.assertFalse(isinstance(sameSummary, BinaryLogisticRegressionSummary))
self.assertAlmostEqual(sameSummary.accuracy, s.accuracy)

def test_gaussian_mixture_summary(self):
Expand Down