In [5]:
# preprocess_masks.ipynb

import multiprocessing
from pathlib import Path
import nibabel as nib
import numpy as np
from scipy.ndimage import zoom
from tqdm import tqdm

# --- Configuration ---
RAW_DATA_ROOT = Path("/mnt/hot/public/COPDGene-1")
PROCESSED_DATA_DIR = Path("/mnt/hot/public/Akul/exhale_pred_data")
TARGET_SHAPE = (128, 128, 128)

def process_mask(patient_id: str):
    """
    Loads, resizes (using nearest-neighbor interpolation), and saves the exhale
    segmentation mask for a single patient ID.
    """
    try:
        # We only need the exhale mask for evaluating the exhale reconstruction
        mask_path = RAW_DATA_ROOT / f"{patient_id}_EXP_mask.nii.gz"

        if not mask_path.exists():
            return f"Skipped {patient_id}: Missing exhale mask file."

        mask_nii = nib.load(mask_path)
        mask_data = mask_nii.get_fdata().astype(np.uint8) # Masks are integer types
        
        # --- IMPORTANT ---
        # Resize the mask using nearest-neighbor interpolation (order=0) to
        # preserve the binary nature of the mask without creating intermediate values.
        zoom_factors = [t / s for t, s in zip(TARGET_SHAPE, mask_data.shape)]
        mask_resized = zoom(mask_data, zoom_factors, order=0, prefilter=False)
        
        # Save the processed mask
        save_dir = PROCESSED_DATA_DIR / "masks" / "exhale"
        np.save(save_dir / f"{patient_id}_EXP_mask.npy", mask_resized)
        
        return None
    except Exception as e:
        return f"Error processing {patient_id} mask: {e}"

def main():
    """
    Finds all patient IDs with exhale masks and processes them in parallel.
    """
    # Create the output directory
    (PROCESSED_DATA_DIR / "masks" / "exhale").mkdir(parents=True, exist_ok=True)

    all_mask_files = list(RAW_DATA_ROOT.glob("*_EXP_mask.nii.gz"))
    patient_ids = sorted([f.name.split('_')[0] for f in all_mask_files])

    if not patient_ids:
        print(f"Error: No exhale masks found in {RAW_DATA_ROOT}. Please check the path and file naming.")
        return

    print(f"Found {len(patient_ids)} patient exhale masks. Starting preprocessing...")
    
    num_processes = multiprocessing.cpu_count()
    print(f"Starting parallel processing with {num_processes} workers.")
    
    with multiprocessing.Pool(processes=num_processes) as pool:
        results = list(tqdm(pool.imap_unordered(process_mask, patient_ids), total=len(patient_ids)))

    error_count = sum(1 for res in results if res is not None)
    for res in results:
        if res is not None:
            print(res)

    print("\\n--- Mask Preprocessing Complete ---")
    print(f"Successfully processed: {len(patient_ids) - error_count} masks.")
    print(f"Failed to process: {error_count} masks.")

if __name__ == "__main__":
    main()

Found 8701 patient exhale masks. Starting preprocessing...
Starting parallel processing with 40 workers.


100%|██████████| 8701/8701 [08:55<00:00, 16.26it/s]


\n--- Mask Preprocessing Complete ---
Successfully processed: 8701 masks.
Failed to process: 0 masks.


In [4]:
# preprocess_masks.ipynb

import multiprocessing
from pathlib import Path
import nibabel as nib
import numpy as np
from scipy.ndimage import zoom
from tqdm import tqdm

# --- Configuration ---
RAW_DATA_ROOT = Path("/mnt/hot/public/COPDGene-1")
PROCESSED_DATA_DIR = Path("/mnt/hot/public/Akul/exhale_pred_data")
TARGET_SHAPE = (128, 128, 128)

def process_mask(patient_id: str):
    """
    Loads, resizes (using nearest-neighbor interpolation), and saves the exhale
    segmentation mask for a single patient ID.
    """
    try:
        # We only need the exhale mask for evaluating the exhale reconstruction
        mask_path = RAW_DATA_ROOT / f"{patient_id}_INSP_mask.nii.gz"

        if not mask_path.exists():
            return f"Skipped {patient_id}: Missing exhale mask file."

        mask_nii = nib.load(mask_path)
        mask_data = mask_nii.get_fdata().astype(np.uint8) # Masks are integer types
        
        # --- IMPORTANT ---
        # Resize the mask using nearest-neighbor interpolation (order=0) to
        # preserve the binary nature of the mask without creating intermediate values.
        zoom_factors = [t / s for t, s in zip(TARGET_SHAPE, mask_data.shape)]
        mask_resized = zoom(mask_data, zoom_factors, order=0, prefilter=False)
        
        # Save the processed mask
        save_dir = PROCESSED_DATA_DIR / "masks" / "inhale"
        np.save(save_dir / f"{patient_id}_EXP_mask.npy", mask_resized)
        
        return None
    except Exception as e:
        return f"Error processing {patient_id} mask: {e}"

def main():
    """
    Finds all patient IDs with exhale masks and processes them in parallel.
    """
    # Create the output directory
    (PROCESSED_DATA_DIR / "masks" / "inhale").mkdir(parents=True, exist_ok=True)

    all_mask_files = list(RAW_DATA_ROOT.glob("*_EXP_mask.nii.gz"))
    patient_ids = sorted([f.name.split('_')[0] for f in all_mask_files])

    if not patient_ids:
        print(f"Error: No exhale masks found in {RAW_DATA_ROOT}. Please check the path and file naming.")
        return

    print(f"Found {len(patient_ids)} patient exhale masks. Starting preprocessing...")
    
    num_processes = multiprocessing.cpu_count()
    print(f"Starting parallel processing with {num_processes} workers.")
    
    with multiprocessing.Pool(processes=num_processes) as pool:
        results = list(tqdm(pool.imap_unordered(process_mask, patient_ids), total=len(patient_ids)))

    error_count = sum(1 for res in results if res is not None)
    for res in results:
        if res is not None:
            print(res)

    print("\\n--- Mask Preprocessing Complete ---")
    print(f"Successfully processed: {len(patient_ids) - error_count} masks.")
    print(f"Failed to process: {error_count} masks.")

if __name__ == "__main__":
    main()

Found 8701 patient exhale masks. Starting preprocessing...
Starting parallel processing with 40 workers.


100%|██████████| 8701/8701 [09:21<00:00, 15.49it/s]


\n--- Mask Preprocessing Complete ---
Successfully processed: 8701 masks.
Failed to process: 0 masks.
