From c6b99dd2af21986a0ad6555c1a54921db8376222 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Fri, 16 May 2025 12:54:37 +0200 Subject: [PATCH 1/3] conform did not support int datatype voxel sizes, which only happen when manually passed. Fix CerebNet bounding boxes that are too small (smaller than the patch size). Fix when CerebNet was complaining about voxel sizes (inverted check) --- CerebNet/datasets/utils.py | 18 ++++++++++-------- CerebNet/inference.py | 2 +- CerebNet/run_prediction.py | 4 ++-- FastSurferCNN/data_loader/conform.py | 2 +- FastSurferCNN/utils/arg_types.py | 6 ++---- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/CerebNet/datasets/utils.py b/CerebNet/datasets/utils.py index d8241dcf0..dd3089eef 100644 --- a/CerebNet/datasets/utils.py +++ b/CerebNet/datasets/utils.py @@ -24,6 +24,10 @@ from numpy import typing as npt from FastSurferCNN.data_loader.conform import getscale, scalecrop +from FastSurferCNN.utils import logging + + +logger = logging.getLogger(__name__) CLASS_NAMES = { "Background": 0, @@ -215,14 +219,12 @@ def bounding_volume_offset( else None ) if img_shape is not None: - offset = tuple( - min(max(0, o), imgs - ts) - for o, ts, imgs in zip(offset, target_img_size, img_shape, strict=False) - ) - if any(o < 0 for o in offset): - raise RuntimeError( - f"Insufficient image size {img_shape} for target image size {target_img_size}" - ) + # try to set the offset so the bounding volume is fully inside + _offset = list((max(0, o), imgs - ts) for o, ts, imgs in zip(offset, target_img_size, img_shape, strict=False)) + # if it does not fit fully inside, warn + if any(min(left, right) < 0 for left, right in _offset): + logger.warning(f"The image is not large enough to cut a {target_img_size} patch, padding!") + offset = tuple(min(left, right) if min(left, right) >= 0 else int((left + right)/2) for left, right in _offset) return offset diff --git a/CerebNet/inference.py b/CerebNet/inference.py index 25449a0a5..4ff8e3796 100644 --- a/CerebNet/inference.py +++ b/CerebNet/inference.py @@ -409,7 +409,7 @@ def _get_subject_dataset( seg, seg_data = _seg.result() conf_file, conf_img, conf_data = _conf_img.result() - if np.allclose(conf_img.header.get_zooms(), 1.0, atol=0.01): + if not np.allclose(conf_img.header.get_zooms(), 1.0, atol=0.01): logger.warning( "CerebNet does not support images that are not conformed to 1.0mm. We detected a voxel sizes of " f"{tuple(conf_img.header.get_zooms())} in {conf_file}!" diff --git a/CerebNet/run_prediction.py b/CerebNet/run_prediction.py index d8d432f8b..25ce1a974 100644 --- a/CerebNet/run_prediction.py +++ b/CerebNet/run_prediction.py @@ -91,9 +91,9 @@ def _vox_size(a): advanced.add_argument( "--vox_size", - choices=("1", "1.0", "none"), + choices=(1.0, None), type=_vox_size, - default=1, + default=1.0, dest="vox_size", help="Choose the voxelsize to process, CerebNet only supports 1 or 'none' to ignore the voxelsize. ", ) diff --git a/FastSurferCNN/data_loader/conform.py b/FastSurferCNN/data_loader/conform.py index e403dfc28..f044c66cb 100644 --- a/FastSurferCNN/data_loader/conform.py +++ b/FastSurferCNN/data_loader/conform.py @@ -1119,7 +1119,7 @@ def conformed_vox_img_size( _conformed_vox_size = MAX_VOX_SIZE target_vox_size = np.full((3,), _conformed_vox_size) # this is similar to mri_convert --conform_size - elif isinstance(vox_size, float) and 0.0 < vox_size <= MAX_VOX_SIZE: + elif isinstance(vox_size, float | int) and 0.0 < vox_size <= MAX_VOX_SIZE: target_vox_size = np.full((3,), vox_size) elif vox_size is None: target_vox_size = None diff --git a/FastSurferCNN/utils/arg_types.py b/FastSurferCNN/utils/arg_types.py index 5780e2437..b85c3d607 100644 --- a/FastSurferCNN/utils/arg_types.py +++ b/FastSurferCNN/utils/arg_types.py @@ -128,10 +128,8 @@ def img_size(a: str) -> ImageSizeOption | None: argparse.ArgumentTypeError If the argument is not "fov", "auto" or convertible to an int greater than 0. """ - if a.lower() == "auto": - return "auto" - if a.lower() == "fov": - return "fov" + if a.lower() in ("auto", "fov"): + return cast(ImageSizeOption, a.lower()) if a.lower() == "any": return None try: From c47ae8e21ffbe4ab58f9e329b8c5ede17caa08a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Fri, 16 May 2025 17:26:16 +0200 Subject: [PATCH 2/3] Fix brun_fastsurfer.sh issue: brun did not properly assign the device in some cases (when multiple devices were passed). Fix viewagg_device issue, where auto always defaulted to device number 0 instead of to the device set with --device (for VINN, CerebNet, HypVINN). Extract function orientation_to_ornts from to_target_orientation for consistency, also introduce StrictOrientationType Add better output to is_conform that also shows the current situation (image size and voxel size and affine) to the output Fix the passing of the voxel size (zoom) to FastSurferVINN and HypVINN Some formatting fixes --- CerebNet/datasets/utils.py | 1 - CerebNet/inference.py | 8 +- CerebNet/run_prediction.py | 4 +- FastSurferCNN/data_loader/conform.py | 114 ++++++++++++-------- FastSurferCNN/data_loader/dataset.py | 12 ++- FastSurferCNN/models/interpolation_layer.py | 4 +- FastSurferCNN/models/networks.py | 10 +- FastSurferCNN/run_prediction.py | 13 ++- FastSurferCNN/utils/arg_types.py | 6 +- FastSurferCNN/utils/common.py | 23 ++-- HypVINN/inference.py | 11 +- HypVINN/run_prediction.py | 4 +- brun_fastsurfer.sh | 10 +- 13 files changed, 122 insertions(+), 98 deletions(-) diff --git a/CerebNet/datasets/utils.py b/CerebNet/datasets/utils.py index dd3089eef..0702f5b53 100644 --- a/CerebNet/datasets/utils.py +++ b/CerebNet/datasets/utils.py @@ -26,7 +26,6 @@ from FastSurferCNN.data_loader.conform import getscale, scalecrop from FastSurferCNN.utils import logging - logger = logging.getLogger(__name__) CLASS_NAMES = { diff --git a/CerebNet/inference.py b/CerebNet/inference.py index 4ff8e3796..1edeea649 100644 --- a/CerebNet/inference.py +++ b/CerebNet/inference.py @@ -32,12 +32,7 @@ from FastSurferCNN.data_loader.conform import crop_transform from FastSurferCNN.utils import PLANES, Plane, logging from FastSurferCNN.utils.arg_types import ImageSizeOption, OrientationType -from FastSurferCNN.utils.common import ( - SerialExecutor, - SubjectDirectory, - SubjectList, - find_device, -) +from FastSurferCNN.utils.common import SerialExecutor, SubjectDirectory, SubjectList, find_device from FastSurferCNN.utils.mapper import JsonColorLookupTable, Mapper, TSVLookupTable from FastSurferCNN.utils.threads import get_num_threads @@ -122,6 +117,7 @@ def __init__( viewagg_device, flag_name="viewagg_device", min_memory=2 * (2**30), + default_cuda_device=_device, ) self.batch_size = cfg.TEST.BATCH_SIZE diff --git a/CerebNet/run_prediction.py b/CerebNet/run_prediction.py index 25ce1a974..0866afd92 100644 --- a/CerebNet/run_prediction.py +++ b/CerebNet/run_prediction.py @@ -168,9 +168,7 @@ def main(args: argparse.Namespace) -> int | str: get_checkpoints(args.ckpt_ax, args.ckpt_cor, args.ckpt_sag, urls=urls) # Check input and output options and get all subjects of interest - subjects = SubjectList( - args, asegdkt_segfile="pred_name", segfile="cereb_segfile", **subjects_kwargs, - ) + subjects = SubjectList(args, asegdkt_segfile="pred_name", segfile="cereb_segfile", **subjects_kwargs) try: tester = Inference( diff --git a/FastSurferCNN/data_loader/conform.py b/FastSurferCNN/data_loader/conform.py index f044c66cb..87cc07b5e 100644 --- a/FastSurferCNN/data_loader/conform.py +++ b/FastSurferCNN/data_loader/conform.py @@ -16,6 +16,7 @@ # IMPORTS import argparse +import re import sys from collections.abc import Callable, Iterable, Sequence from typing import TYPE_CHECKING, Literal, TypeVar, cast @@ -35,7 +36,7 @@ class Tensor: pass from FastSurferCNN.utils import logging -from FastSurferCNN.utils.arg_types import ImageSizeOption, OrientationType, VoxSizeOption +from FastSurferCNN.utils.arg_types import ImageSizeOption, OrientationType, StrictOrientationType, VoxSizeOption from FastSurferCNN.utils.arg_types import float_gt_zero_and_le_one as __conform_to_one_mm from FastSurferCNN.utils.arg_types import img_size as __img_size from FastSurferCNN.utils.arg_types import orientation as __orientation @@ -260,7 +261,7 @@ def options_parse(): def to_target_orientation( image_data: _TA, source_affine: npt.NDArray[float], - target_orientation: str, + target_orientation: StrictOrientationType, ) -> tuple[_TA, Callable[[_TB], _TB]]: """ Reorder and flip image_data such that the data is in orientation. This will always be without interpolation. @@ -271,7 +272,7 @@ def to_target_orientation( The image data to reorder/flip. source_affine : npt.NDArray[float] The affine to detect the reorientation operations. - target_orientation : str + target_orientation : StrictOrientationType The target orientation to reorient to. Returns @@ -281,13 +282,7 @@ def to_target_orientation( Callable[[np.ndarray], np.ndarray], Callable[[torch.Tensor], torch.Tensor] A function that flips and reorders the data back (returns same type as output). """ - from nibabel.orientations import axcodes2ornt, io_orientation, ornt_transform - - source_ornt = io_orientation(source_affine) - target_ornt = axcodes2ornt(target_orientation) - - reorient_ornt = ornt_transform(source_ornt, target_ornt) - unorient_ornt = ornt_transform(target_ornt, source_ornt) + reorient_ornt, unorient_ornt = orientation_to_ornts(source_affine, target_orientation) if np.any([reorient_ornt[:, 1] != 1, reorient_ornt[:, 0] != np.arange(reorient_ornt.shape[0])]): # is not lia yet def back_to_native(data: _TB) -> _TB: @@ -301,6 +296,36 @@ def do_nothing(data: _TB) -> _TB: return image_data, do_nothing +def orientation_to_ornts( + source_affine: npt.NDArray[float], + target_orientation: StrictOrientationType, +) -> tuple[npt.NDArray[int], npt.NDArray[int]]: + """ + Determine the nibabel `ornt` Array to reorder and flip data from source_affine such that the data is in orientation. + + Parameters + ---------- + source_affine : npt.NDArray[float] + The affine to detect the reorientation operations. + target_orientation : StrictOrientationType + The target orientation to reorient to. + + Returns + ------- + npt.NDArray[int] + The `ornt` transform from source_affine to target_orientation. + npt.NDArray[int] + The `ornt` transform back from target_orientation to source_affine. + """ + from nibabel.orientations import axcodes2ornt, io_orientation, ornt_transform + + source_ornt = io_orientation(source_affine) + target_ornt = axcodes2ornt(target_orientation.upper()) + reorient_ornt = ornt_transform(source_ornt, target_ornt) + unorient_ornt = ornt_transform(target_ornt, source_ornt) + return reorient_ornt.astype(int), unorient_ornt.astype(int) + + def apply_orientation(arr: _TB | npt.ArrayLike, ornt: npt.NDArray[int]) -> _TB: """ Apply transformations implied by `ornt` to the first n axes of the array `arr`. @@ -797,14 +822,11 @@ def prepare_mgh_header( re_order_axes = [0, 1, 2] rot_scale_mat = img.affine[:3, :3] else: - in_ornt = nib.orientations.io_orientation(img.affine) - out_ornt = nib.orientations.axcodes2ornt(orientation[-3:].upper()) - LOGGER.debug(f"{nib.orientations.ornt2axcodes(in_ornt)} => {nib.orientations.ornt2axcodes(out_ornt)}") - _ornt_transform = nib.orientations.ornt_transform(in_ornt, out_ornt) - LOGGER.debug(str(_ornt_transform)) - re_order_axes = _ornt_transform[:, 0].astype(int) + _ornt_transform, _ = orientation_to_ornts(img.affine, orientation[-3:]) + re_order_axes = _ornt_transform[:, 0] if len(orientation) == 3: # lia, ras, etc # this is a 3x3 matrix + out_ornt = nib.orientations.axcodes2ornt(orientation[-3:].upper()) rot_scale_mat = nib.orientations.inv_ornt_aff(out_ornt, source_img_shape)[:3, :3] else: # soft lia, ras, .... aff = _ornt_transform[:, 1][None] * img.affine[:3, :3] @@ -943,43 +965,49 @@ def is_conform( if len(img.shape) > 3 and img.shape[3] != 1: raise ValueError(f"Multiple input frames ({img.shape[3]}) not supported!") - checks = {"Number of Dimensions 3": (len(img.shape) == 3, f"image ndim {img.ndim}")} + checks: dict[str, tuple[bool | Literal["IGNORED"], str]] = { + "Number of Dimensions 3": (img.ndim == 3, f"image ndim {img.ndim}") + } + # check dimensions - if img_size is not None and _img_size is not None: - # if not isinstance(_img_size, np.ndarray): - # raise TypeError("_img_size should be numpy.ndarray here") - img_size_criteria = f"Dimensions {'x'.join(map(str, _img_size[:3]))}" - is_correct_img_size = np.array_equal(np.asarray(img.shape[:3]), _img_size) - checks[img_size_criteria] = is_correct_img_size, f"image dimensions {img.shape}" + img_size_text = f"image dimensions {img.shape}" + if img_size in (None, "fov") or _img_size is None: + img_size_criteria = f"Dimensions {img_size}" + checks[img_size_criteria] = "IGNORED", img_size_text + else: + img_size_criteria = f"Dimensions {img_size}={'x'.join(map(str, _img_size[:3]))}" + checks[img_size_criteria] = np.array_equal(np.asarray(img.shape[:3]), _img_size), img_size_text # check voxel size, drop voxel sizes of dimension 4 if available izoom = np.array(img.header.get_zooms()) - if _vox_size is not None: + vox_size_text = f"image {'x'.join(map(str, izoom))}" + if _vox_size is None: + checks[f"Voxel Size {vox_size}"] = "IGNORED", vox_size_text + else: if not isinstance(_vox_size, np.ndarray): raise TypeError("_vox_size should be numpy.ndarray here") - is_correct_vox_size = np.allclose(izoom[:3], _vox_size, atol=vox_eps, rtol=0) - vox_size_criteria = f"Voxel Size {'x'.join(map(str, _vox_size))}" - checks[vox_size_criteria] = (is_correct_vox_size, f"image {'x'.join(map(str, izoom))}") + vox_size_criteria = f"Voxel Size {vox_size}={'x'.join(map(str, _vox_size))}" + checks[vox_size_criteria] = np.allclose(izoom[:3], _vox_size, atol=vox_eps, rtol=0), vox_size_text # check orientation LIA - if orientation is not None and orientation != "native": + affcode = "".join(nib.orientations.aff2axcodes(img.affine)) + with np.printoptions(precision=2, suppress=True): + orientation_text = "affine=" + re.sub("\\s+", " ", str(img.affine[:3, :3])) + f" => {affcode}" + if orientation is None or orientation == "native": + checks[f"Orientation {orientation}"] = "IGNORED", orientation_text + else: is_soft = not orientation.startswith("soft") - if is_correct_orientation := is_orientation(img.affine, orientation[-3:], is_soft, eps): - orientation_text = orientation - else: - from re import sub - - affcode = "".join(nib.orientations.aff2axcodes(img.affine)) - with np.printoptions(precision=2, suppress=True): - orientation_text = "affine: " + sub("\\s+", " ", str(img.affine[:3, :3])) + f" => {affcode}" - checks[f"Orientation {orientation.upper()}"] = (is_correct_orientation, orientation_text) + is_correct_orientation = is_orientation(img.affine, orientation[-3:], is_soft, eps) + checks[f"Orientation {orientation.upper()}"] = is_correct_orientation, orientation_text # check dtype uchar - if dtype is not None: + dtype_text = f"dtype {img.get_data_dtype().name}" + if dtype is None: + checks["Dtype None"] = "IGNORED", dtype_text + else: _dtype: npt.DTypeLike = to_dtype(dtype) _dtype_name = _dtype.name if hasattr(_dtype, "name") else str(dtype) - is_correct_dtype = np.issubdtype(img.get_data_dtype(), _dtype) - checks[f"Dtype {_dtype_name}"] = (is_correct_dtype, f"dtype {img.get_data_dtype().name}") + checks[f"Dtype {_dtype_name}"] = np.issubdtype(img.get_data_dtype(), _dtype), dtype_text _is_conform = all(map(lambda x: x[0], checks.values())) @@ -994,10 +1022,12 @@ def is_conform( conform_str = f"{np.round(_vox_size[0], decimals=2):.2f}-" else: with np.printoptions(precision=2, suppress=True): - conform_str = f"{str(_vox_size)}-" + conform_str = str(_vox_size) + "-" logger.info(f"A {conform_str}conformed image must satisfy the following criteria:") for condition, (value, message) in checks.items(): - logger.info(f" - {condition:<30}: {value if value else 'BUT ' + message}") + if isinstance(value, bool): + value = "GOOD" if value else "BUT" + logger.info(f" - {condition:<30}: {value} {message}") return _is_conform diff --git a/FastSurferCNN/data_loader/dataset.py b/FastSurferCNN/data_loader/dataset.py index e40487b54..f2c5a0272 100644 --- a/FastSurferCNN/data_loader/dataset.py +++ b/FastSurferCNN/data_loader/dataset.py @@ -14,7 +14,7 @@ # IMPORTS import time -from collections.abc import Callable +from collections.abc import Callable, Sequence from typing import Optional import h5py @@ -36,10 +36,12 @@ class MultiScaleOrigDataThickSlices(Dataset): Load MRI-Image and process it to correct format for network inference. """ + zoom : npt.NDArray[float] + def __init__( self, orig_data: npt.NDArray, - orig_zoom: npt.NDArray, + orig_zoom: npt.NDArray[float] | Sequence[float], cfg: yacs.config.CfgNode, transforms: Callable[[npt.NDArray[float]], npt.NDArray[float]] | None = None, ): @@ -65,16 +67,16 @@ def __init__( if self.plane == "sagittal": orig_data = du.transform_sagittal(orig_data) - self.zoom = orig_zoom[::-1][:2] + self.zoom = np.asarray(orig_zoom)[[2, 1]] logger.info(f"Loading Sagittal with input voxelsize {self.zoom}") elif self.plane == "axial": orig_data = du.transform_axial(orig_data) - self.zoom = orig_zoom[::-1][:2] + self.zoom = np.asarray(orig_zoom)[[2, 0]] logger.info(f"Loading Axial with input voxelsize {self.zoom}") else: - self.zoom = orig_zoom[:2] + self.zoom = np.asarray(orig_zoom)[[0, 1]] logger.info(f"Loading Coronal with input voxelsize {self.zoom}") # Create thick slices diff --git a/FastSurferCNN/models/interpolation_layer.py b/FastSurferCNN/models/interpolation_layer.py index dbe37d6fc..947207b79 100644 --- a/FastSurferCNN/models/interpolation_layer.py +++ b/FastSurferCNN/models/interpolation_layer.py @@ -91,9 +91,7 @@ def target_shape(self, target_shape: _T.Sequence[int] | None) -> None: """ Validate and set the target_shape. """ - tup_target_shape = ( - tuple(target_shape) if isinstance(target_shape, _T.Iterable) else tuple() - ) + tup_target_shape = tuple(target_shape) if isinstance(target_shape, _T.Iterable) else tuple() if tup_target_shape != self._target_shape: LOGGER.debug( f"Changing the target_shape of {type(self).__name__} to {tup_target_shape} from {self._target_shape}." diff --git a/FastSurferCNN/models/networks.py b/FastSurferCNN/models/networks.py index 2ef064ec8..ce974628e 100644 --- a/FastSurferCNN/models/networks.py +++ b/FastSurferCNN/models/networks.py @@ -361,18 +361,12 @@ def forward( if scale_factor_out is None: scale_factor_out = rescale_factor else: - scale_factor_out = ( - np.asarray(scale_factor_out) - * np.asarray(rescale_factor) - / np.asarray(scale_factor) - ) + scale_factor_out = np.asarray(scale_factor_out) * np.asarray(rescale_factor) / np.asarray(scale_factor) prior_target_shape = self.interpol2.target_shape self.interpol2.target_shape = skip_encoder_0.shape[2:] try: - decoder_output0, sf = self.interpol2( - decoder_output1, scale_factor_out, rescale=True - ) + decoder_output0, sf = self.interpol2(decoder_output1, scale_factor_out, rescale=True) finally: self.interpol2.target_shape = prior_target_shape outblock = self.outp_block(decoder_output0, skip_encoder_0) diff --git a/FastSurferCNN/run_prediction.py b/FastSurferCNN/run_prediction.py index 48bac2331..1b18158b2 100644 --- a/FastSurferCNN/run_prediction.py +++ b/FastSurferCNN/run_prediction.py @@ -41,7 +41,7 @@ import FastSurferCNN.reduce_to_aseg as rta from FastSurferCNN.data_loader import data_utils as du -from FastSurferCNN.data_loader.conform import conform, is_conform, to_target_orientation +from FastSurferCNN.data_loader.conform import conform, is_conform, orientation_to_ornts, to_target_orientation from FastSurferCNN.inference import Inference from FastSurferCNN.quick_qc import check_volume from FastSurferCNN.utils import PLANES, Plane, logging, parser_defaults @@ -237,7 +237,12 @@ def __init__( self.viewagg_device = self.device else: # check, if GPU is big enough to run view agg on it (this currently takes the memory of the passed device) - self.viewagg_device = find_device(viewagg_device, flag_name="viewagg_device", min_memory=4 * (2**30)) + self.viewagg_device = find_device( + viewagg_device, + flag_name="viewagg_device", + min_memory=4 * (2**30), + default_cuda_device=self.device, + ) LOGGER.info(f"Running view aggregation on {self.viewagg_device}") @@ -384,6 +389,8 @@ def get_prediction( orig_in_lia, back_to_native = to_target_orientation(orig_data, affine, target_orientation="LIA") shape = orig_in_lia.shape + (self.get_num_classes(),) + _ornt_transform, _ = orientation_to_ornts(affine, target_orientation="LIA") + _zoom = _zoom[_ornt_transform[:, 0]] pred_prob = torch.zeros(shape, **kwargs) @@ -392,7 +399,7 @@ def get_prediction( LOGGER.info(f"Run {plane} prediction") self.set_model(plane) # pred_prob is updated inplace to conserve memory - pred_prob = model.run(pred_prob, image_name, orig_in_lia, zoom, out=pred_prob) + pred_prob = model.run(pred_prob, image_name, orig_in_lia, _zoom, out=pred_prob) # Get hard predictions pred_classes = torch.argmax(pred_prob, 3) diff --git a/FastSurferCNN/utils/arg_types.py b/FastSurferCNN/utils/arg_types.py index b85c3d607..a2d1478c9 100644 --- a/FastSurferCNN/utils/arg_types.py +++ b/FastSurferCNN/utils/arg_types.py @@ -25,12 +25,14 @@ __axcode = ("rl", "ap", "si") __orders = tuple(permutations(range(3))) __flips = ((0, 1),) * 3 -__axcodes = ["".join(__axcode[ii[i]][j] for i, j in enumerate(jj)) for ii, jj in product(__orders, product(*__flips))] -VALID_ORIENTATIONS = ["native", *map(lambda x: "soft " + x, __axcodes), *__axcodes] +ORIENTATIONS = ["".join(__axcode[ii[i]][j] for i, j in enumerate(k)) for ii, k in product(__orders, product(*__flips))] +VALID_ORIENTATIONS = ["native", *map(lambda x: "soft " + x, ORIENTATIONS), *ORIENTATIONS] +StrictOrientationType = str OrientationType = str # future better typing, requires Python 3.11 (Syntax Error before that) # OrientationType = Literal[*VALID_ORIENTATIONS] +# StrictOrientationType = Literal[*ORIENTATIONS] diff --git a/FastSurferCNN/utils/common.py b/FastSurferCNN/utils/common.py index 2e39f7840..3b9a826f8 100644 --- a/FastSurferCNN/utils/common.py +++ b/FastSurferCNN/utils/common.py @@ -48,6 +48,7 @@ def find_device( device: torch.device | str = "auto", flag_name: str = "device", min_memory: int = 0, + default_cuda_device: torch.device | str = "cuda", ) -> torch.device: """ Create a device object from the device string passed. @@ -56,14 +57,14 @@ def find_device( Parameters ---------- - device : torch.device, str - The device to search for and test following pytorch device naming - conventions, e.g. 'cuda:0', 'cpu', etc. (default: 'auto'). + device : torch.device, str, default="auto" + The device to search for and test following pytorch device naming conventions, e.g. 'cuda:0', 'cpu', etc. flag_name : str Name of the corresponding flag for error messages (default: 'device'). min_memory : int - The minimum memory in bytes required for cuda-devices to - be valid (default: 0, works always). + The minimum memory in bytes required for cuda-devices to be valid (default: 0, works always). + default_cuda_device : str, torch.device, default="cuda" + Default cuda device to use, if cuda is available and device is "auto". Returns ------- @@ -85,20 +86,18 @@ def find_device( # If auto detect: if str(device) == "auto" or not device: # 1st check cuda / also finds AMD ROCm, then mps, finally cpu - device = "cuda" if has_cuda else "mps" if has_mps else "cpu" + device = default_cuda_device if has_cuda else "mps" if has_mps else "cpu" device = torch.device(device) if device.type == "cuda" and min_memory > 0: dev_num = torch.cuda.current_device() if device.index is None else device.index - total_gpu_memory = torch.cuda.get_device_properties(dev_num).__getattribute__( - "total_memory" - ) + total_gpu_memory = torch.cuda.get_device_properties(dev_num).__getattribute__("total_memory") if total_gpu_memory < min_memory: giga = 1024**3 - logger.info( - f"Found {total_gpu_memory/giga:.1f} GB GPU memory, but " - f"{min_memory/giga:.1f} GB was required." + logger.warning( + f"Found {total_gpu_memory/giga:.1f} GB GPU memory on device {device}, but {min_memory/giga:.1f} GB was " + f"required. Falling back to {flag_name} cpu." ) device = torch.device("cpu") diff --git a/HypVINN/inference.py b/HypVINN/inference.py index 5255fcff2..017f880e2 100644 --- a/HypVINN/inference.py +++ b/HypVINN/inference.py @@ -98,12 +98,11 @@ def __init__( else: # check, if GPU is big enough to run view agg on it # (this currently takes the memory of the passed device) - self.viewagg_device = torch.device( - find_device( - viewagg_device, - flag_name="viewagg_device", - min_memory=4 * (2 ** 30), - ) + self.viewagg_device = find_device( + viewagg_device, + flag_name="viewagg_device", + min_memory=4 * (2 ** 30), + default_cuda_device=self.device, ) logger.info(f"Running view aggregation on {self.viewagg_device}") diff --git a/HypVINN/run_prediction.py b/HypVINN/run_prediction.py index ef4811307..79fe56f12 100644 --- a/HypVINN/run_prediction.py +++ b/HypVINN/run_prediction.py @@ -568,9 +568,7 @@ def get_prediction( # Solution: make this script/function more similar to the optimized FastSurferVINN device, viewagg_device = model.get_device() - h, w, d = target_shape - - pred_shape = (h,w,d, model.get_num_classes()) + pred_shape = tuple(target_shape) + (model.get_num_classes(),) # Set up tensor to hold probabilities and run inference pred_prob = torch.zeros(pred_shape, dtype=torch.float, device=viewagg_device) for plane, opts in view_opts.items(): diff --git a/brun_fastsurfer.sh b/brun_fastsurfer.sh index c7a896aa5..7bbc9e845 100755 --- a/brun_fastsurfer.sh +++ b/brun_fastsurfer.sh @@ -540,6 +540,8 @@ function process_by_token() if [[ "$mode" == "surf" ]] ; then max_processes="$num_parallel_surf" ; else max_processes="$num_parallel_seg" ; fi while [[ "$read_in" == 1 ]] || [[ "${#subject_buffer[@]}" -gt 0 ]] do + # initialize res_args + res_args=() if [[ "$read_in" == 1 ]] then IFS="" @@ -578,20 +580,20 @@ function process_by_token() for name in "${device[@]}" ; do i="$(device2number "$name")" if [[ -z "${used_device[i]}" ]] || [[ -z "$(ps --no-headers "${used_device[i]}")" ]] ; then - res_args=("--device" "$name") ; used_device[i]=""; device_ready=$((device_ready + 1)) ; dev="$name" ; break + res_args+=("--device" "$name") ; used_device[i]=""; device_ready=$((device_ready + 1)) ; dev="$name" ; break fi done - else device_ready=$((device_ready + 1)) ; res_args=("--device" "${device[0]}") + else device_ready=$((device_ready + 1)) ; res_args+=("--device" "${device[0]}") ; dev="${device[0]}" fi if [[ "${#vdevice[@]}" -gt 1 ]] ; then # go through viewagg device assignments, if the processes finished, release the device assignment for name in "${vdevice[@]}" ; do i="$(device2number "$name")" if [[ -z "${used_vdevice[i]}" ]] || [[ -z "$(ps --no-headers "${used_vdevice[i]}")" ]] ; then - res_args=("--viewagg_device" "$name") ; used_vdevice[i]="" ; device_ready=$((device_ready + 1)) ; vdev="$name" ; break + res_args+=("--viewagg_device" "$name") ; used_vdevice[i]="" ; device_ready=$((device_ready + 1)) ; vdev="$name" ; break fi done - else device_ready=$((device_ready + 1)) ; res_args=("--viewagg_device" "${vdevice[0]}") + else device_ready=$((device_ready + 1)) ; res_args+=("--viewagg_device" "${vdevice[0]}") ; vdev="${vdevice[0]}" fi fi if [[ "$device_ready" -gt 1 ]] From d845922f38b42119b160f448566a6e41c8e20615 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20K=C3=BCgler?= Date: Mon, 19 May 2025 18:55:44 +0200 Subject: [PATCH 3/3] Fix anisotropic images warning texts Remove anisotropy TODOs --- FastSurferCNN/data_loader/dataset.py | 13 ++----------- FastSurferCNN/run_prediction.py | 7 ++----- 2 files changed, 4 insertions(+), 16 deletions(-) diff --git a/FastSurferCNN/data_loader/dataset.py b/FastSurferCNN/data_loader/dataset.py index f2c5a0272..0c0140b6d 100644 --- a/FastSurferCNN/data_loader/dataset.py +++ b/FastSurferCNN/data_loader/dataset.py @@ -91,8 +91,6 @@ def _get_scale_factor(self) -> npt.NDArray[float]: Get scaling factor to match original resolution of input image to final resolution of FastSurfer base network. Input resolution is taken from voxel size in image header. - ToDO: This needs to be updated based on the plane we are looking at in case we - are dealing with non-isotropic images as inputs. Returns ------- @@ -238,9 +236,7 @@ def _get_scale_factor( Get scaling factor to match original resolution of input image to final resolution of FastSurfer base network. Input resolution is taken from voxel size in image header. - - ToDO: This needs to be updated based on the plane we are looking at in case we - are dealing with non-isotropic images as inputs. + Parameters ---------- @@ -260,9 +256,7 @@ def _get_scale_factor( scale = self.base_res / img_zoom if self.gn_noise: - scale += ( - torch.randn(1) * 0.1 + 0 - ) # needs to be changed to torch.tensor stuff + scale += torch.randn(1) * 0.1 + 0 # needs to be changed to torch.tensor stuff scale = torch.clamp(scale, min=0.1) return scale @@ -470,9 +464,6 @@ def _get_scale_factor(self, img_zoom): Input resolution is taken from voxel size in image header. - ToDO: This needs to be updated based on the plane we are looking at in case we - are dealing with non-isotropic images as inputs. - Parameters ---------- img_zoom : np.ndarray diff --git a/FastSurferCNN/run_prediction.py b/FastSurferCNN/run_prediction.py index 1b18158b2..1cdc63fa2 100644 --- a/FastSurferCNN/run_prediction.py +++ b/FastSurferCNN/run_prediction.py @@ -324,11 +324,8 @@ def conform_and_save_orig( if not is_conform(orig, **self.__conform_kwargs(verbose=True)): if (self.orientation is None or self.orientation == "native") and \ - not is_conform(orig, **self.__conform_kwargs(verbose=False, dtype=None)): - raise RuntimeError( - f"To store images in native image space, the input image {subject.orig_name} must have isotropic " - f"voxels." - ) + not is_conform(orig, **self.__conform_kwargs(verbose=False, dtype=None, vox_size="min")): + LOGGER.warning("Support for anisotropic voxels is experimental. Careful QC of all images is needed!") LOGGER.info("Conforming image...") orig = conform(orig, **self.__conform_kwargs()) orig_data = np.asanyarray(orig.dataobj)