# Download and load data from ABIDE

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os, sys

gdrive = "colab.research.google.com" in str(os.environ)
if gdrive:
  from google.colab import drive, output
  drive.mount(os.path.join('/content/drive'))
  output.enable_custom_widget_manager()
  path_list = [os.path.sep, 'content', 'drive', 'MyDrive', 'Colab Notebooks', 'clases']
  path = os.path.join(*path_list)
  sys.path.append(path)
  get_ipython().run_line_magic('cd', path)
  path_data = [os.path.sep, 'content', 'drive', 'MyDrive', 'Colab Notebooks', 'datos', 'abide']
  path_data = os.path.join(*path_data)

try:
  import nilearn
  %matplotlib widget
except:
  !pip install nilearn
  !pip install ipympl
  %matplotlib widget


In [None]:
%ls

In [None]:
path_data

## Loading the atlas and ABIDE filename(s)

In [None]:
apply_all = False
atlas_name = 'aal'

atlas = load_atlas(atlas_name)
atlas_filename, roi_numbers, N_labels, interpolation = load_atlas_metadata('aal')

print(atlas_filename)
print(roi_numbers)
print(N_labels)


## One subject analysis

### Loading a ROI step-by-step

In [None]:
atlas_name = 'aal'
atlas = load_atlas(atlas_name)
atlas_filename, roi_numbers, N_labels, interpolation = load_atlas_metadata('aal')
atlas_data = load_nii(atlas_filename)
print(atlas_data.shape)

plt.close('all')
MultiImageDisplay([atlas_data])
plt.show()

In [None]:
file_id = 'Pitt_0050003'
dx = 1

atlas_name = 'aal'
atlas = load_atlas(atlas_name)
atlas_filename, roi_numbers, N_labels, interpolation = load_atlas_metadata('aal')

func_mean_filename = os.path.join(path_data, 'func_mean',f'{file_id}_func_mean.nii.gz')
func_filename = os.path.join(path_data, 'nofilt_noglobal', f'{file_id}_func_preproc.nii.gz')
reho_filename = os.path.join(path_data, 'reho', f'{file_id}_reho.nii.gz')
atlas_resampled = np.squeeze(resample_to_img(atlas_filename, func_mean_filename, interpolation=interpolation).get_fdata())

func_data = load_nii(func_filename)
reho_data = load_nii(reho_filename)
func_mean_data = load_nii(func_mean_filename)

print('func shape', func_data.shape)
print('reho shape', reho_data.shape)
print('atlas resampled shape', atlas_resampled.shape)
print('func mean shape', func_mean_data.shape)

plt.close('all')
MultiImageDisplay([atlas_resampled, 100*(reho_data/reho_data.max()), func_mean_data],
                  shared_slider=True,
                  title_list=['atlas resampled', 'reho', 'functional mean'],
                  figure_size=(10,6))
plt.show()


In [None]:
n = 2101

if atlas_resampled.ndim == 3: atlas_idx = atlas_resampled == n
else: atlas_idx = atlas_resampled[:, :, :, n-1] > 0

atlas_mask = np.zeros(atlas_resampled.shape)
atlas_mask[atlas_idx] = atlas_resampled[atlas_idx]
(i,j,k) = np.where(atlas_idx)
zvals1 = np.unique(k)

reho_p = 90
reho_th = np.percentile(reho_data[atlas_idx].flatten(), reho_p)
idx = np.logical_and( atlas_idx, reho_data>=reho_th )
(i,j,k) = np.where(idx)
zvals2 = np.unique(k)

print(zvals1)
print(atlas_mask.sum()/n)
print(zvals2)

reho_mask = np.zeros(reho_data.shape)
reho_mask[idx] = reho_data[idx]

# # print(np.unique(atlas_mask.flatten()), atlas_mask.shape)

