In [None]:
# from pyAFQ examples. Performs tractography on HARDI dataset.
"""
==========================
Plotting tract profiles
==========================

An example of tracking and segmenting two tracts, and plotting their tract
profiles for FA (calculated with DTI).

"""
import os.path as op
import matplotlib.pyplot as plt
import numpy as np
import nibabel as nib
import dipy.data as dpd
from dipy.data import fetcher
import dipy.tracking.utils as dtu
import dipy.tracking.streamline as dts
from dipy.io.streamline import save_tractogram, load_tractogram
from dipy.stats.analysis import afq_profile, gaussian_weights
from dipy.io.stateful_tractogram import StatefulTractogram
from dipy.io.stateful_tractogram import Space

import AFQ.utils.streamlines as aus
import AFQ.data as afd
import AFQ.tractography as aft
import AFQ.registration as reg
import AFQ.dti as dti
import AFQ.segmentation as seg
from AFQ.utils.volume import patch_up_roi
from AFQ.api import make_bundle_dict

dpd.fetch_stanford_hardi()

hardi_dir = op.join(fetcher.dipy_home, "stanford_hardi")
hardi_fdata = op.join(hardi_dir, "HARDI150.nii.gz")
hardi_fbval = op.join(hardi_dir, "HARDI150.bval")
hardi_fbvec = op.join(hardi_dir, "HARDI150.bvec")

img = nib.load(hardi_fdata)

print("Calculating DTI...")
if not op.exists('./dti_FA.nii.gz'):
    dti_params = dti.fit_dti(hardi_fdata, hardi_fbval, hardi_fbvec,
                             out_dir='.')
else:
    dti_params = {'FA': './dti_FA.nii.gz',
                  'params': './dti_params.nii.gz'}

FA_img = nib.load(dti_params['FA'])
FA_data = FA_img.get_fdata()

templates = afd.read_templates()
bundles = make_bundle_dict()

# bundles = {}
# for name in bundle_names:
#     for hemi in ['_R', '_L']:
#         bundles[name + hemi] = {
#             'ROIs': [templates[name + '_roi1' + hemi],
#                      templates[name + '_roi2' + hemi]],
#             'rules': [True, True],
#             'prob_map': templates[name + hemi + '_prob_map'],
#             'cross_midline': False}

print("Registering to template...")
MNI_T2_img = dpd.read_mni_template()
if not op.exists('mapping.nii.gz'):
    import dipy.core.gradients as dpg
    gtab = dpg.gradient_table(hardi_fbval, hardi_fbvec)
    warped_hardi, mapping = reg.syn_register_dwi(hardi_fdata, gtab)
    reg.write_mapping(mapping, './mapping.nii.gz')
else:
    mapping = reg.read_mapping('./mapping.nii.gz', img, MNI_T2_img)


print("Tracking...")
if not op.exists('dti_streamlines.trk'):
    seed_roi = np.zeros(img.shape[:-1])
    for name in bundle_names:
        for hemi in ['_R', '_L']:
            for roi in bundles[name + hemi]['ROIs']:
                warped_roi = patch_up_roi(
                    (mapping.transform_inverse(
                        roi.get_data().astype(np.float32),
                     interpolation='linear')) > 0)

                # Add voxels that aren't there yet:
                seed_roi = np.logical_or(seed_roi, warped_roi)

    nib.save(nib.Nifti1Image(seed_roi.astype(float), img.affine), 'seed_roi.nii.gz')
    streamlines = aft.track(dti_params['params'], #seed_mask=seed_roi,
                            stop_mask=FA_data, stop_threshold=0.1)

    sft = StatefulTractogram(streamlines, img, Space.RASMM)
    save_tractogram(sft, './dti_streamlines.trk',
                    bbox_valid_check=False)
else:
    tg = load_tractogram('./dti_streamlines.trk', img)
    streamlines = tg.streamlines

streamlines = dts.Streamlines(
    dtu.transform_tracking_output(streamlines,
                                  np.linalg.inv(img.affine)))

In [None]:
# check that streamlines are reasonable
# by using fury to plot them
from fury import actor, window
from fury.colormap import line_colors
streamlines_actor = actor.line(streamlines, line_colors(streamlines))
scene = window.Scene()
scene.add(streamlines_actor)
window.show(scene)

In [None]:
# perform segmentation with different resampling and compare the fiber groups to
# fiber groups generated by segmentation with no resampling
import time
import dipy.tracking.streamlinespeed as dps

def resample_tests():
    # how many points to resample to?
    nb_points_ls = [False, 2, 5, 10, 20, 30, 40, 50, 100]

    fiber_groups_ls = []
    times = []

     for nb_points in nb_points_ls:
        # time the segmentation
        start = time.process_time() 
        print("Segmenting fiber groups using ", nb_points)
        segmentation = seg.Segmentation(nb_points=nb_points)
        fiber_groups = segmentation.segment(bundles, streamlines, hardi_fdata, hardi_fbval, hardi_fbvec,
                                            mapping=mapping, reg_template=MNI_T2_img)
        times.append(time.process_time() - start)   
        fiber_groups_ls.append(fiber_groups)

        # save segmentation results for no resampling
        # to check if they are reasonable
        if nb_points is False:
            with open('all_bundles.txt', 'w') as f:
                f.write(str(fiber_groups))

    return nb_points_ls, times, fiber_groups_ls

