In [None]:
# Load methods from src.visualization.py and necessary libraries
import sys
sys.path.append('..')  # Add the parent directory to the path
from src.visualization import *
from src.preprocessing import *

Cropping

In [None]:
import os
import json
import SimpleITK as sitk
from tqdm import tqdm
import pandas as pd

def crop_all_images_and_masks(images_root, masks_root, json_folder, output_image_root, output_mask_root):
    patient_ids = sorted(os.listdir(images_root))
    patient_metadata = {}

    for patient_id in tqdm(patient_ids, desc="Cropping"):
        input_folder = os.path.join(images_root, patient_id)
        if not os.path.isdir(input_folder):
            continue

        # Get all phase images
        image_files = sorted([f for f in os.listdir(input_folder) if f.endswith(".nii.gz")])
        json_path = os.path.join(json_folder, f"{patient_id}.json")
        if not os.path.exists(json_path):
            print(f"⚠️ Missing JSON: {json_path}")
            continue

        with open(json_path, "r") as f:
            coords = json.load(f)["primary_lesion"]["breast_coordinates"]

        patient_out_folder = os.path.join(output_image_root, patient_id)
        os.makedirs(patient_out_folder, exist_ok=True)

        for image_file in image_files:
            image_path = os.path.join(input_folder, image_file)
            img_sitk = sitk.ReadImage(image_path, sitk.sitkFloat32)
            img_np = sitk.GetArrayFromImage(img_sitk)
            shape = img_np.shape

            x_min, x_max = coords["x_min"], coords["x_max"]
            y_min, y_max = coords["y_min"], coords["y_max"]
            z_min, z_max = coords["z_min"], coords["z_max"]

            if not (0 <= x_min < x_max <= shape[0] and
                    0 <= y_min < y_max <= shape[1] and
                    0 <= z_min < z_max <= shape[2]):
                print(f"❌ Invalid coordinates for {patient_id} - {image_file}")
                continue

            cropped_np = img_np[x_min:x_max, y_min:y_max, z_min:z_max]
            cropped_sitk = sitk.GetImageFromArray(cropped_np)

            spacing = img_sitk.GetSpacing()
            direction = img_sitk.GetDirection()
            origin_phys = img_sitk.TransformIndexToPhysicalPoint((z_min, y_min, x_min))

            cropped_sitk.SetSpacing(spacing)
            cropped_sitk.SetDirection(direction)
            cropped_sitk.SetOrigin(origin_phys)

            out_img_path = os.path.join(patient_out_folder, image_file)
            sitk.WriteImage(cropped_sitk, out_img_path, useCompression=True)

        # === Store coordinates for later ===
        coords["origin_phys"] = origin_phys
        coords["original_spacing"] = spacing
        patient_metadata[patient_id] = coords

    # Optionally, save all metadata to a JSON
    metadata_path = os.path.join(output_image_root, "cropping_metadata.json")
    with open(metadata_path, "w") as f:
        json.dump(patient_metadata, f, indent=2)

    print(f"\n✅ Saved cropping metadata to {metadata_path}")
    return patient_metadata

import ace_tools as tools; tools.display_dataframe_to_user(name="Failed Patients", dataframe=crop_all_images_and_masks(
    images_root="/data/nnUNet/Cropped/images",
    masks_root="/data/nnUNet/Cropped/segmentation",
    json_folder="/data/MAMA-MIA/patient_info_files",
    output_image_root="/data/nnUNet/Cropped/final_images",
    output_mask_root="/data/nnUNet/Cropped/final_masks"
))

Clipping, Normalising, Resampling

In [None]:
# Define the dataset paths and folders
dataset_path = '/data/nnUNet/Cropped'
images_folder = dataset_path + '/final_images'
output_folder = os.path.join(dataset_path, 'processed_cropped/images')

# Create output directory if it doesn't exist
os.makedirs(output_folder, exist_ok=True)

# === Preprocess loop ===
processed_count = 0
phases_to_process = [0, 1, 2]

for patient_id in sorted(os.listdir(images_folder)):
    patient_path = os.path.join(images_folder, patient_id)

    if not os.path.isdir(patient_path):
        continue

    mean, std = compute_precontrast_statistics(images_folder, patient_id)
    if mean is None or std is None or std == 0:
        continue

    for phase_index in phases_to_process:
        file_name = f"{patient_id}_{phase_index:04d}.nii.gz"
        file_path = os.path.join(images_folder, patient_id, file_name)
        output_file_path = os.path.join(output_folder, f"{patient_id}_{phase_index:04d}.nii.gz")

        if os.path.exists(output_file_path):
            continue  # Already processed

        if not os.path.exists(file_path):
            print(f"⚠️ Missing: {file_path}")
            continue

        try:
            image_sitk = sitk.ReadImage(file_path, sitk.sitkFloat32)
            clipped = clip_image_sitk(image_sitk, percentiles=[0.1, 99.9])
            normalized = zscore_normalization_sitk(clipped, mean, std)
            resampled = resample_sitk(normalized, new_spacing=[1, 1, 1], interpolator=sitk.sitkBSpline)

            sitk.WriteImage(resampled, output_file_path, useCompression=True)
            print(f"✅ Processed: {output_file_path}")
            processed_count += 1

        except Exception as e:
            print(f"❌ Error processing {file_path}: {e}")

