Skip to content

Commit

Permalink
added unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
manishamde committed Apr 30, 2014
1 parent 1517155 commit 718506b
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ object DecisionTreeRunner {
algo: Algo = Classification,
maxDepth: Int = 5,
impurity: ImpurityType = Gini,
maxBins: Int = 20)
maxBins: Int = 100)

def main(args: Array[String]) {
val defaultParams = Params()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ import org.apache.spark.SparkContext
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model.Filter
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.model.Split
import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy}
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.linalg.Vectors
Expand Down Expand Up @@ -390,6 +391,53 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(bestSplits(0)._2.rightImpurity === 0)
assert(bestSplits(0)._2.predict === 1)
}

test("test second level node building with/without groups") {
val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Entropy, 3, 100)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
assert(splits.length === 2)
assert(splits(0).length === 99)
assert(bins.length === 2)
assert(bins(0).length === 100)
assert(splits(0).length === 99)
assert(bins(0).length === 100)

val leftFilter = Filter(new Split(0,400,FeatureType.Continuous,List()),-1)
val rightFilter = Filter(new Split(0,400,FeatureType.Continuous,List()),1)
val filters = Array[List[Filter]](List(),List(leftFilter),List(rightFilter))
val parentImpurities = Array(0.5, 0.5, 0.5)

// Single group second level tree construction.
val bestSplits = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1, filters,
splits, bins, 10)
assert(bestSplits.length === 2)
assert(bestSplits(0)._2.gain > 0)
assert(bestSplits(1)._2.gain > 0)

// maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second
// level tree construction.
val bestSplitsWithGroups = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1,
filters, splits, bins, 0)
assert(bestSplitsWithGroups.length === 2)
assert(bestSplitsWithGroups(0)._2.gain > 0)
assert(bestSplitsWithGroups(1)._2.gain > 0)

// Verify whether the splits obtained using single group and multiple group level
// construction strategies are the same.
for (i <- 0 until bestSplits.length) {
assert(bestSplits(i)._1 === bestSplitsWithGroups(i)._1)
assert(bestSplits(i)._2.gain === bestSplitsWithGroups(i)._2.gain)
assert(bestSplits(i)._2.impurity === bestSplitsWithGroups(i)._2.impurity)
assert(bestSplits(i)._2.leftImpurity === bestSplitsWithGroups(i)._2.leftImpurity)
assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity)
assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict)
}

}

}

object DecisionTreeSuite {
Expand All @@ -412,6 +460,20 @@ object DecisionTreeSuite {
arr
}

def generateOrderedLabeledPoints(): Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](1000)
for (i <- 0 until 1000){
if (i < 600){
val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i))
arr(i) = lp
} else {
val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 1000.0 - i))
arr(i) = lp
}
}
arr
}

def generateCategoricalDataPoints(): Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](1000)
for (i <- 0 until 1000){
Expand Down

0 comments on commit 718506b

Please sign in to comment.