Skip to content

Commit

Permalink
Fixed a possible null-ptr
Browse files Browse the repository at this point in the history
Since libsvm re-uses stuff from the problem 
instance, we cannot free the problem before
the model. This is hopefully accomplished with
Foreign.Concurrent.newForeignPtr and explicit
free's on the stored objects.

Atleast valgrind is happy and no segfaults appear.
  • Loading branch information
aleator committed Jun 30, 2011
1 parent 66e8d88 commit 53fb362
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 26 deletions.
61 changes: 36 additions & 25 deletions AI/SVM/Simple.hs
@@ -1,5 +1,5 @@
{-# LANGUAGE ForeignFunctionInterface, BangPatterns, ScopedTypeVariables, TupleSections,
RecordWildCards #-}
{-# LANGUAGE ForeignFunctionInterface, BangPatterns, ScopedTypeVariables,
TupleSections, ViewPatterns, RecordWildCards #-}
-------------------------------------------------------------------------------
-- |
-- Module : Bindings.SVM
Expand Down Expand Up @@ -29,6 +29,8 @@ import Foreign.Ptr
import Foreign.ForeignPtr
import qualified Foreign.Concurrent as C
import Foreign.Marshal.Utils
import Foreign.Marshal.Array
import Foreign.Marshal.Alloc
import Control.Applicative
import System.IO.Unsafe
import Foreign.Storable
Expand All @@ -45,28 +47,32 @@ convertDense v = V.generate (dim+1) readVal
readVal !n | n >= dim = C'svm_node (-1) 0
readVal !n = C'svm_node (fromIntegral n+1) (realToFrac $ v ! n)


withProblem :: [(Double, V.Vector Double)] -> (Ptr C'svm_problem -> IO b) -> IO b
withProblem v op = -- Well. This turned out super ugly. Also, this is a veritable
-- bug magnet.
V.unsafeWith xs $ \ptr_xs ->
V.unsafeWith y $ \ptr_y ->
let optrs = offsetPtrs ptr_xs
in V.unsafeWith optrs $ \ptr_offsets ->
with (C'svm_problem (fromIntegral dim) ptr_y ptr_offsets) op
createProblem v = do -- #TODO Check the problem dimension. Libsvm doesn't
node_array <- newArray xs
class_array <- newArray y
offset_array <- newArray $ offsetPtrs node_array
return (C'svm_problem (fromIntegral dim)
class_array
offset_array
,node_array)
where
dim = length v
lengths = map ((+1) . V.length . snd) v
offsetPtrs addr = V.fromList . take dim $
[addr `plusPtr` (idx * sizeOf (xs ! 0))
offsetPtrs addr = take dim
[addr `plusPtr` (idx * sizeOf (head xs)) -- #TODO: Safer alternative to head
| idx <- scanl (+) 0 lengths]
y = V.fromList . map (realToFrac . fst) $ v
xs = V.concat . map (extractSvmNode.snd) $ v
y = map (realToFrac . fst) v
xs = concatMap (V.toList . extractSvmNode . snd) $ v
extractSvmNode x = convertDense $ V.generate (V.length x) (x !)

deleteProblem (C'svm_problem l class_array offset_array , node_array) =
free class_array >> free offset_array >> free node_array


-- | A Support Vector Machine
newtype SVM = SVM (ForeignPtr C'svm_model)
newtype SVM = SVM (ForeignPtr C'svm_model)

getModelPtr (SVM fp) = fp

modelFinalizer :: Ptr C'svm_model -> IO ()
modelFinalizer modelPtr = with modelPtr c'svm_free_and_destroy_model
Expand All @@ -77,7 +83,7 @@ modelFinalizer modelPtr = with modelPtr c'svm_free_and_destroy_model
loadSVM :: FilePath -> IO SVM
loadSVM fp = do
e <- doesFileExist fp
when (not e) $ error "Model does not exist"
unless e $ error "Model does not exist"
-- Not finding the file causes a bus error. Could do without that..
-- #TODO: Make a smarter error
ptr <- withCString fp c'svm_load_model
Expand All @@ -86,16 +92,17 @@ loadSVM fp = do

-- | Save an svm to a file.
saveSVM :: FilePath -> SVM -> IO ()
saveSVM fp (SVM fptr) =
saveSVM fp (getModelPtr -> fptr) =
withForeignPtr fptr $ \model_ptr ->
withCString fp $ \cstr ->
c'svm_save_model cstr model_ptr

getNRClasses (SVM fptr) = fromIntegral <$> withForeignPtr fptr c'svm_get_nr_class
getNRClasses (getModelPtr -> fptr)
= fromIntegral <$> withForeignPtr fptr c'svm_get_nr_class

-- | Predict the class of a vector with an SVM.
predict :: SVM -> V.Vector Double -> Double
predict (SVM fptr) vec = unsafePerformIO $
predict (getModelPtr -> fptr) vec = unsafePerformIO $
withForeignPtr fptr $ \modelPtr ->
let nodes = convertDense vec
in realToFrac <$> V.unsafeWith nodes
Expand Down Expand Up @@ -173,7 +180,7 @@ setTypeParameters (NU_SVR {..}) p = p{c'svm_parameter'C=rf cost
,c'svm_parameter'svm_type=c'NU_SVR}


withParameters svm kernel op = with parameters op
setParameters svm kernel = parameters
where
parameters = setTypeParameters svm
. setKernelParameters kernel
Expand All @@ -197,10 +204,14 @@ trainSVM svm kernel dataSet = do
-- should be an ioref that captures the output which would then
-- be returned from this function.
c'svm_set_print_string_function pf
modelPtr <- withProblem dataSet $ \ptr_problem ->
withParameters svm kernel $ \ptr_parameters ->
c'svm_train ptr_problem ptr_parameters
SVM <$> C.newForeignPtr modelPtr (modelFinalizer modelPtr)
(problem, ptr_nodes) <- createProblem dataSet
ptr_parameters <- malloc
poke ptr_parameters (setParameters svm kernel)
modelPtr <- with problem $ \ptr_problem ->
c'svm_train ptr_problem ptr_parameters
SVM <$> C.newForeignPtr modelPtr
(free ptr_parameters>>deleteProblem (problem, ptr_nodes)
>>modelFinalizer modelPtr)



Expand Down
4 changes: 3 additions & 1 deletion Examples/SmokeTest.hs
Expand Up @@ -27,7 +27,9 @@ main = do
,(1, V.fromList [1,1])
,(1, V.fromList [0,0])
]
svm2 <- trainSVM (C_SVC 1) Linear trainingData
svm_ <- trainSVM (C_SVC 1) (RBF 1) trainingData
saveSVM "model2" svm_
svm2 <- loadSVM "model2"
print $ predict svm2 $ V.fromList [0,1]
print $ predict svm2 $ V.fromList [1,0]
print $ predict svm2 $ V.fromList [0.5,0.5]
Expand Down

0 comments on commit 53fb362

Please sign in to comment.