Skip to content

Commit

Permalink
add decision trees machinery
Browse files Browse the repository at this point in the history
  • Loading branch information
ocramz committed Oct 22, 2018
1 parent fa4cfa8 commit 6bba752
Show file tree
Hide file tree
Showing 7 changed files with 516 additions and 0 deletions.
9 changes: 9 additions & 0 deletions dh-core/core.cabal
Expand Up @@ -22,11 +22,20 @@ library
exposed-modules: Lib
Core.Numeric.BLAS
Core.Numeric.BLAS.Class
Core.Numeric.Statistics.Classification.DecisionTrees
Core.Numeric.Statistics.Classification.Utils
Core.Numeric.Statistics.Classification.Exceptions
Core.Numeric.Statistics.InformationTheory
Core.Data.Dataset
Core.Data.Datum.Vector

build-depends:
base >=4.10 && <5,
bytestring >=0.10.8.1,
containers >=0.5.7.1,
exceptions >=0.8.3,
mwc-random,
primitive,
text >=1.2.2.2,
vector >=0.12.0.1,
vector-algorithms >=0.7.0.1,
Expand Down
81 changes: 81 additions & 0 deletions dh-core/src/Core/Data/Dataset.hs
@@ -0,0 +1,81 @@
{-# language DeriveFunctor, DeriveFoldable, DeriveTraversable #-}
{-# language TypeFamilies #-}
module Core.Data.Dataset where

import qualified Data.Foldable as F (maximumBy)
import Data.Ord (comparing)

import qualified Data.Map.Strict as M (Map, empty, fromList, toList, fromListWith, mapWithKey, foldl', foldrWithKey, foldlWithKey', insert)
import qualified Data.Map.Internal.Debug as M (showTree)
-- import qualified Data.IntMap.Strict as IM
-- import qualified Data.Set as S

import System.Random.MWC
import Control.Monad.Primitive

import Core.Numeric.Statistics.Classification.Utils (Indexed(..), bootstrapNP)

-- | Labeled dataset represented as a 'Map'. The map keys are the class labels
newtype Dataset k a = Dataset { unDataset :: M.Map k a } deriving (Eq, Show, Functor, Foldable, Traversable)

showTree :: (Show k, Show a) => Dataset k a -> String
showTree (Dataset mm) = M.showTree mm

empty :: Dataset k a
empty = Dataset M.empty

insert :: Ord k => k -> a -> Dataset k a -> Dataset k a
insert k ls (Dataset ds) = Dataset $ M.insert k ls ds

mapWithKey :: (k -> a -> b) -> Dataset k a -> Dataset k b
mapWithKey f (Dataset ds) = Dataset $ M.mapWithKey f ds

foldrWithKey :: (k -> a -> b -> b) -> b -> Dataset k a -> b
foldrWithKey f z (Dataset ds) = M.foldrWithKey f z ds

foldlWithKey' :: (a -> k -> b -> a) -> a -> Dataset k b -> a
foldlWithKey' f z (Dataset ds) = M.foldlWithKey' f z ds

fromList :: Ord k => [(k, a)] -> Dataset k a
fromList ld = Dataset $ M.fromList ld

fromListWith :: Ord k => (a -> a -> a) -> [(k, a)] -> Dataset k a
fromListWith f ld = Dataset $ M.fromListWith f ld

toList :: Dataset k a -> [(k, a)]
toList (Dataset ds) = M.toList ds

-- lookup :: Ord k => k -> Dataset k a -> Maybe a
-- lookup k (Dataset ds) = M.lookup k ds

-- | Size of the dataset
size :: Foldable t => Dataset k (t a) -> Int
size (Dataset ds) = M.foldl' (\acc l -> acc + length l) 0 ds

-- | Maximum likelihood estimate of class label
mlClass :: Dataset k [a] -> k
mlClass = fst . F.maximumBy (comparing f) . toList where
f (_, ll) = length ll


-- | Number of items in each class
sizeClasses :: (Foldable t, Num n) => Dataset k (t a) -> M.Map k n
sizeClasses (Dataset ds) = (fromIntegral . length) <$> ds

-- | Empirical class probabilities i.e. for each k, number of items in class k / total number of items
probClasses :: (Fractional prob, Foldable t) => Dataset k (t a) -> M.Map k prob
probClasses ds = (\n -> n / fromIntegral (size ds)) <$> sizeClasses ds


-- * Bootstrap

-- | Nonparametric bootstrap: each class is resampled (i.e. sampled with replacement)
bootstrap :: (Indexed f, PrimMonad m, Ix f ~ Int) =>
Dataset k (f a)
-> Int -- ^ Number of samples
-> Int -- ^ Number of bootstrap resamples
-> Gen (PrimState m)
-> m [Dataset k [a]]
bootstrap ds@Dataset{} nsamples nboot gen = do
dss <- traverse (bootstrapNP nsamples nboot gen) ds
pure $ sequenceA dss
83 changes: 83 additions & 0 deletions dh-core/src/Core/Data/Datum/Vector.hs
@@ -0,0 +1,83 @@
{-# language DeriveFunctor, DeriveFoldable, DeriveTraversable, GeneralizedNewtypeDeriving #-}
module Core.Data.Datum.Vector where

import qualified Data.IntMap as IM
-- import qualified Data.Vector.Unboxed as VU
import qualified Data.Vector as V

import Control.Monad.Catch (MonadThrow(..))
import Core.Numeric.Statistics.Classification.Exceptions

newtype FeatureLabels = FeatureLabels (IM.IntMap String) deriving (Eq, Show)

featureLabels :: MonadThrow m => Int -> [String] -> m FeatureLabels
featureLabels n ls
| length ls == n = pure $ FeatureLabels $ IM.fromList $ zip [0..] ls
| otherwise = throwM $ DimMismatchE "featureLabels" n (length ls)

lookupFeatureLabelUnsafe :: IM.Key -> FeatureLabels -> String
lookupFeatureLabelUnsafe i (FeatureLabels fl) = fl IM.! i


-- | A data point i.e. a vector in R^n
newtype V a = V (V.Vector a) deriving (Eq, Show, Functor, Foldable, Traversable, Applicative, Monad)

fromListV :: [a] -> V a
fromListV = V . V.fromList
{-# inline fromListV #-}

toListV :: V a -> [a]
toListV (V vv) = V.toList vv

zipV :: V a -> V b -> V (a, b)
zipV (V v1) (V v2) = V $ V.zip v1 v2
{-# inline zipV #-}

unzipV :: V (a, b) -> (V a, V b)
unzipV (V vs) = (V v1, V v2) where
(v1, v2) = V.unzip vs
{-# inline unzipV #-}

mkV :: MonadThrow m => Int -> V.Vector a -> m (V a)
mkV n xs | dxs == n = pure $ V xs
| otherwise = throwM $ DimMismatchE "mkV" n dxs where
dxs = V.length xs

indexUnsafe :: V a -> Int -> a
(V vv) `indexUnsafe` j = vv V.! j
{-# inline indexUnsafe #-}

(!) :: MonadThrow m => V a -> Int -> m a
v ! j | j >= 0 && j < d = pure $ v `indexUnsafe` j
| otherwise = throwM $ IndexOobE "(!)" j 0 d where
d = dim v

dim :: V a -> Int
dim (V vv) = V.length vv
{-# inline dim #-}

foldrWithKey :: (Int -> a -> b -> b) -> b -> V a -> b
foldrWithKey f z vv = foldr ins z $ zipV (fromListV [0..]) vv where
ins (i, x) acc = f i x acc
{-# inline foldrWithKey #-}


dataSplitDecision :: (a -> Bool) -> Int -> (V a -> Bool)
dataSplitDecision p j dat = p (dat `indexUnsafe` j)

-- allComponents :: V (a -> Bool) -> V a -> Bool
-- allComponents ps dat = all (== True) $ f <$> vps where
-- vps = zipV ps dat
-- f (p, vi) = p vi
-- -- allComponents ps dat = all (== True) $ ps <*> dat




-- | Vectors with measurable entries

-- data Measurable a = BoundedBoth a a deriving (Eq, Show)

-- newtype Xf f a = Xf (V.Vector (f a))

-- type XMeas = Xf Measurable
155 changes: 155 additions & 0 deletions dh-core/src/Core/Numeric/Statistics/Classification/DecisionTrees.hs
@@ -0,0 +1,155 @@
{-# language DeriveFunctor, DeriveFoldable, DeriveTraversable #-}
module Core.Numeric.Statistics.Classification.DecisionTrees where

import qualified Data.Foldable as F
import qualified Data.Set as S
import Data.Ord (comparing)

import Core.Data.Dataset
import qualified Core.Data.Datum.Vector as XV
import Core.Numeric.Statistics.Classification.Utils
import Core.Numeric.Statistics.InformationTheory (entropyR)

-- | A binary tree.
--
-- Each leaf carries data of type 'a' and we can attach metadata of type 'd' at each branching point.
data Tree d a =
Node d (Tree d a) (Tree d a)
| Leaf a
deriving (Eq, Show, Functor, Foldable, Traversable)

unfoldTree :: (t -> Either a (d, t, t)) -> t -> Tree d a
unfoldTree f x =
either Leaf (\(d, l, r) -> Node d (unfoldTree f l) (unfoldTree f r) ) (f x)


-- | Tree state : list of candidate dataset cuts (feature #, level)
data TState k a = TState {
tsFeatCuts :: S.Set (Int, a)
, tsDataset :: Dataset k [XV.V a] }

-- | Tree state + local tree depth
data TSd k a = TSd { tsDepth :: !Int, tState :: TState k a }

-- | Global options for growing decision trees
data TOptions = TOptions {
toMaxDepth :: !Int -- ^ Max tree depth
, toMinLeafSize :: !Int -- ^ Minimum size of the contents of a leaf
, toOrder :: Order -- ^ Less than | Equal or larger than
} deriving (Eq, Show)

-- | Tree node metadata
--
-- For decision trees, at each node we store the decision feature and its decision threshold
data TNData a = TNData {
tJStar :: !Int -- ^ Decision feature index
, tTStar :: a -- ^ Decision threshold
} deriving (Eq)

instance Show a => Show (TNData a) where
show (TNData j t) = unwords ["(j =", show j, ", t =", show t, ")"]


-- | Split decision: find feature (value, index) that maximizes the entropy drop (i.e the information gain, or KL divergence between the joint and factored datasets)
--
-- NB generates empty leaves
treeUnfoldStep :: (Ord a, Ord k) =>
TOptions
-> TSd k a
-> Either (Dataset k [XV.V a]) (TNData a, TSd k a, TSd k a)
treeUnfoldStep (TOptions maxdepth minls ord) (TSd depth tst)
| depth >= maxdepth || sizeDs tst <= minls = Left (tsDataset tst)
| sizeDs tsl == 0 = Left (tsDataset tsr)
| sizeDs tsr == 0 = Left (tsDataset tsl)
| otherwise = Right (mdata, tdsl, tdsr)
where
sizeDs = size . tsDataset
mdata = TNData jstar tstar
(jstar, tstar, tsl, tsr) = maxInfoGainSplit ordf tst
ordf = fromOrder ord
d' = depth + 1
tdsl = TSd d' tsl
tdsr = TSd d' tsr



{- | Note (OPTIMIZATIONS maxInfoGainSplit)
1. After splitting a dataset, remove the (threshold, feature index) pair corresponding to the succesful split
2. " " " " , remove /all/ (threshold, index) pairs that are subsumed by the successful test (e.g in the test ((<), 3.2, 27) , remove all [(t, 27) | t <- [tmin ..], t < 3.2 ] ). This is only a useful optimization for /monotonic/ class boundaries.
-}

-- | Tabulate the information gain for a number of decision thresholds and return a decision function corresponding to the threshold that yields the maximum information gain.
maxInfoGainSplit :: (Ord k, Ord a, Eq a) =>
(a -> a -> Bool)
-> TState k a
-> (Int, a, TState k a, TState k a)
maxInfoGainSplit decision (TState tjs ds) = (jstar, tstar, TState tjs' dsl, TState tjs' dsr) where
tjs' = S.delete (jstar, tstar) tjs -- See Note (OPTIMIZATIONS maxInfoGainSPlit)
(jstar, tstar, _, dsl, dsr) = F.maximumBy (comparing third5) $ infog `map` S.toList tjs
infog (j, t) = (j, t, h, dsl, dsr) where
(h, dsl, dsr) = infoGainR (decision t) j ds


third5 :: (a, b, c, d, e) -> c
third5 (_, _, c, _, _) = c

-- | Information gain due to a dataset split (regularized, H(0) := 0)
infoGainR :: (Ord k, Ord h, Floating h) =>
(a -> Bool)
-> Int
-> Dataset k [XV.V a]
-> (h, Dataset k [XV.V a], Dataset k [XV.V a])
infoGainR p j ds = (infoGain, dsl, dsr) where
(dsl, pl, dsr, pr) = splitDatasetAtAttr p j ds
(h0, hl, hr) = (entropyR ds, entropyR dsl, entropyR dsr)
infoGain = h0 - (pl * hl + pr * hr)


-- | helper function for 'infoGain' and 'infoGainR'
splitDatasetAtAttr :: (Fractional n, Ord k) =>
(a -> Bool)
-> Int
-> Dataset k [XV.V a]
-> (Dataset k [XV.V a], n, Dataset k [XV.V a], n)
splitDatasetAtAttr p j ds = (dsl, pl, dsr, pr) where
sz = fromIntegral . size
(dsl, dsr) = partition p j ds
(s0, sl, sr) = (sz ds, sz dsl, sz dsr)
pl = sl / s0
pr = sr / s0


-- | Partition a Dataset in two, according to a decision predicate applied to a given feature.
--
-- e.g. "is the j'th component of datum X_i larger than threshold t ?"
partition :: (Foldable t, Ord k) =>
(a -> Bool) -- ^ Decision function (element-level)
-> Int -- ^ Feature index
-> Dataset k (t (XV.V a))
-> (Dataset k [XV.V a], Dataset k [XV.V a])
partition p j ds@Dataset{} = foldrWithKey insf (empty, empty) ds where
insf k lrow (l, r) = (insert k lp l, insert k rp r) where
(lp, rp) = partition1 (XV.dataSplitDecision p j) lrow



-- | Partition a Foldable in two lists according to a predicate
partition1 :: Foldable t => (a -> Bool) -> t a -> ([a], [a])
partition1 p = foldr ins ([], []) where
ins x (l, r) | p x = (x : l, r)
| otherwise = (l , x : r)



-- | A well-defined Ordering, for strict half-plane separation
data Order = LessThan | GreaterOrEqual deriving (Eq, Ord, Enum, Bounded)
instance Show Order where
show LessThan = "<"
show GreaterOrEqual = ">="

fromOrder :: Ord a => Order -> (a -> a -> Bool)
fromOrder o = case o of
LessThan -> (<)
_ -> (>=)
25 changes: 25 additions & 0 deletions dh-core/src/Core/Numeric/Statistics/Classification/Exceptions.hs
@@ -0,0 +1,25 @@
module Core.Numeric.Statistics.Classification.Exceptions where

import Control.Exception
import Data.Typeable


-- * Exceptions

data ValueException = ZeroProbabilityE String deriving (Eq, Show, Typeable)

instance Exception ValueException


data DataException =
-- MissingFeatureE i
IndexOobE String Int Int Int
| DimMismatchE String Int Int
deriving (Eq, Typeable)

instance Show DataException where
show e = case e of
IndexOobE errMsg ix blo bhi -> unwords [errMsg, ": index", show ix,"out of bounds", show (blo, bhi)]
DimMismatchE errMsg d1 d2 -> unwords [errMsg, ": dimension mismatch : expecting", show d1, "but got", show d2]

instance Exception DataException

0 comments on commit 6bba752

Please sign in to comment.