Skip to content

Commit

Permalink
Merge pull request #190 from nno/_enh/new_svm
Browse files Browse the repository at this point in the history
NF: support for fitcsvm
  • Loading branch information
nno committed Jul 25, 2020
2 parents c063cde + 21ade4a commit a431780
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 13 deletions.
33 changes: 31 additions & 2 deletions mvpa/cosmo_check_external.m
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
% 'xunit' xUnit unit test framework
% 'moxunit' MOxUnit unit test framework
% 'matlabsvm' SVM classifier in matlab stats
% toolbox
% toolbox (prior 2018a)
% 'matlabcsvm'
% 'svm' Either matlabsvm or libsvm
% '@{name}' Matlab toolbox {name}
% It can also be '-list', '-tic', '-toc',' or
Expand Down Expand Up @@ -477,7 +478,9 @@
externals.matlabsvm.is_present=@() (has_toolbox('stats') || ...
has_toolbox('bioinfo')) && ...
has('svmtrain') && ...
has('svmclassify');
has('svmclassify') && ...
is_matlab_prior_2018a();

externals.matlabsvm.is_recent=yes;
externals.matlabsvm.conflicts.neuroelf=@() isequal(...
path_of('svmtrain'),...
Expand All @@ -488,6 +491,14 @@
externals.matlabsvm.label='Matlab stats or bioinfo toolbox';
externals.matlabsvm.url='http://www.mathworks.com';

externals.matlabcsvm.is_present=@() cosmo_wtf('is_matlab') && ...
has('fitcsvm');

externals.matlabcsvm.is_recent=yes;
externals.matlabcsvm.label='Matlab stats or bioinfo toolbox';
externals.matlabcsvm.url='http://www.mathworks.com';


externals.svm={'libsvm', 'matlabsvm'}; % need either

externals.distatis.is_present=yes;
Expand Down Expand Up @@ -559,6 +570,24 @@
externals.modox.authors={'N. N. Oosterhof'};
externals.modox.url='https://github.com/MOdox/MOdox';


function tf=is_matlab_prior_2018a()
this_version=cosmo_wtf('version_number');

matlab_pivot=[9, 4]; % verison 2018a
n_elem = numel(matlab_pivot);

delta = this_version(1:n_elem) - matlab_pivot;
if all(delta==0)
% this is version 2018a
tf=false;
else
idx=find(delta~=0,1);
% positive delta means that the current matlab version
% is later than 2018a
tf=delta(idx)<0;
end

function tf=has_octave_package(label)
tf=false;
if ~cosmo_wtf('is_octave')
Expand Down
127 changes: 127 additions & 0 deletions mvpa/cosmo_classify_matlabcsvm.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
function predicted=cosmo_classify_matlabcsvm(samples_train, targets_train, samples_test, opt)
% svm classifier wrapper (around fitcsvm)
%
% predicted=cosmo_classify_matlabcsvm(samples_train, targets_train, samples_test, opt)
%
% Inputs:
% samples_train PxR training data for P samples and R features
% targets_train Px1 training data classes
% samples_test QxR test data
% opt struct with options. supports any option that
% fitcsvm supports
%
% Output:
% predicted Qx1 predicted data classes for samples_test
%
% Notes:
% - this function uses Matlab's builtin fitcsvm function, which was the
% successor of svmtrain.
% - Matlab's SVM classifier is rather slow, especially for multi-class
% data (more than two classes). When classification takes a long time,
% consider using libsvm.
% - for a guide on svm classification, see
% http://www.csie.ntu.edu.tw/~cjlin/papers/guide/guide.pdf
% Note that cosmo_crossvalidate and cosmo_crossvalidation_measure
% provide an option 'normalization' to perform data scaling
% - As of Matlab 2017a (maybe earlier), Matlab gives the warning that
% 'svmtrain will be removed in a future release. Use fitcsvm instead.'
% however fitcsvm gives different results than svmtrain; as a result
% cosmo_classify_matlabcsvm gives different results than
% cosmo_classify_matlabsvm.
%
% See also fitcsvm, svmclassify, cosmo_classify_matlabsvm.
%
% # For CoSMoMVPA's copyright information and license terms, #
% # see the COPYING file distributed with CoSMoMVPA. #

if nargin<4, opt=struct(); end

[ntrain, nfeatures]=size(samples_train);
[ntest, nfeatures_]=size(samples_test);
ntrain_=numel(targets_train);

if nfeatures~=nfeatures_ || ntrain_~=ntrain
error('illegal input size');
end

if ~cached_has_matlabcsvm()
cosmo_check_external('matlabcsvm');
end

[class_idxs,classes]=cosmo_index_unique(targets_train(:));
nclasses=numel(classes);

if nfeatures==0 || nclasses==1
% matlab's svm cannot deal with empty data, so predict all
% test samples as the class of the first sample
predicted=targets_train(1) * ones(ntest,1);
return
end


opt_cell=opt2cell(opt);

% number of pair-wise comparisons
ncombi=nclasses*(nclasses-1)/2;

% allocate space for all predictions
all_predicted=NaN(ntest, ncombi);

% Consider all pairwise comparisons (over classes)
% and store the predictions in all_predicted
pos=0;
for k=1:(nclasses-1)
for j=(k+1):nclasses
pos=pos+1;
% classify between 2 classes only
idxs=cat(1,class_idxs{k},class_idxs{j});


model=fitcsvm(samples_train(idxs,:), targets_train(idxs), ...
opt_cell{:});

pred=predict(model, samples_test(idxs, :));
all_predicted(idxs,pos)=pred;
end
end

assert(pos==ncombi);

% find the classes that were predicted most often.
% ties are handled by cosmo_winner_indices
[winners, test_classes]=cosmo_winner_indices(all_predicted);

predicted=test_classes(winners);



% helper function to convert cell to struct
function opt_cell=opt2cell(opt)

if isempty(opt)
opt_cell=cell(0);
return;
end

fns=fieldnames(opt);

n=numel(fns);
opt_cell=cell(1,2*n);
for k=1:n
fn=fns{keep_id(k)};
opt_cell{k*2-1}=fn;
opt_cell{k*2}=opt.(fn);
end


function tf=cached_has_matlabcsvm()
persistent cached_tf;

if isequal(cached_tf,true)
tf=true;
return
end

cached_tf=cosmo_check_external('matlabcsvm');
tf=cached_tf;

13 changes: 11 additions & 2 deletions mvpa/cosmo_classify_matlabsvm.m
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,18 @@
% - for a guide on svm classification, see
% http://www.csie.ntu.edu.tw/~cjlin/papers/guide/guide.pdf
% note that cosmo_crossvalidate and cosmo_crossvalidation_measure
% provide an option 'normalization' to perform data scaling
% provide an option 'normalization' to perform data scaling.
% - As of Matlab 2017a (maybe earlier), Matlab gives the warning that
% 'svmtrain will be removed in a future release. Use fitcsvm instead.'
% however fitcsvm gives different results than svmtrain; as a result
% cosmo_classify_matlabcsvm gives different results than
% cosmo_classify_matlabsvm. In this function the warning message is
% . suppressed.
% - As of Matlab 2018a, this function cannot be used anymore. Use
% cosmo_classify_matlabcsvm instead.
%
% See also svmtrain, svmclassify, cosmo_classify_svm, cosmo_classify_libsvm
% See also svmtrain, svmclassify, cosmo_classify_svm,
% cosmo_classify_libsvm, cosmo_classify_matlabcsvm
%
% # For CoSMoMVPA's copyright information and license terms, #
% # see the COPYING file distributed with CoSMoMVPA. #
Expand Down
19 changes: 11 additions & 8 deletions mvpa/cosmo_classify_matlabsvm_2class.m
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,15 @@
% provide an option 'normalization' to perform data scaling
% - As of Matlab 2017a (maybe earlier), Matlab gives the warning that
% 'svmtrain will be removed in a future release. Use fitcsvm instead.'
% however fitcsvm gives different results than svmtrain. In this
% function the warning message is suppressed.
% however fitcsvm gives different results than svmtrain; as a result
% cosmo_classify_matlabcsvm gives different results than
% cosmo_classify_matlabsvm. In this function the warning message is
% . suppressed.
% - As of Matlab 2018a, this function cannot be used anymore. Use
% cosmo_classify_matlabcsvm instead.
%
% See also svmtrain, svmclassify, cosmo_classify_matlabsvm
% See also svmtrain, svmclassify, cosmo_classify_matlabsvm,
% cosmo_classify_matlabcsvm
%
% # For CoSMoMVPA's copyright information and license terms, #
% # see the COPYING file distributed with CoSMoMVPA. #
Expand Down Expand Up @@ -91,17 +96,15 @@
% only show warning once (by default) if this is a
% a stats:obsolete message
suffix=['CoSMoMVPA note: the more recent '...
'fitcsvm / svmsmoset classifiers'...
'fitcsvm / svmsmoset classifiers produce '...
'different results '...
'than the older svmtrain function. '...
'Currently there is no support '...
'in CoSMoMVPA for using fitcsvm in a '...
'classifier'];
'To use fitcsvm, use cosmo_classify_matlabcsvm'];
cosmo_warning('%s\n%s',warning_msg,suffix);
elseif ~strcmp(warning_id,orig_lastid)
% new warning was issued , different from stats:obsolete one;
% show warning message
warning(warning_id,warning_msg);
cosmo_warning(warning_id,warning_msg);
end


