In [None]:
# this function will do the dice comparison between two different runs of pyAFQ
# it gets data from s3, and then uploads results to s3
def compare_sessions_dict(subjects, session1, session2):
    import s3fs
    import os.path as op

    import pandas as pd
    from AFQ.utils.volume import density_map, dice_coeff
    from AFQ.data import BUNDLE_RECO_2_AFQ
    from dipy.io.streamline import load_tractogram
    from dipy.io.stateful_tractogram import Space
    import nibabel as nib

    fs = s3fs.S3FileSystem()

    sessions = [session1, session2]
    base_path = "my_bucket/hcp_trt"

    session_ses = {
        "DKI": "01",
        "CSD": "01",
        "DTI": "01",
        "DKI_retest": "01",
        "CSD_retest": "01",
        "DTI_retest": "01",
        "DKI_reco": "01"}
    session_ses_prefix = {
        "DKI": "",
        "CSD": "",
        "DTI": "",
        "DKI_retest": "",
        "CSD_retest": "",
        "DTI_retest": "",
        "DKI_reco": ""}
    session_folders = {
        "DKI": "multi_shell/hcp_1200_afq",
        "CSD": "multi_shell/hcp_1200_afq_CSD",
        "DTI": "single_shell/hcp_1200_afq",
        "DKI_retest": "multi_shell/hcp_retest_afq",
        "CSD_retest": "multi_shell/hcp_retest_afq_CSD",
        "DTI_retest": "single_shell/hcp_retest_afq",
        "DKI_reco": "multi_shell/hcp_1200_reco80"}
    session_seg_algo = {
        "DKI": "afq",
        "CSD": "afq",
        "DTI": "afq",
        "DKI_retest": "afq",
        "CSD_retest": "afq",
        "DTI_retest": "afq",
        "DKI_reco": "reco80"}

    if "reco" in session1 or "reco" in session2:
        afq_2_reco = {v: k for k, v in BUNDLE_RECO_2_AFQ.items()}
        bundles = list(afq_2_reco.keys())
    else:
        bundles = [
            "ATR_L", "CGC_L", "CST_L", "IFO_L", "ILF_L", "SLF_L", "ARC_L", "UNC_L",
            "ATR_R", "CGC_R", "CST_R", "IFO_R", "ILF_R", "SLF_R", "ARC_R", "UNC_R",
            "FA", "FP"]

    df = pd.DataFrame(columns=["subjectID", "tractID", "dice"])
    for subject in subjects:
        sub_paths = []
        for i, session in enumerate(sessions):
            sub_paths.append(
                op.join(base_path, session_folders[session], f"sub-{subject}/ses-{session_ses[session]}/"))
            fs.get(
                op.join(sub_paths[i], f"sub-{subject}{session_ses_prefix[session]}_dwi_b0.nii.gz"),
                f"img{i}.nii.gz")

        for bundle in bundles:
            density_maps = []
            for i, session in enumerate(sessions):
                if "reco" in session:
                    bundle_remote_name = afq_2_reco[bundle]
                else:
                    bundle_remote_name = bundle
                fs.get(
                    op.join(
                        sub_paths[i],
                        (f"clean_bundles/sub-{subject}{session_ses_prefix[session]}"
                         f"_dwi_space-RASMM_model-{session[:3]}"
                         f"_desc-det-{session_seg_algo[session]}-{bundle_remote_name}_tractography.trk")),
                    f"{bundle}{i}.trk")
                density_maps.append(density_map(load_tractogram(
                    f"{bundle}{i}.trk",
                    nib.load(f"img{i}.nii.gz"),
                    Space.VOX,
                    bbox_valid_check=False)))

            df = df.append({
                "subjectID": subject,
                "tractID": bundle,
                "dice": dice_coeff(density_maps[0], density_maps[1])}, ignore_index=True)
    df.to_csv("df.csv", index=False)
    fs.put("df.csv", op.join(base_path, f"{session1}_{session2}_dice.csv"))


In [None]:
# These are all of the HCP-TRT subjects with DWI data
all_subjects = [
    103818,
    105923,
    111312,
    114823,
    115320,
    122317,
    125525,
    130518,
    135528,
    137128,
    139839,
    143325,
    144226,
    146129,
    149337,
    149741,
    151526,
    158035,
    169343,
    172332,
    175439,
    177746,
    185442,
    187547,
    192439,
    194140,
    195041,
    200109,
    200614,
    204521,
    250427,
    287248,
    341834,
    433839,
    562345,
    599671,
    601127,
    627549,
    660951,
    783462,
    859671,
    861456,
    877168,
    917255
]
len(all_subjects)

