Skip to content

Commit

Permalink
added documentation, fixed off by 1 error in max level calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
manishamde committed May 6, 2014
1 parent cbd9f14 commit 5e82202
Showing 1 changed file with 11 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,20 +76,20 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo

// Max memory usage for aggregates
val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024
logDebug("max memory usage for aggregates = " + maxMemoryUsage)
logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
val numElementsPerNode = {
strategy.algo match {
case Classification => 2 * numBins * numFeatures
case Classification => 2 * numBins * numFeatures
case Regression => 3 * numBins * numFeatures
}
}
logDebug("numElementsPerNode = " + numElementsPerNode)
val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array
val maxNumberOfNodesPerGroup = math.max(maxMemoryUsage / arraySizePerNode, 1)
logDebug("maxNumberOfNodesPerGroup = " + maxNumberOfNodesPerGroup)
// nodes at a level is 2^(level-1). level is zero indexed.
// nodes at a level is 2^level. level is zero indexed.
val maxLevelForSingleGroup = math.max(
(math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt - 1, 0)
(math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt, 0)
logDebug("max level for single group = " + maxLevelForSingleGroup)

/*
Expand Down Expand Up @@ -299,11 +299,16 @@ object DecisionTree extends Serializable with Logging {
bins: Array[Array[Bin]],
maxLevelForSingleGroup: Int): Array[(Split, InformationGainStats)] = {
// split into groups to avoid memory overflow during aggregation
if (level > maxLevelForSingleGroup) {
if (level > maxLevelForSingleGroup) {
// When information for all nodes at a given level cannot be stored in memory,
// the nodes are divided into multiple groups at each level with the number of groups
// increasing exponentially per level. For example, if maxLevelForSingleGroup is 10,
// numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
val numGroups = math.pow(2, (level - maxLevelForSingleGroup)).toInt
logDebug("numGroups = " + numGroups)
var groupIndex = 0
var bestSplits = new Array[(Split, InformationGainStats)](0)
// Iterate over each group of nodes at a level.
var groupIndex = 0
while (groupIndex < numGroups) {
val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level,
filters, splits, bins, numGroups, groupIndex)
Expand Down

0 comments on commit 5e82202

Please sign in to comment.