print(f"\n✨ Total images processed: {processed_count}")

In [None]:
import sys
import os

sys.path.append(os.path.abspath(os.path.dirname(__file__)))

from src.preprocessing import *

# Define the dataset paths and folders
dataset_path = '/data/nnUNet/Cropped'
input_folder = dataset_path + '/final_masks'
output_folder = os.path.join(dataset_path, 'processed_cropped/segmentation')

# Create output directory if it doesn't exist
os.makedirs(output_folder, exist_ok=True)

# Process all patient folders
processed_count = 0

# === Process Each File ===
for filename in sorted(os.listdir(input_folder)):
    if not filename.endswith(".nii.gz"):
        continue

    patient_id = filename.replace(".nii.gz", "")
    input_path = os.path.join(input_folder, filename)
    output_path = os.path.join(output_folder, f"{patient_id}.nii.gz")

    # Skip if already processed
    if os.path.exists(output_path):
        print(f"✔️ Already processed: {patient_id}")
        continue

    if os.path.exists(input_path):
        print(f"🔄 Processing: {patient_id}")
        image_sitk = sitk.ReadImage(input_path, sitk.sitkUInt8)
        # Resample the normalized image to isotropic resolution
        resampled_sitk = resample_sitk(image_sitk, new_spacing=[1,1,1], interpolator=sitk.sitkNearestNeighbor)
        print('Original image size:', image_sitk.GetSize())
        print('Resampled image size:', resampled_sitk.GetSize())
        
        output_file_path = os.path.join(output_folder, f"{patient_id}.nii.gz")

        # Save the resampled image
        sitk.WriteImage(resampled_sitk, output_file_path, useCompression=True)

        print(f"Processed and saved: {output_file_path}")
        processed_count += 1

    else:
        print(f"Warning: {file_path} not found for {patient_id}")

print(f"Successfully processed {processed_count} patients (segmentation mask).")

Postprocessing, obtaining the predicted mask restore it to the original

In [None]:
def resample_to_spacing(image_sitk, target_spacing):
    original_spacing = image_sitk.GetSpacing()
    original_size = image_sitk.GetSize()
    new_size = [
        int(round(osz * ospc / tspc))
        for osz, ospc, tspc in zip(original_size, original_spacing, target_spacing)
    ]
    resample = sitk.ResampleImageFilter()
    resample.SetOutputSpacing(target_spacing)
    resample.SetSize(new_size)
    resample.SetOutputDirection(image_sitk.GetDirection())
    resample.SetOutputOrigin(image_sitk.GetOrigin())
    resample.SetTransform(sitk.Transform())
    resample.SetDefaultPixelValue(0)
    resample.SetInterpolator(sitk.sitkNearestNeighbor)
    return resample.Execute(image_sitk)

def restore_mask_to_original_space_from_data(patient_id, predicted_mask_path, original_image_path, coords, output_path):
    original_sitk = sitk.ReadImage(original_image_path)
    restored_sitk = sitk.Image(original_sitk.GetSize(), sitk.sitkUInt8)
    restored_sitk.CopyInformation(original_sitk)

    cropped_mask_sitk = sitk.ReadImage(predicted_mask_path, sitk.sitkUInt8)
    resampled_pred = resample_to_spacing(cropped_mask_sitk, coords["original_spacing"])
    origin_phys = coords["origin_phys"]
    resampled_pred.SetOrigin(origin_phys)

    destination_index = original_sitk.TransformPhysicalPointToIndex(origin_phys)

    restored_sitk = sitk.Paste(
        destinationImage=restored_sitk,
        sourceImage=resampled_pred,
        sourceSize=resampled_pred.GetSize(),
        sourceIndex=[0, 0, 0],
        destinationIndex=destination_index
    )

    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    sitk.WriteImage(restored_sitk, output_path, useCompression=True)

def restore_all_masks(cropping_metadata_path, predicted_masks_dir, original_images_root, output_root):
    with open(cropping_metadata_path) as f:
        metadata = json.load(f)

    for patient_id, coords in tqdm(metadata.items(), desc="Restoring masks"):
        predicted_mask_path = os.path.join(predicted_masks_dir, f"{patient_id}.nii.gz")
        original_image_path = os.path.join(original_images_root, patient_id, f"{patient_id}_0001.nii.gz")
        output_path = os.path.join(output_root, patient_id, f"{patient_id}_restored.nii.gz")

        if not os.path.exists(predicted_mask_path):
            print(f"⚠️ Missing prediction for {patient_id}")
            continue
        if not os.path.exists(original_image_path):
            print(f"⚠️ Missing original image for {patient_id}")
            continue

        restore_mask_to_original_space_from_data(
            patient_id=patient_id,
            predicted_mask_path=predicted_mask_path,
            original_image_path=original_image_path,
            coords=coords,
            output_path=output_path
        )
        
restore_all_masks(
    cropping_metadata_path="/path/to/cropping_coords.json",
    predicted_masks_dir="/results/nnUNet/nnUNet_results/Dataset108_multiC/nnUNetTrainer__.../fold_0/validation",
    original_images_root="/data/MAMA-MIA/images",
    output_root="/data/nnUNet/RestoredMasks"
)
