Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 8 additions & 22 deletions CerebNet/data_loader/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ def __init__(
slice_thickness: int,
primary_slice: str | None = None,
):
from numpy.linalg import inv

self.slice_thickness = slice_thickness
self.transforms = Compose([ToTensor()])
self.img_org = img_org
Expand All @@ -252,8 +254,6 @@ def __init__(
# binarize the cerebellum from brain_seg
cereb_aseg_mask = utils.get_aseg_cereb_mask(np.asarray(brain_seg.dataobj))

from numpy.linalg import inv

affine = inv(brain_seg.affine) @ img_org.affine

# print(brain_seg.affine, img_org.affine)
Expand All @@ -264,27 +264,19 @@ def __init__(
)
from scipy.ndimage import affine_transform

cereb_aseg = affine_transform(
cereb_aseg_mask.astype(np.float32), affine, output_shape=img_org.shape
)
cereb_aseg = affine_transform(cereb_aseg_mask.astype(np.float32), affine, output_shape=img_org.shape)
cereb_aseg_mask = cereb_aseg > 0.5

bbox = self.locate_mask_bbox(cereb_aseg_mask)

# create the roi from cereb_aseg (where labels after interpolation > 0.05 --> membership rounded to 1 decimal)
self.roi: LocalizerROI = {
"source_shape": img_org.shape,
"offsets": bounding_volume_offset(
bbox, patch_size, image_shape=cereb_aseg_mask.shape
),
"offsets": bounding_volume_offset(bbox, patch_size, image_shape=cereb_aseg_mask.shape),
"target_shape": patch_size,
}
# crop the region of interest
img = crop_transform(
self.img_org_data,
offsets=self.roi["offsets"],
target_shape=self.roi["target_shape"],
)
img = crop_transform(self.img_org_data, offsets=self.roi["offsets"], target_shape=self.roi["target_shape"])
# reorient the data to lia
img_lia, self.back_to_native = to_target_orientation(img, self.img_org.affine, target_orientation="LIA")

Expand All @@ -298,13 +290,9 @@ def __init__(
}
for plane, data_i in data.items():
# data is transformed to 'plane'-direction in axis 2
thick_slices = get_thick_slices(
data_i, self.slice_thickness
) # [H, W, n_slices, C]
thick_slices = get_thick_slices(data_i, self.slice_thickness) # [H, W, n_slices, C]
# it seems x and y are flipped with respect to expectations here
self.images_per_plane[plane] = np.transpose(
thick_slices, (2, 0, 1, 3)
) # [n_slices, H, W, C]
self.images_per_plane[plane] = np.transpose(thick_slices, (2, 0, 1, 3)) # [n_slices, H, W, C]

def locate_mask_bbox(self, mask: npt.NDArray[bool]):
"""Find the largest connected component of the mask.
Expand All @@ -329,9 +317,7 @@ def get_bounding_offsets(self) -> LocalizerROI:
def set_plane(self, plane: Plane):
"""Set the active plane."""
if plane not in self.images_per_plane.keys():
raise ValueError(
f"Invalid plane name, must be in {tuple(self.images_per_plane.keys())}"
)
raise ValueError(f"Invalid plane name, must be in {tuple(self.images_per_plane.keys())}")
self._plane = plane

@property
Expand Down
71 changes: 24 additions & 47 deletions CerebNet/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class Inference:
cerebnet_labels: Mapper[str, int]
cereb_name2fs_id: Mapper[str, int]
freesurfer_name2id: Mapper[str, int]
cereb_id2fs_id: Mapper[int, int]

def __init__(
self,
Expand Down Expand Up @@ -129,9 +130,7 @@ def __init__(
self.viewagg_device = _viewagg_device


def prep_lut(
file: Path, *args, **kwargs,
) -> Future[TSVLookupTable | JsonColorLookupTable]:
def prep_lut(file: Path, *args, **kwargs) -> Future[TSVLookupTable | JsonColorLookupTable]:
_cls = TSVLookupTable
cls = {".json": JsonColorLookupTable, ".txt": _cls, ".tsv": _cls}
return self.pool.submit(cls[file.suffix], file, *args, **kwargs)
Expand All @@ -153,12 +152,8 @@ def lut_path(module: str, file: str) -> Path:

self.cerebnet_labels = _cerebnet_mapper.result().labelname2id()
self.freesurfer_name2id = fs_color_map.result().labelname2id()
cereb_name2fs_name: Mapper[str, str] = (
cereb2freesurfer_mapper.result().labelname2id()
)
cerebsag_name2cereb_name: Mapper[str, str] = (
sagittal_cereb2cereb_mapper.result().labelname2id()
)
cereb_name2fs_name: Mapper[str, str] = cereb2freesurfer_mapper.result().labelname2id()
cerebsag_name2cereb_name: Mapper[str, str] = sagittal_cereb2cereb_mapper.result().labelname2id()

cereb_id2name = self.cerebnet_labels.__reversed__()
self.cereb_name2fs_id = cereb_name2fs_name.chain(self.freesurfer_name2id)
Expand Down Expand Up @@ -204,13 +199,9 @@ def __load_model(cfg: "yacs.config.CfgNode", plane: Plane) -> torch.nn.Module:
return dict(zip(PLANES, self.pool.map(_load_model_func, PLANES), strict=False))

@torch.no_grad()
def _predict_single_subject(
self, subject_dataset: SubjectDataset
) -> dict[Plane, list[torch.Tensor]]:
"""Predict the classes based on a SubjectDataset."""
img_loader = DataLoader(
subject_dataset, batch_size=self.batch_size, shuffle=False
)
def _predict_single_subject(self, subject_dataset: SubjectDataset) -> dict[Plane, list[torch.Tensor]]:
"""Predict the classes based on a SubjectDataset (operates fully in LIA)."""
img_loader = DataLoader(subject_dataset, batch_size=self.batch_size, shuffle=False)
prediction_logits = {}
try:
for plane in PLANES:
Expand All @@ -220,8 +211,7 @@ def _predict_single_subject(
from CerebNet.data_loader.data_utils import slice_lia2ras, slice_ras2lia

for img in img_loader:
# CerebNet is trained on RAS+ conventions, so we need to map between
# lia (FastSurfer) and RAS+
# CerebNet is trained on RAS+ conventions, so we need to map between lia (FastSurfer) and RAS+
# map LIA 2 RAS
img = slice_lia2ras(plane, img, thick_slices=True)
batch = img.to(self.device)
Expand Down Expand Up @@ -367,9 +357,7 @@ def _save_cerebnet_seg(
if cerebnet_seg.shape != orig.shape:
raise RuntimeError("Cereb segmentation shape inconsistent with Orig shape!")
logger.info(f"Saving CerebNet cerebellum segmentation at {filename}")
return self.pool.submit(
save_image, orig.header, orig.affine, cerebnet_seg, filename, dtype=np.int16
)
return self.pool.submit(save_image, orig.header, orig.affine, cerebnet_seg, filename, dtype=np.int16)

def _get_subject_dataset(
self, subject: SubjectDirectory
Expand All @@ -381,7 +369,7 @@ def _get_subject_dataset(

from FastSurferCNN.data_loader.data_utils import load_image, load_maybe_conform

norm_file, norm_data, norm = None, None, None
norm_file, norm_data, _norm = None, None, None
if subject.has_attribute("cereb_statsfile") :
if not subject.can_resolve_attribute("cereb_statsfile"):
from FastSurferCNN.utils.parser_defaults import ALL_FLAGS
Expand All @@ -401,34 +389,31 @@ def _get_subject_dataset(

norm_file = subject.filename_by_attribute("norm_name")
# finally, load the bias field file
norm = self.pool.submit(load_maybe_conform, norm_file, norm_file, **self._conform_kwargs)
_norm = self.pool.submit(load_maybe_conform, norm_file, norm_file, **self._conform_kwargs)

# localization
if not subject.fileexists_by_attribute("asegdkt_segfile"):
raise RuntimeError(
f"The aseg.DKT-segmentation file '{subject.asegdkt_segfile}' did not "
f"exist, please run FastSurferVINN first."
)
seg = self.pool.submit(
load_image, subject.filename_by_attribute("asegdkt_segfile")
)
_seg = self.pool.submit(load_image, subject.filename_by_attribute("asegdkt_segfile"))
# create conformed image
conf_img = self.pool.submit(
_conf_img = self.pool.submit(
load_maybe_conform,
subject.filename_by_attribute("conf_name"),
subject.filename_by_attribute("orig_name"),
**self._conform_kwargs,
)

seg, seg_data = seg.result()
conf_file, conf_img, conf_data = conf_img.result()
seg, seg_data = _seg.result()
conf_file, conf_img, conf_data = _conf_img.result()

if np.allclose(conf_img.header.get_zooms(), 1.0, atol=0.01):
logger.warning(
"CerebNet does not support images that are not conformed to 1.0mm. We detected a voxel sizes of "
f"{tuple(conf_img.header.get_zooms())} in {conf_file}!"
)

subject_dataset = SubjectDataset(
img_org=conf_img,
brain_seg=seg,
Expand All @@ -437,8 +422,8 @@ def _get_subject_dataset(
# obsolete: primary_slice=self.cfg.DATA.PRIMARY_SLICE_DIR,
)
subject_dataset.transforms = ToTensorTest()
if norm is not None:
norm_file, _, norm_data = norm.result()
if _norm is not None:
norm_file, _, norm_data = _norm.result()
return norm_data, norm_file, subject_dataset

def run(self, subject_dirs: SubjectList):
Expand All @@ -454,16 +439,13 @@ def run(self, subject_dirs: SubjectList):
from FastSurferCNN.utils.common import iterate
iter_subjects = iterate(self.pool, self._get_subject_dataset, subject_dirs)
futures = []
for idx, (subject, (norm, norm_file, subject_dataset)) in tqdm(
enumerate(iter_subjects), total=len(subject_dirs), desc="Subject",
):
for idx, (subject, _data) in tqdm(enumerate(iter_subjects), total=len(subject_dirs), desc="Subject"):
norm, norm_file, subject_dataset = _data
try:
# predict CerebNet, returns logits (input and output are LIA)
preds = self._predict_single_subject(subject_dataset)
# create the folder for the output file, if it does not exist
_mkdir = self.pool.submit(
subject.segfile.parent.mkdir, exist_ok=True, parents=True,
)
_mkdir = self.pool.submit(subject.segfile.parent.mkdir, exist_ok=True, parents=True)

# postprocess logits (move axes, map sagittal to all classes, still LIA)
preds_per_plane = self._post_process_preds(preds)
Expand All @@ -487,11 +469,7 @@ def run(self, subject_dirs: SubjectList):
# this is None, but synchronizes the creation of the directory
_ = _mkdir.result()
futures.append(
self._save_cerebnet_seg(
full_cereb_seg,
subject.segfile,
subject_dataset.get_nibabel_img(),
)
self._save_cerebnet_seg(full_cereb_seg, subject.segfile, subject_dataset.get_nibabel_img())
)

if subject.has_attribute("cereb_statsfile"):
Expand Down Expand Up @@ -521,10 +499,9 @@ def run(self, subject_dirs: SubjectList):
)
)

logger.info(
f"Subject {idx + 1}/{len(subject_dirs)} with id '{subject.id}' processed in "
f"{pred_time - start_time :.2f} sec."
)
duration = pred_time - start_time
num = len(subject_dirs)
logger.info(f"Subject {idx + 1}/{num} with id '{subject.id}' processed in {duration:.2f} sec.")
except Exception as e:
logger.exception(e)
return "\n".join(map(str, e.args))
Expand Down
18 changes: 6 additions & 12 deletions CerebNet/run_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,11 @@ def setup_options():
# Training settings
parser = argparse.ArgumentParser(description="Segmentation")

# 1. Directory information (where to read from, where to write from and to incl.
# search-tag)
parser = parser_defaults.add_arguments(
parser, ["in_dir", "tag", "csv_file", "sd", "sid", "remove_suffix"]
)
# 1. Directory information (where to read from, where to write from and to incl. search-tag)
parser = parser_defaults.add_arguments(parser, ["in_dir", "tag", "csv_file", "sd", "sid", "remove_suffix"])

# 2. Options for the MRI volumes
parser = parser_defaults.add_arguments(
parser, ["t1", "conformed_name", "norm_name", "asegdkt_segfile"]
)
parser = parser_defaults.add_arguments(parser, ["t1", "conformed_name", "norm_name", "asegdkt_segfile"])
parser.add_argument(
"--cereb_segfile",
dest="cereb_segfile",
Expand Down Expand Up @@ -136,10 +131,9 @@ def main(args: argparse.Namespace) -> int | str:

Returns
-------
int
Returns 0 upon successful execution to indicate success.
str
A message indicating the failure reason in case of an exception.
int, str
Returns 0 upon successful execution to indicate success or
a message indicating the failure reason in case of an exception.

References
----------
Expand Down
9 changes: 1 addition & 8 deletions Docker/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,6 @@ To build a docker image with attestation and provenance, i.e. Software Bill Of M
[[worker.containerd.gcpolicy]]
all = true
keepBytes = 1024000000
# settings to push to a "local", registry with self-signed certificates
# see for example https://tech.paulcz.net/2016/01/secure-docker-with-tls/ https://github.com/paulczar/omgwtfssl
[registry."host:5000"]
ca=["/path/to/registry/ssl/ca.pem"]
[[registry."landau.dzne.ds:5000".keypair]]
key="/path/to/registry/ssl/key.pem"
cert="/path/to/registry/ssl/cert.pem"
```
3. Attestation files are not supported by the standard docker image storage driver. Therefore, images cannot be tested locally.
There are two solutions to this limitation.
Expand Down Expand Up @@ -231,7 +224,7 @@ rocms=("rocm$rocm")
# end of config

# code
git clone --branch stable --single-branch gtihub.com/Deep-MI/FastSurfer $build_dir
git clone --branch stable --single-branch github.com/Deep-MI/FastSurfer $build_dir
cd $build_dir
all_tags=("latest" "gpu-latest" "cuda-v$version" "rocm-v$version" "cpu-latest")
# build all distinct images
Expand Down
8 changes: 3 additions & 5 deletions FastSurferCNN/data_loader/conform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1412,16 +1412,14 @@ def crop_transform(
_target_shape = image.shape[:-len_off] + tuple(
i - 2 * o for i, o in zip(image.shape[-len_off:], offsets, strict=False)
)
elif len_off != len(target_shape):
raise ValueError(
"Incompatible offset and target_shape dimensionality (at least once)."
)
else:
elif len_off == len(target_shape):
_target_shape = tuple(
i if t == -1 else t
for i, t in zip(image.shape[-len_off:], target_shape, strict=False)
)
_target_shape = image.shape[:-len_off] + _target_shape
else:
raise ValueError("Incompatible offset and target_shape dimensionality (at least once).")

if len_off > image.ndim:
raise RuntimeError("shape of offsets is larger than dim of image allows.")
Expand Down
32 changes: 9 additions & 23 deletions FastSurferCNN/data_loader/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,31 +206,22 @@ def load_maybe_conform(
dst_file = file
else:
# the image is not conformed to 1mm, do this now.

fileext = [
ext for ext in SUPPORTED_OUTPUT_FILE_FORMATS
if file.name.endswith("." + ext)
]
fileext = [ext for ext in SUPPORTED_OUTPUT_FILE_FORMATS if file.name.endswith("." + ext)]
if len(fileext) != 1:
raise RuntimeError(
f"Invalid file extension of conf_name: {file}, must be one of "
f"{SUPPORTED_OUTPUT_FILE_FORMATS}."
f"Invalid file extension of conf_name: {file}, must be one of {SUPPORTED_OUTPUT_FILE_FORMATS}."
)
file_no_fileext = str(file)[:-len(fileext[0]) - 1]
if (vox_size := conform_kwargs.get("vox_size", 1.0)) == "min":
vox_suffix = ".min"
else:
vox_suffix = f".{str(vox_size).replace('.', '')}mm"
vox_size = conform_kwargs.get("vox_size", 1.0)
vox_suffix = ".min" if vox_size == "min" else f".{str(vox_size).replace('.', '')}mm"
if not file_no_fileext.endswith(vox_suffix):
file_no_fileext += vox_suffix
# if the orig file is neither absolute nor in the subject path, use the
# conformed file
# if the orig file is neither absolute nor in the subject path, use the conformed file
src_file = alt_file if alt_file.is_file() else file
if not alt_file.is_file():
LOGGER.warning(
f"No valid alternative file (e.g. orig, here: {alt_file}) was given to "
f"interpolate from, so we might lose quality due to multiple chained "
f"interpolations."
f"No valid alternative file (e.g. orig, here: {alt_file}) was given to interpolate from, so we might "
f"lose quality due to multiple chained interpolations. "
)

dst_file = Path(file_no_fileext + "." + fileext[0])
Expand Down Expand Up @@ -272,13 +263,8 @@ def save_image(
Image array type; if provided, the image object is explicitly set to match this type.
"""
save_as = Path(save_as)
assert (
save_as.suffix[1:] in SUPPORTED_OUTPUT_FILE_FORMATS or
save_as.suffixes[-2:] == [".nii", ".gz"]
), (
f"Output filename does not contain a supported file format "
f"{SUPPORTED_OUTPUT_FILE_FORMATS}!"
)
valid_ext = save_as.suffix[1:] in SUPPORTED_OUTPUT_FILE_FORMATS or save_as.suffixes[-2:] == [".nii", ".gz"]
assert valid_ext, f"Output filename does not contain a supported file format {SUPPORTED_OUTPUT_FILE_FORMATS}!"

mgh_img = None
if save_as.suffix == ".mgz":
Expand Down
Loading