Skip to content

Commit

Permalink
Test decision values.
Browse files Browse the repository at this point in the history
This adds a test for decision values and fixes passing the pointer to
the decision value array.
  • Loading branch information
barucden committed Jun 11, 2021
1 parent 467c15b commit 9ae9c4e
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 6 deletions.
6 changes: 4 additions & 2 deletions src/LIBSVM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -399,9 +399,11 @@ function svmpredict(model::SVM{T}, X::AbstractMatrix{U}; nt::Integer = 0) where

for i = 1:ninstances
if model.probability
output = libsvm_predict_probability(cmod, nodeptrs[i], decvalues[:, i])
output = libsvm_predict_probability(cmod, nodeptrs[i],
Ref(decvalues, nlabels*(i-1)+1))
else
output = libsvm_predict_values(cmod, nodeptrs[i], decvalues[:, i])
output = libsvm_predict_values(cmod, nodeptrs[i],
Ref(decvalues, nlabels*(i-1)+1))
end
if model.SVMtype == EpsilonSVR || model.SVMtype == NuSVR
pred[i] = output
Expand Down
4 changes: 2 additions & 2 deletions src/libcalls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ function libsvm_free_model(model::Ptr{SVMModel})
end

function libsvm_predict_probability(model::SVMModel, nodes::Ptr{SVMNode},
decisions::Vector{Float64})
decisions::Ref{Float64})
return ccall((:svm_predict_probability, libsvm), Cdouble,
(Ref{SVMModel}, Ptr{SVMNode}, Ref{Float64}),
model, node, decisions)
end

function libsvm_predict_values(model::SVMModel, nodes::Ptr{SVMNode},
decisions::Vector{Float64})
decisions::Ref{Float64})
return ccall((:svm_predict_values, libsvm), Float64,
(Ref{SVMModel}, Ptr{SVMNode}, Ref{Float64}),
model, nodes, decisions)
Expand Down
15 changes: 13 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ end
end
end


@testset "AbstractVector as labels" begin
@info "test AbstractVector labels"

Expand All @@ -68,7 +67,6 @@ end
@test== predict(model, Xtest')
end


@testset "JLD2 save/load" begin
@info "JLD2 save/load"

Expand Down Expand Up @@ -156,5 +154,18 @@ end
@test_throws ArgumentError svmtrain(rand(2, 5), ones(5); bad_params...)
end

@testset "Decision values" begin
X = [-2 -1 -1 1 1 2;
-1 -1 -2 1 2 1]
y = [1, 1, 1, 2, 2, 2]
d = [1.5 1.0 1.5 -1.0 -1.5 -1.5;
0.0 0.0 0.0 0.0 0.0 0.0]

model = svmtrain(X, y, kernel=Kernel.Linear)
ỹ, d̃ = svmpredict(model, X)

@test== y
@test d
end

end # @testset "LIBSVM"

0 comments on commit 9ae9c4e

Please sign in to comment.