# check if two streamlines are the same
def is_same_line(sl1, sl2, tol=1e-8):
    if sl1.shape[0] != sl2.shape[0]:
        return False
    for i in range(sl1.shape[0]):
        for j in range(sl1.shape[1]):
            if abs(sl1[i, j] - sl2[i, j]) > tol:
                return False
    return True

TPRs = [] # true positive rates
TPs = [] # number of true positives
FDRs = [] # false discovery rates
nb_points_ls, times, fiber_groups_ls = resample_tests()

# take first group as true
# in this case, the first group is not resampled
fiber_group_truth = list(fiber_groups_ls[0].values())

for i in range(len(fiber_groups_ls)):
    fiber_groups = fiber_groups_ls[i]
    hits = 0
    total = 0
    fiber_group_hat = list(fiber_groups.values())
    for b_i in range(len(fiber_group_truth)):
        for sl_true in fiber_group_truth[b_i]: 
            # update count of sls
            total = total + 1

            # resample truth set to compare
            if nb_points_ls[i] != False:
                sl_true = dps.set_number_of_points(sl_true, nb_points_ls[i])

            # check if streamline exists in the resampled fiber group
            exists = False
            for sl_hat in fiber_group_hat[b_i]:
                if is_same_line(sl_true, sl_hat):
                    exists = True
                    break
            if exists:
                hits = hits + 1
    TPRs.append(hits / total)
    TPs.append(hits)

    err = 0
    total = 0
    for b_i in range(len(fiber_group_hat)):
        for sl_hat in fiber_group_hat[b_i]:
            # update count of sls
            total = total + 1

            # check if streamline exists in the original fiber group
            exists = False
            for sl_true in fiber_group_truth[b_i]:
                # resample truth set to compare
                if nb_points_ls[i] != False:
                    sl_true = dps.set_number_of_points(sl_true, nb_points_ls[i])

                if is_same_line(sl_true, sl_hat):
                    exists = True
                    break
            if not exists:
                err = err + 1
    if total > 0:
        FDRs.append(err / total)
    else:
        FDRs.append(0.0)

def plot_resample(x, y, ylabel, title, tpr=False):
    plt.clf()
    ax = plt.gca()

    # first data point is not resampled, used for reference
    ax.axhline(y=y[0], color='r')

    # draw a vertical green line if tpr ever reaches 1.0
    if tpr:
        for i, y_i in enumerate(y):
            if i != 0 and y_i >= 1.0:
                ax.axvline(x=x[i], color='g')
                break

    plt.plot(x[1:], y[1:])
    plt.xlabel(' of points')
    plt.ylabel(ylabel)
    plt.savefig(title)

plot_resample(nb_points_ls, times, 'Time', 'resample_time')
plot_resample(nb_points_ls, TPRs, 'True Positive Rate', 'resample_tpr', tpr=True)
plot_resample(nb_points_ls, TPs, 'True Positive', 'resample_tp')
plot_resample(nb_points_ls, FDRs, 'False Discovery Rate', 'resample_fdr')

In [None]:
# for some specific resampling, compare profiles with no resampling

# perform resampled segmentation
segmentation = seg.Segmentation(nb_points=100)
fiber_groups_resampled = segmentation.segment(bundles, streamlines, hardi_fdata, hardi_fbval, hardi_fbvec,
                        mapping=mapping, reg_template=MNI_T2_img)

# perform not resampled segmentation
segmentation = seg.Segmentation()
fiber_groups_not_resampled = segmentation.segment(bundles, streamlines, hardi_fdata, hardi_fbval, hardi_fbvec,
                        mapping=mapping, reg_template=MNI_T2_img)

fiber_groups_ls = [fiber_groups_not_resampled, fiber_groups_resampled]
colors = ['b', 'r--']

# clean each fiber group
for i in range(2):
    fiber_groups = fiber_groups_ls[i]
    print("Cleaning fiber groups...")
    for bundle in bundles:
        fiber_groups[bundle] = seg.clean_fiber_group(fiber_groups[bundle])

    for kk in fiber_groups:
        print(kk, len(fiber_groups[kk]))

        sft = StatefulTractogram(
            dtu.transform_tracking_output(fiber_groups[kk], img.affine),
            img, Space.RASMM)

        save_tractogram(sft, './%s_afq.trk'%kk,
                        bbox_valid_check=False)

# plot each fiber group
print("Extracting tract profiles...")
for bundle in bundles:
    fig, ax = plt.subplots(1)
    for i in range(2):
        fiber_groups = fiber_groups_ls[i]
        weights = gaussian_weights(fiber_groups[bundle])
        profile = afq_profile(FA_data, fiber_groups[bundle],
                            np.eye(4), weights=weights)
        ax.plot(profile, colors[i])
        ax.set_title(bundle)

    plt.savefig(bundle)