In [1]:
import os.path 
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from PIL import Image
from numpy import ndarray

# Folder where the created images will be saved in
out_path = r'/local/data1/elech646/Tumor_grade_classification/single_channel_slices'
dataset_path = r'/local/data1/elech646/Tumor_grade_classification/original_dataset/HGG'

# Create subfolders
if not os.path.exists(out_path + "/sagittal_grade_classification"):
    os.mkdir(out_path + "/sagittal_grade_classification")
    
if not os.path.exists(out_path + "/frontal_grade_classification"):
    os.mkdir(out_path + "/frontal_grade_classification")

if not os.path.exists(out_path + "/trans_grade_classification"):
    os.mkdir(out_path + "/trans_grade_classification")

# Add HGG path
sag_path = out_path + "/sagittal_grade_classification" + "/HGG"
fro_path = out_path + "/frontal_grade_classification" + "/HGG"
tra_path = out_path + "/trans_grade_classification" + "/HGG"
    
if not os.path.exists(sag_path):
    os.mkdir(sag_path)
    
if not os.path.exists(fro_path):
    os.mkdir(fro_path)
    
if not os.path.exists(tra_path):
    os.mkdir(tra_path)

In [4]:
seg_path = []
patient_name = []

for roots, dirs, files in os.walk("/local/data1/elech646/Tumor_grade_classification/original_dataset/HGG_tumor_annotations"):
    for name in files:
        if name.endswith((".nii.gz",".nii")):
            seg_path.append(roots + os.path.sep + name)
            patient_name.append('_'.join(name.split('_')[:3]))

# Loop through the subjects
for p_name, s_path in zip(patient_name, seg_path):
    # Load the segmentation file
    seg_img = nib.load(s_path)
    seg_img_data = seg_img.get_fdata()
    
    # Create subfolders 
    if not os.path.exists(os.path.join(sag_path, p_name)):
        os.mkdir(os.path.join(sag_path, p_name))
    if not os.path.exists(os.path.join(fro_path, p_name)):
        os.mkdir(os.path.join(fro_path, p_name))
    if not os.path.exists(os.path.join(tra_path, p_name)):
        os.mkdir(os.path.join(tra_path, p_name))
    
    # Loop through the modalities of choice
    modalities = ['t1', 'flair']
    aus_modalities = []
                 
    for m in modalities:
        #print(f'Working on {p_name}, modality {m} \r', end = '')
        # Load full image for this modality
        #mod_img = nib.load(os.path.join(dataset_path, patient_name[i], patient_name[i] + '_' + idx + '.nii'))
        mod_img = nib.load(os.path.join(dataset_path, p_name, '_'.join([p_name, m]) + '.nii'))
        mod_img_data = mod_img.get_fdata()
        
        # Get min, max
        min_v = mod_img_data.min()
        max_v = mod_img_data.max()

        # Normalize the image
        mod_img_data = (255*(mod_img_data - min_v) / (max_v - min_v)).astype(np.uint8)
        aus_modalities.append(mod_img_data)
        
############################ Uncomment for the function you want every time ######################################

    # Start creating images for the sagittal plane
    def create_sagittal(list_modality, segmentation, save_path = None):
        '''
        Creates sagittal images as RGB stacking tumor slices from the 
        given modalities as individual channels
        
        Arguments:
        list_modality : list
            Each element coming from something like mod_img.get_fdata()
        segmentation : np array
            binary volume defining where the tumor is
        save_path : str
            the path where images are saved        
        '''
        # fix the volumes to match the sagittal plane
        # segmentation
        segmentation = np.rot90(segmentation, axes = (1, 2)) # yz plane sagittal
        segmentation = np.flip(segmentation, 0) # flip the image left/right 
        
        # the modality volumes
        for idx, v in enumerate(list_modality):
            # do the magic for Mango
            v = np.rot90(v, axes = (1, 2)) # yz plane sagittal
            v = np.flip(v, 0) # flip the image left/right
            list_modality[idx] = v
        
        # get indices of tumor slices
        sag_0 = min(ndarray.nonzero(segmentation)[0])   # zmin
        sag_1 = max(ndarray.nonzero(segmentation)[0])   # zmax
        
        # loop through all the indices
        for sag in range(sag_0, sag_1 + 1):
            perc = int(((sag - sag_0)/(sag_1 - sag_0))*100) # Percentage along the selected slices
            
            aus_mod = []
            for idx in range(len(list_modality)):
                aus_mod.append(list_modality[idx][sag,:,:])
            
            # check that we have all the channels
            if len(aus_mod) != 3:
                for i in range(3-len(aus_mod)):
                    aus_mod.append(aus_mod[-1])
                    
            # convert list to array [W, H, CH]
            slices = np.stack(aus_mod, axis = -1)
                         
            # name the files
            aus_name = modalities
            if len(modalities) != 3:
                for i in range(3-len(modalities)):
                    aus_name.append(modalities[-1])
            aus_name = '_'.join(aus_name)
                
            title = os.path.join(save_path, p_name, p_name + '_sag_' + aus_name +\
                                 '_' + str(sag) + '_' + str(perc) + '.png')
            
            # convert to RGB
            im = Image.fromarray(slices).convert('RGB')
            
            # save images
            im.save(title)
            
    create_sagittal(aus_modalities, seg_img_data, save_path = sag_path)
    
    # Start creating images for the frontal plane
