Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions CerebNet/data_loader/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

# IMPORTS
from numbers import Number
from typing import Generic, Literal, TypedDict, TypeVar
from typing import Literal, TypedDict, TypeVar, cast

import h5py
import numpy as np
Expand All @@ -28,13 +28,13 @@
from CerebNet.datasets.utils import bounding_volume_offset
from FastSurferCNN.data_loader.conform import Reorientation, crop_transform
from FastSurferCNN.data_loader.data_utils import get_thick_slices, transform_axial, transform_sagittal
from FastSurferCNN.utils import AffineMatrix4x4, Mask3d, Plane, Shape3d, ShapeType, logging, nibabelImage
from FastSurferCNN.utils import AffineMatrix4x4, Mask3d, Plane, Shape3d, logging, nibabelImage

ROIKeys = Literal["source_shape", "offsets", "target_shape"]
class LocalizerROI(TypedDict, Generic[ShapeType]):
source_shape: ShapeType
offsets: ShapeType
target_shape: ShapeType
class LocalizerROI(TypedDict):
source_shape: Shape3d
offsets: Shape3d
target_shape: Shape3d

NT = TypeVar("NT", bound=Number)

Expand Down Expand Up @@ -231,7 +231,7 @@ class SubjectDataset(Dataset):

"""

roi = LocalizerROI[Shape3d]
roi: LocalizerROI

def __init__(
self,
Expand Down Expand Up @@ -269,14 +269,14 @@ def __init__(
bbox = self.locate_mask_bbox(cereb_aseg_mask)

# create the roi from cereb_aseg (where labels after interpolation > 0.05 --> membership rounded to 1 decimal)
self.roi: LocalizerROI = {
"source_shape": img_org.shape,
"offsets": bounding_volume_offset(bbox, patch_size, image_shape=cereb_aseg_mask.shape),
"target_shape": patch_size,
}
self.roi = LocalizerROI(
source_shape=cast(Shape3d, img_org.shape),
offsets=cast(Shape3d, bounding_volume_offset(bbox, patch_size, image_shape=cereb_aseg_mask.shape)),
target_shape=cast(Shape3d, patch_size),
)
# crop the region of interest
img = crop_transform(self.img_org_data, offsets=self.roi["offsets"], target_shape=self.roi["target_shape"])
patch_vox2vox = np.concatenate([np.eye(4)[:3], np.append(self.roi["offsets"], 1)[None]], axis=0)
patch_vox2vox = np.concatenate([np.eye(4)[:, :3], np.append(self.roi["offsets"], 1)[:, None]], axis=1)
patch_vox2ras = self.img_org.affine @ patch_vox2vox
# reorient the data to lia
self.native_to_lia = Reorientation.from_target_orientation(patch_vox2ras, "soft LIA", self.roi["target_shape"])
Expand Down Expand Up @@ -320,7 +320,7 @@ def locate_mask_bbox(self, mask: Mask3d) -> tuple[int, int, int, int, int, int]:
def get_nibabel_img(self):
return self.img_org

def get_bounding_offsets(self) -> LocalizerROI[Shape3d]:
def get_bounding_offsets(self) -> LocalizerROI:
return self.roi

def set_plane(self, plane: Plane):
Expand Down
38 changes: 24 additions & 14 deletions FastSurferCNN/data_loader/conform.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,12 +373,22 @@ def from_target_orientation(
# by setting only the 3x3 rotational part here, we force from_target_affine to determine the translation
# (as center-conserving)
_reorder_ornt = axcodes2ornt(_target_orientation, AXCODES)
target_strict_affine: AffineMatrix3x3 = ornt2vox2vox(_reorder_ornt, (1,) * 3, vox_size)[:3, :3]

# strict version of the source vox2ras so can generate a soft transform
_source_affine = ornt2vox2vox(io_orientation(source_affine), shape,source_vox_size)

# first, target affine without voxelsize to determine the reordering of voxel sizes
target_strict_affine: AffineMatrix3x3 = ornt2vox2vox(_reorder_ornt, (1,) * 3, )[:3, :3]
matrix = np.pad(np.linalg.inv(_source_affine[:3, :3]) @ target_strict_affine, ((0, 1), (0, 1)))
reorder = io_orientation(matrix)

vox_size_in_target = vox_size[reorder.astype(np.int16)[:, 0]]

# second run, now with correct ordering of output voxel sizes in vox2ras
target_strict_affine: AffineMatrix3x3 = ornt2vox2vox(_reorder_ornt, (1,) * 3, vox_size_in_target)[:3, :3]
if not is_soft:
return cls.from_target_affine(source_affine, target_strict_affine, shape, target_shape, tol)

# soft transform
_source_affine = ornt2vox2vox(io_orientation(source_affine), shape,source_vox_size)
rot_mat = np.linalg.inv(_source_affine[:3, :3]) @ target_strict_affine

if np.allclose(rot_mat, np.round(rot_mat), atol=tol):
Expand Down Expand Up @@ -559,7 +569,7 @@ def inverse(self: SelfReorientation) -> SelfReorientation:
self.target_affine,
np.linalg.inv(self.vox2vox),
self.target_shape,
self.source_shape,
self.reorder_axes(self.source_shape),
self.tol,
)

Expand Down Expand Up @@ -948,8 +958,8 @@ def rescale(
def conform(
img: nibabelImage,
order: int = 1,
vox_size: VoxSizeOption | None = 1.0,
img_size: ImageSizeOption | None = 256,
vox_size: VoxSizeOption = 1.0,
img_size: ImageSizeOption = 256,
dtype: npt.DTypeLike | None = np.uint8,
orientation: OrientationType | None = "lia",
threshold_1mm: float | None = None,
Expand Down Expand Up @@ -1234,8 +1244,8 @@ def isclose(x, y, eps):

def is_conform(
img: nibabelImage,
vox_size: VoxSizeOption | None = 1.0,
img_size: ImageSizeOption | None = 256,
vox_size: VoxSizeOption = 1.0,
img_size: ImageSizeOption = 256,
dtype: npt.DTypeLike | None = np.uint8,
orientation: OrientationType | None = "lia",
verbose: bool = True,
Expand Down Expand Up @@ -1437,8 +1447,8 @@ def is_orientation(

def conformed_vox_img_size(
img: nibabelImage,
vox_size: VoxSizeOption | None,
img_size: ImageSizeOption | None,
vox_size: VoxSizeOption,
img_size: ImageSizeOption,
threshold_1mm: float | None = None,
vox_eps: float = 1e-4,
**kwargs,
Expand Down Expand Up @@ -1483,7 +1493,7 @@ def conformed_vox_img_size(
target_img_size: IntVector3d | None
MAX_VOX_SIZE = 1.0
MAX_DIMENSION = 256
# this is similar to mri_convert --conform_min
# this is similar to mri_convert --conform_min, note, vox_size == 'auto' is extra, but not covered by VoxSizeOption
if isinstance(vox_size, str) and (vox_size := cast(VoxSizeOption, vox_size.lower())) in ["min", "auto"]:
# find minimal voxel side length
min_vox_size = np.round(np.min(img.header.get_zooms()[:3]), decimals=int(np.ceil(-np.log10(vox_eps))))
Expand All @@ -1498,7 +1508,7 @@ def conformed_vox_img_size(
elif vox_size is None:
target_vox_size = None
else:
raise ValueError("Invalid value for vox_size passed.")
raise ValueError(f"Invalid value for vox_size passed: {vox_size}.")
if img_size is None and target_vox_size is not None:
# if we did specify a vox_size, no image size. use the field of view (which is essentially the old image size
# scaled with the voxel size)
Expand Down Expand Up @@ -1843,8 +1853,8 @@ class _OptKwargs(TypedDict, total=False):
threshold_1mm: float

class OptKwargs(_OptKwargs):
vox_size: VoxSizeOption | None
img_size: ImageSizeOption | None
vox_size: VoxSizeOption
img_size: ImageSizeOption
dtype: npt.DTypeLike | None
orientation: OrientationType | None
verbose: bool
Expand Down
16 changes: 8 additions & 8 deletions FastSurferCNN/utils/arg_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import nibabel as nib
import numpy as np

VoxSizeOption = float | Literal["min"]
ImageSizeOption = int | Literal["fov", "auto"]
VoxSizeOption = float | Literal["min"] | None
ImageSizeOption = int | Literal["fov", "auto"] | None

__axcode = ("rl", "ap", "si")
__orders = tuple(permutations(range(3)))
Expand Down Expand Up @@ -81,7 +81,7 @@ def string_to_bool(a: str) -> bool:
return bool(a)
return a.lower() in ("on", "true", "yes", "y", "1")

def vox_size(a: str | float | None) -> VoxSizeOption | None:
def vox_size(a: str | float | None) -> VoxSizeOption:
"""
Convert the vox_size argument to 'min' or a valid voxel size.

