Skip to content

Commit

Permalink
For comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Jan 21, 2016
1 parent 2c32350 commit cd25214
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,11 @@ private[ml] object RandomForest extends Logging {
val categoryStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val centroid = if (categoryStats.count != 0) {
categoryStats.prob(categoryStats.predict)
if (categoryStats.count == 2) {
categoryStats.stats(1)
} else {
categoryStats.predict
}
} else {
Double.MaxValue
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,11 @@ object DecisionTree extends Serializable with Logging {
val categoryStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val centroid = if (categoryStats.count != 0) {
categoryStats.prob(categoryStats.predict)
if (categoryStats.count == 2) {
categoryStats.stats(1)
} else {
categoryStats.predict
}
} else {
Double.MaxValue
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val rootNode = DecisionTree.train(rdd, strategy).topNode

val split = rootNode.split.get
assert(split.categories === List(0.0))
assert(split.categories === List(1.0))
assert(split.featureType === Categorical)
assert(split.threshold === Double.MinValue)

Expand Down Expand Up @@ -520,7 +520,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {

val split = rootNode.split.get
assert(split.categories.length === 1)
assert(split.categories.contains(0.0))
assert(split.categories.contains(1.0))
assert(split.featureType === Categorical)
assert(split.threshold === Double.MinValue)

Expand Down

0 comments on commit cd25214

Please sign in to comment.