Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions LST_AI/annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def annotate_lesions(atlas_t1, atlas_mask, t1w_native, seg_native, out_atlas_war

if __name__ == "__main__":

# Only for testing purposes
lst_dir = os.getcwd()
parent_directory = os.path.dirname(lst_dir)
atlas_t1w_path = os.path.join(parent_directory, "atlas", "sub-mni152_space-mni_t1.nii.gz")
Expand Down
93 changes: 93 additions & 0 deletions LST_AI/custom_tf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import tensorflow as tf
import numpy as np

def load_custom_model(model_path, compile=False):
"""
Loads a custom TensorFlow Keras model from the specified path.

This function is specifically designed to handle models that originally used the
`tfa.InstanceNormalization` layer from TensorFlow Addons (tfa). Since tfa is no
longer maintained, this function replaces the `InstanceNormalization` layer with a
custom layer, `CustomGroupNormalization`, to ensure compatibility and avoid the need
for installing tfa.

Args:
model_path (str): The file path to the saved Keras model.
compile (bool): If True, compiles the model after loading. Defaults to False.

Returns:
tf.keras.Model: The loaded Keras model with `InstanceNormalization` layers replaced
by `CustomGroupNormalization`.

Example:
>>> model = load_custom_model('path/to/model.h5', compile=True)
"""
custom_objects = {
'Addons>InstanceNormalization': CustomGroupNormalization,
}
return tf.keras.models.load_model(model_path, custom_objects=custom_objects, compile=compile)



class CustomGroupNormalization(tf.keras.layers.Layer):
"""
Custom Group Normalization layer for TensorFlow Keras models.

This class provides an alternative to the `tfa.InstanceNormalization` layer found in
TensorFlow Addons (tfa), which is no longer maintained and not available for MAC ARM platforms.
It facilitates the use of group normalization in models without the dependency on tfa, ensuring
compatibility and broader platform support.

Args:
groups (int): Number of groups for Group Normalization. Default is -1.
**kwargs: Additional keyword arguments for layer configuration.
"""
def __init__(self, groups=-1, **kwargs):
# Extract necessary arguments from kwargs
self.groups = kwargs.pop('groups', -1)
self.epsilon = kwargs.pop('epsilon', 0.001)
self.center = kwargs.pop('center', True)
self.scale = kwargs.pop('scale', True)
self.beta_initializer = kwargs.pop('beta_initializer', 'zeros')
self.gamma_initializer = kwargs.pop('gamma_initializer', 'ones')
self.beta_regularizer = kwargs.pop('beta_regularizer', None)
self.gamma_regularizer = kwargs.pop('gamma_regularizer', None)
self.beta_constraint = kwargs.pop('beta_constraint', None)
self.gamma_constraint = kwargs.pop('gamma_constraint', None)

# 'axis' argument is not used in GroupNormalization, so we remove it
kwargs.pop('axis', None)

super(CustomGroupNormalization, self).__init__(**kwargs)
self.group_norm = tf.keras.layers.GroupNormalization(
groups=self.groups,
epsilon=self.epsilon,
center=self.center,
scale=self.scale,
beta_initializer=self.beta_initializer,
gamma_initializer=self.gamma_initializer,
beta_regularizer=self.beta_regularizer,
gamma_regularizer=self.gamma_regularizer,
beta_constraint=self.beta_constraint,
gamma_constraint=self.gamma_constraint,
**kwargs
)

def call(self, inputs, training=None):
return self.group_norm(inputs, training=training)

def get_config(self):
config = super(CustomGroupNormalization, self).get_config()
config.update({
'groups': self.groups,
'epsilon': self.epsilon,
'center': self.center,
'scale': self.scale,
'beta_initializer': self.beta_initializer,
'gamma_initializer': self.gamma_initializer,
'beta_regularizer': self.beta_regularizer,
'gamma_regularizer': self.gamma_regularizer,
'beta_constraint': self.beta_constraint,
'gamma_constraint': self.gamma_constraint
})
return config
62 changes: 36 additions & 26 deletions LST_AI/lst
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,11 @@ import tempfile
import shutil
import argparse

# to filter the warning:
# WARNING:root:The given value for groups will be overwritten.
import logging
class Filter(logging.Filter):
def filter(self, record):
return 'The given value for groups will be overwritten.' not in record.getMessage()

logging.getLogger().addFilter(Filter())

from LST_AI.strip import run_hdbet, apply_mask
from LST_AI.register import mni_registration, apply_warp, rigid_reg
from LST_AI.segment import unet_segmentation
from LST_AI.annotate import annotate_lesions
from LST_AI.stats import compute_stats
from LST_AI.utils import download_data

if __name__ == "__main__":
Expand Down Expand Up @@ -135,10 +127,10 @@ if __name__ == "__main__":
os.makedirs(work_dir)

# Define Image Paths (original space)
path_org_t1w = os.path.join(work_dir, 'sub-X_ses-Y_space-orig_T1w.nii.gz')
path_org_flair = os.path.join(work_dir, 'sub-X_ses-Y_space-orig_FLAIR.nii.gz')
path_org_stripped_t1w = os.path.join(work_dir, 'sub-X_ses-Y_space-orig_desc-stripped_T1w.nii.gz')
path_org_stripped_flair = os.path.join(work_dir, 'sub-X_ses-Y_space-orig_desc-stripped_FLAIR.nii.gz')
path_org_t1w = os.path.join(work_dir, 'sub-X_ses-Y_space-t1w_T1w.nii.gz')
path_org_flair = os.path.join(work_dir, 'sub-X_ses-Y_space-flair_FLAIR.nii.gz')
path_org_stripped_t1w = os.path.join(work_dir, 'sub-X_ses-Y_space-t1w_desc-stripped_T1w.nii.gz')
path_org_stripped_flair = os.path.join(work_dir, 'sub-X_ses-Y_space-flair_desc-stripped_FLAIR.nii.gz')

# Define Image Paths (MNI space)
path_mni_t1w = os.path.join(work_dir, 'sub-X_ses-Y_space-mni_T1w.nii.gz')
Expand All @@ -147,15 +139,23 @@ if __name__ == "__main__":
path_mni_stripped_flair = os.path.join(work_dir, 'sub-X_ses-Y_space-mni_desc-stripped_FLAIR.nii.gz')

# Masks
path_orig_brainmask_t1w = os.path.join(work_dir, 'sub-X_ses-Y_space-org_T1w_mask.nii.gz')
path_orig_brainmask_flair = os.path.join(work_dir, 'sub-X_ses-Y_space-org_FLAIR_mask.nii.gz')
path_orig_brainmask_t1w = os.path.join(work_dir, 'sub-X_ses-Y_space-t1w_brainmask.nii.gz')
path_orig_brainmask_flair = os.path.join(work_dir, 'sub-X_ses-Y_space-flair_brainmask.nii.gz')
path_mni_brainmask = os.path.join(work_dir, 'sub-X_ses-Y_space-mni_brainmask.nii.gz')

# Segmentation results
path_orig_segmentation = os.path.join(work_dir, 'sub-X_ses-Y_space-orig_seg.nii.gz')
path_mni_segmentation = os.path.join(work_dir, 'sub-X_ses-Y_space-mni_seg.nii.gz')
path_orig_annotated_segmentation = os.path.join(work_dir, 'sub-X_ses-Y_space-orig_seg-annotated.nii.gz')
path_mni_annotated_segmentation = os.path.join(work_dir, 'sub-X_ses-Y_space-mni_seg-annotated.nii.gz')
# Temp Segmentation results
path_orig_segmentation = os.path.join(work_dir, 'sub-X_ses-Y_space-flair_seg-lst.nii.gz')
path_mni_segmentation = os.path.join(work_dir, 'sub-X_ses-Y_space-mni_seg-lst.nii.gz')
path_orig_annotated_segmentation = os.path.join(work_dir, 'sub-X_ses-Y_space-flair_desc-annotated_seg-lst.nii.gz')
path_mni_annotated_segmentation = os.path.join(work_dir, 'sub-X_ses-Y_space-mni_desc-annotated_seg-lst.nii.gz')

# Output paths (in original space)
filename_output_segmentation = "space-flair_seg-lst.nii.gz"
filename_output_annotated_segmentation = "space-flair_desc-annotated_seg-lst.nii.gz"

# Stats
filename_output_stats_segmentation = "lesion_stats.csv"
filename_output_stats_annotated_segmentation = "annotated_lesion_stats.csv"

# affines
path_affine_mni_t1w = os.path.join(work_dir, 'affine_t1w_to_mni.mat')
Expand Down Expand Up @@ -187,6 +187,7 @@ if __name__ == "__main__":

# Annotation only
if args.annotate_only:
print("LST-AI assumes existing segmentation to be in FLAIR space.")
if os.path.isfile(args.existing_seg):
shutil.copy(args.existing_seg, path_orig_segmentation)
else:
Expand Down Expand Up @@ -240,7 +241,7 @@ if __name__ == "__main__":
out_annotated_native=path_orig_annotated_segmentation)

