Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 57 additions & 6 deletions brainles_preprocessing/modality.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,49 @@ def __init__(
self,
modality_name: str,
input_path: str,
output_path: str,
bet: bool,
atlas_correction: bool = True,
raw_bet_output_path: str = None,
raw_skull_output_path: str = None,
normalized_bet_output_path: str = None,
normalized_skull_output_path: str = None,
normalizer: Optional[Normalizer] = None,
atlas_correction: bool = True,
) -> None:
self.modality_name = modality_name
self.input_path = turbopath(input_path)
self.output_path = turbopath(output_path)
self.bet = bet
self.atlas_correction = atlas_correction
if (
raw_bet_output_path is None
and normalized_bet_output_path is None
and raw_skull_output_path is None
and normalized_skull_output_path is None
):
raise ValueError(
"All output paths are None. At least one output path must be provided."
)
self.raw_bet_output_path = turbopath(raw_bet_output_path)
self.raw_skull_output_path = turbopath(raw_skull_output_path)
if normalized_bet_output_path is not None:
if normalizer is None:
raise ValueError(
"A normalizer must be provided if normalized_bet_output_path is not None."
)
self.normalized_bet_output_path = turbopath(normalized_bet_output_path)

if normalized_skull_output_path is not None:
if normalizer is None:
raise ValueError(
"A normalizer must be provided if normalized_skull_output_path is not None."
)
self.normalized_skull_output_path = turbopath(normalized_skull_output_path)

self.normalizer = normalizer
self.atlas_correction = atlas_correction

self.current = self.input_path

@property
def bet(self) -> bool:
return any([self.raw_bet_output_path, self.normalized_bet_output_path])

def normalize(
self,
temporary_directory: str,
Expand Down Expand Up @@ -225,3 +255,24 @@ def extract_brain_region(
if self.bet is True:
self.current = atlas_bet_cm
return atlas_mask_path

def save_current_image(
self,
output_path: str,
normalization=False,
) -> None:
os.makedirs(output_path.parent, exist_ok=True)

if normalization is False:
shutil.copyfile(
self.current,
output_path,
)
elif normalization is True:
image = read_nifti(self.current)
normalized_image = self.normalizer.normalize(image=image)
write_nifti(
input_array=normalized_image,
output_nifti_path=output_path,
reference_nifti_path=self.current,
)
65 changes: 42 additions & 23 deletions brainles_preprocessing/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,25 @@ def run(
save_dir_atlas_registration: Optional[str] = None,
save_dir_atlas_correction: Optional[str] = None,
save_dir_brain_extraction: Optional[str] = None,
save_dir_unnormalized: Optional[str] = None,
):
"""
Execute the preprocessing pipeline.
Execute the preprocessing pipeline, encompassing coregistration, atlas-based registration,
atlas correction, and optional brain extraction.

Args:
save_dir_coregistration (str, optional): Directory path to save coregistration results.
save_dir_atlas_registration (str, optional): Directory path to save atlas registration results.
save_dir_atlas_correction (str, optional): Directory path to save atlas correction results.
save_dir_brain_extraction (str, optional): Directory path to save brain extraction results.
save_dir_unnormalized (str, optional): Directory path to save unnormalized images.

This method orchestrates the entire preprocessing workflow by sequentially performing:

1. Coregistration: Aligning moving modalities to the central modality.
2. Atlas Registration: Aligning the central modality to a predefined atlas.
3. Atlas Correction: Applying additional correction in atlas space if specified.
4. Brain Extraction: Optionally extracting brain regions using specified masks.

Results are saved in the specified directories, allowing for modular and configurable output storage.
"""
# Coregister moving modalities to center modality
coregistration_dir = os.path.join(self.temp_folder, "coregistration")
Expand Down Expand Up @@ -142,22 +150,22 @@ def run(
)

# Register center modality to atlas
file_name = f"atlas__{self.center_modality.modality_name}"
center_file_name = f"atlas__{self.center_modality.modality_name}"
transformation_matrix = self.center_modality.register(
registrator=self.registrator,
fixed_image_path=self.atlas_image_path,
registration_dir=self.atlas_dir,
moving_image_name=file_name,
moving_image_name=center_file_name,
)

# Transform moving modalities to atlas
for moving_modality in self.moving_modalities:
file_name = f"atlas__{moving_modality.modality_name}"
moving_file_name = f"atlas__{moving_modality.modality_name}"
moving_modality.transform(
registrator=self.registrator,
fixed_image_path=self.atlas_image_path,
registration_dir_path=self.atlas_dir,
moving_image_name=file_name,
moving_image_name=moving_file_name,
transformation_matrix_path=transformation_matrix,
)
self._save_output(
Expand All @@ -171,12 +179,12 @@ def run(

for moving_modality in self.moving_modalities:
if moving_modality.atlas_correction is True:
file_name = f"atlas_corrected__{self.center_modality.modality_name}__{moving_modality.modality_name}"
moving_file_name = f"atlas_corrected__{self.center_modality.modality_name}__{moving_modality.modality_name}"
moving_modality.register(
registrator=self.registrator,
fixed_image_path=self.center_modality.current,
registration_dir=atlas_correction_dir,
moving_image_name=file_name,
moving_image_name=moving_file_name,
)

if self.center_modality.atlas_correction is True:
Expand All @@ -193,8 +201,22 @@ def run(
save_dir=save_dir_atlas_correction,
)

# now we save images that are not skullstripped
for modality in self.all_modalities:
if modality.raw_skull_output_path:
modality.save_current_image(
modality.raw_skull_output_path,
normalization=False,
)
if modality.normalized_skull_output_path:
modality.save_current_image(
modality.normalized_skull_output_path,
normalization=True,
)

# Optional: Brain extraction
brain_extraction = any(modality.bet for modality in self.all_modalities)

if brain_extraction:
bet_dir = os.path.join(self.temp_folder, "brain-extraction")
os.makedirs(bet_dir, exist_ok=True)
Expand All @@ -216,21 +238,18 @@ def run(
save_dir=save_dir_brain_extraction,
)

# Optional: Normalization
normalization = any(modality.normalizer for modality in self.all_modalities)
if normalization:
for modality in [self.center_modality] + self.moving_modalities:
modality.normalize(
temporary_directory=self.temp_folder,
store_unnormalized=save_dir_unnormalized,
)

# now we save images that are skullstripped
for modality in self.all_modalities:
os.makedirs(modality.output_path.parent, exist_ok=True)
shutil.copyfile(
modality.current,
modality.output_path,
)
if modality.raw_bet_output_path:
modality.save_current_image(
modality.raw_bet_output_path,
normalization=False,
)
if modality.normalized_bet_output_path:
modality.save_current_image(
modality.normalized_bet_output_path,
normalization=True,
)

def _save_output(
self,
Expand Down
Loading