In [1]:
from google.colab import drive
drive.mount('/content/drive')

# Example: Accessing a file in your Google Drive
ACDC_dataset_path = '/content/drive/My Drive/GP_Data_Folder/GP_Data-Sets/ACDC/database'
print(f"File path: {ACDC_dataset_path}")

# To get the current working directory:
import os
current_directory = os.getcwd()
print(f"\nCurrent working directory: {current_directory}")

# To list files in a directory:
import os
directory_to_list = '/content/drive/My Drive'  # Replace with your desired directory
try:
  files = os.listdir(ACDC_dataset_path)
  print(f"\nFiles and directories in {ACDC_dataset_path}:")
  for file in files:
    print(file)
except FileNotFoundError:
  print(f"Error: Directory not found at {ACDC_dataset_path}")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
File path: /content/drive/My Drive/GP_Data_Folder/GP_Data-Sets/ACDC/database

Current working directory: /content

Files and directories in /content/drive/My Drive/GP_Data_Folder/GP_Data-Sets/ACDC/database:
MANDATORY_CITATION.md
LICENSE_TERMS.md
STATE-OF-THE-ART-nnUNet-method.md
training
test_standardized
train_standardized
testing
processed_training
processed_testing


In [2]:
!pip install SimpleITK nibabel pydicom gdown



In [3]:
import os
import re
import sys
import logging
import importlib.util
import numpy as np
import nibabel as nib
import SimpleITK as sitk
import scipy.interpolate as spi
import matplotlib.pyplot as plt
%matplotlib inline
from pathlib import Path
from tqdm import tqdm
from typing import Tuple, Optional, List
import shutil


# Configuration
class Config:
    # BASE_PATH = Path(os.path.abspath(os.path.join(os.getcwd(), "../../Data/ACDC/database"))) # Locally
    BASE_PATH = Path(ACDC_dataset_path) # Google Drive
    TARGET_SPACING = (1.0, 1.0)  # (x, y)
    TARGET_SHAPE = (512, 512)     # (height, width)
    MAX_PATIENTS = None              # Set to None for all patients
    SEED = 42                     # For reproducible sampling

# Setup logging
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
logger = logging.getLogger(__name__)

In [4]:
def process_4d_volume(nifti_path: Path, output_dir: Path) -> None:
    """Process a single 4D NIfTI file into standardized slices."""
    try:
        img = nib.load(nifti_path)
        data = img.get_fdata()  # Shape: (X, Y, Z, T)
        patient_id = nifti_path.stem.split('_')[0]

        for t in range(data.shape[-1]):  # Time frames
            for z in range(data.shape[2]):  # Slices
                process_slice(data[:, :, z, t], patient_id, t, z, output_dir)

    except Exception as e:
        logger.error(f"Failed {nifti_path.name}: {str(e)}")

def process_slice(slice_data: np.ndarray, pid: str, t: int, z: int, output_dir: Path) -> None:
    """Process and save a single slice with localization."""
    try:
        # Resample
        slice_sitk = sitk.GetImageFromArray(slice_data.T)  # Handle axis order
        resampled = resample_slice(slice_sitk, Config.TARGET_SPACING)

        # Normalize
        normalized = normalize_slice(sitk.GetArrayFromImage(resampled))

        # Pad
        padded = pad_to_shape(normalized, Config.TARGET_SHAPE)

        # Save
        save_path = output_dir / f"{pid}_t{t:02d}_z{z:02d}.npy"
        np.save(save_path, padded.astype(np.float32))

    except Exception as e:
        logger.error(f"Failed {pid} t{t} z{z}: {str(e)}")

