Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
516 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
{-# language DeriveFunctor, DeriveFoldable, DeriveTraversable #-} | ||
{-# language TypeFamilies #-} | ||
module Core.Data.Dataset where | ||
|
||
import qualified Data.Foldable as F (maximumBy) | ||
import Data.Ord (comparing) | ||
|
||
import qualified Data.Map.Strict as M (Map, empty, fromList, toList, fromListWith, mapWithKey, foldl', foldrWithKey, foldlWithKey', insert) | ||
import qualified Data.Map.Internal.Debug as M (showTree) | ||
-- import qualified Data.IntMap.Strict as IM | ||
-- import qualified Data.Set as S | ||
|
||
import System.Random.MWC | ||
import Control.Monad.Primitive | ||
|
||
import Core.Numeric.Statistics.Classification.Utils (Indexed(..), bootstrapNP) | ||
|
||
-- | Labeled dataset represented as a 'Map'. The map keys are the class labels | ||
newtype Dataset k a = Dataset { unDataset :: M.Map k a } deriving (Eq, Show, Functor, Foldable, Traversable) | ||
|
||
showTree :: (Show k, Show a) => Dataset k a -> String | ||
showTree (Dataset mm) = M.showTree mm | ||
|
||
empty :: Dataset k a | ||
empty = Dataset M.empty | ||
|
||
insert :: Ord k => k -> a -> Dataset k a -> Dataset k a | ||
insert k ls (Dataset ds) = Dataset $ M.insert k ls ds | ||
|
||
mapWithKey :: (k -> a -> b) -> Dataset k a -> Dataset k b | ||
mapWithKey f (Dataset ds) = Dataset $ M.mapWithKey f ds | ||
|
||
foldrWithKey :: (k -> a -> b -> b) -> b -> Dataset k a -> b | ||
foldrWithKey f z (Dataset ds) = M.foldrWithKey f z ds | ||
|
||
foldlWithKey' :: (a -> k -> b -> a) -> a -> Dataset k b -> a | ||
foldlWithKey' f z (Dataset ds) = M.foldlWithKey' f z ds | ||
|
||
fromList :: Ord k => [(k, a)] -> Dataset k a | ||
fromList ld = Dataset $ M.fromList ld | ||
|
||
fromListWith :: Ord k => (a -> a -> a) -> [(k, a)] -> Dataset k a | ||
fromListWith f ld = Dataset $ M.fromListWith f ld | ||
|
||
toList :: Dataset k a -> [(k, a)] | ||
toList (Dataset ds) = M.toList ds | ||
|
||
-- lookup :: Ord k => k -> Dataset k a -> Maybe a | ||
-- lookup k (Dataset ds) = M.lookup k ds | ||
|
||
-- | Size of the dataset | ||
size :: Foldable t => Dataset k (t a) -> Int | ||
size (Dataset ds) = M.foldl' (\acc l -> acc + length l) 0 ds | ||
|
||
-- | Maximum likelihood estimate of class label | ||
mlClass :: Dataset k [a] -> k | ||
mlClass = fst . F.maximumBy (comparing f) . toList where | ||
f (_, ll) = length ll | ||
|
||
|
||
-- | Number of items in each class | ||
sizeClasses :: (Foldable t, Num n) => Dataset k (t a) -> M.Map k n | ||
sizeClasses (Dataset ds) = (fromIntegral . length) <$> ds | ||
|
||
-- | Empirical class probabilities i.e. for each k, number of items in class k / total number of items | ||
probClasses :: (Fractional prob, Foldable t) => Dataset k (t a) -> M.Map k prob | ||
probClasses ds = (\n -> n / fromIntegral (size ds)) <$> sizeClasses ds | ||
|
||
|
||
-- * Bootstrap | ||
|
||
-- | Nonparametric bootstrap: each class is resampled (i.e. sampled with replacement) | ||
bootstrap :: (Indexed f, PrimMonad m, Ix f ~ Int) => | ||
Dataset k (f a) | ||
-> Int -- ^ Number of samples | ||
-> Int -- ^ Number of bootstrap resamples | ||
-> Gen (PrimState m) | ||
-> m [Dataset k [a]] | ||
bootstrap ds@Dataset{} nsamples nboot gen = do | ||
dss <- traverse (bootstrapNP nsamples nboot gen) ds | ||
pure $ sequenceA dss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
{-# language DeriveFunctor, DeriveFoldable, DeriveTraversable, GeneralizedNewtypeDeriving #-} | ||
module Core.Data.Datum.Vector where | ||
|
||
import qualified Data.IntMap as IM | ||
-- import qualified Data.Vector.Unboxed as VU | ||
import qualified Data.Vector as V | ||
|
||
import Control.Monad.Catch (MonadThrow(..)) | ||
import Core.Numeric.Statistics.Classification.Exceptions | ||
|
||
newtype FeatureLabels = FeatureLabels (IM.IntMap String) deriving (Eq, Show) | ||
|
||
featureLabels :: MonadThrow m => Int -> [String] -> m FeatureLabels | ||
featureLabels n ls | ||
| length ls == n = pure $ FeatureLabels $ IM.fromList $ zip [0..] ls | ||
| otherwise = throwM $ DimMismatchE "featureLabels" n (length ls) | ||
|
||
lookupFeatureLabelUnsafe :: IM.Key -> FeatureLabels -> String | ||
lookupFeatureLabelUnsafe i (FeatureLabels fl) = fl IM.! i | ||
|
||
|
||
-- | A data point i.e. a vector in R^n | ||
newtype V a = V (V.Vector a) deriving (Eq, Show, Functor, Foldable, Traversable, Applicative, Monad) | ||
|
||
fromListV :: [a] -> V a | ||
fromListV = V . V.fromList | ||
{-# inline fromListV #-} | ||
|
||
toListV :: V a -> [a] | ||
toListV (V vv) = V.toList vv | ||
|
||
zipV :: V a -> V b -> V (a, b) | ||
zipV (V v1) (V v2) = V $ V.zip v1 v2 | ||
{-# inline zipV #-} | ||
|
||
unzipV :: V (a, b) -> (V a, V b) | ||
unzipV (V vs) = (V v1, V v2) where | ||
(v1, v2) = V.unzip vs | ||
{-# inline unzipV #-} | ||
|
||
mkV :: MonadThrow m => Int -> V.Vector a -> m (V a) | ||
mkV n xs | dxs == n = pure $ V xs | ||
| otherwise = throwM $ DimMismatchE "mkV" n dxs where | ||
dxs = V.length xs | ||
|
||
indexUnsafe :: V a -> Int -> a | ||
(V vv) `indexUnsafe` j = vv V.! j | ||
{-# inline indexUnsafe #-} | ||
|
||
(!) :: MonadThrow m => V a -> Int -> m a | ||
v ! j | j >= 0 && j < d = pure $ v `indexUnsafe` j | ||
| otherwise = throwM $ IndexOobE "(!)" j 0 d where | ||
d = dim v | ||
|
||
dim :: V a -> Int | ||
dim (V vv) = V.length vv | ||
{-# inline dim #-} | ||
|
||
foldrWithKey :: (Int -> a -> b -> b) -> b -> V a -> b | ||
foldrWithKey f z vv = foldr ins z $ zipV (fromListV [0..]) vv where | ||
ins (i, x) acc = f i x acc | ||
{-# inline foldrWithKey #-} | ||
|
||
|
||
dataSplitDecision :: (a -> Bool) -> Int -> (V a -> Bool) | ||
dataSplitDecision p j dat = p (dat `indexUnsafe` j) | ||
|
||
-- allComponents :: V (a -> Bool) -> V a -> Bool | ||
-- allComponents ps dat = all (== True) $ f <$> vps where | ||
-- vps = zipV ps dat | ||
-- f (p, vi) = p vi | ||
-- -- allComponents ps dat = all (== True) $ ps <*> dat | ||
|
||
|
||
|
||
|
||
-- | Vectors with measurable entries | ||
|
||
-- data Measurable a = BoundedBoth a a deriving (Eq, Show) | ||
|
||
-- newtype Xf f a = Xf (V.Vector (f a)) | ||
|
||
-- type XMeas = Xf Measurable |
155 changes: 155 additions & 0 deletions
155
dh-core/src/Core/Numeric/Statistics/Classification/DecisionTrees.hs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
{-# language DeriveFunctor, DeriveFoldable, DeriveTraversable #-} | ||
module Core.Numeric.Statistics.Classification.DecisionTrees where | ||
|
||
import qualified Data.Foldable as F | ||
import qualified Data.Set as S | ||
import Data.Ord (comparing) | ||
|
||
import Core.Data.Dataset | ||
import qualified Core.Data.Datum.Vector as XV | ||
import Core.Numeric.Statistics.Classification.Utils | ||
import Core.Numeric.Statistics.InformationTheory (entropyR) | ||
|
||
-- | A binary tree. | ||
-- | ||
-- Each leaf carries data of type 'a' and we can attach metadata of type 'd' at each branching point. | ||
data Tree d a = | ||
Node d (Tree d a) (Tree d a) | ||
| Leaf a | ||
deriving (Eq, Show, Functor, Foldable, Traversable) | ||
|
||
unfoldTree :: (t -> Either a (d, t, t)) -> t -> Tree d a | ||
unfoldTree f x = | ||
either Leaf (\(d, l, r) -> Node d (unfoldTree f l) (unfoldTree f r) ) (f x) | ||
|
||
|
||
-- | Tree state : list of candidate dataset cuts (feature #, level) | ||
data TState k a = TState { | ||
tsFeatCuts :: S.Set (Int, a) | ||
, tsDataset :: Dataset k [XV.V a] } | ||
|
||
-- | Tree state + local tree depth | ||
data TSd k a = TSd { tsDepth :: !Int, tState :: TState k a } | ||
|
||
-- | Global options for growing decision trees | ||
data TOptions = TOptions { | ||
toMaxDepth :: !Int -- ^ Max tree depth | ||
, toMinLeafSize :: !Int -- ^ Minimum size of the contents of a leaf | ||
, toOrder :: Order -- ^ Less than | Equal or larger than | ||
} deriving (Eq, Show) | ||
|
||
-- | Tree node metadata | ||
-- | ||
-- For decision trees, at each node we store the decision feature and its decision threshold | ||
data TNData a = TNData { | ||
tJStar :: !Int -- ^ Decision feature index | ||
, tTStar :: a -- ^ Decision threshold | ||
} deriving (Eq) | ||
|
||
instance Show a => Show (TNData a) where | ||
show (TNData j t) = unwords ["(j =", show j, ", t =", show t, ")"] | ||
|
||
|
||
-- | Split decision: find feature (value, index) that maximizes the entropy drop (i.e the information gain, or KL divergence between the joint and factored datasets) | ||
-- | ||
-- NB generates empty leaves | ||
treeUnfoldStep :: (Ord a, Ord k) => | ||
TOptions | ||
-> TSd k a | ||
-> Either (Dataset k [XV.V a]) (TNData a, TSd k a, TSd k a) | ||
treeUnfoldStep (TOptions maxdepth minls ord) (TSd depth tst) | ||
| depth >= maxdepth || sizeDs tst <= minls = Left (tsDataset tst) | ||
| sizeDs tsl == 0 = Left (tsDataset tsr) | ||
| sizeDs tsr == 0 = Left (tsDataset tsl) | ||
| otherwise = Right (mdata, tdsl, tdsr) | ||
where | ||
sizeDs = size . tsDataset | ||
mdata = TNData jstar tstar | ||
(jstar, tstar, tsl, tsr) = maxInfoGainSplit ordf tst | ||
ordf = fromOrder ord | ||
d' = depth + 1 | ||
tdsl = TSd d' tsl | ||
tdsr = TSd d' tsr | ||
|
||
|
||
|
||
{- | Note (OPTIMIZATIONS maxInfoGainSplit) | ||
1. After splitting a dataset, remove the (threshold, feature index) pair corresponding to the succesful split | ||
2. " " " " , remove /all/ (threshold, index) pairs that are subsumed by the successful test (e.g in the test ((<), 3.2, 27) , remove all [(t, 27) | t <- [tmin ..], t < 3.2 ] ). This is only a useful optimization for /monotonic/ class boundaries. | ||
-} | ||
|
||
-- | Tabulate the information gain for a number of decision thresholds and return a decision function corresponding to the threshold that yields the maximum information gain. | ||
maxInfoGainSplit :: (Ord k, Ord a, Eq a) => | ||
(a -> a -> Bool) | ||
-> TState k a | ||
-> (Int, a, TState k a, TState k a) | ||
maxInfoGainSplit decision (TState tjs ds) = (jstar, tstar, TState tjs' dsl, TState tjs' dsr) where | ||
tjs' = S.delete (jstar, tstar) tjs -- See Note (OPTIMIZATIONS maxInfoGainSPlit) | ||
(jstar, tstar, _, dsl, dsr) = F.maximumBy (comparing third5) $ infog `map` S.toList tjs | ||
infog (j, t) = (j, t, h, dsl, dsr) where | ||
(h, dsl, dsr) = infoGainR (decision t) j ds | ||
|
||
|
||
third5 :: (a, b, c, d, e) -> c | ||
third5 (_, _, c, _, _) = c | ||
|
||
-- | Information gain due to a dataset split (regularized, H(0) := 0) | ||
infoGainR :: (Ord k, Ord h, Floating h) => | ||
(a -> Bool) | ||
-> Int | ||
-> Dataset k [XV.V a] | ||
-> (h, Dataset k [XV.V a], Dataset k [XV.V a]) | ||
infoGainR p j ds = (infoGain, dsl, dsr) where | ||
(dsl, pl, dsr, pr) = splitDatasetAtAttr p j ds | ||
(h0, hl, hr) = (entropyR ds, entropyR dsl, entropyR dsr) | ||
infoGain = h0 - (pl * hl + pr * hr) | ||
|
||
|
||
-- | helper function for 'infoGain' and 'infoGainR' | ||
splitDatasetAtAttr :: (Fractional n, Ord k) => | ||
(a -> Bool) | ||
-> Int | ||
-> Dataset k [XV.V a] | ||
-> (Dataset k [XV.V a], n, Dataset k [XV.V a], n) | ||
splitDatasetAtAttr p j ds = (dsl, pl, dsr, pr) where | ||
sz = fromIntegral . size | ||
(dsl, dsr) = partition p j ds | ||
(s0, sl, sr) = (sz ds, sz dsl, sz dsr) | ||
pl = sl / s0 | ||
pr = sr / s0 | ||
|
||
|
||
-- | Partition a Dataset in two, according to a decision predicate applied to a given feature. | ||
-- | ||
-- e.g. "is the j'th component of datum X_i larger than threshold t ?" | ||
partition :: (Foldable t, Ord k) => | ||
(a -> Bool) -- ^ Decision function (element-level) | ||
-> Int -- ^ Feature index | ||
-> Dataset k (t (XV.V a)) | ||
-> (Dataset k [XV.V a], Dataset k [XV.V a]) | ||
partition p j ds@Dataset{} = foldrWithKey insf (empty, empty) ds where | ||
insf k lrow (l, r) = (insert k lp l, insert k rp r) where | ||
(lp, rp) = partition1 (XV.dataSplitDecision p j) lrow | ||
|
||
|
||
|
||
-- | Partition a Foldable in two lists according to a predicate | ||
partition1 :: Foldable t => (a -> Bool) -> t a -> ([a], [a]) | ||
partition1 p = foldr ins ([], []) where | ||
ins x (l, r) | p x = (x : l, r) | ||
| otherwise = (l , x : r) | ||
|
||
|
||
|
||
-- | A well-defined Ordering, for strict half-plane separation | ||
data Order = LessThan | GreaterOrEqual deriving (Eq, Ord, Enum, Bounded) | ||
instance Show Order where | ||
show LessThan = "<" | ||
show GreaterOrEqual = ">=" | ||
|
||
fromOrder :: Ord a => Order -> (a -> a -> Bool) | ||
fromOrder o = case o of | ||
LessThan -> (<) | ||
_ -> (>=) |
25 changes: 25 additions & 0 deletions
25
dh-core/src/Core/Numeric/Statistics/Classification/Exceptions.hs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
module Core.Numeric.Statistics.Classification.Exceptions where | ||
|
||
import Control.Exception | ||
import Data.Typeable | ||
|
||
|
||
-- * Exceptions | ||
|
||
data ValueException = ZeroProbabilityE String deriving (Eq, Show, Typeable) | ||
|
||
instance Exception ValueException | ||
|
||
|
||
data DataException = | ||
-- MissingFeatureE i | ||
IndexOobE String Int Int Int | ||
| DimMismatchE String Int Int | ||
deriving (Eq, Typeable) | ||
|
||
instance Show DataException where | ||
show e = case e of | ||
IndexOobE errMsg ix blo bhi -> unwords [errMsg, ": index", show ix,"out of bounds", show (blo, bhi)] | ||
DimMismatchE errMsg d1 d2 -> unwords [errMsg, ": dimension mismatch : expecting", show d1, "but got", show d2] | ||
|
||
instance Exception DataException |
Oops, something went wrong.