In [None]:
#WARNING: dipy version >= 0.16.0 is needed
import dipy
dipy.__version__

In [None]:
import os
import numpy as np
import nibabel as nib
from functools import partial
from os.path import join as pjoin
from sklearn.cluster import KMeans
from dipy.segment.clustering import QuickBundles
from dipy.tracking.streamline import set_number_of_points
from dipy.tracking.distances import bundles_distances_mam
from euclidean_embeddings.dissimilarity import compute_dissimilarity
from euclidean_embeddings.distances import euclidean_distance, parallel_distance_computation
from dipy.data.fetcher import (fetch_target_tractogram_hcp,
                               fetch_bundle_atlas_hcp842,
                               get_bundle_atlas_hcp842,
                               get_target_tractogram_hcp)

In [None]:
def compute_centroid(bundle, nb_points=100):
    """Compute the centroid of a bundle.
    """
    st = np.array([s for s in bundle], dtype=np.object)
    qb = QuickBundles(threshold=10.0, max_nb_clusters=1)
    centroid = [cluster.centroid for cluster in qb.cluster(bundle)]
    centroid = set_number_of_points(centroid, nb_points)

    return centroid

In [None]:
if __name__ == '__main__':

    #get the tractogram atlas and the 80 bundles
    atlas_file, atlas_folder = fetch_bundle_atlas_hcp842()
    atlas_file, all_bundles_files = get_bundle_atlas_hcp842()

In [None]:
    #Read the tractogram atlas with old API because the it 
    #does not apply any transformation during loading.
    atlas_tr, _ = nib.trackvis.read(atlas_file)
    atlas = [sl[0] for sl in atlas_tr]
    atlas = np.array(atlas, dtype=np.object)
    len(atlas)

In [None]:
    #EXAMPLE 1: compute the dissimilarity of the atlas using 100 
    #prototypes computed in the atlas with prototype_policy='sff'

In [None]:
    distance = partial(parallel_distance_computation, distance=bundles_distances_mam)
    n_prototypes = 100

In [None]:
    #compute dissimilarity
    dissimilarity_atlas, prototype_idx = compute_dissimilarity(atlas, distance, n_prototypes,
                                                               prototype_policy='sff',
                                                               verbose=False)

In [None]:
    dissimilarity_atlas.shape

In [None]:
    #EXAMPLE 2: compute the dissimilarity of the 
    #IFOF_L bundle using the same prototypes    

In [None]:
    tract_name = 'IFOF_L'
    bundle_file = '%s/Atlas_80_Bundles/bundles/%s.trk' %(atlas_folder, tract_name)

In [None]:
    #Read the bundle with old API because the it 
    #does not apply any transformation during loading.
    bundle, _ = nib.trackvis.read(bundle_file)
    bundle = [sl[0] for sl in bundle]
    bundle = nib.streamlines.array_sequence.ArraySequence(bundle)
    len(bundle)

In [None]:
    #compute dissimilarity
    prototypes = atlas[prototype_idx]
    dissimilarity_bundle = distance(bundle, prototypes)
    dissimilarity_bundle.shape

In [None]:
    #EXAMPLE 3: compute the dissimilarity of the IFOF_L
    #bundle using as a prototype the centroid of the bundle

In [None]:
    centroid = compute_centroid(bundle)
    dissimilarity_bundle_1 = distance(bundle, centroid)
    dissimilarity_bundle_1.shape

In [None]:
    #EXERCISE 1: compute the dissimilarity of the IFOF_L bundle using as 
    #prototypes the centroids of all the 80 bundles contained in the bundle_folder.
    #WARNING: For the Fornix (F_L_R.trk) both left and right sides are
    #included in one file. It should be possible to easily separate them.

In [None]:
    bundles_folder = '%s/Atlas_80_Bundles/bundles' %atlas_folder