diff --git a/mvpa/cosmo_phase_itc.m b/mvpa/cosmo_phase_itc.m new file mode 100644 index 00000000..d9569a3d --- /dev/null +++ b/mvpa/cosmo_phase_itc.m @@ -0,0 +1,159 @@ +function itc_ds=cosmo_phase_itc(ds,varargin) +% compute phase inter trial coherence +% +% itc_ds=cosmo_phase_itc(ds,varargin) +% +% Inputs: +% ds dataset struct with fields: +% .samples PxQ complex matrix for P samples (trials, +% observations) and Q features (e.g. combinations +% of time points, frequencies and channels) +% .sa.targets Px1 array with trial conditions. Each condition +% must occur equally often; that is, the +% design must be balanced. +% In the typical case of two conditions, +% .sa.targets must have exactly two unique +% values. +% .sa.chunks Px1 array indicating which samples can be +% considered to be independent. It is required +% that all samples are independent, therefore +% all values in .sa.chunks must be different from +% each other +% .fa } optional feature attributes +% .a } optional sample attributes +% 'samples_are_unit_length',u (optional) +% If u==true, then all elements in ds.samples +% are assumed to be already of unit length. If +% this is indeed true, this can speed up the +% computation of the output. +% +% Output: +% itc_ds dataset structu with fields +% .samples (N+1)xQ array with inter-trial coherence +% measure, where U=unique(ds.sa.targets) and +% N=numel(U). The first N rows correspond to the +% inter trial coherence for each condition. The +% last row is the inter trial coherence for all +% samples together. +% .sa.targets (N+1)x1 vector containing the vector [U(:);1]' +% with trial conditions +% +% .a } if present in the input, then the output +% .fa } contains these fields as well + + defaults=struct(); + defaults.samples_are_unit_length=false; + defaults.check_dataset=true; + + opt=cosmo_structjoin(defaults,varargin{:}); + + check_input(ds,opt); + + samples=ds.samples; + if opt.samples_are_unit_length + quick_check_some_samples_being_unit_length(samples); + else + % normalize + samples=samples./abs(samples); + end + + [idxs,classes]=cosmo_index_unique(ds.sa.targets); + nclasses=numel(classes); + nfeatures=size(samples,2); + itc=zeros(nclasses+1,nfeatures); + + % ITC for each class + for k=1:nclasses + samples_k=samples(idxs{k},:); + itc(k,:)=itc_on_unit_length_elements(samples_k); + end + + % overall ITC + itc(nclasses+1,:)=itc_on_unit_length_elements(samples); + + % set output + itc_ds=set_output(itc,ds,classes); + + +function itc_ds=set_output(itc,ds,classes) + % store results + itc_ds=struct(); + itc_ds.samples=itc; + itc_ds.sa.targets=[classes(:); NaN]; + + % copy .a and .fa fields, if present + if isfield(ds,'a') + itc_ds.a=ds.a; + + if isfield(ds.a,'sdim') + % remove sample dimensions if present + itc_ds.a=rmfield(itc_ds.a,'sdim'); + end + end + + if isfield(ds,'fa') + itc_ds.fa=ds.fa; + end + + +function itc=itc_on_unit_length_elements(samples) + % computes inter-trial coherence for each column seperately + itc=abs(sum(samples,1)./size(samples,1)); + + +function quick_check_some_samples_being_unit_length(samples) + % instead of checking all values, only verify for a subset of values. + % This should prevent most use cases where the user accidentally + % uses non-normalized data, whereas checking all values would be + % equivalent to actually computing their length for each of them. + count_to_check=10; + + % generate random positions to check for unit length + nelem=numel(samples); + pos=ceil(rand(1,count_to_check)*nelem); + + samples_subset=samples(pos); + lengths=abs(samples_subset); + + delta=eps(1); + if any(lengths+delta<1 | lengths-delta>1) + error('.samples input is not of unit length'); + end + + +function check_input(ds,opt) + % must be a proper dataset + raise_exception=true; + cosmo_check_dataset(ds,raise_exception); + + % must have targets and chunks + cosmo_isfield(ds,{'sa.targets','sa.chunks'},raise_exception); + + + % all chunks must be unique + if ~isequal(sort(ds.sa.chunks),unique(ds.sa.chunks)) + error(['All values in .sa.chunks must be different '... + 'from each other']); + end + + % trial counts must be balanced + [idxs,classes]=cosmo_index_unique(ds.sa.targets); + class_count=cellfun(@numel,idxs); + unequal_pos=find(class_count~=class_count(1),1); + if ~isempty(unequal_pos) + error(['.sa.targets indicates unbalanced targets, with '... + '.sa.targets==%d occurding %d times, and '... + '.sa.targets==%d occurding %d times'],... + 1,class_count(1),unequal_pos,class_count(unequal_pos)); + end + + % input must be complex + if isreal(ds.samples) + error('.samples must be complex'); + end + + v=opt.samples_are_unit_length; + if ~(islogical(v) ... + && isscalar(v)) + error('option ''samples_are_unit_length'' must be logical scalar'); + end diff --git a/tests/test_phase_itc.m b/tests/test_phase_itc.m new file mode 100644 index 00000000..2ede5af6 --- /dev/null +++ b/tests/test_phase_itc.m @@ -0,0 +1,139 @@ +function test_suite=test_phase_itc +% tests for test_phase_itc +% +% # For CoSMoMVPA's copyright information and license terms, # +% # see the COPYING file distributed with CoSMoMVPA. # + try % assignment of 'localfunctions' is necessary in Matlab >= 2016 + test_functions=localfunctions(); + catch % no problem; early Matlab versions can use initTestSuite fine + end + initTestSuite; + +function r=randint() + r=ceil(2+rand()*10); + +function test_phase_itc_basics + nclasses=randint(); + classes=1:2:(2*nclasses); + + nrepeats=randint(); + nfeatures=randint(); + + ds=generate_random_dataset(classes,nrepeats,nfeatures); + + % compute expected ITC + itc_ds=cosmo_phase_itc(ds); + expected_samples=zeros(nclasses+1,nfeatures); + for k=1:nclasses + msk=ds.sa.targets==classes(k); + expected_samples(k,:)=quick_itc(ds.samples(msk,:)); + end + expected_samples(nclasses+1,:)=quick_itc(ds.samples); + + % construct expected dataset + expected_itc_ds=struct(); + expected_itc_ds.samples=expected_samples; + expected_itc_ds.sa.targets=[classes,NaN]'; + expected_itc_ds.a=ds.a; + expected_itc_ds.fa=ds.fa; + + assert_datasets_almost_equal(itc_ds,expected_itc_ds); + + +function test_phase_itc_unit_length() + ds=generate_random_dataset(1:10,randint(),randint()); + ds_unit=ds; + ds_unit.samples=ds_unit.samples./abs(ds_unit.samples); + + itc_ds=cosmo_phase_itc(ds); + itc_unit_ds=cosmo_phase_itc(ds_unit,'samples_are_unit_length',true); + + assert_datasets_almost_equal(itc_ds,itc_unit_ds); + +function assert_datasets_almost_equal(p,q) + assertElementsAlmostEqual(p.samples,q.samples); + + p=rmfield(p,'samples'); + q=rmfield(q,'samples'); + + + assertEqual(p,q); + + + +function ds=generate_random_dataset(classes,nrepeats,nfeatures) + nclasses=numel(classes); + nsamples=nclasses*nrepeats; + sz=[nsamples,nfeatures]; + ds=struct(); + ds.samples=randn(sz)+1i*randn(sz); + ds.sa.targets=repmat(classes,1,nrepeats)'; + ds.sa.chunks=(1:nsamples)'; + ds.a='foo'; + ds.fa.bar=1:nfeatures; + + % permute randomly + ds=cosmo_slice(ds,cosmo_randperm(nsamples)); + + +function test_phase_itc_sdim_field + ds=cosmo_synthetic_dataset('ntargets',3,'nchunks',3); + ds.samples=ds.samples+1i*randn(size(ds.samples)); + ds.sa.chunks(:)=1:9; + + % add sample dimension + ds=cosmo_dim_insert(ds,1,1,{'foo'},{[1:9]},{[1:9]'}); + + itc_ds=cosmo_phase_itc(ds); + assert(~isfield(itc_ds.a,'sdim')); + assert(~isfield(itc_ds.sa,'foo')); + + + +function itc=quick_itc(samples) + s=samples./abs(samples); + itc=abs(mean(s,1)); + + +function test_phase_itc_exceptions + aet=@(varargin)assertExceptionThrown(... + @()cosmo_phase_itc(varargin{:}),''); + + ds=cosmo_synthetic_dataset('ntargets',2,'nchunks',6); + nsamples=size(ds.samples,1); + sz=size(ds.samples); + ds.samples=randn(sz)+1i*randn(sz); + ds.sa.chunks(:)=1:nsamples; + cosmo_phase_itc(ds); % ok + + % input not imaginary + bad_ds=ds; + bad_ds.samples=randn(sz); + aet(bad_ds); + + % chunks not all unique + bad_ds=ds; + bad_ds.sa.chunks(1)=bad_ds.sa.chunks(2); + aet(bad_ds); + + % imbalance + bad_ds=ds; + bad_ds.sa.targets(:)=[repmat([1 2],1,5),[1 1]]; + aet(bad_ds); + + % bad values for samples_are_unit_length + bad_samples_are_unit_length_cell={[],'',1,[true false]}; + for k=1:numel(bad_samples_are_unit_length_cell) + arg={'samples_are_unit_length',... + bad_samples_are_unit_length_cell{k}}; + aet(ds,arg{:}); + end + + % with samples_are_unit_length=true, raise exception if some values + % are not unit length + aet(ds,'samples_are_unit_length',true); + + + + +