In [38]:
from nilearn import datasets
from os import listdir
from os.path import isfile, isdir, join, basename
import math
import codecs, json
import numpy as np

from tqdm import tqdm, tqdm_gui, tqdm_notebook

from nilearn.input_data import NiftiLabelsMasker
from nilearn.connectome import ConnectivityMeasure


class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return json.JSONEncoder.default(self, obj)
    
    
# dataset = datasets.fetch_atlas_aal()
# atlas_filename = dataset.maps
atlas_list = [f for f in listdir('/data_59/huze/Atlases') if f.endswith('.nii.gz')]
atlas_filename = join('/data_59/huze/Atlases', atlas_list[0])
# labels = dataset.labels

correlation_measure = ConnectivityMeasure(kind='correlation')

In [39]:
atlas_list.remove('QSDR.scale60.thick2.MNI152.nii.gz')

In [40]:
atlas_list

['QSDR.scale125.thick2.MNI152.nii.gz',
 'QSDR.scale250.thick2.MNI152.nii.gz',
 'QSDR.scale33.thick2.MNI152.nii.gz',
 'QSDR.scale500.thick2.MNI152.nii.gz']

In [25]:
working_dir = '/data_59/huze/MRS/FEAT.linear'

subject_list = [f.strip('.feat') for f in listdir(working_dir) if f.endswith('.feat')]
correlation_dict = dict()
time_series_dict = dict()

In [26]:
len(subject_list)

212

In [28]:
def slice_time_series(time_series):
    """
    2s directions, 30s RESTBLOCK/IPBLOCK
    """
    time_series_slices = []
    for i in range(1, 12):
        start = math.ceil((32*i-30)/3)
        end = math.floor((32*i)/3)
        time_series_slices.append(time_series[start:end])
        
    rest_blocks, ip_blocks = time_series_slices[::2], time_series_slices[1::2]
    rest_block = np.concatenate(rest_blocks, axis=0)
    ip_block = np.concatenate(ip_blocks, axis=0)
    return rest_block, ip_block

In [29]:
def preproc(subject, masker, correlation_measure):
    fmri_filenames = join(working_dir, subject + '.feat', 'filtered_func_data.nii.gz')
    time_series = masker.fit_transform(fmri_filenames)
    rest_block, ip_block = slice_time_series(time_series)
    rest_corr = correlation_measure.fit_transform([rest_block])[0]
    ip_corr = correlation_measure.fit_transform([ip_block])[0]
    return rest_corr, ip_corr, time_series

In [None]:
for atlas in tqdm_notebook(atlas_list):
    atlas_file = join('/data_59/huze/Atlases', atlas)
    correlation_dict = dict()
    atlas_name = atlas.strip('.MNI152.nii.gz')
    masker = NiftiLabelsMasker(labels_img=atlas_file, standardize=True,
                           memory='nilearn_cache', verbose=5)
    for subject in tqdm_notebook(subject_list, leave=False):
        rest_corr, ip_corr, _ = preproc(subject, masker, correlation_measure)
        correlation_dict.update({subject: {'RESTBLOCK': rest_corr,
                                           'IPBLOCK': ip_corr
                                          }})
    json.dump(correlation_dict, codecs.open('/data_59/huze/MRS/TSCM/correlation_matrix_sliced_{}.json'.format(atlas_name),
                                        'w', encoding='utf-8'), separators=(',', ':'),
                                        sort_keys=True, indent=4, cls=NumpyEncoder)

In [42]:
sliced_corr_dict['3044_1'].keys()

dict_keys(['IPBLOCK', 'RESTBLOCK', 'WHOLE'])

In [43]:
scale_list = ['33', '125', '250', '500']
# scale_list = ['60']
tscm_dir = '/data_59/huze/MRS/TSCM/'

In [None]:
for scale in tqdm_notebook(scale_list):
    sliced_corr_dict = json.load(codecs.open(join(tscm_dir, 'correlation_matrix_sliced_QSDR.scale{0}.thick.json'.format(scale)),
                                              'r', encoding='utf-8'))
    whole_corr_dict = json.load(codecs.open(join(tscm_dir, 'correlation_matrix_QSDR.scale{0}.thick.json'.format(scale)),
                                              'r', encoding='utf-8'))
    for subject in tqdm_notebook(subject_list, leave=False):
        sliced_corr_dict[subject]['WHOLE'] = whole_corr_dict[subject]
        
    json.dump(sliced_corr_dict, codecs.open(join(tscm_dir, 'correlation_matrix_sliced_QSDR.scale{0}.thick.json'.format(scale)),
                                        'w', encoding='utf-8'), separators=(',', ':'),
                                        sort_keys=True, indent=4, cls=NumpyEncoder)