In [53]:
## imports
import os.path as op
import dipy.reconst.csdeconv as csd
from dipy.data.fetcher import fetch_hcp
from dipy.core.gradients import gradient_table
import nibabel  as nib
import time
from dipy.reconst.csdeconv import (auto_response_ssst,
                                   mask_for_response_ssst,
                                   response_from_mask_ssst)
from dipy.align import resample

In [54]:
## fetch subject data
subject = 100307
dataset_path = fetch_hcp(subject)[1]
subject_dir = op.join(
    dataset_path,
    "derivatives",
    "hcp_pipeline",
    f"sub-{subject}",
    )
print(subject_dir)
subject_files = [op.join(subject_dir, "dwi",
        f"sub-{subject}_dwi.{ext}") for ext in ["nii.gz", "bval", "bvec"]]

/Users/asagilmore/.dipy/HCP_1200/derivatives/hcp_pipeline/sub-100307


In [55]:
## load data
dwi_img = nib.load(subject_files[0])
data = dwi_img.get_fdata()
seg_img = nib.load(op.join(
    subject_dir, "anat", f'sub-{subject}_aparc+aseg_seg.nii.gz'))
seg_data = seg_img.get_fdata()
brain_mask = seg_data > 0
dwi_volume = nib.Nifti1Image(data[..., 0], dwi_img.affine)
brain_mask_xform = resample(brain_mask, dwi_volume,
                            moving_affine=seg_img.affine)
brain_mask_data = brain_mask_xform.get_fdata().astype(int)
gtab = gradient_table(subject_files[1], subject_files[2])
response_mask = mask_for_response_ssst(gtab, data, roi_radii=10, fa_thr=0.7)
response, _ = response_from_mask_ssst(gtab, data, response_mask)


csdm = csd.ConstrainedSphericalDeconvModel(gtab, response=response)

In [56]:
import pandas as pd
import math
from scipy.stats import norm
import time

In [84]:
import time
import statistics


runTimeData = pd.DataFrame(columns=["engine","vox_per_chunk","time"])

## run csdm with the given engine and vox_per_chunk
# appends the given time, engine, and vox_per_chunk to the data dataframe
# returns the time it took to run
def run_csdm(engine, vox_per_chunk):

    global runTimeData

    start = time.time()
    csdm.fit(data, mask=brain_mask_data, engine=engine, vox_per_chunk=vox_per_chunk)
    end = time.time()

    runTime = end-start

    new_row = pd.Series({"engine":engine,"vox_per_chunk":vox_per_chunk,"time":runTime})
    runTimeData = pd.concat([runTimeData, new_row.reset_index(drop=True)], ignore_index=True)

    print("engine: ", engine, " vox_per_chunk: ", vox_per_chunk, " time: ", runTime)

    return runTime

"""
Runs the csdm model with the given engine and vox_per_chunk until a certain confidence interval is reached
for a certain confidence level.
Here we assume that the distribution of error is roughly gassian, this has yet to be tested.

Args:
engine (str): the engine to use for parallelization
vox_per_chunk (int): the number of voxels to process in each chunk
conf_int (float): the confidence interval to reach as a percentage of mean
conf_level (float): the confidence level to reach as a percentage
"""
def compute_to_confidence(engine, vox_per_chunk, conf_int, conf_level=0.95,max_iter=30):
    times = []

    # z score for conf level
    Z = norm.ppf((1 + conf_level) / 2)

    ##run a couple times to get standard deviation
    print("running inital 6 to get standard deviation")
    for i in range(0,6):
        times.append(run_csdm(engine, vox_per_chunk))

    mean = statistics.mean(times)
    std = statistics.stdev(times)

    margin_err = conf_int * mean

    n = math.ceil((Z * std / conf_int) ** 2)

    print("mean: ", mean, " std: ", std, " margin_err: ", margin_err, "conf_level: ", conf_level, " n(needed to meet margin_err + conf_level): ", n)


    # if we have enough samples, return,v
    # else run more until we do
    # if we run more than max_iter times, return
    if (len(times) >= n):
        return
    else:
        while(len(times) < n):
            times.append(run_csdm(engine, vox_per_chunk))

            # update mean and std
            mean = statistics.mean(times)
            std = statistics.stdev(times)
            margin_err = conf_int * mean
            n = math.ceil((Z * std / conf_int) ** 2)

            print("mean: ", mean, " std: ", std, " margin_err: ", margin_err, "conf_level: ", conf_level, " n(needed to meet margin_err + conf_level): ", n)

            if(n > max_iter):
                print("max iterations reached")
                return

    return






In [85]:
engines = ["ray","joblib","dask","serial"]

vox_per_chunk = [2**i for i in range(0,17)]

compute_to_confidence(engine = "ray", vox_per_chunk = 5000, conf_int = 0.05, conf_level = 0.95, max_iter = 30)

running inital 6 to get standard deviation
engine:  ray  vox_per_chunk:  5000  time:  15.146624088287354
