Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
[Machine Learning] Using Google Protocol Buffers for common records
Summary:
Also signficant cleanups to the code

Test Plan:
ghc -Wall *.hs

Reviewers:

CC:
  • Loading branch information
ajtulloch committed Nov 3, 2013
1 parent 9ae2311 commit 42405ad
Show file tree
Hide file tree
Showing 29 changed files with 1,826 additions and 164 deletions.
4 changes: 2 additions & 2 deletions MachineLearning.cabal
Expand Up @@ -48,7 +48,7 @@ cabal-version: >=1.10

library
-- Modules exported by the library.
exposed-modules: MachineLearning.Common, MachineLearning.DecisionTrees, MachineLearning.LogisticRegression
exposed-modules: MachineLearning.DecisionTrees, MachineLearning.LogisticRegression

-- Modules included in this library but not exported.
-- other-modules:
Expand All @@ -57,7 +57,7 @@ library
-- other-extensions:

-- Other library packages from which modules are imported.
build-depends: base >=4.6 && <4.7, vector >=0.10 && <0.11
build-depends: base >=4.6 && <4.7, vector >=0.10 && <0.11, containers, protocol-buffers

-- Directories containing source files.
-- hs-source-dirs:
Expand Down
8 changes: 0 additions & 8 deletions MachineLearning/Common.hs

This file was deleted.

258 changes: 143 additions & 115 deletions MachineLearning/DecisionTrees.hs
@@ -1,172 +1,200 @@
module MachineLearning.DecisionTrees where

import MachineLearning.Common

import Data.Vector as V
import qualified Data.List
import Data.Function
module MachineLearning.DecisionTrees
(LossFunction(..),
Examples,
trainBoosting,
predictForest) where

import Data.Function (on)
import Data.List (and, sortBy)
import Data.Maybe (fromJust,
isJust)
import qualified Data.Sequence as S
import Data.Vector ((!))
import qualified Data.Vector as V

-- Protocol Buffer records
import qualified MachineLearning.Protobufs.Example as PB
import qualified MachineLearning.Protobufs.SplittingConstraints as PB
import qualified MachineLearning.Protobufs.TreeNode as PB
import Text.ProtocolBuffers.Header (defaultValue)

data DecisionTree = Leaf {
value :: Double
_value :: Double
} | Branch {
feature :: Int
, value :: Double
, left :: DecisionTree
, right :: DecisionTree
_feature :: Int
, _value :: Double
, _left :: DecisionTree
, _right :: DecisionTree
}

data Split = Split {
splitFeature :: Int
, splitValue :: Double
, averageGain :: Double
_splitFeature :: Int
, _splitValue :: Double
, _averageGain :: Double
}

data LossState = LossState {
averageLabel :: Double
, sumSquaredDivergence :: Double
, numExamples :: Int
_averageLabel :: Double
, _sumSquaredDivergence :: Double
, _numExamples :: Int
}

informationGain :: Vector Example -> Vector Double
type Examples = V.Vector PB.Example
type Trees = V.Vector DecisionTree

informationGain :: Examples -> V.Vector Double
informationGain examples =
V.zipWith gain (incrementalLoss examples) (incrementalLoss (V.reverse examples)) where
totalLoss = V.last $ incrementalLoss examples
gain l r =
(sumSquaredDivergence totalLoss -
sumSquaredDivergence l +
sumSquaredDivergence r) / fromIntegral (V.length examples)
gain l r =
(_sumSquaredDivergence totalLoss -
_sumSquaredDivergence l +
_sumSquaredDivergence r) / fromIntegral (V.length examples)
incrementalLoss = V.scanl addExample LossState{
averageLabel=0
, sumSquaredDivergence=0
, numExamples=0
_averageLabel=0
, _sumSquaredDivergence=0
, _numExamples=0
}

addExample :: LossState -> Example -> LossState
-- Convenience accessors for label and feature from Protobuf generated
-- code to
label' :: PB.Example -> Double
label' e = fromJust $ PB.label e

features' :: PB.Example -> V.Vector Double
features' example = V.generate (S.length features) (S.index features) where
features = PB.features example

