Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Added (test version of) trainWtdClassifier

  • Loading branch information...
commit 5cdf18755d01e00f59bf490048177c30ae9affe3 1 parent 921122e
@aleator authored
Showing with 20 additions and 2 deletions.
  1. +17 −2 AI/SVM/Simple.hs
  2. +3 −0  Examples/Test.hs
View
19 AI/SVM/Simple.hs
@@ -39,7 +39,7 @@ module AI.SVM.Simple (
,Kernel(..)
,SVMOneClass(), SVMClassifier(), SVMRegressor()
-- * Classifier machines
- ,trainClassifier, crossvalidateClassifier, classify
+ ,trainClassifier, trainWtdClassifier, crossvalidateClassifier, classify
-- * One class machines
,trainOneClass, inSet, OneClassResult(..)
-- * Regression machines
@@ -50,7 +50,7 @@ module AI.SVM.Simple (
import AI.SVM.Base
import Control.Applicative
-import Control.Arrow (second, (***), (&&&))
+import Control.Arrow (first, second, (***), (&&&))
import Control.Monad
import Data.Binary
import Data.List
@@ -135,6 +135,21 @@ trainClassifier ctype kernel dataset = unsafePerformIO $ do
(m,svm) <- trainSVM (generalizeClassifier ctype) kernel [] doubleDataSet
return . (m,) $ SVMClassifier svm to from
+-- | Train an SVM classifier of given type.
+trainWtdClassifier
+ :: (SVMVector b, Ord a) =>
+ ClassifierType -- ^ The type of the classifier
+ -> Kernel -- ^ Kernel
+ -> [(a, Double)] -- ^ Training weights
+ -> [(a, b)] -- ^ Training data
+ -> (String, SVMClassifier a)
+trainWtdClassifier ctype kernel ws dataset = unsafePerformIO $ do
+ let (to,from, doubleDataSet) = convertToDouble dataset
+ cw = map (first conv) ws
+ conv i = round $ to Map.! i
+ (m,svm) <- trainSVM (generalizeClassifier ctype) kernel [] doubleDataSet
+ return . (m,) $ SVMClassifier svm to from
+
convertToDouble dataset = let
l = zip (nub . map fst $ dataset) [1..]
to = Map.fromList l
View
3  Examples/Test.hs
@@ -4,8 +4,11 @@ main = do
let trnC = [(x==y,[x,y]) | x <- [0,1], y <- [0,1::Double]]
trnR = [(x-y,[x,y]) | x <- [0,0.1..1], y <- [0,0.1..1::Double]]
(msgs1,cf) = trainClassifier (C 0.5) (RBF 0.3) trnC
+ (msgs3,cf2) = trainWtdClassifier (C 0.5) (RBF 0.3) [(True,10),(False,1)] trnC
(msgs2,re) = trainRegressor (Epsilon 0.1 0.1) (RBF 0.3) trnR
print msgs1
print msgs2
+ print msgs3
print (map (classify cf) $ map snd trnC)
+ print (map (classify cf2) $ map snd trnC)
print (map (predictRegression re) $ [[0,1],[0.5,0.2],[1,2::Double]])
Please sign in to comment.
Something went wrong with that request. Please try again.