diff --git a/CerebNet/data_loader/dataset.py b/CerebNet/data_loader/dataset.py index dacfc4743..91e9cb8d7 100644 --- a/CerebNet/data_loader/dataset.py +++ b/CerebNet/data_loader/dataset.py @@ -14,7 +14,7 @@ # IMPORTS from numbers import Number -from typing import Generic, Literal, TypedDict, TypeVar +from typing import Literal, TypedDict, TypeVar, cast import h5py import numpy as np @@ -28,13 +28,13 @@ from CerebNet.datasets.utils import bounding_volume_offset from FastSurferCNN.data_loader.conform import Reorientation, crop_transform from FastSurferCNN.data_loader.data_utils import get_thick_slices, transform_axial, transform_sagittal -from FastSurferCNN.utils import AffineMatrix4x4, Mask3d, Plane, Shape3d, ShapeType, logging, nibabelImage +from FastSurferCNN.utils import AffineMatrix4x4, Mask3d, Plane, Shape3d, logging, nibabelImage ROIKeys = Literal["source_shape", "offsets", "target_shape"] -class LocalizerROI(TypedDict, Generic[ShapeType]): - source_shape: ShapeType - offsets: ShapeType - target_shape: ShapeType +class LocalizerROI(TypedDict): + source_shape: Shape3d + offsets: Shape3d + target_shape: Shape3d NT = TypeVar("NT", bound=Number) @@ -231,7 +231,7 @@ class SubjectDataset(Dataset): """ - roi = LocalizerROI[Shape3d] + roi: LocalizerROI def __init__( self, @@ -269,14 +269,14 @@ def __init__( bbox = self.locate_mask_bbox(cereb_aseg_mask) # create the roi from cereb_aseg (where labels after interpolation > 0.05 --> membership rounded to 1 decimal) - self.roi: LocalizerROI = { - "source_shape": img_org.shape, - "offsets": bounding_volume_offset(bbox, patch_size, image_shape=cereb_aseg_mask.shape), - "target_shape": patch_size, - } + self.roi = LocalizerROI( + source_shape=cast(Shape3d, img_org.shape), + offsets=cast(Shape3d, bounding_volume_offset(bbox, patch_size, image_shape=cereb_aseg_mask.shape)), + target_shape=cast(Shape3d, patch_size), + ) # crop the region of interest img = crop_transform(self.img_org_data, offsets=self.roi["offsets"], target_shape=self.roi["target_shape"]) - patch_vox2vox = np.concatenate([np.eye(4)[:3], np.append(self.roi["offsets"], 1)[None]], axis=0) + patch_vox2vox = np.concatenate([np.eye(4)[:, :3], np.append(self.roi["offsets"], 1)[:, None]], axis=1) patch_vox2ras = self.img_org.affine @ patch_vox2vox # reorient the data to lia self.native_to_lia = Reorientation.from_target_orientation(patch_vox2ras, "soft LIA", self.roi["target_shape"]) @@ -320,7 +320,7 @@ def locate_mask_bbox(self, mask: Mask3d) -> tuple[int, int, int, int, int, int]: def get_nibabel_img(self): return self.img_org - def get_bounding_offsets(self) -> LocalizerROI[Shape3d]: + def get_bounding_offsets(self) -> LocalizerROI: return self.roi def set_plane(self, plane: Plane): diff --git a/FastSurferCNN/data_loader/conform.py b/FastSurferCNN/data_loader/conform.py index 1095ad52e..5792d73db 100644 --- a/FastSurferCNN/data_loader/conform.py +++ b/FastSurferCNN/data_loader/conform.py @@ -373,12 +373,22 @@ def from_target_orientation( # by setting only the 3x3 rotational part here, we force from_target_affine to determine the translation # (as center-conserving) _reorder_ornt = axcodes2ornt(_target_orientation, AXCODES) - target_strict_affine: AffineMatrix3x3 = ornt2vox2vox(_reorder_ornt, (1,) * 3, vox_size)[:3, :3] + + # strict version of the source vox2ras so can generate a soft transform + _source_affine = ornt2vox2vox(io_orientation(source_affine), shape,source_vox_size) + + # first, target affine without voxelsize to determine the reordering of voxel sizes + target_strict_affine: AffineMatrix3x3 = ornt2vox2vox(_reorder_ornt, (1,) * 3, )[:3, :3] + matrix = np.pad(np.linalg.inv(_source_affine[:3, :3]) @ target_strict_affine, ((0, 1), (0, 1))) + reorder = io_orientation(matrix) + + vox_size_in_target = vox_size[reorder.astype(np.int16)[:, 0]] + + # second run, now with correct ordering of output voxel sizes in vox2ras + target_strict_affine: AffineMatrix3x3 = ornt2vox2vox(_reorder_ornt, (1,) * 3, vox_size_in_target)[:3, :3] if not is_soft: return cls.from_target_affine(source_affine, target_strict_affine, shape, target_shape, tol) - # soft transform - _source_affine = ornt2vox2vox(io_orientation(source_affine), shape,source_vox_size) rot_mat = np.linalg.inv(_source_affine[:3, :3]) @ target_strict_affine if np.allclose(rot_mat, np.round(rot_mat), atol=tol): @@ -559,7 +569,7 @@ def inverse(self: SelfReorientation) -> SelfReorientation: self.target_affine, np.linalg.inv(self.vox2vox), self.target_shape, - self.source_shape, + self.reorder_axes(self.source_shape), self.tol, ) @@ -948,8 +958,8 @@ def rescale( def conform( img: nibabelImage, order: int = 1, - vox_size: VoxSizeOption | None = 1.0, - img_size: ImageSizeOption | None = 256, + vox_size: VoxSizeOption = 1.0, + img_size: ImageSizeOption = 256, dtype: npt.DTypeLike | None = np.uint8, orientation: OrientationType | None = "lia", threshold_1mm: float | None = None, @@ -1234,8 +1244,8 @@ def isclose(x, y, eps): def is_conform( img: nibabelImage, - vox_size: VoxSizeOption | None = 1.0, - img_size: ImageSizeOption | None = 256, + vox_size: VoxSizeOption = 1.0, + img_size: ImageSizeOption = 256, dtype: npt.DTypeLike | None = np.uint8, orientation: OrientationType | None = "lia", verbose: bool = True, @@ -1437,8 +1447,8 @@ def is_orientation( def conformed_vox_img_size( img: nibabelImage, - vox_size: VoxSizeOption | None, - img_size: ImageSizeOption | None, + vox_size: VoxSizeOption, + img_size: ImageSizeOption, threshold_1mm: float | None = None, vox_eps: float = 1e-4, **kwargs, @@ -1483,7 +1493,7 @@ def conformed_vox_img_size( target_img_size: IntVector3d | None MAX_VOX_SIZE = 1.0 MAX_DIMENSION = 256 - # this is similar to mri_convert --conform_min + # this is similar to mri_convert --conform_min, note, vox_size == 'auto' is extra, but not covered by VoxSizeOption if isinstance(vox_size, str) and (vox_size := cast(VoxSizeOption, vox_size.lower())) in ["min", "auto"]: # find minimal voxel side length min_vox_size = np.round(np.min(img.header.get_zooms()[:3]), decimals=int(np.ceil(-np.log10(vox_eps)))) @@ -1498,7 +1508,7 @@ def conformed_vox_img_size( elif vox_size is None: target_vox_size = None else: - raise ValueError("Invalid value for vox_size passed.") + raise ValueError(f"Invalid value for vox_size passed: {vox_size}.") if img_size is None and target_vox_size is not None: # if we did specify a vox_size, no image size. use the field of view (which is essentially the old image size # scaled with the voxel size) @@ -1843,8 +1853,8 @@ class _OptKwargs(TypedDict, total=False): threshold_1mm: float class OptKwargs(_OptKwargs): - vox_size: VoxSizeOption | None - img_size: ImageSizeOption | None + vox_size: VoxSizeOption + img_size: ImageSizeOption dtype: npt.DTypeLike | None orientation: OrientationType | None verbose: bool diff --git a/FastSurferCNN/utils/arg_types.py b/FastSurferCNN/utils/arg_types.py index ece07d9c2..5d4e9c584 100644 --- a/FastSurferCNN/utils/arg_types.py +++ b/FastSurferCNN/utils/arg_types.py @@ -19,8 +19,8 @@ import nibabel as nib import numpy as np -VoxSizeOption = float | Literal["min"] -ImageSizeOption = int | Literal["fov", "auto"] +VoxSizeOption = float | Literal["min"] | None +ImageSizeOption = int | Literal["fov", "auto"] | None __axcode = ("rl", "ap", "si") __orders = tuple(permutations(range(3))) @@ -81,7 +81,7 @@ def string_to_bool(a: str) -> bool: return bool(a) return a.lower() in ("on", "true", "yes", "y", "1") -def vox_size(a: str | float | None) -> VoxSizeOption | None: +def vox_size(a: str | float | None) -> VoxSizeOption: """ Convert the vox_size argument to 'min' or a valid voxel size. @@ -95,14 +95,14 @@ def vox_size(a: str | float | None) -> VoxSizeOption | None: str or float or None If 'auto' or 'min' is provided, it returns a string('auto' or 'min'). If a valid voxel size (between 0 and 1) is provided, it returns a float. - If 'any', it returns None. + If 'any' or 'keep', it returns None. Raises ------ ValueError If the argument is not "min", "auto" or convertible to a float between 0 and 1. """ - if a is None or isinstance(a, str) and a.lower() == "any": + if a is None or isinstance(a, str) and a.lower() in ["any", "keep"]: return None if isinstance(a, str) and a.lower() in ["auto", "min"]: return "min" @@ -111,7 +111,7 @@ def vox_size(a: str | float | None) -> VoxSizeOption | None: except ValueError as e: raise ValueError(e.args[0] + " Additionally, vox_size may be 'min'.") from None -def img_size(a: str) -> ImageSizeOption | None: +def img_size(a: str) -> ImageSizeOption: """ Convert the img_size argument to 'fov', 'auto' or int as a valid image size. @@ -125,7 +125,7 @@ def img_size(a: str) -> ImageSizeOption | None: str or int If 'auto' or 'fov' is provided, it returns a string('auto' or 'fov'). If a valid image size (greater than 0) is provided, it returns an int. - If 'any', it returns None. + If 'any' or 'keep', it returns None. Raises ------ @@ -134,7 +134,7 @@ def img_size(a: str) -> ImageSizeOption | None: """ if a.lower() in ("auto", "fov"): return cast(ImageSizeOption, a.lower()) - if a.lower() == "any": + if a.lower() in ("any", "keep"): return None try: return int_gt_zero(a) diff --git a/HypVINN/data_loader/data_utils.py b/HypVINN/data_loader/data_utils.py index 096bbf9b1..fc65ab3e2 100644 --- a/HypVINN/data_loader/data_utils.py +++ b/HypVINN/data_loader/data_utils.py @@ -16,7 +16,7 @@ # IMPORTS import numpy as np -from FastSurferCNN.data_loader.conform import AXCODES, conform, getscale, scalecrop +from FastSurferCNN.data_loader.conform import getscale, scalecrop from FastSurferCNN.utils import ShapeType from HypVINN.config.hypvinn_global_var import FS_CLASS_NAMES, HYPVINN_CLASS_NAMES, SAG2FULL_MAP, hyposubseg_labels @@ -25,31 +25,6 @@ ## -def reorient_img(img, ref_img): - """ - Reorient a Nibabel image based on the orientation of a reference nibabel image. - - Parameters - ---------- - img : nibabel.Nifti1Image - Nibabel Image to reorient. - ref_img : nibabel.Nifti1Image - Reference orientation nibabel image. - - Returns - ------- - img : nibabel.Nifti1Image - Reoriented image. - """ - # if the affines are the same, no reorientation is required and we can skip this - if np.array_equal(ref_img.affine, img.affine): - return img - from nibabel.orientations import aff2axcodes - target_orientation = "soft " + "".join(aff2axcodes(ref_img.affine, AXCODES)) - # returns the same class as img - return conform(img, orientation=target_orientation, vox_size=None, img_size=None, dtype=None, rescale=None) - - def transform_axial2coronal(vol: np.ndarray, axial2coronal: bool = True) -> np.ndarray: """ Transforms a volume into the coronal axis and back. diff --git a/HypVINN/utils/img_processing_utils.py b/HypVINN/utils/img_processing_utils.py index 725bb5d11..379f18d12 100644 --- a/HypVINN/utils/img_processing_utils.py +++ b/HypVINN/utils/img_processing_utils.py @@ -16,12 +16,14 @@ import nibabel as nib import numpy as np +from nibabel import Nifti1Image from nibabel.orientations import aff2axcodes from scipy import ndimage from skimage.measure import label import FastSurferCNN.utils.logging as logging -from FastSurferCNN.utils import AffineMatrix4x4, Image4d, nibabelImage +from FastSurferCNN.data_loader.conform import Reorientation, does_vox2vox_rot_require_interpolation +from FastSurferCNN.utils import AffineMatrix4x4, Image4d, nibabelHeader, nibabelImage from HypVINN.data_loader.data_utils import hypo_map_subseg_2_fsseg LOGGER = logging.get_logger(__name__) @@ -70,24 +72,41 @@ def save_segmentation( """ from time import time starttime = time() - from HypVINN.data_loader.data_utils import reorient_img pred_arr, labels_cc = get_clean_labels(np.array(prediction, dtype=np.uint8)) - # Mapped HypVINN labelst to FreeSurfer Hypvinn Labels + # Mapped HypVINN labels to FreeSurfer Hypvinn Labels pred_arr = hypo_map_subseg_2_fsseg(pred_arr) orig_img = cast(nibabelImage, nib.load(orig_path)) + + reorient = Reorientation.from_target_affine(ras_affine, orig_img.affine, labels_cc.shape) LOGGER.info(f"Orig data orientation : {aff2axcodes(orig_img.affine)}") + for data, name in ((pred_arr, "segmentation"), (labels_cc, "mask")): + if not np.allclose(reorient.reorder_axes(np.asarray(data.shape)), orig_img.shape): + raise RuntimeError(f"Hypothalamus {name} and orig image have different shapes!") + + if does_vox2vox_rot_require_interpolation(reorient.vox2vox): + LOGGER.warning("Hypothalamus mask and segmentation reorientation requires lossy interpolation.") + if save_mask: - mask_img = nib.Nifti1Image(labels_cc, affine=ras_affine, header=ras_header) - LOGGER.info(f"HypVINN Mask orientation: {aff2axcodes(mask_img.affine)}") - mask_img = reorient_img(mask_img, orig_img) + mask_header: nibabelHeader = Nifti1Image.header_class.from_header(orig_img.header) + mask_header.set_data_dtype(np.uint8) + mask_img = nib.Nifti1Image( + reorient(labels_cc.astype(np.uint8), order=0), + affine=orig_img.affine, + header=mask_header, + ) + mask_img.set_data_dtype(np.float32) LOGGER.info(f"HypVINN Mask after re-orientation: {aff2axcodes(mask_img.affine)}") nib.save(mask_img, subject_dir / "mri" / mask_file) - pred_img = nib.Nifti1Image(pred_arr, affine=ras_affine, header=ras_header) - LOGGER.info(f"HypVINN Prediction orientation: {aff2axcodes(pred_img.affine)}") - pred_img = reorient_img(pred_img, orig_img) + pred_header: nibabelHeader = Nifti1Image.header_class.from_header(orig_img.header) + pred_header.set_data_dtype(np.uint8) + pred_img = nib.Nifti1Image( + reorient(pred_arr.astype(np.int16), order=0), + affine=orig_img.affine, + header=pred_header, + ) LOGGER.info(f"HypVINN Prediction after re-orientation: {aff2axcodes(pred_img.affine)}") pred_img.set_data_dtype(np.int16) # Maximum value 984 nib.save(pred_img, subject_dir / "mri" / seg_file) @@ -129,16 +148,16 @@ def save_logits( The path where the logits were saved. """ - from HypVINN.data_loader.data_utils import reorient_img orig_img = cast(nibabelImage, nib.load(orig_path)) LOGGER.info(f"Orig data orientation: {aff2axcodes(orig_img.affine)}") + header: nibabelHeader = Nifti1Image.header_class.from_header(orig_img.header) + header.set_data_type(np.float32) + reorient = Reorientation.from_target_affine(ras_affine, orig_img.affine, logits.shape) nifti_img = nib.Nifti1Image( - logits.astype(np.float32), - affine=ras_affine, - header=ras_header, + reorient(logits.astype(np.float32)), + affine=orig_img.affine, + header=header, ) - LOGGER.info(f"HypVINN logits orientation: {aff2axcodes(nifti_img.affine)}") - nifti_img = reorient_img(nifti_img, orig_img) LOGGER.info(f"HypVINN logits after re-orientation: {aff2axcodes(nifti_img.affine)}") nifti_img.set_data_dtype(np.float32) save_as = save_dir / f"HypVINN_logits_{mode}.nii.gz"