From 87838f0372a926ddaf8b901e11b44699893f6ce9 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Sun, 20 Mar 2016 17:16:30 +0800 Subject: [PATCH 1/8] create lbfgs_check --- .../org/apache/spark/mllib/optimization/LBFGS.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index 82c2ce4790055..f5c892600ed7e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -52,7 +52,8 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) * Restriction: numCorrections > 0 */ def setNumCorrections(corrections: Int): this.type = { - assert(corrections > 0) + require(corrections > 0, s"Number of corrections must be greater than 0," + + s" but got ${corrections}") this.numCorrections = corrections this } @@ -64,6 +65,8 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) * and therefore generally cause more iterations to be run. */ def setConvergenceTol(tolerance: Double): this.type = { + require(tolerance >= 0, s"Convergence tolerance must be no less than 0," + + s" but got ${tolerance}") this.convergenceTol = tolerance this } @@ -88,6 +91,8 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) * Set the maximal number of iterations for L-BFGS. Default 100. */ def setNumIterations(iters: Int): this.type = { + require(iters > 0, s"Maximum of iterations must be greater than 0," + + s" but got ${iters}") this.maxNumIterations = iters this } @@ -103,6 +108,8 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) * Set the regularization parameter. Default 0.0. */ def setRegParam(regParam: Double): this.type = { + require(regParam >= 0, s"Regularization parameter must be no less than 0," + + s" but got ${regParam}") this.regParam = regParam this } From d55a7e57ad0aee00cb31b432fb8b4eebdd816f4e Mon Sep 17 00:00:00 2001 From: mllabs Date: Wed, 23 Mar 2016 11:25:26 +0800 Subject: [PATCH 2/8] add sgd --- .../mllib/optimization/GradientDescent.scala | 11 ++++++++++- .../apache/spark/mllib/optimization/LBFGS.scala | 16 ++++++++-------- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index fbf657b0fac48..30ee465b90e97 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -46,6 +46,8 @@ class GradientDescent private[spark] (private var gradient: Gradient, private va * In subsequent steps, the step size will decrease with stepSize/sqrt(t) */ def setStepSize(step: Double): this.type = { + require(step > 0, + s"Initial step size must be greater than 0, but got ${step}") this.stepSize = step this } @@ -57,6 +59,8 @@ class GradientDescent private[spark] (private var gradient: Gradient, private va */ @Experimental def setMiniBatchFraction(fraction: Double): this.type = { + require(fraction > 0 && fraction <= 1.0, + s"Fraction for mini-batch SGD must be in range (0, 1], but got ${fraction}") this.miniBatchFraction = fraction this } @@ -65,6 +69,8 @@ class GradientDescent private[spark] (private var gradient: Gradient, private va * Set the number of iterations for SGD. Default 100. */ def setNumIterations(iters: Int): this.type = { + require(iters > 0, + s"Number of iterations must be greater than 0, but got ${iters}") this.numIterations = iters this } @@ -73,6 +79,8 @@ class GradientDescent private[spark] (private var gradient: Gradient, private va * Set the regularization parameter. Default 0.0. */ def setRegParam(regParam: Double): this.type = { + require(regParam >= 0, + s"Regularization parameter must be no less than 0, but got ${regParam}") this.regParam = regParam this } @@ -91,7 +99,8 @@ class GradientDescent private[spark] (private var gradient: Gradient, private va * Must be between 0.0 and 1.0 inclusively. */ def setConvergenceTol(tolerance: Double): this.type = { - require(0.0 <= tolerance && tolerance <= 1.0) + require(tolerance >= 0.0 && tolerance <= 1.0, + s"Convergence tolerance must be in range [0, 1], but got ${tolerance}") this.convergenceTol = tolerance this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index f5c892600ed7e..de88497db216a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -52,8 +52,8 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) * Restriction: numCorrections > 0 */ def setNumCorrections(corrections: Int): this.type = { - require(corrections > 0, s"Number of corrections must be greater than 0," + - s" but got ${corrections}") + require(corrections > 0, + s"Number of corrections must be greater than 0, but got ${corrections}") this.numCorrections = corrections this } @@ -65,8 +65,8 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) * and therefore generally cause more iterations to be run. */ def setConvergenceTol(tolerance: Double): this.type = { - require(tolerance >= 0, s"Convergence tolerance must be no less than 0," + - s" but got ${tolerance}") + require(tolerance >= 0, + s"Convergence tolerance must be no less than 0, but got ${tolerance}") this.convergenceTol = tolerance this } @@ -91,8 +91,8 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) * Set the maximal number of iterations for L-BFGS. Default 100. */ def setNumIterations(iters: Int): this.type = { - require(iters > 0, s"Maximum of iterations must be greater than 0," + - s" but got ${iters}") + require(iters > 0, + s"Maximum of iterations must be greater than 0, but got ${iters}") this.maxNumIterations = iters this } @@ -108,8 +108,8 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) * Set the regularization parameter. Default 0.0. */ def setRegParam(regParam: Double): this.type = { - require(regParam >= 0, s"Regularization parameter must be no less than 0," + - s" but got ${regParam}") + require(regParam >= 0, + s"Regularization parameter must be no less than 0, but got ${regParam}") this.regParam = regParam this } From e43b736fa895f069ccbd6e842a89933c50cbe6a5 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Wed, 23 Mar 2016 11:42:05 +0800 Subject: [PATCH 3/8] add clustering --- .../spark/mllib/clustering/GaussianMixture.scala | 6 ++++++ .../org/apache/spark/mllib/clustering/KMeans.scala | 11 ++++++++--- .../org/apache/spark/mllib/clustering/LDA.scala | 2 ++ .../mllib/clustering/PowerIterationClustering.scala | 4 ++++ .../spark/mllib/clustering/StreamingKMeans.scala | 6 ++++++ .../spark/mllib/optimization/GradientDescent.scala | 12 ++++++------ .../org/apache/spark/mllib/optimization/LBFGS.scala | 10 +++++----- 7 files changed, 37 insertions(+), 14 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index 88dbfe3fcc9f5..d6c28d63773c5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -97,6 +97,8 @@ class GaussianMixture private ( */ @Since("1.3.0") def setK(k: Int): this.type = { + require(k > 0, + s"Number of Gaussians must be greater than 0 but got ${k}") this.k = k this } @@ -112,6 +114,8 @@ class GaussianMixture private ( */ @Since("1.3.0") def setMaxIterations(maxIterations: Int): this.type = { + require(maxIterations >= 0, + s"Maximum of iterations must be no less than 0 but got ${maxIterations}") this.maxIterations = maxIterations this } @@ -128,6 +132,8 @@ class GaussianMixture private ( */ @Since("1.3.0") def setConvergenceTol(convergenceTol: Double): this.type = { + require(convergenceTol >= 0.0, + s"Convergence tolerance must be no less than but got ${convergenceTol}") this.convergenceTol = convergenceTol this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 26f5600e6c078..aa3fb97f05e15 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -65,6 +65,8 @@ class KMeans private ( */ @Since("0.8.0") def setK(k: Int): this.type = { + require(k > 0, + s"Number of clusters must be greater than 0 but got ${k}") this.k = k this } @@ -80,6 +82,8 @@ class KMeans private ( */ @Since("0.8.0") def setMaxIterations(maxIterations: Int): this.type = { + require(maxIterations >= 0, + s"Maximum of iterations must be no less than 0 but got ${maxIterations}") this.maxIterations = maxIterations this } @@ -147,9 +151,8 @@ class KMeans private ( */ @Since("0.8.0") def setInitializationSteps(initializationSteps: Int): this.type = { - if (initializationSteps <= 0) { - throw new IllegalArgumentException("Number of initialization steps must be positive") - } + require(initializationSteps > 0, + s"Number of initialization steps must be greater than 0 but got ${initializationSteps}") this.initializationSteps = initializationSteps this } @@ -166,6 +169,8 @@ class KMeans private ( */ @Since("0.8.0") def setEpsilon(epsilon: Double): this.type = { + require(epsilon >= 0, + s"Distance threshold must be no less than 0 but got ${epsilon}") this.epsilon = epsilon this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index fad808857a788..61ce761f4de55 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -232,6 +232,8 @@ class LDA private ( */ @Since("1.3.0") def setMaxIterations(maxIterations: Int): this.type = { + require(maxIterations >= 0, + s"Maximum of iterations must be no less than 0 but got ${maxIterations}") this.maxIterations = maxIterations this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index a422303dc933a..7d842297f6c0f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -137,6 +137,8 @@ class PowerIterationClustering private[clustering] ( */ @Since("1.3.0") def setK(k: Int): this.type = { + require(k > 0, + s"Number of clusters must be greater than 0 but got ${k}") this.k = k this } @@ -146,6 +148,8 @@ class PowerIterationClustering private[clustering] ( */ @Since("1.3.0") def setMaxIterations(maxIterations: Int): this.type = { + require(maxIterations >= 0, + s"Maximum of iterations must be no less than 0 but got ${maxIterations}") this.maxIterations = maxIterations this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index a8d7b8fdedb1f..c9788ae1fde2c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -178,6 +178,8 @@ class StreamingKMeans @Since("1.2.0") ( */ @Since("1.2.0") def setK(k: Int): this.type = { + require(k > 0, + s"Number of clusters must be greater than 0 but got ${k}") this.k = k this } @@ -187,6 +189,8 @@ class StreamingKMeans @Since("1.2.0") ( */ @Since("1.2.0") def setDecayFactor(a: Double): this.type = { + require(a >= 0, + s"Decay factor must be no less than 0 but got ${a}") this.decayFactor = a this } @@ -198,6 +202,8 @@ class StreamingKMeans @Since("1.2.0") ( */ @Since("1.2.0") def setHalfLife(halfLife: Double, timeUnit: String): this.type = { + require(halfLife > 0, + s"Half life must be greater than 0 but got ${halfLife}") if (timeUnit != StreamingKMeans.BATCHES && timeUnit != StreamingKMeans.POINTS) { throw new IllegalArgumentException("Invalid time unit for decay: " + timeUnit) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 30ee465b90e97..51f04d1dae4eb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -47,7 +47,7 @@ class GradientDescent private[spark] (private var gradient: Gradient, private va */ def setStepSize(step: Double): this.type = { require(step > 0, - s"Initial step size must be greater than 0, but got ${step}") + s"Initial step size must be greater than 0 but got ${step}") this.stepSize = step this } @@ -60,7 +60,7 @@ class GradientDescent private[spark] (private var gradient: Gradient, private va @Experimental def setMiniBatchFraction(fraction: Double): this.type = { require(fraction > 0 && fraction <= 1.0, - s"Fraction for mini-batch SGD must be in range (0, 1], but got ${fraction}") + s"Fraction for mini-batch SGD must be in range (0, 1] but got ${fraction}") this.miniBatchFraction = fraction this } @@ -69,8 +69,8 @@ class GradientDescent private[spark] (private var gradient: Gradient, private va * Set the number of iterations for SGD. Default 100. */ def setNumIterations(iters: Int): this.type = { - require(iters > 0, - s"Number of iterations must be greater than 0, but got ${iters}") + require(iters >= 0, + s"Number of iterations must be no less than 0 but got ${iters}") this.numIterations = iters this } @@ -80,7 +80,7 @@ class GradientDescent private[spark] (private var gradient: Gradient, private va */ def setRegParam(regParam: Double): this.type = { require(regParam >= 0, - s"Regularization parameter must be no less than 0, but got ${regParam}") + s"Regularization parameter must be no less than 0 but got ${regParam}") this.regParam = regParam this } @@ -100,7 +100,7 @@ class GradientDescent private[spark] (private var gradient: Gradient, private va */ def setConvergenceTol(tolerance: Double): this.type = { require(tolerance >= 0.0 && tolerance <= 1.0, - s"Convergence tolerance must be in range [0, 1], but got ${tolerance}") + s"Convergence tolerance must be in range [0, 1] but got ${tolerance}") this.convergenceTol = tolerance this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index de88497db216a..889ec6ea54ef7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -53,7 +53,7 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) */ def setNumCorrections(corrections: Int): this.type = { require(corrections > 0, - s"Number of corrections must be greater than 0, but got ${corrections}") + s"Number of corrections must be greater than 0 but got ${corrections}") this.numCorrections = corrections this } @@ -66,7 +66,7 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) */ def setConvergenceTol(tolerance: Double): this.type = { require(tolerance >= 0, - s"Convergence tolerance must be no less than 0, but got ${tolerance}") + s"Convergence tolerance must be no less than 0 but got ${tolerance}") this.convergenceTol = tolerance this } @@ -91,8 +91,8 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) * Set the maximal number of iterations for L-BFGS. Default 100. */ def setNumIterations(iters: Int): this.type = { - require(iters > 0, - s"Maximum of iterations must be greater than 0, but got ${iters}") + require(iters >= 0, + s"Maximum of iterations must be no less than 0 but got ${iters}") this.maxNumIterations = iters this } @@ -109,7 +109,7 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) */ def setRegParam(regParam: Double): this.type = { require(regParam >= 0, - s"Regularization parameter must be no less than 0, but got ${regParam}") + s"Regularization parameter must be no less than 0 but got ${regParam}") this.regParam = regParam this } From f4ccbd6f19f9b5ed0354962787e31d07194fb755 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Wed, 23 Mar 2016 13:26:33 +0800 Subject: [PATCH 4/8] add others in mllib --- .../spark/mllib/classification/NaiveBayes.scala | 2 ++ .../org/apache/spark/mllib/feature/PCA.scala | 3 ++- .../org/apache/spark/mllib/feature/Word2Vec.scala | 15 ++++++++++++++- .../apache/spark/mllib/fpm/AssociationRules.scala | 3 ++- .../org/apache/spark/mllib/fpm/FPGrowth.scala | 4 ++++ .../apache/spark/mllib/recommendation/ALS.scala | 12 ++++++++++++ 6 files changed, 36 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index bf0d9d9231ac7..ea476be16cca2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -326,6 +326,8 @@ class NaiveBayes private ( /** Set the smoothing parameter. Default: 1.0. */ @Since("0.9.0") def setLambda(lambda: Double): NaiveBayes = { + require(lambda > 0, + s"Smoothing parameter must be greater than 0 but got ${lambda}") this.lambda = lambda this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala index 24e0a98c39bff..21835adc7b6ec 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala @@ -30,7 +30,8 @@ import org.apache.spark.rdd.RDD */ @Since("1.4.0") class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) { - require(k >= 1, s"PCA requires a number of principal components k >= 1 but was given $k") + require(k > 0, + s"Number of principal components must be greater than 0 but got ${k}") /** * Computes a [[PCAModel]] that contains the principal components of the input vectors. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index d3356b783fc24..2a54da4a37677 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -84,6 +84,8 @@ class Word2Vec extends Serializable with Logging { */ @Since("2.0.0") def setMaxSentenceLength(maxSentenceLength: Int): this.type = { + require(maxSentenceLength > 0, + s"Maximum length of sentences must be greater than 0 but got ${maxSentenceLength}") this.maxSentenceLength = maxSentenceLength this } @@ -93,6 +95,8 @@ class Word2Vec extends Serializable with Logging { */ @Since("1.1.0") def setVectorSize(vectorSize: Int): this.type = { + require(vectorSize > 0, + s"vector size must be greater than 0 but got ${vectorSize}") this.vectorSize = vectorSize this } @@ -102,6 +106,8 @@ class Word2Vec extends Serializable with Logging { */ @Since("1.1.0") def setLearningRate(learningRate: Double): this.type = { + require(learningRate > 0, + s"Initial learning rate must be greater than 0 but got ${learningRate}") this.learningRate = learningRate this } @@ -111,7 +117,8 @@ class Word2Vec extends Serializable with Logging { */ @Since("1.1.0") def setNumPartitions(numPartitions: Int): this.type = { - require(numPartitions > 0, s"numPartitions must be greater than 0 but got $numPartitions") + require(numPartitions > 0, + s"Number of partitions must be greater than 0 but got ${numPartitions}") this.numPartitions = numPartitions this } @@ -122,6 +129,8 @@ class Word2Vec extends Serializable with Logging { */ @Since("1.1.0") def setNumIterations(numIterations: Int): this.type = { + require(numIterations >= 0, + s"Number of iterations must be greater than 0 but got ${numIterations}") this.numIterations = numIterations this } @@ -140,6 +149,8 @@ class Word2Vec extends Serializable with Logging { */ @Since("1.6.0") def setWindowSize(window: Int): this.type = { + require(window > 0, + s"Window of words must be greater than 0 but got ${window}") this.window = window this } @@ -150,6 +161,8 @@ class Word2Vec extends Serializable with Logging { */ @Since("1.3.0") def setMinCount(minCount: Int): this.type = { + require(minCount > 0, + s"Minimum number of times must be greater than 0 but got ${minCount}") this.minCount = minCount this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala index 5592416964226..9a63cc29dacb5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala @@ -50,7 +50,8 @@ class AssociationRules private[fpm] ( */ @Since("1.5.0") def setMinConfidence(minConfidence: Double): this.type = { - require(minConfidence >= 0.0 && minConfidence <= 1.0) + require(minConfidence >= 0.0 && minConfidence <= 1.0, + s"Minimal confidence must be in range [0, 1] but got ${minConfidence}") this.minConfidence = minConfidence this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 3f40af8f3ada7..beeed99c9a47c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -180,6 +180,8 @@ class FPGrowth private ( */ @Since("1.3.0") def setMinSupport(minSupport: Double): this.type = { + require(minSupport >= 0.0 && minSupport <= 1.0, + s"Minimal support level must be in range [0, 1] but got ${minSupport}") this.minSupport = minSupport this } @@ -190,6 +192,8 @@ class FPGrowth private ( */ @Since("1.3.0") def setNumPartitions(numPartitions: Int): this.type = { + require(numPartitions > 0, + s"Number of partitions must be greater than 0 but got ${numPartitions}") this.numPartitions = numPartitions this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index c5b02d6b2e9ce..c6195e8111ae0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -97,6 +97,8 @@ class ALS private ( */ @Since("0.8.0") def setBlocks(numBlocks: Int): this.type = { + require(numBlocks > 0 || numBlocks == -1, + s"Number of blocks must be -1 or greater than 0 but got ${numBlocks}") this.numUserBlocks = numBlocks this.numProductBlocks = numBlocks this @@ -107,6 +109,8 @@ class ALS private ( */ @Since("1.1.0") def setUserBlocks(numUserBlocks: Int): this.type = { + require(numUserBlocks > 0 || numUserBlocks == -1, + s"Number of blocks must be -1 or greater than 0 but got ${numUserBlocks}") this.numUserBlocks = numUserBlocks this } @@ -116,6 +120,8 @@ class ALS private ( */ @Since("1.1.0") def setProductBlocks(numProductBlocks: Int): this.type = { + require(numProductBlocks > 0 || numProductBlocks == -1, + s"Number of product blocks must be -1 or greater than 0 but got ${numProductBlocks}") this.numProductBlocks = numProductBlocks this } @@ -123,6 +129,8 @@ class ALS private ( /** Set the rank of the feature matrices computed (number of features). Default: 10. */ @Since("0.8.0") def setRank(rank: Int): this.type = { + require(rank > 0, + s"Rank of the feature matrices must be greater than 0 but got ${rank}") this.rank = rank this } @@ -130,6 +138,8 @@ class ALS private ( /** Set the number of iterations to run. Default: 10. */ @Since("0.8.0") def setIterations(iterations: Int): this.type = { + require(iterations >= 0, + s"Number of iterations must be no less than 0 but got ${iterations}") this.iterations = iterations this } @@ -137,6 +147,8 @@ class ALS private ( /** Set the regularization parameter, lambda. Default: 0.01. */ @Since("0.8.0") def setLambda(lambda: Double): this.type = { + require(lambda >= 0.0, + s"Regularization parameter must be no less than 0 but got ${lambda}") this.lambda = lambda this } From c8e0ff5d1978d91ca4825193473ffa3768d889d9 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 23 Mar 2016 15:33:06 +0800 Subject: [PATCH 5/8] fix w2v --- .../scala/org/apache/spark/mllib/feature/Word2Vec.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 2a54da4a37677..c91cc4bde3ce9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -130,7 +130,7 @@ class Word2Vec extends Serializable with Logging { @Since("1.1.0") def setNumIterations(numIterations: Int): this.type = { require(numIterations >= 0, - s"Number of iterations must be greater than 0 but got ${numIterations}") + s"Number of iterations must be no less than 0 but got ${numIterations}") this.numIterations = numIterations this } @@ -161,8 +161,8 @@ class Word2Vec extends Serializable with Logging { */ @Since("1.3.0") def setMinCount(minCount: Int): this.type = { - require(minCount > 0, - s"Minimum number of times must be greater than 0 but got ${minCount}") + require(minCount >= 0, + s"Minimum number of times must be no less than 0 but got ${minCount}") this.minCount = minCount this } From ebc37e1aa1453c86df497d6a7da24a7c4c7bf543 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 23 Mar 2016 15:39:51 +0800 Subject: [PATCH 6/8] fix gmm --- .../org/apache/spark/mllib/clustering/GaussianMixture.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index d6c28d63773c5..816f3ba31bbdd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -133,7 +133,7 @@ class GaussianMixture private ( @Since("1.3.0") def setConvergenceTol(convergenceTol: Double): this.type = { require(convergenceTol >= 0.0, - s"Convergence tolerance must be no less than but got ${convergenceTol}") + s"Convergence tolerance must be no less 0 than but got ${convergenceTol}") this.convergenceTol = convergenceTol this } From c811869da75bde7769078bcedfe33411458c6a27 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 23 Mar 2016 16:35:59 +0800 Subject: [PATCH 7/8] shorten --- .../mllib/classification/NaiveBayes.scala | 2 +- .../mllib/clustering/GaussianMixture.scala | 14 ++++++-------- .../apache/spark/mllib/clustering/KMeans.scala | 8 ++++---- .../apache/spark/mllib/clustering/LDA.scala | 2 +- .../clustering/PowerIterationClustering.scala | 4 ++-- .../mllib/clustering/StreamingKMeans.scala | 6 +++--- .../org/apache/spark/mllib/feature/PCA.scala | 2 +- .../apache/spark/mllib/feature/Word2Vec.scala | 14 +++++++------- .../org/apache/spark/mllib/fpm/FPGrowth.scala | 2 +- .../mllib/optimization/GradientDescent.scala | 6 +++--- .../spark/mllib/optimization/LBFGS.scala | 8 ++++---- .../spark/mllib/recommendation/ALS.scala | 18 +++++++++--------- 12 files changed, 42 insertions(+), 44 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index ea476be16cca2..d51cee65c08d4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -327,7 +327,7 @@ class NaiveBayes private ( @Since("0.9.0") def setLambda(lambda: Double): NaiveBayes = { require(lambda > 0, - s"Smoothing parameter must be greater than 0 but got ${lambda}") + s"Smoothing parameter must be positive but got ${lambda}") this.lambda = lambda this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index 816f3ba31bbdd..03eb903bb8fee 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -78,11 +78,9 @@ class GaussianMixture private ( */ @Since("1.3.0") def setInitialModel(model: GaussianMixtureModel): this.type = { - if (model.k == k) { - initialModel = Some(model) - } else { - throw new IllegalArgumentException("mismatched cluster count (model.k != k)") - } + require(model.k == k, + s"Mismatched cluster count (model.k ${model.k} != k ${k})") + initialModel = Some(model) this } @@ -98,7 +96,7 @@ class GaussianMixture private ( @Since("1.3.0") def setK(k: Int): this.type = { require(k > 0, - s"Number of Gaussians must be greater than 0 but got ${k}") + s"Number of Gaussians must be positive but got ${k}") this.k = k this } @@ -115,7 +113,7 @@ class GaussianMixture private ( @Since("1.3.0") def setMaxIterations(maxIterations: Int): this.type = { require(maxIterations >= 0, - s"Maximum of iterations must be no less than 0 but got ${maxIterations}") + s"Maximum of iterations must be nonnegative but got ${maxIterations}") this.maxIterations = maxIterations this } @@ -133,7 +131,7 @@ class GaussianMixture private ( @Since("1.3.0") def setConvergenceTol(convergenceTol: Double): this.type = { require(convergenceTol >= 0.0, - s"Convergence tolerance must be no less 0 than but got ${convergenceTol}") + s"Convergence tolerance must be nonnegative but got ${convergenceTol}") this.convergenceTol = convergenceTol this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index aa3fb97f05e15..a7beb81980299 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -66,7 +66,7 @@ class KMeans private ( @Since("0.8.0") def setK(k: Int): this.type = { require(k > 0, - s"Number of clusters must be greater than 0 but got ${k}") + s"Number of clusters must be positive but got ${k}") this.k = k this } @@ -83,7 +83,7 @@ class KMeans private ( @Since("0.8.0") def setMaxIterations(maxIterations: Int): this.type = { require(maxIterations >= 0, - s"Maximum of iterations must be no less than 0 but got ${maxIterations}") + s"Maximum of iterations must be nonnegative but got ${maxIterations}") this.maxIterations = maxIterations this } @@ -152,7 +152,7 @@ class KMeans private ( @Since("0.8.0") def setInitializationSteps(initializationSteps: Int): this.type = { require(initializationSteps > 0, - s"Number of initialization steps must be greater than 0 but got ${initializationSteps}") + s"Number of initialization steps must be positive but got ${initializationSteps}") this.initializationSteps = initializationSteps this } @@ -170,7 +170,7 @@ class KMeans private ( @Since("0.8.0") def setEpsilon(epsilon: Double): this.type = { require(epsilon >= 0, - s"Distance threshold must be no less than 0 but got ${epsilon}") + s"Distance threshold must be nonnegative but got ${epsilon}") this.epsilon = epsilon this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 61ce761f4de55..12813fd412b11 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -233,7 +233,7 @@ class LDA private ( @Since("1.3.0") def setMaxIterations(maxIterations: Int): this.type = { require(maxIterations >= 0, - s"Maximum of iterations must be no less than 0 but got ${maxIterations}") + s"Maximum of iterations must be nonnegative but got ${maxIterations}") this.maxIterations = maxIterations this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 7d842297f6c0f..2e257ff9b7def 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -138,7 +138,7 @@ class PowerIterationClustering private[clustering] ( @Since("1.3.0") def setK(k: Int): this.type = { require(k > 0, - s"Number of clusters must be greater than 0 but got ${k}") + s"Number of clusters must be positive but got ${k}") this.k = k this } @@ -149,7 +149,7 @@ class PowerIterationClustering private[clustering] ( @Since("1.3.0") def setMaxIterations(maxIterations: Int): this.type = { require(maxIterations >= 0, - s"Maximum of iterations must be no less than 0 but got ${maxIterations}") + s"Maximum of iterations must be nonnegative but got ${maxIterations}") this.maxIterations = maxIterations this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index c9788ae1fde2c..4eb8fc049e611 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -179,7 +179,7 @@ class StreamingKMeans @Since("1.2.0") ( @Since("1.2.0") def setK(k: Int): this.type = { require(k > 0, - s"Number of clusters must be greater than 0 but got ${k}") + s"Number of clusters must be positive but got ${k}") this.k = k this } @@ -190,7 +190,7 @@ class StreamingKMeans @Since("1.2.0") ( @Since("1.2.0") def setDecayFactor(a: Double): this.type = { require(a >= 0, - s"Decay factor must be no less than 0 but got ${a}") + s"Decay factor must be nonnegative but got ${a}") this.decayFactor = a this } @@ -203,7 +203,7 @@ class StreamingKMeans @Since("1.2.0") ( @Since("1.2.0") def setHalfLife(halfLife: Double, timeUnit: String): this.type = { require(halfLife > 0, - s"Half life must be greater than 0 but got ${halfLife}") + s"Half life must be positive but got ${halfLife}") if (timeUnit != StreamingKMeans.BATCHES && timeUnit != StreamingKMeans.POINTS) { throw new IllegalArgumentException("Invalid time unit for decay: " + timeUnit) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala index 21835adc7b6ec..30c403e547bee 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala @@ -31,7 +31,7 @@ import org.apache.spark.rdd.RDD @Since("1.4.0") class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) { require(k > 0, - s"Number of principal components must be greater than 0 but got ${k}") + s"Number of principal components must be positive but got ${k}") /** * Computes a [[PCAModel]] that contains the principal components of the input vectors. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index c91cc4bde3ce9..5b079fce3a83d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -85,7 +85,7 @@ class Word2Vec extends Serializable with Logging { @Since("2.0.0") def setMaxSentenceLength(maxSentenceLength: Int): this.type = { require(maxSentenceLength > 0, - s"Maximum length of sentences must be greater than 0 but got ${maxSentenceLength}") + s"Maximum length of sentences must be positive but got ${maxSentenceLength}") this.maxSentenceLength = maxSentenceLength this } @@ -96,7 +96,7 @@ class Word2Vec extends Serializable with Logging { @Since("1.1.0") def setVectorSize(vectorSize: Int): this.type = { require(vectorSize > 0, - s"vector size must be greater than 0 but got ${vectorSize}") + s"vector size must be positive but got ${vectorSize}") this.vectorSize = vectorSize this } @@ -107,7 +107,7 @@ class Word2Vec extends Serializable with Logging { @Since("1.1.0") def setLearningRate(learningRate: Double): this.type = { require(learningRate > 0, - s"Initial learning rate must be greater than 0 but got ${learningRate}") + s"Initial learning rate must be positive but got ${learningRate}") this.learningRate = learningRate this } @@ -118,7 +118,7 @@ class Word2Vec extends Serializable with Logging { @Since("1.1.0") def setNumPartitions(numPartitions: Int): this.type = { require(numPartitions > 0, - s"Number of partitions must be greater than 0 but got ${numPartitions}") + s"Number of partitions must be positive but got ${numPartitions}") this.numPartitions = numPartitions this } @@ -130,7 +130,7 @@ class Word2Vec extends Serializable with Logging { @Since("1.1.0") def setNumIterations(numIterations: Int): this.type = { require(numIterations >= 0, - s"Number of iterations must be no less than 0 but got ${numIterations}") + s"Number of iterations must be nonnegative but got ${numIterations}") this.numIterations = numIterations this } @@ -150,7 +150,7 @@ class Word2Vec extends Serializable with Logging { @Since("1.6.0") def setWindowSize(window: Int): this.type = { require(window > 0, - s"Window of words must be greater than 0 but got ${window}") + s"Window of words must be positive but got ${window}") this.window = window this } @@ -162,7 +162,7 @@ class Word2Vec extends Serializable with Logging { @Since("1.3.0") def setMinCount(minCount: Int): this.type = { require(minCount >= 0, - s"Minimum number of times must be no less than 0 but got ${minCount}") + s"Minimum number of times must be nonnegative but got ${minCount}") this.minCount = minCount this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index beeed99c9a47c..4f4996f3be617 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -193,7 +193,7 @@ class FPGrowth private ( @Since("1.3.0") def setNumPartitions(numPartitions: Int): this.type = { require(numPartitions > 0, - s"Number of partitions must be greater than 0 but got ${numPartitions}") + s"Number of partitions must be positive but got ${numPartitions}") this.numPartitions = numPartitions this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 51f04d1dae4eb..a67ea836e5681 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -47,7 +47,7 @@ class GradientDescent private[spark] (private var gradient: Gradient, private va */ def setStepSize(step: Double): this.type = { require(step > 0, - s"Initial step size must be greater than 0 but got ${step}") + s"Initial step size must be positive but got ${step}") this.stepSize = step this } @@ -70,7 +70,7 @@ class GradientDescent private[spark] (private var gradient: Gradient, private va */ def setNumIterations(iters: Int): this.type = { require(iters >= 0, - s"Number of iterations must be no less than 0 but got ${iters}") + s"Number of iterations must be nonnegative but got ${iters}") this.numIterations = iters this } @@ -80,7 +80,7 @@ class GradientDescent private[spark] (private var gradient: Gradient, private va */ def setRegParam(regParam: Double): this.type = { require(regParam >= 0, - s"Regularization parameter must be no less than 0 but got ${regParam}") + s"Regularization parameter must be nonnegative but got ${regParam}") this.regParam = regParam this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index 889ec6ea54ef7..16a33526414bd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -53,7 +53,7 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) */ def setNumCorrections(corrections: Int): this.type = { require(corrections > 0, - s"Number of corrections must be greater than 0 but got ${corrections}") + s"Number of corrections must be positive but got ${corrections}") this.numCorrections = corrections this } @@ -66,7 +66,7 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) */ def setConvergenceTol(tolerance: Double): this.type = { require(tolerance >= 0, - s"Convergence tolerance must be no less than 0 but got ${tolerance}") + s"Convergence tolerance must be nonnegative but got ${tolerance}") this.convergenceTol = tolerance this } @@ -92,7 +92,7 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) */ def setNumIterations(iters: Int): this.type = { require(iters >= 0, - s"Maximum of iterations must be no less than 0 but got ${iters}") + s"Maximum of iterations must be nonnegative but got ${iters}") this.maxNumIterations = iters this } @@ -109,7 +109,7 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) */ def setRegParam(regParam: Double): this.type = { require(regParam >= 0, - s"Regularization parameter must be no less than 0 but got ${regParam}") + s"Regularization parameter must be nonnegative but got ${regParam}") this.regParam = regParam this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index c6195e8111ae0..467cb83cd1662 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -97,8 +97,8 @@ class ALS private ( */ @Since("0.8.0") def setBlocks(numBlocks: Int): this.type = { - require(numBlocks > 0 || numBlocks == -1, - s"Number of blocks must be -1 or greater than 0 but got ${numBlocks}") + require(numBlocks == -1 || numBlocks > 0, + s"Number of blocks must be -1 or positive but got ${numBlocks}") this.numUserBlocks = numBlocks this.numProductBlocks = numBlocks this @@ -109,8 +109,8 @@ class ALS private ( */ @Since("1.1.0") def setUserBlocks(numUserBlocks: Int): this.type = { - require(numUserBlocks > 0 || numUserBlocks == -1, - s"Number of blocks must be -1 or greater than 0 but got ${numUserBlocks}") + require(numUserBlocks == -1 || numUserBlocks > 0, + s"Number of blocks must be -1 or positive but got ${numUserBlocks}") this.numUserBlocks = numUserBlocks this } @@ -120,8 +120,8 @@ class ALS private ( */ @Since("1.1.0") def setProductBlocks(numProductBlocks: Int): this.type = { - require(numProductBlocks > 0 || numProductBlocks == -1, - s"Number of product blocks must be -1 or greater than 0 but got ${numProductBlocks}") + require(numProductBlocks == -1 || numProductBlocks > 0, + s"Number of product blocks must be -1 or positive but got ${numProductBlocks}") this.numProductBlocks = numProductBlocks this } @@ -130,7 +130,7 @@ class ALS private ( @Since("0.8.0") def setRank(rank: Int): this.type = { require(rank > 0, - s"Rank of the feature matrices must be greater than 0 but got ${rank}") + s"Rank of the feature matrices must be positive but got ${rank}") this.rank = rank this } @@ -139,7 +139,7 @@ class ALS private ( @Since("0.8.0") def setIterations(iterations: Int): this.type = { require(iterations >= 0, - s"Number of iterations must be no less than 0 but got ${iterations}") + s"Number of iterations must be nonnegative but got ${iterations}") this.iterations = iterations this } @@ -148,7 +148,7 @@ class ALS private ( @Since("0.8.0") def setLambda(lambda: Double): this.type = { require(lambda >= 0.0, - s"Regularization parameter must be no less than 0 but got ${lambda}") + s"Regularization parameter must be nonnegative but got ${lambda}") this.lambda = lambda this } From a826081118a588bf46dea08a3f9a0f25448e18ba Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 24 Mar 2016 09:30:44 +0800 Subject: [PATCH 8/8] fix nb --- .../org/apache/spark/mllib/classification/NaiveBayes.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index d51cee65c08d4..eb3ee41f7cf4f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -326,8 +326,8 @@ class NaiveBayes private ( /** Set the smoothing parameter. Default: 1.0. */ @Since("0.9.0") def setLambda(lambda: Double): NaiveBayes = { - require(lambda > 0, - s"Smoothing parameter must be positive but got ${lambda}") + require(lambda >= 0, + s"Smoothing parameter must be nonnegative but got ${lambda}") this.lambda = lambda this }