Skip to content

Commit

Permalink
[Fix] Fix threading on DRBUDDI interface (#540)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattcieslak committed Mar 31, 2023
1 parent 5707897 commit 55982e3
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 11 deletions.
4 changes: 2 additions & 2 deletions qsiprep/interfaces/anatomical.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def _run_interface(self, runtime):
skulled_img, brainmask_img)

actual_brain_to_skull_ratio = brain_median / nonbrain_head_median
LOGGER.info("found brain to skull ratio:", actual_brain_to_skull_ratio)
LOGGER.info("found brain to skull ratio: %.3f", actual_brain_to_skull_ratio)
desat_data = skulled_img.get_fdata(dtype=np.float32).copy()
adjustment = 1.
if actual_brain_to_skull_ratio < self.inputs.brain_to_skull_ratio:
Expand Down Expand Up @@ -375,4 +375,4 @@ def clip_values(values):
in_brain_median = np.median(head_data[brain_mask])
non_brain_head_median = np.median(head_data[nbmask > 0])

return in_brain_median, non_brain_head_median
return in_brain_median, non_brain_head_median
44 changes: 38 additions & 6 deletions qsiprep/interfaces/tortoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,40 @@
"\[learning_rate=\{1.\},cfs=\{20:1:0\},field_smoothing=\{4:0\}," \
"metrics=\{MSJac:CC\},restrict_constrain=\{0:0\}\]"

class TORTOISEInputSpec(BaseInterfaceInputSpec):
pass
class TORTOISEInputSpec(CommandLineInputSpec):
num_threads = traits.Int(desc="numpy of OMP threads")

class TORTOISECommandLine(CommandLine):
"""Support for TORTOISE commands that utilize OpenMP
Sets the environment variable 'OMP_NUM_THREADS' to the number
of threads specified by the input num_threads.
"""

input_spec = TORTOISEInputSpec
_num_threads = None

def __init__(self, **inputs):
super(TORTOISECommandLine, self).__init__(**inputs)
self.inputs.on_trait_change(self._num_threads_update, "num_threads")
if not self._num_threads:
self._num_threads = os.environ.get("OMP_NUM_THREADS", None)
if not self._num_threads:
self._num_threads = os.environ.get("NSLOTS", None)
if not isdefined(self.inputs.num_threads) and self._num_threads:
self.inputs.num_threads = int(self._num_threads)
self._num_threads_update()

def _num_threads_update(self):
if self.inputs.num_threads:
self.inputs.environ.update(
{"OMP_NUM_THREADS": str(self.inputs.num_threads)}
)

def run(self, **inputs):
if "num_threads" in inputs:
self.inputs.num_threads = inputs["num_threads"]
self._num_threads_update()
return super(TORTOISECommandLine, self).run(**inputs)


class _GatherDRBUDDIInputsInputSpec(TORTOISEInputSpec):
Expand Down Expand Up @@ -206,7 +238,6 @@ class _DRBUDDIInputSpec(TORTOISEInputSpec):
File(exists=True, copyfile=False),
argstr='-s %s',
help="Path(s) to anatomical image files. Can provide more than one. NO T1W's!!")
nthreads=traits.Int(1, usedefault=True, hash_files=False)
fieldmap_type = traits.Enum("epi", "rpe_series", mandatory=True)
blip_assignments = traits.List()
tensor_fit_bval_max = traits.Int(
Expand Down Expand Up @@ -261,7 +292,7 @@ class _DRBUDDIOutputSpec(TraitedSpec):
structural_image = File(exists=True)


class DRBUDDI(CommandLine):
class DRBUDDI(TORTOISECommandLine):
input_spec = _DRBUDDIInputSpec
output_spec = _DRBUDDIOutputSpec
_cmd = "DRBUDDI"
Expand Down Expand Up @@ -427,7 +458,7 @@ def _run_interface(self, runtime):
return runtime


class _GibbsInputSpec(SeriesPreprocReportInputSpec):
class _GibbsInputSpec(TORTOISEInputSpec, SeriesPreprocReportInputSpec):
"""Gibbs input_nifti output_nifti kspace_coverage(1,0.875,0.75) phase_encoding_dir nsh minW(optional) maxW(optional)"""
in_file = traits.File(
exists=True,
Expand Down Expand Up @@ -455,13 +486,14 @@ class _GibbsInputSpec(SeriesPreprocReportInputSpec):
position=4)
min_w = traits.Int()
mask = File()
num_threads = traits.Int(1, usedefault=True, nohash=True)


class _GibbsOutputSpec(SeriesPreprocReportOutputSpec):
out_file = File(exists=True)


class Gibbs(SeriesPreprocReport, CommandLine):
class Gibbs(SeriesPreprocReport, TORTOISECommandLine):
input_spec = _GibbsInputSpec
output_spec = _GibbsOutputSpec
_cmd = "Gibbs"
Expand Down
4 changes: 2 additions & 2 deletions qsiprep/workflows/dwi/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def get_buffernode():
if do_unringing:
if unringing_method == 'mrdegibbs':
degibbser = pe.Node(
MRDeGibbs(),
MRDeGibbs(nthreads=omp_nthreads),
name='degibbser',
n_procs=omp_nthreads)
elif unringing_method == 'rpg':
Expand Down Expand Up @@ -528,4 +528,4 @@ def get_merged_parameter(parameter_df, parameter_name,
if selection_mode == 'mode':
return col.mode()[0]

raise Exception("selection_mode must be 'all' or 'mode'")
raise Exception("selection_mode must be 'all' or 'mode'")
2 changes: 1 addition & 1 deletion qsiprep/workflows/fieldmap/drbuddi.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def init_drbuddi_wf(scan_groups, b0_threshold, raw_image_sdc, t2w_sdc, omp_nthre
drbuddi = pe.Node(
DRBUDDI(
fieldmap_type=fieldmap_info['suffix'],
nthreads=omp_nthreads,
num_threads=omp_nthreads,
sloppy=sloppy),
name='drbuddi',
n_procs=omp_nthreads)
Expand Down

0 comments on commit 55982e3

Please sign in to comment.