asPBTree' :: DecisionTree -> PB.TreeNode
asPBTree' (Leaf value) = defaultValue { PB.leafValue = Just value }
asPBTree' (Branch f v l r) = defaultValue { PB.feature = Just $ fromIntegral f
, PB.splitValue = Just v
, PB.left = Just $ asPBTree' l
, PB.right = Just $ asPBTree' r
}
fromPBTree' :: PB.TreeNode -> DecisionTree
fromPBTree' (PB.TreeNode feature splitValue left right leafValue _)
| isJust leafValue = Leaf $ fromJust leafValue
| otherwise = Branch { _feature=(fromIntegral . fromJust) feature
, _value=fromJust splitValue
, _left=(fromPBTree' . fromJust) left
, _right=(fromPBTree' . fromJust) right
}

addExample :: LossState -> PB.Example -> LossState
addExample state example = LossState {
numExamples=numExamples state + 1
, averageLabel=newAverageLabel
, sumSquaredDivergence=newSumSquaredDivergence
_numExamples=_numExamples state + 1
, _averageLabel=newAverageLabel
, _sumSquaredDivergence=newSumSquaredDivergence
} where
newAverageLabel = averageLabel state + delta / fromIntegral (numExamples state)
delta = label example - averageLabel state
newDelta = label example - newAverageLabel
newSumSquaredDivergence = sumSquaredDivergence state + delta * newDelta

removeExample :: LossState -> Example -> LossState
removeExample state example = LossState {
numExamples=numExamples state - 1
, averageLabel=newAverageLabel
, sumSquaredDivergence=newSumSquaredDivergence
} where
newAverageLabel = averageLabel state - delta / fromIntegral (numExamples state)
delta = label example - averageLabel state
newDelta = label example - newAverageLabel
newSumSquaredDivergence = sumSquaredDivergence state - delta * newDelta
newAverageLabel = _averageLabel state + delta / fromIntegral (_numExamples state)
delta = label' example - _averageLabel state
newDelta = label' example - newAverageLabel
newSumSquaredDivergence = _sumSquaredDivergence state + delta * newDelta

sortFeature :: Vector Example -> Int -> Vector Example
sortFeature :: Examples -> Int -> Examples
sortFeature examples feature =
V.fromList
(Data.List.sortBy
(\l r -> compare (features l ! feature) (features r ! feature))
(sortBy
(\l r -> compare (features' l ! feature) (features' r ! feature))
(V.toList examples))

findBestSplit :: Vector Example -> Int -> Split
findBestSplit :: Examples -> Int -> Split
findBestSplit examples feature = Split {
splitFeature=feature
, splitValue=features (samples ! splitPoint) ! feature
, averageGain=V.maximum informationGains
_splitFeature=feature
, _splitValue=features' (samples ! splitPoint) ! feature
, _averageGain=V.maximum informationGains
} where
samples = sortFeature examples feature
informationGains = informationGain samples
splitPoint = maxIndex informationGains
splitPoint = V.maxIndex informationGains

-- TODO(tulloch) - make this more intelligent (support subsampling
-- features for random forests, etc)
getFeatures :: Vector Example -> Vector Int
getFeatures :: Examples -> V.Vector Int
getFeatures examples = V.fromList [0..numFeatures] where
numFeatures = V.length (features $ V.head examples)

data SplittingConstraint = SplittingConstraint {
maximumLevels :: Maybe Int
, minimumAverageGain :: Maybe Double
, minimumSamplesAtLeaf :: Maybe Int
}
numFeatures = V.length (features' $ V.head examples)

-- Determines whether a candidate set of splits should happen
shouldSplit :: SplittingConstraint -> Int -> Vector Example -> Split -> Bool
shouldSplit constraint currentLevel currentExamples candidateSplit =
Data.List.and constraints where
constraints = [
case maximumLevels constraint of
Nothing -> True
Just maxLevels -> currentLevel < maxLevels,
case minimumAverageGain constraint of
Nothing -> True
Just minGain -> minGain < averageGain candidateSplit,
case minimumSamplesAtLeaf constraint of
Nothing -> True
Just minSamples -> minSamples < V.length currentExamples]
buildTreeAtLevel leafWeight splittingConstraint level examples =
if shouldSplit splittingConstraint level examples bestSplit
shouldSplit :: PB.SplittingConstraints -> Int -> Examples -> Split -> Bool
shouldSplit constraint currentLevel currentExamples candidateSplit = and constraints where
constraints = [
case PB.maximumLevels constraint of
Nothing -> True
Just maxLevels -> fromIntegral currentLevel < maxLevels,
case PB.minimumAverageGain constraint of
Nothing -> True
Just minGain -> minGain < _averageGain candidateSplit,
case PB.minimumSamplesAtLeaf constraint of
Nothing -> True
Just minSamples -> minSamples < (fromIntegral . V.length) currentExamples]

buildTreeAtLevel :: (Examples -> Double) -> PB.SplittingConstraints -> Int -> Examples -> DecisionTree
buildTreeAtLevel leafWeight splittingConstraints level examples =
if shouldSplit splittingConstraints level examples bestSplit
then Branch {
feature=splitFeature bestSplit
, value=splitValue bestSplit
, left=recur $ V.takeWhile takePredicate orderedExamples
, right=recur $ V.dropWhile takePredicate orderedExamples
_feature=_splitFeature bestSplit
, _value=_splitValue bestSplit
, _left=recur $ V.takeWhile takePredicate orderedExamples
, _right=recur $ V.dropWhile takePredicate orderedExamples
}
else Leaf (leafWeight examples) where
-- candidate splits
candidates = V.map (findBestSplit examples) (getFeatures examples)
-- best candidate from all the features
bestSplit = V.maximumBy (compare `on` averageGain) candidates
bestSplit = V.maximumBy (compare `on` _averageGain) candidates
-- sort the examples at this branch by the best feature
orderedExamples = sortFeature examples (splitFeature bestSplit)
orderedExamples = sortFeature examples (_splitFeature bestSplit)
-- left branch takes <, right branch takes >
takePredicate ex = features ex ! splitFeature bestSplit < splitValue bestSplit
takePredicate ex = features' ex ! _splitFeature bestSplit < _splitValue bestSplit
-- construct the next level of the tree
recur = buildTreeAtLevel leafWeight splittingConstraint (level + 1)
recur = buildTreeAtLevel leafWeight splittingConstraints (level + 1)

buildTree :: (Examples -> Double) -> PB.SplittingConstraints -> Examples -> DecisionTree
buildTree leafWeight splittingConstraints = buildTreeAtLevel leafWeight splittingConstraints 0

buildTree leafWeight splittingConstraint = buildTreeAtLevel leafWeight splittingConstraint 0
predict' :: DecisionTree -> V.Vector Double -> Double
predict' (Leaf v) _ = v
predict' (Branch f v l r) featureVector =
if featureVector ! f < v then predict' l featureVector else predict' r featureVector

predict :: DecisionTree -> Vector Double -> Double
predict (Leaf v) _ = v
predict (Branch f v l r) featureVector =
if featureVector ! f < v then predict l featureVector else predict r featureVector
predictForest' :: Trees -> V.Vector Double -> Double
predictForest' trees featureVector = V.sum (V.map (`predict'` featureVector) trees)

predictForest :: Vector DecisionTree -> Vector Double -> Double
predictForest trees featureVector = V.sum (V.map (`predict` featureVector) trees)
predictForest :: V.Vector PB.TreeNode -> V.Vector Double -> Double
predictForest trees featureVector = predictForest' (V.map fromPBTree' trees) featureVector

-- Typeclass for a given loss function
-- Typeclass representing a usable loss function
class LossFunction a where
prior :: a -> Vector Example -> Double
leaf :: a -> Vector Example -> Double
weight :: a -> Vector DecisionTree -> Example -> Double
-- Prior maps
prior :: a -> Examples -> Double
leaf :: a -> Examples -> Double
weight :: a -> Trees -> PB.Example -> Double

data LogitLoss = LogitLoss deriving (Show, Eq)

-- From Algorithm 5 in http://www-stat.stanford.edu/~jhf/ftp/trebst.pdf
instance LossFunction LogitLoss where
instance LossFunction LogitLoss where
prior _ examples = 0.5 * log ((1 + averageLabel) / (1 - averageLabel)) where
averageLabel = V.sum (V.map label examples) / fromIntegral (V.length examples)
averageLabel = V.sum (V.map label' examples) / fromIntegral (V.length examples)

leaf _ examples = numerator / denominator where
numerator = V.sum (V.map label examples)
numerator = V.sum (V.map label' examples)
denominator = V.sum (V.map (\e -> abs (label' e) * (2 - abs (label' e))) examples)

denominator = V.sum (V.map (\e -> abs (label e) * (2 - abs (label e))) examples)

weight _ trees example = (2 * label example) /
(1 + exp (2 * label example * predictForest trees (features example)))
weight _ trees example = (2 * label' example) /
(1 + exp (2 * label' example * predictForest' trees (features' example)))

runBoostingRound :: (LossFunction a) => a -> SplittingConstraint -> Vector Example -> Vector DecisionTree -> DecisionTree
runBoostingRound lossFunction splittingConstraint examples forest =
buildTree (leaf lossFunction) splittingConstraint weightedExamples where
weightedExamples = V.map (\e -> e {label=weightedLabel e}) examples

runBoostingRound :: LossFunction a => a -> PB.SplittingConstraints -> Examples -> Trees -> DecisionTree
runBoostingRound lossFunction splittingConstraints examples forest =
buildTree (leaf lossFunction) splittingConstraints weightedExamples where
weightedExamples = V.map (\e -> e {PB.label=Just (weightedLabel e)}) examples
weightedLabel = weight lossFunction forest
boosting :: (LossFunction a) => a -> Int -> SplittingConstraint -> Vector Example -> Vector DecisionTree
boosting lossFunction numRounds splittingConstraint examples =
V.foldl addTree (V.singleton priorTree) (V.replicate numRounds 0) where
priorTree = Leaf (prior lossFunction examples)
addTree currentForest _ = V.snoc currentForest weakLearner where
weakLearner = runBoostingRound lossFunction splittingConstraint examples currentForest

trainBoosting :: (LossFunction a) => a -> Int -> PB.SplittingConstraints -> Examples -> V.Vector PB.TreeNode
trainBoosting lossFunction numRounds splittingConstraints examples = V.map asPBTree' trees where
trees = V.foldl addTree (V.singleton priorTree) (V.replicate numRounds (0 :: Int))
priorTree = Leaf (prior lossFunction examples)
addTree currentForest _ = V.snoc currentForest weakLearner where
weakLearner = runBoostingRound lossFunction splittingConstraints examples currentForest

0 comments on commit 42405ad

Please sign in to comment.