diff --git a/src/LIBSVM.jl b/src/LIBSVM.jl index c11b017..f053cfd 100644 --- a/src/LIBSVM.jl +++ b/src/LIBSVM.jl @@ -262,11 +262,9 @@ function set_num_threads(nt::Integer) ccall((:svm_set_num_threads, libsvm), Cvoid, (Cint,), nt) end -function check_parameter(problem::Vector{SVMProblem}, param::Vector{SVMParameter}) - @assert length(problem) == 1 - @assert length(param) == 1 +function check_parameter(problem::SVMProblem, param::SVMParameter) err = ccall((:svm_check_parameter, libsvm), Cstring, - (Ptr{SVMProblem}, Ptr{SVMParameter}), + (Ref{SVMProblem}, Ref{SVMParameter}), problem, param) if err != C_NULL throw(ArgumentError("Incorrect parameter: $(unsafe_string(err))")) @@ -359,16 +357,15 @@ function svmtrain( idx, reverse_labels, weights, weight_labels = indices_and_weights(y, X, weights) end - param = Array{SVMParameter}(undef, 1) - param[1] = SVMParameter(_svmtype, _kernel, Int32(degree), Float64(gamma), + param = SVMParameter( + _svmtype, _kernel, Int32(degree), Float64(gamma), coef0, cachesize, tolerance, cost, Int32(length(weights)), pointer(weight_labels), pointer(weights), nu, epsilon, Int32(shrinking), Int32(probability)) # Construct SVMProblem (nodes, nodeptrs) = instances2nodes(X) - problem = SVMProblem[SVMProblem(Int32(size(X, 2)), pointer(idx), - pointer(nodeptrs))] + problem = SVMProblem(Int32(size(X, 2)), pointer(idx), pointer(nodeptrs)) # Validate the given parameters check_parameter(problem, param) @@ -382,10 +379,10 @@ function svmtrain( (Ptr{Cvoid},), @cfunction(svmnoprint, Cvoid, (Ptr{UInt8},))) end - mod = ccall((:svm_train, libsvm), Ptr{SVMModel}, (Ptr{SVMProblem}, - Ptr{SVMParameter}), problem, param) - svm = SVM(unsafe_load(mod), y, X, wts, reverse_labels, - svmtype, kernel) + mod = ccall((:svm_train, libsvm), Ptr{SVMModel}, + (Ref{SVMProblem}, Ref{SVMParameter}), + problem, param) + svm = SVM(unsafe_load(mod), y, X, wts, reverse_labels, svmtype, kernel) ccall((:svm_free_model_content, libsvm), Cvoid, (Ptr{Cvoid},), mod) return svm