In [1]:
import os
import random
import numpy as np
import json
import matplotlib.pyplot as plt
import SimpleITK as sitk
import nibabel as nib

from monai.transforms import (
    CenterSpatialCropd,
    Compose,
    EnsureChannelFirstd,
    LoadImaged,
    Resized,
    SpatialPadd,
    ScaleIntensityRangePercentilesd,
    ToTensord,
)

In [2]:
os.chdir(os.path.dirname(os.path.dirname(os.getcwd())))
print(os.getcwd())

/home/AR32500/AR32500/MyPapers/box-prompt-learning-VFM/src


## Functions for preprocessing

In [4]:
def get_val_patient_idx(train_patient_idx=[], num_val_patients=10):
    print('Creating a validation dataset with {} patients'.format(num_val_patients))
    _val_patient_idx = random.sample(train_patient_idx, num_val_patients)
    val_patient_idx = sorted(_val_patient_idx)
    return val_patient_idx


def get_train_test_patient_idx(data_dir, images_list, num_test_patients=10):
    """_summary_

    Args:
        data_dir (str): data directory
        images_list (list): list of paths to each images (ending with .nii.gz)
        num_test_patients (int, optional): number of test patients. Defaults to 10.

    Returns:
        train_patient_idx (list)
        test_patient_idx (list)
    """
    # If the data was already preprocessed, we take the patients idx separation already used
    preprocessed_data_dir = os.path.join(os.path.dirname(data_dir), 'preprocessed')
    
    if os.path.exists(preprocessed_data_dir):
        train_scan_info_path = os.path.join(preprocessed_data_dir, 'train', 'scan_info.json')
        with open(train_scan_info_path) as f:
            d = json.load(f)
            train_patient_idx = list(d.keys())
            
        test_scan_info_path = os.path.join(preprocessed_data_dir, 'test', 'scan_info.json')
        with open(test_scan_info_path) as f:
            d = json.load(f)
            test_patient_idx = list(d.keys())
    else: 
        print('{} does not exist'.format(preprocessed_data_dir))
        
        # We choose test patients
        patient_name_list = [os.path.basename(path).replace('.nii.gz', '') for path in images_list]
        _test_patient_idx = random.sample(patient_name_list, num_test_patients)
        test_patient_idx = sorted(_test_patient_idx)

        train_patient_idx = [name for name in patient_name_list if name not in test_patient_idx]
        
    return train_patient_idx, test_patient_idx
    
    
def get_train_val_test_list(data_dir, images_list, num_test_patients=10, num_val_patients=10):    
    _train_patient_idx, test_patient_idx = get_train_test_patient_idx(data_dir, images_list, num_test_patients)
    val_patient_idx = get_val_patient_idx(_train_patient_idx, num_val_patients)
    
    assert all(value in _train_patient_idx for value in val_patient_idx)
    train_patient_idx = [idx for idx in _train_patient_idx if idx not in val_patient_idx]
    
    print('train patients:', len(list(train_patient_idx)))
    print('val patients:', len(list(val_patient_idx)))
    print('test patients:', len(list(test_patient_idx)))
    
    return train_patient_idx, val_patient_idx, test_patient_idx

In [5]:
def create_sam_directories(base_dir, type='slice'):
    """Create directories"""
    dir_paths = {}
    if type == 'slice':
        for dataset in ['train', 'val', 'test']:
            for data_type in ['2d_images', '2d_masks']:
                # Construct the directory path
                dir_path = os.path.join(base_dir, f'{dataset}_{data_type}')
                dir_paths[f'{dataset}_{data_type}'] = dir_path
                # Create the directory
                os.makedirs(dir_path, exist_ok=True)
    elif type == 'volume':
        for dataset in ["imagesTr", "labelsTr"]:
            # Construct the directory path
            dir_path = os.path.join(base_dir, dataset)
            dir_paths[dataset] = dir_path
            # Create the directory
            os.makedirs(dir_path, exist_ok=True)        
    return dir_paths


def ceil_to_multiple_of_5(n):
    return 5 * np.ceil(n / 5.)

# Preprocessing dataset

