In [1]:
import os
import numpy as np
import nibabel as nib
import cv2
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import IPython.display as ipyd
from PIL import Image
from tqdm import tqdm

ModuleNotFoundError: No module named 'nibabel'

In [2]:
n_slices = 155
label_color_map = {0: [0,0,0], 63: [0,0,255], 127: [0,255,0], 255: [255,0,0]}
modes = ['flair', 't1', 't1ce', 't2']

In [3]:
def load_nifti(img_path):
    """
    Loads data from nifti file into 3D numpy array
    """
    img = nib.load(img_path)
    img = img.get_fdata()
    img = (img/img.max())*255
    img = img.astype(np.uint8)
    return img


def generate_img_slices(img_path, target_path, resize=None):
    """
    Generates 2D image slices from 3D nifti image files and saves them in the target folder
    """
    img_name = os.path.split(img_path)[-1].split('.')[0]
    img = load_nifti(img_path)
    
    # Generate slices
    for i in range(img.shape[2]):
        img_slice_name = img_name + '_' + str(i) + '.jpg'
        img_slice_path = os.path.join(target_path, img_slice_name)
        img_slice = img[:,:,i]
        
        if resize:
            img_slice = cv2.resize(img_slice, dsize=(resize,resize), interpolation=cv2.INTER_CUBIC)
       
        # Make output directory if not exists
        try: os.makedirs(target_path)
        except FileExistsError: pass
        cv2.imwrite(img_slice_path, img_slice)
        
        
def generate_seg_slices(seg_path, target_path):
    """
    Generates 2D segmentation mask slices from 3D nifti segmentation mask files and saves them in the target folder
    """
    seg_name = os.path.split(seg_path)[-1].split('.')[0]
    seg = load_nifti(seg_path)
    
    # Generate slices
    for i in range(seg.shape[2]):
        seg_slice_name = seg_name + '_' + str(i) + '.npy'
        seg_slice_path = os.path.join(target_path, seg_slice_name)
        seg_slice = seg[:,:,i]
        
        # Make output directory if not exists
        try: os.makedirs(target_path)
        except FileExistsError: pass
        np.save(seg_slice_path, seg_slice)
    
    
def apply_seg_mask(img_arr, seg_arr):
    """
    Receives numpy arrays of a 2D RGB image and a 2D RGB segmentation mask slice and applies the latter on the former, producing a numpy array of a segmented image.
    """
    seg_img_arr = img_arr.copy()
    mask = np.any(seg_arr!=[0,0,0],axis=2)  #  Select the pixels of the segmentation mask where the mask is not empty (axis 2 is for the color channel)
    seg_img_arr[mask] = seg_arr[mask]  # Apply the non-empty pixels of the segmentation mask to the image
    return seg_img_arr

    
def plot_gif(data_path, id, save_path):
    """
    Plots gif for each modality from a data sample id
    """
    frames = {mode: [] for mode in modes}
    
    for n in range(n_slices): # For each slice
        
        # Load segmentation mask slice
        seg_name = f'BraTS20_Training_{"0" * max(0, 3 - len(str(id))) + str(id)}_seg'
        seg_slice_path = os.path.join(data_path, seg_name + f'_{n}.npy')
        seg_slice =  np.load(seg_slice_path)
        
        # Convert segmentation mask slice to RGB numpy array using a label-color mapping
        seg_slice_rgb = np.zeros([dim for dim in list(seg_slice.shape)] + [3])  # Add extra dimension (3) for RGB
        for label in label_color_map.keys():
            seg_slice_rgb[seg_slice==label] = label_color_map[label]
            
        # TODO: Might need to resize the segmentation mask  array to match the image array in case we resized the images

        # Plot gif for each modality of slice n
        for mode in modes:
            
            # Load image slice
            img_name = f'BraTS20_Training_{"0" * max(0, 3 - len(str(id))) + str(id)}_{mode}'  
            img_slice_path = os.path.join(data_path, img_name + f'_{n}.jpg')
            img_slice = Image.open(img_slice_path)
             
             # Convert image slice to RGB numpy array
            img_slice_rgb = np.expand_dims(np.asarray(img_slice),axis=2).repeat(3,axis=2)
            
            seg_img_slice = Image.fromarray(apply_seg_mask(img_slice_rgb, seg_slice_rgb))  # Generate a segmented image slice
            frames[mode].append(seg_img_slice)  # Add it as a frame to the gif
        
    for mode in modes:  
        img_name = f'BraTS20_Training_{"0" * max(0, 3 - len(str(id))) + str(id)}_{mode}' 
        gif_path = os.path.join(save_path, img_name + '.gif')
        frame_one = frames[mode][0]
        frame_one.save(gif_path, format="GIF", append_images=frames[mode],
                   save_all=True, duration=100, loop=0)
        print(f'Saved at: {gif_path}')
        
        