shutil.copy(path_orig_annotated_segmentation,
os.path.join(args.output, "space-orig_desc-annotated_seg-lst.nii.gz"))
os.path.join(args.output, filename_output_annotated_segmentation))


# Segmentation only + (opt. Annotation)
Expand Down Expand Up @@ -283,8 +284,7 @@ if __name__ == "__main__":

# move processed mask to correct naming convention
hdbet_mask = path_mni_stripped_t1w.replace(".nii.gz", "_mask.nii.gz")
print(hdbet_mask)
shutil.copy(hdbet_mask, path_mni_brainmask)
shutil.move(hdbet_mask, path_mni_brainmask)

# then apply brain mask to FLAIR
apply_mask(input_image=path_mni_flair,
Expand Down Expand Up @@ -333,7 +333,7 @@ if __name__ == "__main__":
n_threads=args.threads)

# store the segmentations
shutil.copy(path_orig_segmentation, os.path.join(args.output, "space-orig_seg-lst.nii.gz"))
shutil.copy(path_orig_segmentation, os.path.join(args.output, filename_output_segmentation))

# Annotation
if not args.segment_only:
Expand All @@ -354,8 +354,18 @@ if __name__ == "__main__":
n_threads=args.threads)

# store the segmentations
shutil.copy(path_orig_annotated_segmentation, os.path.join(args.output, "space-orig_desc-annotated_seg-lst.nii.gz"))