In [6]:
data_dir = '/data/users/melanie/data/'
dataset_name = 'CAMUS_public'
file_type = '_2CH_ED'
suffix = '_niigz'

frac_test_patients = 0.2
frac_val_patients = 0.1

remove_background_slices = True
class_list = [1, 2, 3]

# For _512 data
crop_pad_size = (512, 512)
new_size = (512, 512) #(no resizing)


In [9]:
# We create create 6 folders in the subfolder 'preprocessed_sam' of the directory associated with the dataset: 
# 2 folders ('2d_images' and '2d_masks') for each cut ['train', 'val', 'test'].

raw_data_dir = os.path.join(data_dir, dataset_name, 'raw', 'database_nifti')
base_dir_slice = os.path.join(data_dir, dataset_name, 'preprocessed_sam')

patient_name_list = sorted(os.listdir(os.path.join(data_dir, dataset_name, 'raw', 'database_nifti')))
print(len(patient_name_list))

# We get train, val and test patient names
num_test_patients = int(ceil_to_multiple_of_5(len(patient_name_list) * frac_test_patients))
num_val_patients = int(len(patient_name_list) * frac_val_patients)
train_patient_idx, val_patient_idx, test_patient_idx = get_train_val_test_list(raw_data_dir, patient_name_list, num_test_patients, num_val_patients)

# Create directories to save preprocessed volumes and slices
dir_paths_slice = create_sam_directories(base_dir_slice, type="slice")

dir_paths_slice

500
/data/users/melanie/data/CAMUS_public/raw/preprocessed does not exist
Creating a validation dataset with 50 patients
train patients: 350
val patients: 50
test patients: 100


{'train_2d_images': '/data/users/melanie/data/CAMUS_public/preprocessed_sam/train_2d_images',
 'train_2d_masks': '/data/users/melanie/data/CAMUS_public/preprocessed_sam/train_2d_masks',
 'val_2d_images': '/data/users/melanie/data/CAMUS_public/preprocessed_sam/val_2d_images',
 'val_2d_masks': '/data/users/melanie/data/CAMUS_public/preprocessed_sam/val_2d_masks',
 'test_2d_images': '/data/users/melanie/data/CAMUS_public/preprocessed_sam/test_2d_images',
 'test_2d_masks': '/data/users/melanie/data/CAMUS_public/preprocessed_sam/test_2d_masks'}

In [None]:
# val_patient_idx = ['patient0014', 'patient0016', 'patient0040', 'patient0057', 'patient0059', 'patient0068', 'patient0069', 'patient0076', 'patient0085', 'patient0091', 'patient0099', 'patient0107', 'patient0135', 'patient0156', 'patient0165', 'patient0187', 'patient0194', 'patient0203', 'patient0213', 'patient0225', 'patient0236', 'patient0237', 'patient0239', 'patient0251', 'patient0258', 'patient0259', 'patient0281', 'patient0293', 'patient0302', 'patient0305', 'patient0321', 'patient0324', 'patient0326', 'patient0329', 'patient0340', 'patient0344', 'patient0347', 'patient0356', 'patient0370', 'patient0379', 'patient0385', 'patient0414', 'patient0422', 'patient0431', 'patient0448', 'patient0458', 'patient0460', 'patient0475', 'patient0485', 'patient0494']
# test_patient_idx = ['patient0011', 'patient0022', 'patient0029', 'patient0030', 'patient0034', 'patient0037', 'patient0041', 'patient0051', 'patient0055', 'patient0058', 'patient0060', 'patient0062', 'patient0066', 'patient0072', 'patient0073', 'patient0074', 'patient0089', 'patient0096', 'patient0097', 'patient0101', 'patient0113', 'patient0114', 'patient0115', 'patient0120', 'patient0124', 'patient0125', 'patient0134', 'patient0138', 'patient0146', 'patient0150', 'patient0158', 'patient0161', 'patient0162', 'patient0173', 'patient0183', 'patient0189', 'patient0190', 'patient0210', 'patient0212', 'patient0221', 'patient0227', 'patient0228', 'patient0229', 'patient0233', 'patient0244', 'patient0252', 'patient0255', 'patient0257', 'patient0261', 'patient0266', 'patient0276', 'patient0277', 'patient0292', 'patient0295', 'patient0296', 'patient0303', 'patient0318', 'patient0330', 'patient0334', 'patient0335', 'patient0336', 'patient0338', 'patient0341', 'patient0343', 'patient0349', 'patient0350', 'patient0352', 'patient0354', 'patient0358', 'patient0365', 'patient0369', 'patient0372', 'patient0373', 'patient0380', 'patient0386', 'patient0390', 'patient0393', 'patient0400', 'patient0404', 'patient0407', 'patient0413', 'patient0423', 'patient0428', 'patient0430', 'patient0438', 'patient0439', 'patient0442', 'patient0443', 'patient0444', 'patient0445', 'patient0456', 'patient0464', 'patient0473', 'patient0476', 'patient0477', 'patient0486', 'patient0490', 'patient0491', 'patient0495', 'patient0499']
# train_patient_idx = [f for f in patient_name_list if (f not in val_patient_idx) and (f not in test_patient_idx)]
# print(len(train_patient_idx), len(val_patient_idx), len(test_patient_idx))

