Skip to content

Commit

Permalink
cosmetics
Browse files Browse the repository at this point in the history
  • Loading branch information
MechCoder committed Jul 1, 2016
1 parent 7ee0837 commit 04bdb3a
Showing 1 changed file with 10 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,17 @@ class DecisionTreeRegressorSuite
dt.setMaxDepth(1)
.setMaxBins(6)
.setSeed(0)
val expectVariances = dt.fit(toyDF).transform(toyDF).select("variance").collect().map {
val calculatedVariances = dt.fit(toyDF).transform(toyDF).select("variance").collect().map {
case Row(variance: Double) => variance }
val trueVariances = Array(0.667, 0.667, 0.667, 2.667, 2.667, 2.667)
trueVariances.zip(expectVariances).foreach(x => x._1 ~== x._2 absTol 1e-3)

// Since max depth is set to 1, the best split point is that which splits the data
// into (0.0, 1.0, 2.0) and (10.0, 12.0, 14.0). The predicted variance for each
// data point in the left node is 0.667 and for each data point in the right node
// is 2.667
val expectedVariances = Array(0.667, 0.667, 0.667, 2.667, 2.667, 2.667)
calculatedVariances.zip(expectedVariances).foreach { case (actual, expected) =>
assert(actual ~== expected absTol 1e-3)
}
}

test("Feature importance with toy data") {
Expand Down

0 comments on commit 04bdb3a

Please sign in to comment.