In [1]:
## imports
import os.path as op
import dipy.reconst.csdeconv as csd
from dipy.data.fetcher import fetch_hcp
from dipy.core.gradients import gradient_table
import nibabel  as nib
import time
from dipy.reconst.csdeconv import (auto_response_ssst,
                                   mask_for_response_ssst,
                                   response_from_mask_ssst)
from dipy.align import resample

In [2]:
## fetch subject data
subject = 100307
dataset_path = fetch_hcp(subject)[1]
subject_dir = op.join(
    dataset_path,
    "derivatives",
    "hcp_pipeline",
    f"sub-{subject}",
    )
print(subject_dir)
subject_files = [op.join(subject_dir, "dwi",
        f"sub-{subject}_dwi.{ext}") for ext in ["nii.gz", "bval", "bvec"]]

/Users/asagilmore/.dipy/HCP_1200/derivatives/hcp_pipeline/sub-100307


In [3]:
## load data
dwi_img = nib.load(subject_files[0])
data = dwi_img.get_fdata()
seg_img = nib.load(op.join(
    subject_dir, "anat", f'sub-{subject}_aparc+aseg_seg.nii.gz'))
seg_data = seg_img.get_fdata()
brain_mask = seg_data > 0
dwi_volume = nib.Nifti1Image(data[..., 0], dwi_img.affine)
brain_mask_xform = resample(brain_mask, dwi_volume,
                            moving_affine=seg_img.affine)
brain_mask_data = brain_mask_xform.get_fdata().astype(int)
gtab = gradient_table(subject_files[1], subject_files[2])
response_mask = mask_for_response_ssst(gtab, data, roi_radii=10, fa_thr=0.7)
response, _ = response_from_mask_ssst(gtab, data, response_mask)


csdm = csd.ConstrainedSphericalDeconvModel(gtab, response=response)

In [4]:
import csv
import os
import math
from scipy.stats import norm
import time

In [5]:
import time
import statistics


runTimeData = []

## run csdm with the given engine and vox_per_chunk
# appends the given time, engine, and vox_per_chunk to the data dataframe
# returns the time it took to run
def run_csdm(engine, vox_per_chunk,save = True):

    global runTimeData

    start = time.time()
    csdm.fit(data, mask=brain_mask_data, engine=engine, vox_per_chunk=vox_per_chunk)
    end = time.time()

    runTime = end-start

    if (save):
        runTimeData.append({'engine':engine,'vox_per_chunk':vox_per_chunk,'time':runTime})
    else:
        print("save turned off, runTime not saved")

    print("engine: ", engine, " vox_per_chunk: ", vox_per_chunk, " time: ", runTime)

    return runTime

def save_data(filename):
    global runTimeData

    # specify the names for CSV column headers
    fieldnames = runTimeData[0].keys() if runTimeData else Error("No data to save")

    # writing to csv file
    with open(filename, 'a', newline='') as csvfile:
        # creating a csv writer object
        csvwriter = csv.DictWriter(csvfile, fieldnames=fieldnames)

        # writing headers (field names) if the file doesn't exist or it is empty
        if not os.path.isfile(filename) or os.path.getsize(filename) == 0:
            csvwriter.writeheader()

        # writing the data rows
        csvwriter.writerows(runTimeData)

    runTimeData.clear()

"""
Runs the csdm model with the given engine and vox_per_chunk until a certain confidence interval is reached
for a certain confidence level.
Here we assume that the distribution of error is roughly gassian, this has yet to be tested.

Args:
engine (str): the engine to use for parallelization
vox_per_chunk (int): the number of voxels to process in each chunk
conf_int (float): the confidence interval to reach as a percentage of mean
conf_level (float): the confidence level to reach as a percentage
"""
def compute_to_confidence(engine, vox_per_chunk, conf_int, conf_level=0.95,max_iter=30):
    times = []

    # z score for conf level
    Z = norm.ppf((1 + conf_level) / 2)

    ##run first empty once to get rid of startup overhead
    run_csdm(engine, vox_per_chunk,False)
    ##run a couple times to get standard deviation
    print("running inital 6 to get standard deviation")
    for i in range(0,6):
        times.append(run_csdm(engine, vox_per_chunk))

    mean = statistics.mean(times)
    std = statistics.stdev(times)

    margin_err = conf_int * mean

    n = math.ceil((Z * std / conf_int) ** 2)

    print("mean: ", mean, " std: ", std, " margin_err: ", margin_err, "conf_level: ", conf_level, " n(needed to meet margin_err + conf_level): ", n)


    # if we have enough samples, return,v
    # else run more until we do
    # if we run more than max_iter times, return
    if (len(times) >= n):
        print("enough samples taken")
        return
    else:
        while(len(times) < n):
            times.append(run_csdm(engine, vox_per_chunk))

            # update mean and std
            mean = statistics.mean(times)
            std = statistics.stdev(times)
            margin_err = conf_int * mean
            n = math.ceil((Z * std / conf_int) ** 2)

            print("mean: ", mean, " std: ", std, " margin_err: ", margin_err, "conf_level: ", conf_level, " n(needed to meet margin_err + conf_level): ", n)

            if(len(times) > max_iter):
                print("max iterations reached")
                return

    return






In [6]:
engines = ["ray","dask"]

vox_per_chunk = [2**i for i in range(3,17)]


for i in range(30):
    run_csdm("serial",1)
    save_data('runTimeData.csv')
    for v in vox_per_chunk:
        for e in engines:
            run_csdm(e,v)
            save_data('runTimeData.csv')


Fitting reconstruction model: : 3658350it [03:38, 16719.91it/s]                           


engine:  serial  vox_per_chunk:  1  time:  218.91965508460999


2024-01-30 10:49:58,296	INFO worker.py:1724 -- Started a local Ray instance.
[36m(raylet)[0m Spilled 2141 MiB, 6818 objects, write throughput 2044 MiB/s. Set RAY_verbose_spill_logs=0 to disable this message.
[36m(raylet)[0m Spilled 5820 MiB, 18534 objects, write throughput 2704 MiB/s.
[36m(raylet)[0m Spilled 16885 MiB, 53774 objects, write throughput 3515 MiB/s.[32m [repeated 2x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)[0m
[36m(raylet)[0m Spilled 32986 MiB, 105039 objects, write throughput 2872 MiB/s.


engine:  ray  vox_per_chunk:  8  time:  127.1330931186676


KeyboardInterrupt: 

In [None]:
csdm.fit(data, mask=brain_mask_data, engine="ray", vox_per_chunk=(2**10))