In [None]:
# These are the DICE comparisons we used in the paper
args = [
    [all_subjects, "DKI", "CSD"],
    [all_subjects, "DKI_reco", "DKI"],
    [all_subjects, "DKI", "DKI_retest"],
    [all_subjects, "CSD", "CSD_retest"]]

In [None]:
# we use cloudknot to calculate the dice comparisons on aws
import cloudknot as ck
ck.set_region('us-west-2')

In [None]:
# make knot to process DICE comparisons
knot = ck.Knot(
    name='hcp-dice-210101-0',
    func=compare_sessions_dict,
    base_image='python:3.8',
    image_github_installs="https://github.com/36000/pyAFQ.git@0.6",
    pars_policies=('AmazonS3FullAccess',),
    bid_percentage=100)


In [None]:
# run the first 3 argument combinations on cloud (do the rest if these succeed)
result_futures = knot.map(args[:3], starmap=True)

In [None]:
# check the status of your jobs
ck.set_region('us-west-2')
knot.view_jobs()
# check the status of and individual job
j0 = knot.jobs[0]
j0.status

In [None]:
# clobber your knot resource when you are done
#knot.clobber(clobber_pars=True, clobber_repo=True, clobber_image=True)

In [None]:
import pandas as pd
import os.path as op
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from AFQ.viz.utils import BrainAxes, gen_color_dict
from AFQ.api import make_bundle_dict

In [None]:
# you must download the dice coefficients to some local folder
# and then load them here
dice_folder = "/my/dice/folder" 
dki_csd_dice = pd.read_csv(op.join(dice_folder, "DKI_CSD_dice.csv"))
dki_dki_dice = pd.read_csv(op.join(dice_folder, "DKI_DKI_retest_dice.csv"))
csd_csd_dice = pd.read_csv(op.join(dice_folder, "CSD_CSD_retest_dice.csv"))
roi_reco_dice = pd.read_csv(op.join(dice_folder, "DKI_reco_DKI_dice.csv"))
py_m_dice = pd.read_csv(op.join(dice_folder, "DTI_prek_py_DTI_prek_m_dice.csv"))
py_py_dice = pd.read_csv(op.join(dice_folder, "DTI_prek_py_DTI_prek_py_post_dice.csv"))
m_m_dice = pd.read_csv(op.join(dice_folder, "DTI_prek_m_DTI_prek_m_post_dice.csv"))
dki_csd_dice # example dataframe

In [None]:
# we need to re-label columns and mask for NaNs before plotting
def prep_for_plotting(df, name):
    df_plot = df.copy()
    df_plot = df_plot[np.logical_and(
        df_plot["dice"] != 0,
        ~np.isnan(df_plot["dice"]))]
    df_plot['tractID'] = name + df_plot['tractID'].astype(str)
    return df_plot

In [None]:
# Fig. 6A: Replicability plots
py_py_dice_plot = prep_for_plotting(py_py_dice, "py_")
m_m_dice_plot = prep_for_plotting(m_m_dice, "ma_")
dki_dki_dice_plot = prep_for_plotting(dki_dki_dice, "dk_")
csd_csd_dice_plot = prep_for_plotting(csd_csd_dice, "cs_")

bundle_names = list(make_bundle_dict().keys())
bundle_names.sort()
bundle_names_order = []
for bundle_name in bundle_names:
    bundle_names_order.append("py_" + bundle_name)
    bundle_names_order.append("ma_" + bundle_name)
    bundle_names_order.append("dk_" + bundle_name)
    bundle_names_order.append("cs_" + bundle_name)
color_dict = gen_color_dict(bundle_names)

sns.set(style="whitegrid")
fig, ax = plt.subplots(figsize=(16, 4))

for i, comparison in enumerate([py_py_dice_plot, m_m_dice_plot, dki_dki_dice_plot, csd_csd_dice_plot]):
    for j, tractID in enumerate(comparison["tractID"].unique()):
        bar = sns.barplot(
            data=comparison[comparison["tractID"] == tractID],
            x='tractID', y='dice',
            order=bundle_names_order,
            estimator=np.median,
            palette=[color_dict[tractID[3:]]],
            units="dice",
            ci=95,
            ax=ax)
        if i == 1:
            for patch in bar.patches[-18*4:]:
                patch.set_hatch("-")
        if i == 2:
            for patch in bar.patches[-18*4:]:
                patch.set_hatch("/")
        if i == 3:
            for patch in bar.patches[-18*4:]:
                patch.set_hatch("\\")

ax.set_ylabel("Median Dice's Coefficient", fontsize=18)
ax.set_xlabel("Bundle", fontsize=18)
ax.set_ylim([0.4, 1])
ax.set_yticks(np.arange(0.4, 1.1, 0.1))
ax.set_xticklabels(bundle_names)
ax.tick_params(
    axis='x', which='major', labelsize=16)
ax.tick_params(
    axis='y', which='major', labelsize=16)
ax.set_xticks(np.arange(18)*4+2)
plt.setp(plt.gca().get_xticklabels(),
         rotation=65,
         horizontalalignment='right')

legend_labels = [
    Patch(
        facecolor='k'),
    Patch(
        facecolor='k',
        hatch='-'),
    Patch(
        facecolor='k',
        hatch='/'),
    Patch(
        facecolor='k',
        hatch='\\')]
fig.legend(
    legend_labels,
    ["UW-PREK pyAFQ test retest", "UW-PREK mAFQ test retest", "HCP-TRT DKI test retest", "HCP-TRT CSD test retest"],
    fontsize=16,
    loc='center',
    bbox_to_anchor=(0.26, 0.48))
pass

In [None]:
# Fig. 7A: pyAFQ vs mAFQ wDSC robustness plot
bundle_list = list(make_bundle_dict().keys())
bundle_list.sort()

py_m_dice_plot = prep_for_plotting(py_m_dice, "")

fig, ax = plt.subplots(figsize=(8, 4))

sns.set(style="whitegrid")
sns.barplot(
    data=py_m_dice_plot,
    x='tractID', y='dice',
    order=bundle_list,
    estimator=np.median,
    ci=95,
    ax=ax,
    palette=COLOR_DICT)
ax = plt.gca()
ax.set_ylabel("Median Dice's Coefficient", fontsize=18)
ax.set_xlabel("Bundle", fontsize=18)
ax.set_ylim([0, 1])
ax.tick_params(
    axis='x', which='major', labelsize=16)
ax.tick_params(
    axis='y', which='major', labelsize=16)
ax.set_xticks(np.arange(18)+0.5)
plt.setp(plt.gca().get_xticklabels(),
         rotation=65,
         horizontalalignment='right')
pass

In [None]:
# Fig. 8A: DKI v CSD wDSC robustness plot
bundle_names = list(make_bundle_dict().keys())
bundle_names.sort()
color_dict = gen_color_dict(bundle_names)

dki_csd_dice_plot = prep_for_plotting(dki_csd_dice, "")

sns.set(style="whitegrid")
fig, ax = plt.subplots(figsize=(8, 4))

bar = sns.barplot(
    data=dki_csd_dice_plot,
    x='tractID', y='dice',
    order=bundle_names,
    estimator=np.median,
    palette=color_dict,
    units="dice",
    ci=95,
    ax=ax)

ax.set_ylabel("Median Dice's Coefficient", fontsize=18)
ax.set_xlabel("Bundle", fontsize=18)
ax.set_ylim([0, 1])
ax.set_xticklabels(bundle_names)
ax.tick_params(
    axis='x', which='major', labelsize=16)
ax.tick_params(
    axis='y', which='major', labelsize=16)
ax.set_xticks(np.arange(18)+0.4)
plt.setp(plt.gca().get_xticklabels(),
         rotation=65,
         horizontalalignment='right')
pass

In [None]:
# Fig. 9A: Recobundles vs waypoint ROIs wDSC robustness plot
bundles_list = list(roi_reco_dice_removed["tractID"].unique())
bundles_list.sort()

roi_reco_dice_removed = prep_for_plotting(roi_reco_dice, "")

sns.set(style="whitegrid")
sns.barplot(
    data=roi_reco_dice_removed,
    x='tractID', y='dice',
    order=bundles_list,
    estimator=np.median,
    ci=95,
    palette=COLOR_DICT)
ax = plt.gca()
ax.set_ylabel("Median Dice's Coefficient", fontsize=18)
ax.set_xlabel("Bundle", fontsize=18)
ax.set_ylim([0, 1])
ax.tick_params(
    axis='x', which='major', labelsize=16)
ax.tick_params(
    axis='y', which='major', labelsize=16)
ax.set_xticks(np.arange(12)+0.5)
plt.setp(plt.gca().get_xticklabels(),
         rotation=65,
         horizontalalignment='right')
pass