In [None]:
import configparser
import itertools
import os.path as op

In [None]:
# get hcp keys from ~/.aws/credentials
# you will need to get access to hcp data in order to run this script.
CP = configparser.ConfigParser()
CP.read_file(open(op.join(op.expanduser('~'), '.aws', 'credentials')))
CP.sections()
aws_access_key = CP.get('hcp', 'AWS_ACCESS_KEY_ID')
aws_secret_key = CP.get('hcp', 'AWS_SECRET_ACCESS_KEY')

In [None]:
# define a function that will be run on the cloud for each subject, session, and configuration
def afq_hcp_retest(subject, shell, session, seg_algo, reuse_tractography, use_callosal, aws_access_key, aws_secret_key):
    import logging
    import s3fs

    from AFQ.data import fetch_hcp
    import AFQ.api as api
    import AFQ.mask as afm
    import numpy as np
    import os.path as op

    logging.basicConfig(level=logging.INFO)
    log = logging.getLogger(__name__)
    
    fs = s3fs.S3FileSystem()
    
    my_hcp_key = "my_bucket/hcp_trt"

    # get HCP data for the given subject / session
    _, hcp_bids = fetch_hcp(
        [subject],
        profile_name=False,
        study=f"HCP_{session}",
        aws_access_key_id=aws_access_key,
        aws_secret_access_key=aws_secret_key)

    # if use_callosal, use the callosal bundles
    if use_callosal:
        bundle_info = api.BUNDLES + api.CALLOSUM_BUNDLES
    else:
        bundle_info = None

    # if single shell, only use b values between 990 and 1010
    if "single" in shell.lower():
        tracking_params = {"odf_model": "DTI"}
        kwargs = {
            "min_bval": 990,
            "max_bval": 1010
        }
    # if multi shell, use DKI instead of CSD everywhere
    else:
        tracking_params = {
            'seed_mask': afm.ScalarMask('dki_fa'),
            'stop_mask': afm.ScalarMask('dki_fa'),
            "odf_model": "DKI"}
        kwargs = {
            "scalars": ["dki_fa", "dki_md"]
        }

    # use csd if csd is in shell
    if "csd" in shell.lower():
        tracking_params["odf_model"] = "CSD"

    # Whether to reuse a previous tractography that has already been uploaded to s3
    # by another run of this function. Useful if you want to try new parameters that
    # do not change the tractography.
    if reuse_tractography:
        fs.get(
            (
                f"{my_hcp_key}/{shell}_shell/"
                f"hcp_{session.lower()}_afq/sub-{subject}/ses-01/"
                f"sub-{subject}_dwi_space-RASMM_model-"
                f"{tracking_params['odf_model']}_desc-det_tractography.trk"),
            op.join(hcp_bids, f"derivatives/dmriprep/sub-{subject}/ses-01/sub-{subject}_customtrk.trk"))
        custom_tractography_bids_filters = {
            "suffix": "customtrk", "scope": "dmriprep"}
    else:
        custom_tractography_bids_filters = None

    # Initialize the AFQ object with all of the parameters we have set so far
    # Also uses the brain mask provided by HCP
    # Sets viz_backend='plotly' to make GIFs in addition to the default
    # html visualizations (this adds ~45 minutes)
    myafq = api.AFQ(
        hcp_bids,
        brain_mask=afm.LabelledMaskFile('seg', {'scope':'dmriprep'}, exclusive_labels=[0]),
        custom_tractography_bids_filters=custom_tractography_bids_filters,
        tracking_params=tracking_params,
        bundle_info=bundle_info,
        segmentation_params={"seg_algo": seg_algo, "reg_algo": "syn"},
        viz_backend='plotly',
        **kwargs)
    # run the AFQ objects
    myafq.export_all()

    # upload the results to my_hcp_key, organized by parameters used
    remote_export_path = f"{my_hcp_key}/{shell}_shell/hcp_{session.lower()}_{seg_algo}"
    if use_callosal:
        remote_export_path = remote_export_path + "_callosal"
    myafq.upload_to_s3(fs, remote_export_path)

