Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 62 additions & 25 deletions brainles_preprocessing/brain_extraction/brain_extractor.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,79 @@
# TODO add typing and docs
from abc import abstractmethod
import os

import nibabel as nib
import numpy as np
from brainles_hd_bet import run_hd_bet

from auxiliary.nifti.io import read_nifti, write_nifti
from auxiliary.turbopath import name_extractor


from shutil import copyfile


class BrainExtractor:
@abstractmethod
def extract(
self,
input_image,
output_image,
log_file,
mode,
):
input_image_path: str,
masked_image_path: str,
brain_mask_path: str,
log_file_path: str,
# TODO convert mode to enum
mode: str,
) -> None:
pass

def apply_mask(
self,
input_image,
mask_image,
output_image,
):
"""masks images with brain masks"""
inputnifti = nib.load(input_image)
mask = nib.load(mask_image)
input_image_path: str,
mask_image_path: str,
masked_image_path: str,
) -> None:
"""
Apply a brain mask to an input image.

# mask it
masked_file = np.multiply(inputnifti.get_fdata(), mask.get_fdata())
masked_file = nib.Nifti1Image(masked_file, inputnifti.affine, inputnifti.header)
Parameters:
- input_image_path (str): Path to the input image (NIfTI format).
- mask_image_path (str): Path to the brain mask image (NIfTI format).
- masked_image_path (str): Path to save the resulting masked image (NIfTI format).

# save it
nib.save(masked_file, output_image)
Returns:
- str: Path to the saved masked image.
"""

# read data
input_data = read_nifti(input_image_path)
mask_data = read_nifti(mask_image_path)

# mask and save it
masked_data = input_data * mask_data

write_nifti(
input_array=masked_data,
output_nifti_path=masked_image_path,
reference_nifti_path=input_image_path,
create_parent_directory=True,
)


class HDBetExtractor(BrainExtractor):
def extract(
self,
input_image,
masked_image,
# TODO implement logging!
log_file,
mode="accurate",
):
input_image_path: str,
masked_image_path: str,
brain_mask_path: str,
log_file_path: str = None,
# TODO convert mode to enum
mode: str = "accurate",
) -> None:
# GPU + accurate + TTA
"""skullstrips images with HD-BET generates a skullstripped file and mask"""
run_hd_bet(
mri_fnames=[input_image],
output_fnames=[masked_image],
mri_fnames=[input_image_path],
output_fnames=[masked_image_path],
# device=0,
# TODO consider postprocessing
# postprocess=False,
Expand All @@ -59,3 +84,15 @@ def extract(
keep_mask=True,
overwrite=True,
)

hdbet_mask_path = (
masked_image_path.parent
+ "/"
+ name_extractor(masked_image_path)
+ "_masked.nii.gz"
)

copyfile(
src=hdbet_mask_path,
dst=brain_mask_path,
)
13 changes: 7 additions & 6 deletions brainles_preprocessing/modality.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ def apply_mask(
f"brain_masked__{self.modality_name}.nii.gz",
)
brain_extractor.apply_mask(
input_image=self.current,
mask_image=atlas_mask,
output_image=brain_masked,
input_image_path=self.current,
mask_image_path=atlas_mask,
masked_image_path=brain_masked,
)
self.current = brain_masked

Expand Down Expand Up @@ -153,9 +153,10 @@ def extract_brain_region(
)

brain_extractor.extract(
input_image=self.current,
masked_image=atlas_bet_cm,
log_file=bet_log,
input_image_path=self.current,
masked_image_path=atlas_bet_cm,
brain_mask_path=atlas_mask,
log_file_path=bet_log,
)
self.current = atlas_bet_cm
return atlas_mask