-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #190 from nno/_enh/new_svm
NF: support for fitcsvm
- Loading branch information
Showing
6 changed files
with
246 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
function predicted=cosmo_classify_matlabcsvm(samples_train, targets_train, samples_test, opt) | ||
% svm classifier wrapper (around fitcsvm) | ||
% | ||
% predicted=cosmo_classify_matlabcsvm(samples_train, targets_train, samples_test, opt) | ||
% | ||
% Inputs: | ||
% samples_train PxR training data for P samples and R features | ||
% targets_train Px1 training data classes | ||
% samples_test QxR test data | ||
% opt struct with options. supports any option that | ||
% fitcsvm supports | ||
% | ||
% Output: | ||
% predicted Qx1 predicted data classes for samples_test | ||
% | ||
% Notes: | ||
% - this function uses Matlab's builtin fitcsvm function, which was the | ||
% successor of svmtrain. | ||
% - Matlab's SVM classifier is rather slow, especially for multi-class | ||
% data (more than two classes). When classification takes a long time, | ||
% consider using libsvm. | ||
% - for a guide on svm classification, see | ||
% http://www.csie.ntu.edu.tw/~cjlin/papers/guide/guide.pdf | ||
% Note that cosmo_crossvalidate and cosmo_crossvalidation_measure | ||
% provide an option 'normalization' to perform data scaling | ||
% - As of Matlab 2017a (maybe earlier), Matlab gives the warning that | ||
% 'svmtrain will be removed in a future release. Use fitcsvm instead.' | ||
% however fitcsvm gives different results than svmtrain; as a result | ||
% cosmo_classify_matlabcsvm gives different results than | ||
% cosmo_classify_matlabsvm. | ||
% | ||
% See also fitcsvm, svmclassify, cosmo_classify_matlabsvm. | ||
% | ||
% # For CoSMoMVPA's copyright information and license terms, # | ||
% # see the COPYING file distributed with CoSMoMVPA. # | ||
|
||
if nargin<4, opt=struct(); end | ||
|
||
[ntrain, nfeatures]=size(samples_train); | ||
[ntest, nfeatures_]=size(samples_test); | ||
ntrain_=numel(targets_train); | ||
|
||
if nfeatures~=nfeatures_ || ntrain_~=ntrain | ||
error('illegal input size'); | ||
end | ||
|
||
if ~cached_has_matlabcsvm() | ||
cosmo_check_external('matlabcsvm'); | ||
end | ||
|
||
[class_idxs,classes]=cosmo_index_unique(targets_train(:)); | ||
nclasses=numel(classes); | ||
|
||
if nfeatures==0 || nclasses==1 | ||
% matlab's svm cannot deal with empty data, so predict all | ||
% test samples as the class of the first sample | ||
predicted=targets_train(1) * ones(ntest,1); | ||
return | ||
end | ||
|
||
|
||
opt_cell=opt2cell(opt); | ||
|
||
% number of pair-wise comparisons | ||
ncombi=nclasses*(nclasses-1)/2; | ||
|
||
% allocate space for all predictions | ||
all_predicted=NaN(ntest, ncombi); | ||
|
||
% Consider all pairwise comparisons (over classes) | ||
% and store the predictions in all_predicted | ||
pos=0; | ||
for k=1:(nclasses-1) | ||
for j=(k+1):nclasses | ||
pos=pos+1; | ||
% classify between 2 classes only | ||
idxs=cat(1,class_idxs{k},class_idxs{j}); | ||
|
||
|
||
model=fitcsvm(samples_train(idxs,:), targets_train(idxs), ... | ||
opt_cell{:}); | ||
|
||
pred=predict(model, samples_test(idxs, :)); | ||
all_predicted(idxs,pos)=pred; | ||
end | ||
end | ||
|
||
assert(pos==ncombi); | ||
|
||
% find the classes that were predicted most often. | ||
% ties are handled by cosmo_winner_indices | ||
[winners, test_classes]=cosmo_winner_indices(all_predicted); | ||
|
||
predicted=test_classes(winners); | ||
|
||
|
||
|
||
% helper function to convert cell to struct | ||
function opt_cell=opt2cell(opt) | ||
|
||
if isempty(opt) | ||
opt_cell=cell(0); | ||
return; | ||
end | ||
|
||
fns=fieldnames(opt); | ||
|
||
n=numel(fns); | ||
opt_cell=cell(1,2*n); | ||
for k=1:n | ||
fn=fns{keep_id(k)}; | ||
opt_cell{k*2-1}=fn; | ||
opt_cell{k*2}=opt.(fn); | ||
end | ||
|
||
|
||
function tf=cached_has_matlabcsvm() | ||
persistent cached_tf; | ||
|
||
if isequal(cached_tf,true) | ||
tf=true; | ||
return | ||
end | ||
|
||
cached_tf=cosmo_check_external('matlabcsvm'); | ||
tf=cached_tf; | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters