Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-2207][SPARK-3272][MLLib]Add minimum information gain and minimum instances per node as training parameters for decision tree. #2332

Closed
wants to merge 17 commits into from

Conversation

chouqin
Copy link
Contributor

@chouqin chouqin commented Sep 9, 2014

These two parameters can act as early stop rules to do pre-pruning. When a split cause cause left or right child to have less than minInstancesPerNode or has less information gain than minInfoGain, current node will not be split by this split.

When there is no possible splits that satisfy requirements, there is no useful information gain stats, but we still need to calculate the predict value for current node. So I separated calculation of predict from calculation of information gain, which can also save computation when the number of possible splits is large. Please see SPARK-3272 for more details.

CC: @mengxr @manishamde @jkbradley, please help me review this, thanks.

@SparkQA
Copy link

SparkQA commented Sep 9, 2014

Can one of the admins verify this patch?

@manishamde
Copy link
Contributor

@chouqin Thanks for the PR. I won't be able to comment since I am on a break now. @jkbradley and @mengxr reviews should be sufficient. :-)

@mengxr
Copy link
Contributor

mengxr commented Sep 9, 2014

Jenkins, add to whitelist.

@mengxr
Copy link
Contributor

mengxr commented Sep 9, 2014

this is ok to test

val gainStats =
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
(splitIdx, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
if (bestFeatureGainStats == InformationGainStats.invalidInformationGainStats) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you could avoid explicitly checking for invalidInformationGainStats since the gain is Double.minValue. At the very end of the maxBy calls, you could then check to see if the information gain is Double.minValue, in which case we know that no split is worth doing. That should simplify the code here and in the other maxBy calls below.

@jkbradley
Copy link
Member

@chouqin Thanks for this update---it will be great to have these 2 options supported. My comments are mostly about simplifying the code: removing Predict and Split.noSplit, and the related simplifications. We will not save much computation by avoiding calculating predictions. However, we will definitely save a lot of computation by supporting these 2 options to allow early stopping on some datasets!

@SparkQA
Copy link

SparkQA commented Sep 9, 2014

QA tests have started for PR 2332 at commit efcc736.

  • This patch merges cleanly.

@SparkQA
Copy link

SparkQA commented Sep 9, 2014

QA tests have finished for PR 2332 at commit efcc736.

  • This patch passes unit tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
    • class Predict(

@chouqin
Copy link
Contributor Author

chouqin commented Sep 9, 2014

@jkbradley thanks for your comments, I will change my code accordingly. As for the Predict class, I still think it is needed, for the following reasons:

  1. Saving of computation, for each split, it will traverse a array of bins two times(one to add, one to find the index that has the maximum value), I don't think this saving is trival.
  2. As for code simplicity, I think predict value for a node should not tied to information gain for a split(information gain ). I have read Weka and scikit-learn's decision tree code, they don't store a predict value along with split's information gain stats. I think the changed code may be easy to understand somehow.
  3. For the early return of calculate calculateGainForSplit, when left child or right child has less than minInstancesPerNode we can just return an invalid information gain stats, without calculate the predict value.If all splits are early returned, we need a way to calculate the predict value for current node.

If you don't like creating a new Predict class, I can use a tuple to replace that, but this seems to be harder to understand.

@jkbradley
Copy link
Member

@chouqin Thanks for your responses. I think you've convinced me that Predict is reasonable, since it is a different concept from info gain. Could you please make it private[tree] though?

Clarification for 1.: By "array of bins," do you mean the array of classes to calculate the prediction (for classification)? Unless there are a very large number of classes, I do not think the savings will be that much.

Thanks!

@jkbradley
Copy link
Member

@chouqin Could you also please add tag [mllib] to the PR title?

@SparkQA
Copy link

SparkQA commented Sep 10, 2014

QA tests have started for PR 2332 at commit d593ec7.

  • This patch merges cleanly.

@SparkQA
Copy link

SparkQA commented Sep 10, 2014

QA tests have finished for PR 2332 at commit d593ec7.

  • This patch passes unit tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
    • class Predict(

@chouqin
Copy link
Contributor Author

chouqin commented Sep 10, 2014

@jkbradley I have removed noSplit object and add private[tree] to Predict.

@SparkQA
Copy link

SparkQA commented Sep 10, 2014

QA tests have started for PR 2332 at commit 0278a11.

  • This patch merges cleanly.

@SparkQA
Copy link

SparkQA commented Sep 10, 2014

QA tests have finished for PR 2332 at commit 0278a11.

  • This patch passes unit tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
    • case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression]
    • case class LastFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction
    • case class Abs(child: Expression) extends UnaryExpression

@jkbradley
Copy link
Member

@chouqin Thanks for the updates! This looks basically ready, except for the edge cases in the test suite. I tested it and it ran fine. I think those complaints about public classes are unrelated. Once the test suite is updated, I'd say it is ready.

@chouqin
Copy link
Contributor Author

chouqin commented Sep 10, 2014

@jkbradley thanks for your replies. as I replied in your comments, I have changed minInstancePerNode to 2 in test cases, and add one more test case to test that when a split doesn't satisfy min instances per node requirements, this split will not be chosen, even though the info gain is large(in this test case, total number of instances is 4, and we can find a split to let both left and right child have 2 instances). Do you think this is OK?

@SparkQA
Copy link

SparkQA commented Sep 10, 2014

QA tests have started for PR 2332 at commit f1d11d1.

  • This patch merges cleanly.

@SparkQA
Copy link

SparkQA commented Sep 10, 2014

QA tests have finished for PR 2332 at commit f1d11d1.

  • This patch fails unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@mengxr
Copy link
Contributor

mengxr commented Sep 10, 2014

test this please

@jkbradley
Copy link
Member

@chouqin Thanks for the updates! LGTM

@@ -898,6 +928,10 @@ object DecisionTree extends Serializable with Logging {
(bestFeatureSplit, bestFeatureGainStats)
}
}.maxBy(_._2.gain)

require(predict.isDefined, "must calculate predict for each node")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Use assert instead of require. The latter throws IllegalArgumentException, which doesn't apply here. (not necessary to update)

@mengxr
Copy link
Contributor

mengxr commented Sep 10, 2014

@chouqin I made very minor inline comments. It is not necessary to update the PR. I'm going to merge this if Jenkins is happy, and @jkbradley will make those changes in his following PR.

@SparkQA
Copy link

SparkQA commented Sep 10, 2014

QA tests have started for PR 2332 at commit f1d11d1.

  • This patch merges cleanly.

@SparkQA
Copy link

SparkQA commented Sep 10, 2014

QA tests have started for PR 2332 at commit f1d11d1.

  • This patch merges cleanly.

@SparkQA
Copy link

SparkQA commented Sep 10, 2014

QA tests have finished for PR 2332 at commit f1d11d1.

  • This patch passes unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Sep 10, 2014

QA tests have finished for PR 2332 at commit f1d11d1.

  • This patch passes unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@asfgit asfgit closed this in 79cdb9b Sep 10, 2014
@mengxr
Copy link
Contributor

mengxr commented Sep 10, 2014

Merged into master. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants