Skip to content

Commit

Permalink
Merge pull request #108 from nno/rf/parallel_mc_stat
Browse files Browse the repository at this point in the history
NF: parallel montecarlo_cluster_stat
  • Loading branch information
nno committed Nov 22, 2016
2 parents f689bf5 + 0dd6d0a commit ac10c4e
Show file tree
Hide file tree
Showing 7 changed files with 690 additions and 184 deletions.
14 changes: 13 additions & 1 deletion mvpa/cosmo_check_external.m
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,19 @@
externals.mocov.url='https://github.com/MOcov/MOcov';

function tf=has_octave_package(label)
tf=cosmo_wtf('is_octave') && ~isempty(pkg('list',label));
tf=false;
if ~cosmo_wtf('is_octave')
return;
end

result=pkg('list',label);
if isempty(result)
return;
end

assert(numel(result)==1);

tf=result{1}.loaded;

function tf=same_path(args)
pths=cellfun(@(x)fileparts(which(x)),args,'UniformOutput',false);
Expand Down
205 changes: 153 additions & 52 deletions mvpa/cosmo_montecarlo_cluster_stat.m
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,13 @@
% 'niter' if the 'feature_stat' option is not 'none'.
% 'progress',p Show progress every p steps (default: 10). Use
% p=false to not show progress.
% 'nproc', np If the Matlab parallel processing toolbox, or the
% GNU Octave parallel package is available, use
% np parallel threads. (Multiple threads may speed
% up computations).
% If parallel processing is not available, or if
% this option is not provided, then a single thread
% is used.
% 'seed',s Use seed s for pseudo-random number generation. If
% this option is provided, then this function behaves
% deterministically. If this option is omitted (the
Expand Down Expand Up @@ -230,6 +237,7 @@
defaults.feature_stat='auto';
defaults.cluster_stat='tfce';
defaults.progress=10;
defaults.nproc=1;

opt=cosmo_structjoin(defaults,varargin);

Expand All @@ -239,7 +247,13 @@
% check input options
check_opt(ds,opt);

% get number of processes
nproc_available=cosmo_parallel_get_nproc_available(opt);

% Matlab needs newline character at progress message to show it in
% parallel mode; Octave should not have newline character
environment=cosmo_wtf('environment');
progress_suffix=get_progress_suffix(environment);

% the heavy lifting is done by four helper functions:
% 1) ds_preproc=preproc_func(ds) takes a dataset and preprocesses it.
Expand Down Expand Up @@ -273,22 +287,110 @@

nfeatures=size(ds.samples,2);
orig_cluster_vals=zeros(2,nfeatures);
less_than_orig_count=zeros(2,nfeatures);

niter=get_niter(opt);

prev_progress_msg='';
clock_start=clock();
%compute the original values first
ds_perm=preproc_func(ds);
ds_perm_zscore=stat_func(ds_perm);

for neg_pos=1:2
% treat negative and positive values separately
perm_sign=2*neg_pos-3; % -1 or 1
% multiple samples by either 1 or -1
signed_perm_zscore=ds_perm_zscore.samples*perm_sign;
% apply clustering to z-scored data
cluster_vals=cluster_func(signed_perm_zscore);
orig_cluster_vals(neg_pos,:)=cluster_vals;
end

for iter=0:niter
is_null_iter=iter>0;
% split iterations in multiple parts, so that each thread can do a
% subset of all the work
block_size = ceil(niter/nproc_available);
iter_start=1:block_size:niter;
iter_end=[block_size:block_size:(niter-1) niter];

% set options for each worker process
worker_opt_cell=cell(1,nproc_available);
for p=1:nproc_available
worker_opt=struct();
worker_opt.orig_cluster_vals=orig_cluster_vals;
worker_opt.permutation_preproc_func=permutation_preproc_func;
worker_opt.stat_func=stat_func;
worker_opt.cluster_func=cluster_func;
worker_opt.nfeatures=nfeatures;
worker_opt.worker_id=p;
worker_opt.nworkers=nproc_available;
worker_opt.progress=opt.progress;
worker_opt.progress_suffix=progress_suffix;
worker_opt.iters=iter_start(p):iter_end(p);
worker_opt_cell{p}=worker_opt;
end

if is_null_iter
ds_perm=permutation_preproc_func(iter);
else
ds_perm=preproc_func(ds);
% Run process for each worker in parallel
% Note that when using nproc=1, cosmo_parcellfun does actually not
% use any parallellization; the result is a cell with a single element.
result_cell=cosmo_parcellfun(nproc_available,...
@run_with_worker,...
worker_opt_cell,...
'UniformOutput',false);

% join results from each worker
less_than_orig_count = sum(cat(3,result_cell{:}),3);

% safety check: each item is either positive or negative
assert(max(sum(less_than_orig_count>0,1))<=1);

% convert p-values of two tails into one p-value
ps_two_tailed=sum(bsxfun(@times,[-1;1],less_than_orig_count))/...
(niter*2)+.5;

% deal with extreme tails
min_p_value=1/niter;
ps_two_tailed(ps_two_tailed>1-min_p_value)=1-min_p_value;
ps_two_tailed(ps_two_tailed< min_p_value)=min_p_value;

