In [None]:
from glob import glob
import os
import numpy as np
import pydicom
from pydicom.pixel_data_handlers.util import apply_modality_lut
import nibabel as nib
from nibabel import Nifti1Image
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap


# Define paths
image_path = "/Volumes/advent/test/SBTMRI_001/Axial_T1weight_Lumbar"
mask_paths = {
    "L_ES": "/Volumes/advent/test/SBT001/L_ES",
    "R_ES": "/Volumes/advent/test/SBT001/R_ES",
    "L_Mult": "/Volumes/advent/test/SBT001/L_Mult",
    "R_Mult": "/Volumes/advent/test/SBT001/R_Mult",
}
output_dir = "/Volumes/advent/processed_data/SBT001"
os.makedirs(output_dir, exist_ok=True)

def load_dicom_images(folder):
    dicom_files = sorted(glob(os.path.join(folder, "*.dcm")))
    if not dicom_files:
        raise ValueError(f"No DICOM files found in {folder}")
    print(f"Found {len(dicom_files)} DICOM files in {folder}")
    images = []
    for file in dicom_files:
        ds = pydicom.dcmread(file)
        image = apply_modality_lut(ds.pixel_array, ds)
        images.append(image)
    return np.stack(images)

def load_dicom_masks(mask_folder):
    mask_files = sorted(glob(os.path.join(mask_folder, "*.dcm")))
    masks = []
    for file in mask_files:
        ds = pydicom.dcmread(file)
        mask = apply_modality_lut(ds.pixel_array, ds)
        masks.append(mask)
    return np.stack(masks)

def save_nifti(data, affine, file_name):
    nifti_image = Nifti1Image(data, affine)
    nib.save(nifti_image, file_name)


def visualize_overlay_colored_masks(image_data, masks, slice_idx):
    """
    Visualize a single image slice with multiple masks overlaid in specific colors.

    :param image_data: 3D numpy array of the image data.
    :param masks: Dictionary of masks, each a 3D numpy array.
    :param slice_idx: Index of the slice to visualize.
    """
    # Define colormap for each mask
    colors = {
        "L_ES": ListedColormap(["none", "red"]),
        "R_ES": ListedColormap(["none", "blue"]),
        "L_Mult": ListedColormap(["none", "green"]),
        "R_Mult": ListedColormap(["none", "yellow"])
    }

    plt.figure(figsize=(10, 10))
    plt.imshow(image_data[slice_idx], cmap="gray", interpolation="none")  # Base image

    # Overlay each mask with its corresponding color
    for mask_name, mask_data in masks.items():
        if mask_name in colors:  # Ensure the mask has a defined colormap
            plt.imshow(
                mask_data[slice_idx],
                cmap=colors[mask_name],
                alpha=0.4,
                interpolation="none"
            )

    plt.title(f"Overlay of All Masks - Slice {slice_idx}")
    plt.axis("off")
    plt.show()


# Load images
print("Loading images...")
images = load_dicom_images(image_path)

# Load masks
print("Loading masks...")
masks = {name: load_dicom_masks(path) for name, path in mask_paths.items()}

# Save as NIfTI
print("Saving images as NIfTI...")
save_nifti(images, np.eye(4), os.path.join(output_dir, "images.nii"))

for mask_name, mask_data in masks.items():
    print(f"Saving mask {mask_name} as NIfTI...")
    save_nifti(mask_data, np.eye(4), os.path.join(output_dir, f"{mask_name}_mask.nii"))

# Visualize overlay
slice_index = images.shape[0] // 2
visualize_overlay_colored_masks(images, masks, slice_index)