#     def create_frontal(list_modality, segmentation, save_path = None):
#         '''
#         Creates frontal images as RGB stacking tumor slices from the 
#         given modalities as individual channels.
        
#         Arguments: 
#         list_modality : list
#             Each element coming from something like mod_img.get_fdata()
#         segmentation : np array
#             binary volume defining where the tumor is
#         save_path : str
#             the path where images are saved     
#         '''
#         # fix the volumes to match the frontal plane
#         # segmentation
#         segmentation = np.rot90(segmentation, axes = (0,2)) # xz plane frontal
        
#         # the modality volumes
#         for idx, v in enumerate(list_modality):
#             # do the magic for Mango
#             v = np.rot90(v, axes = (0,2))
#             list_modality[idx] = v
        
#         # get indices of tumor slices
#         fr_0 = min(ndarray.nonzero(segmentation)[1])   # ymin
#         fr_1 = max(ndarray.nonzero(segmentation)[1])   # ymax
        
#         # loop through all the indices
#         for front in range(fr_0, fr_1 + 1):
#             perc = int(((front - fr_0)/(fr_1 - fr_0))*100) # Percentage along the selected slices
            
#             aus_mod = []
#             for idx in range(len(list_modality)):
#                 aus_mod.append(list_modality[idx][:,front,:])
            
#             # check that we have all the channels
#             if len(aus_mod) != 3:
#                 for i in range(3-len(aus_mod)):
#                     aus_mod.append(aus_mod[-1])
                    
#             # convert list to array [W, H, CH]
#             slices = np.stack(aus_mod, axis = -1)
                         
#             # Name the files
#             aus_name = modalities
#             if len(modalities) != 3:
#                 for i in range(3-len(modalities)):
#                     aus_name.append(modalities[-1])
#             aus_name = '_'.join(aus_name)
                
#             title = os.path.join(save_path, p_name, p_name + '_fro_' + aus_name +\
#                                  '_' + str(front) + '_' + str(perc) + '.png')
#             # Convert to RGB
#             im = Image.fromarray(slices).convert('RGB')
            
#             # save
#             im.save(title)
            
#     create_frontal(aus_modalities, seg_img_data, save_path = fro_path)
    
    # Start creating images for the transversal plane
#     def create_transversal(list_modality, segmentation, save_path = None):
#         '''
#         Creates transversal images as RGB stacking tumor slices from the 
#         given modalities as individual channels.
        
#         Arguments: 
#         list_modality : list
#             Each element coming from something like mod_img.get_fdata()
#         segmentation : np array
#             binary volume defining where the tumor is
#         save_path : str
#             the path where images are saved     
#         '''
#         # fix the volumes to match the transversal plane
#         # segmentation
#         segmentation = np.rot90(segmentation, 3, axes = (0,1)) # xy plane transversal
        
#         # the modality volumes
#         for idx, v in enumerate(list_modality):
#             # do the magic for Mango
#             v = np.rot90(v, 3, axes = (0,1))            
#             list_modality[idx] = v
        
#         # get indices of tumor slices
#         tr_1 = max(ndarray.nonzero(segmentation)[2]) # xmax
#         tr_0 = min(ndarray.nonzero(segmentation)[2]) # xmin
        
#         # loop through all the indices
#         for transv in range(tr_0, tr_1 + 1):
#             perc = int(((transv - tr_0)/(tr_1 - tr_0))*100) # Percentage along the selected slices
            
#             aus_mod = []
#             for idx in range(len(list_modality)):
#                 aus_mod.append(list_modality[idx][:,:,transv])
            
#             # check that we have all the channels
#             if len(aus_mod) != 3:
#                 for i in range(3-len(aus_mod)):
#                     aus_mod.append(aus_mod[-1])
                    
#             # convert list to array [W, H, CH]
#             slices = np.stack(aus_mod, axis = -1)
                         
#             # Name the files
#             aus_name = modalities
#             if len(modalities) != 3:
#                 for i in range(3-len(modalities)):
#                     aus_name.append(modalities[-1])
#             aus_name = '_'.join(aus_name)
                
#             title = os.path.join(save_path, p_name, p_name + '_trans_' + aus_name +\
#                                  '_' + str(transv) + '_' + str(perc) + '.png')
#             # Convert to RGB
#             im = Image.fromarray(slices).convert('RGB')
            
#             # save
#             im.save(title)
            
#     create_transversal(aus_modalities, seg_img_data, save_path = tra_path)