diff --git a/AI/SVM/Simple.hs b/AI/SVM/Simple.hs index 8c32cb5..29e4bd4 100644 --- a/AI/SVM/Simple.hs +++ b/AI/SVM/Simple.hs @@ -1,5 +1,5 @@ -{-# LANGUAGE ForeignFunctionInterface, BangPatterns, ScopedTypeVariables, TupleSections, - RecordWildCards #-} +{-# LANGUAGE ForeignFunctionInterface, BangPatterns, ScopedTypeVariables, + TupleSections, ViewPatterns, RecordWildCards #-} ------------------------------------------------------------------------------- -- | -- Module : Bindings.SVM @@ -29,6 +29,8 @@ import Foreign.Ptr import Foreign.ForeignPtr import qualified Foreign.Concurrent as C import Foreign.Marshal.Utils +import Foreign.Marshal.Array +import Foreign.Marshal.Alloc import Control.Applicative import System.IO.Unsafe import Foreign.Storable @@ -45,28 +47,32 @@ convertDense v = V.generate (dim+1) readVal readVal !n | n >= dim = C'svm_node (-1) 0 readVal !n = C'svm_node (fromIntegral n+1) (realToFrac $ v ! n) - -withProblem :: [(Double, V.Vector Double)] -> (Ptr C'svm_problem -> IO b) -> IO b -withProblem v op = -- Well. This turned out super ugly. Also, this is a veritable - -- bug magnet. - V.unsafeWith xs $ \ptr_xs -> - V.unsafeWith y $ \ptr_y -> - let optrs = offsetPtrs ptr_xs - in V.unsafeWith optrs $ \ptr_offsets -> - with (C'svm_problem (fromIntegral dim) ptr_y ptr_offsets) op +createProblem v = do -- #TODO Check the problem dimension. Libsvm doesn't + node_array <- newArray xs + class_array <- newArray y + offset_array <- newArray $ offsetPtrs node_array + return (C'svm_problem (fromIntegral dim) + class_array + offset_array + ,node_array) where dim = length v lengths = map ((+1) . V.length . snd) v - offsetPtrs addr = V.fromList . take dim $ - [addr `plusPtr` (idx * sizeOf (xs ! 0)) + offsetPtrs addr = take dim + [addr `plusPtr` (idx * sizeOf (head xs)) -- #TODO: Safer alternative to head | idx <- scanl (+) 0 lengths] - y = V.fromList . map (realToFrac . fst) $ v - xs = V.concat . map (extractSvmNode.snd) $ v + y = map (realToFrac . fst) v + xs = concatMap (V.toList . extractSvmNode . snd) $ v extractSvmNode x = convertDense $ V.generate (V.length x) (x !) +deleteProblem (C'svm_problem l class_array offset_array , node_array) = + free class_array >> free offset_array >> free node_array + + -- | A Support Vector Machine -newtype SVM = SVM (ForeignPtr C'svm_model) +newtype SVM = SVM (ForeignPtr C'svm_model) +getModelPtr (SVM fp) = fp modelFinalizer :: Ptr C'svm_model -> IO () modelFinalizer modelPtr = with modelPtr c'svm_free_and_destroy_model @@ -77,7 +83,7 @@ modelFinalizer modelPtr = with modelPtr c'svm_free_and_destroy_model loadSVM :: FilePath -> IO SVM loadSVM fp = do e <- doesFileExist fp - when (not e) $ error "Model does not exist" + unless e $ error "Model does not exist" -- Not finding the file causes a bus error. Could do without that.. -- #TODO: Make a smarter error ptr <- withCString fp c'svm_load_model @@ -86,16 +92,17 @@ loadSVM fp = do -- | Save an svm to a file. saveSVM :: FilePath -> SVM -> IO () -saveSVM fp (SVM fptr) = +saveSVM fp (getModelPtr -> fptr) = withForeignPtr fptr $ \model_ptr -> withCString fp $ \cstr -> c'svm_save_model cstr model_ptr -getNRClasses (SVM fptr) = fromIntegral <$> withForeignPtr fptr c'svm_get_nr_class +getNRClasses (getModelPtr -> fptr) + = fromIntegral <$> withForeignPtr fptr c'svm_get_nr_class -- | Predict the class of a vector with an SVM. predict :: SVM -> V.Vector Double -> Double -predict (SVM fptr) vec = unsafePerformIO $ +predict (getModelPtr -> fptr) vec = unsafePerformIO $ withForeignPtr fptr $ \modelPtr -> let nodes = convertDense vec in realToFrac <$> V.unsafeWith nodes @@ -173,7 +180,7 @@ setTypeParameters (NU_SVR {..}) p = p{c'svm_parameter'C=rf cost ,c'svm_parameter'svm_type=c'NU_SVR} -withParameters svm kernel op = with parameters op +setParameters svm kernel = parameters where parameters = setTypeParameters svm . setKernelParameters kernel @@ -197,10 +204,14 @@ trainSVM svm kernel dataSet = do -- should be an ioref that captures the output which would then -- be returned from this function. c'svm_set_print_string_function pf - modelPtr <- withProblem dataSet $ \ptr_problem -> - withParameters svm kernel $ \ptr_parameters -> - c'svm_train ptr_problem ptr_parameters - SVM <$> C.newForeignPtr modelPtr (modelFinalizer modelPtr) + (problem, ptr_nodes) <- createProblem dataSet + ptr_parameters <- malloc + poke ptr_parameters (setParameters svm kernel) + modelPtr <- with problem $ \ptr_problem -> + c'svm_train ptr_problem ptr_parameters + SVM <$> C.newForeignPtr modelPtr + (free ptr_parameters>>deleteProblem (problem, ptr_nodes) + >>modelFinalizer modelPtr) diff --git a/Examples/SmokeTest.hs b/Examples/SmokeTest.hs index c4bf99c..28d5309 100644 --- a/Examples/SmokeTest.hs +++ b/Examples/SmokeTest.hs @@ -27,7 +27,9 @@ main = do ,(1, V.fromList [1,1]) ,(1, V.fromList [0,0]) ] - svm2 <- trainSVM (C_SVC 1) Linear trainingData + svm_ <- trainSVM (C_SVC 1) (RBF 1) trainingData + saveSVM "model2" svm_ + svm2 <- loadSVM "model2" print $ predict svm2 $ V.fromList [0,1] print $ predict svm2 $ V.fromList [1,0] print $ predict svm2 $ V.fromList [0.5,0.5]