Expand All @@ -95,14 +95,14 @@ def vox_size(a: str | float | None) -> VoxSizeOption | None:
str or float or None
If 'auto' or 'min' is provided, it returns a string('auto' or 'min').
If a valid voxel size (between 0 and 1) is provided, it returns a float.
If 'any', it returns None.
If 'any' or 'keep', it returns None.

Raises
------
ValueError
If the argument is not "min", "auto" or convertible to a float between 0 and 1.
"""
if a is None or isinstance(a, str) and a.lower() == "any":
if a is None or isinstance(a, str) and a.lower() in ["any", "keep"]:
return None
if isinstance(a, str) and a.lower() in ["auto", "min"]:
return "min"
Expand All @@ -111,7 +111,7 @@ def vox_size(a: str | float | None) -> VoxSizeOption | None:
except ValueError as e:
raise ValueError(e.args[0] + " Additionally, vox_size may be 'min'.") from None

def img_size(a: str) -> ImageSizeOption | None:
def img_size(a: str) -> ImageSizeOption:
"""
Convert the img_size argument to 'fov', 'auto' or int as a valid image size.

Expand All @@ -125,7 +125,7 @@ def img_size(a: str) -> ImageSizeOption | None:
str or int
If 'auto' or 'fov' is provided, it returns a string('auto' or 'fov').
If a valid image size (greater than 0) is provided, it returns an int.
If 'any', it returns None.
If 'any' or 'keep', it returns None.