In [None]:
# These are all the HCP subjects with DWI data in the HCP test-retest dataset (HCP-TRT)
all_subjects = [
    103818,
    105923,
    111312,
    114823,
    115320,
    122317,
    125525,
    130518,
    135528,
    137128,
    139839,
    143325,
    144226,
    146129,
    149337,
    149741,
    151526,
    158035,
    169343,
    172332,
    175439,
    177746,
    185442,
    187547,
    192439,
    194140,
    195041,
    200109,
    200614,
    204521,
    250427,
    287248,
    341834,
    433839,
    562345,
    599671,
    601127,
    627549,
    660951,
    783462,
    859671,
    861456,
    877168,
    917255
]
len(all_subjects)

In [None]:
subjects = [str(i) for i in all_subjects] # converts subject numbers to strings
shell = ["multi", "single", "CSD"] # try different ODF models for tractography
session = ["1200", "Retest"] # use both the test and retest datasets
seg_algo = ["afq", "reco80"] # try dfferent segmentation algorithms
reuse_tractography = [False] # Set to True to reuse tractography from a previous run with the same shell
use_callosal = [False] # set to True to get callosal bundles

# try all different combinations of subjects, shells, sessoins, and other parameters
args = list(itertools.product(subjects, shell, session, seg_algo, reuse_tractography, use_callosal))
print(args)

# attach aws keys to each list of arguments
def attach_keys(list_of_arg_lists):
    new_list_of_arg_lists = []
    for args in list_of_arg_lists:
        arg_ls = list(args)
        arg_ls.extend([aws_access_key, aws_secret_key])
        new_list_of_arg_lists.append(arg_ls)
    return new_list_of_arg_lists
args = attach_keys(args)

In [None]:
# use cloudknot to process HCP on aws
# will require your own aws account with access to s3 and ec2 resources
# credentials should be in ~/.aws/credentials with the hcp credentials
import cloudknot as ck
ck.set_region('us-west-2')

In [None]:
# build and push a docker image with non-conflicting requirements that uses version 0.6 of pyafq
di = ck.DockerImage(
    name='hcp-api-jk',
    func=afq_hcp_retest,
    base_image="python:3.8",
    github_installs='https://github.com/yeatmanlab/pyAFQ.git@0.6',
    overwrite=True)
with open(di.req_path, "w") as f:
    f.write(
        """pandas==1.1.4
        nibabel==3.2.1
        boto3==1.14.18
        s3fs==0.5.1
        dipy==1.3.0
        cloudpickle==1.6.0""")
di.build(tags=["hcp-trt-210101-0"])
repo = ck.aws.DockerRepo(name=ck.get_ecr_repo())
di.push(repo=repo)

In [None]:
# make knot to process HCP data, which is high quality and requires a lot of memory and disk space
knot = ck.Knot(
    name='hcp-trt-210101-0',
    docker_image=di,
    pars_policies=('AmazonS3FullAccess',),
    bid_percentage=100,
    volume_size=64,
    max_vcpus=256,
    memory=128000)

In [None]:
# run the first 3 argument combinations on cloud (do the rest if these succeed)
result_futures = knot.map(args[:3], starmap=True)

In [None]:
# check the status of your jobs
ck.set_region('us-west-2')
knot.view_jobs()


In [None]:
# check the status of and individual job
j0 = knot.jobs[0]
j0.status

In [None]:
# clobber your knot resource when you are done
#knot.clobber(clobber_pars=True, clobber_repo=True, clobber_image=True)

In [None]:
import AFQ.api as api

In [None]:
# when your done, download and combine afq profiles from all of the various jobs

# DTI, single shell
api.download_and_combine_afq_profiles(
    "~/AFQ_data/hcp_reliability_profiles/single_shell_test_profiles.csv",
    "my_bucket",
    "hcp_trt/single_shell/hcp_1200_afq"
    )
api.download_and_combine_afq_profiles(
    "~/AFQ_data/hcp_reliability_profiles/single_shell_retest_profiles.csv",
    "my_bucket",
    "hcp_trt/single_shell/hcp_retest_afq"
    )

# DKI, multi shell
api.download_and_combine_afq_profiles(
    "~/AFQ_data/hcp_reliability_profiles/multi_shell_test_profiles.csv",
    "my_bucket",
    "hcp_trt/multi_shell/hcp_1200_afq"
    )
