From f4b38604fda97970f64c5d2f375e372c1f7945e3 Mon Sep 17 00:00:00 2001 From: Luvsandondov Lkhamsuren Date: Fri, 4 Sep 2015 18:29:10 -0700 Subject: [PATCH 1/2] [SPARK-9963] [ML] ML RandomForest cleanup: replace predictNodeIndex with predictImpl --- .../scala/org/apache/spark/ml/tree/Node.scala | 32 +++++++++++++ .../spark/ml/tree/impl/RandomForest.scala | 46 +------------------ 2 files changed, 34 insertions(+), 44 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index cd24931293903..673465bcc1ef3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -279,6 +279,38 @@ private[tree] class LearningNode( } } + /** + * Get the node corresponding to this data point. + * This function mimics prediction, passing an example from the root node down to a leaf + * or unsplit node; that node is returned. + * + * @param binnedFeatures Binned feature vector for data point. + * @param splits possible splits for all features, indexed (numFeatures)(numSplits) + */ + def predictImpl(binnedFeatures: Array[Int], splits: Array[Array[Split]]): LearningNode = { + if (this.isLeaf || this.split.isEmpty) { + this + } else { + val split = this.split.get + val featureIndex = split.featureIndex + val splitLeft = split.shouldGoLeft(binnedFeatures(featureIndex), splits(featureIndex)) + if (this.leftChild.isEmpty) { + // Not yet split. Return next layer of nodes to train + if (splitLeft) { + leftChild.get + } else { + rightChild.get + } + } else { + if (splitLeft) { + this.leftChild.get.predictImpl(binnedFeatures, splits) + } else { + this.rightChild.get.predictImpl(binnedFeatures, splits) + } + } + } + } + } private[tree] object LearningNode { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index c494556085e95..9d51f7f930a1d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -205,47 +205,6 @@ private[ml] object RandomForest extends Logging { } } - /** - * Get the node index corresponding to this data point. - * This function mimics prediction, passing an example from the root node down to a leaf - * or unsplit node; that node's index is returned. - * - * @param node Node in tree from which to classify the given data point. - * @param binnedFeatures Binned feature vector for data point. - * @param splits possible splits for all features, indexed (numFeatures)(numSplits) - * @return Leaf index if the data point reaches a leaf. - * Otherwise, last node reachable in tree matching this example. - * Note: This is the global node index, i.e., the index used in the tree. - * This index is different from the index used during training a particular - * group of nodes on one call to [[findBestSplits()]]. - */ - private def predictNodeIndex( - node: LearningNode, - binnedFeatures: Array[Int], - splits: Array[Array[Split]]): Int = { - if (node.isLeaf || node.split.isEmpty) { - node.id - } else { - val split = node.split.get - val featureIndex = split.featureIndex - val splitLeft = split.shouldGoLeft(binnedFeatures(featureIndex), splits(featureIndex)) - if (node.leftChild.isEmpty) { - // Not yet split. Return index from next layer of nodes to train - if (splitLeft) { - LearningNode.leftChildIndex(node.id) - } else { - LearningNode.rightChildIndex(node.id) - } - } else { - if (splitLeft) { - predictNodeIndex(node.leftChild.get, binnedFeatures, splits) - } else { - predictNodeIndex(node.rightChild.get, binnedFeatures, splits) - } - } - } - } - /** * Helper for binSeqOp, for data which can contain a mix of ordered and unordered features. * @@ -453,9 +412,8 @@ private[ml] object RandomForest extends Logging { agg: Array[DTStatsAggregator], baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = { treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => - val nodeIndex = - predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures, splits) - nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint) + val node = topNodes(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits) + nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(node.id, null), agg, baggedPoint) } agg } From c3bb2f3c849f72ed0871549fbd9924677167228a Mon Sep 17 00:00:00 2001 From: Luvsandondov Lkhamsuren Date: Thu, 8 Oct 2015 09:20:51 -0700 Subject: [PATCH 2/2] Change Learning Node predictImpl to return Node's index. --- .../scala/org/apache/spark/ml/tree/Node.scala | 17 +++++++++++------ .../spark/ml/tree/impl/RandomForest.scala | 4 ++-- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index 673465bcc1ef3..d89682611e3f5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -280,16 +280,21 @@ private[tree] class LearningNode( } /** - * Get the node corresponding to this data point. + * Get the node index corresponding to this data point. * This function mimics prediction, passing an example from the root node down to a leaf - * or unsplit node; that node is returned. + * or unsplit node; that node's index is returned. * * @param binnedFeatures Binned feature vector for data point. * @param splits possible splits for all features, indexed (numFeatures)(numSplits) + * @return Leaf index if the data point reaches a leaf. + * Otherwise, last node reachable in tree matching this example. + * Note: This is the global node index, i.e., the index used in the tree. + * This index is different from the index used during training a particular + * group of nodes on one call to [[findBestSplits()]]. */ - def predictImpl(binnedFeatures: Array[Int], splits: Array[Array[Split]]): LearningNode = { + def predictImpl(binnedFeatures: Array[Int], splits: Array[Array[Split]]): Int = { if (this.isLeaf || this.split.isEmpty) { - this + this.id } else { val split = this.split.get val featureIndex = split.featureIndex @@ -297,9 +302,9 @@ private[tree] class LearningNode( if (this.leftChild.isEmpty) { // Not yet split. Return next layer of nodes to train if (splitLeft) { - leftChild.get + LearningNode.leftChildIndex(this.id) } else { - rightChild.get + LearningNode.rightChildIndex(this.id) } } else { if (splitLeft) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 9d51f7f930a1d..96d5652857e08 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -412,8 +412,8 @@ private[ml] object RandomForest extends Logging { agg: Array[DTStatsAggregator], baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = { treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => - val node = topNodes(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits) - nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(node.id, null), agg, baggedPoint) + val nodeIndex = topNodes(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits) + nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint) } agg }