From cd5ea88687b72335647d74c0aeef375de01724d9 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Thu, 26 Apr 2018 12:59:18 +0800 Subject: [PATCH 1/4] init pr --- .../DecisionTreeClassifier.scala | 4 +- .../RandomForestClassifier.scala | 3 +- .../ml/regression/DecisionTreeRegressor.scala | 4 +- .../ml/regression/RandomForestRegressor.scala | 3 +- .../ml/tree/impl/GradientBoostedTrees.scala | 29 ++--- .../spark/ml/tree/impl/RandomForest.scala | 100 ++++++++++-------- .../spark/ml/util/Instrumentation.scala | 20 +++- .../spark/mllib/tree/RandomForest.scala | 2 +- .../ml/tree/impl/RandomForestSuite.scala | 20 ++-- 9 files changed, 108 insertions(+), 77 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 57797d1cc4978..1773499069559 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -114,7 +114,7 @@ class DecisionTreeClassifier @Since("1.4.0") ( instr.logParams(params: _*) val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", - seed = $(seed), instr = Some(instr), parentUID = Some(uid)) + seed = $(seed), instr = OptionalInstrumentation.create(instr), parentUID = Some(uid)) val m = trees.head.asInstanceOf[DecisionTreeClassificationModel] instr.logSuccess(m) @@ -128,7 +128,7 @@ class DecisionTreeClassifier @Since("1.4.0") ( instr.logParams(params: _*) val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all", - seed = 0L, instr = Some(instr), parentUID = Some(uid)) + seed = 0L, instr = OptionalInstrumentation.create(instr), parentUID = Some(uid)) val m = trees.head.asInstanceOf[DecisionTreeClassificationModel] instr.logSuccess(m) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index f1ef26a07d3f8..dd6c58c7b45d0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -136,7 +136,8 @@ class RandomForestClassifier @Since("1.4.0") ( minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval) val trees = RandomForest - .run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr)) + .run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, + instr = OptionalInstrumentation.create(instr)) .map(_.asInstanceOf[DecisionTreeClassificationModel]) val numFeatures = oldDataset.first().features.size diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 8bcf0793a64c1..76ca0048be06b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -109,7 +109,7 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S instr.logParams(params: _*) val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", - seed = $(seed), instr = Some(instr), parentUID = Some(uid)) + seed = $(seed), instr = OptionalInstrumentation.create(instr), parentUID = Some(uid)) val m = trees.head.asInstanceOf[DecisionTreeRegressionModel] instr.logSuccess(m) @@ -125,7 +125,7 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S instr.logParams(params: _*) val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy, - seed = $(seed), instr = Some(instr), parentUID = Some(uid)) + seed = $(seed), instr = OptionalInstrumentation.create(instr), parentUID = Some(uid)) val m = trees.head.asInstanceOf[DecisionTreeRegressionModel] instr.logSuccess(m) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 4509f85aafd12..fd93307cb7c0a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -127,7 +127,8 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S minInstancesPerNode, seed, subsamplingRate, cacheNodeIds, checkpointInterval) val trees = RandomForest - .run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr)) + .run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, + instr = OptionalInstrumentation.create(instr)) .map(_.asInstanceOf[DecisionTreeRegressionModel]) val numFeatures = oldDataset.first().features.size diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index bd8c9afb5e209..a41084081f168 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -17,10 +17,10 @@ package org.apache.spark.ml.tree.impl -import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} +import org.apache.spark.ml.util.OptionalInstrumentation import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy} import org.apache.spark.mllib.tree.impurity.{Variance => OldVariance} @@ -30,7 +30,7 @@ import org.apache.spark.rdd.util.PeriodicRDDCheckpointer import org.apache.spark.storage.StorageLevel -private[spark] object GradientBoostedTrees extends Logging { +private[spark] object GradientBoostedTrees { /** * Method to train a gradient boosting model @@ -250,7 +250,10 @@ private[spark] object GradientBoostedTrees extends Logging { boostingStrategy: OldBoostingStrategy, validate: Boolean, seed: Long, - featureSubsetStrategy: String): (Array[DecisionTreeRegressionModel], Array[Double]) = { + featureSubsetStrategy: String, + instr: OptionalInstrumentation = OptionalInstrumentation + .create(GradientBoostedTrees.getClass) + ): (Array[DecisionTreeRegressionModel], Array[Double]) = { val timer = new TimeTracker() timer.start("total") timer.start("init") @@ -287,9 +290,9 @@ private[spark] object GradientBoostedTrees extends Logging { timer.stop("init") - logDebug("##########") - logDebug("Building tree 0") - logDebug("##########") + instr.logDebug("##########") + instr.logDebug("Building tree 0") + instr.logDebug("##########") // Initialize tree timer.start("building tree 0") @@ -302,7 +305,7 @@ private[spark] object GradientBoostedTrees extends Logging { var predError: RDD[(Double, Double)] = computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss) predErrorCheckpointer.update(predError) - logDebug("error of gbt = " + predError.values.mean()) + instr.logDebug("error of gbt = " + predError.values.mean()) // Note: A model of type regression is used since we require raw prediction timer.stop("building tree 0") @@ -322,9 +325,9 @@ private[spark] object GradientBoostedTrees extends Logging { } timer.start(s"building tree $m") - logDebug("###################################################") - logDebug("Gradient boosting tree iteration " + m) - logDebug("###################################################") + instr.logDebug("###################################################") + instr.logDebug("Gradient boosting tree iteration " + m) + instr.logDebug("###################################################") val dt = new DecisionTreeRegressor().setSeed(seed + m) val model = dt.train(data, treeStrategy, featureSubsetStrategy) @@ -339,7 +342,7 @@ private[spark] object GradientBoostedTrees extends Logging { predError = updatePredictionError( input, predError, baseLearnerWeights(m), baseLearners(m), loss) predErrorCheckpointer.update(predError) - logDebug("error of gbt = " + predError.values.mean()) + instr.logDebug("error of gbt = " + predError.values.mean()) if (validate) { // Stop training early if @@ -364,8 +367,8 @@ private[spark] object GradientBoostedTrees extends Logging { timer.stop("total") - logInfo("Internal timing for DecisionTree:") - logInfo(s"$timer") + instr.logInfo("Internal timing for DecisionTree:") + instr.logInfo(s"$timer") predErrorCheckpointer.unpersistDataSet() predErrorCheckpointer.deleteAllCheckpoints() diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 905870178e549..d20608ee27a23 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -22,12 +22,11 @@ import java.io.IOException import scala.collection.mutable import scala.util.Random -import org.apache.spark.internal.Logging import org.apache.spark.ml.classification.DecisionTreeClassificationModel import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree._ -import org.apache.spark.ml.util.Instrumentation +import org.apache.spark.ml.util.{Instrumentation, OptionalInstrumentation} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.ImpurityCalculator import org.apache.spark.mllib.tree.model.ImpurityStats @@ -77,7 +76,7 @@ import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} * the heaviest part of the computation. In general, this implementation is bound by either * the cost of statistics computation on workers or by communicating the sufficient statistics. */ -private[spark] object RandomForest extends Logging { +private[spark] object RandomForest { /** * Train a random forest. @@ -91,7 +90,7 @@ private[spark] object RandomForest extends Logging { numTrees: Int, featureSubsetStrategy: String, seed: Long, - instr: Option[Instrumentation[_]], + instr: OptionalInstrumentation = OptionalInstrumentation.create(RandomForest.getClass), prune: Boolean = true, // exposed for testing only, real trees are always pruned parentUID: Option[String] = None): Array[DecisionTreeModel] = { @@ -104,24 +103,24 @@ private[spark] object RandomForest extends Logging { val retaggedInput = input.retag(classOf[LabeledPoint]) val metadata = DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy) - instr match { + instr.instrumentation match { case Some(instrumentation) => instrumentation.logNumFeatures(metadata.numFeatures) instrumentation.logNumClasses(metadata.numClasses) instrumentation.logNumExamples(metadata.numExamples) case None => - logInfo("numFeatures: " + metadata.numFeatures) - logInfo("numClasses: " + metadata.numClasses) - logInfo("numExamples: " + metadata.numExamples) + instr.logInfo("numFeatures: " + metadata.numFeatures) + instr.logInfo("numClasses: " + metadata.numClasses) + instr.logInfo("numExamples: " + metadata.numExamples) } // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. timer.start("findSplits") - val splits = findSplits(retaggedInput, metadata, seed) + val splits = findSplits(retaggedInput, metadata, seed, instr = instr) timer.stop("findSplits") - logDebug("numBins: feature: number of bins") - logDebug(Range(0, metadata.numFeatures).map { featureIndex => + instr.logDebug("numBins: feature: number of bins") + instr.logDebug(Range(0, metadata.numFeatures).map { featureIndex => s"\t$featureIndex\t${metadata.numBins(featureIndex)}" }.mkString("\n")) @@ -143,7 +142,7 @@ private[spark] object RandomForest extends Logging { // Max memory usage for aggregates // TODO: Calculate memory usage more precisely. val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L - logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") + instr.logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") /* * The main idea here is to perform group-wise training of the decision tree nodes thus @@ -187,7 +186,7 @@ private[spark] object RandomForest extends Logging { // Collect some nodes to split, and choose features for each node (if subsampling). // Each group of nodes may come from one or multiple trees, and at multiple levels. val (nodesForGroup, treeToNodeToIndexInfo) = - RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng) + RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng, instr = instr) // Sanity check (should never occur): assert(nodesForGroup.nonEmpty, s"RandomForest selected empty nodesForGroup. Error for unknown reason.") @@ -199,7 +198,7 @@ private[spark] object RandomForest extends Logging { // Choose node splits, and enqueue new nodes as needed. timer.start("findBestSplits") RandomForest.findBestSplits(baggedInput, metadata, topNodesForGroup, nodesForGroup, - treeToNodeToIndexInfo, splits, nodeStack, timer, nodeIdCache) + treeToNodeToIndexInfo, splits, nodeStack, timer, nodeIdCache, instr = instr) timer.stop("findBestSplits") } @@ -207,8 +206,8 @@ private[spark] object RandomForest extends Logging { timer.stop("total") - logInfo("Internal timing for DecisionTree:") - logInfo(s"$timer") + instr.logInfo("Internal timing for DecisionTree:") + instr.logInfo(s"$timer") // Delete any remaining checkpoints used for node Id cache. if (nodeIdCache.nonEmpty) { @@ -216,7 +215,7 @@ private[spark] object RandomForest extends Logging { nodeIdCache.get.deleteAllCheckpoints() } catch { case e: IOException => - logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}") + instr.logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}") } } @@ -373,7 +372,9 @@ private[spark] object RandomForest extends Logging { splits: Array[Array[Split]], nodeStack: mutable.ArrayStack[(Int, LearningNode)], timer: TimeTracker = new TimeTracker, - nodeIdCache: Option[NodeIdCache] = None): Unit = { + nodeIdCache: Option[NodeIdCache] = None, + instr: OptionalInstrumentation = OptionalInstrumentation.create(RandomForest.getClass) + ): Unit = { /* * The high-level descriptions of the best split optimizations are noted here. @@ -399,13 +400,13 @@ private[spark] object RandomForest extends Logging { // numNodes: Number of nodes in this group val numNodes = nodesForGroup.values.map(_.length).sum - logDebug("numNodes = " + numNodes) - logDebug("numFeatures = " + metadata.numFeatures) - logDebug("numClasses = " + metadata.numClasses) - logDebug("isMulticlass = " + metadata.isMulticlass) - logDebug("isMulticlassWithCategoricalFeatures = " + + instr.logDebug("numNodes = " + numNodes) + instr.logDebug("numFeatures = " + metadata.numFeatures) + instr.logDebug("numClasses = " + metadata.numClasses) + instr.logDebug("isMulticlass = " + metadata.isMulticlass) + instr.logDebug("isMulticlassWithCategoricalFeatures = " + metadata.isMulticlassWithCategoricalFeatures) - logDebug("using nodeIdCache = " + nodeIdCache.nonEmpty.toString) + instr.logDebug("using nodeIdCache = " + nodeIdCache.nonEmpty.toString) /** * Performs a sequential aggregation over a partition for a particular tree and node. @@ -562,7 +563,7 @@ private[spark] object RandomForest extends Logging { // find best split for each node val (split: Split, stats: ImpurityStats) = - binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) + binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex), instr = instr) (nodeIndex, (split, stats)) }.collectAsMap() @@ -582,14 +583,14 @@ private[spark] object RandomForest extends Logging { val aggNodeIndex = nodeInfo.nodeIndexInGroup val (split: Split, stats: ImpurityStats) = nodeToBestSplits(aggNodeIndex) - logDebug("best split = " + split) + instr.logDebug("best split = " + split) // Extract info for this node. Create children if not leaf. val isLeaf = (stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth) node.isLeaf = isLeaf node.stats = stats - logDebug("Node = " + node) + instr.logDebug("Node = " + node) if (!isLeaf) { node.split = Some(split) @@ -616,9 +617,9 @@ private[spark] object RandomForest extends Logging { nodeStack.push((treeIndex, node.rightChild.get)) } - logDebug("leftChildIndex = " + node.leftChild.get.id + + instr.logDebug("leftChildIndex = " + node.leftChild.get.id + ", impurity = " + stats.leftImpurity) - logDebug("rightChildIndex = " + node.rightChild.get.id + + instr.logDebug("rightChildIndex = " + node.rightChild.get.id + ", impurity = " + stats.rightImpurity) } } @@ -699,7 +700,9 @@ private[spark] object RandomForest extends Logging { binAggregates: DTStatsAggregator, splits: Array[Array[Split]], featuresForNode: Option[Array[Int]], - node: LearningNode): (Split, ImpurityStats) = { + node: LearningNode, + instr: OptionalInstrumentation = OptionalInstrumentation.create(RandomForest.getClass) + ): (Split, ImpurityStats) = { // Calculate InformationGain and ImpurityStats if current node is top node val level = LearningNode.indexToLevel(node.id) @@ -793,12 +796,13 @@ private[spark] object RandomForest extends Logging { (featureValue, centroid) } - logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(",")) + instr.logDebug("Centroids for categorical variable: " + + centroidForCategories.mkString(",")) // bins sorted by centroids val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2) - logDebug("Sorted centroids for categorical variable = " + + instr.logDebug("Sorted centroids for categorical variable = " + categoriesSortedByCentroid.mkString(",")) // Cumulative sum (scanLeft) of bin statistics. @@ -885,9 +889,11 @@ private[spark] object RandomForest extends Logging { protected[tree] def findSplits( input: RDD[LabeledPoint], metadata: DecisionTreeMetadata, - seed: Long): Array[Array[Split]] = { + seed: Long, + instr: OptionalInstrumentation = OptionalInstrumentation.create(RandomForest.getClass) + ): Array[Array[Split]] = { - logDebug("isMulticlass = " + metadata.isMulticlass) + instr.logDebug("isMulticlass = " + metadata.isMulticlass) val numFeatures = metadata.numFeatures @@ -895,19 +901,21 @@ private[spark] object RandomForest extends Logging { val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous) val sampledInput = if (continuousFeatures.nonEmpty) { val fraction = samplesFractionForFindSplits(metadata) - logDebug("fraction of data used for calculating quantiles = " + fraction) + instr.logDebug("fraction of data used for calculating quantiles = " + fraction) input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()) } else { input.sparkContext.emptyRDD[LabeledPoint] } - findSplitsBySorting(sampledInput, metadata, continuousFeatures) + findSplitsBySorting(sampledInput, metadata, continuousFeatures, instr = instr) } private def findSplitsBySorting( input: RDD[LabeledPoint], metadata: DecisionTreeMetadata, - continuousFeatures: IndexedSeq[Int]): Array[Array[Split]] = { + continuousFeatures: IndexedSeq[Int], + instr: OptionalInstrumentation = OptionalInstrumentation.create(RandomForest.getClass) + ): Array[Array[Split]] = { val continuousSplits: scala.collection.Map[Int, Array[Split]] = { // reduce the parallelism for split computations when there are less @@ -920,9 +928,9 @@ private[spark] object RandomForest extends Logging { continuousFeatures.map(idx => (idx, point.features(idx))).filter(_._2 != 0.0) }.groupByKey(numPartitions) .map { case (idx, samples) => - val thresholds = findSplitsForContinuousFeature(samples, metadata, idx) + val thresholds = findSplitsForContinuousFeature(samples, metadata, idx, instr = instr) val splits: Array[Split] = thresholds.map(thresh => new ContinuousSplit(idx, thresh)) - logDebug(s"featureIndex = $idx, numSplits = ${splits.length}") + instr.logDebug(s"featureIndex = $idx, numSplits = ${splits.length}") (idx, splits) }.collectAsMap() } @@ -992,7 +1000,9 @@ private[spark] object RandomForest extends Logging { private[tree] def findSplitsForContinuousFeature( featureSamples: Iterable[Double], metadata: DecisionTreeMetadata, - featureIndex: Int): Array[Double] = { + featureIndex: Int, + instr: OptionalInstrumentation = OptionalInstrumentation.create(RandomForest.getClass) + ): Array[Double] = { require(metadata.isContinuous(featureIndex), "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.") @@ -1032,7 +1042,7 @@ private[spark] object RandomForest extends Logging { } else { // stride between splits val stride: Double = numSamples.toDouble / (numSplits + 1) - logDebug("stride = " + stride) + instr.logDebug("stride = " + stride) // iterate `valueCount` to find splits val splitsBuilder = mutable.ArrayBuilder.make[Double] @@ -1090,7 +1100,9 @@ private[spark] object RandomForest extends Logging { nodeStack: mutable.ArrayStack[(Int, LearningNode)], maxMemoryUsage: Long, metadata: DecisionTreeMetadata, - rng: Random): (Map[Int, Array[LearningNode]], Map[Int, Map[Int, NodeIndexInfo]]) = { + rng: Random, + instr: OptionalInstrumentation = OptionalInstrumentation.create(RandomForest.getClass) + ): (Map[Int, Array[LearningNode]], Map[Int, Map[Int, NodeIndexInfo]]) = { // Collect some nodes to split: // nodesForGroup(treeIndex) = nodes to split val mutableNodesForGroup = new mutable.HashMap[Int, mutable.ArrayBuffer[LearningNode]]() @@ -1127,8 +1139,8 @@ private[spark] object RandomForest extends Logging { } if (memUsage > maxMemoryUsage) { // If maxMemoryUsage is 0, we should still allow splitting 1 node. - logWarning(s"Tree learning is using approximately $memUsage bytes per iteration, which" + - s" exceeds requested limit maxMemoryUsage=$maxMemoryUsage. This allows splitting" + + instr.logWarning(s"Tree learning is using approximately $memUsage bytes per iteration, " + + s"which exceeds requested limit maxMemoryUsage=$maxMemoryUsage. This allows splitting" + s" $numNodesInGroup nodes in this iteration.") } // Convert mutable maps to immutable ones. diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index 3247c394dfa64..32c712d2df09c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -42,8 +42,8 @@ import org.apache.spark.sql.Dataset * @tparam E the type of the estimator */ private[spark] class Instrumentation[E <: Estimator[_]] private ( - val estimator: E, - val dataset: RDD[_]) extends Logging { + @transient val estimator: E, + @transient val dataset: RDD[_]) extends Logging with Serializable { private val id = UUID.randomUUID() private val prefix = { @@ -79,6 +79,13 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( super.logInfo(prefix + msg) } + /** + * Logs an debug message with a prefix that uniquely identifies the training session. + */ + override def logDebug(msg: => String): Unit = { + super.logDebug(prefix + msg) + } + /** * Alias for logInfo, see above. */ @@ -171,10 +178,17 @@ private[spark] object Instrumentation { */ private[spark] class OptionalInstrumentation private( val instrumentation: Option[Instrumentation[_ <: Estimator[_]]], - val className: String) extends Logging { + val className: String) extends Logging with Serializable { protected override def logName: String = className + override def logDebug(msg: => String) { + instrumentation match { + case Some(instr) => instr.logDebug(msg) + case None => super.logDebug(msg) + } + } + override def logInfo(msg: => String) { instrumentation match { case Some(instr) => instr.logInfo(msg) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index a8c5286f3dc10..09970aa918d73 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -92,7 +92,7 @@ private class RandomForest ( */ def run(input: RDD[LabeledPoint]): RandomForestModel = { val trees: Array[NewDTModel] = NewRandomForest.run(input.map(_.asML), strategy, numTrees, - featureSubsetStrategy, seed.toLong, None) + featureSubsetStrategy, seed.toLong) new RandomForestModel(strategy.algo, trees.map(_.toOld)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 4dbbd75d2466d..16d0cb8355ece 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -203,7 +203,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { withClue("DecisionTree requires number of features > 0," + " but was given an empty features vector") { intercept[IllegalArgumentException] { - RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None) + RandomForest.run(rdd, strategy, 1, "all", 42L) } } } @@ -219,7 +219,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { numClasses = 2, maxBins = 5, categoricalFeaturesInfo = Map(0 -> 1, 1 -> 5)) - val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None) + val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L) assert(tree.rootNode.impurity === -1.0) assert(tree.depth === 0) assert(tree.rootNode.prediction === lp.label) @@ -230,7 +230,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { Variance, maxDepth = 2, maxBins = 5) - val Array(tree2) = RandomForest.run(rdd, strategy2, 1, "all", 42L, instr = None) + val Array(tree2) = RandomForest.run(rdd, strategy2, 1, "all", 42L) assert(tree2.rootNode.impurity === -1.0) assert(tree2.depth === 0) assert(tree2.rootNode.prediction === lp.label) @@ -411,7 +411,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3) val model = RandomForest.run(input, strategy, numTrees = 1, featureSubsetStrategy = "all", - seed = 42, instr = None, prune = false).head + seed = 42, prune = false).head model.rootNode match { case n: InternalNode => n.split match { @@ -435,9 +435,9 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { new OldStrategy(OldAlgo.Classification, Entropy, 3, 2, 100, maxMemoryInMB = 0) val tree1 = RandomForest.run(rdd, strategy1, numTrees = 1, featureSubsetStrategy = "all", - seed = 42, instr = None).head + seed = 42).head val tree2 = RandomForest.run(rdd, strategy2, numTrees = 1, featureSubsetStrategy = "all", - seed = 42, instr = None).head + seed = 42).head def getChildren(rootNode: Node): Array[InternalNode] = rootNode match { case n: InternalNode => @@ -661,10 +661,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { numClasses = numClasses, maxBins = 32) val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", - seed = 42, instr = None).head + seed = 42).head val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", - seed = 42, instr = None, prune = false).head + seed = 42, prune = false).head assert(prunedTree.numNodes === 5) assert(unprunedTree.numNodes === 7) @@ -691,10 +691,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { numClasses = 0, maxBins = 32) val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", - seed = 42, instr = None).head + seed = 42).head val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", - seed = 42, instr = None, prune = false).head + seed = 42, prune = false).head assert(prunedTree.numNodes === 3) assert(unprunedTree.numNodes === 5) From e30377083fe635dd657e1a0547991ced862fe0c0 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Thu, 3 May 2018 18:14:54 +0800 Subject: [PATCH 2/4] GBT passing instr --- .../ml/classification/GBTClassifier.scala | 2 +- .../spark/ml/regression/GBTRegressor.scala | 2 +- .../ml/tree/impl/GradientBoostedTrees.scala | 18 ++++++++++++------ 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 0aa24f0a3cfcc..0001263c50421 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -177,7 +177,7 @@ class GBTClassifier @Since("1.4.0") ( instr.logNumClasses(numClasses) val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, - $(seed), $(featureSubsetStrategy)) + $(seed), $(featureSubsetStrategy), instr = OptionalInstrumentation.create(instr)) val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) instr.logSuccess(m) m diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 8598e808c4946..2e559137a4a71 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -159,7 +159,7 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) instr.logNumFeatures(numFeatures) val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, - $(seed), $(featureSubsetStrategy)) + $(seed), $(featureSubsetStrategy), instr = OptionalInstrumentation.create(instr)) val m = new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) instr.logSuccess(m) m diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index a41084081f168..6fa419a6d9024 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -43,17 +43,20 @@ private[spark] object GradientBoostedTrees { input: RDD[LabeledPoint], boostingStrategy: OldBoostingStrategy, seed: Long, - featureSubsetStrategy: String): (Array[DecisionTreeRegressionModel], Array[Double]) = { + featureSubsetStrategy: String, + instr: OptionalInstrumentation = OptionalInstrumentation + .create(GradientBoostedTrees.getClass) + ): (Array[DecisionTreeRegressionModel], Array[Double]) = { val algo = boostingStrategy.treeStrategy.algo algo match { case OldAlgo.Regression => GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, - seed, featureSubsetStrategy) + seed, featureSubsetStrategy, instr = instr) case OldAlgo.Classification => // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false, - seed, featureSubsetStrategy) + seed, featureSubsetStrategy, instr = instr) case _ => throw new IllegalArgumentException(s"$algo is not supported by gradient boosting.") } @@ -76,12 +79,15 @@ private[spark] object GradientBoostedTrees { validationInput: RDD[LabeledPoint], boostingStrategy: OldBoostingStrategy, seed: Long, - featureSubsetStrategy: String): (Array[DecisionTreeRegressionModel], Array[Double]) = { + featureSubsetStrategy: String, + instr: OptionalInstrumentation = OptionalInstrumentation + .create(GradientBoostedTrees.getClass) + ): (Array[DecisionTreeRegressionModel], Array[Double]) = { val algo = boostingStrategy.treeStrategy.algo algo match { case OldAlgo.Regression => GradientBoostedTrees.boost(input, validationInput, boostingStrategy, - validate = true, seed, featureSubsetStrategy) + validate = true, seed, featureSubsetStrategy, instr = instr) case OldAlgo.Classification => // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map( @@ -89,7 +95,7 @@ private[spark] object GradientBoostedTrees { val remappedValidationInput = validationInput.map( x => new LabeledPoint((x.label * 2) - 1, x.features)) GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy, - validate = true, seed, featureSubsetStrategy) + validate = true, seed, featureSubsetStrategy, instr = instr) case _ => throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") } From 3ccf94e412a061277b58a56fa9b6cbdd6bdf08ba Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Thu, 3 May 2018 21:44:10 +0800 Subject: [PATCH 3/4] cleanup import --- .../main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index d20608ee27a23..edb12c63f8d54 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -26,7 +26,7 @@ import org.apache.spark.ml.classification.DecisionTreeClassificationModel import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree._ -import org.apache.spark.ml.util.{Instrumentation, OptionalInstrumentation} +import org.apache.spark.ml.util.OptionalInstrumentation import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.ImpurityCalculator import org.apache.spark.mllib.tree.model.ImpurityStats From ca587f37312d9ae838f677c68881400577a85a06 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Thu, 31 May 2018 23:09:47 +0800 Subject: [PATCH 4/4] fix build error --- .../scala/org/apache/spark/ml/util/Instrumentation.scala | 7 ------- 1 file changed, 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index 2ce1864953f8b..acc10c360b639 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -86,13 +86,6 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( super.logInfo(prefix + msg) } - /** - * Logs an debug message with a prefix that uniquely identifies the training session. - */ - override def logDebug(msg: => String): Unit = { - super.logDebug(prefix + msg) - } - /** * Alias for logInfo, see above. */