Skip to content

Commit

Permalink
refine check_parameter, use proper type mapping (#75)
Browse files Browse the repository at this point in the history
  • Loading branch information
iblislin committed Jun 11, 2021
1 parent 35a640b commit e4c64c3
Showing 1 changed file with 9 additions and 12 deletions.
21 changes: 9 additions & 12 deletions src/LIBSVM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,9 @@ function set_num_threads(nt::Integer)
ccall((:svm_set_num_threads, libsvm), Cvoid, (Cint,), nt)
end

function check_parameter(problem::Vector{SVMProblem}, param::Vector{SVMParameter})
@assert length(problem) == 1
@assert length(param) == 1
function check_parameter(problem::SVMProblem, param::SVMParameter)
err = ccall((:svm_check_parameter, libsvm), Cstring,
(Ptr{SVMProblem}, Ptr{SVMParameter}),
(Ref{SVMProblem}, Ref{SVMParameter}),
problem, param)
if err != C_NULL
throw(ArgumentError("Incorrect parameter: $(unsafe_string(err))"))
Expand Down Expand Up @@ -359,16 +357,15 @@ function svmtrain(
idx, reverse_labels, weights, weight_labels = indices_and_weights(y, X, weights)
end

param = Array{SVMParameter}(undef, 1)
param[1] = SVMParameter(_svmtype, _kernel, Int32(degree), Float64(gamma),
param = SVMParameter(
_svmtype, _kernel, Int32(degree), Float64(gamma),
coef0, cachesize, tolerance, cost, Int32(length(weights)),
pointer(weight_labels), pointer(weights), nu, epsilon, Int32(shrinking),
Int32(probability))

# Construct SVMProblem
(nodes, nodeptrs) = instances2nodes(X)
problem = SVMProblem[SVMProblem(Int32(size(X, 2)), pointer(idx),
pointer(nodeptrs))]
problem = SVMProblem(Int32(size(X, 2)), pointer(idx), pointer(nodeptrs))

# Validate the given parameters
check_parameter(problem, param)
Expand All @@ -382,10 +379,10 @@ function svmtrain(
(Ptr{Cvoid},), @cfunction(svmnoprint, Cvoid, (Ptr{UInt8},)))
end

mod = ccall((:svm_train, libsvm), Ptr{SVMModel}, (Ptr{SVMProblem},
Ptr{SVMParameter}), problem, param)
svm = SVM(unsafe_load(mod), y, X, wts, reverse_labels,
svmtype, kernel)
mod = ccall((:svm_train, libsvm), Ptr{SVMModel},
(Ref{SVMProblem}, Ref{SVMParameter}),
problem, param)
svm = SVM(unsafe_load(mod), y, X, wts, reverse_labels, svmtype, kernel)

ccall((:svm_free_model_content, libsvm), Cvoid, (Ptr{Cvoid},), mod)
return svm
Expand Down

0 comments on commit e4c64c3

Please sign in to comment.