From a08adbd7f46d9fb0831dd8f302869805df2db4a8 Mon Sep 17 00:00:00 2001 From: sethah Date: Fri, 2 Oct 2015 08:53:56 -0700 Subject: [PATCH 01/12] Removing superfluous bins in decision tree training --- .../spark/ml/tree/impl/RandomForest.scala | 56 +++++++++--- .../spark/mllib/tree/DecisionTree.scala | 89 ++++++++++++++++++- .../mllib/tree/impl/DTStatsAggregator.scala | 21 +++++ .../tree/impl/DecisionTreeMetadata.scala | 7 +- 4 files changed, 157 insertions(+), 16 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index f994c258b2cad..aa3ece58f6edd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -81,6 +81,10 @@ private[ml] object RandomForest extends Logging { s"\t$featureIndex\t${metadata.numBins(featureIndex)}" }.mkString("\n")) + println("*****************") + metadata.numBins.foreach(x => printf(x.toString + "/")) + println("*****************") + // Bin feature values (TreePoint representation). // Cache input RDD for speedup during multiple passes. val treeInput = TreePoint.convertToTreeRDD(retaggedInput, splits, metadata) @@ -250,14 +254,22 @@ private[ml] object RandomForest extends Logging { val numSplits = agg.metadata.numSplits(featureIndex) val featureSplits = splits(featureIndex) var splitIndex = 0 +// while (splitIndex < numSplits) { +// if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) { +// agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) +// } else { +// agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) +// } +// splitIndex += 1 +// } + while (splitIndex < numSplits) { if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) { agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) - } else { - agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) } splitIndex += 1 } + } else { // Ordered feature val binIndex = treePoint.binnedFeatures(featureIndex) @@ -697,17 +709,39 @@ private[ml] object RandomForest extends Logging { (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else if (binAggregates.metadata.isUnordered(featureIndex)) { // Unordered categorical feature + val (leftChildOffset, rightChildOffset) = binAggregates.getLeftRightFeatureOffsets(featureIndexIdx) - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { splitIndex => - val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) - val rightChildStats = - binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) - gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, - leftChildStats, rightChildStats, binAggregates.metadata) - (splitIndex, gainAndImpurityStats) - }.maxBy(_._2.gain) + + // SETH + val leftSplits = Range(0, numSplits).map { splitIndex => + binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) + } + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + var splitIndex = 0 + while (splitIndex < numSplits - 1) { + binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) + splitIndex += 1 + } + + val (bestFeatureSplitIndex, bestFeatureGainStats) = leftSplits.map { leftChildStats => + val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits - 1) + rightChildStats.subtract(leftChildStats) + gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIndex, gainAndImpurityStats) + }.maxBy(_._2.gain) + // SETH + +// val (bestFeatureSplitIndex, bestFeatureGainStats) = +// Range(0, numSplits).map { splitIndex => +// val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) +// val rightChildStats = +// binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) +// gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, +// leftChildStats, rightChildStats, binAggregates.metadata) +// (splitIndex, gainAndImpurityStats) +// }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else { // Ordered categorical feature diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 76c32208ea098..f11bd4fbd46a6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -52,6 +52,7 @@ class DecisionTree @Since("1.0.0") (private val strategy: Strategy) /** * Method to train a decision tree model over an RDD + * * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * @return DecisionTreeModel that can be used for prediction. */ @@ -377,12 +378,19 @@ object DecisionTree extends Serializable with Logging { if (splits(featureIndex)(splitIndex).categories.contains(featureValue)) { agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) - } else { - agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label, - instanceWeight) } splitIndex += 1 } +// while (splitIndex < numSplits) { +// if (splits(featureIndex)(splitIndex).categories.contains(featureValue)) { +// agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, +// instanceWeight) +// } else { +// agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label, +// instanceWeight) +// } +// splitIndex += 1 +// } } else { // Ordered feature val binIndex = treePoint.binnedFeatures(featureIndex) @@ -831,6 +839,81 @@ object DecisionTree extends Serializable with Logging { binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) splitIndex += 1 } + // Find best split. + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { case splitIdx => + val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) + val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) + rightChildStats.subtract(leftChildStats) + predictWithImpurity = Some(predictWithImpurity.getOrElse( + calculatePredictImpurity(leftChildStats, rightChildStats))) + val gainStats = calculateGainForSplit(leftChildStats, + rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) + (splitIdx, gainStats) + }.maxBy(_._2.gain) + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) + } else if (binAggregates.metadata.isUnordered(featureIndex)) { + // Unordered categorical feature + val (leftChildOffset, rightChildOffset) = + binAggregates.getLeftRightFeatureOffsets(featureIndexIdx) + + // SETH + val leftSplits = Range(0, numSplits).map { splitIndex => + binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) + } + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + var splitIndex = 0 + while (splitIndex < numSplits - 1) { + binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) + splitIndex += 1 + } + + val (bestFeatureSplitIndex, bestFeatureGainStats) = leftSplits.map { leftChildStats => + val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits - 1) + rightChildStats.subtract(leftChildStats) + predictWithImpurity = Some(predictWithImpurity.getOrElse( + calculatePredictImpurity(leftChildStats, rightChildStats))) + val gainStats = calculateGainForSplit(leftChildStats, + rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) + (splitIndex, gainStats) + }.maxBy(_._2.gain) + // SETH + +// val (bestFeatureSplitIndex, bestFeatureGainStats) = +// Range(0, numSplits).map { splitIndex => +// val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) +// val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) +// predictWithImpurity = Some(predictWithImpurity.getOrElse( +// calculatePredictImpurity(leftChildStats, rightChildStats))) +// val gainStats = calculateGainForSplit(leftChildStats, +// rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) +// (splitIndex, gainStats) +// }.maxBy(_._2.gain) + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) + } else { + // Ordered categorical feature + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + val numBins = binAggregates.metadata.numBins(featureIndex) + + /* Each bin is one category (feature value). + * The bins are ordered based on centroidForCategories, and this ordering determines which + * splits are considered. (With K categories, we consider K - 1 possible splits.) + * + * centroidForCategories is a list: (category, centroid) + */ + val centroidForCategories = if (binAggregates.metadata.isMulticlass) { + // For categorical variables in multiclass classification, + // the bins are ordered by the impurity of their corresponding labels. + Range(0, numBins).map { case featureValue => + val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val centroid = if (categoryStats.count != 0) { + categoryStats.calculate() + } else { + Double.MaxValue + } + (featureValue, centroid) +>>>>>>> Removing superfluous bins in decision tree training + } // Find best split. val (bestFeatureSplitIndex, bestFeatureGainStats) = Range(0, numSplits).map { case splitIdx => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala index 7985ed4b4c0fa..3fa0011c82d7f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala @@ -164,4 +164,25 @@ private[spark] class DTStatsAggregator( } this } + + def copy: DTStatsAggregator = { + val copyAggregator = new DTStatsAggregator(this.metadata, featureSubset) + copyAggregator.merge(this) + } + + def totalsForFeature(featureIndex: Int): Array[Double] = { + val numBins = metadata.numSplits(featureIndex) + val featureOffset = featureOffsets(featureIndex) + var i = 0 + val totals = Array.fill[Double](statsSize)(0.0) + while (i < numBins) { + var j = 0 + while (j < statsSize) { + totals(j) += allStats(featureOffset + i*statsSize + j) + j +=1 + } + i += 1 + } + totals + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index 21ee49c45788c..f241019bd6cb4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -71,7 +71,7 @@ private[spark] class DecisionTreeMetadata( * For ordered features, there is 1 more bin than split. */ def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) { - numBins(featureIndex) >> 1 + numBins(featureIndex) //>> 1 } else { numBins(featureIndex) - 1 } @@ -140,6 +140,8 @@ private[spark] object DecisionTreeMetadata extends Logging { val unorderedFeatures = new mutable.HashSet[Int]() val numBins = Array.fill[Int](numFeatures)(maxPossibleBins) if (numClasses > 2) { + println("Multiclass") + println(strategy.categoricalFeaturesInfo) // Multiclass classification val maxCategoriesForUnorderedFeature = ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt @@ -151,6 +153,7 @@ private[spark] object DecisionTreeMetadata extends Logging { // which require 2 * ((1 << numCategories - 1) - 1) bins. // We do this check with log values to prevent overflows in case numCategories is large. // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins + println(maxCategoriesForUnorderedFeature.toString+numCategories.toString+"[") if (numCategories <= maxCategoriesForUnorderedFeature) { unorderedFeatures.add(featureIndex) numBins(featureIndex) = numUnorderedBins(numCategories) @@ -212,6 +215,6 @@ private[spark] object DecisionTreeMetadata extends Logging { * there are math.pow(2, arity - 1) - 1 such splits. * Each split has 2 corresponding bins. */ - def numUnorderedBins(arity: Int): Int = 2 * ((1 << arity - 1) - 1) + def numUnorderedBins(arity: Int): Int = ((1 << arity - 1) - 1) } From 59a45187e6b51ef66e3269b99131e2f8e96db295 Mon Sep 17 00:00:00 2001 From: sethah Date: Mon, 2 Nov 2015 10:15:41 -0800 Subject: [PATCH 02/12] adding parent stats to aggregator --- .../ml/tree/impl/DTStatsAggregator.scala | 200 ++++++++++++++++ .../ml/tree/impl/DecisionTreeMetadata.scala | 221 ++++++++++++++++++ .../spark/ml/tree/impl/RandomForest.scala | 47 +--- .../apache/spark/ml/tree/impl/TreePoint.scala | 2 +- .../spark/mllib/tree/DecisionTree.scala | 60 ++--- .../spark/mllib/tree/RandomForest.scala | 6 +- .../mllib/tree/impl/DTStatsAggregator.scala | 29 ++- .../spark/mllib/tree/impurity/Entropy.scala | 6 +- .../spark/mllib/tree/impurity/Gini.scala | 6 +- .../spark/mllib/tree/impurity/Impurity.scala | 6 + .../spark/mllib/tree/impurity/Variance.scala | 6 +- .../spark/mllib/tree/DecisionTreeSuite.scala | 48 ++-- .../spark/mllib/tree/RandomForestSuite.scala | 4 +- 13 files changed, 520 insertions(+), 121 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala new file mode 100644 index 0000000000000..f41ebb4b71707 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.mllib.tree.impurity._ + + + +/** + * DecisionTree statistics aggregator for a node. + * This holds a flat array of statistics for a set of (features, bins) + * and helps with indexing. + * This class is abstract to support learning with and without feature subsampling. + */ +private[spark] class DTStatsAggregator( + val metadata: DecisionTreeMetadata, + featureSubset: Option[Array[Int]]) extends Serializable { + + /** + * [[ImpurityAggregator]] instance specifying the impurity type. + */ + val impurityAggregator: ImpurityAggregator = metadata.impurity match { + case Gini => new GiniAggregator(metadata.numClasses) + case Entropy => new EntropyAggregator(metadata.numClasses) + case Variance => new VarianceAggregator() + case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}") + } + + /** + * Number of elements (Double values) used for the sufficient statistics of each bin. + */ + private val statsSize: Int = impurityAggregator.statsSize + + /** + * Number of bins for each feature. This is indexed by the feature index. + */ + private val numBins: Array[Int] = { + if (featureSubset.isDefined) { + featureSubset.get.map(metadata.numBins(_)) + } else { + metadata.numBins + } + } + + /** + * Offset for each feature for calculating indices into the [[allStats]] array. + */ + private val featureOffsets: Array[Int] = { + numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins) + } + + /** + * Total number of elements stored in this aggregator + */ + private val allStatsSize: Int = featureOffsets.last + + /** + * Flat array of elements. + * Index for start of stats for a (feature, bin) is: + * index = featureOffsets(featureIndex) + binIndex * statsSize + * Note: For unordered features, + * the left child stats have binIndex in [0, numBins(featureIndex) / 2)) + * and the right child stats in [numBins(featureIndex) / 2), numBins(featureIndex)) + */ + private val allStats: Array[Double] = new Array[Double](allStatsSize) + + private val parentStats: Array[Double] = new Array[Double](statsSize) + + + /** + * Get an [[ImpurityCalculator]] for a given (node, feature, bin). + * @param featureOffset For ordered features, this is a pre-computed (node, feature) offset + * from [[getFeatureOffset]]. + * For unordered features, this is a pre-computed + * (node, feature, left/right child) offset from + * [[getLeftRightFeatureOffsets]]. + */ + def getImpurityCalculator(featureOffset: Int, binIndex: Int): ImpurityCalculator = { + impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize) + } + + def getParentImpurityCalculator(): ImpurityCalculator = { + impurityAggregator.getCalculator(parentStats) + } + + /** + * Update the stats for a given (feature, bin) for ordered features, using the given label. + */ + def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = { + val i = featureOffsets(featureIndex) + binIndex * statsSize + impurityAggregator.update(allStats, i, label, instanceWeight) + } + def updateParent(label: Double, instanceWeight: Double): Unit = { + impurityAggregator.update(parentStats, 0, label, instanceWeight) + } + + /** + * Faster version of [[update]]. + * Update the stats for a given (feature, bin), using the given label. + * @param featureOffset For ordered features, this is a pre-computed feature offset + * from [[getFeatureOffset]]. + * For unordered features, this is a pre-computed + * (feature, left/right child) offset from + * [[getLeftRightFeatureOffsets]]. + */ + def featureUpdate( + featureOffset: Int, + binIndex: Int, + label: Double, + instanceWeight: Double): Unit = { + impurityAggregator.update(allStats, featureOffset + binIndex * statsSize, + label, instanceWeight) + } + + /** + * Pre-compute feature offset for use with [[featureUpdate]]. + * For ordered features only. + */ + def getFeatureOffset(featureIndex: Int): Int = featureOffsets(featureIndex) + + /** + * Pre-compute feature offset for use with [[featureUpdate]]. + * For unordered features only. + */ + def getLeftRightFeatureOffsets(featureIndex: Int): (Int, Int) = { + val baseOffset = featureOffsets(featureIndex) + (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize) + } + + /** + * For a given feature, merge the stats for two bins. + * @param featureOffset For ordered features, this is a pre-computed feature offset + * from [[getFeatureOffset]]. + * For unordered features, this is a pre-computed + * (feature, left/right child) offset from + * [[getLeftRightFeatureOffsets]]. + * @param binIndex The other bin is merged into this bin. + * @param otherBinIndex This bin is not modified. + */ + def mergeForFeature(featureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = { + impurityAggregator.merge(allStats, featureOffset + binIndex * statsSize, + featureOffset + otherBinIndex * statsSize) + } + + /** + * Merge this aggregator with another, and returns this aggregator. + * This method modifies this aggregator in-place. + */ + def merge(other: DTStatsAggregator): DTStatsAggregator = { + require(allStatsSize == other.allStatsSize, + s"DTStatsAggregator.merge requires that both aggregators have the same length stats vectors." + + s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.") + var i = 0 + // TODO: Test BLAS.axpy + while (i < allStatsSize) { + allStats(i) += other.allStats(i) + i += 1 + } + + var j = 0 + // TODO: Test BLAS.axpy + while (j < statsSize) { + parentStats(j) += other.parentStats(j) + j += 1 + } + + + this + } + + def totalsForFeature(featureOffset: Int): Array[Double] = { + val numBins = metadata.numSplits(featureOffset) + var i = 0 + val totals = Array.fill[Double](statsSize)(0.0) + while (i < numBins) { + var j = 0 + while (j < statsSize) { + totals(j) += allStats(featureOffset + i*statsSize + j) + j +=1 + } + i += 1 + } + totals + } +} \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala new file mode 100644 index 0000000000000..644c88f34e92a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.Logging +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.impurity.Impurity +import org.apache.spark.rdd.RDD + +import scala.collection.mutable + +/** + * Learning and dataset metadata for DecisionTree. + * + * @param numClasses For classification: labels can take values {0, ..., numClasses - 1}. + * For regression: fixed at 0 (no meaning). + * @param maxBins Maximum number of bins, for all features. + * @param featureArity Map: categorical feature index --> arity. + * I.e., the feature takes values in {0, ..., arity - 1}. + * @param numBins Number of bins for each feature. + */ +private[spark] class DecisionTreeMetadata( + val numFeatures: Int, + val numExamples: Long, + val numClasses: Int, + val maxBins: Int, + val featureArity: Map[Int, Int], + val unorderedFeatures: Set[Int], + val numBins: Array[Int], + val impurity: Impurity, + val quantileStrategy: QuantileStrategy, + val maxDepth: Int, + val minInstancesPerNode: Int, + val minInfoGain: Double, + val numTrees: Int, + val numFeaturesPerNode: Int) extends Serializable { + + def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex) + + def isClassification: Boolean = numClasses >= 2 + + def isMulticlass: Boolean = numClasses > 2 + + def isMulticlassWithCategoricalFeatures: Boolean = isMulticlass && (featureArity.size > 0) + + def isCategorical(featureIndex: Int): Boolean = featureArity.contains(featureIndex) + + def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex) + + /** + * Number of splits for the given feature. + * For unordered features, there are 1 bin per split. + * For ordered features, there is 1 more bin than split. + */ + def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) { + numBins(featureIndex) + } else { + numBins(featureIndex) - 1 + } + + + /** + * Set number of splits for a continuous feature. + * For a continuous feature, number of bins is number of splits plus 1. + */ + def setNumSplits(featureIndex: Int, numSplits: Int) { + require(isContinuous(featureIndex), + s"Only number of bin for a continuous feature can be set.") + numBins(featureIndex) = numSplits + 1 + } + + /** + * Indicates if feature subsampling is being used. + */ + def subsamplingFeatures: Boolean = numFeatures != numFeaturesPerNode + +} + +private[spark] object DecisionTreeMetadata extends Logging { + + /** + * Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters. + * This computes which categorical features will be ordered vs. unordered, + * as well as the number of splits and bins for each feature. + */ + def buildMetadata( + input: RDD[LabeledPoint], + strategy: Strategy, + numTrees: Int, + featureSubsetStrategy: String): DecisionTreeMetadata = { + + val numFeatures = input.map(_.features.size).take(1).headOption.getOrElse { + throw new IllegalArgumentException(s"DecisionTree requires size of input RDD > 0, " + + s"but was given by empty one.") + } + val numExamples = input.count() + val numClasses = strategy.algo match { + case Classification => strategy.numClasses + case Regression => 0 + } + + val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt + if (maxPossibleBins < strategy.maxBins) { + logWarning(s"DecisionTree reducing maxBins from ${strategy.maxBins} to $maxPossibleBins" + + s" (= number of training instances)") + } + + // We check the number of bins here against maxPossibleBins. + // This needs to be checked here instead of in Strategy since maxPossibleBins can be modified + // based on the number of training examples. + if (strategy.categoricalFeaturesInfo.nonEmpty) { + val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max + val maxCategory = + strategy.categoricalFeaturesInfo.find(_._2 == maxCategoriesPerFeature).get._1 + require(maxCategoriesPerFeature <= maxPossibleBins, + s"DecisionTree requires maxBins (= $maxPossibleBins) to be at least as large as the " + + s"number of values in each categorical feature, but categorical feature $maxCategory " + + s"has $maxCategoriesPerFeature values. Considering remove this and other categorical " + + "features with a large number of values, or add more training examples.") + } + + val unorderedFeatures = new mutable.HashSet[Int]() + val numBins = Array.fill[Int](numFeatures)(maxPossibleBins) + if (numClasses > 2) { + println("Multiclass") + println(strategy.categoricalFeaturesInfo) + // Multiclass classification + val maxCategoriesForUnorderedFeature = + ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt + strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => + // Hack: If a categorical feature has only 1 category, we treat it as continuous. + // TODO(SPARK-9957): Handle this properly by filtering out those features. + // TODO: update this check or not? Change wording of comments + if (numCategories > 1) { + // Decide if some categorical features should be treated as unordered features, + // which require 2 * ((1 << numCategories - 1) - 1) bins. + // We do this check with log values to prevent overflows in case numCategories is large. + // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins + println(maxCategoriesForUnorderedFeature.toString+numCategories.toString+"[") + if (numCategories <= maxCategoriesForUnorderedFeature) { + unorderedFeatures.add(featureIndex) + numBins(featureIndex) = numUnorderedBins(numCategories) + } else { + numBins(featureIndex) = numCategories + } + } + } + } else { + // Binary classification or regression + strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => + // If a categorical feature has only 1 category, we treat it as continuous: SPARK-9957 + if (numCategories > 1) { + numBins(featureIndex) = numCategories + } + } + } + + // Set number of features to use per node (for random forests). + val _featureSubsetStrategy = featureSubsetStrategy match { + case "auto" => + if (numTrees == 1) { + "all" + } else { + if (strategy.algo == Classification) { + "sqrt" + } else { + "onethird" + } + } + case _ => featureSubsetStrategy + } + val numFeaturesPerNode: Int = _featureSubsetStrategy match { + case "all" => numFeatures + case "sqrt" => math.sqrt(numFeatures).ceil.toInt + case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt) + case "onethird" => (numFeatures / 3.0).ceil.toInt + } + + new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max, + strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins, + strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth, + strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode) + } + + /** + * Version of [[DecisionTreeMetadata#buildMetadata]] for DecisionTree. + */ + def buildMetadata( + input: RDD[LabeledPoint], + strategy: Strategy): DecisionTreeMetadata = { + buildMetadata(input, strategy, numTrees = 1, featureSubsetStrategy = "all") + } + + /** + * Given the arity of a categorical feature (arity = number of categories), + * return the number of bins for the feature if it is to be treated as an unordered feature. + * There is 1 split for every partitioning of categories into 2 disjoint, non-empty sets; + * there are math.pow(2, arity - 1) - 1 such splits. + * Each split has 1 corresponding bin. + */ + def numUnorderedBins(arity: Int): Int = (1 << arity - 1) - 1 + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index aa3ece58f6edd..ed89208a84762 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -254,22 +254,12 @@ private[ml] object RandomForest extends Logging { val numSplits = agg.metadata.numSplits(featureIndex) val featureSplits = splits(featureIndex) var splitIndex = 0 -// while (splitIndex < numSplits) { -// if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) { -// agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) -// } else { -// agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) -// } -// splitIndex += 1 -// } - while (splitIndex < numSplits) { if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) { agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) } splitIndex += 1 } - } else { // Ordered feature val binIndex = treePoint.binnedFeatures(featureIndex) @@ -406,6 +396,7 @@ private[ml] object RandomForest extends Logging { mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits, metadata.unorderedFeatures, instanceWeight, featuresForNode) } + agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight) } } @@ -670,7 +661,7 @@ private[ml] object RandomForest extends Logging { // Calculate InformationGain and ImpurityStats if current node is top node val level = LearningNode.indexToLevel(node.id) - var gainAndImpurityStats: ImpurityStats = if (level ==0) { + var gainAndImpurityStats: ImpurityStats = if (level == 0) { null } else { node.stats @@ -714,34 +705,16 @@ private[ml] object RandomForest extends Logging { binAggregates.getLeftRightFeatureOffsets(featureIndexIdx) // SETH - val leftSplits = Range(0, numSplits).map { splitIndex => - binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) - } - val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) - var splitIndex = 0 - while (splitIndex < numSplits - 1) { - binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) - splitIndex += 1 - } - - val (bestFeatureSplitIndex, bestFeatureGainStats) = leftSplits.map { leftChildStats => - val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits - 1) - rightChildStats.subtract(leftChildStats) - gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, - leftChildStats, rightChildStats, binAggregates.metadata) - (splitIndex, gainAndImpurityStats) - }.maxBy(_._2.gain) + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { splitIndex => + val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) + val rightChildStats = binAggregates.getParentImpurityCalculator().subtract(leftChildStats) + gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIndex, gainAndImpurityStats) + }.maxBy(_._2.gain) // SETH -// val (bestFeatureSplitIndex, bestFeatureGainStats) = -// Range(0, numSplits).map { splitIndex => -// val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) -// val rightChildStats = -// binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) -// gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, -// leftChildStats, rightChildStats, binAggregates.metadata) -// (splitIndex, gainAndImpurityStats) -// }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else { // Ordered categorical feature diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala index 9fa27e5e1f721..9b971a0b85439 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.tree.impl import org.apache.spark.ml.tree.{ContinuousSplit, Split} import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata +//import org.apache.spark.ml.tree.impl.DecisionTreeMetadata import org.apache.spark.rdd.RDD diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index f11bd4fbd46a6..6559bece3bed0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -381,16 +381,6 @@ object DecisionTree extends Serializable with Logging { } splitIndex += 1 } -// while (splitIndex < numSplits) { -// if (splits(featureIndex)(splitIndex).categories.contains(featureValue)) { -// agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, -// instanceWeight) -// } else { -// agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label, -// instanceWeight) -// } -// splitIndex += 1 -// } } else { // Ordered feature val binIndex = treePoint.binnedFeatures(featureIndex) @@ -460,7 +450,7 @@ object DecisionTree extends Serializable with Logging { */ private[tree] def findBestSplits( input: RDD[BaggedPoint[TreePoint]], - metadata: DecisionTreeMetadata, + metadata: impl.DecisionTreeMetadata, topNodes: Array[Node], nodesForGroup: Map[Int, Array[Node]], treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]], @@ -529,6 +519,7 @@ object DecisionTree extends Serializable with Logging { mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits, metadata.unorderedFeatures, instanceWeight, featuresForNode) } + agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight) } } @@ -739,7 +730,7 @@ object DecisionTree extends Serializable with Logging { private def calculateGainForSplit( leftImpurityCalculator: ImpurityCalculator, rightImpurityCalculator: ImpurityCalculator, - metadata: DecisionTreeMetadata, + metadata: impl.DecisionTreeMetadata, impurity: Double): InformationGainStats = { val leftCount = leftImpurityCalculator.count val rightCount = rightImpurityCalculator.count @@ -857,38 +848,17 @@ object DecisionTree extends Serializable with Logging { val (leftChildOffset, rightChildOffset) = binAggregates.getLeftRightFeatureOffsets(featureIndexIdx) - // SETH - val leftSplits = Range(0, numSplits).map { splitIndex => - binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) - } - val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) - var splitIndex = 0 - while (splitIndex < numSplits - 1) { - binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) - splitIndex += 1 - } - - val (bestFeatureSplitIndex, bestFeatureGainStats) = leftSplits.map { leftChildStats => - val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits - 1) - rightChildStats.subtract(leftChildStats) - predictWithImpurity = Some(predictWithImpurity.getOrElse( - calculatePredictImpurity(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) - (splitIndex, gainStats) - }.maxBy(_._2.gain) - // SETH - -// val (bestFeatureSplitIndex, bestFeatureGainStats) = -// Range(0, numSplits).map { splitIndex => -// val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { splitIndex => + val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) + val rightChildStats = binAggregates.getParentImpurityCalculator().subtract(leftChildStats) // val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) -// predictWithImpurity = Some(predictWithImpurity.getOrElse( -// calculatePredictImpurity(leftChildStats, rightChildStats))) -// val gainStats = calculateGainForSplit(leftChildStats, -// rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) -// (splitIndex, gainStats) -// }.maxBy(_._2.gain) + predictWithImpurity = Some(predictWithImpurity.getOrElse( + calculatePredictImpurity(leftChildStats, rightChildStats))) + val gainStats = calculateGainForSplit(leftChildStats, + rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) + (splitIndex, gainStats) + }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else { // Ordered categorical feature @@ -1055,7 +1025,7 @@ object DecisionTree extends Serializable with Logging { */ protected[tree] def findSplitsBins( input: RDD[LabeledPoint], - metadata: DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) = { + metadata: impl.DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) = { logDebug("isMulticlass = " + metadata.isMulticlass) @@ -1202,7 +1172,7 @@ object DecisionTree extends Serializable with Logging { */ private[tree] def findSplitsForContinuousFeature( featureSamples: Array[Double], - metadata: DecisionTreeMetadata, + metadata: impl.DecisionTreeMetadata, featureIndex: Int): Array[Double] = { require(metadata.isContinuous(featureIndex), "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index b7714b382a594..c23492a2dff0b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -135,7 +135,7 @@ private class RandomForest ( val retaggedInput = input.retag(classOf[LabeledPoint]) val metadata = - DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy) + impl.DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy) logDebug("algo = " + strategy.algo) logDebug("numTrees = " + numTrees) logDebug("seed = " + seed) @@ -468,7 +468,7 @@ object RandomForest extends Serializable with Logging { private[tree] def selectNodesToSplit( nodeQueue: mutable.Queue[(Int, Node)], maxMemoryUsage: Long, - metadata: DecisionTreeMetadata, + metadata: impl.DecisionTreeMetadata, rng: scala.util.Random): (Map[Int, Array[Node]], Map[Int, Map[Int, NodeIndexInfo]]) = { // Collect some nodes to split: // nodesForGroup(treeIndex) = nodes to split @@ -510,7 +510,7 @@ object RandomForest extends Serializable with Logging { * If None, then use all features. */ private[tree] def aggregateSizeForNode( - metadata: DecisionTreeMetadata, + metadata: impl.DecisionTreeMetadata, featureSubset: Option[Array[Int]]): Long = { val totalBins = if (featureSubset.nonEmpty) { featureSubset.get.map(featureIndex => metadata.numBins(featureIndex).toLong).sum diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala index 3fa0011c82d7f..be40910da82ec 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala @@ -79,6 +79,8 @@ private[spark] class DTStatsAggregator( */ private val allStats: Array[Double] = new Array[Double](allStatsSize) + private val parentStats: Array[Double] = new Array[Double](statsSize) + /** * Get an [[ImpurityCalculator]] for a given (node, feature, bin). @@ -92,6 +94,10 @@ private[spark] class DTStatsAggregator( impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize) } + def getParentImpurityCalculator(): ImpurityCalculator = { + impurityAggregator.getCalculator(parentStats) + } + /** * Update the stats for a given (feature, bin) for ordered features, using the given label. */ @@ -99,6 +105,9 @@ private[spark] class DTStatsAggregator( val i = featureOffsets(featureIndex) + binIndex * statsSize impurityAggregator.update(allStats, i, label, instanceWeight) } + def updateParent(label: Double, instanceWeight: Double): Unit = { + impurityAggregator.update(parentStats, 0, label, instanceWeight) + } /** * Faster version of [[update]]. @@ -162,17 +171,21 @@ private[spark] class DTStatsAggregator( allStats(i) += other.allStats(i) i += 1 } - this - } - def copy: DTStatsAggregator = { - val copyAggregator = new DTStatsAggregator(this.metadata, featureSubset) - copyAggregator.merge(this) + require(statsSize == other.statsSize, + s"DTStatsAggregator.merge requires that both aggregators have the same length parent stats vectors." + + s" This aggregator is of length $statsSize, but the other is ${other.statsSize}.") + var j = 0 + while (j < statsSize) { + parentStats(j) += other.parentStats(j) + j += 1 + } + + this } - def totalsForFeature(featureIndex: Int): Array[Double] = { - val numBins = metadata.numSplits(featureIndex) - val featureOffset = featureOffsets(featureIndex) + def totalsForFeature(featureOffset: Int): Array[Double] = { + val numBins = metadata.numSplits(featureOffset) var i = 0 val totals = Array.fill[Double](statsSize)(0.0) while (i < numBins) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 73df6b054a8ce..252c38d45d1a9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -85,7 +85,7 @@ object Entropy extends Impurity { * Note: Instances of this class do not hold the data; they operate on views of the data. * @param numClasses Number of classes for label. */ -private[tree] class EntropyAggregator(numClasses: Int) +private[spark] class EntropyAggregator(numClasses: Int) extends ImpurityAggregator(numClasses) with Serializable { /** @@ -114,6 +114,10 @@ private[tree] class EntropyAggregator(numClasses: Int) new EntropyCalculator(allStats.view(offset, offset + statsSize).toArray) } + def getCalculator(sufficientStats: Array[Double]): EntropyCalculator = { + new EntropyCalculator(sufficientStats) + } + } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index f21845b21a802..20400da3da48b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -81,7 +81,7 @@ object Gini extends Impurity { * Note: Instances of this class do not hold the data; they operate on views of the data. * @param numClasses Number of classes for label. */ -private[tree] class GiniAggregator(numClasses: Int) +private[spark] class GiniAggregator(numClasses: Int) extends ImpurityAggregator(numClasses) with Serializable { /** @@ -110,6 +110,10 @@ private[tree] class GiniAggregator(numClasses: Int) new GiniCalculator(allStats.view(offset, offset + statsSize).toArray) } + def getCalculator(sufficientStats: Array[Double]): GiniCalculator = { + new GiniCalculator(sufficientStats.clone()) + } + } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 4637dcceea7f8..663a0bd2d13f1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -90,6 +90,12 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser */ def getCalculator(allStats: Array[Double], offset: Int): ImpurityCalculator + /** + * Get an [[ImpurityCalculator]] for a (node, feature, bin). + * @param sufficientStats Flat stats array, with stats for this (node, feature, bin) contiguous. + */ + def getCalculator(sufficientStats: Array[Double]): ImpurityCalculator + } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 09017d482a73c..0be1ebb609dea 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -71,7 +71,7 @@ object Variance extends Impurity { * in order to compute impurity from a sample. * Note: Instances of this class do not hold the data; they operate on views of the data. */ -private[tree] class VarianceAggregator() +private[spark] class VarianceAggregator() extends ImpurityAggregator(statsSize = 3) with Serializable { /** @@ -94,6 +94,10 @@ private[tree] class VarianceAggregator() new VarianceCalculator(allStats.view(offset, offset + statsSize).toArray) } + def getCalculator(sufficientStats: Array[Double]): VarianceCalculator = { + new VarianceCalculator(sufficientStats) + } + } /** diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 5518bdf527c8a..7c723792cedaa 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.tree +import org.apache.spark.ml.tree.impl + import scala.collection.JavaConverters._ import scala.collection.mutable @@ -45,7 +47,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Gini, 3, 2, 100) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) @@ -67,7 +69,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2)) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) @@ -91,7 +93,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) @@ -111,7 +113,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { test("find splits for a continuous feature") { // find splits for normal case { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new impl.DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), Array(6), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 @@ -128,7 +130,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits should not return identical splits // when there are not enough split candidates, reduce the number of splits in metadata { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new impl.DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), Array(5), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 @@ -142,7 +144,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits when most samples close to the minimum { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new impl.DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 @@ -156,7 +158,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits when most samples close to the maximum { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new impl.DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 @@ -181,7 +183,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(metadata.isUnordered(featureIndex = 0)) assert(metadata.isUnordered(featureIndex = 1)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) @@ -241,7 +243,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) // 2^(10-1) - 1 > 100, so categorical features will be ordered - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) @@ -262,7 +264,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1, numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) - val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val metadata = impl.DecisionTreeMetadata.buildMetadata(input, strategy) val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) @@ -305,7 +307,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) - val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val metadata = impl.DecisionTreeMetadata.buildMetadata(input, strategy) val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) @@ -372,7 +374,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) @@ -453,7 +455,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) @@ -487,7 +489,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) @@ -531,7 +533,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Gini, maxDepth = 3, numClasses = 2, maxBins = 100) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) @@ -555,7 +557,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Gini, maxDepth = 3, numClasses = 2, maxBins = 100) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) @@ -580,7 +582,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Entropy, maxDepth = 3, numClasses = 2, maxBins = 100) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) @@ -605,7 +607,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Entropy, maxDepth = 3, numClasses = 2, maxBins = 100) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) @@ -629,7 +631,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClasses = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(strategy.isMulticlassClassification) assert(metadata.isUnordered(featureIndex = 0)) assert(metadata.isUnordered(featureIndex = 1)) @@ -686,7 +688,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { numClasses = 3, maxBins = maxBins, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) assert(strategy.isMulticlassClassification) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(metadata.isUnordered(featureIndex = 0)) assert(metadata.isUnordered(featureIndex = 1)) @@ -714,7 +716,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClasses = 3, maxBins = 100) assert(strategy.isMulticlassClassification) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) val model = DecisionTree.train(rdd, strategy) DecisionTreeSuite.validateClassifier(model, arr, 0.9) @@ -735,7 +737,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClasses = 3, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3)) assert(strategy.isMulticlassClassification) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(metadata.isUnordered(featureIndex = 0)) val model = DecisionTree.train(rdd, strategy) @@ -757,7 +759,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { numClasses = 3, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) assert(strategy.isMulticlassClassification) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index c72fc9bb4f5d0..acc37c1ef0c31 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.tree +import org.apache.spark.ml.tree.impl + import scala.collection.mutable import org.apache.spark.SparkFunSuite @@ -121,7 +123,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val seeds = Array(123, 5354, 230, 349867, 23987) val maxMemoryUsage: Long = 128 * 1024L * 1024L val metadata = - DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees, featureSubsetStrategy) + impl.DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees, featureSubsetStrategy) seeds.foreach { seed => val failString = s"Failed on test with:" + s"numTrees=$numTrees, featureSubsetStrategy=$featureSubsetStrategy," + From 072b2aec6726ae70b9b4b9ea495cb4d176e67f76 Mon Sep 17 00:00:00 2001 From: sethah Date: Mon, 2 Nov 2015 12:48:29 -0800 Subject: [PATCH 03/12] reverting to mllib based change --- .../ml/tree/impl/DTStatsAggregator.scala | 200 ---------------- .../ml/tree/impl/DecisionTreeMetadata.scala | 221 ------------------ .../spark/ml/tree/impl/RandomForest.scala | 6 - .../apache/spark/ml/tree/impl/TreePoint.scala | 2 +- .../spark/mllib/tree/DecisionTree.scala | 8 +- .../spark/mllib/tree/RandomForest.scala | 6 +- .../mllib/tree/impl/DTStatsAggregator.scala | 15 -- .../tree/impl/DecisionTreeMetadata.scala | 4 +- .../spark/mllib/tree/DecisionTreeSuite.scala | 48 ++-- .../spark/mllib/tree/RandomForestSuite.scala | 4 +- 10 files changed, 34 insertions(+), 480 deletions(-) delete mode 100644 mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala delete mode 100644 mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala deleted file mode 100644 index f41ebb4b71707..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala +++ /dev/null @@ -1,200 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.tree.impl - -import org.apache.spark.mllib.tree.impurity._ - - - -/** - * DecisionTree statistics aggregator for a node. - * This holds a flat array of statistics for a set of (features, bins) - * and helps with indexing. - * This class is abstract to support learning with and without feature subsampling. - */ -private[spark] class DTStatsAggregator( - val metadata: DecisionTreeMetadata, - featureSubset: Option[Array[Int]]) extends Serializable { - - /** - * [[ImpurityAggregator]] instance specifying the impurity type. - */ - val impurityAggregator: ImpurityAggregator = metadata.impurity match { - case Gini => new GiniAggregator(metadata.numClasses) - case Entropy => new EntropyAggregator(metadata.numClasses) - case Variance => new VarianceAggregator() - case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}") - } - - /** - * Number of elements (Double values) used for the sufficient statistics of each bin. - */ - private val statsSize: Int = impurityAggregator.statsSize - - /** - * Number of bins for each feature. This is indexed by the feature index. - */ - private val numBins: Array[Int] = { - if (featureSubset.isDefined) { - featureSubset.get.map(metadata.numBins(_)) - } else { - metadata.numBins - } - } - - /** - * Offset for each feature for calculating indices into the [[allStats]] array. - */ - private val featureOffsets: Array[Int] = { - numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins) - } - - /** - * Total number of elements stored in this aggregator - */ - private val allStatsSize: Int = featureOffsets.last - - /** - * Flat array of elements. - * Index for start of stats for a (feature, bin) is: - * index = featureOffsets(featureIndex) + binIndex * statsSize - * Note: For unordered features, - * the left child stats have binIndex in [0, numBins(featureIndex) / 2)) - * and the right child stats in [numBins(featureIndex) / 2), numBins(featureIndex)) - */ - private val allStats: Array[Double] = new Array[Double](allStatsSize) - - private val parentStats: Array[Double] = new Array[Double](statsSize) - - - /** - * Get an [[ImpurityCalculator]] for a given (node, feature, bin). - * @param featureOffset For ordered features, this is a pre-computed (node, feature) offset - * from [[getFeatureOffset]]. - * For unordered features, this is a pre-computed - * (node, feature, left/right child) offset from - * [[getLeftRightFeatureOffsets]]. - */ - def getImpurityCalculator(featureOffset: Int, binIndex: Int): ImpurityCalculator = { - impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize) - } - - def getParentImpurityCalculator(): ImpurityCalculator = { - impurityAggregator.getCalculator(parentStats) - } - - /** - * Update the stats for a given (feature, bin) for ordered features, using the given label. - */ - def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = { - val i = featureOffsets(featureIndex) + binIndex * statsSize - impurityAggregator.update(allStats, i, label, instanceWeight) - } - def updateParent(label: Double, instanceWeight: Double): Unit = { - impurityAggregator.update(parentStats, 0, label, instanceWeight) - } - - /** - * Faster version of [[update]]. - * Update the stats for a given (feature, bin), using the given label. - * @param featureOffset For ordered features, this is a pre-computed feature offset - * from [[getFeatureOffset]]. - * For unordered features, this is a pre-computed - * (feature, left/right child) offset from - * [[getLeftRightFeatureOffsets]]. - */ - def featureUpdate( - featureOffset: Int, - binIndex: Int, - label: Double, - instanceWeight: Double): Unit = { - impurityAggregator.update(allStats, featureOffset + binIndex * statsSize, - label, instanceWeight) - } - - /** - * Pre-compute feature offset for use with [[featureUpdate]]. - * For ordered features only. - */ - def getFeatureOffset(featureIndex: Int): Int = featureOffsets(featureIndex) - - /** - * Pre-compute feature offset for use with [[featureUpdate]]. - * For unordered features only. - */ - def getLeftRightFeatureOffsets(featureIndex: Int): (Int, Int) = { - val baseOffset = featureOffsets(featureIndex) - (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize) - } - - /** - * For a given feature, merge the stats for two bins. - * @param featureOffset For ordered features, this is a pre-computed feature offset - * from [[getFeatureOffset]]. - * For unordered features, this is a pre-computed - * (feature, left/right child) offset from - * [[getLeftRightFeatureOffsets]]. - * @param binIndex The other bin is merged into this bin. - * @param otherBinIndex This bin is not modified. - */ - def mergeForFeature(featureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = { - impurityAggregator.merge(allStats, featureOffset + binIndex * statsSize, - featureOffset + otherBinIndex * statsSize) - } - - /** - * Merge this aggregator with another, and returns this aggregator. - * This method modifies this aggregator in-place. - */ - def merge(other: DTStatsAggregator): DTStatsAggregator = { - require(allStatsSize == other.allStatsSize, - s"DTStatsAggregator.merge requires that both aggregators have the same length stats vectors." - + s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.") - var i = 0 - // TODO: Test BLAS.axpy - while (i < allStatsSize) { - allStats(i) += other.allStats(i) - i += 1 - } - - var j = 0 - // TODO: Test BLAS.axpy - while (j < statsSize) { - parentStats(j) += other.parentStats(j) - j += 1 - } - - - this - } - - def totalsForFeature(featureOffset: Int): Array[Double] = { - val numBins = metadata.numSplits(featureOffset) - var i = 0 - val totals = Array.fill[Double](statsSize)(0.0) - while (i < numBins) { - var j = 0 - while (j < statsSize) { - totals(j) += allStats(featureOffset + i*statsSize + j) - j +=1 - } - i += 1 - } - totals - } -} \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala deleted file mode 100644 index 644c88f34e92a..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala +++ /dev/null @@ -1,221 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.tree.impl - -import org.apache.spark.Logging -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.configuration.Strategy -import org.apache.spark.mllib.tree.impurity.Impurity -import org.apache.spark.rdd.RDD - -import scala.collection.mutable - -/** - * Learning and dataset metadata for DecisionTree. - * - * @param numClasses For classification: labels can take values {0, ..., numClasses - 1}. - * For regression: fixed at 0 (no meaning). - * @param maxBins Maximum number of bins, for all features. - * @param featureArity Map: categorical feature index --> arity. - * I.e., the feature takes values in {0, ..., arity - 1}. - * @param numBins Number of bins for each feature. - */ -private[spark] class DecisionTreeMetadata( - val numFeatures: Int, - val numExamples: Long, - val numClasses: Int, - val maxBins: Int, - val featureArity: Map[Int, Int], - val unorderedFeatures: Set[Int], - val numBins: Array[Int], - val impurity: Impurity, - val quantileStrategy: QuantileStrategy, - val maxDepth: Int, - val minInstancesPerNode: Int, - val minInfoGain: Double, - val numTrees: Int, - val numFeaturesPerNode: Int) extends Serializable { - - def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex) - - def isClassification: Boolean = numClasses >= 2 - - def isMulticlass: Boolean = numClasses > 2 - - def isMulticlassWithCategoricalFeatures: Boolean = isMulticlass && (featureArity.size > 0) - - def isCategorical(featureIndex: Int): Boolean = featureArity.contains(featureIndex) - - def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex) - - /** - * Number of splits for the given feature. - * For unordered features, there are 1 bin per split. - * For ordered features, there is 1 more bin than split. - */ - def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) { - numBins(featureIndex) - } else { - numBins(featureIndex) - 1 - } - - - /** - * Set number of splits for a continuous feature. - * For a continuous feature, number of bins is number of splits plus 1. - */ - def setNumSplits(featureIndex: Int, numSplits: Int) { - require(isContinuous(featureIndex), - s"Only number of bin for a continuous feature can be set.") - numBins(featureIndex) = numSplits + 1 - } - - /** - * Indicates if feature subsampling is being used. - */ - def subsamplingFeatures: Boolean = numFeatures != numFeaturesPerNode - -} - -private[spark] object DecisionTreeMetadata extends Logging { - - /** - * Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters. - * This computes which categorical features will be ordered vs. unordered, - * as well as the number of splits and bins for each feature. - */ - def buildMetadata( - input: RDD[LabeledPoint], - strategy: Strategy, - numTrees: Int, - featureSubsetStrategy: String): DecisionTreeMetadata = { - - val numFeatures = input.map(_.features.size).take(1).headOption.getOrElse { - throw new IllegalArgumentException(s"DecisionTree requires size of input RDD > 0, " + - s"but was given by empty one.") - } - val numExamples = input.count() - val numClasses = strategy.algo match { - case Classification => strategy.numClasses - case Regression => 0 - } - - val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt - if (maxPossibleBins < strategy.maxBins) { - logWarning(s"DecisionTree reducing maxBins from ${strategy.maxBins} to $maxPossibleBins" + - s" (= number of training instances)") - } - - // We check the number of bins here against maxPossibleBins. - // This needs to be checked here instead of in Strategy since maxPossibleBins can be modified - // based on the number of training examples. - if (strategy.categoricalFeaturesInfo.nonEmpty) { - val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max - val maxCategory = - strategy.categoricalFeaturesInfo.find(_._2 == maxCategoriesPerFeature).get._1 - require(maxCategoriesPerFeature <= maxPossibleBins, - s"DecisionTree requires maxBins (= $maxPossibleBins) to be at least as large as the " + - s"number of values in each categorical feature, but categorical feature $maxCategory " + - s"has $maxCategoriesPerFeature values. Considering remove this and other categorical " + - "features with a large number of values, or add more training examples.") - } - - val unorderedFeatures = new mutable.HashSet[Int]() - val numBins = Array.fill[Int](numFeatures)(maxPossibleBins) - if (numClasses > 2) { - println("Multiclass") - println(strategy.categoricalFeaturesInfo) - // Multiclass classification - val maxCategoriesForUnorderedFeature = - ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt - strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => - // Hack: If a categorical feature has only 1 category, we treat it as continuous. - // TODO(SPARK-9957): Handle this properly by filtering out those features. - // TODO: update this check or not? Change wording of comments - if (numCategories > 1) { - // Decide if some categorical features should be treated as unordered features, - // which require 2 * ((1 << numCategories - 1) - 1) bins. - // We do this check with log values to prevent overflows in case numCategories is large. - // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins - println(maxCategoriesForUnorderedFeature.toString+numCategories.toString+"[") - if (numCategories <= maxCategoriesForUnorderedFeature) { - unorderedFeatures.add(featureIndex) - numBins(featureIndex) = numUnorderedBins(numCategories) - } else { - numBins(featureIndex) = numCategories - } - } - } - } else { - // Binary classification or regression - strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => - // If a categorical feature has only 1 category, we treat it as continuous: SPARK-9957 - if (numCategories > 1) { - numBins(featureIndex) = numCategories - } - } - } - - // Set number of features to use per node (for random forests). - val _featureSubsetStrategy = featureSubsetStrategy match { - case "auto" => - if (numTrees == 1) { - "all" - } else { - if (strategy.algo == Classification) { - "sqrt" - } else { - "onethird" - } - } - case _ => featureSubsetStrategy - } - val numFeaturesPerNode: Int = _featureSubsetStrategy match { - case "all" => numFeatures - case "sqrt" => math.sqrt(numFeatures).ceil.toInt - case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt) - case "onethird" => (numFeatures / 3.0).ceil.toInt - } - - new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max, - strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins, - strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth, - strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode) - } - - /** - * Version of [[DecisionTreeMetadata#buildMetadata]] for DecisionTree. - */ - def buildMetadata( - input: RDD[LabeledPoint], - strategy: Strategy): DecisionTreeMetadata = { - buildMetadata(input, strategy, numTrees = 1, featureSubsetStrategy = "all") - } - - /** - * Given the arity of a categorical feature (arity = number of categories), - * return the number of bins for the feature if it is to be treated as an unordered feature. - * There is 1 split for every partitioning of categories into 2 disjoint, non-empty sets; - * there are math.pow(2, arity - 1) - 1 such splits. - * Each split has 1 corresponding bin. - */ - def numUnorderedBins(arity: Int): Int = (1 << arity - 1) - 1 - -} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index ed89208a84762..e279192003d09 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -81,10 +81,6 @@ private[ml] object RandomForest extends Logging { s"\t$featureIndex\t${metadata.numBins(featureIndex)}" }.mkString("\n")) - println("*****************") - metadata.numBins.foreach(x => printf(x.toString + "/")) - println("*****************") - // Bin feature values (TreePoint representation). // Cache input RDD for speedup during multiple passes. val treeInput = TreePoint.convertToTreeRDD(retaggedInput, splits, metadata) @@ -704,7 +700,6 @@ private[ml] object RandomForest extends Logging { val (leftChildOffset, rightChildOffset) = binAggregates.getLeftRightFeatureOffsets(featureIndexIdx) - // SETH val (bestFeatureSplitIndex, bestFeatureGainStats) = Range(0, numSplits).map { splitIndex => val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) @@ -713,7 +708,6 @@ private[ml] object RandomForest extends Logging { leftChildStats, rightChildStats, binAggregates.metadata) (splitIndex, gainAndImpurityStats) }.maxBy(_._2.gain) - // SETH (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala index 9b971a0b85439..9fa27e5e1f721 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.tree.impl import org.apache.spark.ml.tree.{ContinuousSplit, Split} import org.apache.spark.mllib.regression.LabeledPoint -//import org.apache.spark.ml.tree.impl.DecisionTreeMetadata +import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata import org.apache.spark.rdd.RDD diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 6559bece3bed0..787843bfded74 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -450,7 +450,7 @@ object DecisionTree extends Serializable with Logging { */ private[tree] def findBestSplits( input: RDD[BaggedPoint[TreePoint]], - metadata: impl.DecisionTreeMetadata, + metadata: DecisionTreeMetadata, topNodes: Array[Node], nodesForGroup: Map[Int, Array[Node]], treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]], @@ -730,7 +730,7 @@ object DecisionTree extends Serializable with Logging { private def calculateGainForSplit( leftImpurityCalculator: ImpurityCalculator, rightImpurityCalculator: ImpurityCalculator, - metadata: impl.DecisionTreeMetadata, + metadata: DecisionTreeMetadata, impurity: Double): InformationGainStats = { val leftCount = leftImpurityCalculator.count val rightCount = rightImpurityCalculator.count @@ -1025,7 +1025,7 @@ object DecisionTree extends Serializable with Logging { */ protected[tree] def findSplitsBins( input: RDD[LabeledPoint], - metadata: impl.DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) = { + metadata: DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) = { logDebug("isMulticlass = " + metadata.isMulticlass) @@ -1172,7 +1172,7 @@ object DecisionTree extends Serializable with Logging { */ private[tree] def findSplitsForContinuousFeature( featureSamples: Array[Double], - metadata: impl.DecisionTreeMetadata, + metadata: DecisionTreeMetadata, featureIndex: Int): Array[Double] = { require(metadata.isContinuous(featureIndex), "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index c23492a2dff0b..b7714b382a594 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -135,7 +135,7 @@ private class RandomForest ( val retaggedInput = input.retag(classOf[LabeledPoint]) val metadata = - impl.DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy) + DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy) logDebug("algo = " + strategy.algo) logDebug("numTrees = " + numTrees) logDebug("seed = " + seed) @@ -468,7 +468,7 @@ object RandomForest extends Serializable with Logging { private[tree] def selectNodesToSplit( nodeQueue: mutable.Queue[(Int, Node)], maxMemoryUsage: Long, - metadata: impl.DecisionTreeMetadata, + metadata: DecisionTreeMetadata, rng: scala.util.Random): (Map[Int, Array[Node]], Map[Int, Map[Int, NodeIndexInfo]]) = { // Collect some nodes to split: // nodesForGroup(treeIndex) = nodes to split @@ -510,7 +510,7 @@ object RandomForest extends Serializable with Logging { * If None, then use all features. */ private[tree] def aggregateSizeForNode( - metadata: impl.DecisionTreeMetadata, + metadata: DecisionTreeMetadata, featureSubset: Option[Array[Int]]): Long = { val totalBins = if (featureSubset.nonEmpty) { featureSubset.get.map(featureIndex => metadata.numBins(featureIndex).toLong).sum diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala index be40910da82ec..07f6248c587a7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala @@ -183,19 +183,4 @@ private[spark] class DTStatsAggregator( this } - - def totalsForFeature(featureOffset: Int): Array[Double] = { - val numBins = metadata.numSplits(featureOffset) - var i = 0 - val totals = Array.fill[Double](statsSize)(0.0) - while (i < numBins) { - var j = 0 - while (j < statsSize) { - totals(j) += allStats(featureOffset + i*statsSize + j) - j +=1 - } - i += 1 - } - totals - } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index f241019bd6cb4..963510b270a94 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -71,7 +71,7 @@ private[spark] class DecisionTreeMetadata( * For ordered features, there is 1 more bin than split. */ def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) { - numBins(featureIndex) //>> 1 + numBins(featureIndex) } else { numBins(featureIndex) - 1 } @@ -215,6 +215,6 @@ private[spark] object DecisionTreeMetadata extends Logging { * there are math.pow(2, arity - 1) - 1 such splits. * Each split has 2 corresponding bins. */ - def numUnorderedBins(arity: Int): Int = ((1 << arity - 1) - 1) + def numUnorderedBins(arity: Int): Int = (1 << arity - 1) - 1 } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 7c723792cedaa..5518bdf527c8a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.mllib.tree -import org.apache.spark.ml.tree.impl - import scala.collection.JavaConverters._ import scala.collection.mutable @@ -47,7 +45,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Gini, 3, 2, 100) - val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) @@ -69,7 +67,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2)) - val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) @@ -93,7 +91,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) - val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) @@ -113,7 +111,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { test("find splits for a continuous feature") { // find splits for normal case { - val fakeMetadata = new impl.DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), Array(6), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 @@ -130,7 +128,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits should not return identical splits // when there are not enough split candidates, reduce the number of splits in metadata { - val fakeMetadata = new impl.DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), Array(5), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 @@ -144,7 +142,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits when most samples close to the minimum { - val fakeMetadata = new impl.DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 @@ -158,7 +156,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits when most samples close to the maximum { - val fakeMetadata = new impl.DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 @@ -183,7 +181,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) - val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(metadata.isUnordered(featureIndex = 0)) assert(metadata.isUnordered(featureIndex = 1)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) @@ -243,7 +241,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) // 2^(10-1) - 1 > 100, so categorical features will be ordered - val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) @@ -264,7 +262,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1, numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) - val metadata = impl.DecisionTreeMetadata.buildMetadata(input, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) @@ -307,7 +305,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) - val metadata = impl.DecisionTreeMetadata.buildMetadata(input, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) @@ -374,7 +372,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) @@ -455,7 +453,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) - val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) @@ -489,7 +487,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) - val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) @@ -533,7 +531,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Gini, maxDepth = 3, numClasses = 2, maxBins = 100) - val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) @@ -557,7 +555,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Gini, maxDepth = 3, numClasses = 2, maxBins = 100) - val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) @@ -582,7 +580,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Entropy, maxDepth = 3, numClasses = 2, maxBins = 100) - val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) @@ -607,7 +605,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Entropy, maxDepth = 3, numClasses = 2, maxBins = 100) - val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) @@ -631,7 +629,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClasses = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) - val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(strategy.isMulticlassClassification) assert(metadata.isUnordered(featureIndex = 0)) assert(metadata.isUnordered(featureIndex = 1)) @@ -688,7 +686,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { numClasses = 3, maxBins = maxBins, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) assert(strategy.isMulticlassClassification) - val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(metadata.isUnordered(featureIndex = 0)) assert(metadata.isUnordered(featureIndex = 1)) @@ -716,7 +714,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClasses = 3, maxBins = 100) assert(strategy.isMulticlassClassification) - val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) val model = DecisionTree.train(rdd, strategy) DecisionTreeSuite.validateClassifier(model, arr, 0.9) @@ -737,7 +735,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClasses = 3, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3)) assert(strategy.isMulticlassClassification) - val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(metadata.isUnordered(featureIndex = 0)) val model = DecisionTree.train(rdd, strategy) @@ -759,7 +757,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { numClasses = 3, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) assert(strategy.isMulticlassClassification) - val metadata = impl.DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index acc37c1ef0c31..c72fc9bb4f5d0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.mllib.tree -import org.apache.spark.ml.tree.impl - import scala.collection.mutable import org.apache.spark.SparkFunSuite @@ -123,7 +121,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val seeds = Array(123, 5354, 230, 349867, 23987) val maxMemoryUsage: Long = 128 * 1024L * 1024L val metadata = - impl.DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees, featureSubsetStrategy) + DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees, featureSubsetStrategy) seeds.foreach { seed => val failString = s"Failed on test with:" + s"numTrees=$numTrees, featureSubsetStrategy=$featureSubsetStrategy," + From 2af3dc8369164bf9dea4c01eb1cd15ea1503e708 Mon Sep 17 00:00:00 2001 From: sethah Date: Tue, 3 Nov 2015 16:38:52 -0800 Subject: [PATCH 04/12] style cleanup --- .../org/apache/spark/ml/tree/impl/RandomForest.scala | 3 --- .../org/apache/spark/mllib/tree/DecisionTree.scala | 2 -- .../spark/mllib/tree/impl/DTStatsAggregator.scala | 10 ++++++++-- .../spark/mllib/tree/impl/DecisionTreeMetadata.scala | 3 --- .../org/apache/spark/mllib/tree/impurity/Entropy.scala | 5 ++++- .../org/apache/spark/mllib/tree/impurity/Gini.scala | 5 ++++- .../apache/spark/mllib/tree/impurity/Impurity.scala | 5 ++--- .../apache/spark/mllib/tree/impurity/Variance.scala | 5 ++++- 8 files changed, 22 insertions(+), 16 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index e279192003d09..07b9b7db486ee 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -696,10 +696,8 @@ private[ml] object RandomForest extends Logging { (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else if (binAggregates.metadata.isUnordered(featureIndex)) { // Unordered categorical feature - val (leftChildOffset, rightChildOffset) = binAggregates.getLeftRightFeatureOffsets(featureIndexIdx) - val (bestFeatureSplitIndex, bestFeatureGainStats) = Range(0, numSplits).map { splitIndex => val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) @@ -708,7 +706,6 @@ private[ml] object RandomForest extends Logging { leftChildStats, rightChildStats, binAggregates.metadata) (splitIndex, gainAndImpurityStats) }.maxBy(_._2.gain) - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else { // Ordered categorical feature diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 787843bfded74..0d9c7ed78c20b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -847,12 +847,10 @@ object DecisionTree extends Serializable with Logging { // Unordered categorical feature val (leftChildOffset, rightChildOffset) = binAggregates.getLeftRightFeatureOffsets(featureIndexIdx) - val (bestFeatureSplitIndex, bestFeatureGainStats) = Range(0, numSplits).map { splitIndex => val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) val rightChildStats = binAggregates.getParentImpurityCalculator().subtract(leftChildStats) -// val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) predictWithImpurity = Some(predictWithImpurity.getOrElse( calculatePredictImpurity(leftChildStats, rightChildStats))) val gainStats = calculateGainForSplit(leftChildStats, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala index 07f6248c587a7..8c8d99bb9a011 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala @@ -81,7 +81,6 @@ private[spark] class DTStatsAggregator( private val parentStats: Array[Double] = new Array[Double](statsSize) - /** * Get an [[ImpurityCalculator]] for a given (node, feature, bin). * @param featureOffset For ordered features, this is a pre-computed (node, feature) offset @@ -94,6 +93,9 @@ private[spark] class DTStatsAggregator( impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize) } + /** + * Get an [[ImpurityCalculator]] for a given (node, feature, bin). + */ def getParentImpurityCalculator(): ImpurityCalculator = { impurityAggregator.getCalculator(parentStats) } @@ -105,6 +107,10 @@ private[spark] class DTStatsAggregator( val i = featureOffsets(featureIndex) + binIndex * statsSize impurityAggregator.update(allStats, i, label, instanceWeight) } + + /** + * Update the parent node stats using the given label. + */ def updateParent(label: Double, instanceWeight: Double): Unit = { impurityAggregator.update(parentStats, 0, label, instanceWeight) } @@ -174,7 +180,7 @@ private[spark] class DTStatsAggregator( require(statsSize == other.statsSize, s"DTStatsAggregator.merge requires that both aggregators have the same length parent stats vectors." - + s" This aggregator is of length $statsSize, but the other is ${other.statsSize}.") + + s" This aggregator's parent stats are length $statsSize, but the other is ${other.statsSize}.") var j = 0 while (j < statsSize) { parentStats(j) += other.parentStats(j) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index 963510b270a94..bd2501adf6dc9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -140,8 +140,6 @@ private[spark] object DecisionTreeMetadata extends Logging { val unorderedFeatures = new mutable.HashSet[Int]() val numBins = Array.fill[Int](numFeatures)(maxPossibleBins) if (numClasses > 2) { - println("Multiclass") - println(strategy.categoricalFeaturesInfo) // Multiclass classification val maxCategoriesForUnorderedFeature = ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt @@ -153,7 +151,6 @@ private[spark] object DecisionTreeMetadata extends Logging { // which require 2 * ((1 << numCategories - 1) - 1) bins. // We do this check with log values to prevent overflows in case numCategories is large. // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins - println(maxCategoriesForUnorderedFeature.toString+numCategories.toString+"[") if (numCategories <= maxCategoriesForUnorderedFeature) { unorderedFeatures.add(featureIndex) numBins(featureIndex) = numUnorderedBins(numCategories) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 252c38d45d1a9..a1f47c9103538 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -114,10 +114,13 @@ private[spark] class EntropyAggregator(numClasses: Int) new EntropyCalculator(allStats.view(offset, offset + statsSize).toArray) } + /** + * Get an [[ImpurityCalculator]] for a node. + * @param sufficientStats Sufficient stats array for a node. + */ def getCalculator(sufficientStats: Array[Double]): EntropyCalculator = { new EntropyCalculator(sufficientStats) } - } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 20400da3da48b..5d404d09acd9e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -110,10 +110,13 @@ private[spark] class GiniAggregator(numClasses: Int) new GiniCalculator(allStats.view(offset, offset + statsSize).toArray) } + /** + * Get an [[ImpurityCalculator]] for a node. + * @param sufficientStats Sufficient stats array for a node. + */ def getCalculator(sufficientStats: Array[Double]): GiniCalculator = { new GiniCalculator(sufficientStats.clone()) } - } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 663a0bd2d13f1..0804920d0e718 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -91,11 +91,10 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser def getCalculator(allStats: Array[Double], offset: Int): ImpurityCalculator /** - * Get an [[ImpurityCalculator]] for a (node, feature, bin). - * @param sufficientStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * Get an [[ImpurityCalculator]] for a node. + * @param sufficientStats Sufficient stats array for a node. */ def getCalculator(sufficientStats: Array[Double]): ImpurityCalculator - } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 0be1ebb609dea..2f8c772d92575 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -94,10 +94,13 @@ private[spark] class VarianceAggregator() new VarianceCalculator(allStats.view(offset, offset + statsSize).toArray) } + /** + * Get an [[ImpurityCalculator]] for a node. + * @param sufficientStats Sufficient stats array for a node. + */ def getCalculator(sufficientStats: Array[Double]): VarianceCalculator = { new VarianceCalculator(sufficientStats) } - } /** From 1b0c6b3c9266c13a73dbfd339faaad1fb3de9c26 Mon Sep 17 00:00:00 2001 From: sethah Date: Tue, 3 Nov 2015 16:47:37 -0800 Subject: [PATCH 05/12] changing scopes --- .../scala/org/apache/spark/mllib/tree/impurity/Entropy.scala | 2 +- .../main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala | 2 +- .../scala/org/apache/spark/mllib/tree/impurity/Variance.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index a1f47c9103538..473f6749563d1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -85,7 +85,7 @@ object Entropy extends Impurity { * Note: Instances of this class do not hold the data; they operate on views of the data. * @param numClasses Number of classes for label. */ -private[spark] class EntropyAggregator(numClasses: Int) +private[tree] class EntropyAggregator(numClasses: Int) extends ImpurityAggregator(numClasses) with Serializable { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 5d404d09acd9e..a62c89df1a112 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -81,7 +81,7 @@ object Gini extends Impurity { * Note: Instances of this class do not hold the data; they operate on views of the data. * @param numClasses Number of classes for label. */ -private[spark] class GiniAggregator(numClasses: Int) +private[tree] class GiniAggregator(numClasses: Int) extends ImpurityAggregator(numClasses) with Serializable { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 2f8c772d92575..4ecc1951059cc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -71,7 +71,7 @@ object Variance extends Impurity { * in order to compute impurity from a sample. * Note: Instances of this class do not hold the data; they operate on views of the data. */ -private[spark] class VarianceAggregator() +private[tree] class VarianceAggregator() extends ImpurityAggregator(statsSize = 3) with Serializable { /** From e6226264cb9787cc2396d3164e1eb010b406cad7 Mon Sep 17 00:00:00 2001 From: sethah Date: Wed, 4 Nov 2015 11:25:18 -0800 Subject: [PATCH 06/12] removing obsolete methods --- .../spark/ml/tree/impl/RandomForest.scala | 6 ++-- .../spark/mllib/tree/DecisionTree.scala | 6 ++-- .../mllib/tree/impl/DTStatsAggregator.scala | 32 ++++--------------- 3 files changed, 11 insertions(+), 33 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 07b9b7db486ee..1c1e57435ddaa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -244,8 +244,7 @@ private[ml] object RandomForest extends Logging { if (unorderedFeatures.contains(featureIndex)) { // Unordered feature val featureValue = treePoint.binnedFeatures(featureIndex) - val (leftNodeFeatureOffset, rightNodeFeatureOffset) = - agg.getLeftRightFeatureOffsets(featureIndexIdx) + val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx) // Update the left or right bin for each split. val numSplits = agg.metadata.numSplits(featureIndex) val featureSplits = splits(featureIndex) @@ -696,8 +695,7 @@ private[ml] object RandomForest extends Logging { (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else if (binAggregates.metadata.isUnordered(featureIndex)) { // Unordered categorical feature - val (leftChildOffset, rightChildOffset) = - binAggregates.getLeftRightFeatureOffsets(featureIndexIdx) + val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx) val (bestFeatureSplitIndex, bestFeatureGainStats) = Range(0, numSplits).map { splitIndex => val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 0d9c7ed78c20b..984a020b42638 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -369,8 +369,7 @@ object DecisionTree extends Serializable with Logging { if (unorderedFeatures.contains(featureIndex)) { // Unordered feature val featureValue = treePoint.binnedFeatures(featureIndex) - val (leftNodeFeatureOffset, rightNodeFeatureOffset) = - agg.getLeftRightFeatureOffsets(featureIndexIdx) + val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx) // Update the left or right bin for each split. val numSplits = agg.metadata.numSplits(featureIndex) var splitIndex = 0 @@ -845,8 +844,7 @@ object DecisionTree extends Serializable with Logging { (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else if (binAggregates.metadata.isUnordered(featureIndex)) { // Unordered categorical feature - val (leftChildOffset, rightChildOffset) = - binAggregates.getLeftRightFeatureOffsets(featureIndexIdx) + val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx) val (bestFeatureSplitIndex, bestFeatureGainStats) = Range(0, numSplits).map { splitIndex => val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala index 8c8d99bb9a011..19893e18bfb0a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala @@ -73,28 +73,25 @@ private[spark] class DTStatsAggregator( * Flat array of elements. * Index for start of stats for a (feature, bin) is: * index = featureOffsets(featureIndex) + binIndex * statsSize - * Note: For unordered features, - * the left child stats have binIndex in [0, numBins(featureIndex) / 2)) - * and the right child stats in [numBins(featureIndex) / 2), numBins(featureIndex)) */ private val allStats: Array[Double] = new Array[Double](allStatsSize) + /** + * Array of parent node sufficient stats. + */ private val parentStats: Array[Double] = new Array[Double](statsSize) /** * Get an [[ImpurityCalculator]] for a given (node, feature, bin). - * @param featureOffset For ordered features, this is a pre-computed (node, feature) offset + * @param featureOffset This is a pre-computed (node, feature) offset * from [[getFeatureOffset]]. - * For unordered features, this is a pre-computed - * (node, feature, left/right child) offset from - * [[getLeftRightFeatureOffsets]]. */ def getImpurityCalculator(featureOffset: Int, binIndex: Int): ImpurityCalculator = { impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize) } /** - * Get an [[ImpurityCalculator]] for a given (node, feature, bin). + * Get an [[ImpurityCalculator]] for the parent node. */ def getParentImpurityCalculator(): ImpurityCalculator = { impurityAggregator.getCalculator(parentStats) @@ -118,11 +115,8 @@ private[spark] class DTStatsAggregator( /** * Faster version of [[update]]. * Update the stats for a given (feature, bin), using the given label. - * @param featureOffset For ordered features, this is a pre-computed feature offset + * @param featureOffset This is a pre-computed feature offset * from [[getFeatureOffset]]. - * For unordered features, this is a pre-computed - * (feature, left/right child) offset from - * [[getLeftRightFeatureOffsets]]. */ def featureUpdate( featureOffset: Int, @@ -139,22 +133,10 @@ private[spark] class DTStatsAggregator( */ def getFeatureOffset(featureIndex: Int): Int = featureOffsets(featureIndex) - /** - * Pre-compute feature offset for use with [[featureUpdate]]. - * For unordered features only. - */ - def getLeftRightFeatureOffsets(featureIndex: Int): (Int, Int) = { - val baseOffset = featureOffsets(featureIndex) - (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize) - } - /** * For a given feature, merge the stats for two bins. - * @param featureOffset For ordered features, this is a pre-computed feature offset + * @param featureOffset This is a pre-computed feature offset * from [[getFeatureOffset]]. - * For unordered features, this is a pre-computed - * (feature, left/right child) offset from - * [[getLeftRightFeatureOffsets]]. * @param binIndex The other bin is merged into this bin. * @param otherBinIndex This bin is not modified. */ From e210d87156ad9557b195c2b7e915d6cb8d08bf18 Mon Sep 17 00:00:00 2001 From: sethah Date: Wed, 4 Nov 2015 13:44:14 -0800 Subject: [PATCH 07/12] adding test for number of bins --- .../apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala | 2 +- .../scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index bd2501adf6dc9..235f0c8dd5966 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -67,7 +67,7 @@ private[spark] class DecisionTreeMetadata( /** * Number of splits for the given feature. - * For unordered features, there are 2 bins per split. + * For unordered features, there is 1 bin per split. * For ordered features, there is 1 more bin than split. */ def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 5518bdf527c8a..89b64fce96ebf 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -189,6 +189,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { assert(bins.length === 2) assert(splits(0).length === 3) assert(bins(0).length === 0) + assert(metadata.numSplits(0) === 3) + assert(metadata.numBins(0) === 3) + assert(metadata.numSplits(1) === 3) + assert(metadata.numBins(1) === 3) // Expecting 2^2 - 1 = 3 bins/splits assert(splits(0)(0).feature === 0) From c29e818c8db2ac81434f0bd9e3142b8bfa6509fb Mon Sep 17 00:00:00 2001 From: sethah Date: Wed, 4 Nov 2015 14:21:30 -0800 Subject: [PATCH 08/12] clone parent stats in getImpurityCalculator --- .../scala/org/apache/spark/mllib/tree/impurity/Entropy.scala | 2 +- .../scala/org/apache/spark/mllib/tree/impurity/Variance.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 473f6749563d1..10fe32b4307b3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -119,7 +119,7 @@ private[tree] class EntropyAggregator(numClasses: Int) * @param sufficientStats Sufficient stats array for a node. */ def getCalculator(sufficientStats: Array[Double]): EntropyCalculator = { - new EntropyCalculator(sufficientStats) + new EntropyCalculator(sufficientStats.clone()) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 4ecc1951059cc..7f9b5156c6a2d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -99,7 +99,7 @@ private[tree] class VarianceAggregator() * @param sufficientStats Sufficient stats array for a node. */ def getCalculator(sufficientStats: Array[Double]): VarianceCalculator = { - new VarianceCalculator(sufficientStats) + new VarianceCalculator(sufficientStats.clone()) } } From 31574c6f386c7b35469f5bba040bb05b78c105f6 Mon Sep 17 00:00:00 2001 From: sethah Date: Wed, 4 Nov 2015 14:37:45 -0800 Subject: [PATCH 09/12] style fixes --- .../scala/org/apache/spark/ml/tree/impl/RandomForest.scala | 3 ++- .../scala/org/apache/spark/mllib/tree/DecisionTree.scala | 3 ++- .../org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala | 5 +++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 1c1e57435ddaa..03a5cb0ba0b1c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -699,7 +699,8 @@ private[ml] object RandomForest extends Logging { val (bestFeatureSplitIndex, bestFeatureGainStats) = Range(0, numSplits).map { splitIndex => val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) - val rightChildStats = binAggregates.getParentImpurityCalculator().subtract(leftChildStats) + val rightChildStats = binAggregates.getParentImpurityCalculator() + .subtract(leftChildStats) gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, leftChildStats, rightChildStats, binAggregates.metadata) (splitIndex, gainAndImpurityStats) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 984a020b42638..980646b2aa153 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -848,7 +848,8 @@ object DecisionTree extends Serializable with Logging { val (bestFeatureSplitIndex, bestFeatureGainStats) = Range(0, numSplits).map { splitIndex => val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) - val rightChildStats = binAggregates.getParentImpurityCalculator().subtract(leftChildStats) + val rightChildStats = binAggregates.getParentImpurityCalculator() + .subtract(leftChildStats) predictWithImpurity = Some(predictWithImpurity.getOrElse( calculatePredictImpurity(leftChildStats, rightChildStats))) val gainStats = calculateGainForSplit(leftChildStats, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala index 19893e18bfb0a..d09334b82ca82 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala @@ -161,8 +161,9 @@ private[spark] class DTStatsAggregator( } require(statsSize == other.statsSize, - s"DTStatsAggregator.merge requires that both aggregators have the same length parent stats vectors." - + s" This aggregator's parent stats are length $statsSize, but the other is ${other.statsSize}.") + s"DTStatsAggregator.merge requires that both aggregators have the same length parent " + + s"stats vectors. This aggregator's parent stats are length $statsSize, " + + s"but the other is ${other.statsSize}.") var j = 0 while (j < statsSize) { parentStats(j) += other.parentStats(j) From ed74ab2718c8ede81567d4fb93169d31c4d9ad61 Mon Sep 17 00:00:00 2001 From: sethah Date: Tue, 15 Mar 2016 15:00:56 -0700 Subject: [PATCH 10/12] merge conflicts --- .../spark/mllib/tree/DecisionTree.scala | 54 ------------------- 1 file changed, 54 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 980646b2aa153..480da37f9adb4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -858,60 +858,6 @@ object DecisionTree extends Serializable with Logging { }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else { - // Ordered categorical feature - val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) - val numBins = binAggregates.metadata.numBins(featureIndex) - - /* Each bin is one category (feature value). - * The bins are ordered based on centroidForCategories, and this ordering determines which - * splits are considered. (With K categories, we consider K - 1 possible splits.) - * - * centroidForCategories is a list: (category, centroid) - */ - val centroidForCategories = if (binAggregates.metadata.isMulticlass) { - // For categorical variables in multiclass classification, - // the bins are ordered by the impurity of their corresponding labels. - Range(0, numBins).map { case featureValue => - val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val centroid = if (categoryStats.count != 0) { - categoryStats.calculate() - } else { - Double.MaxValue - } - (featureValue, centroid) ->>>>>>> Removing superfluous bins in decision tree training - } - // Find best split. - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { case splitIdx => - val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) - val rightChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) - rightChildStats.subtract(leftChildStats) - predictWithImpurity = Some(predictWithImpurity.getOrElse( - calculatePredictImpurity(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) - (splitIdx, gainStats) - }.maxBy(_._2.gain) - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) - } else if (binAggregates.metadata.isUnordered(featureIndex)) { - // Unordered categorical feature - val (leftChildOffset, rightChildOffset) = - binAggregates.getLeftRightFeatureOffsets(featureIndexIdx) - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { splitIndex => - val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) - val rightChildStats = - binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) - predictWithImpurity = Some(predictWithImpurity.getOrElse( - calculatePredictImpurity(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) - (splitIndex, gainStats) - }.maxBy(_._2.gain) - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) - } else { // Ordered categorical feature val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) val numBins = binAggregates.metadata.numBins(featureIndex) From 16e6f64ac51304ae8dda1f1cd079dac2603998ff Mon Sep 17 00:00:00 2001 From: sethah Date: Tue, 15 Mar 2016 15:25:20 -0700 Subject: [PATCH 11/12] indentation --- .../spark/mllib/tree/DecisionTree.scala | 58 +++++++++---------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 480da37f9adb4..ac23c9ccd6a2a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -829,35 +829,35 @@ object DecisionTree extends Serializable with Logging { binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) splitIndex += 1 } - // Find best split. - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { case splitIdx => - val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) - val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) - rightChildStats.subtract(leftChildStats) - predictWithImpurity = Some(predictWithImpurity.getOrElse( - calculatePredictImpurity(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) - (splitIdx, gainStats) - }.maxBy(_._2.gain) - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) - } else if (binAggregates.metadata.isUnordered(featureIndex)) { - // Unordered categorical feature - val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx) - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { splitIndex => - val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) - val rightChildStats = binAggregates.getParentImpurityCalculator() - .subtract(leftChildStats) - predictWithImpurity = Some(predictWithImpurity.getOrElse( - calculatePredictImpurity(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) - (splitIndex, gainStats) - }.maxBy(_._2.gain) - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) - } else { + // Find best split. + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { case splitIdx => + val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) + val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) + rightChildStats.subtract(leftChildStats) + predictWithImpurity = Some(predictWithImpurity.getOrElse( + calculatePredictImpurity(leftChildStats, rightChildStats))) + val gainStats = calculateGainForSplit(leftChildStats, + rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) + (splitIdx, gainStats) + }.maxBy(_._2.gain) + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) + } else if (binAggregates.metadata.isUnordered(featureIndex)) { + // Unordered categorical feature + val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx) + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { splitIndex => + val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) + val rightChildStats = binAggregates.getParentImpurityCalculator() + .subtract(leftChildStats) + predictWithImpurity = Some(predictWithImpurity.getOrElse( + calculatePredictImpurity(leftChildStats, rightChildStats))) + val gainStats = calculateGainForSplit(leftChildStats, + rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) + (splitIndex, gainStats) + }.maxBy(_._2.gain) + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) + } else { // Ordered categorical feature val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) val numBins = binAggregates.metadata.numBins(featureIndex) From 1cae10c48f34e91b0a7b4dbead698cbd771e676a Mon Sep 17 00:00:00 2001 From: sethah Date: Thu, 17 Mar 2016 11:58:03 -0700 Subject: [PATCH 12/12] addressing comments --- .../scala/org/apache/spark/mllib/tree/DecisionTree.scala | 3 ++- .../apache/spark/mllib/tree/impl/DTStatsAggregator.scala | 5 ++++- .../org/apache/spark/mllib/tree/impurity/Entropy.scala | 8 -------- .../scala/org/apache/spark/mllib/tree/impurity/Gini.scala | 8 -------- .../org/apache/spark/mllib/tree/impurity/Impurity.scala | 6 ------ .../org/apache/spark/mllib/tree/impurity/Variance.scala | 8 -------- 6 files changed, 6 insertions(+), 32 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index ac23c9ccd6a2a..fa6dde38c414d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -833,7 +833,8 @@ object DecisionTree extends Serializable with Logging { val (bestFeatureSplitIndex, bestFeatureGainStats) = Range(0, numSplits).map { case splitIdx => val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) - val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) + val rightChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) rightChildStats.subtract(leftChildStats) predictWithImpurity = Some(predictWithImpurity.getOrElse( calculatePredictImpurity(leftChildStats, rightChildStats))) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala index d09334b82ca82..c745e9f8dbed5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala @@ -78,6 +78,9 @@ private[spark] class DTStatsAggregator( /** * Array of parent node sufficient stats. + * + * Note: this is necessary because stats for the parent node are not available + * on the first iteration of tree learning. */ private val parentStats: Array[Double] = new Array[Double](statsSize) @@ -94,7 +97,7 @@ private[spark] class DTStatsAggregator( * Get an [[ImpurityCalculator]] for the parent node. */ def getParentImpurityCalculator(): ImpurityCalculator = { - impurityAggregator.getCalculator(parentStats) + impurityAggregator.getCalculator(parentStats, 0) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 10fe32b4307b3..13aff110079ec 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -113,14 +113,6 @@ private[tree] class EntropyAggregator(numClasses: Int) def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator = { new EntropyCalculator(allStats.view(offset, offset + statsSize).toArray) } - - /** - * Get an [[ImpurityCalculator]] for a node. - * @param sufficientStats Sufficient stats array for a node. - */ - def getCalculator(sufficientStats: Array[Double]): EntropyCalculator = { - new EntropyCalculator(sufficientStats.clone()) - } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index a62c89df1a112..39c7f9c3be8ab 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -109,14 +109,6 @@ private[tree] class GiniAggregator(numClasses: Int) def getCalculator(allStats: Array[Double], offset: Int): GiniCalculator = { new GiniCalculator(allStats.view(offset, offset + statsSize).toArray) } - - /** - * Get an [[ImpurityCalculator]] for a node. - * @param sufficientStats Sufficient stats array for a node. - */ - def getCalculator(sufficientStats: Array[Double]): GiniCalculator = { - new GiniCalculator(sufficientStats.clone()) - } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 0804920d0e718..ae2cdcb3879c8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -89,12 +89,6 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser * @param offset Start index of stats for this (node, feature, bin). */ def getCalculator(allStats: Array[Double], offset: Int): ImpurityCalculator - - /** - * Get an [[ImpurityCalculator]] for a node. - * @param sufficientStats Sufficient stats array for a node. - */ - def getCalculator(sufficientStats: Array[Double]): ImpurityCalculator } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 7f9b5156c6a2d..92d74a1b83341 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -93,14 +93,6 @@ private[tree] class VarianceAggregator() def getCalculator(allStats: Array[Double], offset: Int): VarianceCalculator = { new VarianceCalculator(allStats.view(offset, offset + statsSize).toArray) } - - /** - * Get an [[ImpurityCalculator]] for a node. - * @param sufficientStats Sufficient stats array for a node. - */ - def getCalculator(sufficientStats: Array[Double]): VarianceCalculator = { - new VarianceCalculator(sufficientStats.clone()) - } } /**