Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
decisiontrees/regression_splitter.go
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
203 lines (175 sloc)
5.65 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package decisiontrees | |
import ( | |
"code.google.com/p/goprotobuf/proto" | |
pb "github.com/ajtulloch/decisiontrees/protobufs" | |
"github.com/golang/glog" | |
"sync" | |
) | |
type lossState struct { | |
averageLabel float64 | |
sumSquaredDivergence float64 | |
numExamples int | |
} | |
func constructLoss(e Examples) *lossState { | |
l := &lossState{} | |
for _, ex := range e { | |
l.addExample(ex) | |
} | |
return l | |
} | |
func (l *lossState) addExample(e *pb.Example) { | |
l.numExamples += 1 | |
delta := e.GetWeightedLabel() - l.averageLabel | |
l.averageLabel += delta / float64(l.numExamples) | |
newDelta := e.GetWeightedLabel() - l.averageLabel | |
l.sumSquaredDivergence += delta * newDelta | |
} | |
func (l *lossState) removeExample(e *pb.Example) { | |
l.numExamples -= 1 | |
delta := e.GetWeightedLabel() - l.averageLabel | |
l.averageLabel -= delta / float64(l.numExamples) | |
newDelta := e.GetWeightedLabel() - l.averageLabel | |
l.sumSquaredDivergence -= delta * newDelta | |
} | |
type regressionSplitter struct { | |
leafWeight func(e Examples) float64 | |
featureSelector FeatureSelector | |
splittingConstraints *pb.SplittingConstraints | |
shrinkageConfig *pb.ShrinkageConfig | |
} | |
func (c *regressionSplitter) shouldSplit( | |
examples Examples, | |
bestSplit split, | |
currentLevel int64) bool { | |
if len(examples) <= 1 { | |
glog.Infof("Num examples is %v, terminating", len(examples)) | |
return false | |
} | |
if bestSplit.index == 0 || bestSplit.index == len(examples) { | |
glog.Infof("Empty branch with bestSplit = %v, numExamples = %v, terminating", bestSplit, len(examples)) | |
return false | |
} | |
maximumLevels := c.splittingConstraints.MaximumLevels | |
if maximumLevels != nil && *maximumLevels < currentLevel { | |
glog.Infof("Maximum levels is %v < %v currentLevel", *maximumLevels, currentLevel) | |
return false | |
} | |
minAverageGain := c.splittingConstraints.MinimumAverageGain | |
if minAverageGain != nil && *minAverageGain > bestSplit.gain/float64(len(examples)) { | |
return false | |
} | |
minSamplesAtLeaf := c.splittingConstraints.MinimumSamplesAtLeaf | |
if minSamplesAtLeaf != nil && *minSamplesAtLeaf > int64(len(examples)) { | |
return false | |
} | |
return true | |
} | |
type split struct { | |
feature int | |
index int | |
gain float64 | |
} | |
func getBestSplit(examples Examples, feature int) split { | |
examplesCopy := make([]*pb.Example, len(examples)) | |
if copy(examplesCopy, examples) != len(examples) { | |
glog.Fatal("Failed copying all examples for sorting") | |
} | |
by(func(e1, e2 *pb.Example) bool { | |
return e1.Features[feature] < e2.Features[feature] | |
}).Sort(Examples(examplesCopy)) | |
leftLoss := constructLoss(Examples{}) | |
rightLoss := constructLoss(examplesCopy) | |
totalLoss := constructLoss(examplesCopy) | |
bestSplit := split{ | |
feature: feature, | |
} | |
for index, example := range examplesCopy { | |
func() { | |
if index == 0 { | |
return | |
} | |
previousValue := examplesCopy[index-1].Features[feature] | |
currentValue := example.Features[feature] | |
if previousValue == currentValue { | |
return | |
} | |
gain := totalLoss.sumSquaredDivergence - | |
leftLoss.sumSquaredDivergence - | |
rightLoss.sumSquaredDivergence | |
if gain > bestSplit.gain { | |
bestSplit.gain = gain | |
bestSplit.index = index | |
} | |
}() | |
leftLoss.addExample(example) | |
rightLoss.removeExample(example) | |
} | |
return bestSplit | |
} | |
func (c *regressionSplitter) generateTree(examples Examples, currentLevel int64) *pb.TreeNode { | |
glog.Infof("Generating tree at level %v with %v examples", currentLevel, len(examples)) | |
glog.V(2).Infof("Generating with examples %+v", currentLevel, examples) | |
features := c.featureSelector.getFeatures(examples) | |
candidateSplits := make(chan split, len(features)) | |
for _, feature := range features { | |
go func(feature int) { | |
candidateSplits <- getBestSplit(examples, feature) | |
}(feature) | |
} | |
bestSplit := split{} | |
for _ = range features { | |
candidateSplit := <-candidateSplits | |
if candidateSplit.gain > bestSplit.gain { | |
bestSplit = candidateSplit | |
} | |
} | |
if c.shouldSplit(examples, bestSplit, currentLevel) { | |
glog.Infof("Splitting at level %v with split %v", currentLevel, bestSplit) | |
by(func(e1, e2 *pb.Example) bool { | |
return e1.Features[bestSplit.feature] < e2.Features[bestSplit.feature] | |
}).Sort(examples) | |
bestValue := 0.5 * (examples[bestSplit.index-1].Features[bestSplit.feature] + | |
examples[bestSplit.index].Features[bestSplit.feature]) | |
tree := &pb.TreeNode{ | |
Feature: proto.Int64(int64(bestSplit.feature)), | |
SplitValue: proto.Float64(bestValue), | |
Annotation: &pb.Annotation{ | |
NumExamples: proto.Int64(int64(len(examples))), | |
AverageGain: proto.Float64(bestSplit.gain / float64(len(examples))), | |
LeftFraction: proto.Float64(float64(bestSplit.index) / float64(len(examples))), | |
}, | |
} | |
// Recur down the left and right branches in parallel | |
w := sync.WaitGroup{} | |
recur := func(child **pb.TreeNode, e Examples) { | |
w.Add(1) | |
go func() { | |
*child = c.generateTree(e, currentLevel+1) | |
w.Done() | |
}() | |
} | |
recur(&tree.Left, examples[bestSplit.index:]) | |
recur(&tree.Right, examples[:bestSplit.index]) | |
w.Wait() | |
return tree | |
} | |
glog.Infof("Terminating at level %v with %v examples", currentLevel, len(examples)) | |
glog.V(2).Infof("Terminating with examples: %v", examples) | |
// Otherwise, return the leaf | |
leafWeight := c.leafWeight(examples) | |
shrinkage := 1.0 | |
if c.shrinkageConfig != nil && c.shrinkageConfig.Shrinkage != nil { | |
shrinkage = c.shrinkageConfig.GetShrinkage() | |
} | |
glog.Infof("Leaf weight: %v, shrinkage: %v", leafWeight, shrinkage) | |
return &pb.TreeNode{ | |
LeafValue: proto.Float64(leafWeight * shrinkage), | |
Annotation: &pb.Annotation{ | |
NumExamples: proto.Int64(int64(len(examples))), | |
}, | |
} | |
} | |
// GenerateTree generates a regression tree on the examples given | |
func (c *regressionSplitter) GenerateTree(examples Examples) *pb.TreeNode { | |
return c.generateTree(examples, 0) | |
} |