Skip to content

Commit

Permalink
separate calculation of predict of node from calculation of info gain
Browse files Browse the repository at this point in the history
  • Loading branch information
qiping.lqp committed Sep 9, 2014
1 parent ac42378 commit ff34845
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 29 deletions.
47 changes: 33 additions & 14 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo

// Find best split for all nodes at a level.
timer.start("findBestSplits")
val splitsStatsForLevel: Array[(Split, InformationGainStats)] =
val splitsStatsForLevel: Array[(Split, InformationGainStats, Predict)] =
DecisionTree.findBestSplits(treeInput, parentImpurities,
metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer)
timer.stop("findBestSplits")
Expand All @@ -143,8 +143,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
timer.start("extractNodeInfo")
val split = nodeSplitStats._1
val stats = nodeSplitStats._2
val predict = nodeSplitStats._3
val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth)
val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats))
val node = new Node(nodeIndex, predict, isLeaf, Some(split), None, None, Some(stats))
logDebug("Node = " + node)
nodes(nodeIndex) = node
timer.stop("extractNodeInfo")
Expand Down Expand Up @@ -425,7 +426,7 @@ object DecisionTree extends Serializable with Logging {
splits: Array[Array[Split]],
bins: Array[Array[Bin]],
maxLevelForSingleGroup: Int,
timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats)] = {
timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats, Predict)] = {
// split into groups to avoid memory overflow during aggregation
if (level > maxLevelForSingleGroup) {
// When information for all nodes at a given level cannot be stored in memory,
Expand All @@ -434,7 +435,7 @@ object DecisionTree extends Serializable with Logging {
// numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
val numGroups = 1 << level - maxLevelForSingleGroup
logDebug("numGroups = " + numGroups)
var bestSplits = new Array[(Split, InformationGainStats)](0)
var bestSplits = new Array[(Split, InformationGainStats, Predict)](0)
// Iterate over each group of nodes at a level.
var groupIndex = 0
while (groupIndex < numGroups) {
Expand Down Expand Up @@ -605,7 +606,7 @@ object DecisionTree extends Serializable with Logging {
bins: Array[Array[Bin]],
timer: TimeTracker,
numGroups: Int = 1,
groupIndex: Int = 0): Array[(Split, InformationGainStats)] = {
groupIndex: Int = 0): Array[(Split, InformationGainStats, Predict)] = {

/*
* The high-level descriptions of the best split optimizations are noted here.
Expand Down Expand Up @@ -705,7 +706,7 @@ object DecisionTree extends Serializable with Logging {

// Calculate best splits for all nodes at a given level
timer.start("chooseSplits")
val bestSplits = new Array[(Split, InformationGainStats)](numNodes)
val bestSplits = new Array[(Split, InformationGainStats, Predict)](numNodes)
// Iterating over all nodes at this level
var nodeIndex = 0
while (nodeIndex < numNodes) {
Expand Down Expand Up @@ -747,18 +748,16 @@ object DecisionTree extends Serializable with Logging {

val totalCount = leftCount + rightCount

val parentNodeAgg = leftImpurityCalculator.copy
parentNodeAgg.add(rightImpurityCalculator)

// impurity of parent node
val impurity = if (level > 0) {
topImpurity
} else {
val parentNodeAgg = leftImpurityCalculator.copy
parentNodeAgg.add(rightImpurityCalculator)
parentNodeAgg.calculate()
}

val predict = parentNodeAgg.predict
val prob = parentNodeAgg.prob(predict)

val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
val rightImpurity = rightImpurityCalculator.calculate()

Expand All @@ -770,7 +769,18 @@ object DecisionTree extends Serializable with Logging {
return InformationGainStats.invalidInformationGainStats
}

new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
new InformationGainStats(gain, impurity, leftImpurity, rightImpurity)
}

private def calculatePredict(
leftImpurityCalculator: ImpurityCalculator,
rightImpurityCalculator: ImpurityCalculator): Predict = {
val parentNodeAgg = leftImpurityCalculator.copy
parentNodeAgg.add(rightImpurityCalculator)
val predict = parentNodeAgg.predict
val prob = parentNodeAgg.prob(predict)

new Predict(predict, prob)
}

/**
Expand All @@ -786,12 +796,14 @@ object DecisionTree extends Serializable with Logging {
nodeImpurity: Double,
level: Int,
metadata: DecisionTreeMetadata,
splits: Array[Array[Split]]): (Split, InformationGainStats) = {
splits: Array[Array[Split]]): (Split, InformationGainStats, Predict) = {

logDebug("node impurity = " + nodeImpurity)

var predict: Option[Predict] = None

// For each (feature, split), calculate the gain, and select the best (feature, split).
Range(0, metadata.numFeatures).map { featureIndex =>
val (bestSplit, bestSplitStats) = Range(0, metadata.numFeatures).map { featureIndex =>
val numSplits = metadata.numSplits(featureIndex)
if (metadata.isContinuous(featureIndex)) {
// Cumulative sum (scanLeft) of bin statistics.
Expand All @@ -809,6 +821,7 @@ object DecisionTree extends Serializable with Logging {
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
rightChildStats.subtract(leftChildStats)
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
val gainStats =
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
(splitIdx, gainStats)
Expand All @@ -825,6 +838,7 @@ object DecisionTree extends Serializable with Logging {
Range(0, numSplits).map { splitIndex =>
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
val gainStats =
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
(splitIndex, gainStats)
Expand Down Expand Up @@ -899,6 +913,7 @@ object DecisionTree extends Serializable with Logging {
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
rightChildStats.subtract(leftChildStats)
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
val gainStats =
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
(splitIndex, gainStats)
Expand All @@ -913,6 +928,10 @@ object DecisionTree extends Serializable with Logging {
(bestFeatureSplit, bestFeatureGainStats)
}
}.maxBy(_._2.gain)

require(predict.isDefined, "must calculate predict for each node")

(bestSplit, bestSplitStats, predict.get)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,21 @@ import org.apache.spark.annotation.DeveloperApi
* @param impurity current node impurity
* @param leftImpurity left node impurity
* @param rightImpurity right node impurity
* @param predict predicted value
* @param prob probability of the label (classification only)
*/
@DeveloperApi
class InformationGainStats(
val gain: Double,
val impurity: Double,
val leftImpurity: Double,
val rightImpurity: Double,
val predict: Double,
val prob: Double = 0.0) extends Serializable {
val rightImpurity: Double) extends Serializable {

override def toString = {
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f, prob = %f"
.format(gain, impurity, leftImpurity, rightImpurity, predict, prob)
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f"
.format(gain, impurity, leftImpurity, rightImpurity)
}
}


private[tree] object InformationGainStats {
val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, 0.0)
val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.tree.model

import org.apache.spark.annotation.DeveloperApi

/**
* :: DeveloperApi ::
* Predicted value for a node
* @param predict predicted value
* @param prob probability of the label (classification only)
*/
@DeveloperApi
class Predict(
val predict: Double,
val prob: Double = 0.0) extends Serializable{

override def toString = {
"predict = %f, prob = %f".format(predict, prob)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,10 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(split.threshold === Double.MinValue)

val stats = bestSplits(0)._2
val predict = bestSplits(0)._3
assert(stats.gain > 0)
assert(stats.predict === 1)
assert(stats.prob === 0.6)
assert(predict.predict === 1)
assert(predict.prob === 0.6)
assert(stats.impurity > 0.2)
}

Expand Down Expand Up @@ -313,8 +314,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(split.threshold === Double.MinValue)

val stats = bestSplits(0)._2
val predict = bestSplits(0)._3.predict
assert(stats.gain > 0)
assert(stats.predict === 0.6)
assert(predict === 0.6)
assert(stats.impurity > 0.2)
}

Expand Down Expand Up @@ -392,7 +394,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplits(0)._2.gain === 0)
assert(bestSplits(0)._2.leftImpurity === 0)
assert(bestSplits(0)._2.rightImpurity === 0)
assert(bestSplits(0)._2.predict === 1)
assert(bestSplits(0)._3.predict === 1)
}

test("Binary classification stump with fixed label 0 for Entropy") {
Expand Down Expand Up @@ -421,7 +423,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplits(0)._2.gain === 0)
assert(bestSplits(0)._2.leftImpurity === 0)
assert(bestSplits(0)._2.rightImpurity === 0)
assert(bestSplits(0)._2.predict === 0)
assert(bestSplits(0)._3.predict === 0)
}

test("Binary classification stump with fixed label 1 for Entropy") {
Expand Down Expand Up @@ -450,7 +452,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplits(0)._2.gain === 0)
assert(bestSplits(0)._2.leftImpurity === 0)
assert(bestSplits(0)._2.rightImpurity === 0)
assert(bestSplits(0)._2.predict === 1)
assert(bestSplits(0)._3.predict === 1)
}

test("Second level node building with vs. without groups") {
Expand Down Expand Up @@ -501,7 +503,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplits(i)._2.impurity === bestSplitsWithGroups(i)._2.impurity)
assert(bestSplits(i)._2.leftImpurity === bestSplitsWithGroups(i)._2.leftImpurity)
assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity)
assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict)
assert(bestSplits(i)._3.predict === bestSplitsWithGroups(i)._3.predict)
}
}

Expand Down

0 comments on commit ff34845

Please sign in to comment.