api.download_and_combine_afq_profiles(
    "~/AFQ_data/hcp_reliability_profiles/multi_shell_retest_profiles.csv",
    "my_bucket",
    "hcp_trt/multi_shell/hcp_retest_afq"
    )

# reco80, DKI, multi shell
api.download_and_combine_afq_profiles(
    "~/AFQ_data/hcp_reliability_profiles/multi_shell_test_reco80_profiles.csv",
    "my_bucket",
    "hcp_trt/multi_shell/hcp_1200_reco80"
    )
api.download_and_combine_afq_profiles(
    "~/AFQ_data/hcp_reliability_profiles/multi_shell_retest_reco80_profiles.csv",
    "my_bucket",
    "hcp_trt/multi_shell/hcp_retest_reco80"
    )

# CSD, multi shell
api.download_and_combine_afq_profiles(
    "~/AFQ_data/hcp_reliability_profiles/multi_shell_CSD_test_profiles.csv",
    "my_bucket",
    "hcp_trt/multi_shell/hcp_1200_afq_CSD"
    )
api.download_and_combine_afq_profiles(
    "~/AFQ_data/hcp_reliability_profiles/multi_shell_CSD_retest_profiles.csv",
    "my_bucket",
    "hcp_trt/multi_shell/hcp_retest_afq_CSD"
    )

In [None]:
# get json file for RTP results from https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7705748/
import pandas as pd
data_path = op.join(op.expanduser("~"), "AFQ_data/hcp_reliability_profiles", "AllV04_multiSiteAndMeas_ComputationalReproducibility_2020-08-03T12-36.json")
hcp_ready = pd.read_json(data_path)


In [None]:
# convert format from RTP to our format (takes ~10-15 minutes)
from tqdm import tqdm
hcp_ready = hcp_ready[hcp_ready["Proj"] == "HCP"]
hcp_ready = hcp_ready.rename(columns={"Struct": "tractID", "SubjID": "subjectID"})
hcp_ready = hcp_ready.drop(["Proj", "ad", "curvature", "rd", "torsion", "volume", "SubjectMD", "AcquMD", "AnalysisMD"], axis=1)
hcp_ready_test = pd.DataFrame(columns=["nodeID", "tractID", "subjectID", "fa", "md"])
hcp_ready_retest = pd.DataFrame(columns=["nodeID", "tractID", "subjectID", "fa", "md"])
with tqdm(total=hcp_ready.shape[0]) as pbar:
    for index, row in hcp_ready.iterrows():
        pbar.update(1)
        if row["Proj"] == "HCP":
            if row["AcquMD"][0]['scanbValue'] == 1000:
                if row["TRT"] == "TEST":
                    for i in range(100):
                        hcp_ready_test = hcp_ready_test.append({
                            "nodeID": i,
                            "tractID": row["Struct"],
                            "subjectID": row["SubjID"],
                            "fa": row["fa"][i],
                            "md": row["md"][i]},
                            ignore_index=True)
                elif row["TRT"] == "RETEST":
                    for i in range(100):
                        hcp_ready_retest = hcp_ready_retest.append({
                            "nodeID": i,
                            "tractID": row["Struct"],
                            "subjectID": row["SubjID"],
                            "fa": row["fa"][i],
                            "md": row["md"][i]},
                            ignore_index=True)
                else:
                    raise ValueError("TRT not test or retest")
hcp_ready_test.to_csv("~/AFQ_data/hcp_reliability_profiles/rtp_test_profiles.csv", index=False)
hcp_ready_retest.to_csv("~/AFQ_data/hcp_reliability_profiles/rtp_retest_profiles.csv", index=False)

In [None]:
import AFQ.viz.utils as vut
import logging
from importlib import reload
import AFQ.data as afd

In [None]:
# GroupCSVComparison objects contain a list of different CSVs to compare
# and methods for doing the comparison (ie: tract profiles, ACIPs, subject and profile reliabilities)

