In [2]:
import configparser
import os.path as op

CP = configparser.ConfigParser()
CP.read_file(open(op.join(op.expanduser('~'), '.aws', 'credentials')))
CP.sections()
aws_access_key = CP.get('hcp', 'AWS_ACCESS_KEY_ID')
aws_secret_key = CP.get('hcp', 'AWS_SECRET_ACCESS_KEY')

def attach_keys(arr):
    return [(e, aws_access_key, aws_secret_key) for e in arr]

In [3]:
def profile_hcp(args):
    import os.path as op
    import numpy as np
    import nibabel as nib
    import dipy.data as dpd
    from dipy.data import fetcher
    import dipy.tracking.utils as dtu
    import dipy.tracking.streamline as dts
    from dipy.io.streamline import save_tractogram, load_tractogram
    from dipy.stats.analysis import afq_profile, gaussian_weights
    from dipy.io.stateful_tractogram import StatefulTractogram
    from dipy.io.stateful_tractogram import Space
    import dipy.core.gradients as dpg

    import AFQ.utils.streamlines as aus
    import AFQ.data as afd
    import AFQ.tractography as aft
    import AFQ.registration as reg
    import AFQ.dti as dti
    import AFQ.segmentation as seg
    from AFQ.utils.volume import patch_up_roi
    from AFQ.data import fetch_hcp

    subject = args[0]
    aws_access_key = args[1]
    aws_secret_key = args[2]

    data_paths = list(fetch_hcp([subject], profile_name=False,
                                aws_access_key_id=aws_access_key,
                                aws_secret_access_key=aws_secret_key).keys())
    
    fdata = data_paths[2]
    fbval = data_paths[0]
    fbvec = data_paths[1]

    img = nib.load(fdata)
    
    dti_params = dti.fit_dti(fdata, fbval, fbvec,
                             out_dir='.')
    FA_img = nib.load(dti_params['FA'])
    FA_data = FA_img.get_fdata()

    templates = afd.read_templates()
    bundle_names = ["CST", "ILF"]

    bundles = {}
    for name in bundle_names:
        for hemi in ['_R', '_L']:
            bundles[name + hemi] = {
                'ROIs': [templates[name + '_roi1' + hemi],
                         templates[name + '_roi2' + hemi]],
                'rules': [True, True],
                'prob_map': templates[name + hemi + '_prob_map'],
                'cross_midline': False}

    MNI_T2_img = dpd.read_mni_template()
    gtab = dpg.gradient_table(fbval, fbvec)
    _, mapping = reg.syn_register_dwi(fdata, gtab)

    print("Tracking...")
    seed_roi = np.zeros(img.shape[:-1])
    for name in bundle_names:
        for hemi in ['_R', '_L']:
            for roi in bundles[name + hemi]['ROIs']:
                warped_roi = patch_up_roi(
                    (mapping.transform_inverse(
                        roi.get_data().astype(np.float32),
                     interpolation='linear')) > 0)

                # Add voxels that aren't there yet:
                seed_roi = np.logical_or(seed_roi, warped_roi)

    nib.save(nib.Nifti1Image(seed_roi.astype(float), img.affine), 'seed_roi.nii.gz')
    streamlines = aft.track(dti_params['params'], seed_mask=seed_roi,
                            stop_mask=FA_data, stop_threshold=0.1)

    streamlines = dts.Streamlines(
        dtu.transform_tracking_output(streamlines,
                                      np.linalg.inv(img.affine)))

    print("Segmenting fiber groups...")
    segmentation = seg.Segmentation()
    segmentation.segment(bundles,
                         streamlines,
                         fdata=fdata,
                         fbval=fbval,
                         fbvec=fbvec,
                         mapping=mapping,
                         reg_template=MNI_T2_img)


    fiber_groups = segmentation.fiber_groups

    for bundle in bundles:
        fiber_groups[bundle] = seg.clean_fiber_group(fiber_groups[bundle])

    profiles = []
    print("Extracting tract profiles...")
    for bundle in bundles:
        weights = gaussian_weights(fiber_groups[bundle])
        profile = afq_profile(FA_data, fiber_groups[bundle],
                              np.eye(4), weights=weights)
        profiles.append(profile)

    return [bundles, profiles]


In [4]:
#profile_hcp(["100206", aws_access_key, aws_secret_key])

In [5]:
import cloudknot as ck
ck.set_region('us-east-1')

In [6]:
knot = ck.Knot(name='profile-hcp-3',
               image_github_installs=("https://github.com/arokem/pyAFQ.git"),
               resource_type='SPOT',
               bid_percentage=100,
               memory=64000,
               func=profile_hcp, pars_policies=('AmazonS3FullAccess',))

In [7]:
result_futures = knot.map(attach_keys(["100206", "162228", "175540"]))

In [12]:
knot.view_jobs()

Job ID              Name                        Status   
---------------------------------------------------------
d15a8177-1868-4a25-8744-002393f7e1ad        profile-hcp-3-0             SUCCEEDED


In [None]:
j0 = knot.jobs[0]
print(len(j0.result()))
bundles, profiles = j0.result()
for bundle_idx, bundle in bundles:
    fig, ax = plt.subplots(1)
    
    ax.plot(profiles[bundle_idx])
    ax.set_title(bundle)

plt.show()

In [63]:
knot.clobber(clobber_pars=True, clobber_repo=True, clobber_image=True)