In [None]:
def visualize_normalized_masks(image_data, masks, slice_idx):
    """
    Visualize each mask separately with proper normalization.
    
    Args:
        image_data (numpy.ndarray): 3D array of the main image data
        masks (dict): Dictionary of mask names and their corresponding 3D array data
        slice_idx (int): Index of the slice to visualize
    """
    # Create a figure with 5 subplots (original + 4 masks)
    fig = plt.figure(figsize=(20, 4))
    
    # Plot original image
    plt.subplot(151)
    plt.imshow(image_data[slice_idx], cmap='gray')
    plt.title('Original Image')
    plt.axis('off')
    
    # Plot each mask separately with normalization
    for idx, (name, mask) in enumerate(masks.items(), start=2):
        plt.subplot(1, 5, idx)
        
        # Get the mask slice
        mask_slice = mask[slice_idx]
        
        # Normalize mask to [0, 1] range for visualization
        if mask_slice.max() != mask_slice.min():
            # For masks with negative values, zero them out
            if mask_slice.min() < 0:
                mask_slice = np.maximum(mask_slice, 0)
            
            # Normalize
            mask_slice = (mask_slice - mask_slice.min()) / (mask_slice.max() - mask_slice.min())
        
        # Create binary mask for visualization
        binary_mask = (mask_slice > 0).astype(float)
        
        # Show the normalized mask
        plt.imshow(binary_mask, cmap='jet')
        plt.title(f'{name}\nNon-zero pixels: {np.count_nonzero(binary_mask)}')
        plt.axis('off')

    plt.tight_layout()
    plt.show()

def load_nifti(file_path):
    """
    Load a NIfTI file and return the image data and affine matrix.
    """
    nifti = nib.load(file_path)
    return nifti.get_fdata(), nifti.affine

def load_and_show_masks(image_file, mask_files):
    """
    Load and visualize all masks with normalization.
    """
    # Load the main image
    main_image, _ = load_nifti(image_file)
    
    # Load all masks
    masks = {}
    for name, path in mask_files.items():
        masks[name], _ = load_nifti(path)
    
    # Visualize middle slice
    slice_idx = main_image.shape[0] // 4
    visualize_normalized_masks(main_image, masks, slice_idx)

# Run the visualization
load_and_show_masks(
    image_file="/Volumes/advent/processed_data/SBT001/images.nii",
    mask_files={
        "L_ES": "/Volumes/advent/processed_data/SBT001/L_ES_mask.nii",
        "R_ES": "/Volumes/advent/processed_data/SBT001/R_ES_mask.nii",
        "L_Mult": "/Volumes/advent/processed_data/SBT001/L_Mult_mask.nii",
        "R_Mult": "/Volumes/advent/processed_data/SBT001/R_Mult_mask.nii"
    }
)

In [None]:
import os
from glob import glob
import numpy as np
import pydicom
from pydicom.pixel_data_handlers.util import apply_modality_lut
import nibabel as nib
from nibabel import Nifti1Image
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

# -----------------------------------------------------------------------------
# Define paths
# -----------------------------------------------------------------------------
# Example folder paths; change these to match your actual directories
image_path = "/Volumes/advent/test/SBTMRI_001/Axial_T1weight_Lumbar"

mask_paths = {
    "L_ES": "/Volumes/advent/test/SBT001/L_ES",
    "R_ES": "/Volumes/advent/test/SBT001/R_ES",
    "L_Mult": "/Volumes/advent/test/SBT001/L_Mult",
    "R_Mult": "/Volumes/advent/test/SBT001/R_Mult",
}

output_dir = "/Volumes/advent/processed_data/SBT001"
os.makedirs(output_dir, exist_ok=True)


# -----------------------------------------------------------------------------
# Utility functions
# -----------------------------------------------------------------------------

def load_dicom_images(folder):
    """
    Loads DICOM images from `folder` and returns a 3D NumPy array of shape:
    (num_slices, height, width).
    """
    dicom_files = sorted(glob(os.path.join(folder, "*.dcm")))
    if not dicom_files:
        raise ValueError(f"No DICOM files found in {folder}")
    
    print(f"Found {len(dicom_files)} DICOM files in {folder}")
    slices = []
    for file in dicom_files:
        ds = pydicom.dcmread(file)
        # Apply the modality LUT if relevant; ensures correct intensity scaling
        image_array = apply_modality_lut(ds.pixel_array, ds)
        slices.append(image_array)
    
    # Stack all slices into a 3D array: (num_slices, height, width)
    volume = np.stack(slices, axis=0)
    return volume

def load_dicom_masks(mask_folder):
    """
    Loads DICOM masks from `mask_folder` and returns a binary 3D NumPy array.
    Each voxel is set to 1 if >0, otherwise 0.
    """
    mask_files = sorted(glob(os.path.join(mask_folder, "*.dcm")))
    if not mask_files:
        raise ValueError(f"No DICOM mask files found in {mask_folder}")
    
    masks = []
    for file in mask_files:
        ds = pydicom.dcmread(file)
        mask_array = apply_modality_lut(ds.pixel_array, ds)
        # Convert nonzero values to 1 for a binary mask
        mask_binary = np.where(mask_array > 0, 1, 0)
        masks.append(mask_binary)
    
    # Stack into 3D: (num_slices, height, width)
    mask_volume = np.stack(masks, axis=0)
    return mask_volume

