In [1]:
from matplotlib import pyplot as plt
import cv2
import time
import nibabel as nib
import os
import numpy as np
from shutil import rmtree
from threading import Thread

In [3]:
image_folder = './datasets/MICCAI_BraTS2020_TrainingData/'
all_images_folder = [image_folder + f + '/' for f in os.listdir(image_folder)][:-2]

In [4]:
training_folder = 'F:torch_gan/pix2pix/'

In [5]:
rmtree(training_folder)
os.mkdir(training_folder)
os.mkdir(training_folder + 'train')
os.mkdir(training_folder + 'val')
os.mkdir(training_folder + 'test')

In [6]:
image_file_format = '{}{}_{}.nii.gz'

def get_images(file):
    img = nib.load(file)
    data = img.get_fdata()
    maxx = data.max()
    data = data/maxx
    
    return data, data.shape[-1]

In [7]:
def get_save_folder():
    rand = np.random.random()
    
    if rand<0.15:
        return 'val/'
    elif rand<0.3:
        return 'test/'
    else:
        return 'train/'

In [8]:
def save_image_data_train(image_folder):
    file_path_t1 = image_file_format.format(image_folder, image_folder.split('/')[-2], 't1')
    file_path_t2 = image_file_format.format(image_folder, image_folder.split('/')[-2], 't1ce')
    file_path_seg = image_file_format.format(image_folder, image_folder.split('/')[-2], 'seg')
    
    t1_img, _ = get_images(file_path_t1)
    t2_img, _ = get_images(file_path_t2)
    sg_img, _ = get_images(file_path_seg)
    
    save_folder = get_save_folder()
    file_name = image_folder.split('/')[-2].split('_')[-1]
    image_size = t1_img.shape[0]
    for i in range(27, 127):
        if cv2.sumElems(sg_img[:, :, i])[0] > 100:
            canvas = np.empty((image_size, image_size*2), np.uint8)
            canvas[:, :image_size] = (t1_img[:, :, i] * 255).astype('int')
            canvas[:, image_size:] = (t2_img[:, :, i] * 255).astype('int')
            cv2.imwrite(training_folder + save_folder + file_name + '_' + str(i) + '.jpg', canvas)

In [9]:
processing_threads = []
max_threads = 10

for idx, img_folder in enumerate(all_images_folder):
    print('Processing at index {}'.format(idx), end='\r')
    processing_threads.append(Thread(target=save_image_data_train, args=[img_folder]))
    processing_threads[-1].start()
    
    if len(processing_threads) == max_threads:
        while all([t.is_alive() for t in processing_threads]):
            time.sleep(0.3)
        
        for t in reversed(processing_threads):
            if not t.is_alive():
                processing_threads.remove(t)

Processing at index 368