zval = zvals2[len(zvals2)//2]

roi_signals = func_data[idx, :].T


plt.close('all')
fig, axs = plt.subplots(1, 2, num=1)
axs[0].imshow(atlas_mask[:,:,zval])
axs[1].imshow(100*reho_mask[:,:,zval]/reho_mask.max())
plt.show()



plt.figure(2)
plt.imshow(roi_signals)
plt.show()

plt.figure(3)
plt.plot(roi_signals[:, 10])
plt.show()


## Loading all ROI signals of one ABIDE subject

In [None]:

def load_roi_signals(n, func_data, atlas_data, reho_data=None, reho_p=90):
  if isinstance(reho_data, type(None)): reho_filtering = False
  if atlas_data.ndim == 3:
    atlas_idx = atlas_data == n # atlas coordinates with label n
  else:
    atlas_idx = atlas_data[:, :, :, n-1] > 0 # atlas coordinates with label n
  if reho_filtering:
    reho_th = np.percentile(reho_data[atlas_idx].flatten(), reho_p)
    idx = np.logical_and( atlas_idx, reho_data>=reho_th )
  else:
    idx = atlas_idx
  return func_data[idx, :].T


if apply_all:
  r = pd.read_csv(os.path.join(path_data, 'Phenotypic_V1_0b_preprocessed1.csv'))
  r = r[r['FILE_ID'] != 'no_filename']
  apply_to = r[['DX_GROUP', 'FILE_ID']].iterrows()
else:
  apply_to = enumerate([{'DX_GROUP': 1, 'FILE_ID': 'Pitt_0050003'}])


for i, row in apply_to:

  file_id = row['FILE_ID']
  dx = row['DX_GROUP']

  func_mean_filename = os.path.join(path_data, 'func_mean',f'{file_id}_func_mean.nii.gz')
  func_filename = os.path.join(path_data, 'nofilt_noglobal', f'{file_id}_func_preproc.nii.gz')
  reho_filename = os.path.join(path_data, 'reho', f'{file_id}_reho.nii.gz')

  if f'{file_id}_func_preproc.nii.gz' not in os.listdir(os.path.join(path_data, 'nofilt_noglobal')): continue
  print(f'processing subject {file_id}, {i}, dx group {dx}')

  atlas_data = np.squeeze(resample_to_img(atlas_filename, func_mean_filename, interpolation=interpolation).get_fdata())
  func_data = load_nii(func_filename)
  reho_data = load_nii(reho_filename)
  func_mean_data = load_nii(func_mean_filename)


  print('func_data shape', func_data.shape)
  print('atlas_data shape', atlas_data.shape)
  print('func_data mean shape', func_mean_data.shape)

  N_pts = func_data.shape[-1]
  Xmean = np.zeros( (N_pts, N_labels) )

  for j, roi_number in enumerate(roi_numbers):
    print('ROI', roi_number, row['DX_GROUP'], row['FILE_ID'])
    roi_signals = load_roi_signals(roi_number, func_data, atlas_data)#, reho_data)
    print('Shape of roi signals', roi_signals.shape)

    Xmean[:, j] = roi_signals.mean(axis=1)







In [None]:
plt.figure()
plt.imshow(Xmean)
plt.show()

In [None]:
import numpy as np
import pandas as pd
import os
import nibabel as nb

import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display
import numpy as np
from matplotlib.widgets import  RectangleSelector
import matplotlib.patches as patches
import matplotlib.cm as cm
from matplotlib.ticker import MaxNLocator
import copy


from nilearn import datasets
from nilearn.image import resample_to_img


def load_nii(filename, verbose=False):
  '''
  This function returns an array containg data from the file called filename
  Inputs:
      filename
  Output:
      data array
  '''
  if verbose:
    print('Reading the file', filename)
  return nb.load(filename).get_fdata()


def save_nii(filename, data, verbose=False):
  '''
  This function save data in the nifti format
  Inputs:
      filename: is the filename
      data: is the data to be saved
  '''
  if verbose:
    print('Saving the file', filename)
  nb.save(nb.Nifti1Image(data, np.eye(4)), filename)





def load_atlas_metadata(atlas):
  interpolation = 'nearest'
  if atlas == 'aal':
    aal = datasets.fetch_atlas_aal(version='SPM12')
    atlas_filename = aal.maps
    roi_numbers = list(map(int, aal.indices))
    N_labels = len(roi_numbers)
  elif atlas == 'ho_cort':
    ho = datasets.fetch_atlas_harvard_oxford('cort-prob-1mm')
    atlas_filename = ho['filename']
    N_labels = 48
    roi_numbers = list(range(1,N_labels+1))
    interpolation = 'continuous'
  elif atlas == 'ho_cort_maxprob':
    ho = datasets.fetch_atlas_harvard_oxford('cort-maxprob-thr50-1mm')
    atlas_filename = ho['filename']
    N_labels = 48
    roi_numbers = list(range(1,N_labels+1))
  elif atlas == 'ho_sub':
    ho = datasets.fetch_atlas_harvard_oxford('sub-prob-1mm')
    atlas_filename = ho['filename']
    N_labels = 21
    roi_numbers = list(range(1,N_labels+1))
    interpolation = 'continuous'
  elif atlas == 'ho_sub_maxprob':
    ho = datasets.fetch_atlas_harvard_oxford('sub-maxprob-thr50-1mm')
    atlas_filename = ho['filename']
    N_labels = 21
    roi_numbers = list(range(1,N_labels+1))
  elif atlas == 'msdl':
    msdl = datasets.fetch_atlas_msdl()
    atlas_filename = msdl.maps
    N_labels = 39
    roi_numbers = list(range(1,N_labels+1))
    interpolation = 'continuous'
  elif atlas == 'pauly':
    atlas = datasets.fetch_atlas_pauli_2017()
    atlas_filename = atlas.maps
    N_labels = 16
    roi_numbers = list(range(1,N_labels+1))
    interpolation = 'continuous'
  elif atlas == 'yeo_thin_17':
    yeo = datasets.fetch_atlas_yeo_2011()
    atlas_filename = yeo['thin_17'] # thick_17
    N_labels = 17
    roi_numbers = list(range(1,N_labels+1))
  elif atlas == 'yeo_thick_17':
    yeo = datasets.fetch_atlas_yeo_2011()
    atlas_filename = yeo['thick_17'] # thick_17
    N_labels = 17
    roi_numbers = list(range(1,N_labels+1))
  elif atlas == 'basc007':
    multiscale = datasets.fetch_atlas_basc_multiscale_2015()
    atlas_filename = multiscale.scale007
    N_labels = 7
    roi_numbers = list(range(1,N_labels+1))
  elif atlas == 'basc012':
    multiscale = datasets.fetch_atlas_basc_multiscale_2015()
    atlas_filename = multiscale.scale012
    N_labels = 12
    roi_numbers = list(range(1,N_labels+1))
  elif atlas == 'basc020':
    multiscale = datasets.fetch_atlas_basc_multiscale_2015()
    atlas_filename = multiscale.scale020
    N_labels = 20
    roi_numbers = list(range(1,N_labels+1))
  elif atlas == 'basc036':
    multiscale = datasets.fetch_atlas_basc_multiscale_2015()
    atlas_filename = multiscale.scale036
    N_labels = 36
    roi_numbers = list(range(1,N_labels+1))
  elif atlas == 'basc064':
    multiscale = datasets.fetch_atlas_basc_multiscale_2015()
    atlas_filename = multiscale.scale064
    N_labels = 64
    roi_numbers = list(range(1,N_labels+1))
  elif atlas == 'basc122':
    multiscale = datasets.fetch_atlas_basc_multiscale_2015()
    atlas_filename = multiscale.scale122
    N_labels = 122
    roi_numbers = list(range(1,N_labels+1))
  elif atlas == 'basc197':
    multiscale = datasets.fetch_atlas_basc_multiscale_2015()
    atlas_filename = multiscale.scale197
    N_labels = 197
    roi_numbers = list(range(1,N_labels+1))
  elif atlas == 'basc325':
    multiscale = datasets.fetch_atlas_basc_multiscale_2015()
    atlas_filename = multiscale.scale325
    N_labels = 325
    roi_numbers = list(range(1,N_labels+1))
  elif atlas == 'basc444':
    multiscale = datasets.fetch_atlas_basc_multiscale_2015()
    atlas_filename = multiscale.scale444
    N_labels = 444
    roi_numbers = list(range(1,N_labels+1))
  elif atlas == 'seitzman':
    atlas_filename = 'utilities/Parcels_MNI_111.nii'
    N_labels = 333
    roi_numbers = list(range(1,N_labels+1))

  return atlas_filename, roi_numbers, N_labels, interpolation


def load_atlas(atlas):

  if atlas == 'aal':
    return datasets.fetch_atlas_aal(version='SPM12')
  elif atlas == 'ho_cort':
    return datasets.fetch_atlas_harvard_oxford('cort-prob-1mm')
  elif atlas == 'ho_cort_maxprob':
    return datasets.fetch_atlas_harvard_oxford('cort-maxprob-thr50-1mm')
  elif atlas == 'ho_sub':
    return datasets.fetch_atlas_harvard_oxford('sub-prob-1mm')
  elif atlas == 'ho_sub_maxprob':
    return datasets.fetch_atlas_harvard_oxford('sub-maxprob-thr50-1mm')
  elif atlas == 'msdl':
    return datasets.fetch_atlas_msdl()
  elif atlas == 'pauly':
    return datasets.fetch_atlas_pauli_2017()
  elif atlas == 'yeo':
    return datasets.fetch_atlas_yeo_2011()
  elif atlas == 'basc':
    return datasets.fetch_atlas_basc_multiscale_2015()
  elif atlas == 'seitzman':
    return datasets.fetch_coords_seitzman_2018(ordered_regions=True, legacy_format=True)




In [None]:
class MultiImageDisplay(object):
    def __init__(self, image_list, axis=0, shared_slider=False, title_list=None, window_level_list= None, figure_size=(10,8), horizontal=True):

        self.npa_list, wl_range, wl_init = self.get_window_level_numpy_array(image_list, window_level_list)
        if title_list:
            if len(image_list)!=len(title_list):
                raise ValueError('Title list and image list lengths do not match')
            self.title_list = list(title_list)
        else:
            self.title_list = ['']*len(image_list)

        # Our dynamic slice, based on the axis the user specifies
        self.slc = [slice(None)]*3
        self.axis = axis

        ui = self.create_ui(shared_slider, wl_range, wl_init)
        display(ui)

        # Create a figure.
        col_num, row_num = (len(image_list), 1)  if horizontal else (1, len(image_list))
        self.fig, self.axes = plt.subplots(row_num,col_num,figsize=figure_size)
        if len(image_list)==1:
            self.axes = [self.axes]

        # Display the data and the controls, first time we display the image is outside the "update_display" method
        # as that method relies on the previous zoom factor which doesn't exist yet.
        for ax, npa, slider, wl_slider in zip(self.axes, self.npa_list, self.slider_list, self.wl_list):
            self.slc[self.axis] = slice(slider.value, slider.value+1)
            # Need to use squeeze to collapse degenerate dimension (e.g. RGB image size 124 124 1 3)
            ax.imshow(np.squeeze(npa[tuple(self.slc)]),
                      cmap=plt.cm.Greys_r,
                      vmin=wl_slider.value[0],
                      vmax=wl_slider.value[1])
        self.update_display()
        plt.tight_layout()


    def create_ui(self, shared_slider, wl_range, wl_init):
        # Create the active UI components. Height and width are specified in 'em' units. This is
        # a html size specification, size relative to current font size.

        if shared_slider:
            # Validate that all the images have the same size along the axis which we scroll through
            sz = self.npa_list[0].shape[self.axis]
            for npa in self.npa_list:
                       if npa.shape[self.axis]!=sz:
                           raise ValueError('Not all images have the same size along the specified axis, cannot share slider.')

            slider = widgets.IntSlider(description='image slice:',
                                      min=0,
                                      max=sz-1,
                                      step=1,
                                      value = int((sz-1)/2),
                                      width='20em')
            slider.observe(self.on_slice_slider_value_change, names='value')
            self.slider_list = [slider]*len(self.npa_list)
            slicer_box = widgets.Box(padding=7, children=[slider])
        else:
            self.slider_list = []
            for npa in self.npa_list:
                slider = widgets.IntSlider(description='image slice:',
                                           min=0,
                                           max=npa.shape[self.axis]-1,
                                           step=1,
                                           value = int((npa.shape[self.axis]-1)/2),
                                           width='20em')
                slider.observe(self.on_slice_slider_value_change, names='value')
                self.slider_list.append(slider)
            slicer_box = widgets.Box(padding=7, children=self.slider_list)
        self.wl_list = []
        # Each image has a window-level slider, but it is disabled if the image
        # is a color image len(npa.shape)==4 . This allows us to display both
        # color and grayscale images in the same UI while retaining a reasonable
        # layout for the sliders.
        for r_values, i_values, npa in zip(wl_range, wl_init, self.npa_list):
            wl_range_slider = widgets.IntRangeSlider(description='intensity:',
                                              min=r_values[0],
                                              max=r_values[1],
                                              step=1,
                                              value = [i_values[0], i_values[1]],
                                              width='20em',
                                              disabled = len(npa.shape) == 4)
            wl_range_slider.observe(self.on_wl_slider_value_change, names='value')
            self.wl_list.append(wl_range_slider)
        wl_box = widgets.Box(padding=7, children=self.wl_list)
        return widgets.VBox(children=[slicer_box,wl_box])

    def get_window_level_numpy_array(self, npa_list, window_level_list):
        # Using GetArray and not GetArrayView because we don't keep references
        # to the original images. If they are deleted outside the view would become
        # invalid, so we use a copy wich guarentees that the gui is consistent.
        # npa_list = list(map(sitk.GetArrayFromImage, image_list))

        wl_range = []
        wl_init = []
        # We need to iterate over the images because they can be a mix of
        # grayscale and color images. If they are color we set the wl_range
        # to [0,255] and the wl_init is equal, ignoring the window_level_list
        # entry.
        for i, npa in enumerate(npa_list):
            if len(npa.shape) == 4: #color image
                wl_range.append((0,255))
                wl_init.append((0,255))
                # ignore any window_level_list entry
            else:
                # We don't take the minimum/maximum values, just in case there are outliers (top/bottom 2%)
                min_max = np.percentile(npa.flatten(), [2,98])
                wl_range.append((min_max[0], min_max[1]))
                if not window_level_list:
                    wl_init.append(wl_range[-1])
                else:
                    wl = window_level_list[i]
                    if wl:
                        wl_init.append((wl[1]-wl[0]/2.0, wl[1]+wl[0]/2.0))
                    else:
                        wl_init.append(wl_range[-1])
        return (npa_list, wl_range, wl_init)

    def on_slice_slider_value_change(self, change):
        self.update_display()

    def on_wl_slider_value_change(self, change):
        self.update_display()

    def update_display(self):

        # Draw the image(s)
        for ax, npa, title, slider, wl_slider in zip(self.axes, self.npa_list, self.title_list, self.slider_list, self.wl_list):
            # We want to keep the zoom factor which was set prior to display, so we log it before
            # clearing the axes.
            xlim = ax.get_xlim()
            ylim = ax.get_ylim()

            self.slc[self.axis] = slice(slider.value, slider.value+1)
            ax.clear()
            # Need to use squeeze to collapse degenerate dimension (e.g. RGB image size 124 124 1 3)
            ax.imshow(np.squeeze(npa[tuple(self.slc)]),
                      cmap=plt.cm.Greys_r,
                      vmin=wl_slider.value[0],
                      vmax=wl_slider.value[1])
            ax.set_title(title)
            ax.set_axis_off()

            # Set the zoom factor back to what it was before we cleared the axes, and rendered our data.
            ax.set_xlim(xlim)
            ax.set_ylim(ylim)

        self.fig.canvas.draw_idle()