def generate_2D_dataset(source_path, target_path, resize=None):
    """
    Generates 2D dataset in target path from the 3D dataset in source path
    """
    for sample_name in tqdm(next(os.walk(source_path))[1]):  # For every folder (dataset sample) in the source path
        sample_path = os.path.join(source_path, sample_name)
        for file_name in next(os.walk(sample_path))[2]:  # For every file in the sample path
            file_path = os.path.join(sample_path,file_name)
            if 'seg' not in file_name:  # If the file is a modality and not a segmentation mask
                generate_img_slices(file_path, target_path, resize)
            else:
                generate_seg_slices(file_path, target_path)
            

In [16]:
generate_2D_dataset('/mnt/5C5C25FB5C25D116/data/BraTS2020/train', '/mnt/5C5C25FB5C25D116/data/BraTS2020/train_2D')

100%|██████████| 369/369 [04:17<00:00,  1.43it/s]


In [17]:
generate_2D_dataset('/mnt/5C5C25FB5C25D116/data/BraTS2020/val', '/mnt/5C5C25FB5C25D116/data/BraTS2020/val_2D')

100%|██████████| 125/125 [01:26<00:00,  1.44it/s]


In [55]:
plot_gif('/mnt/5C5C25FB5C25D116/data/BraTS2020/train_2D', 26, '../results')

Saved at: ../results/BraTS20_Training_026_flair.gif
Saved at: ../results/BraTS20_Training_026_t1.gif
Saved at: ../results/BraTS20_Training_026_t1ce.gif
Saved at: ../results/BraTS20_Training_026_t2.gif


In [46]:
# from scipy.linalg import circulant
# x_vec = np.concatenate([np.linspace(start=0, stop=1, num=50),np.linspace(start=1, stop=0, num=50)])
# x = np.expand_dims(circulant(x_vec),axis=2).repeat(3,axis=2)
# 
# k = 25
# y = np.ones((x.shape[0], x.shape[1]))
# for i in range(-k,k):
#     y += np.eye(x.shape[0], x.shape[1], i)
# y_rgb = np.zeros([dim for dim in list(y.shape)] + [3])
# y_rgb[y==1] = [1,0,0]
# new_x = x.copy()
# new_x[np.any(y_rgb!=[0,0,0],axis=2)] = y_rgb[np.any(y_rgb!=[0,0,0],axis=2)]
# plt.figure()
# plt.imshow(x)
# plt.figure()
# plt.imshow(y_rgb)
# plt.figure()
# plt.imshow(new_x)
# Image.fromarray((new_x * 255).astype(np.uint8)).show()

In [3]:
import numpy as np
from glob import glob
import SimpleITK as sitk
from tqdm import tqdm
import os

def load_sitk_with_resample(img_path):
    outsize = [0, 0, 0]
    outspacing = [1, 1, 1]

    vol = sitk.ReadImage(img_path)
    inputsize = vol.GetSize()
    return inputsize

data_path = '/projects/0/prjs0905/data/LUNA16'
xyz = []

for index_subset in range(0,10):
    luna_subset_path = os.path.join(data_path, "subset" + str(index_subset))
    file_list = glob(os.path.join(luna_subset_path, "*.mhd"))  # Only selects mhd files (excludes the segmentations)
    subset = 'subset' + str(index_subset)
    
    for img_file in tqdm(file_list, desc=f'Images in fold {index_subset} parsed'):
        img_name = os.path.split(img_file)[-1]
        size = load_sitk_with_resample(img_file)
        xyz.append(size[2])
        

Images in fold 0 parsed: 100%|██████████| 89/89 [00:18<00:00,  4.72it/s]
Images in fold 1 parsed: 100%|██████████| 89/89 [00:16<00:00,  5.48it/s]
Images in fold 2 parsed: 100%|██████████| 89/89 [00:20<00:00,  4.31it/s]
Images in fold 3 parsed: 100%|██████████| 89/89 [00:22<00:00,  4.01it/s]
Images in fold 4 parsed: 100%|██████████| 89/89 [00:17<00:00,  5.07it/s]
Images in fold 5 parsed: 100%|██████████| 89/89 [00:17<00:00,  5.08it/s]
Images in fold 6 parsed: 100%|██████████| 89/89 [00:16<00:00,  5.42it/s]
Images in fold 7 parsed: 100%|██████████| 89/89 [00:17<00:00,  5.02it/s]
Images in fold 8 parsed: 100%|██████████| 88/88 [00:15<00:00,  5.50it/s]
Images in fold 9 parsed: 100%|██████████| 88/88 [00:16<00:00,  5.20it/s]


In [4]:
print(np.max(xyz))
print(np.min(xyz))

764
95
