Skip to content

Commit

Permalink
Unify Logistic Regression convergence tolerance of ML & MLlib
Browse files Browse the repository at this point in the history
  • Loading branch information
yanboliang committed Feb 22, 2016
1 parent 8a4ed78 commit f6b7b8b
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater)
extends Optimizer with Logging {

private var numCorrections = 10
private var convergenceTol = 1E-4
private var convergenceTol = 1E-6
private var maxNumIterations = 100
private var regParam = 0.0

Expand All @@ -59,7 +59,7 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater)
}

/**
* Set the convergence tolerance of iterations for L-BFGS. Default 1E-4.
* Set the convergence tolerance of iterations for L-BFGS. Default 1E-6.
* Smaller value will lead to higher accuracy with the cost of more iterations.
* This value must be nonnegative. Lower convergence values are less tolerant
* and therefore generally cause more iterations to be run.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -667,9 +667,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w

test("binary logistic regression with intercept with L1 regularization") {
val trainer1 = new LogisticRegressionWithLBFGS().setIntercept(true).setFeatureScaling(true)
trainer1.optimizer.setUpdater(new L1Updater).setRegParam(0.12).setConvergenceTol(1E-6)
trainer1.optimizer.setUpdater(new L1Updater).setRegParam(0.12)
val trainer2 = new LogisticRegressionWithLBFGS().setIntercept(true).setFeatureScaling(false)
trainer2.optimizer.setUpdater(new L1Updater).setRegParam(0.12).setConvergenceTol(1E-6)
trainer2.optimizer.setUpdater(new L1Updater).setRegParam(0.12)

val model1 = trainer1.run(binaryDataset)
val model2 = trainer2.run(binaryDataset)
Expand Down Expand Up @@ -726,9 +726,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w

test("binary logistic regression without intercept with L1 regularization") {
val trainer1 = new LogisticRegressionWithLBFGS().setIntercept(false).setFeatureScaling(true)
trainer1.optimizer.setUpdater(new L1Updater).setRegParam(0.12).setConvergenceTol(1E-6)
trainer1.optimizer.setUpdater(new L1Updater).setRegParam(0.12)
val trainer2 = new LogisticRegressionWithLBFGS().setIntercept(false).setFeatureScaling(false)
trainer2.optimizer.setUpdater(new L1Updater).setRegParam(0.12).setConvergenceTol(1E-6)
trainer2.optimizer.setUpdater(new L1Updater).setRegParam(0.12)

val model1 = trainer1.run(binaryDataset)
val model2 = trainer2.run(binaryDataset)
Expand Down Expand Up @@ -786,9 +786,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w

test("binary logistic regression with intercept with L2 regularization") {
val trainer1 = new LogisticRegressionWithLBFGS().setIntercept(true).setFeatureScaling(true)
trainer1.optimizer.setUpdater(new SquaredL2Updater).setRegParam(1.37).setConvergenceTol(1E-6)
trainer1.optimizer.setUpdater(new SquaredL2Updater).setRegParam(1.37)
val trainer2 = new LogisticRegressionWithLBFGS().setIntercept(true).setFeatureScaling(false)
trainer2.optimizer.setUpdater(new SquaredL2Updater).setRegParam(1.37).setConvergenceTol(1E-6)
trainer2.optimizer.setUpdater(new SquaredL2Updater).setRegParam(1.37)

val model1 = trainer1.run(binaryDataset)
val model2 = trainer2.run(binaryDataset)
Expand Down Expand Up @@ -845,9 +845,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w

test("binary logistic regression without intercept with L2 regularization") {
val trainer1 = new LogisticRegressionWithLBFGS().setIntercept(false).setFeatureScaling(true)
trainer1.optimizer.setUpdater(new SquaredL2Updater).setRegParam(1.37).setConvergenceTol(1E-6)
trainer1.optimizer.setUpdater(new SquaredL2Updater).setRegParam(1.37)
val trainer2 = new LogisticRegressionWithLBFGS().setIntercept(false).setFeatureScaling(false)
trainer2.optimizer.setUpdater(new SquaredL2Updater).setRegParam(1.37).setConvergenceTol(1E-6)
trainer2.optimizer.setUpdater(new SquaredL2Updater).setRegParam(1.37)

val model1 = trainer1.run(binaryDataset)
val model2 = trainer2.run(binaryDataset)
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/mllib/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ class LogisticRegressionWithLBFGS(object):
@classmethod
@since('1.2.0')
def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType="l2",
intercept=False, corrections=10, tolerance=1e-4, validateData=True, numClasses=2):
intercept=False, corrections=10, tolerance=1e-6, validateData=True, numClasses=2):
"""
Train a logistic regression model on the given data.
Expand Down Expand Up @@ -359,7 +359,7 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType
(default: 10)
:param tolerance:
The convergence tolerance of iterations for L-BFGS.
(default: 1e-4)
(default: 1e-6)
:param validateData:
Boolean parameter which indicates if the algorithm should
validate data before training.
Expand Down

0 comments on commit f6b7b8b

Please sign in to comment.