Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions CerebNet/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
from numpy import typing as npt

from FastSurferCNN.data_loader.conform import getscale, scalecrop
from FastSurferCNN.utils import logging

logger = logging.getLogger(__name__)

CLASS_NAMES = {
"Background": 0,
Expand Down Expand Up @@ -215,14 +218,12 @@ def bounding_volume_offset(
else None
)
if img_shape is not None:
offset = tuple(
min(max(0, o), imgs - ts)
for o, ts, imgs in zip(offset, target_img_size, img_shape, strict=False)
)
if any(o < 0 for o in offset):
raise RuntimeError(
f"Insufficient image size {img_shape} for target image size {target_img_size}"
)
# try to set the offset so the bounding volume is fully inside
_offset = list((max(0, o), imgs - ts) for o, ts, imgs in zip(offset, target_img_size, img_shape, strict=False))
# if it does not fit fully inside, warn
if any(min(left, right) < 0 for left, right in _offset):
logger.warning(f"The image is not large enough to cut a {target_img_size} patch, padding!")
offset = tuple(min(left, right) if min(left, right) >= 0 else int((left + right)/2) for left, right in _offset)
return offset


Expand Down
10 changes: 3 additions & 7 deletions CerebNet/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,7 @@
from FastSurferCNN.data_loader.conform import crop_transform
from FastSurferCNN.utils import PLANES, Plane, logging
from FastSurferCNN.utils.arg_types import ImageSizeOption, OrientationType
from FastSurferCNN.utils.common import (
SerialExecutor,
SubjectDirectory,
SubjectList,
find_device,
)
from FastSurferCNN.utils.common import SerialExecutor, SubjectDirectory, SubjectList, find_device
from FastSurferCNN.utils.mapper import JsonColorLookupTable, Mapper, TSVLookupTable
from FastSurferCNN.utils.threads import get_num_threads

Expand Down Expand Up @@ -122,6 +117,7 @@ def __init__(
viewagg_device,
flag_name="viewagg_device",
min_memory=2 * (2**30),
default_cuda_device=_device,
)

self.batch_size = cfg.TEST.BATCH_SIZE
Expand Down Expand Up @@ -409,7 +405,7 @@ def _get_subject_dataset(
seg, seg_data = _seg.result()
conf_file, conf_img, conf_data = _conf_img.result()

if np.allclose(conf_img.header.get_zooms(), 1.0, atol=0.01):
if not np.allclose(conf_img.header.get_zooms(), 1.0, atol=0.01):
logger.warning(
"CerebNet does not support images that are not conformed to 1.0mm. We detected a voxel sizes of "
f"{tuple(conf_img.header.get_zooms())} in {conf_file}!"
Expand Down
8 changes: 3 additions & 5 deletions CerebNet/run_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ def _vox_size(a):

advanced.add_argument(
"--vox_size",
choices=("1", "1.0", "none"),
choices=(1.0, None),
type=_vox_size,
default=1,
default=1.0,
dest="vox_size",
help="Choose the voxelsize to process, CerebNet only supports 1 or 'none' to ignore the voxelsize. ",
)
Expand Down Expand Up @@ -168,9 +168,7 @@ def main(args: argparse.Namespace) -> int | str:
get_checkpoints(args.ckpt_ax, args.ckpt_cor, args.ckpt_sag, urls=urls)

# Check input and output options and get all subjects of interest
subjects = SubjectList(
args, asegdkt_segfile="pred_name", segfile="cereb_segfile", **subjects_kwargs,
)
subjects = SubjectList(args, asegdkt_segfile="pred_name", segfile="cereb_segfile", **subjects_kwargs)

try:
tester = Inference(
Expand Down
116 changes: 73 additions & 43 deletions FastSurferCNN/data_loader/conform.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

# IMPORTS
import argparse
import re
import sys
from collections.abc import Callable, Iterable, Sequence
from typing import TYPE_CHECKING, Literal, TypeVar, cast
Expand All @@ -35,7 +36,7 @@ class Tensor:
pass

from FastSurferCNN.utils import logging
from FastSurferCNN.utils.arg_types import ImageSizeOption, OrientationType, VoxSizeOption
from FastSurferCNN.utils.arg_types import ImageSizeOption, OrientationType, StrictOrientationType, VoxSizeOption
from FastSurferCNN.utils.arg_types import float_gt_zero_and_le_one as __conform_to_one_mm
from FastSurferCNN.utils.arg_types import img_size as __img_size
from FastSurferCNN.utils.arg_types import orientation as __orientation
Expand Down Expand Up @@ -260,7 +261,7 @@ def options_parse():
def to_target_orientation(
image_data: _TA,
source_affine: npt.NDArray[float],
target_orientation: str,
target_orientation: StrictOrientationType,
) -> tuple[_TA, Callable[[_TB], _TB]]:
"""
Reorder and flip image_data such that the data is in orientation. This will always be without interpolation.
Expand All @@ -271,7 +272,7 @@ def to_target_orientation(
The image data to reorder/flip.
source_affine : npt.NDArray[float]
The affine to detect the reorientation operations.
target_orientation : str
target_orientation : StrictOrientationType
The target orientation to reorient to.

Returns
Expand All @@ -281,13 +282,7 @@ def to_target_orientation(
Callable[[np.ndarray], np.ndarray], Callable[[torch.Tensor], torch.Tensor]
A function that flips and reorders the data back (returns same type as output).
"""
from nibabel.orientations import axcodes2ornt, io_orientation, ornt_transform

source_ornt = io_orientation(source_affine)
target_ornt = axcodes2ornt(target_orientation)

reorient_ornt = ornt_transform(source_ornt, target_ornt)
unorient_ornt = ornt_transform(target_ornt, source_ornt)
reorient_ornt, unorient_ornt = orientation_to_ornts(source_affine, target_orientation)

if np.any([reorient_ornt[:, 1] != 1, reorient_ornt[:, 0] != np.arange(reorient_ornt.shape[0])]): # is not lia yet
def back_to_native(data: _TB) -> _TB:
Expand All @@ -301,6 +296,36 @@ def do_nothing(data: _TB) -> _TB:
return image_data, do_nothing


def orientation_to_ornts(
source_affine: npt.NDArray[float],
target_orientation: StrictOrientationType,
) -> tuple[npt.NDArray[int], npt.NDArray[int]]:
"""
Determine the nibabel `ornt` Array to reorder and flip data from source_affine such that the data is in orientation.

Parameters
----------
source_affine : npt.NDArray[float]
The affine to detect the reorientation operations.
target_orientation : StrictOrientationType
The target orientation to reorient to.

Returns
-------
npt.NDArray[int]
The `ornt` transform from source_affine to target_orientation.
npt.NDArray[int]
The `ornt` transform back from target_orientation to source_affine.
"""
from nibabel.orientations import axcodes2ornt, io_orientation, ornt_transform

source_ornt = io_orientation(source_affine)
target_ornt = axcodes2ornt(target_orientation.upper())
reorient_ornt = ornt_transform(source_ornt, target_ornt)
unorient_ornt = ornt_transform(target_ornt, source_ornt)
return reorient_ornt.astype(int), unorient_ornt.astype(int)


def apply_orientation(arr: _TB | npt.ArrayLike, ornt: npt.NDArray[int]) -> _TB:
"""
Apply transformations implied by `ornt` to the first n axes of the array `arr`.
Expand Down Expand Up @@ -797,14 +822,11 @@ def prepare_mgh_header(
re_order_axes = [0, 1, 2]
rot_scale_mat = img.affine[:3, :3]
else:
in_ornt = nib.orientations.io_orientation(img.affine)
out_ornt = nib.orientations.axcodes2ornt(orientation[-3:].upper())
LOGGER.debug(f"{nib.orientations.ornt2axcodes(in_ornt)} => {nib.orientations.ornt2axcodes(out_ornt)}")
_ornt_transform = nib.orientations.ornt_transform(in_ornt, out_ornt)
LOGGER.debug(str(_ornt_transform))
re_order_axes = _ornt_transform[:, 0].astype(int)
_ornt_transform, _ = orientation_to_ornts(img.affine, orientation[-3:])
re_order_axes = _ornt_transform[:, 0]
if len(orientation) == 3: # lia, ras, etc
# this is a 3x3 matrix
out_ornt = nib.orientations.axcodes2ornt(orientation[-3:].upper())
rot_scale_mat = nib.orientations.inv_ornt_aff(out_ornt, source_img_shape)[:3, :3]
else: # soft lia, ras, ....
aff = _ornt_transform[:, 1][None] * img.affine[:3, :3]
Expand Down Expand Up @@ -943,43 +965,49 @@ def is_conform(
if len(img.shape) > 3 and img.shape[3] != 1:
raise ValueError(f"Multiple input frames ({img.shape[3]}) not supported!")

checks = {"Number of Dimensions 3": (len(img.shape) == 3, f"image ndim {img.ndim}")}
checks: dict[str, tuple[bool | Literal["IGNORED"], str]] = {
"Number of Dimensions 3": (img.ndim == 3, f"image ndim {img.ndim}")
}

# check dimensions
if img_size is not None and _img_size is not None:
# if not isinstance(_img_size, np.ndarray):
# raise TypeError("_img_size should be numpy.ndarray here")
img_size_criteria = f"Dimensions {'x'.join(map(str, _img_size[:3]))}"
is_correct_img_size = np.array_equal(np.asarray(img.shape[:3]), _img_size)
checks[img_size_criteria] = is_correct_img_size, f"image dimensions {img.shape}"
img_size_text = f"image dimensions {img.shape}"
if img_size in (None, "fov") or _img_size is None:
img_size_criteria = f"Dimensions {img_size}"
checks[img_size_criteria] = "IGNORED", img_size_text
else:
img_size_criteria = f"Dimensions {img_size}={'x'.join(map(str, _img_size[:3]))}"
checks[img_size_criteria] = np.array_equal(np.asarray(img.shape[:3]), _img_size), img_size_text

# check voxel size, drop voxel sizes of dimension 4 if available
izoom = np.array(img.header.get_zooms())
if _vox_size is not None:
vox_size_text = f"image {'x'.join(map(str, izoom))}"
if _vox_size is None:
checks[f"Voxel Size {vox_size}"] = "IGNORED", vox_size_text
else:
if not isinstance(_vox_size, np.ndarray):
raise TypeError("_vox_size should be numpy.ndarray here")
is_correct_vox_size = np.allclose(izoom[:3], _vox_size, atol=vox_eps, rtol=0)
vox_size_criteria = f"Voxel Size {'x'.join(map(str, _vox_size))}"
checks[vox_size_criteria] = (is_correct_vox_size, f"image {'x'.join(map(str, izoom))}")
vox_size_criteria = f"Voxel Size {vox_size}={'x'.join(map(str, _vox_size))}"
checks[vox_size_criteria] = np.allclose(izoom[:3], _vox_size, atol=vox_eps, rtol=0), vox_size_text

# check orientation LIA
if orientation is not None and orientation != "native":
affcode = "".join(nib.orientations.aff2axcodes(img.affine))
with np.printoptions(precision=2, suppress=True):
orientation_text = "affine=" + re.sub("\\s+", " ", str(img.affine[:3, :3])) + f" => {affcode}"
if orientation is None or orientation == "native":
checks[f"Orientation {orientation}"] = "IGNORED", orientation_text
else:
is_soft = not orientation.startswith("soft")
if is_correct_orientation := is_orientation(img.affine, orientation[-3:], is_soft, eps):
orientation_text = orientation
else:
from re import sub

affcode = "".join(nib.orientations.aff2axcodes(img.affine))
with np.printoptions(precision=2, suppress=True):
orientation_text = "affine: " + sub("\\s+", " ", str(img.affine[:3, :3])) + f" => {affcode}"
checks[f"Orientation {orientation.upper()}"] = (is_correct_orientation, orientation_text)
is_correct_orientation = is_orientation(img.affine, orientation[-3:], is_soft, eps)
checks[f"Orientation {orientation.upper()}"] = is_correct_orientation, orientation_text

# check dtype uchar
if dtype is not None:
dtype_text = f"dtype {img.get_data_dtype().name}"
if dtype is None:
checks["Dtype None"] = "IGNORED", dtype_text
else:
_dtype: npt.DTypeLike = to_dtype(dtype)
_dtype_name = _dtype.name if hasattr(_dtype, "name") else str(dtype)
is_correct_dtype = np.issubdtype(img.get_data_dtype(), _dtype)
checks[f"Dtype {_dtype_name}"] = (is_correct_dtype, f"dtype {img.get_data_dtype().name}")
checks[f"Dtype {_dtype_name}"] = np.issubdtype(img.get_data_dtype(), _dtype), dtype_text

_is_conform = all(map(lambda x: x[0], checks.values()))

Expand All @@ -994,10 +1022,12 @@ def is_conform(
conform_str = f"{np.round(_vox_size[0], decimals=2):.2f}-"
else:
with np.printoptions(precision=2, suppress=True):
conform_str = f"{str(_vox_size)}-"
conform_str = str(_vox_size) + "-"
logger.info(f"A {conform_str}conformed image must satisfy the following criteria:")
for condition, (value, message) in checks.items():
logger.info(f" - {condition:<30}: {value if value else 'BUT ' + message}")
if isinstance(value, bool):
value = "GOOD" if value else "BUT"
logger.info(f" - {condition:<30}: {value} {message}")
return _is_conform


Expand Down Expand Up @@ -1119,7 +1149,7 @@ def conformed_vox_img_size(
_conformed_vox_size = MAX_VOX_SIZE
target_vox_size = np.full((3,), _conformed_vox_size)
# this is similar to mri_convert --conform_size <float>
elif isinstance(vox_size, float) and 0.0 < vox_size <= MAX_VOX_SIZE:
elif isinstance(vox_size, float | int) and 0.0 < vox_size <= MAX_VOX_SIZE:
target_vox_size = np.full((3,), vox_size)
elif vox_size is None:
target_vox_size = None
Expand Down
25 changes: 9 additions & 16 deletions FastSurferCNN/data_loader/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

# IMPORTS
import time
from collections.abc import Callable
from collections.abc import Callable, Sequence
from typing import Optional

import h5py
Expand All @@ -36,10 +36,12 @@ class MultiScaleOrigDataThickSlices(Dataset):
Load MRI-Image and process it to correct format for network inference.
"""

zoom : npt.NDArray[float]

def __init__(
self,
orig_data: npt.NDArray,
orig_zoom: npt.NDArray,
orig_zoom: npt.NDArray[float] | Sequence[float],
cfg: yacs.config.CfgNode,
transforms: Callable[[npt.NDArray[float]], npt.NDArray[float]] | None = None,
):
Expand All @@ -65,16 +67,16 @@ def __init__(

if self.plane == "sagittal":
orig_data = du.transform_sagittal(orig_data)
self.zoom = orig_zoom[::-1][:2]
self.zoom = np.asarray(orig_zoom)[[2, 1]]
logger.info(f"Loading Sagittal with input voxelsize {self.zoom}")

elif self.plane == "axial":
orig_data = du.transform_axial(orig_data)
self.zoom = orig_zoom[::-1][:2]
self.zoom = np.asarray(orig_zoom)[[2, 0]]
logger.info(f"Loading Axial with input voxelsize {self.zoom}")

else:
self.zoom = orig_zoom[:2]
self.zoom = np.asarray(orig_zoom)[[0, 1]]
logger.info(f"Loading Coronal with input voxelsize {self.zoom}")

# Create thick slices
Expand All @@ -89,8 +91,6 @@ def _get_scale_factor(self) -> npt.NDArray[float]:
Get scaling factor to match original resolution of input image to final resolution of FastSurfer base network.

Input resolution is taken from voxel size in image header.
ToDO: This needs to be updated based on the plane we are looking at in case we
are dealing with non-isotropic images as inputs.

Returns
-------
Expand Down Expand Up @@ -236,9 +236,7 @@ def _get_scale_factor(
Get scaling factor to match original resolution of input image to final resolution of FastSurfer base network.

Input resolution is taken from voxel size in image header.

ToDO: This needs to be updated based on the plane we are looking at in case we
are dealing with non-isotropic images as inputs.


Parameters
----------
Expand All @@ -258,9 +256,7 @@ def _get_scale_factor(
scale = self.base_res / img_zoom

if self.gn_noise:
scale += (
torch.randn(1) * 0.1 + 0
) # needs to be changed to torch.tensor stuff
scale += torch.randn(1) * 0.1 + 0 # needs to be changed to torch.tensor stuff
scale = torch.clamp(scale, min=0.1)

return scale
Expand Down Expand Up @@ -468,9 +464,6 @@ def _get_scale_factor(self, img_zoom):

Input resolution is taken from voxel size in image header.

ToDO: This needs to be updated based on the plane we are looking at in case we
are dealing with non-isotropic images as inputs.

Parameters
----------
img_zoom : np.ndarray
Expand Down
4 changes: 1 addition & 3 deletions FastSurferCNN/models/interpolation_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,7 @@ def target_shape(self, target_shape: _T.Sequence[int] | None) -> None:
"""
Validate and set the target_shape.
"""
tup_target_shape = (
tuple(target_shape) if isinstance(target_shape, _T.Iterable) else tuple()
)
tup_target_shape = tuple(target_shape) if isinstance(target_shape, _T.Iterable) else tuple()
if tup_target_shape != self._target_shape:
LOGGER.debug(
f"Changing the target_shape of {type(self).__name__} to {tup_target_shape} from {self._target_shape}."
Expand Down
Loading