def resample_slice(slice_sitk: sitk.Image, target_spacing: Tuple[float, float]) -> sitk.Image:
    """
    Resample slice using your custom monotonic interpolation method.
    Returns a SimpleITK image with correct metadata.
    """
    # Get original metadata
    original_spacing = slice_sitk.GetSpacing()
    original_direction = slice_sitk.GetDirection()
    original_origin = slice_sitk.GetOrigin()

    # Calculate resize factor (X, Y only)
    resize_factor = np.array([
        original_spacing[0] / target_spacing[0],
        original_spacing[1] / target_spacing[1]
    ])

    # Convert to numpy array (preserve axis order)
    image_np = sitk.GetArrayFromImage(slice_sitk)  # Shape: (Z, Y, X) -> but Z=1 for 2D

    # Remove singleton dimension if needed
    if image_np.shape[0] == 1:
        image_np = image_np[0]  # Shape: (Y, X)

    # Apply your custom interpolation
    resampled_np = monotonic_zoom_interpolate(
        image_np,
        resize_factor[::-1]  # Reverse for (Y, X) axes
    )

    # Create new SimpleITK image
    resampled_sitk = sitk.GetImageFromArray(resampled_np)

    # Set metadata correctly
    resampled_sitk.SetSpacing((target_spacing[0], target_spacing[1]))
    resampled_sitk.SetDirection(original_direction)
    resampled_sitk.SetOrigin(original_origin)

    return resampled_sitk

def normalize_slice(slice_np: np.ndarray) -> np.ndarray:
    """Robust normalization using percentile clipping."""
    p1, p99 = np.percentile(slice_np, [1, 99])
    return np.clip((slice_np - p1) / (p99 - p1), 0, 1)

