Skip to content

Commit

Permalink
Fixed the crossvalidation bug and bumped version
Browse files Browse the repository at this point in the history
  • Loading branch information
aleator committed Nov 29, 2011
1 parent 73b8a76 commit d082cf6
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 8 deletions.
31 changes: 24 additions & 7 deletions AI/SVM/Simple.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{-# LANGUAGE ScopedTypeVariables, TupleSections, ViewPatterns,
RecordWildCards, FlexibleInstances #-}
RecordWildCards, FlexibleInstances, ForeignFunctionInterface #-}
-------------------------------------------------------------------------------
-- |
-- Module : Bindings.SVM
Expand Down Expand Up @@ -60,18 +60,21 @@ import System.Directory
import System.IO.Unsafe
import qualified Data.ByteString.Lazy as B
import qualified Data.Map as Map
import Foreign.C.Types (CInt)
import AI.SVM.Common


-- | Supported SVM classifiers
data ClassifierType =
C {cost :: Double}
| NU {cost :: Double, nu :: Double}
deriving (Show)

-- | Supported SVM regression machines
data RegressorType =
Epsilon Double Double
| NU_r Double Double
deriving (Show)

data SVMClassifier a = SVMClassifier SVM (Map a Double) (Map Double a)
newtype SVMRegressor = SVMRegressor SVM
Expand Down Expand Up @@ -123,9 +126,9 @@ instance Persisting SVMOneClass where
-- | Train an SVM classifier of given type.
trainClassifier
:: (SVMVector b, Ord a) =>
ClassifierType
-> Kernel
-> [(a, b)]
ClassifierType -- ^ The type of the classifier
-> Kernel -- ^ Kernel
-> [(a, b)] -- ^ Training data
-> (String, SVMClassifier a)
trainClassifier ctype kernel dataset = unsafePerformIO $ do
let (to,from, doubleDataSet) = convertToDouble dataset
Expand All @@ -140,10 +143,15 @@ convertToDouble dataset = let

-- | Perform n-foldl cross validation for given set of SVM parameters
crossvalidateClassifier :: (SVMVector b, Ord a) =>
ClassifierType -> Kernel -> Int -> [(a, b)]
ClassifierType -- ^ The type of classifier
-> Kernel -- ^ Classifier kernel
-> Int -- ^ Number of folds to use
-> [(a, b)] -- ^ Dataset
-> Int -- ^ Seed value. The crossvalidation randomly partitions the data into subsets using this seed value
-> (String, [a])
crossvalidateClassifier ctype kernel folds dataset = unsafePerformIO $ do
crossvalidateClassifier ctype kernel folds dataset seed = unsafePerformIO $ do
let (to,from, doubleDataSet) = convertToDouble dataset
c_srand (fromIntegral seed)
(m,res :: [Double]) <- crossvalidate (generalizeClassifier ctype) kernel folds doubleDataSet
return . (m,) $ map (from Map.!) res
where
Expand Down Expand Up @@ -181,13 +189,22 @@ trainRegressor rtype kernel dataset = unsafePerformIO $ do
(m,svm) <- trainSVM (generalizeRegressor rtype) kernel doubleDataSet
return . (m,) $ SVMRegressor svm

crossvalidateRegressor rtype kernel folds dataset = unsafePerformIO $ do
crossvalidateRegressor :: (SVMVector b) =>
RegressorType -- ^ The type of the regressor
-> Kernel -- ^ Kernel
-> Int -- ^ Number of folds to use
-> [(Double, b)] -- ^ Dataset
-> Int -- ^ Seed value. The crossvalidation randomly partitions the data into subsets using this seed value
-> (String, [Double])
crossvalidateRegressor rtype kernel folds dataset seed = unsafePerformIO $ do
let doubleDataSet = map (second convert) dataset
c_srand (fromIntegral seed)
(m,res) <- crossvalidate (generalizeRegressor rtype) kernel folds doubleDataSet
return (m,res)

-- | Predict value for given vector via regression
predictRegression :: SVMVector a => SVMRegressor -> a -> Double
predictRegression (SVMRegressor svm) (convert -> v) = predict svm v

foreign import ccall "srand" c_srand :: CInt -> IO ()

6 changes: 5 additions & 1 deletion svm-simple.cabal
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
name: svm-simple
version: 0.2.5
version: 0.2.6
synopsis: Medium level, simplified, bindings to libsvm
description:
This is a set of simplified bindings to libsvm <http://www.csie.ntu.edu.tw/~cjlin/libsvm/> suite
of support vector machines. This package provides tools for classification, one-class classification
and support vector regression.
.
.
Changes in version 0.2.6
.
* Fixed a critical bug with training and crossvalidation
.
Changes in version 0.2.5
.
* Crossvalidation for the high level interface
Expand Down

0 comments on commit d082cf6

Please sign in to comment.