-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 9ae2311
Showing
7 changed files
with
323 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
.cabal-sandbox/ | ||
cabal.sandbox.config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License (MIT) | ||
|
||
Copyright (c) 2013 Andrew Tulloch | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in | ||
all copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | ||
THE SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
-- Initial MachineLearning.cabal generated by cabal init. For further | ||
-- documentation, see http://haskell.org/cabal/users-guide/ | ||
|
||
-- The name of the package. | ||
name: MachineLearning | ||
|
||
-- The package version. See the Haskell package versioning policy (PVP) | ||
-- for standards guiding when and how versions should be incremented. | ||
-- http://www.haskell.org/haskellwiki/Package_versioning_policy | ||
-- PVP summary: +-+------- breaking API changes | ||
-- | | +----- non-breaking API additions | ||
-- | | | +--- code changes with no API change | ||
version: 0.1.0.0 | ||
|
||
-- A short (one-line) description of the package. | ||
-- synopsis: | ||
|
||
-- A longer description of the package. | ||
-- description: | ||
|
||
-- The license under which the package is released. | ||
license: MIT | ||
|
||
-- The file containing the license text. | ||
license-file: LICENSE | ||
|
||
-- The package author(s). | ||
author: Andrew Tulloch | ||
|
||
-- An email address to which users can send suggestions, bug reports, and | ||
-- patches. | ||
maintainer: andrew@tullo.ch | ||
|
||
-- A copyright notice. | ||
-- copyright: | ||
|
||
category: Math | ||
|
||
build-type: Simple | ||
|
||
-- Extra files to be distributed with the package, such as examples or a | ||
-- README. | ||
-- extra-source-files: | ||
|
||
-- Constraint on the version of Cabal needed to build this package. | ||
cabal-version: >=1.10 | ||
|
||
|
||
library | ||
-- Modules exported by the library. | ||
exposed-modules: MachineLearning.Common, MachineLearning.DecisionTrees, MachineLearning.LogisticRegression | ||
|
||
-- Modules included in this library but not exported. | ||
-- other-modules: | ||
|
||
-- LANGUAGE extensions used by modules in this package. | ||
-- other-extensions: | ||
|
||
-- Other library packages from which modules are imported. | ||
build-depends: base >=4.6 && <4.7, vector >=0.10 && <0.11 | ||
|
||
-- Directories containing source files. | ||
-- hs-source-dirs: | ||
|
||
-- Base language which the package is written in. | ||
default-language: Haskell2010 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
module MachineLearning.Common where | ||
|
||
import Data.Vector as V | ||
|
||
data Example = Example { | ||
features :: V.Vector Double | ||
, label :: Double | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
module MachineLearning.DecisionTrees where | ||
|
||
import MachineLearning.Common | ||
|
||
import Data.Vector as V | ||
import qualified Data.List | ||
import Data.Function | ||
|
||
data DecisionTree = Leaf { | ||
value :: Double | ||
} | Branch { | ||
feature :: Int | ||
, value :: Double | ||
, left :: DecisionTree | ||
, right :: DecisionTree | ||
} | ||
|
||
data Split = Split { | ||
splitFeature :: Int | ||
, splitValue :: Double | ||
, averageGain :: Double | ||
} | ||
|
||
data LossState = LossState { | ||
averageLabel :: Double | ||
, sumSquaredDivergence :: Double | ||
, numExamples :: Int | ||
} | ||
|
||
informationGain :: Vector Example -> Vector Double | ||
informationGain examples = | ||
V.zipWith gain (incrementalLoss examples) (incrementalLoss (V.reverse examples)) where | ||
totalLoss = V.last $ incrementalLoss examples | ||
gain l r = | ||
(sumSquaredDivergence totalLoss - | ||
sumSquaredDivergence l + | ||
sumSquaredDivergence r) / fromIntegral (V.length examples) | ||
incrementalLoss = V.scanl addExample LossState{ | ||
averageLabel=0 | ||
, sumSquaredDivergence=0 | ||
, numExamples=0 | ||
} | ||
|
||
addExample :: LossState -> Example -> LossState | ||
addExample state example = LossState { | ||
numExamples=numExamples state + 1 | ||
, averageLabel=newAverageLabel | ||
, sumSquaredDivergence=newSumSquaredDivergence | ||
} where | ||
newAverageLabel = averageLabel state + delta / fromIntegral (numExamples state) | ||
delta = label example - averageLabel state | ||
newDelta = label example - newAverageLabel | ||
newSumSquaredDivergence = sumSquaredDivergence state + delta * newDelta | ||
|
||
removeExample :: LossState -> Example -> LossState | ||
removeExample state example = LossState { | ||
numExamples=numExamples state - 1 | ||
, averageLabel=newAverageLabel | ||
, sumSquaredDivergence=newSumSquaredDivergence | ||
} where | ||
newAverageLabel = averageLabel state - delta / fromIntegral (numExamples state) | ||
delta = label example - averageLabel state | ||
newDelta = label example - newAverageLabel | ||
newSumSquaredDivergence = sumSquaredDivergence state - delta * newDelta | ||
|
||
sortFeature :: Vector Example -> Int -> Vector Example | ||
sortFeature examples feature = | ||
V.fromList | ||
(Data.List.sortBy | ||
(\l r -> compare (features l ! feature) (features r ! feature)) | ||
(V.toList examples)) | ||
|
||
findBestSplit :: Vector Example -> Int -> Split | ||
findBestSplit examples feature = Split { | ||
splitFeature=feature | ||
, splitValue=features (samples ! splitPoint) ! feature | ||
, averageGain=V.maximum informationGains | ||
} where | ||
samples = sortFeature examples feature | ||
informationGains = informationGain samples | ||
splitPoint = maxIndex informationGains | ||
|
||
-- TODO(tulloch) - make this more intelligent (support subsampling | ||
-- features for random forests, etc) | ||
getFeatures :: Vector Example -> Vector Int | ||
getFeatures examples = V.fromList [0..numFeatures] where | ||
numFeatures = V.length (features $ V.head examples) | ||
|
||
data SplittingConstraint = SplittingConstraint { | ||
maximumLevels :: Maybe Int | ||
, minimumAverageGain :: Maybe Double | ||
, minimumSamplesAtLeaf :: Maybe Int | ||
} | ||
|
||
-- Determines whether a candidate set of splits should happen | ||
shouldSplit :: SplittingConstraint -> Int -> Vector Example -> Split -> Bool | ||
shouldSplit constraint currentLevel currentExamples candidateSplit = | ||
Data.List.and constraints where | ||
constraints = [ | ||
case maximumLevels constraint of | ||
Nothing -> True | ||
Just maxLevels -> currentLevel < maxLevels, | ||
case minimumAverageGain constraint of | ||
Nothing -> True | ||
Just minGain -> minGain < averageGain candidateSplit, | ||
case minimumSamplesAtLeaf constraint of | ||
Nothing -> True | ||
Just minSamples -> minSamples < V.length currentExamples] | ||
|
||
buildTreeAtLevel leafWeight splittingConstraint level examples = | ||
if shouldSplit splittingConstraint level examples bestSplit | ||
then Branch { | ||
feature=splitFeature bestSplit | ||
, value=splitValue bestSplit | ||
, left=recur $ V.takeWhile takePredicate orderedExamples | ||
, right=recur $ V.dropWhile takePredicate orderedExamples | ||
} | ||
else Leaf (leafWeight examples) where | ||
-- candidate splits | ||
candidates = V.map (findBestSplit examples) (getFeatures examples) | ||
-- best candidate from all the features | ||
bestSplit = V.maximumBy (compare `on` averageGain) candidates | ||
-- sort the examples at this branch by the best feature | ||
orderedExamples = sortFeature examples (splitFeature bestSplit) | ||
-- left branch takes <, right branch takes > | ||
takePredicate ex = features ex ! splitFeature bestSplit < splitValue bestSplit | ||
-- construct the next level of the tree | ||
recur = buildTreeAtLevel leafWeight splittingConstraint (level + 1) | ||
|
||
buildTree leafWeight splittingConstraint = buildTreeAtLevel leafWeight splittingConstraint 0 | ||
|
||
predict :: DecisionTree -> Vector Double -> Double | ||
predict (Leaf v) _ = v | ||
predict (Branch f v l r) featureVector = | ||
if featureVector ! f < v then predict l featureVector else predict r featureVector | ||
|
||
predictForest :: Vector DecisionTree -> Vector Double -> Double | ||
predictForest trees featureVector = V.sum (V.map (`predict` featureVector) trees) | ||
|
||
-- Typeclass for a given loss function | ||
class LossFunction a where | ||
prior :: a -> Vector Example -> Double | ||
leaf :: a -> Vector Example -> Double | ||
weight :: a -> Vector DecisionTree -> Example -> Double | ||
|
||
data LogitLoss = LogitLoss deriving (Show, Eq) | ||
|
||
-- From Algorithm 5 in http://www-stat.stanford.edu/~jhf/ftp/trebst.pdf | ||
instance LossFunction LogitLoss where | ||
prior _ examples = 0.5 * log ((1 + averageLabel) / (1 - averageLabel)) where | ||
averageLabel = V.sum (V.map label examples) / fromIntegral (V.length examples) | ||
|
||
leaf _ examples = numerator / denominator where | ||
numerator = V.sum (V.map label examples) | ||
|
||
denominator = V.sum (V.map (\e -> abs (label e) * (2 - abs (label e))) examples) | ||
|
||
weight _ trees example = (2 * label example) / | ||
(1 + exp (2 * label example * predictForest trees (features example))) | ||
|
||
runBoostingRound :: (LossFunction a) => a -> SplittingConstraint -> Vector Example -> Vector DecisionTree -> DecisionTree | ||
runBoostingRound lossFunction splittingConstraint examples forest = | ||
buildTree (leaf lossFunction) splittingConstraint weightedExamples where | ||
weightedExamples = V.map (\e -> e {label=weightedLabel e}) examples | ||
weightedLabel = weight lossFunction forest | ||
|
||
boosting :: (LossFunction a) => a -> Int -> SplittingConstraint -> Vector Example -> Vector DecisionTree | ||
boosting lossFunction numRounds splittingConstraint examples = | ||
V.foldl addTree (V.singleton priorTree) (V.replicate numRounds 0) where | ||
priorTree = Leaf (prior lossFunction examples) | ||
addTree currentForest _ = V.snoc currentForest weakLearner where | ||
weakLearner = runBoostingRound lossFunction splittingConstraint examples currentForest |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
module MachineLearning.LogisticRegression where | ||
|
||
import MachineLearning.Common | ||
|
||
import Data.Vector as V | ||
import qualified Data.List | ||
import Data.Function | ||
|
||
data LogisticRegressionState = LogisticRegressionState { | ||
weights :: V.Vector Double | ||
, learningRate :: Double | ||
} deriving (Show, Eq, Read) | ||
|
||
dotProduct :: Num a => Vector a -> Vector a -> a | ||
dotProduct a b = V.sum $ V.zipWith (*) a b | ||
|
||
onlineLogisticRegression :: Example -> LogisticRegressionState -> LogisticRegressionState | ||
onlineLogisticRegression example oldState = newState where | ||
newState = LogisticRegressionState {weights=newWeights, learningRate=newLearningRate} | ||
newWeights = computeUpdate example oldState | ||
newLearningRate = learningRate oldState | ||
|
||
predict :: Example -> LogisticRegressionState -> Double | ||
predict example state = 1.0 / (1.0 + exp (-1 * logit)) where | ||
logit = dotProduct (features example) (weights state) | ||
|
||
gradients :: Example -> LogisticRegressionState -> Vector Double | ||
gradients example state = | ||
V.map (\x -> learningRate state * update * x) (features example) where | ||
update = label example - prediction | ||
prediction = predict example state | ||
|
||
computeUpdate :: Example -> LogisticRegressionState -> Vector Double | ||
computeUpdate example state = V.zipWith (+) (gradients example state) (weights state) | ||
|
||
batchLogisticRegression :: Vector Example -> Int -> LogisticRegressionState -> LogisticRegressionState | ||
batchLogisticRegression _ 0 state = state | ||
batchLogisticRegression examples n state = | ||
batchLogisticRegression examples (n - 1) newState where | ||
newState = runBatchRound examples state | ||
|
||
runBatchRound :: Vector Example -> LogisticRegressionState -> LogisticRegressionState | ||
runBatchRound examples initialState = LogisticRegressionState { | ||
learningRate=learningRate initialState, | ||
weights=weights | ||
} where | ||
weights = V.map (* scalingFactor) accumulatedGradients | ||
scalingFactor = 1.0 / fromIntegral (V.length examples) | ||
accumulatedGradients = V.foldl (V.zipWith (+)) emptyVector exampleGradients | ||
exampleGradients = V.map (`gradients` initialState) examples | ||
emptyVector = V.replicate (V.length examples) 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
import Distribution.Simple | ||
main = defaultMain |