reload(vut) # reloads viz.utils (useful after making tweaks to plotting code)
waypoint_comparisons = vut.GroupCSVComparison( # comparisons using waypoint ROI bundles
    'hcp_reliability_profiles/comparisons',
    [
        "~/hcp_reliability_profiles/single_shell_test_profiles.csv",
        "~/hcp_reliability_profiles/single_shell_retest_profiles.csv",
        "~/hcp_reliability_profiles/multi_shell_test_profiles.csv",
        "~/hcp_reliability_profiles/multi_shell_retest_profiles.csv",
        "~/hcp_reliability_profiles/multi_shell_CSD_test_profiles.csv",
        "~/hcp_reliability_profiles/multi_shell_CSD_retest_profiles.csv",
        "~/hcp_reliability_profiles/rtp_test_profiles.csv",
        "~/hcp_reliability_profiles/rtp_retest_profiles.csv"
    ], [
        'single_test',
        'single_retest',
        'multi_test',
        'multi_retest',
        'multi_CSD_test',
        'multi_CSD_retest',
        'rtp_test',
        'rtp_retest'
    ], is_special = [
        '',
        '',
        '',
        '',
        '',
        '',
        'mat',
        'mat'
    ],
    subjects=all_subjects,
    remove_model=True,
    scalar_bounds={'lb': {'FA': 0.2},
                   'ub': {'MD': 0.002}})
waypoint_comparisons.logger.setLevel(logging.WARNING)

reco_waypoint_comparison = vut.GroupCSVComparison( # comparisons on overlap between reco and afq bundles
    'hcp_reliability_profiles/comparisons',
    [
        "~/hcp_reliability_profiles/multi_shell_test_reco80_profiles.csv",
        "~/hcp_reliability_profiles/multi_shell_test_profiles.csv",
    ],
    ['multi_test_reco80', 'multi_test'],
    subjects=all_subjects, bundles=list(afd.BUNDLE_RECO_2_AFQ.values()),
    is_special = ["reco", ""])
reco_waypoint_comparison.logger.setLevel(logging.WARNING)

In [None]:
# Example tract profile
waypoint_comparisons.tract_profiles(names=['single_test'], show_plots=True)

In [None]:
# HCP TRR, multi shell vs csd
_, _, _, suf_bundles, multi_intersubject, _, multi_profile, _ = \
    waypoint_comparisons.reliability_plots(
        names=['multi_test', 'multi_retest'], show_plots=True)
_, _, _, _, CSD_intersubject, _, CSD_profile, _ = \
    waypoint_comparisons.reliability_plots(
        names=['multi_CSD_test', 'multi_CSD_retest'], show_plots=True)

waypoint_comparisons.compare_reliability(
    CSD_profile, multi_profile, "CSD", "DKI", suf_bundles, rtype="Profile TRR", show_plots=True, show_legend=False)
waypoint_comparisons.compare_reliability(
    CSD_intersubject, multi_intersubject, "CSD", "DKI", suf_bundles, rtype="Subject TRR", show_plots=True, show_legend=False)

In [None]:
# HCP TRR, single shell vs RTP
_, _, _, _, single_intersubject, _, single_profile, _ = \
    waypoint_comparisons.reliability_plots(
        names=['single_test', 'single_retest'], show_plots=True)
_, _, _, _, ready_intersubject, _, ready_profile, _ = \
    waypoint_comparisons.reliability_plots(
        names=['rtp_test', 'rtp_retest'], show_plots=True)

waypoint_comparisons.compare_reliability(
    single_profile, ready_profile, "pyAFQ DTI", "RTP DTI", suf_bundles, rtype="Profile TRR", show_plots=True, show_legend=False)
waypoint_comparisons.compare_reliability(
    single_intersubject, ready_intersubject, "pyAFQ DTI", "RTP DTI", suf_bundles, rtype="Subject TRR", show_plots=True, show_legend=False)

In [None]:
# HCP ODF robustness
waypoint_comparisons.reliability_plots(
        names=['multi_test', 'multi_CSD_test'], scalars=["FA", "MD"], rtype="Robustness", show_plots=True)
waypoint_comparisons.contrast_index(names=["multi_test", "multi_CSD_test"], show_plots=True)

In [None]:
# HCP Recobundles robustness
reco_waypoint_comparison.contrast_index(show_plots=True)
reco_waypoint_comparison.reliability_plots(rtype="Robustness", show_plots=True)