In [18]:
import os
aws_cred = os.path.expanduser('~/.aws/credentials')

f = open(aws_cred, "r")

while(not '[hcp]' in f.readline()):
    pass

aws_access_key = f.readline().split(" ")[-1].strip()
aws_secret_key = f.readline().split(" ")[-1].strip()

def attach_keys(arr):
    new_arr = []
    for e in arr:
        arg_arr = [e]
        arg_arr.append(aws_access_key)
        arg_arr.append(aws_secret_key)
        new_arr.append(arg_arr)
    return new_arr

In [29]:
def profile_hcp(args):
    import os.path as op
    import matplotlib.pyplot as plt
    import numpy as np
    import nibabel as nib
    import dipy.data as dpd
    from dipy.data import fetcher
    import dipy.tracking.utils as dtu
    import dipy.tracking.streamline as dts
    from dipy.io.streamline import save_tractogram, load_tractogram
    from dipy.stats.analysis import afq_profile, gaussian_weights
    from dipy.io.stateful_tractogram import StatefulTractogram
    from dipy.io.stateful_tractogram import Space

    import AFQ.utils.streamlines as aus
    import AFQ.data as afd
    import AFQ.tractography as aft
    import AFQ.registration as reg
    import AFQ.dti as dti
    import AFQ.segmentation as seg
    from AFQ.utils.volume import patch_up_roi
    from AFQ.data import fetch_hcp

    subject = args[0]
    aws_access_key = args[1]
    aws_secret_key = args[2]
    print("here")
    data_paths = list(fetch_hcp([subject], profile_name=False,
                                aws_access_key_id=aws_access_key,
                                aws_secret_access_key=aws_secret_key).keys())
    
    fdata = data_paths[2]
    fbval = data_paths[0]
    fbvec = data_paths[1]

    img = nib.load(fdata)

    if not op.exists('./dti_FA.nii.gz'):
        dti_params = dti.fit_dti(fdata, fbval, fbvec,
                                 out_dir='.')
    else:
        dti_params = {'FA': './dti_FA.nii.gz',
                      'params': './dti_params.nii.gz'}

    FA_img = nib.load(dti_params['FA'])
    FA_data = FA_img.get_fdata()

    templates = afd.read_templates()
    bundle_names = ["CST", "ILF"]

    bundles = {}
    for name in bundle_names:
        for hemi in ['_R', '_L']:
            bundles[name + hemi] = {
                'ROIs': [templates[name + '_roi1' + hemi],
                         templates[name + '_roi2' + hemi]],
                'rules': [True, True],
                'prob_map': templates[name + hemi + '_prob_map'],
                'cross_midline': False}

    MNI_T2_img = dpd.read_mni_template()
    if not op.exists('mapping.nii.gz'):
        import dipy.core.gradients as dpg
        gtab = dpg.gradient_table(fbval, fbvec)
        _, mapping = reg.syn_register_dwi(fdata, gtab)
        reg.write_mapping(mapping, './mapping.nii.gz')
    else:
        mapping = reg.read_mapping('./mapping.nii.gz', img, MNI_T2_img)

    print("Tracking...")
    if not op.exists('dti_streamlines.trk'):
        seed_roi = np.zeros(img.shape[:-1])
        for name in bundle_names:
            for hemi in ['_R', '_L']:
                for roi in bundles[name + hemi]['ROIs']:
                    warped_roi = patch_up_roi(
                        (mapping.transform_inverse(
                            roi.get_data().astype(np.float32),
                         interpolation='linear')) > 0)

                    # Add voxels that aren't there yet:
                    seed_roi = np.logical_or(seed_roi, warped_roi)

        nib.save(nib.Nifti1Image(seed_roi.astype(float), img.affine), 'seed_roi.nii.gz')
        streamlines = aft.track(dti_params['params'], seed_mask=seed_roi,
                                stop_mask=FA_data, stop_threshold=0.1)

        sft = StatefulTractogram(streamlines, img, Space.RASMM)
        save_tractogram(sft, './dti_streamlines.trk',
                        bbox_valid_check=False)
    else:
        tg = load_tractogram('./dti_streamlines.trk', img)
        streamlines = tg.streamlines

    streamlines = dts.Streamlines(
        dtu.transform_tracking_output(streamlines,
                                      np.linalg.inv(img.affine)))

    print("Segmenting fiber groups...")
    segmentation = seg.Segmentation()
    segmentation.segment(bundles,
                         streamlines,
                         fdata=fdata,
                         fbval=fbval,
                         fbvec=fbvec,
                         mapping=mapping,
                         reg_template=MNI_T2_img)


    fiber_groups = segmentation.fiber_groups

    print("Cleaning fiber groups...")
    for bundle in bundles:
        fiber_groups[bundle] = seg.clean_fiber_group(fiber_groups[bundle])

    for kk in fiber_groups:
        print(kk, len(fiber_groups[kk]))

        sft = StatefulTractogram(
            dtu.transform_tracking_output(fiber_groups[kk], img.affine),
            img, Space.RASMM)

        save_tractogram(sft, './%s_afq.trk'%kk,
                        bbox_valid_check=False)

    axs = []
    print("Extracting tract profiles...")
    for bundle in bundles:
        fig, ax = plt.subplots(1)
        weights = gaussian_weights(fiber_groups[bundle])
        profile = afq_profile(FA_data, fiber_groups[bundle],
                              np.eye(4), weights=weights)
        ax.plot(profile)
        ax.set_title(bundle)
        axs.append(ax)

    return axs


In [37]:
#profile_hcp(["100206", aws_access_key, aws_secret_key])

In [31]:
import cloudknot as ck

In [33]:
knot = ck.Knot(name='profile-hcp',
               image_github_installs=("https://github.com/arokem/pyAFQ.git"),
               func=profile_hcp, pars_policies=('AmazonS3FullAccess',))

In [34]:
# import boto3
# bucket_name = 'profile-hcp'
# aws_cred_local = os.path.expanduser('~/.aws/credentials')
# aws_cred_bucket = '.aws/credentials'

# upload aws credentials
# s3 = boto3.client('s3')
# with open(aws_cred_local, "rb") as f:
#    s3.upload_fileobj(f, bucket_name, aws_cred_bucket)

In [35]:
result_futures = knot.map(attach_keys(["100206", "162228", "175540"]))

In [43]:
knot.view_jobs()

Job ID              Name                        Status   
---------------------------------------------------------
0c892b21-37ac-42b6-a2df-4fd881cea8c8        profile-hcp-0               FAILED   


In [12]:
r0 = result_futures
results = r0.result()

KeyboardInterrupt: 

In [30]:
knot.clobber(clobber_pars=True, clobber_repo=True, clobber_image=True)

In [38]:
import boto3
s3 = boto3.client('s3')
s3.create_bucket(Bucket='fails-to-create')

ClientError: An error occurred (AccessDenied) when calling the CreateBucket operation: Access Denied