Skip to content

Commit

Permalink
Check dimension of gram matrix (#87)
Browse files Browse the repository at this point in the history
* Implement dimension check for prediction gram matrix

* Add tests for gram matrix dimension check

* Disallow providing only support vector related entries

* Adjust error message

* Minor style changes

Resolves #85
  • Loading branch information
till-m committed Nov 11, 2021
1 parent b9f7a17 commit 2964e6b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
12 changes: 10 additions & 2 deletions src/LIBSVM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,13 @@ function SVM(smc::SVMModel, y, X, weights, labels, svmtype, kernel)
unsafe_copyto!(pointer(libsvmweight_label), smc.param.weight_label, nw)
end

SVM(svmtype, kernel, weights, size(X,1),
if kernel == Kernel.Precomputed
nfeatures = size(X, 2)
else
nfeatures = size(X, 1)
end

SVM(svmtype, kernel, weights, nfeatures,
smc.nr_class, labels, libsvmlabel, libsvmweight, libsvmweight_label,
svs, smc.param.coef0, coefs, probA, probB,
rho, smc.param.degree,
Expand Down Expand Up @@ -373,7 +379,9 @@ function svmpredict(model::SVM{T}, X::AbstractMatrix{U}; nt::Integer = 0) where
set_num_threads(nt)

if model.kernel != Kernel.Precomputed && size(X, 1) != model.nfeatures
throw(DimensionMismatch("Model has $(model.nfeatures) but $(size(X, 1)) provided"))
throw(DimensionMismatch("Model has $(model.nfeatures) features but $(size(X, 1)) provided"))
elseif model.kernel == Kernel.Precomputed && size(X, 1) != model.nfeatures
throw(DimensionMismatch("Gram matrix should have $(model.nfeatures) but $(size(X, 1)) provided"))
end

ninstances = size(X, 2)
Expand Down
19 changes: 19 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,25 @@ end
@test model.coefs model₂.coefs
@test model.SVs.indices model₂.SVs.indices
end

@testset "Malformed prediction" begin
X = [-2 -1 -1 1 1 2;
-1 -1 -2 1 2 1]
y = [1, 1, 1, 2, 2, 2]

T = [-1 2 3;
-1 2 2]

K = X' * X

model = svmtrain(K, y, kernel=Kernel.Precomputed)

KK = X' * T

KK_malformed = KK[1:1,:]

@test_throws DimensionMismatch svmpredict(model, KK_malformed)
end
end

end # @testset "LIBSVM"

0 comments on commit 2964e6b

Please sign in to comment.