def save_nifti(data, affine, file_name):
    """
    Saves the NumPy array `data` as a NIfTI file using the provided `affine` matrix.
    """
    nifti_image = Nifti1Image(data.astype(np.int16), affine)
    nib.save(nifti_image, file_name)
    print(f"Saved NIfTI: {file_name}")

def visualize_overlay_colored_masks(image_data, masks, slice_idx):
    """
    Visualize a single slice from 'image_data' with multiple binary masks overlaid.
    Each mask is assigned a distinct color via a ListedColormap.
    """
    # Define color for each label
    colors = {
        "L_ES_II": ListedColormap(["none", "red"]),
        "R_ES_II": ListedColormap(["none", "blue"]),
        "L_Mult_II": ListedColormap(["none", "green"]),
        "R_Mult_II": ListedColormap(["none", "yellow"])
    }

    plt.figure(figsize=(10, 8))
    plt.imshow(image_data[slice_idx], cmap="gray", interpolation="none")

    for mask_name, mask_data in masks.items():
        cmap = colors.get(mask_name, None)
        if cmap is not None:
            plt.imshow(
                mask_data[slice_idx],
                cmap=cmap,
                alpha=0.4,
                interpolation="none"
            )
    plt.title(f"Overlay of Masks - Slice {slice_idx}")
    plt.axis("off")
    plt.show()


# -----------------------------------------------------------------------------
# Main Execution
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    # 1. Load DICOM images
    print("Loading images...")
    images = load_dicom_images(image_path)

    # 2. Load the four DICOM masks (left and right ES, left and right Multifidus)
    print("Loading masks...")
    L_ES_I = load_dicom_masks(mask_paths["L_ES"])
    R_ES_I = load_dicom_masks(mask_paths["R_ES"])
    L_Mult_I = load_dicom_masks(mask_paths["L_Mult"])
    R_Mult_I = load_dicom_masks(mask_paths["R_Mult"])

    # 3. Compute overlap (element-wise multiplication) between Multifidus & ES
    #    and then subtract that overlap from the ES masks to avoid double labeling.
    #
    #    L_Mult_II = L_Mult_I * L_ES_I
    #    R_Mult_II = R_Mult_I * R_ES_I
    #
    #    L_ES_II = L_ES_I - L_Mult_II
    #    R_ES_II = R_ES_I - R_Mult_II
    #
    #    Explanation:
    #    - "Mult_II" mask keeps only the shared (overlapping) portion
    #      of Multifidus and ES on that side.
    #    - "ES_II" mask keeps the original ES minus the overlap,
    #      so it doesn't double label the same region.
    print("Creating refined masks...")
    L_Mult_II = L_Mult_I * L_ES_I
    R_Mult_II = R_Mult_I * R_ES_I

    L_ES_II = L_ES_I - L_Mult_II
    R_ES_II = R_ES_I - R_Mult_II

    # 4. Save everything as NIfTI
    print("\nSaving as NIfTI...")
    # 4a. Save the DICOM image volume
    save_nifti(images, np.eye(4), os.path.join(output_dir, "images.nii"))

    # 4b. Save the new masks
    save_nifti(L_Mult_II, np.eye(4), os.path.join(output_dir, "L_Mult_mask.nii"))
    save_nifti(R_Mult_II, np.eye(4), os.path.join(output_dir, "R_Mult_mask.nii"))
    save_nifti(L_ES_II,   np.eye(4), os.path.join(output_dir, "L_ES_mask.nii"))
    save_nifti(R_ES_II,   np.eye(4), os.path.join(output_dir, "R_ES_mask.nii"))

    # 5. Visualize overlay for a middle slice (or pick any)
    slice_index = images.shape[0] // 3
    mask_dict = {
        "L_Mult_II": L_Mult_II,
        "R_Mult_II": R_Mult_II,
        "L_ES_II": L_ES_II,
        "R_ES_II": R_ES_II,
    }

    visualize_overlay_colored_masks(images, mask_dict, slice_index)

    print("\nDone!")
