In [1]:
# !pip install SimpleITK

import SimpleITK as sitk
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
import re
import shutil
import random
import SimpleITK as sitk
import os
import numpy as np
from sklearn.model_selection import train_test_split

def save_image(image, output_path):
    # Ensure the image is in the correct format (unsigned char) for PNG
    image = sitk.Cast(sitk.RescaleIntensity(image), sitk.sitkUInt8)
    sitk.WriteImage(image, output_path)

def process_and_split_data(case_dir, case_name, train_dir, test_dir, test_size=0.2):
    print(f'Processing case: {case_name}...')

    # Paths to the image files
    consensus_path = os.path.join(case_dir, 'Consensus.nii')
    flair_path = os.path.join(case_dir, '3DFLAIR.nii')
    
    # Load the images
    consensus_image = sitk.ReadImage(consensus_path)
    flair_image = sitk.ReadImage(flair_path)

    # Assuming that the third dimension Z is the slice direction
    z_slices = range(consensus_image.GetSize()[2])
    
    # Split the indices into train and test sets
    train_indices, test_indices = train_test_split(list(z_slices), test_size=test_size, random_state=42)

    # Process each slice and save to the appropriate directory
    for z_slice in train_indices:
        save_image(consensus_image[:, :, z_slice], os.path.join(train_dir, 'labels', f"{case_name}_consensus_{z_slice:03}.png"))
        save_image(flair_image[:, :, z_slice], os.path.join(train_dir, 'images', f"{case_name}_flair_{z_slice:03}.png"))

    for z_slice in test_indices:
        save_image(consensus_image[:, :, z_slice], os.path.join(test_dir, 'labels', f"{case_name}_consensus_{z_slice:03}.png"))
        save_image(flair_image[:, :, z_slice], os.path.join(test_dir, 'images', f"{case_name}_flair_{z_slice:03}.png"))

    print(f'Finished processing case: {case_name}')

# Directories for train and test datasets
output_dir = './data'
train_dir = os.path.join(output_dir, 'train')
test_dir = os.path.join(output_dir, 'test')

# Create directories if they don't exist
os.makedirs(train_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)
os.makedirs(os.path.join(train_dir, 'images'), exist_ok=True)
os.makedirs(os.path.join(train_dir, 'labels'), exist_ok=True)
os.makedirs(os.path.join(test_dir, 'images'), exist_ok=True)
os.makedirs(os.path.join(test_dir, 'labels'), exist_ok=True)

# Root directory of the cases
data_root_dir = '../0_Data_reg_inter_rigid'

# Process each case
case_dirs = [d for d in os.listdir(data_root_dir) if os.path.isdir(os.path.join(data_root_dir, d))]
for case_name in case_dirs:
    case_dir = os.path.join(data_root_dir, case_name)
    process_and_split_data(case_dir, case_name, train_dir, test_dir)


Processing case: 01016SACH...
Finished processing case: 01016SACH
Processing case: 01038PAGU...
Finished processing case: 01038PAGU
Processing case: 01039VITE...
Finished processing case: 01039VITE
Processing case: 01040VANE...
Finished processing case: 01040VANE
Processing case: 01042GULE...
Finished processing case: 01042GULE
Processing case: 07001MOEL...
Finished processing case: 07001MOEL
Processing case: 07003SATH...
Finished processing case: 07003SATH
Processing case: 07010NABO...
Finished processing case: 07010NABO
Processing case: 07040DORE...
Finished processing case: 07040DORE
Processing case: 07043SEME...
Finished processing case: 07043SEME
Processing case: 08002CHJE...
Finished processing case: 08002CHJE
Processing case: 08027SYBR...
Finished processing case: 08027SYBR
Processing case: 08029IVDI...
Finished processing case: 08029IVDI
Processing case: 08031SEVE...
Finished processing case: 08031SEVE
Processing case: 08037ROGU...
Finished processing case: 08037ROGU


In [13]:
# def generate_consensus_labels():
#     image_path = 'data/test/images'
#     label_path = 'data/test/labels'
    
#     caster = sitk.CastImageFilter()
#     caster.SetOutputPixelType(sitk.sitkUInt8)

#     # Iterate over all files in the folder
#     for filename in os.listdir(image_path):
#         file_path = os.path.join(image_path, filename)

#         # Check if it's a file (not a subdirectory)
#         if os.path.isfile(file_path):
#             # Extract information using regex
#             parts = filename.split('_')

#             if len(parts) == 3:
#                 # Extract imageName and slice number
#                 imageName, fileType, slice_extension = parts
#                 sliceNumber, extension = slice_extension.split('.')
#                 sliceNumber = int(sliceNumber.lstrip('0'))
                
#                 print(f"Generating consensus for: {imageName}-{sliceNumber}...")
                
#                 ### LOAD SEGMENTATION FILES
#                 con_image = sitk.ReadImage(f'data/{imageName}/Consensus.nii.gz')
                
#                 sliced_image = con_image[sliceNumber, :, :]
#                 sliced_array = sitk.GetArrayFromImage(sliced_image)
                
#                 binary_mask = (sliced_array > 0).astype(np.uint8)
#                 output_image = sitk.GetImageFromArray(binary_mask)
                
#                 output_image.SetOrigin(sliced_image.GetOrigin())
#                 output_image.SetSpacing(sliced_image.GetSpacing())
#                 output_image.SetDirection(sliced_image.GetDirection())
                
#                 sliceNumber = str(sliceNumber).zfill(3)
#                 png_output_path = f'{label_path}/{imageName}_{sliceNumber}.png'
#                 sitk.WriteImage(output_image, png_output_path)
                
#                 file_path_to_delete = f'{label_path}/{imageName}_{fileType}_{sliceNumber}.{extension}'
#                 if os.path.exists(file_path_to_delete):
#                     os.remove(file_path_to_delete)

#                 # Create the new file name without {type}
#                 new_filename = f'{imageName}_{sliceNumber}.{extension}'  # Adjust the file extension if needed
#                 if not os.path.exists(os.path.join(image_path, new_filename)):
#                     os.rename(file_path, os.path.join(image_path, new_filename))
                    
#                 file_path_to_delete = f'{image_path}/{imageName}_{fileType}_{sliceNumber}.{extension}'
#                 if os.path.exists(file_path_to_delete):
#                     os.remove(file_path_to_delete)
# # generate consensus images for test set
# generate_consensus_labels()