Raises
------
Expand All @@ -134,7 +134,7 @@ def img_size(a: str) -> ImageSizeOption | None:
"""
if a.lower() in ("auto", "fov"):
return cast(ImageSizeOption, a.lower())
if a.lower() == "any":
if a.lower() in ("any", "keep"):
return None
try:
return int_gt_zero(a)
Expand Down
27 changes: 1 addition & 26 deletions HypVINN/data_loader/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# IMPORTS
import numpy as np

from FastSurferCNN.data_loader.conform import AXCODES, conform, getscale, scalecrop
from FastSurferCNN.data_loader.conform import getscale, scalecrop
from FastSurferCNN.utils import ShapeType
from HypVINN.config.hypvinn_global_var import FS_CLASS_NAMES, HYPVINN_CLASS_NAMES, SAG2FULL_MAP, hyposubseg_labels

Expand All @@ -25,31 +25,6 @@
##


def reorient_img(img, ref_img):
"""
Reorient a Nibabel image based on the orientation of a reference nibabel image.

Parameters
----------
img : nibabel.Nifti1Image
Nibabel Image to reorient.
ref_img : nibabel.Nifti1Image
Reference orientation nibabel image.

Returns
-------
img : nibabel.Nifti1Image
Reoriented image.
"""
# if the affines are the same, no reorientation is required and we can skip this
if np.array_equal(ref_img.affine, img.affine):
return img
from nibabel.orientations import aff2axcodes
target_orientation = "soft " + "".join(aff2axcodes(ref_img.affine, AXCODES))
# returns the same class as img
return conform(img, orientation=target_orientation, vox_size=None, img_size=None, dtype=None, rescale=None)


def transform_axial2coronal(vol: np.ndarray, axial2coronal: bool = True) -> np.ndarray:
"""
Transforms a volume into the coronal axis and back.
Expand Down
49 changes: 34 additions & 15 deletions HypVINN/utils/img_processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@

import nibabel as nib
import numpy as np
from nibabel import Nifti1Image
from nibabel.orientations import aff2axcodes
from scipy import ndimage
from skimage.measure import label

import FastSurferCNN.utils.logging as logging
from FastSurferCNN.utils import AffineMatrix4x4, Image4d, nibabelImage
from FastSurferCNN.data_loader.conform import Reorientation, does_vox2vox_rot_require_interpolation
from FastSurferCNN.utils import AffineMatrix4x4, Image4d, nibabelHeader, nibabelImage
from HypVINN.data_loader.data_utils import hypo_map_subseg_2_fsseg

