In [9]:
import numpy as np
import glob, os

from mvpa2.base import externals
from mvpa2.base.param import Parameter
from mvpa2.base.constraints import *
from mvpa2.base.dataset import hstack
from mvpa2.base.types import is_datasetlike

from mvpa2.datasets import Dataset

from mvpa2.mappers.zscore import zscore
from mvpa2.mappers.fxy import FxyMapper
from mvpa2.mappers.svd import SVDMapper

from mvpa2.measures.searchlight import Searchlight
from mvpa2.measures.base import Measure

from mvpa2.algorithms.searchlight_hyperalignment import SearchlightHyperalignment
# experimental sparse matrix usage for faster computing with possible extra mem load
# from searchlight_hyperalignment import SearchlightHyperalignment

from mvpa2.algorithms.hyperalignment import Hyperalignment

from mvpa2.support.due import due, Doi

if externals.exists('h5py'):
    from mvpa2.base.hdf5 import h5save, h5load

if __debug__:
    from mvpa2.base import debug
    if 'CHPAL' in debug.active:
        def _chpaldebug(msg):
            debug('CHPAL', "%s" % msg)
    else:
        def _chpaldebug(*args):
            return None
else:
    def _chpaldebug(*args):
        return None

class MeanFeatureMeasure(Measure):
    """Mean group feature measure
    Because the vanilla one doesn't want to work for me (Swaroop).
    TODO: figure out why "didn't work" exactly, and adjust description
    and possibly name above
    """

    is_trained = True

    def __init__(self, **kwargs):
        Measure.__init__(self, **kwargs)

    def _call(self, dataset):
        return Dataset(samples=np.mean(dataset.samples, axis=1))
    
def _get_connectomes(self, datasets):
    params = self.params
    # If no precomputed connectomes are supplied, compute them.
    if params.connectomes is not None and os.path.exists(params.connectomes):
        _chpaldebug("Loading pre-computed connectomes from %s" % params.connectomes)
        connectomes = h5load(params.connectomes)
        return connectomes
    connectivity_mapper = FxyMapper(params.conn_metric)
    # Initializing datasets with original anatomically aligned datasets
    mfm = MeanFeatureMeasure()
    # TODO Handle seed_radius if seed queryengines are not provided
    seed_radius = params.seed_radius
    _chpaldebug("Performing surface connectivity hyperalignment with seeds")
    _chpaldebug("Computing connectomes.")
    ndatasets = len(datasets)
    if params.seed_queryengines is None:
        raise NotImplementedError("For now, we need seed queryengines.")
    qe_all = super(ConnectivityHyperalignment, self)._get_trained_queryengines(
        datasets, params.seed_queryengines, seed_radius, params.ref_ds)
    # If seed_indices are not supplied, use all as centers
    if not params.seed_indices:
        roi_ids = super(ConnectivityHyperalignment, self)._get_verified_ids(qe_all)
    else:
        roi_ids = params.seed_indices
    if len(qe_all) == 1:
        qe_all *= ndatasets
    # Computing Seed means to be used for aligning seed features
    seed_means = [self._get_seed_means(MeanFeatureMeasure(), qe, ds, params.seed_indices)
                  for qe, ds in zip(qe_all, datasets)]
    if params.npcs is None:
        conn_targets = []
        for seed_mean in seed_means:
            zscore(seed_mean, chunks_attr=None)
            conn_targets.append(seed_mean)
    else:
        # compute all PC-seed connectivity in each subject
        # 1. make common model SVs in each seed SL based on connectivity to seed_means
        # 2. Use these SVs for computing connectomes
        _chpaldebug("Aligning SVs in each searchlight across subjects")
        # Looping over all seeds in which SVD is done
        pc_data = [[] for isub in range(ndatasets)]
        sl_common_models = dict()
        if params.common_model is not None and os.path.exists(params.common_model):
            _chpaldebug("Loading common model from %s" % params.common_model)
            common_model = h5load(params.common_model)
            sl_common_models = common_model['local_models']
        for inode in roi_ids:
            # For each SL, computing connectivity of features to seed means
            # This line below doesn't need common model
            sl_connectomes = self._get_sl_connectomes(seed_means, qe_all, datasets,
                                                      inode, connectivity_mapper)
            # Hyperalign connectomes in SL
            # XXX TODO Common model input to below function should be updated.
            local_common_model = sl_common_models[inode][:, :params.npcs] \
                                    if params.common_model else None
            sl_hmappers, svm, sl_common_model = self._get_hypesvs(sl_connectomes,
                                            local_common_model=local_common_model)
            if sl_common_model is not None:
                sl_common_models[inode] = sl_common_model
            # make common model SV timeseries data in each subject
            for sd, slhm, qe, pcd in zip(datasets, sl_hmappers, qe_all, pc_data):
                sd_svs = slhm.forward(sd[:, qe[inode]])
                zscore(sd_svs, chunks_attr=None)
                if svm is not None:
                    sd_svs = svm.forward(sd_svs)
                    sd_svs = sd_svs[:, :params.npcs]
                    zscore(sd_svs, chunks_attr=None)
                pcd.append(sd_svs)
        if params.save_model is not None:
            # TODO: should use debug
            print('Saving local models to %s' % params.save_model)
            h5save(params.save_model, sl_common_models)
        pc_data = [hstack(pcd) for pcd in pc_data]
        conn_targets = pc_data
        #print pc_data[-1]
    # compute connectomes using connectivity targets (PCs or seed means)
    connectomes = []
    if params.common_model is not None and os.path.exists(params.common_model):
        # TODO: should use debug
        print('Loading from saved common model: %s' % params.common_model)
        connectome_model = common_model['connectome_model']
        connectomes.append(connectome_model)
    for t_, ds in zip(conn_targets, datasets):
        connectivity_mapper.train(t_)
        connectome = connectivity_mapper.forward(ds)
        t_ = None
        connectome.fa = ds.fa
        if connectome.samples.dtype == 'float64':
            connectome.samples = connectome.samples.astype('float32')
        zscore(connectome, chunks_attr=None)
        connectomes.append(connectome)
    if params.connectomes is not None and not os.path.exists(params.connectomes):
        _chpaldebug("Saving connectomes to ", params.connectomes)
        h5save(params.connectomes, connectomes)
    return connectomes



In [44]:
data_path = '/dartfs-hpc/scratch/psyc164/tcat/data/budapest/'
save_path = '/dartfs-hpc/scratch/psyc164/tcat/data'
hemis = ['L'] #'R'
subids = [5, 7, 9] #, 10, 13, 20, 21, 24, 29, 34, 52, 114, 120, 134, 142, 278, 416, 499, 522, 535, 560]
files = []

for hemi in hemis:
    for subid in subids:
        sub = '{:0>6}'.format(subid)
        fn = os.path.join(data_path + 'sub-rid' + sub + '*' + hemi + '.npy')
        files.append(sorted(glob.glob(fn)))

In [45]:
targets = range(1,21)
ds = None
for x in range(len(files)):
    d = mv.Dataset(np.load(files[x][0]))#mv.gifti_dataset(files[x], targets=targets)
    if ds is None:
        ds = d
    else:      
        ds = mv.vstack((ds,d))
ds.fa['node_indices'] = range(ds.shape[1])

In [46]:
ds.shape

(9156, 9372)

In [49]:
_get_connectomes(ds)

TypeError: _get_connectomes() takes exactly 2 arguments (1 given)