Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
[Machine Learning] Using Google Protocol Buffers for common records
Summary: Also signficant cleanups to the code Test Plan: ghc -Wall *.hs Reviewers: CC:
- Loading branch information
Showing
29 changed files
with
1,826 additions
and
164 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,172 +1,200 @@ | ||
module MachineLearning.DecisionTrees where | ||
|
||
import MachineLearning.Common | ||
|
||
import Data.Vector as V | ||
import qualified Data.List | ||
import Data.Function | ||
module MachineLearning.DecisionTrees | ||
(LossFunction(..), | ||
Examples, | ||
trainBoosting, | ||
predictForest) where | ||
|
||
import Data.Function (on) | ||
import Data.List (and, sortBy) | ||
import Data.Maybe (fromJust, | ||
isJust) | ||
import qualified Data.Sequence as S | ||
import Data.Vector ((!)) | ||
import qualified Data.Vector as V | ||
|
||
-- Protocol Buffer records | ||
import qualified MachineLearning.Protobufs.Example as PB | ||
import qualified MachineLearning.Protobufs.SplittingConstraints as PB | ||
import qualified MachineLearning.Protobufs.TreeNode as PB | ||
import Text.ProtocolBuffers.Header (defaultValue) | ||
|
||
data DecisionTree = Leaf { | ||
value :: Double | ||
_value :: Double | ||
} | Branch { | ||
feature :: Int | ||
, value :: Double | ||
, left :: DecisionTree | ||
, right :: DecisionTree | ||
_feature :: Int | ||
, _value :: Double | ||
, _left :: DecisionTree | ||
, _right :: DecisionTree | ||
} | ||
|
||
data Split = Split { | ||
splitFeature :: Int | ||
, splitValue :: Double | ||
, averageGain :: Double | ||
_splitFeature :: Int | ||
, _splitValue :: Double | ||
, _averageGain :: Double | ||
} | ||
|
||
data LossState = LossState { | ||
averageLabel :: Double | ||
, sumSquaredDivergence :: Double | ||
, numExamples :: Int | ||
_averageLabel :: Double | ||
, _sumSquaredDivergence :: Double | ||
, _numExamples :: Int | ||
} | ||
|
||
informationGain :: Vector Example -> Vector Double | ||
type Examples = V.Vector PB.Example | ||
type Trees = V.Vector DecisionTree | ||
|
||
informationGain :: Examples -> V.Vector Double | ||
informationGain examples = | ||
V.zipWith gain (incrementalLoss examples) (incrementalLoss (V.reverse examples)) where | ||
totalLoss = V.last $ incrementalLoss examples | ||
gain l r = | ||
(sumSquaredDivergence totalLoss - | ||
sumSquaredDivergence l + | ||
sumSquaredDivergence r) / fromIntegral (V.length examples) | ||
gain l r = | ||
(_sumSquaredDivergence totalLoss - | ||
_sumSquaredDivergence l + | ||
_sumSquaredDivergence r) / fromIntegral (V.length examples) | ||
incrementalLoss = V.scanl addExample LossState{ | ||
averageLabel=0 | ||
, sumSquaredDivergence=0 | ||
, numExamples=0 | ||
_averageLabel=0 | ||
, _sumSquaredDivergence=0 | ||
, _numExamples=0 | ||
} | ||
|
||
addExample :: LossState -> Example -> LossState | ||
-- Convenience accessors for label and feature from Protobuf generated | ||
-- code to | ||
label' :: PB.Example -> Double | ||
label' e = fromJust $ PB.label e | ||
|
||
features' :: PB.Example -> V.Vector Double | ||
features' example = V.generate (S.length features) (S.index features) where | ||
features = PB.features example | ||
|
||
asPBTree' :: DecisionTree -> PB.TreeNode | ||
asPBTree' (Leaf value) = defaultValue { PB.leafValue = Just value } | ||
asPBTree' (Branch f v l r) = defaultValue { PB.feature = Just $ fromIntegral f | ||
, PB.splitValue = Just v | ||
, PB.left = Just $ asPBTree' l | ||
, PB.right = Just $ asPBTree' r | ||
} | ||
fromPBTree' :: PB.TreeNode -> DecisionTree | ||
fromPBTree' (PB.TreeNode feature splitValue left right leafValue _) | ||
| isJust leafValue = Leaf $ fromJust leafValue | ||
| otherwise = Branch { _feature=(fromIntegral . fromJust) feature | ||
, _value=fromJust splitValue | ||
, _left=(fromPBTree' . fromJust) left | ||
, _right=(fromPBTree' . fromJust) right | ||
} | ||
|
||
addExample :: LossState -> PB.Example -> LossState | ||
addExample state example = LossState { | ||
numExamples=numExamples state + 1 | ||
, averageLabel=newAverageLabel | ||
, sumSquaredDivergence=newSumSquaredDivergence | ||
_numExamples=_numExamples state + 1 | ||
, _averageLabel=newAverageLabel | ||
, _sumSquaredDivergence=newSumSquaredDivergence | ||
} where | ||
newAverageLabel = averageLabel state + delta / fromIntegral (numExamples state) | ||
delta = label example - averageLabel state | ||
newDelta = label example - newAverageLabel | ||
newSumSquaredDivergence = sumSquaredDivergence state + delta * newDelta | ||
|
||
removeExample :: LossState -> Example -> LossState | ||
removeExample state example = LossState { | ||
numExamples=numExamples state - 1 | ||
, averageLabel=newAverageLabel | ||
, sumSquaredDivergence=newSumSquaredDivergence | ||
} where | ||
newAverageLabel = averageLabel state - delta / fromIntegral (numExamples state) | ||
delta = label example - averageLabel state | ||
newDelta = label example - newAverageLabel | ||
newSumSquaredDivergence = sumSquaredDivergence state - delta * newDelta | ||
newAverageLabel = _averageLabel state + delta / fromIntegral (_numExamples state) | ||
delta = label' example - _averageLabel state | ||
newDelta = label' example - newAverageLabel | ||
newSumSquaredDivergence = _sumSquaredDivergence state + delta * newDelta | ||
|
||
sortFeature :: Vector Example -> Int -> Vector Example | ||
sortFeature :: Examples -> Int -> Examples | ||
sortFeature examples feature = | ||
V.fromList | ||
(Data.List.sortBy | ||
(\l r -> compare (features l ! feature) (features r ! feature)) | ||
(sortBy | ||
(\l r -> compare (features' l ! feature) (features' r ! feature)) | ||
(V.toList examples)) | ||
|
||
findBestSplit :: Vector Example -> Int -> Split | ||
findBestSplit :: Examples -> Int -> Split | ||
findBestSplit examples feature = Split { | ||
splitFeature=feature | ||
, splitValue=features (samples ! splitPoint) ! feature | ||
, averageGain=V.maximum informationGains | ||
_splitFeature=feature | ||
, _splitValue=features' (samples ! splitPoint) ! feature | ||
, _averageGain=V.maximum informationGains | ||
} where | ||
samples = sortFeature examples feature | ||
informationGains = informationGain samples | ||
splitPoint = maxIndex informationGains | ||
splitPoint = V.maxIndex informationGains | ||
|
||
-- TODO(tulloch) - make this more intelligent (support subsampling | ||
-- features for random forests, etc) | ||
getFeatures :: Vector Example -> Vector Int | ||
getFeatures :: Examples -> V.Vector Int | ||
getFeatures examples = V.fromList [0..numFeatures] where | ||
numFeatures = V.length (features $ V.head examples) | ||
|
||
data SplittingConstraint = SplittingConstraint { | ||
maximumLevels :: Maybe Int | ||
, minimumAverageGain :: Maybe Double | ||
, minimumSamplesAtLeaf :: Maybe Int | ||
} | ||
numFeatures = V.length (features' $ V.head examples) | ||
|
||
-- Determines whether a candidate set of splits should happen | ||
shouldSplit :: SplittingConstraint -> Int -> Vector Example -> Split -> Bool | ||
shouldSplit constraint currentLevel currentExamples candidateSplit = | ||
Data.List.and constraints where | ||
constraints = [ | ||
case maximumLevels constraint of | ||
Nothing -> True | ||
Just maxLevels -> currentLevel < maxLevels, | ||
case minimumAverageGain constraint of | ||
Nothing -> True | ||
Just minGain -> minGain < averageGain candidateSplit, | ||
case minimumSamplesAtLeaf constraint of | ||
Nothing -> True | ||
Just minSamples -> minSamples < V.length currentExamples] | ||
buildTreeAtLevel leafWeight splittingConstraint level examples = | ||
if shouldSplit splittingConstraint level examples bestSplit | ||
shouldSplit :: PB.SplittingConstraints -> Int -> Examples -> Split -> Bool | ||
shouldSplit constraint currentLevel currentExamples candidateSplit = and constraints where | ||
constraints = [ | ||
case PB.maximumLevels constraint of | ||
Nothing -> True | ||
Just maxLevels -> fromIntegral currentLevel < maxLevels, | ||
case PB.minimumAverageGain constraint of | ||
Nothing -> True | ||
Just minGain -> minGain < _averageGain candidateSplit, | ||
case PB.minimumSamplesAtLeaf constraint of | ||
Nothing -> True | ||
Just minSamples -> minSamples < (fromIntegral . V.length) currentExamples] | ||
|
||
buildTreeAtLevel :: (Examples -> Double) -> PB.SplittingConstraints -> Int -> Examples -> DecisionTree | ||
buildTreeAtLevel leafWeight splittingConstraints level examples = | ||
if shouldSplit splittingConstraints level examples bestSplit | ||
then Branch { | ||
feature=splitFeature bestSplit | ||
, value=splitValue bestSplit | ||
, left=recur $ V.takeWhile takePredicate orderedExamples | ||
, right=recur $ V.dropWhile takePredicate orderedExamples | ||
_feature=_splitFeature bestSplit | ||
, _value=_splitValue bestSplit | ||
, _left=recur $ V.takeWhile takePredicate orderedExamples | ||
, _right=recur $ V.dropWhile takePredicate orderedExamples | ||
} | ||
else Leaf (leafWeight examples) where | ||
-- candidate splits | ||
candidates = V.map (findBestSplit examples) (getFeatures examples) | ||
-- best candidate from all the features | ||
bestSplit = V.maximumBy (compare `on` averageGain) candidates | ||
bestSplit = V.maximumBy (compare `on` _averageGain) candidates | ||
-- sort the examples at this branch by the best feature | ||
orderedExamples = sortFeature examples (splitFeature bestSplit) | ||
orderedExamples = sortFeature examples (_splitFeature bestSplit) | ||
-- left branch takes <, right branch takes > | ||
takePredicate ex = features ex ! splitFeature bestSplit < splitValue bestSplit | ||
takePredicate ex = features' ex ! _splitFeature bestSplit < _splitValue bestSplit | ||
-- construct the next level of the tree | ||
recur = buildTreeAtLevel leafWeight splittingConstraint (level + 1) | ||
recur = buildTreeAtLevel leafWeight splittingConstraints (level + 1) | ||
|
||
buildTree :: (Examples -> Double) -> PB.SplittingConstraints -> Examples -> DecisionTree | ||
buildTree leafWeight splittingConstraints = buildTreeAtLevel leafWeight splittingConstraints 0 | ||
|
||
buildTree leafWeight splittingConstraint = buildTreeAtLevel leafWeight splittingConstraint 0 | ||
predict' :: DecisionTree -> V.Vector Double -> Double | ||
predict' (Leaf v) _ = v | ||
predict' (Branch f v l r) featureVector = | ||
if featureVector ! f < v then predict' l featureVector else predict' r featureVector | ||
|
||
predict :: DecisionTree -> Vector Double -> Double | ||
predict (Leaf v) _ = v | ||
predict (Branch f v l r) featureVector = | ||
if featureVector ! f < v then predict l featureVector else predict r featureVector | ||
predictForest' :: Trees -> V.Vector Double -> Double | ||
predictForest' trees featureVector = V.sum (V.map (`predict'` featureVector) trees) | ||
|
||
predictForest :: Vector DecisionTree -> Vector Double -> Double | ||
predictForest trees featureVector = V.sum (V.map (`predict` featureVector) trees) | ||
predictForest :: V.Vector PB.TreeNode -> V.Vector Double -> Double | ||
predictForest trees featureVector = predictForest' (V.map fromPBTree' trees) featureVector | ||
|
||
-- Typeclass for a given loss function | ||
-- Typeclass representing a usable loss function | ||
class LossFunction a where | ||
prior :: a -> Vector Example -> Double | ||
leaf :: a -> Vector Example -> Double | ||
weight :: a -> Vector DecisionTree -> Example -> Double | ||
-- Prior maps | ||
prior :: a -> Examples -> Double | ||
leaf :: a -> Examples -> Double | ||
weight :: a -> Trees -> PB.Example -> Double | ||
|
||
data LogitLoss = LogitLoss deriving (Show, Eq) | ||
|
||
-- From Algorithm 5 in http://www-stat.stanford.edu/~jhf/ftp/trebst.pdf | ||
instance LossFunction LogitLoss where | ||
instance LossFunction LogitLoss where | ||
prior _ examples = 0.5 * log ((1 + averageLabel) / (1 - averageLabel)) where | ||
averageLabel = V.sum (V.map label examples) / fromIntegral (V.length examples) | ||
averageLabel = V.sum (V.map label' examples) / fromIntegral (V.length examples) | ||
|
||
leaf _ examples = numerator / denominator where | ||
numerator = V.sum (V.map label examples) | ||
numerator = V.sum (V.map label' examples) | ||
denominator = V.sum (V.map (\e -> abs (label' e) * (2 - abs (label' e))) examples) | ||
|
||
denominator = V.sum (V.map (\e -> abs (label e) * (2 - abs (label e))) examples) | ||
|
||
weight _ trees example = (2 * label example) / | ||
(1 + exp (2 * label example * predictForest trees (features example))) | ||
weight _ trees example = (2 * label' example) / | ||
(1 + exp (2 * label' example * predictForest' trees (features' example))) | ||
|
||
runBoostingRound :: (LossFunction a) => a -> SplittingConstraint -> Vector Example -> Vector DecisionTree -> DecisionTree | ||
runBoostingRound lossFunction splittingConstraint examples forest = | ||
buildTree (leaf lossFunction) splittingConstraint weightedExamples where | ||
weightedExamples = V.map (\e -> e {label=weightedLabel e}) examples | ||
|
||
runBoostingRound :: LossFunction a => a -> PB.SplittingConstraints -> Examples -> Trees -> DecisionTree | ||
runBoostingRound lossFunction splittingConstraints examples forest = | ||
buildTree (leaf lossFunction) splittingConstraints weightedExamples where | ||
weightedExamples = V.map (\e -> e {PB.label=Just (weightedLabel e)}) examples | ||
weightedLabel = weight lossFunction forest | ||
boosting :: (LossFunction a) => a -> Int -> SplittingConstraint -> Vector Example -> Vector DecisionTree | ||
boosting lossFunction numRounds splittingConstraint examples = | ||
V.foldl addTree (V.singleton priorTree) (V.replicate numRounds 0) where | ||
priorTree = Leaf (prior lossFunction examples) | ||
addTree currentForest _ = V.snoc currentForest weakLearner where | ||
weakLearner = runBoostingRound lossFunction splittingConstraint examples currentForest | ||
|
||
trainBoosting :: (LossFunction a) => a -> Int -> PB.SplittingConstraints -> Examples -> V.Vector PB.TreeNode | ||
trainBoosting lossFunction numRounds splittingConstraints examples = V.map asPBTree' trees where | ||
trees = V.foldl addTree (V.singleton priorTree) (V.replicate numRounds (0 :: Int)) | ||
priorTree = Leaf (prior lossFunction examples) | ||
addTree currentForest _ = V.snoc currentForest weakLearner where | ||
weakLearner = runBoostingRound lossFunction splittingConstraints examples currentForest |
Oops, something went wrong.