def pad_to_shape(slice_np: np.ndarray, target_shape: Tuple[int, int]) -> np.ndarray:
    """Symmetrically pad slice to target shape."""
    pads = [(max(0, (ts - s) // 2), max(0, ts - s - (ts - s) // 2))
            for s, ts in zip(slice_np.shape, target_shape)]
    return np.pad(slice_np, pads, mode='constant')

def monotonic_zoom_interpolate(image_np, resize_factor):
    """
    Apply monotonic zoom interpolation to a given image.
    """
    result = image_np.copy()

    for axis, factor in enumerate(resize_factor[::-1]):
        # Create a new array for the interpolated values
        new_length = int(result.shape[axis] * factor)
        x_old = np.arange(result.shape[axis])
        x_new = np.linspace(0, result.shape[axis] - 1, new_length)

        # Perform monotonic interpolation
        pchip_interp = spi.PchipInterpolator(x_old, result.take(indices=x_old, axis=axis), axis=axis)
        result = pchip_interp(x_new)

    return result


In [5]:
def run_pipeline(dataset_type: str = "training", output_dir_name: Optional[str] = None) -> None:
    """
    Main pipeline execution function.

    Args:
        dataset_type: Type of dataset to process (e.g., "training", "testing").
        output_dir_name: Custom name for the output directory. If None, defaults to "processed_{dataset_type}".
    """
    input_dir = Config.BASE_PATH / dataset_type

    # Set output directory name
    if output_dir_name is None:
        output_dir_name = f"processed_{dataset_type}"
    output_dir = Config.BASE_PATH / output_dir_name
    output_dir.mkdir(exist_ok=True)

    nifti_files = list(input_dir.glob("**/*_4d.nii.gz"))
    if Config.MAX_PATIENTS:
        nifti_files = nifti_files[:Config.MAX_PATIENTS]

    logger.info(f"Processing {len(nifti_files)} patients...")
    for nifti_path in tqdm(nifti_files, desc="Patients"):
        process_4d_volume(nifti_path, output_dir)

In [6]:
def validate_processing(output_dir: Path, num_samples: int = 5, dataset_type: str = "training") -> None:
    """Validate processed data with multiple checks."""
    # Check file consistency
    npy_files = list(output_dir.glob("*.npy"))
    assert len(npy_files) > 0, "No processed files found!"

    # Check array properties
    sample = np.load(npy_files[0])
    assert sample.shape == Config.TARGET_SHAPE, f"Shape mismatch: {sample.shape}"
    assert sample.dtype == np.float32, f"Dtype mismatch: {sample.dtype}"

    # Value range check
    min_val, max_val = sample.min(), sample.max()
    assert 0 <= min_val <= max_val <= 1, f"Value range error: [{min_val}, {max_val}]"

    logger.info("Basic validation passed!")

    # New: Intensity distribution check
    sample_values = sample.flatten()
    hist, bins = np.histogram(sample_values, bins=100)
    assert np.percentile(sample_values, 99) > 0.1, "Suspicious intensity distribution"

    logger.info("Enhanced validation passed!")

    # Visualize samples
    visualize_samples(output_dir, num_samples, dataset_type)
    logger.info("Visualization check passed!")

def visualize_samples(output_dir: Path, num_samples: int = 5, dataset_type: str = "training") -> None:
    """Visualize random samples with before/after comparison and save as PNG."""
    # Create a folder to save the visualizations
    print(output_dir)
    visualization_dir = output_dir / "visualizations"
    if not visualization_dir.exists():
        visualization_dir.mkdir(parents=True, exist_ok=True)
        logger.info(f"Created visualization directory: {visualization_dir}")
    else:
        logger.info(f"Visualization directory already exists: {visualization_dir}")

    np.random.seed(Config.SEED)
    sample_files = np.random.choice(list(output_dir.glob("*.npy")), num_samples)

    for sf in sample_files:
        # Parse filename to get patient ID, time frame, and slice index
        parts = sf.stem.split('_')
        patient_id = parts[0]
        time_frame = int(parts[1][1:])  # Extract number after 't'
        slice_index = int(parts[2][1:])  # Extract number after 'z'

        # Locate original 4D NIfTI file
        original_path = find_original_image(sf, dataset_type)
        if not original_path:
            logger.warning(f"Original image not found for {sf.name}")
            continue

        # Load original 4D NIfTI file
        orig_img = sitk.GetArrayFromImage(sitk.ReadImage(str(original_path)))  # Shape: (T, Z, Y, X)
        orig_img = orig_img.transpose()

        # Ensure the dimensions are correct
        if orig_img.ndim != 4:
            logger.error(f"Original image {original_path.name} is not 4D. Shape: {orig_img.shape}")
            continue

        # Extract the corresponding slice and time frame
        try:
            # Correctly extract the slice: (X, Y, Z, T) -> (X, Y) for the given Z and T
            orig_slice = orig_img[:, :, slice_index, time_frame].T  # Transpose for correct orientation
        except IndexError:
            logger.error(f"Invalid slice or time frame for {sf.name}: slice={slice_index}, frame={time_frame}")
            continue

        # Load processed numpy file
        processed = np.load(sf)  # Shape: (H, W)

        # Plot original and processed slices
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

        # Original slice
        ax1.imshow(orig_slice, cmap='gray')
        ax1.set_title(f"Original\n{original_path.name}\nSlice {slice_index}, Frame {time_frame}")

        # Processed slice
        ax2.imshow(processed, cmap='gray')
        ax2.set_title(f"Processed\n{sf.name}")

        # Save the figure as a PNG file
        output_file = visualization_dir / f"{sf.stem}_comparison.png"
        fig.savefig(output_file, bbox_inches="tight", dpi=300)
        plt.close(fig)  # Close the figure to free memory
        logger.info(f"Saved visualization: {output_file}")

def find_original_image(processed_path: Path, dataset_type: str = "training") -> Optional[Path]:
    """
    Locate original NIfTI file from processed numpy path.

    Args:
        processed_path: Path to the processed numpy file.
        dataset_type: Type of dataset to search in (e.g., "training", "testing").

    Returns:
        Path to the original NIfTI file, or None if not found.
    """
    parts = processed_path.stem.split('_')
    patient_id = parts[0]

    # Search in the correct dataset folder
    patient_folder = Config.BASE_PATH / dataset_type / patient_id
    return next(patient_folder.glob("*4d.nii.gz"), None)

In [None]:
# Process training data
run_pipeline("training", "processed_training")

# Process testing data (optional)
run_pipeline("testing", "processed_testing")

  return np.clip((slice_np - p1) / (p99 - p1), 0, 1)
Patients: 100%|██████████| 100/100 [20:26<00:00, 12.27s/it]
Patients: 100%|██████████| 50/50 [10:26<00:00, 12.52s/it]


In [7]:
# output_dir = Config.BASE_PATH / "processed_training"
# validate_processing(output_dir, num_samples=1000, dataset_type="training")
output_dir = Config.BASE_PATH / "processed_testing"
validate_processing(output_dir, num_samples=500, dataset_type="testing")

/content/drive/My Drive/GP_Data_Folder/GP_Data-Sets/ACDC/database/processed_testing
