From e57ffaaad1666577d956c1f8f734f97569b93969 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 7 Mar 2018 18:37:22 +0800 Subject: [PATCH] init pr --- .../scala/org/apache/spark/ml/tree/Node.scala | 5 ++- .../spark/mllib/tree/impurity/Entropy.scala | 2 + .../spark/mllib/tree/impurity/Gini.scala | 2 + .../spark/mllib/tree/impurity/Impurity.scala | 37 +++++++++++++++++++ .../spark/mllib/tree/impurity/Variance.scala | 2 + .../DecisionTreeClassifierSuite.scala | 26 +++++++++++++ .../DecisionTreeRegressorSuite.scala | 14 +++++++ 7 files changed, 87 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index d30be452a436e..29b5713e0a6fe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.tree import org.apache.spark.ml.linalg.Vector -import org.apache.spark.mllib.tree.impurity.ImpurityCalculator +import org.apache.spark.mllib.tree.impurity.{ImpurityCalculator, TreeStatInfo} import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict} /** @@ -35,6 +35,9 @@ sealed abstract class Node extends Serializable { /** Impurity measure at this node (for training data) */ def impurity: Double + /** label/impurity stats at this node */ + def statInfo: TreeStatInfo = impurityStats.getStatInfo + /** * Statistics aggregated from training data at this node, used to compute prediction, impurity, * and probabilities. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index d4448da9eef51..e08a167ca97d7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -160,6 +160,8 @@ private[spark] class EntropyCalculator(stats: Array[Double]) extends ImpurityCal } } + override def getStatInfo: TreeStatInfo = new TreeClassifierStatInfo(stats) + override def toString: String = s"EntropyCalculator(stats = [${stats.mkString(", ")}])" } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index c5e34ffa4f2e5..fe32577a92c71 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -157,6 +157,8 @@ private[spark] class GiniCalculator(stats: Array[Double]) extends ImpurityCalcul } } + override def getStatInfo: TreeStatInfo = new TreeClassifierStatInfo(stats) + override def toString: String = s"GiniCalculator(stats = [${stats.mkString(", ")}])" } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index f151a6a01b658..7c0138aa6a0be 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -177,6 +177,7 @@ private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) exten result._1 } + def getStatInfo: TreeStatInfo } private[spark] object ImpurityCalculator { @@ -196,3 +197,39 @@ private[spark] object ImpurityCalculator { } } } + +@Since("2.4.0") +trait TreeStatInfo extends Serializable { + + @Since("2.4.0") + def asTreeClassifierStatInfo: TreeClassifierStatInfo = this.asInstanceOf[TreeClassifierStatInfo] + + @Since("2.4.0") + def asTreeRegressorStatInfo: TreeRegressorStatInfo = this.asInstanceOf[TreeRegressorStatInfo] +} + +@Since("2.4.0") +class TreeClassifierStatInfo(val stats: Array[Double]) extends TreeStatInfo { + + @Since("2.4.0") + def getLabelCount(label: Int): Double = { + require(label >= 0 && label < stats.length, + s"label must be between 0(inclusive) and ${stats.length}(exclusive).") + stats(label) + } +} + +@Since("2.4.0") +class TreeRegressorStatInfo(val stats: Array[Double]) extends TreeStatInfo { + + require(stats.length == 3) + + @Since("2.4.0") + def getCount(): Double = stats(0) + + @Since("2.4.0") + def getSum(): Double = stats(1) + + @Since("2.4.0") + def getSquareSum(): Double = stats(2) +} \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index c9bf0db4de3c2..a1c1b3a97c550 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -129,6 +129,8 @@ private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCa stats(1) / count } + override def getStatInfo: TreeStatInfo = new TreeRegressorStatInfo(stats) + override def toString: String = { s"VarianceAggregator(cnt = ${stats(0)}, sum = ${stats(1)}, sum2 = ${stats(2)})" } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index eeb0324187c5b..40624f77f17f3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -360,6 +360,32 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { testDefaultReadWrite(model) } + + test("label/impurity stats") { + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) + val rdd = sc.parallelize(arr) + val df = TreeTests.setMetadata(rdd, Map.empty[Int, Int], 2) + val dt1 = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(2) + .setMinInstancesPerNode(2) + val model1 = dt1.fit(df) + + val statInfo1 = model1.rootNode.statInfo.asTreeClassifierStatInfo + assert(Array(statInfo1.getLabelCount(0), statInfo1.getLabelCount(1)) === Array(2.0, 1.0)) + + val dt2 = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(2) + .setMinInstancesPerNode(2) + val model2 = dt2.fit(df) + + val statInfo2 = model2.rootNode.statInfo.asTreeClassifierStatInfo + assert(Array(statInfo2.getLabelCount(0), statInfo2.getLabelCount(1)) === Array(2.0, 1.0)) + } } private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 68a1218c23ece..f7860f8696143 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -144,6 +144,20 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest { } } + test("label/impurity stats") { + val categoricalFeatures = Map(0 -> 2, 1 -> 2) + val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0) + val dtr = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(8) + val model = dtr.fit(df) + val statInfo = model.rootNode.statInfo.asTreeRegressorStatInfo + + assert(statInfo.getCount() == 1000.0 && statInfo.getSum() == 600.0 + && statInfo.getSquareSum() == 600.0) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load /////////////////////////////////////////////////////////////////////////////