Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
qiping.lqp committed Sep 9, 2014
1 parent e72c7e4 commit 46b891f
Showing 1 changed file with 11 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -836,10 +836,11 @@ object DecisionTree extends Serializable with Logging {
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
(splitIdx, gainStats)
}.maxBy(_._2.gain)
if (bestFeatureGainStats == InformationGainStats.invalidInformationGainStats) {
if (bestFeatureGainStats.gain < metadata.minInfoGain) {
(Split.noSplit, InformationGainStats.invalidInformationGainStats)
} else {
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
}
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else if (metadata.isUnordered(featureIndex)) {
// Unordered categorical feature
val (leftChildOffset, rightChildOffset) =
Expand All @@ -855,8 +856,9 @@ object DecisionTree extends Serializable with Logging {
}.maxBy(_._2.gain)
if (bestFeatureGainStats == InformationGainStats.invalidInformationGainStats) {
(Split.noSplit, InformationGainStats.invalidInformationGainStats)
} else {
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
}
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else {
// Ordered categorical feature
val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndex)
Expand Down Expand Up @@ -930,12 +932,13 @@ object DecisionTree extends Serializable with Logging {
}.maxBy(_._2.gain)
if (bestFeatureGainStats == InformationGainStats.invalidInformationGainStats) {
(Split.noSplit, InformationGainStats.invalidInformationGainStats)
} else {
val categoriesForSplit =
categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
val bestFeatureSplit =
new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
(bestFeatureSplit, bestFeatureGainStats)
}
val categoriesForSplit =
categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
val bestFeatureSplit =
new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
(bestFeatureSplit, bestFeatureGainStats)
}
}.maxBy(_._2.gain)

Expand Down

0 comments on commit 46b891f

Please sign in to comment.