Skip to content

Commit

Permalink
[SPARK-20631][PYTHON][ML] LogisticRegression._checkThresholdConsisten…
Browse files Browse the repository at this point in the history
…cy should use values not Params

## What changes were proposed in this pull request?

- Replace `getParam` calls with `getOrDefault` calls.
- Fix exception message to avoid unintended `TypeError`.
- Add unit tests

## How was this patch tested?

New unit tests.

Author: zero323 <zero323@users.noreply.github.com>

Closes #17891 from zero323/SPARK-20631.

(cherry picked from commit 804949c)
Signed-off-by: Yanbo Liang <ybliang8@gmail.com>
  • Loading branch information
zero323 authored and yanboliang committed May 10, 2017
1 parent ef50a95 commit 3ed2f4d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
6 changes: 3 additions & 3 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,13 +349,13 @@ def getThresholds(self):

def _checkThresholdConsistency(self):
if self.isSet(self.threshold) and self.isSet(self.thresholds):
ts = self.getParam(self.thresholds)
ts = self.getOrDefault(self.thresholds)
if len(ts) != 2:
raise ValueError("Logistic Regression getThreshold only applies to" +
" binary classification, but thresholds has length != 2." +
" thresholds: " + ",".join(ts))
" thresholds: {0}".format(str(ts)))
t = 1.0/(1.0 + ts[0]/ts[1])
t2 = self.getParam(self.threshold)
t2 = self.getOrDefault(self.threshold)
if abs(t2 - t) >= 1E-5:
raise ValueError("Logistic Regression getThreshold found inconsistent values for" +
" threshold (%g) and thresholds (equivalent to %g)" % (t2, t))
Expand Down
12 changes: 12 additions & 0 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,18 @@ def test_logistic_regression(self):
except OSError:
pass

def logistic_regression_check_thresholds(self):
self.assertIsInstance(
LogisticRegression(threshold=0.5, thresholds=[0.5, 0.5]),
LogisticRegressionModel
)

self.assertRaisesRegexp(
ValueError,
"Logistic Regression getThreshold found inconsistent.*$",
LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5]
)

def _compare_params(self, m1, m2, param):
"""
Compare 2 ML Params instances for the given param, and assert both have the same param value
Expand Down

0 comments on commit 3ed2f4d

Please sign in to comment.