Skip to content

Commit

Permalink
change edge minInstancesPerNode to 2 and add one more test
Browse files Browse the repository at this point in the history
  • Loading branch information
chouqin committed Sep 10, 2014
1 parent 0278a11 commit 39f9b60
Showing 1 changed file with 30 additions and 4 deletions.
Expand Up @@ -683,8 +683,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))

val input = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
numClassesForClassification = 2, minInstancesPerNode = 4)
val strategy = new Strategy(algo = Classification, impurity = Gini,
maxDepth = 2, numClassesForClassification = 2, minInstancesPerNode = 2)

val model = DecisionTree.train(input, strategy)
assert(model.topNode.isLeaf)
Expand All @@ -701,11 +701,37 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
new Array[Node](0), splits, bins, 10)

assert(bestSplits.length === 1)
assert(bestSplits.length == 1)
val bestInfoStats = bestSplits(0)._2
assert(bestInfoStats == InformationGainStats.invalidInformationGainStats)
}

test("don't chose split that doesn't satify min instance per node requirements") {
// if a split doesn't satisfy min instances per node requirements,
// this split is invalid, even though the information gain of split is large.
val arr = new Array[LabeledPoint](4)
arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0, 1.0))
arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0, 1.0))
arr(2) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0))
arr(3) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0))

val input = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini,
maxBins = 2, maxDepth = 2, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2),
numClassesForClassification = 2, minInstancesPerNode = 2)
val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
new Array[Node](0), splits, bins, 10)

assert(bestSplits.length == 1)
val bestSplit = bestSplits(0)._1
val bestSplitStats = bestSplits(0)._1
assert(bestSplit.feature == 1)
assert(bestSplitStats != InformationGainStats.invalidInformationGainStats)
}

test("split must satisfy min info gain requirements") {
val arr = new Array[LabeledPoint](3)
arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
Expand All @@ -731,7 +757,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
new Array[Node](0), splits, bins, 10)

assert(bestSplits.length === 1)
assert(bestSplits.length == 1)
val bestInfoStats = bestSplits(0)._2
assert(bestInfoStats == InformationGainStats.invalidInformationGainStats)
}
Expand Down

0 comments on commit 39f9b60

Please sign in to comment.