350 50 100


In [None]:
# Added ScaleIntensityRangePercentilesd on Feb 14th 2024
transforms = Compose([
    LoadImaged(keys=["img", "label"]),  # load .nii or .nii.gz files
    EnsureChannelFirstd(keys=['img', 'label']),
    CenterSpatialCropd(keys=['img', 'label'], roi_size=crop_pad_size), 
    SpatialPadd(keys=["img", "label"], spatial_size=crop_pad_size), # pad if size smaller than 512 x 512 --> get size 512 x 512 (since already cropped)
    Resized(keys=["img", "label"], spatial_size=new_size, mode=['bilinear', 'nearest']),
    ScaleIntensityRangePercentilesd(keys=["img"], 
                     lower=0.5,  upper=99.5,  # This should call the percentile_scale function to get the 95th percentile
                      b_min=0, b_max=255, clip=True),
    ToTensord(keys=["img", "label"])
    ])

In [None]:
for patient_name in patient_name_list:
    img_path = os.path.join(data_dir, dataset_name, 'raw', 'database_nifti', patient_name, patient_name + file_type + '.nii.gz') 
    mask_path = os.path.join(data_dir, dataset_name, 'raw', 'database_nifti', patient_name, patient_name + file_type + '_gt.nii.gz') 

    data_dict = transforms({'img': img_path, 'label': mask_path})
    img = data_dict['img'][0, :, :].astype(np.uint8)
    mask = data_dict['label'][0, :, :].astype(np.uint8)

    print(patient_name, nib.load(img_path).get_fdata().shape)

    # Optionally remove non-informative slices
    #if remove_background_slices and np.all(mask_2d == 0):
    unique_labels = np.unique(mask)
    if remove_background_slices and not all(label in unique_labels for label in class_list):
        print('not all label classes: {}'.format(patient_name))
        pass

    else:
        # Select appropriate directories
        if patient_name in train_patient_idx:  # Training
            img_dir = dir_paths_slice['train_2d_images']
            mask_dir = dir_paths_slice['train_2d_masks']
        elif patient_name in val_patient_idx:  # Validation
            img_dir = dir_paths_slice['val_2d_images']
            mask_dir = dir_paths_slice['val_2d_masks']
        else:  # Testing
            img_dir = dir_paths_slice['test_2d_images']
            mask_dir = dir_paths_slice['test_2d_masks']
    
        # Define the output paths
        filename = os.path.basename(img_path)
        img_slice_path = os.path.join(img_dir, filename)
        mask_slice_path = os.path.join(mask_dir, filename)
    
        sitk.WriteImage(sitk.GetImageFromArray(img), img_slice_path)
        sitk.WriteImage(sitk.GetImageFromArray(mask), mask_slice_path)
    
        fig = plt.figure(figsize=(10, 10))
        ax = fig.add_subplot(121)
        ax.imshow(img)
        ax = fig.add_subplot(122)
        ax.imshow(mask)