Expand Down
1 change: 1 addition & 0 deletions mvpa/cosmo_crossvalidation_measure.m
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@
folds=cat(1,fold_cell{:});
fold_targets=cat(1,target_cell{:});

2;


function winner_pred=compute_winner_predictions(pred)
Expand Down
66 changes: 65 additions & 1 deletion tests/test_classify.m
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,52 @@
end
initTestSuite;

function test_general_classifiers_strong_signal()
% note: part of these checks are also implemented in
% general_test_classifier (included below)

cfy_cell={{@cosmo_classify_lda,''},...
{@cosmo_classify_nn,''},...
{@cosmo_classify_naive_bayes,''},...
{@cosmo_classify_libsvm,'libsvm'},...
{@cosmo_classify_matlabsvm,'matlabsvm'},...
{@cosmo_classify_matlabcsvm,'matlabcsvm'}};

n_cfy=numel(cfy_cell);

no_information=0;
strong_information=2+rand();
for sigma=[no_information, strong_information]
[tr_s,tr_t, te_s, te_t]=generate_informative_data(sigma);

for k=1:n_cfy
cfy=cfy_cell{k}{1};
predictor_func=@()cfy(tr_s, tr_t, te_s);

external=cfy_cell{k}{2};
if ~strcmp(external,'')
if ~cosmo_check_external(external,false)
assertExceptionThrown(predictor_func,'*');
continue;
end
end

