Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-14975][ML] Fixed GBTClassifier to predict probability per training instance and fixed interfaces #16441

Closed
wants to merge 20 commits into from

Conversation

Projects
None yet
6 participants
@imatiach-msft
Copy link
Contributor

commented Dec 30, 2016

What changes were proposed in this pull request?

For all of the classifiers in MLLib we can predict probabilities except for GBTClassifier.
Also, all classifiers inherit from ProbabilisticClassifier but GBTClassifier strangely inherits from Predictor, which is a bug.
This change corrects the interface and adds the ability for the classifier to give a probabilities vector.

How was this patch tested?

The basic ML tests were run after making the changes. I've marked this as WIP as I need to add more tests.

@SparkQA

This comment has been minimized.

Copy link

commented Dec 30, 2016

Test build #70759 has finished for PR 16441 at commit 4468891.

  • This patch fails MiMa tests.
  • This patch merges cleanly.
  • This patch adds no public classes.
@imatiach-msft

This comment has been minimized.

Copy link
Contributor Author

commented Dec 30, 2016

Jenkins, retest this please

@SparkQA

This comment has been minimized.

Copy link

commented Dec 30, 2016

Test build #70760 has finished for PR 16441 at commit 489e0e6.

  • This patch fails PySpark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.
@jkbradley

This comment has been minimized.

Copy link
Member

commented Dec 31, 2016

Thanks for the PR; I do want to get this fixed. However, I don't think this is the right way to make predictions of probabilities for GBTs. I believe it should depend on the loss used. E.g., check out page 8 of Friedman (1999) "Greedy Function Approximation? A Gradient Boosting Machine"

@SparkQA

This comment has been minimized.

Copy link

commented Jan 5, 2017

Test build #70935 has finished for PR 16441 at commit 4348c2e.

  • This patch fails Scala style tests.
  • This patch merges cleanly.
  • This patch adds no public classes.
@imatiach-msft

This comment has been minimized.

Copy link
Contributor Author

commented Jan 5, 2017

Thanks, I've updated the PR based on your comment. The only disadvantage to the current code is that I do the probability computation within the classifier, but it seems like it should be moved to the LogLoss.scala class. However, it's not a problem right now because GBTClassifier only uses logistic loss, and other learners would have to be modified in a similar way as well probably.

@imatiach-msft imatiach-msft force-pushed the imatiach-msft:ilmat/fix-GBT branch Jan 5, 2017

@SparkQA

This comment has been minimized.

Copy link

commented Jan 5, 2017

Test build #70938 has finished for PR 16441 at commit 2b842e5.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.
@SparkQA

This comment has been minimized.

Copy link

commented Jan 5, 2017

Test build #70939 has finished for PR 16441 at commit 9def0ca.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.
@imatiach-msft

This comment has been minimized.

Copy link
Contributor Author

commented Jan 5, 2017

@jkbradley I've updated based on your comments, please take another look, thanks!

@sethah
Copy link
Contributor

left a comment

Thanks for the patch. I made a first pass

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala Outdated
def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) =
this(uid, _trees, _treeWeights, -1)
def this(uid: String, _trees: Array[DecisionTreeRegressionModel],
_treeWeights: Array[Double]) =

This comment has been minimized.

Copy link
@sethah

sethah Jan 5, 2017

Contributor

put this back on one line

This comment has been minimized.

Copy link
@imatiach-msft

imatiach-msft Jan 5, 2017

Author Contributor

done

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala Outdated
@@ -215,10 +223,23 @@ class GBTClassificationModel private[ml](
*
* @param _trees Decision trees in the ensemble.
* @param _treeWeights Weights for the decision trees in the ensemble.
* @param numFeatures The number of features.
*/
@Since("1.6.0")

This comment has been minimized.

Copy link
@sethah

sethah Jan 5, 2017

Contributor

This is actually not correct since the constructor was private[ml] before. Since this has always been private, and we aren't actually using it anywhere, I think we can remove this constructor entirely.

This comment has been minimized.

Copy link
@imatiach-msft

imatiach-msft Jan 5, 2017

Author Contributor

removed

This comment has been minimized.

Copy link
@sethah

sethah Jan 10, 2017

Contributor

Since tag not needed since it's private

This comment has been minimized.

Copy link
@imatiach-msft

imatiach-msft Jan 10, 2017

Author Contributor

removed since tag

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala Outdated
@@ -248,12 +269,38 @@ class GBTClassificationModel private[ml](
if (prediction > 0.0) 1.0 else 0.0
}

override protected def predictRaw(features: Vector): Vector = {
val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)

This comment has been minimized.

Copy link
@sethah

sethah Jan 5, 2017

Contributor

We should import org.apache.spark.ml.linalg.BLAS and call BLAS.dot here and in predict.

This comment has been minimized.

Copy link
@imatiach-msft

imatiach-msft Jan 5, 2017

Author Contributor

it looks like BLAS.dot is only for Vector, but these are both arrays. I'm worried that this may degrade performance. Is this specifically what you are looking for:
BLAS.dot(Vectors.dense(treePredictions), Vectors.dense(_treeWeights))
is the extra dense vector allocation worth it?

This comment has been minimized.

Copy link
@sethah

sethah Jan 6, 2017

Contributor

Yeah, I see it's not quite the same as in other places. We can leave it

This comment has been minimized.

Copy link
@imatiach-msft

imatiach-msft Jan 6, 2017

Author Contributor

oh ok, thank you for confirming

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala Outdated
var i = 0
val size = dv.size
while (i < size) {
dv.values(i) = 1 / (1 + math.exp(-2 * dv.values(i)))

This comment has been minimized.

Copy link
@sethah

sethah Jan 5, 2017

Contributor

my concern is that this is hard coded to logistic loss. Maybe we can add a static method to GBTClassificationModel

private def classProbability(class: Int, loss: String, rawPrediction: Double): Double = {
  loss match {
    case "logistic" => ...
    case _ => throw new Exception("Only logistic loss is supported ...")
  }
}

This comment has been minimized.

Copy link
@imatiach-msft

imatiach-msft Jan 5, 2017

Author Contributor

done

mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala Outdated
ParamsSuite.checkParams(model)
}

test("Verify raw scores correspond to labels") {
val rawPredictionCol = "MyRawPrediction"

This comment has been minimized.

Copy link
@sethah

sethah Jan 5, 2017

Contributor

Just use defaults here. And I'm in favor of only setting parameters that matter for the given test, otherwise it may give the impression that the test depends on a certain, say checkpoint interval.

This comment has been minimized.

Copy link
@imatiach-msft

imatiach-msft Jan 6, 2017

Author Contributor

done

mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala Outdated
ParamsSuite.checkParams(model)
}

test("Verify raw scores correspond to labels") {

This comment has been minimized.

Copy link
@sethah

sethah Jan 5, 2017

Contributor

Could you take a look at this test, and make it line up here? Specifically:

  • compute probabilities manually from rawPrediction and ensure that it matches the probabilities column
  • make sure that probabilities.argmax and rawPrediction.argmax equal the prediction
  • make sure probabilities sum to one
  • check the different code paths by unsetting some of the output columns

This comment has been minimized.

Copy link
@imatiach-msft

imatiach-msft Jan 6, 2017

Author Contributor

done

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala Outdated
@@ -248,12 +269,38 @@ class GBTClassificationModel private[ml](
if (prediction > 0.0) 1.0 else 0.0
}

override protected def predictRaw(features: Vector): Vector = {

This comment has been minimized.

Copy link
@sethah

sethah Jan 5, 2017

Contributor

In logistic regression we had previously overridden some of the methods in probabilistic classifier since we were only dealing with two classes, which makes those methods a bit faster (hard to say how much). We can do it here for now, but I'd be slightly in favor of not doing it since I'm not sure how much we gain from it and it makes the code harder to follow. Thoughts?

This comment has been minimized.

Copy link
@imatiach-msft

imatiach-msft Jan 5, 2017

Author Contributor

sorry I'm a bit confused, this classifier also only deals with two classes, it does not support multiclass data. Instead of overriding, what is the alternative? There is no default predictRaw or raw2probability implemented in probabilistic classifier, and it seems that this is the minimum required for GBTClassifier to use ProbabilisticClassifier. Can you please give more information on this point?

This comment has been minimized.

Copy link
@sethah

sethah Jan 6, 2017

Contributor

I can see how my comment was confusing now :) Since GBT only supports two classes right now, we could override methods like probability2prediction which are by default calling what is implemented in ProbabilisticClassifier. When thresholds are not defined, it calls probablity.argmax which for two classes we could simplify to

if (probability(1) > probablity(0)) 1 else 0

Looking now, logistic regression also had a getThreshold method which allowed it to avoid loops in some cases, but we don't have it here. Let's leave things how they are.

This comment has been minimized.

Copy link
@imatiach-msft

imatiach-msft Jan 6, 2017

Author Contributor

sorry, I'm still a little confused, should I override probability2prediction and simplify, or should I keep the argmax as is? The argmax seems better because it is more general anyway, but please let me know if you would prefer that I make any changes here.

This comment has been minimized.

Copy link
@sethah

sethah Jan 10, 2017

Contributor

Let's not change anything for now, it's fine as it is. Sorry for the confusion.

mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala Outdated
val gbtModel = gbt.fit(trainData.toDF(labelCol, featuresCol))
val scoredData = gbtModel.transform(validationData.toDF(labelCol, featuresCol))
scoredData.select(rawPredictionCol, predictionCol).collect()
.foreach(row => {

This comment has been minimized.

Copy link
@sethah

sethah Jan 5, 2017

Contributor

you can use .foreach { case Row(raw: DenseVector, pred: Double, prob: DenseVector) => ... } here.

This comment has been minimized.

Copy link
@imatiach-msft

imatiach-msft Jan 6, 2017

Author Contributor

done

@imatiach-msft

This comment has been minimized.

Copy link
Contributor Author

commented Jan 6, 2017

@sethah @jkbradley thank you for the review - could you please take another look since I've updated the code review based on your comments?

@SparkQA

This comment has been minimized.

Copy link

commented Jan 6, 2017

Test build #70963 has finished for PR 16441 at commit 0c0cb8b.

  • This patch fails MiMa tests.
  • This patch merges cleanly.
  • This patch adds no public classes.
@imatiach-msft

This comment has been minimized.

Copy link
Contributor Author

commented Jan 6, 2017

It looks like I am failing the binary compatibility tests despite this constructor being private:

class GBTClassificationModel private[ml](
@SInCE("1.6.0") override val uid: String,
private val _trees: Array[DecisionTreeRegressionModel],
private val _treeWeights: Array[Double],
@SInCE("1.6.0") override val numFeatures: Int,
@SInCE("2.2.0") override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, GBTClassificationModel]

This is the same thing that happened in my original PR and then I had to add the additional this() overload to pass the tests. In the PR comment it was mentioned that I should be able to remove the unused constructor, does this mean that I need to change the binary compatibility test somehow as well? My guess is that the binary compat tests are java based and not scala based, in which case private[ml] doesn't matter, so the solution would be to keep the extra constructor I had before, just make sure that it is still private[ml], only so I can pass the binary compat tests.

@SparkQA

This comment has been minimized.

Copy link

commented Jan 6, 2017

Test build #70982 has finished for PR 16441 at commit 2f06cb5.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.
@imatiach-msft

This comment has been minimized.

Copy link
Contributor Author

commented Jan 6, 2017

Indeed re-adding the constructor seems to make the binary compatibility tests pass (see spark QA build above). I think in favor of making the binary compat tests pass, we can keep the extra private constructor, even though for most people it won't do anything.

Please let me know if there are any outstanding comments that still need to be addressed. Thank you!

@imatiach-msft imatiach-msft changed the title [SPARK-14975][ML][WIP] Fixed GBTClassifier to predict probability per training instance and fixed interfaces [SPARK-14975][ML] Fixed GBTClassifier to predict probability per training instance and fixed interfaces Jan 6, 2017

@imatiach-msft

This comment has been minimized.

Copy link
Contributor Author

commented Jan 6, 2017

I've removed the WIP from title to reflect the status of the pull request.

@imatiach-msft

This comment has been minimized.

Copy link
Contributor Author

commented Jan 9, 2017

ping @sethah @jkbradley could you please take another look since I've updated the code review based on your comments? Thank you!

@sethah
Copy link
Contributor

left a comment

Made another pass, thanks for working on this, my apologies for the delayed review.

mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala Outdated
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors}

This comment has been minimized.

Copy link
@sethah

sethah Jan 10, 2017

Contributor

DenseVector is unused

This comment has been minimized.

Copy link
@imatiach-msft

imatiach-msft Jan 10, 2017

Author Contributor

removed

mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala Outdated
case Row(raw: Vector, prob: Vector) =>
assert(raw.size === 2)
assert(prob.size === 2)
val prodFromRaw = raw.toDense.values.map(value => 1 / (1 + math.exp(-2 * value)))

This comment has been minimized.

Copy link
@sethah

sethah Jan 10, 2017

Contributor

predFromRaw ?

This comment has been minimized.

Copy link
@sethah

sethah Jan 10, 2017

Contributor

Also, can we leave a comment regarding the fact that we'd want to check other loss types here for classification if they are ever added.

This comment has been minimized.

Copy link
@imatiach-msft

imatiach-msft Jan 10, 2017

Author Contributor

done and done

mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala Outdated
assert(raw.size === 2)
assert(prob.size === 2)
val prodFromRaw = raw.toDense.values.map(value => 1 / (1 + math.exp(-2 * value)))
assert(prob(0) ~== prodFromRaw(0) relTol eps)

This comment has been minimized.

Copy link
@sethah

sethah Jan 10, 2017

Contributor

check that prob(0) + prob(1) ~== 1.0 absTol 1e-8

This comment has been minimized.

Copy link
@imatiach-msft

imatiach-msft Jan 10, 2017

Author Contributor

good idea! done. I added absEps for 1e-8 so that there won't be any magic constants floating around.

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala Outdated
// and negative result:
// p-(x) = 1 / (1 + e^(2 * F(x)))
case dv: DenseVector =>
var i = 0

This comment has been minimized.

Copy link
@sethah

sethah Jan 10, 2017

Contributor

We can save ourselves some computation here:

case dv: DenseVector =>
  dv.values(0) = computeProb(dv.values(0))
  dv.values(1) = 1.0 - dv.values(0)
  dv

This comment has been minimized.

Copy link
@imatiach-msft

imatiach-msft Jan 10, 2017

Author Contributor

done

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala Outdated
@@ -275,6 +321,13 @@ class GBTClassificationModel private[ml](
@Since("2.0.0")
lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures)

private def classProbability(loss: String, rawPrediction: Double): Double = {

This comment has been minimized.

Copy link
@sethah

sethah Jan 10, 2017

Contributor

Actually, this would be better served embedded in the loss object. One solution would be to make a few changes to the loss:

trait ClassificationLoss extends Loss {
  private[spark] def computeProbability(prediction: Double): Double
}
object LogLoss extends ClassificationLoss

Then we could add a class member to the model private val oldLoss: ClassificationLoss = getOldLossType, then we can just call oldLoss.computeProbability(pred) inside raw2ProbabilityInPlace. There might be a better solution too, but really I think it should be part of the loss.

This comment has been minimized.

Copy link
@imatiach-msft

imatiach-msft Jan 10, 2017

Author Contributor

adding "private val oldLoss: ClassificationLoss = getOldLossType" won't work because getOldLossType returns a Loss and not a LogLoss, which doesn't have computeProbability. However, I did add the ClassificationLoss trait and in ClassProbability I just call LogLoss.computeProbability. I'm not sure if it will pass the binary compat checks though, let's see...

This comment has been minimized.

Copy link
@sethah

sethah Jan 10, 2017

Contributor

You can change getOldLossType to return a classification loss, can't you?

This comment has been minimized.

Copy link
@imatiach-msft

imatiach-msft Jan 10, 2017

Author Contributor

good point, will update

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala Outdated
@@ -215,10 +223,23 @@ class GBTClassificationModel private[ml](
*
* @param _trees Decision trees in the ensemble.
* @param _treeWeights Weights for the decision trees in the ensemble.
* @param numFeatures The number of features.
*/
@Since("1.6.0")

This comment has been minimized.

Copy link
@sethah

sethah Jan 10, 2017

Contributor

Since tag not needed since it's private

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala Outdated
@@ -159,14 +157,21 @@ class GBTClassifier @Since("1.4.0") (
val numFeatures = oldDataset.first().features.size
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)

val instr = Instrumentation.create(this, oldDataset)
val numClasses: Int = getNumClasses(dataset)

This comment has been minimized.

Copy link
@sethah

sethah Jan 10, 2017

Contributor

We should just use numClasses = 2 for now, since getNumClasses can make an extra pass over the data, and >2 classes are not supported anyway.

This comment has been minimized.

Copy link
@imatiach-msft

imatiach-msft Jan 10, 2017

Author Contributor

hmm, the logistic regression gets the number of classes and throws in the binomial case, and getNumClasses should ideally get the number of classes from the metadata which shouldn't make an extra pass (ideally the label column is categorical?), but I think it's ok for now to make it 2 until we make GBT support the multiclass case.

This comment has been minimized.

Copy link
@sethah

sethah Jan 10, 2017

Contributor

If getNumClasses doesn't find metadata, then it will make a pass over the data.

This comment has been minimized.

Copy link
@imatiach-msft

imatiach-msft Jan 10, 2017

Author Contributor

right, I removed it for now, but ideally the user would preprocess the data and make the label column categorical. Either they would do that through the string indexer, or if they know it ahead of time, they would just add the metadata themselves (although unfortunately currently only advanced users would be able to do this, there is no transform that will allow they to pre-specify the labels if they know ahead of time what the labels are)

This comment has been minimized.

Copy link
@sethah

sethah Jan 10, 2017

Contributor

It's still there...

This comment has been minimized.

Copy link
@imatiach-msft

imatiach-msft Jan 10, 2017

Author Contributor

oops, I thought I changed it, sorry

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala Outdated
@@ -159,14 +157,21 @@ class GBTClassifier @Since("1.4.0") (
val numFeatures = oldDataset.first().features.size
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)

val instr = Instrumentation.create(this, oldDataset)
val numClasses: Int = getNumClasses(dataset)
if (isDefined(thresholds)) {

This comment has been minimized.

Copy link
@sethah

sethah Jan 10, 2017

Contributor

I prefer to leave the handling of thresholds for another JIRA, but technically users will be able to set it. We can either do it here in this PR, or throw an error until we get it implemented in a follow up. Thoughts @jkbradley?

This comment has been minimized.

Copy link
@imatiach-msft

imatiach-msft Jan 10, 2017

Author Contributor

it looks like decision tree classifier has the same problem with thresholds

This comment has been minimized.

Copy link
@imatiach-msft

imatiach-msft Jan 10, 2017

Author Contributor

actually, it looks like both this classifier and decision tree handle thresholds already in method probability2prediction under ProbabilisticClassifier.scala. Can you give more information on why GBTClassifier is not handling thresholds correctly?

This comment has been minimized.

Copy link
@sethah

sethah Jan 10, 2017

Contributor

There is no setThresholds method, and there are no unit tests off the top of my head.

This comment has been minimized.

Copy link
@imatiach-msft

imatiach-msft Jan 10, 2017

Author Contributor

I do see a setThresholds method both on the classifier and the model. It comes from ProbabilisticClassifier:

abstract class ProbabilisticClassifier[
FeaturesType,
E <: ProbabilisticClassifier[FeaturesType, E, M],
M <: ProbabilisticClassificationModel[FeaturesType, M]]
extends Classifier[FeaturesType, E, M] with ProbabilisticClassifierParams {

/** @group setParam */
def setProbabilityCol(value: String): E = set(probabilityCol, value).asInstanceOf[E]

/** @group setParam */
def setThresholds(value: Array[Double]): E = set(thresholds, value).asInstanceOf[E]
}

This comment has been minimized.

Copy link
@sethah

sethah Jan 10, 2017

Contributor

ah, ok good catch. We should handle thresholds in this PR then. Can you look at other test suites and add those tests?

This comment has been minimized.

Copy link
@imatiach-msft

imatiach-msft Jan 10, 2017

Author Contributor

sure, I've added more tests in the latest commit. I've also fixed an issue where predict was not using thresholds - if they are defined we now use them.

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala Outdated
@@ -248,12 +268,38 @@ class GBTClassificationModel private[ml](
if (prediction > 0.0) 1.0 else 0.0
}

override protected def predictRaw(features: Vector): Vector = {
val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)

This comment has been minimized.

Copy link
@sethah

sethah Jan 10, 2017

Contributor

we can avoid duplicating this code. Maybe, as in LogisticRegression, we can create a private function called score or margin and then use that in predict and predictRaw

This comment has been minimized.

Copy link
@imatiach-msft

imatiach-msft Jan 10, 2017

Author Contributor

good idea, refactored to margin private method

case (pred1, pred2) => assert(pred1 === pred2)
}
}

This comment has been minimized.

Copy link
@sethah

sethah Jan 10, 2017

Contributor

Shall we add a "default params" test for parity with other suites like LogisticRegression?

This comment has been minimized.

Copy link
@imatiach-msft

imatiach-msft Jan 10, 2017

Author Contributor

good idea, added the extra test

@SparkQA

This comment has been minimized.

Copy link

commented Jan 10, 2017

Test build #71142 has finished for PR 16441 at commit 2bd32a0.

  • This patch fails Scala style tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
  • trait ClassificationLoss extends Loss
@SparkQA

This comment has been minimized.

Copy link

commented Jan 10, 2017

Test build #71144 has finished for PR 16441 at commit ffa0fe5.

  • This patch fails Scala style tests.
  • This patch merges cleanly.
  • This patch adds no public classes.
@SparkQA

This comment has been minimized.

Copy link

commented Jan 10, 2017

Test build #71145 has finished for PR 16441 at commit 1dde99b.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.
@imatiach-msft

This comment has been minimized.

Copy link
Contributor Author

commented Jan 10, 2017

ping @sethah @jkbradley could you please take another look since I've updated the code review based on your comments? Thank you!

@SparkQA

This comment has been minimized.

Copy link

commented Jan 10, 2017

Test build #71150 has finished for PR 16441 at commit 8cd6c2b.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.
@sethah
Copy link
Contributor

left a comment

Looking good! Thanks for all the updates

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala Outdated
@@ -215,10 +224,21 @@ class GBTClassificationModel private[ml](
*
* @param _trees Decision trees in the ensemble.
* @param _treeWeights Weights for the decision trees in the ensemble.
* @param numFeatures The number of features.
*/
private[ml] def this(uid: String, _trees: Array[DecisionTreeRegressionModel],

This comment has been minimized.

Copy link
@sethah

sethah Jan 10, 2017

Contributor

style: put each arg on one line, using 4 space indentation as is done with the constructor

This comment has been minimized.

Copy link
@imatiach-msft

imatiach-msft Jan 10, 2017

Author Contributor

done, thanks, also updated the other constructor (my default intellij settings don't seem to match the suggested ones)

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala Outdated
super.predict(features)
} else {
val prediction: Double = margin(features)
if (prediction > 0.0) 1.0 else 0.0

This comment has been minimized.

Copy link
@sethah

sethah Jan 10, 2017

Contributor

nit: if (margin(features) > 0.0) 1.0 else 0.0

This comment has been minimized.

Copy link
@imatiach-msft

imatiach-msft Jan 10, 2017

Author Contributor

done

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala Outdated

override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
rawPrediction match {
// The probability can be calculated for positive result:

This comment has been minimized.

Copy link
@sethah

sethah Jan 10, 2017

Contributor

This comment should be removed since we made this function generic

This comment has been minimized.

Copy link
@imatiach-msft

imatiach-msft Jan 10, 2017

Author Contributor

moved comment to LogLoss computeProbability method (kept for positive result only)

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala Outdated
// and negative result:
// p-(x) = 1 / (1 + e^(2 * F(x)))
case dv: DenseVector =>
dv.values(0) = getOldLossType.computeProbability(dv.values(0))

This comment has been minimized.

Copy link
@sethah

sethah Jan 10, 2017

Contributor

Should we make a private class member private val loss = getOldLossType? Otherwise we call getOldLossType, (which calls getLossType) for every single instance.

This comment has been minimized.

Copy link
@imatiach-msft

imatiach-msft Jan 10, 2017

Author Contributor

hmm, this is a tricky point, because in the future if we have more than one loss when the user changes it the results should change as well, but since we only have one loss function I guess it is ok... I'll make the update but add a warning comment

This comment has been minimized.

Copy link
@sethah

sethah Jan 11, 2017

Contributor

You mean that if someone takes a model and changes the loss type via set(lossType, "other") that the probability function should change? I don't think it makes sense to change the probability function for a model, since the probability is chosen to be optimal for a specific loss, but it's a good point. What do you think?

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala Outdated
@@ -159,14 +157,21 @@ class GBTClassifier @Since("1.4.0") (
val numFeatures = oldDataset.first().features.size
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)

val instr = Instrumentation.create(this, oldDataset)
val numClasses: Int = getNumClasses(dataset)

This comment has been minimized.

Copy link
@sethah

sethah Jan 10, 2017

Contributor

It's still there...

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala Outdated
*/
@Since("1.2.0")
@DeveloperApi
trait ClassificationLoss extends Loss {

This comment has been minimized.

Copy link
@sethah

sethah Jan 10, 2017

Contributor

this can be private[spark] I think

This comment has been minimized.

Copy link
@imatiach-msft

imatiach-msft Jan 10, 2017

Author Contributor

done

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala Outdated
@@ -52,4 +61,8 @@ object LogLoss extends Loss {
// The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
2.0 * MLUtils.log1pExp(-margin)
}

override private[spark] def computeProbability(prediction: Double): Double = {
1 / (1 + math.exp(-2 * prediction))

This comment has been minimized.

Copy link
@sethah

sethah Jan 10, 2017

Contributor

nit: prefer explicit doubles like 1.0 instead of 1

This comment has been minimized.

Copy link
@imatiach-msft

imatiach-msft Jan 10, 2017

Author Contributor

done

@imatiach-msft

This comment has been minimized.

Copy link
Contributor Author

commented Jan 10, 2017

ping @sethah @jkbradley could you please take another look since I've updated the code review based on your comments? Thank you!

@SparkQA

This comment has been minimized.

Copy link

commented Jan 11, 2017

Test build #71169 has finished for PR 16441 at commit 0b96223.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.
@SparkQA

This comment has been minimized.

Copy link

commented Jan 11, 2017

Test build #71170 has finished for PR 16441 at commit b4f9b34.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.
@SparkQA

This comment has been minimized.

Copy link

commented Jan 11, 2017

Test build #71171 has finished for PR 16441 at commit 92d1348.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@imatiach-msft imatiach-msft force-pushed the imatiach-msft:ilmat/fix-GBT branch to 1abfee0 Jan 18, 2017

@SparkQA

This comment has been minimized.

Copy link

commented Jan 18, 2017

Test build #71616 has finished for PR 16441 at commit 1abfee0.

  • This patch fails to build.
  • This patch merges cleanly.
  • This patch adds no public classes.
@SparkQA

This comment has been minimized.

Copy link

commented Jan 18, 2017

Test build #71617 has finished for PR 16441 at commit 818de81.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.
@jkbradley

This comment has been minimized.

Copy link
Member

commented Jan 18, 2017

LGTM
Merging with master
Thanks @imatiach-msft and @sethah for reviewing!

@asfgit asfgit closed this in fe409f3 Jan 18, 2017

@MLnick

This comment has been minimized.

Copy link
Contributor

commented Jan 19, 2017

@imatiach-msft thanks for this, really great to have GBT in the classification trait hierarchy, and now usable with binary evaluator metrics!

uzadude added a commit to uzadude/spark that referenced this pull request Jan 27, 2017

[SPARK-14975][ML] Fixed GBTClassifier to predict probability per trai…
…ning instance and fixed interfaces

## What changes were proposed in this pull request?

For all of the classifiers in MLLib we can predict probabilities except for GBTClassifier.
Also, all classifiers inherit from ProbabilisticClassifier but GBTClassifier strangely inherits from Predictor, which is a bug.
This change corrects the interface and adds the ability for the classifier to give a probabilities vector.

## How was this patch tested?

The basic ML tests were run after making the changes.  I've marked this as WIP as I need to add more tests.

Author: Ilya Matiach <ilmat@microsoft.com>

Closes apache#16441 from imatiach-msft/ilmat/fix-GBT.

cmonkey added a commit to cmonkey/spark that referenced this pull request Feb 15, 2017

[SPARK-14975][ML] Fixed GBTClassifier to predict probability per trai…
…ning instance and fixed interfaces

## What changes were proposed in this pull request?

For all of the classifiers in MLLib we can predict probabilities except for GBTClassifier.
Also, all classifiers inherit from ProbabilisticClassifier but GBTClassifier strangely inherits from Predictor, which is a bug.
This change corrects the interface and adds the ability for the classifier to give a probabilities vector.

## How was this patch tested?

The basic ML tests were run after making the changes.  I've marked this as WIP as I need to add more tests.

Author: Ilya Matiach <ilmat@microsoft.com>

Closes apache#16441 from imatiach-msft/ilmat/fix-GBT.
@yonglyhoo

This comment has been minimized.

Copy link

commented Jul 15, 2017

In which release this fix is going to be available? Thanks!

@MLnick

This comment has been minimized.

Copy link
Contributor

commented Jul 15, 2017

@yonglyhoo

This comment has been minimized.

Copy link

commented Jul 15, 2017

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.