In [1]:
import cloudknot as ck

In [2]:
ck.set_region('us-east-1')

In [3]:
def bundle_hcp(params):
    # algo should be 'reco' or 'afq'
    subject, hcp_ak, hcp_sk, algo = params
    import pandas as pd
    import s3fs
    import json
    import logging
    import os.path as op
    import numpy as np
    import nibabel as nib
    import dipy.data as dpd
    import dipy.tracking.utils as dtu
    import dipy.tracking.streamline as dts
    from dipy.io.streamline import save_tractogram, load_tractogram
    from dipy.stats.analysis import afq_profile, gaussian_weights
    from dipy.io.stateful_tractogram import StatefulTractogram
    from dipy.io.stateful_tractogram import Space
    import dipy.core.gradients as dpg
    from dipy.segment.mask import median_otsu

    import AFQ.data as afd
    import AFQ.tractography as aft
    import AFQ.registration as reg
    import AFQ.dti as dti
    import AFQ.segmentation as seg
    from AFQ import api
    from AFQ import csd
    
    import numpy as np
    deriv_name = 'hcp.derivatives'
    if algo == 'reco':
        algo_name = 'recobundles'
        bucket_name = 'hcp.' + algo_name
        algo_name_formal = "RecoBundles"
        bundle_names = ['CST',
                        'C',
                        'F',
                        'UF',
                        'MCP',
                        'AF',
                        'CCMid',
                        'AF',
                        'CC_ForcepsMajor',
                        'CC_ForcepsMinor',
                        'IFOF'] 
    else:
        algo_name = 'afq'
        bucket_name = 'hcp.' + algo_name
        algo_name_formal = "AFQ"
        bundle_names = ["ATR", "CGC", "CST", "HCC", "IFO", "ILF", "SLF", "ARC", "UNC", "FA", "FP"]
    
    logging.basicConfig(level=logging.INFO)
    log = logging.getLogger(__name__)    
    
    log.info(f"Fetching HCP subject {subject}")
    afd.fetch_hcp([subject], 
                  profile_name=False,
                  aws_access_key_id=hcp_ak,
                  aws_secret_access_key=hcp_sk)    
        
    dwi_dir = op.join(afd.afq_home, 'HCP', 'derivatives',
                      'dmriprep', f'sub-{subject}', 'sess-01/dwi')

    anat_dir = op.join(afd.afq_home, 'HCP', 'derivatives',
                      'dmriprep', f'sub-{subject}', 'sess-01/anat')

    hardi_fdata = op.join(dwi_dir, f"sub-{subject}_dwi.nii.gz")
    hardi_fbval = op.join(dwi_dir, f"sub-{subject}_dwi.bval")
    hardi_fbvec = op.join(dwi_dir, f"sub-{subject}_dwi.bvec")

    log.info(f"Reading data from file {hardi_fdata}")
    img = nib.load(hardi_fdata)
    log.info(f"Creating gradient table from {hardi_fbval} and {hardi_fbvec}")
    gtab = dpg.gradient_table(hardi_fbval, hardi_fbvec)
    
    fs = s3fs.S3FileSystem()
    
    wm_mask_fname = f'{deriv_name}/sub-{subject}/sub-{subject}_wm_mask.nii.gz'
    if fs.exists(wm_mask_fname):
        log.info(f"WM mask exists. Reading from {wm_mask_fname}")
        wm_img = afd.s3fs_nifti_read(wm_mask_fname)
        wm_mask = wm_img.get_data()
    else:
        log.info(f"Calculating WM segmentation")
        wm_labels=[250, 251, 252, 253, 254, 255, 41, 2, 16, 77]
        seg_img = nib.load(op.join(anat_dir, f"sub-{subject}_aparc+aseg.nii.gz"))
        seg_data_orig = seg_img.get_fdata()
        # For different sets of labels, extract all the voxels that
        # have any of these values:
        wm_mask = np.sum(np.concatenate([(seg_data_orig == l)[..., None]
                                        for l in wm_labels], -1), -1)

        # Resample to DWI data:
        dwi_data = img.get_fdata()
        wm_mask = np.round(reg.resample(wm_mask, 
                                        dwi_data[..., 0],
                                        seg_img.affine,
                                        img.affine)).astype(int)

        wm_img = nib.Nifti1Image(wm_mask.astype(int),
                                 img.affine)
        log.info(f"Saving to {wm_mask_fname}")
        afd.s3fs_nifti_write(wm_img, wm_mask_fname)
    
    fa_fname = f'{deriv_name}/sub-{subject}/sub-{subject}_dti_FA.nii.gz'
    dti_params_fname = f'{deriv_name}/sub-{subject}/sub-{subject}_dti.nii.gz'
    dti_meta_fname = f'{deriv_name}/sub-{subject}/sub-{subject}_dti.json'
    if fs.exists(fa_fname):
        log.info(f"DTI already exists. Reading FA from {fa_fname}")
        log.info(f"DTI already exists. Reading params from {dti_params_fname}")
        FA_img = afd.s3fs_nifti_read(fa_fname)
        dti_params = afd.s3fs_nifti_read(dti_params_fname)
    else:
        log.info("Calculating DTI")
        dti_params = dti.fit_dti(hardi_fdata, hardi_fbval, hardi_fbvec,
                                out_dir='.', b0_threshold=50,
                                mask=wm_mask)
        FA_img = nib.load('./dti_FA.nii.gz')
        log.info(f"Writing FA to {fa_fname}")
        afd.s3fs_nifti_write(FA_img, fa_fname)
        dti_params_img = nib.load('./dti_params.nii.gz')
        log.info(f"Writing DTI params to {dti_params_fname}")
        afd.s3fs_nifti_write(dti_params_img, dti_params_fname)
        dti_params_json = {"Model": "Diffusion Tensor",
                           "OrientationRepresentation": "param",
                            "ReferenceAxes": "xyz",
                            "Parameters": {
                                "FitMethod": "ols",
                                "OutlierRejection": False
                                }
                          }
        log.info(f"Writing DTI metadata to {dti_meta_fname}")
        afd.s3fs_json_write(dti_params_json, dti_meta_fname)

    log.info(f"Reading FA data from img")
    FA_data = FA_img.get_fdata()

    csd_fname = f'{deriv_name}/sub-{subject}/sub-{subject}_csd.nii.gz'
    csd_meta_fname = f'{deriv_name}/sub-{subject}/sub-{subject}_csd.json'

    if fs.exists(csd_fname):
        log.info(f"CSD already exists. Getting it from {csd_fname}")        
        csd_params = afd.s3fs_nifti_read(csd_fname)
    else:
        log.info(f"Calculating CSD")        
        csd_params = csd.fit_csd(hardi_fdata, hardi_fbval, hardi_fbvec,
                                 out_dir='.', b0_threshold=50,
                                 mask=wm_mask)
        afd.s3fs_nifti_write(nib.load(csd_params), csd_fname)

        
        csd_params_json = {
    "Model": "Constrained Spherical Deconvolution (CSD)",
    "ModelURL": "https://github.com/nipy/dipy/commit/abf31d15a0ee5dc0704ee03ebbba57358d540612",
    "Shells": [ 0, 1000, 2000, 3000 ],
    "Parameters": {
        "ResponseFunctionTensor" : "auto",
        "SphericalHarmonicBasis": "Descoteaux",
        "NonNegativityConstraint": "hard",
        "SphericalHarmonicDegree" : 8
                }
            }
        
        log.info(f"Writing CSD metadata to {csd_meta_fname}")
        afd.s3fs_json_write(csd_params_json, csd_meta_fname)


    csd_streamlines_fname = f'{deriv_name}/sub-{subject}/sub-{subject}_model-csd_track-det.trk'
    csd_streamlines_meta_fname = f'{deriv_name}/sub-{subject}/sub-{subject}_model-csd_track-det.json'
    if fs.exists(csd_streamlines_fname):
        log.info(f"Streamlines already exist. Loading from {csd_streamlines_fname}")        
        fs.get(csd_streamlines_fname, './csd_streamlines.trk')
        tg = load_tractogram('./csd_streamlines.trk', img)
        streamlines = tg.streamlines
    else:
        log.info(f"Generating streamlines")      
        seed_roi = np.zeros(img.shape[:-1])
        seed_roi[FA_data > 0.1] = 1
        seed_roi[wm_mask < 1] = 0
        streamlines = aft.track(csd_params, seed_mask=seed_roi,
                                directions='det', stop_mask=FA_data,
                                stop_threshold=0.1)
        log.info(f"After tracking, there are {len(streamlines)} streamlines")
        sft = StatefulTractogram(streamlines, img, Space.RASMM)
        save_tractogram(sft, './csd_streamlines.trk',
                        bbox_valid_check=False)
        log.info(f"Uploading streamlines to {csd_streamlines_fname}")
        fs.upload('./csd_streamlines.trk', csd_streamlines_fname)
        csd_streamlines_json = {
            "Algorithm" : "LocalTracking",
            "AlgorithmURL":"https://github.com/yeatmanlab/pyAFQ/commit/c04835cd4ca13d28c20bb449d6f088e656c55e57",
            "Parameters":{
            "SeedRoi": "dti_FA>0.1",
            "Directions": "det",
            "StopMask" : "dti_FA<0.1"}
            }
        log.info(f"Writing streamlines metadata to {csd_streamlines_meta_fname}")
        afd.s3fs_json_write(csd_streamlines_json, csd_streamlines_meta_fname)
        
    log.info("Segmenting")
    
    if algo == 'afq':
        streamlines = dts.Streamlines(
            dtu.transform_tracking_output(streamlines,
                                  np.linalg.inv(img.affine)))

    bundles = api.make_bundle_dict(bundle_names=bundle_names, seg_algo=algo)
    mapping = reg.syn_register_dwi(hardi_fdata, gtab)[1]

    segmentation = seg.Segmentation(algo=algo,
                                    model_clust_thr=20,
                                    reduction_thr=20,
                                    b0_threshold=50,
                                    return_idx=True)
    segmentation.segment(bundles, streamlines, img_affine=img.affine, mapping=mapping)
    fiber_groups = segmentation.fiber_groups
        
    log.info(f"Cleaning fiber groups...")   
    for kk in fiber_groups:
        if len(fiber_groups[kk]['sl']) >= 20:
            new_fibers, idx_in_bundle = seg.clean_fiber_group(fiber_groups[kk]['sl'], return_idx=True)
            fiber_groups[kk]['sl'] = new_fibers
            fiber_groups[kk]['idx'] = fiber_groups[kk]['idx'][idx_in_bundle]

    sl_count = []
    for kk in fiber_groups:
        sl_count.append(len(fiber_groups[kk]['sl']))
        log.info(f"There are {sl_count[-1]} streamlines in {kk}")
        sft = StatefulTractogram(fiber_groups[kk]['sl'], img, Space.RASMM)
        local_tg_fname = ('./%s_' + algo + '.trk')%kk
        save_tractogram(sft, local_tg_fname,
                        bbox_valid_check=False)
        tg_fname = f'{bucket_name}/sub-{subject}/sub-{subject}_model-csd_track-det_segment-' + algo_name + f'_bundle-{kk}.trk'
        log.info(f"Uploading {local_tg_fname} to {tg_fname}")
        fs.upload(('./%s_' + algo + '.trk')%kk, tg_fname)
        tg_meta_fname = f'{bucket_name}/sub-{subject}/sub-{subject}_model-csd_track-det_segment-' + algo_name + f'_bundle-{kk}.json'
        tg_meta_json = {
            "Algorithm" : algo_name_formal,
            "AlgorithmURL" : "https://github.com/yeatmanlab/pyAFQ/commit/b63adc5",
            "Parameters":
            {"model_clust_thr":20,
             "reduction_thr":20}
        }
        
        log.info(f"Uploading segmentation metadata to {tg_meta_fname}")
        afd.s3fs_json_write(tg_meta_json, tg_meta_fname)

        np.save('bundle_idx.npy', fiber_groups[kk]['idx'])
        idx_fname = f'{bucket_name}/sub-{subject}/sub-{subject}_model-csd_track-det_segment-' + algo_name + f'_bundle-{kk}_idx.npy'
        log.info(f"Uploading bundle indices to {idx_fname}")
        fs.upload('bundle_idx.npy', idx_fname)

    log.info("Saving streamline counts")
    sl_count = pd.DataFrame(data=sl_count, index=fiber_groups.keys(), columns=["streamlines"])
    sl_count.to_csv("./sl_count.csv")
    sl_count_fname = f'{bucket_name}/sub-{subject}/sub-{subject}_model-csd_track-det_segment-' + algo_name + f'_counts.csv'
    log.info(f"Uploading streamline counts to {sl_count_fname}")
    fs.upload("./sl_count.csv", sl_count_fname)

    log.info(f"Extracting tract profiles...")
    profiles = []
    for kk in fiber_groups:
        weights = gaussian_weights(fiber_groups[kk]['sl'])
        profile = afq_profile(FA_data, fiber_groups[kk]['sl'],
                              np.eye(4), weights=weights)
        for ii in range(len(profile)):
            # Subject, Bundle, node, method, metric (FA, MD), value
            profiles.append([subject, kk, ii, algo_name_formal, 'FA', profile[ii]])

    profiles = pd.DataFrame(data=profiles, columns=["Subject", "Bundle", "Node", "Method", "Metric", "Value"])
    profiles.to_csv("./profiles.csv")
    profiles_fname = f'{bucket_name}/sub-{subject}/sub-{subject}_model-csd_track-det_segment-' + algo_name + f'_profiles.csv'
    log.info(f"Uploading profiles to {profiles_fname}")
    fs.upload("./profiles.csv", profiles_fname)


In [4]:
import configparser
import os.path as op

In [5]:
CP = configparser.ConfigParser()
CP.read_file(open(op.join(op.expanduser('~'), '.aws', 'credentials')))
CP.sections()
ak = CP.get('hcp', 'AWS_ACCESS_KEY_ID')
sk = CP.get('hcp', 'AWS_SECRET_ACCESS_KEY')

In [6]:
rb_knot = ck.Knot(name='bundle_hcp-64gb-191206-02',
                  func=bundle_hcp,
                  base_image='python:3.7',
                  image_github_installs="https://github.com/36000/pyAFQ.git",
                  pars_policies=('AmazonS3FullAccess',),
                  resource_type="SPOT",
                  bid_percentage=100,
                  memory=64000)

In [27]:
inputs_afq = [(sub, ak, sk, 'afq') for sub in [
            100408,
            100610,
            101006,
            101107,
            101309,
            101410,
            101915,
            102008,
            102109,
            102311,
            102513]]
inputs_reco = [(sub, ak, sk, 'reco') for sub in [
            100408,
            100610,
            101006,
            101107,
            101309,
            101410,
            101915,
            102008,
            102109,
            102311,
            102513]]

In [28]:
ft = rb_knot.map(inputs_afq + inputs_reco)