pred = cfy(tr_s, tr_t, te_s);
acc = mean(pred==te_t);

assertEqual(numel(pred), 400);

if sigma==no_information
assertTrue(0.4 < acc && acc < 0.6);
elseif sigma==strong_information
assertTrue(acc > 0.9);
else
assertFalse(true);
end
end
end


function test_classify_lda
cfy=@cosmo_classify_lda;
handle=get_predictor(cfy);
Expand Down Expand Up @@ -113,6 +159,24 @@
assert_predictions_equal(handle,[1 3 9 7 6 6 9 3 7 5 6 6 4 ...
1 7 7 7 7 1 7 7 1 7 6 7 1 9]');
general_test_classifier(cfy);

function test_classify_matlabcsvm
warning_state=cosmo_warning();
cleaner=onCleanup(@()cosmo_warning(warning_state));
cosmo_warning('off');

cfy=@cosmo_classify_matlabcsvm;
handle=get_predictor(cfy);
if ~cosmo_check_external('matlabcsvm',false)
assert_throws_illegal_input_exceptions(cfy);
assertExceptionThrown(handle,'');
notify_test_skipped('matlabcsvm');
return;
end

assert_predictions_equal(handle,[1 2 3 4 5 6 7 8 9 1 2 3 4 ...
5 6 7 8 9 1 2 3 4 5 6 7 8 9]');
general_test_classifier(cfy);

function test_classify_matlabsvm_2class
warning_state=cosmo_warning();
Expand Down Expand Up @@ -287,7 +351,7 @@ function assert_chance_null_data(cfy)
assert_accuracy_in_range(cfy, 0, 0.3, 0.7);

function assert_above_chance_informative_data(cfy)
assert_accuracy_in_range(cfy, 10, 0.8, 1);
assert_accuracy_in_range(cfy, 2, 0.8, 1);

function assert_accuracy_in_range(cfy, sigma, min_val, max_val)
[tr_s,tr_t, te_s, te_t]=generate_informative_data(sigma);
Expand Down

0 comments on commit a431780

Please sign in to comment.