% convert to z-score
z_two_tailed=norminv(ps_two_tailed);

% store result in dataset structure
ds_z=struct();
ds_z.samples=z_two_tailed;
ds_z.sa.stats={'Zscore()'};
ds_z.a=ds.a;
ds_z.fa=ds.fa;

function less_than_orig_count=run_with_worker(worker_opt)

orig_cluster_vals=worker_opt.orig_cluster_vals;
permutation_preproc_func=worker_opt.permutation_preproc_func;
stat_func=worker_opt.stat_func;
cluster_func=worker_opt.cluster_func;
nfeatures=worker_opt.nfeatures;
worker_id=worker_opt.worker_id;
nworkers=worker_opt.nworkers;
progress=worker_opt.progress;
progress_suffix=worker_opt.progress_suffix;
iters=worker_opt.iters;

less_than_orig_count=zeros(2,nfeatures);
niter = length(iters);

% see if progress is to be reported
show_progress=~isempty(progress) && ...
progress && ...
worker_id==1;
if show_progress
progress_step=progress;
if progress_step<1
progress_step=ceil(ncenters*progress_step);
end
prev_progress_msg='';
clock_start=clock();
end

for iter=1:niter
ds_perm=permutation_preproc_func(iters(iter));
ds_perm_zscore=stat_func(ds_perm);

for neg_pos=1:2
Expand All @@ -302,55 +404,44 @@
% apply clustering to z-scored data
cluster_vals=cluster_func(signed_perm_zscore);

if is_null_iter
% null permuted data, see which features show weaker
% cluster stat than the original data
perm_lt=max(cluster_vals)<orig_cluster_vals(neg_pos,:);
% null permuted data, see which features show weaker
% cluster stat than the original data
perm_lt=max(cluster_vals)<orig_cluster_vals(neg_pos,:);

% increase counter for those features
less_than_orig_count(neg_pos,perm_lt)=...
less_than_orig_count(neg_pos,perm_lt)+1;
else
% original data, store for comparison with null data
orig_cluster_vals(neg_pos,:)=cluster_vals;
end
% increase counter for those features
less_than_orig_count(neg_pos,perm_lt)=...
less_than_orig_count(neg_pos,perm_lt)+1;
end

show_progress=opt.progress && (iter<10 || ...
mod(iter, opt.progress)==0 || ...
iter==niter);

if show_progress
iter_pos=max(iter,1);
p_min=(iter_pos-max(less_than_orig_count,[],2))/iter_pos;
p_range=sqrt(1/4/max(iter,1));
msg=sprintf('p = %.3f / %.3f [+/-%.3f] (left/right)',...
p_min,p_range);
prev_progress_msg=cosmo_show_progress(clock_start, ...
(iter+1)/(niter+1), msg, prev_progress_msg);
if show_progress && (iter<10 || ...
~mod(iter, progress_step) || ...
iter==niter);
if nworkers>1
if iter==niter
% other workers may be slower than first worker
msg=sprintf(['worker %d has completed; waiting for '...
'other workers to finish...%s'],...
worker_id, progress_suffix);
else
% can only show progress from a single worker;
% therefore show progress of first worker
msg=sprintf('for worker %d / %d%s', worker_id, ...
nworkers, progress_suffix);
end
prev_progress_msg=cosmo_show_progress(clock_start, ...
iter/niter, msg, prev_progress_msg);
else
iter_pos=max(iter,1);
p_min=(iter_pos-max(less_than_orig_count,[],2))/iter_pos;
p_range=sqrt(1/4/max(iter,1));
msg=sprintf('p = %.3f / %.3f [+/-%.3f] (left/right)',...
p_min,p_range);
prev_progress_msg=cosmo_show_progress(clock_start, ...
(iter+1)/(niter+1), msg, prev_progress_msg);
end
end
end

assert(max(sum(less_than_orig_count>0,1))<=1);

% convert p-values of two tails into one p-value
ps_two_tailed=sum(bsxfun(@times,[-1;1],less_than_orig_count))/...
(niter*2)+.5;

% deal with extreme tails
ps_two_tailed(ps_two_tailed>1-1/niter)=1-1/niter;
ps_two_tailed(ps_two_tailed< 1/niter)=1/niter;

% convert to z-score
z_two_tailed=norminv(ps_two_tailed);

ds_z=struct();
ds_z.samples=z_two_tailed;
ds_z.sa.stats={'Zscore()'};
ds_z.a=ds.a;
ds_z.fa=ds.fa;


function stat_func=get_stat_func(ds,opt)
if has_feature_stat_auto(opt)
stat_func=get_stat_func_auto(ds);
Expand Down Expand Up @@ -774,3 +865,13 @@ function check_opt(ds,opt)
end


function suffix=get_progress_suffix(environment)
% Matlab needs newline character at progress message to show it in
% parallel mode; Octave should not have newline character

switch environment
case 'matlab'
suffix=sprintf('\n');
case 'octave'
suffix='';
end
Loading

0 comments on commit ac10c4e

Please sign in to comment.