diff --git a/CerebNet/datasets/utils.py b/CerebNet/datasets/utils.py index d8241dcf0..0702f5b53 100644 --- a/CerebNet/datasets/utils.py +++ b/CerebNet/datasets/utils.py @@ -24,6 +24,9 @@ from numpy import typing as npt from FastSurferCNN.data_loader.conform import getscale, scalecrop +from FastSurferCNN.utils import logging + +logger = logging.getLogger(__name__) CLASS_NAMES = { "Background": 0, @@ -215,14 +218,12 @@ def bounding_volume_offset( else None ) if img_shape is not None: - offset = tuple( - min(max(0, o), imgs - ts) - for o, ts, imgs in zip(offset, target_img_size, img_shape, strict=False) - ) - if any(o < 0 for o in offset): - raise RuntimeError( - f"Insufficient image size {img_shape} for target image size {target_img_size}" - ) + # try to set the offset so the bounding volume is fully inside + _offset = list((max(0, o), imgs - ts) for o, ts, imgs in zip(offset, target_img_size, img_shape, strict=False)) + # if it does not fit fully inside, warn + if any(min(left, right) < 0 for left, right in _offset): + logger.warning(f"The image is not large enough to cut a {target_img_size} patch, padding!") + offset = tuple(min(left, right) if min(left, right) >= 0 else int((left + right)/2) for left, right in _offset) return offset diff --git a/CerebNet/inference.py b/CerebNet/inference.py index 25449a0a5..1edeea649 100644 --- a/CerebNet/inference.py +++ b/CerebNet/inference.py @@ -32,12 +32,7 @@ from FastSurferCNN.data_loader.conform import crop_transform from FastSurferCNN.utils import PLANES, Plane, logging from FastSurferCNN.utils.arg_types import ImageSizeOption, OrientationType -from FastSurferCNN.utils.common import ( - SerialExecutor, - SubjectDirectory, - SubjectList, - find_device, -) +from FastSurferCNN.utils.common import SerialExecutor, SubjectDirectory, SubjectList, find_device from FastSurferCNN.utils.mapper import JsonColorLookupTable, Mapper, TSVLookupTable from FastSurferCNN.utils.threads import get_num_threads @@ -122,6 +117,7 @@ def __init__( viewagg_device, flag_name="viewagg_device", min_memory=2 * (2**30), + default_cuda_device=_device, ) self.batch_size = cfg.TEST.BATCH_SIZE @@ -409,7 +405,7 @@ def _get_subject_dataset( seg, seg_data = _seg.result() conf_file, conf_img, conf_data = _conf_img.result() - if np.allclose(conf_img.header.get_zooms(), 1.0, atol=0.01): + if not np.allclose(conf_img.header.get_zooms(), 1.0, atol=0.01): logger.warning( "CerebNet does not support images that are not conformed to 1.0mm. We detected a voxel sizes of " f"{tuple(conf_img.header.get_zooms())} in {conf_file}!" diff --git a/CerebNet/run_prediction.py b/CerebNet/run_prediction.py index d8d432f8b..0866afd92 100644 --- a/CerebNet/run_prediction.py +++ b/CerebNet/run_prediction.py @@ -91,9 +91,9 @@ def _vox_size(a): advanced.add_argument( "--vox_size", - choices=("1", "1.0", "none"), + choices=(1.0, None), type=_vox_size, - default=1, + default=1.0, dest="vox_size", help="Choose the voxelsize to process, CerebNet only supports 1 or 'none' to ignore the voxelsize. ", ) @@ -168,9 +168,7 @@ def main(args: argparse.Namespace) -> int | str: get_checkpoints(args.ckpt_ax, args.ckpt_cor, args.ckpt_sag, urls=urls) # Check input and output options and get all subjects of interest - subjects = SubjectList( - args, asegdkt_segfile="pred_name", segfile="cereb_segfile", **subjects_kwargs, - ) + subjects = SubjectList(args, asegdkt_segfile="pred_name", segfile="cereb_segfile", **subjects_kwargs) try: tester = Inference( diff --git a/FastSurferCNN/data_loader/conform.py b/FastSurferCNN/data_loader/conform.py index e403dfc28..87cc07b5e 100644 --- a/FastSurferCNN/data_loader/conform.py +++ b/FastSurferCNN/data_loader/conform.py @@ -16,6 +16,7 @@ # IMPORTS import argparse +import re import sys from collections.abc import Callable, Iterable, Sequence from typing import TYPE_CHECKING, Literal, TypeVar, cast @@ -35,7 +36,7 @@ class Tensor: pass from FastSurferCNN.utils import logging -from FastSurferCNN.utils.arg_types import ImageSizeOption, OrientationType, VoxSizeOption +from FastSurferCNN.utils.arg_types import ImageSizeOption, OrientationType, StrictOrientationType, VoxSizeOption from FastSurferCNN.utils.arg_types import float_gt_zero_and_le_one as __conform_to_one_mm from FastSurferCNN.utils.arg_types import img_size as __img_size from FastSurferCNN.utils.arg_types import orientation as __orientation @@ -260,7 +261,7 @@ def options_parse(): def to_target_orientation( image_data: _TA, source_affine: npt.NDArray[float], - target_orientation: str, + target_orientation: StrictOrientationType, ) -> tuple[_TA, Callable[[_TB], _TB]]: """ Reorder and flip image_data such that the data is in orientation. This will always be without interpolation. @@ -271,7 +272,7 @@ def to_target_orientation( The image data to reorder/flip. source_affine : npt.NDArray[float] The affine to detect the reorientation operations. - target_orientation : str + target_orientation : StrictOrientationType The target orientation to reorient to. Returns @@ -281,13 +282,7 @@ def to_target_orientation( Callable[[np.ndarray], np.ndarray], Callable[[torch.Tensor], torch.Tensor] A function that flips and reorders the data back (returns same type as output). """ - from nibabel.orientations import axcodes2ornt, io_orientation, ornt_transform - - source_ornt = io_orientation(source_affine) - target_ornt = axcodes2ornt(target_orientation) - - reorient_ornt = ornt_transform(source_ornt, target_ornt) - unorient_ornt = ornt_transform(target_ornt, source_ornt) + reorient_ornt, unorient_ornt = orientation_to_ornts(source_affine, target_orientation) if np.any([reorient_ornt[:, 1] != 1, reorient_ornt[:, 0] != np.arange(reorient_ornt.shape[0])]): # is not lia yet def back_to_native(data: _TB) -> _TB: @@ -301,6 +296,36 @@ def do_nothing(data: _TB) -> _TB: return image_data, do_nothing +def orientation_to_ornts( + source_affine: npt.NDArray[float], + target_orientation: StrictOrientationType, +) -> tuple[npt.NDArray[int], npt.NDArray[int]]: + """ + Determine the nibabel `ornt` Array to reorder and flip data from source_affine such that the data is in orientation. + + Parameters + ---------- + source_affine : npt.NDArray[float] + The affine to detect the reorientation operations. + target_orientation : StrictOrientationType + The target orientation to reorient to. + + Returns + ------- + npt.NDArray[int] + The `ornt` transform from source_affine to target_orientation. + npt.NDArray[int] + The `ornt` transform back from target_orientation to source_affine. + """ + from nibabel.orientations import axcodes2ornt, io_orientation, ornt_transform + + source_ornt = io_orientation(source_affine) + target_ornt = axcodes2ornt(target_orientation.upper()) + reorient_ornt = ornt_transform(source_ornt, target_ornt) + unorient_ornt = ornt_transform(target_ornt, source_ornt) + return reorient_ornt.astype(int), unorient_ornt.astype(int) + + def apply_orientation(arr: _TB | npt.ArrayLike, ornt: npt.NDArray[int]) -> _TB: """ Apply transformations implied by `ornt` to the first n axes of the array `arr`. @@ -797,14 +822,11 @@ def prepare_mgh_header( re_order_axes = [0, 1, 2] rot_scale_mat = img.affine[:3, :3] else: - in_ornt = nib.orientations.io_orientation(img.affine) - out_ornt = nib.orientations.axcodes2ornt(orientation[-3:].upper()) - LOGGER.debug(f"{nib.orientations.ornt2axcodes(in_ornt)} => {nib.orientations.ornt2axcodes(out_ornt)}") - _ornt_transform = nib.orientations.ornt_transform(in_ornt, out_ornt) - LOGGER.debug(str(_ornt_transform)) - re_order_axes = _ornt_transform[:, 0].astype(int) + _ornt_transform, _ = orientation_to_ornts(img.affine, orientation[-3:]) + re_order_axes = _ornt_transform[:, 0] if len(orientation) == 3: # lia, ras, etc # this is a 3x3 matrix + out_ornt = nib.orientations.axcodes2ornt(orientation[-3:].upper()) rot_scale_mat = nib.orientations.inv_ornt_aff(out_ornt, source_img_shape)[:3, :3] else: # soft lia, ras, .... aff = _ornt_transform[:, 1][None] * img.affine[:3, :3] @@ -943,43 +965,49 @@ def is_conform( if len(img.shape) > 3 and img.shape[3] != 1: raise ValueError(f"Multiple input frames ({img.shape[3]}) not supported!") - checks = {"Number of Dimensions 3": (len(img.shape) == 3, f"image ndim {img.ndim}")} + checks: dict[str, tuple[bool | Literal["IGNORED"], str]] = { + "Number of Dimensions 3": (img.ndim == 3, f"image ndim {img.ndim}") + } + # check dimensions - if img_size is not None and _img_size is not None: - # if not isinstance(_img_size, np.ndarray): - # raise TypeError("_img_size should be numpy.ndarray here") - img_size_criteria = f"Dimensions {'x'.join(map(str, _img_size[:3]))}" - is_correct_img_size = np.array_equal(np.asarray(img.shape[:3]), _img_size) - checks[img_size_criteria] = is_correct_img_size, f"image dimensions {img.shape}" + img_size_text = f"image dimensions {img.shape}" + if img_size in (None, "fov") or _img_size is None: + img_size_criteria = f"Dimensions {img_size}" + checks[img_size_criteria] = "IGNORED", img_size_text + else: + img_size_criteria = f"Dimensions {img_size}={'x'.join(map(str, _img_size[:3]))}" + checks[img_size_criteria] = np.array_equal(np.asarray(img.shape[:3]), _img_size), img_size_text # check voxel size, drop voxel sizes of dimension 4 if available izoom = np.array(img.header.get_zooms()) - if _vox_size is not None: + vox_size_text = f"image {'x'.join(map(str, izoom))}" + if _vox_size is None: + checks[f"Voxel Size {vox_size}"] = "IGNORED", vox_size_text + else: if not isinstance(_vox_size, np.ndarray): raise TypeError("_vox_size should be numpy.ndarray here") - is_correct_vox_size = np.allclose(izoom[:3], _vox_size, atol=vox_eps, rtol=0) - vox_size_criteria = f"Voxel Size {'x'.join(map(str, _vox_size))}" - checks[vox_size_criteria] = (is_correct_vox_size, f"image {'x'.join(map(str, izoom))}") + vox_size_criteria = f"Voxel Size {vox_size}={'x'.join(map(str, _vox_size))}" + checks[vox_size_criteria] = np.allclose(izoom[:3], _vox_size, atol=vox_eps, rtol=0), vox_size_text # check orientation LIA - if orientation is not None and orientation != "native": + affcode = "".join(nib.orientations.aff2axcodes(img.affine)) + with np.printoptions(precision=2, suppress=True): + orientation_text = "affine=" + re.sub("\\s+", " ", str(img.affine[:3, :3])) + f" => {affcode}" + if orientation is None or orientation == "native": + checks[f"Orientation {orientation}"] = "IGNORED", orientation_text + else: is_soft = not orientation.startswith("soft") - if is_correct_orientation := is_orientation(img.affine, orientation[-3:], is_soft, eps): - orientation_text = orientation - else: - from re import sub - - affcode = "".join(nib.orientations.aff2axcodes(img.affine)) - with np.printoptions(precision=2, suppress=True): - orientation_text = "affine: " + sub("\\s+", " ", str(img.affine[:3, :3])) + f" => {affcode}" - checks[f"Orientation {orientation.upper()}"] = (is_correct_orientation, orientation_text) + is_correct_orientation = is_orientation(img.affine, orientation[-3:], is_soft, eps) + checks[f"Orientation {orientation.upper()}"] = is_correct_orientation, orientation_text # check dtype uchar - if dtype is not None: + dtype_text = f"dtype {img.get_data_dtype().name}" + if dtype is None: + checks["Dtype None"] = "IGNORED", dtype_text + else: _dtype: npt.DTypeLike = to_dtype(dtype) _dtype_name = _dtype.name if hasattr(_dtype, "name") else str(dtype) - is_correct_dtype = np.issubdtype(img.get_data_dtype(), _dtype) - checks[f"Dtype {_dtype_name}"] = (is_correct_dtype, f"dtype {img.get_data_dtype().name}") + checks[f"Dtype {_dtype_name}"] = np.issubdtype(img.get_data_dtype(), _dtype), dtype_text _is_conform = all(map(lambda x: x[0], checks.values())) @@ -994,10 +1022,12 @@ def is_conform( conform_str = f"{np.round(_vox_size[0], decimals=2):.2f}-" else: with np.printoptions(precision=2, suppress=True): - conform_str = f"{str(_vox_size)}-" + conform_str = str(_vox_size) + "-" logger.info(f"A {conform_str}conformed image must satisfy the following criteria:") for condition, (value, message) in checks.items(): - logger.info(f" - {condition:<30}: {value if value else 'BUT ' + message}") + if isinstance(value, bool): + value = "GOOD" if value else "BUT" + logger.info(f" - {condition:<30}: {value} {message}") return _is_conform @@ -1119,7 +1149,7 @@ def conformed_vox_img_size( _conformed_vox_size = MAX_VOX_SIZE target_vox_size = np.full((3,), _conformed_vox_size) # this is similar to mri_convert --conform_size - elif isinstance(vox_size, float) and 0.0 < vox_size <= MAX_VOX_SIZE: + elif isinstance(vox_size, float | int) and 0.0 < vox_size <= MAX_VOX_SIZE: target_vox_size = np.full((3,), vox_size) elif vox_size is None: target_vox_size = None diff --git a/FastSurferCNN/data_loader/dataset.py b/FastSurferCNN/data_loader/dataset.py index e40487b54..0c0140b6d 100644 --- a/FastSurferCNN/data_loader/dataset.py +++ b/FastSurferCNN/data_loader/dataset.py @@ -14,7 +14,7 @@ # IMPORTS import time -from collections.abc import Callable +from collections.abc import Callable, Sequence from typing import Optional import h5py @@ -36,10 +36,12 @@ class MultiScaleOrigDataThickSlices(Dataset): Load MRI-Image and process it to correct format for network inference. """ + zoom : npt.NDArray[float] + def __init__( self, orig_data: npt.NDArray, - orig_zoom: npt.NDArray, + orig_zoom: npt.NDArray[float] | Sequence[float], cfg: yacs.config.CfgNode, transforms: Callable[[npt.NDArray[float]], npt.NDArray[float]] | None = None, ): @@ -65,16 +67,16 @@ def __init__( if self.plane == "sagittal": orig_data = du.transform_sagittal(orig_data) - self.zoom = orig_zoom[::-1][:2] + self.zoom = np.asarray(orig_zoom)[[2, 1]] logger.info(f"Loading Sagittal with input voxelsize {self.zoom}") elif self.plane == "axial": orig_data = du.transform_axial(orig_data) - self.zoom = orig_zoom[::-1][:2] + self.zoom = np.asarray(orig_zoom)[[2, 0]] logger.info(f"Loading Axial with input voxelsize {self.zoom}") else: - self.zoom = orig_zoom[:2] + self.zoom = np.asarray(orig_zoom)[[0, 1]] logger.info(f"Loading Coronal with input voxelsize {self.zoom}") # Create thick slices @@ -89,8 +91,6 @@ def _get_scale_factor(self) -> npt.NDArray[float]: Get scaling factor to match original resolution of input image to final resolution of FastSurfer base network. Input resolution is taken from voxel size in image header. - ToDO: This needs to be updated based on the plane we are looking at in case we - are dealing with non-isotropic images as inputs. Returns ------- @@ -236,9 +236,7 @@ def _get_scale_factor( Get scaling factor to match original resolution of input image to final resolution of FastSurfer base network. Input resolution is taken from voxel size in image header. - - ToDO: This needs to be updated based on the plane we are looking at in case we - are dealing with non-isotropic images as inputs. + Parameters ---------- @@ -258,9 +256,7 @@ def _get_scale_factor( scale = self.base_res / img_zoom if self.gn_noise: - scale += ( - torch.randn(1) * 0.1 + 0 - ) # needs to be changed to torch.tensor stuff + scale += torch.randn(1) * 0.1 + 0 # needs to be changed to torch.tensor stuff scale = torch.clamp(scale, min=0.1) return scale @@ -468,9 +464,6 @@ def _get_scale_factor(self, img_zoom): Input resolution is taken from voxel size in image header. - ToDO: This needs to be updated based on the plane we are looking at in case we - are dealing with non-isotropic images as inputs. - Parameters ---------- img_zoom : np.ndarray diff --git a/FastSurferCNN/models/interpolation_layer.py b/FastSurferCNN/models/interpolation_layer.py index dbe37d6fc..947207b79 100644 --- a/FastSurferCNN/models/interpolation_layer.py +++ b/FastSurferCNN/models/interpolation_layer.py @@ -91,9 +91,7 @@ def target_shape(self, target_shape: _T.Sequence[int] | None) -> None: """ Validate and set the target_shape. """ - tup_target_shape = ( - tuple(target_shape) if isinstance(target_shape, _T.Iterable) else tuple() - ) + tup_target_shape = tuple(target_shape) if isinstance(target_shape, _T.Iterable) else tuple() if tup_target_shape != self._target_shape: LOGGER.debug( f"Changing the target_shape of {type(self).__name__} to {tup_target_shape} from {self._target_shape}." diff --git a/FastSurferCNN/models/networks.py b/FastSurferCNN/models/networks.py index 2ef064ec8..ce974628e 100644 --- a/FastSurferCNN/models/networks.py +++ b/FastSurferCNN/models/networks.py @@ -361,18 +361,12 @@ def forward( if scale_factor_out is None: scale_factor_out = rescale_factor else: - scale_factor_out = ( - np.asarray(scale_factor_out) - * np.asarray(rescale_factor) - / np.asarray(scale_factor) - ) + scale_factor_out = np.asarray(scale_factor_out) * np.asarray(rescale_factor) / np.asarray(scale_factor) prior_target_shape = self.interpol2.target_shape self.interpol2.target_shape = skip_encoder_0.shape[2:] try: - decoder_output0, sf = self.interpol2( - decoder_output1, scale_factor_out, rescale=True - ) + decoder_output0, sf = self.interpol2(decoder_output1, scale_factor_out, rescale=True) finally: self.interpol2.target_shape = prior_target_shape outblock = self.outp_block(decoder_output0, skip_encoder_0) diff --git a/FastSurferCNN/run_prediction.py b/FastSurferCNN/run_prediction.py index 48bac2331..1cdc63fa2 100644 --- a/FastSurferCNN/run_prediction.py +++ b/FastSurferCNN/run_prediction.py @@ -41,7 +41,7 @@ import FastSurferCNN.reduce_to_aseg as rta from FastSurferCNN.data_loader import data_utils as du -from FastSurferCNN.data_loader.conform import conform, is_conform, to_target_orientation +from FastSurferCNN.data_loader.conform import conform, is_conform, orientation_to_ornts, to_target_orientation from FastSurferCNN.inference import Inference from FastSurferCNN.quick_qc import check_volume from FastSurferCNN.utils import PLANES, Plane, logging, parser_defaults @@ -237,7 +237,12 @@ def __init__( self.viewagg_device = self.device else: # check, if GPU is big enough to run view agg on it (this currently takes the memory of the passed device) - self.viewagg_device = find_device(viewagg_device, flag_name="viewagg_device", min_memory=4 * (2**30)) + self.viewagg_device = find_device( + viewagg_device, + flag_name="viewagg_device", + min_memory=4 * (2**30), + default_cuda_device=self.device, + ) LOGGER.info(f"Running view aggregation on {self.viewagg_device}") @@ -319,11 +324,8 @@ def conform_and_save_orig( if not is_conform(orig, **self.__conform_kwargs(verbose=True)): if (self.orientation is None or self.orientation == "native") and \ - not is_conform(orig, **self.__conform_kwargs(verbose=False, dtype=None)): - raise RuntimeError( - f"To store images in native image space, the input image {subject.orig_name} must have isotropic " - f"voxels." - ) + not is_conform(orig, **self.__conform_kwargs(verbose=False, dtype=None, vox_size="min")): + LOGGER.warning("Support for anisotropic voxels is experimental. Careful QC of all images is needed!") LOGGER.info("Conforming image...") orig = conform(orig, **self.__conform_kwargs()) orig_data = np.asanyarray(orig.dataobj) @@ -384,6 +386,8 @@ def get_prediction( orig_in_lia, back_to_native = to_target_orientation(orig_data, affine, target_orientation="LIA") shape = orig_in_lia.shape + (self.get_num_classes(),) + _ornt_transform, _ = orientation_to_ornts(affine, target_orientation="LIA") + _zoom = _zoom[_ornt_transform[:, 0]] pred_prob = torch.zeros(shape, **kwargs) @@ -392,7 +396,7 @@ def get_prediction( LOGGER.info(f"Run {plane} prediction") self.set_model(plane) # pred_prob is updated inplace to conserve memory - pred_prob = model.run(pred_prob, image_name, orig_in_lia, zoom, out=pred_prob) + pred_prob = model.run(pred_prob, image_name, orig_in_lia, _zoom, out=pred_prob) # Get hard predictions pred_classes = torch.argmax(pred_prob, 3) diff --git a/FastSurferCNN/utils/arg_types.py b/FastSurferCNN/utils/arg_types.py index 5780e2437..a2d1478c9 100644 --- a/FastSurferCNN/utils/arg_types.py +++ b/FastSurferCNN/utils/arg_types.py @@ -25,12 +25,14 @@ __axcode = ("rl", "ap", "si") __orders = tuple(permutations(range(3))) __flips = ((0, 1),) * 3 -__axcodes = ["".join(__axcode[ii[i]][j] for i, j in enumerate(jj)) for ii, jj in product(__orders, product(*__flips))] -VALID_ORIENTATIONS = ["native", *map(lambda x: "soft " + x, __axcodes), *__axcodes] +ORIENTATIONS = ["".join(__axcode[ii[i]][j] for i, j in enumerate(k)) for ii, k in product(__orders, product(*__flips))] +VALID_ORIENTATIONS = ["native", *map(lambda x: "soft " + x, ORIENTATIONS), *ORIENTATIONS] +StrictOrientationType = str OrientationType = str # future better typing, requires Python 3.11 (Syntax Error before that) # OrientationType = Literal[*VALID_ORIENTATIONS] +# StrictOrientationType = Literal[*ORIENTATIONS] @@ -128,10 +130,8 @@ def img_size(a: str) -> ImageSizeOption | None: argparse.ArgumentTypeError If the argument is not "fov", "auto" or convertible to an int greater than 0. """ - if a.lower() == "auto": - return "auto" - if a.lower() == "fov": - return "fov" + if a.lower() in ("auto", "fov"): + return cast(ImageSizeOption, a.lower()) if a.lower() == "any": return None try: diff --git a/FastSurferCNN/utils/common.py b/FastSurferCNN/utils/common.py index 2e39f7840..3b9a826f8 100644 --- a/FastSurferCNN/utils/common.py +++ b/FastSurferCNN/utils/common.py @@ -48,6 +48,7 @@ def find_device( device: torch.device | str = "auto", flag_name: str = "device", min_memory: int = 0, + default_cuda_device: torch.device | str = "cuda", ) -> torch.device: """ Create a device object from the device string passed. @@ -56,14 +57,14 @@ def find_device( Parameters ---------- - device : torch.device, str - The device to search for and test following pytorch device naming - conventions, e.g. 'cuda:0', 'cpu', etc. (default: 'auto'). + device : torch.device, str, default="auto" + The device to search for and test following pytorch device naming conventions, e.g. 'cuda:0', 'cpu', etc. flag_name : str Name of the corresponding flag for error messages (default: 'device'). min_memory : int - The minimum memory in bytes required for cuda-devices to - be valid (default: 0, works always). + The minimum memory in bytes required for cuda-devices to be valid (default: 0, works always). + default_cuda_device : str, torch.device, default="cuda" + Default cuda device to use, if cuda is available and device is "auto". Returns ------- @@ -85,20 +86,18 @@ def find_device( # If auto detect: if str(device) == "auto" or not device: # 1st check cuda / also finds AMD ROCm, then mps, finally cpu - device = "cuda" if has_cuda else "mps" if has_mps else "cpu" + device = default_cuda_device if has_cuda else "mps" if has_mps else "cpu" device = torch.device(device) if device.type == "cuda" and min_memory > 0: dev_num = torch.cuda.current_device() if device.index is None else device.index - total_gpu_memory = torch.cuda.get_device_properties(dev_num).__getattribute__( - "total_memory" - ) + total_gpu_memory = torch.cuda.get_device_properties(dev_num).__getattribute__("total_memory") if total_gpu_memory < min_memory: giga = 1024**3 - logger.info( - f"Found {total_gpu_memory/giga:.1f} GB GPU memory, but " - f"{min_memory/giga:.1f} GB was required." + logger.warning( + f"Found {total_gpu_memory/giga:.1f} GB GPU memory on device {device}, but {min_memory/giga:.1f} GB was " + f"required. Falling back to {flag_name} cpu." ) device = torch.device("cpu") diff --git a/HypVINN/inference.py b/HypVINN/inference.py index 5255fcff2..017f880e2 100644 --- a/HypVINN/inference.py +++ b/HypVINN/inference.py @@ -98,12 +98,11 @@ def __init__( else: # check, if GPU is big enough to run view agg on it # (this currently takes the memory of the passed device) - self.viewagg_device = torch.device( - find_device( - viewagg_device, - flag_name="viewagg_device", - min_memory=4 * (2 ** 30), - ) + self.viewagg_device = find_device( + viewagg_device, + flag_name="viewagg_device", + min_memory=4 * (2 ** 30), + default_cuda_device=self.device, ) logger.info(f"Running view aggregation on {self.viewagg_device}") diff --git a/HypVINN/run_prediction.py b/HypVINN/run_prediction.py index ef4811307..79fe56f12 100644 --- a/HypVINN/run_prediction.py +++ b/HypVINN/run_prediction.py @@ -568,9 +568,7 @@ def get_prediction( # Solution: make this script/function more similar to the optimized FastSurferVINN device, viewagg_device = model.get_device() - h, w, d = target_shape - - pred_shape = (h,w,d, model.get_num_classes()) + pred_shape = tuple(target_shape) + (model.get_num_classes(),) # Set up tensor to hold probabilities and run inference pred_prob = torch.zeros(pred_shape, dtype=torch.float, device=viewagg_device) for plane, opts in view_opts.items(): diff --git a/brun_fastsurfer.sh b/brun_fastsurfer.sh index c7a896aa5..7bbc9e845 100755 --- a/brun_fastsurfer.sh +++ b/brun_fastsurfer.sh @@ -540,6 +540,8 @@ function process_by_token() if [[ "$mode" == "surf" ]] ; then max_processes="$num_parallel_surf" ; else max_processes="$num_parallel_seg" ; fi while [[ "$read_in" == 1 ]] || [[ "${#subject_buffer[@]}" -gt 0 ]] do + # initialize res_args + res_args=() if [[ "$read_in" == 1 ]] then IFS="" @@ -578,20 +580,20 @@ function process_by_token() for name in "${device[@]}" ; do i="$(device2number "$name")" if [[ -z "${used_device[i]}" ]] || [[ -z "$(ps --no-headers "${used_device[i]}")" ]] ; then - res_args=("--device" "$name") ; used_device[i]=""; device_ready=$((device_ready + 1)) ; dev="$name" ; break + res_args+=("--device" "$name") ; used_device[i]=""; device_ready=$((device_ready + 1)) ; dev="$name" ; break fi done - else device_ready=$((device_ready + 1)) ; res_args=("--device" "${device[0]}") + else device_ready=$((device_ready + 1)) ; res_args+=("--device" "${device[0]}") ; dev="${device[0]}" fi if [[ "${#vdevice[@]}" -gt 1 ]] ; then # go through viewagg device assignments, if the processes finished, release the device assignment for name in "${vdevice[@]}" ; do i="$(device2number "$name")" if [[ -z "${used_vdevice[i]}" ]] || [[ -z "$(ps --no-headers "${used_vdevice[i]}")" ]] ; then - res_args=("--viewagg_device" "$name") ; used_vdevice[i]="" ; device_ready=$((device_ready + 1)) ; vdev="$name" ; break + res_args+=("--viewagg_device" "$name") ; used_vdevice[i]="" ; device_ready=$((device_ready + 1)) ; vdev="$name" ; break fi done - else device_ready=$((device_ready + 1)) ; res_args=("--viewagg_device" "${vdevice[0]}") + else device_ready=$((device_ready + 1)) ; res_args+=("--viewagg_device" "${vdevice[0]}") ; vdev="${vdevice[0]}" fi fi if [[ "$device_ready" -gt 1 ]]