shutil.copy(path_orig_annotated_segmentation, os.path.join(args.output, filename_output_annotated_segmentation))

# Compute Stats of (annotated) segmentation if they exist
if os.path.exists(path_orig_segmentation):
compute_stats(mask_file=path_orig_segmentation,
output_file=os.path.join(args.output, filename_output_stats_segmentation),
multi_class=False)

if os.path.exists(path_orig_annotated_segmentation):
compute_stats(mask_file=path_orig_annotated_segmentation,
output_file=os.path.join(args.output, filename_output_stats_annotated_segmentation),
multi_class=True)

print(f"Results in {work_dir}")
if not args.temp:
Expand Down
4 changes: 2 additions & 2 deletions LST_AI/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@ def apply_warp(image_org_space, affine, origin, target, reverse=False, n_threads

subprocess.run(shlex.split(warp_call), check=True)



if __name__ == "__main__":

# Testing only

# Working directory
script_dir = os.getcwd()
parent_directory = os.path.dirname(script_dir)
Expand Down
10 changes: 5 additions & 5 deletions LST_AI/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import tensorflow_addons as tfa
#logging.getLogger("tensorflow").setLevel(logging.CRITICAL)
#logging.getLogger("tensorflow_addons").setLevel(logging.CRITICAL)

from LST_AI.custom_tf import load_custom_model


def unet_segmentation(model_path, mni_t1, mni_flair, output_segmentation_path, device='cpu', input_shape=(192,192,192), threshold=0.5):
"""
Expand Down Expand Up @@ -99,7 +99,7 @@ def preprocess_intensities(img_arr):
for i, model in enumerate(unet_mdls):
with tf.device(tf_device):
print(f"Running model {i}. ")
mdl = tf.keras.models.load_model(model, compile=False)
mdl = load_custom_model(model, compile=False)

img_image = np.stack([flair, t1], axis=-1)
img_image = np.expand_dims(img_image, axis=0)
Expand Down Expand Up @@ -129,7 +129,7 @@ def preprocess_intensities(img_arr):


if __name__ == "__main__":

# Testing only
# Working directory
script_dir = os.getcwd()
parent_dir = os.path.dirname(script_dir)
Expand Down
95 changes: 95 additions & 0 deletions LST_AI/stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import nibabel as nib
import numpy as np
import csv
import argparse
from scipy.ndimage import label

def compute_stats(mask_file, output_file, multi_class):
"""
Compute statistics from a lesion mask and save the results to a CSV file.

Parameters:
mask_file (str): Path to the input mask file in NIfTI format.
output_file (str): Path to the output CSV file where results will be saved.
multi_class (bool): Flag indicating whether the mask contains multiple classes (True) or is binary (False).

This function calculates the number of lesions, the number of voxels in lesions, and the total lesion volume.
If `multi_class` is True, these statistics are calculated for each lesion class separately.
"""
# Load the mask file
mask = nib.load(mask_file)
mask_data = mask.get_fdata()

# Voxel dimensions to calculate volume
voxel_dims = mask.header.get_zooms()

results = []

if multi_class:
# Multi-class processing
lesion_labels = [1, 2, 3, 4]
label_names = {
1: 'Periventricular',
2: 'Juxtacortical',
3: 'Subcortical',
4: 'Infratentorial'
}

for lesion_label in lesion_labels:
class_mask = mask_data == lesion_label

# Count lesions (connected components) for each class
_ , num_lesions = label(class_mask)

voxel_count = np.count_nonzero(class_mask)
volume = voxel_count * np.prod(voxel_dims)

results.append({
'Region': label_names[lesion_label],
'Num_Lesions': num_lesions,
'Num_Vox': voxel_count,
'Lesion_Volume': volume
})

else:
# Binary mask processing
# Assert that only two unique values are present (0 and 1)
unique_values = np.unique(mask_data)
assert len(unique_values) <= 2, "Binary mask must contain no more than two unique values."

# Count lesions (connected components) in binary mask
_, num_lesions = label(mask_data > 0)

voxel_count = np.count_nonzero(mask_data)
volume = voxel_count * np.prod(voxel_dims)

results.append({
'Num_Lesions': num_lesions,
'Num_Vox': voxel_count,
'Lesion_Volume': volume
})

# Save results to CSV
with open(output_file, 'w', newline='') as file:
writer = csv.writer(file)
if multi_class:
writer.writerow(['Region', 'Num_Lesions', 'Num_Vox', 'Lesion_Volume'])
for result in results:
writer.writerow([result['Region'], result['Num_Lesions'], result['Num_Vox'], result['Lesion_Volume']])
else:
writer.writerow(['Num_Lesions', 'Num_Vox', 'Lesion_Volume'])
for result in results:
writer.writerow([result['Num_Lesions'], result['Num_Vox'], result['Lesion_Volume']])

if __name__ == "__main__":
"""
Main entry point of the script. Parses command-line arguments and calls the compute_stats function.
"""
parser = argparse.ArgumentParser(description='Process a lesion mask file.')
parser.add_argument('--in', dest='input_file', required=True, help='Input mask file path')
parser.add_argument('--out', dest='output_file', required=True, help='Output CSV file path')
parser.add_argument('--multi-class', dest='multi_class', action='store_true', help='Flag for multi-class processing')

args = parser.parse_args()

compute_stats(args.input_file, args.output_file, args.multi_class)
Loading