LOGGER = logging.get_logger(__name__)
Expand Down Expand Up @@ -70,24 +72,41 @@ def save_segmentation(
"""
from time import time
starttime = time()
from HypVINN.data_loader.data_utils import reorient_img

pred_arr, labels_cc = get_clean_labels(np.array(prediction, dtype=np.uint8))
# Mapped HypVINN labelst to FreeSurfer Hypvinn Labels
# Mapped HypVINN labels to FreeSurfer Hypvinn Labels
pred_arr = hypo_map_subseg_2_fsseg(pred_arr)
orig_img = cast(nibabelImage, nib.load(orig_path))

reorient = Reorientation.from_target_affine(ras_affine, orig_img.affine, labels_cc.shape)
LOGGER.info(f"Orig data orientation : {aff2axcodes(orig_img.affine)}")

for data, name in ((pred_arr, "segmentation"), (labels_cc, "mask")):
if not np.allclose(reorient.reorder_axes(np.asarray(data.shape)), orig_img.shape):
raise RuntimeError(f"Hypothalamus {name} and orig image have different shapes!")

if does_vox2vox_rot_require_interpolation(reorient.vox2vox):
LOGGER.warning("Hypothalamus mask and segmentation reorientation requires lossy interpolation.")

if save_mask:
mask_img = nib.Nifti1Image(labels_cc, affine=ras_affine, header=ras_header)
LOGGER.info(f"HypVINN Mask orientation: {aff2axcodes(mask_img.affine)}")
mask_img = reorient_img(mask_img, orig_img)
mask_header: nibabelHeader = Nifti1Image.header_class.from_header(orig_img.header)
mask_header.set_data_dtype(np.uint8)
mask_img = nib.Nifti1Image(
reorient(labels_cc.astype(np.uint8), order=0),
affine=orig_img.affine,
header=mask_header,
)
mask_img.set_data_dtype(np.float32)
LOGGER.info(f"HypVINN Mask after re-orientation: {aff2axcodes(mask_img.affine)}")
nib.save(mask_img, subject_dir / "mri" / mask_file)

pred_img = nib.Nifti1Image(pred_arr, affine=ras_affine, header=ras_header)
LOGGER.info(f"HypVINN Prediction orientation: {aff2axcodes(pred_img.affine)}")
pred_img = reorient_img(pred_img, orig_img)
pred_header: nibabelHeader = Nifti1Image.header_class.from_header(orig_img.header)
pred_header.set_data_dtype(np.uint8)
pred_img = nib.Nifti1Image(
reorient(pred_arr.astype(np.int16), order=0),
affine=orig_img.affine,
header=pred_header,
)
LOGGER.info(f"HypVINN Prediction after re-orientation: {aff2axcodes(pred_img.affine)}")
pred_img.set_data_dtype(np.int16) # Maximum value 984
nib.save(pred_img, subject_dir / "mri" / seg_file)
Expand Down Expand Up @@ -129,16 +148,16 @@ def save_logits(
The path where the logits were saved.

"""
from HypVINN.data_loader.data_utils import reorient_img
orig_img = cast(nibabelImage, nib.load(orig_path))
LOGGER.info(f"Orig data orientation: {aff2axcodes(orig_img.affine)}")
header: nibabelHeader = Nifti1Image.header_class.from_header(orig_img.header)
header.set_data_type(np.float32)
reorient = Reorientation.from_target_affine(ras_affine, orig_img.affine, logits.shape)
nifti_img = nib.Nifti1Image(
logits.astype(np.float32),
affine=ras_affine,
header=ras_header,
reorient(logits.astype(np.float32)),
affine=orig_img.affine,
header=header,
)
LOGGER.info(f"HypVINN logits orientation: {aff2axcodes(nifti_img.affine)}")
nifti_img = reorient_img(nifti_img, orig_img)
LOGGER.info(f"HypVINN logits after re-orientation: {aff2axcodes(nifti_img.affine)}")
nifti_img.set_data_dtype(np.float32)
save_as = save_dir / f"HypVINN_logits_{mode}.nii.gz"
Expand Down
Loading