In [None]:
# Imports
import os
import shutil
import nibabel as nib
import numpy as np
import random
from random import shuffle
from glob import glob
from nilearn.plotting import plot_roi, plot_epi

In [None]:
#%% Set current directory
os.chdir('/home/uziel/DISS')
root = './data_processed/ISLES2017/training'

In [None]:
# groups relevant sequences per subject
subjects_paths = sorted(os.listdir(root))
channels_per_subject = dict()
for i in range(len(subjects_paths)):
    s_path = os.path.join(root, subjects_paths[i])
    channels_per_subject[i] = sorted([os.path.join(s_path, x)
                                      for x in os.listdir(s_path)
                                      if 'clone' not in x])


In [None]:
# For each subjec, create n new ones (default=1),
# whose lesion region is randomly sampled.
clones_number = 1
for subject, entries in channels_per_subject.items():
    channels = [x for x in entries if "OT" not in x and "mask" not in x]
    mask = [x for x in entries if "mask" in x]
    label = [x for x in entries if "OT" in x]
    subject_path = os.path.dirname(channels[0])
    
    # load subject label
    label_img = nib.load(label[0])
    label_data = label_img.get_data()
    # load subejct mask
    mask_img = nib.load(mask[0])
    mask_name = os.path.basename(mask[0])
    
    # create new clones
    for i in range(clones_number):
        # create path to save clone data
        clone_path = os.path.join(subject_path, 'clone_V2_'+ str(i))
        if os.path.exists(clone_path): shutil.rmtree(clone_path)
        os.makedirs(clone_path)
        
        # create each clone channel
        for j in range(len(channels)):
            channel_img = nib.load(channels[j])
            channel_data = channel_img.get_data().copy()
            # get data withing roi (label)
            roi_data = channel_data[np.nonzero(label_data)]
            # new data follows gaussian distribution
            mean_value, std_value = [np.mean(roi_data), np.std(roi_data)]
            channel_data[np.nonzero(label_data)] = np.array([random.gauss(mean_value, std_value)
                                                             for _ in range(roi_data.shape[0])])
            # create modified channel for clone
            modified_channel = nib.Nifti1Image(channel_data, channel_img.affine)
            #TODO: Normalize image?
            # save clone channel
            channel_name = os.path.basename(channels[j])
            nib.save(modified_channel, os.path.join(clone_path, 'clone_V2_' + str(i) + '.' + channel_name))
        
        # save unaltered label for clone
        label_name = os.path.basename(label[0])
        nib.save(label_img, os.path.join(clone_path, 'clone_V2_' + str(i) + '.' + label_name))
        #save unaltered mask for clone
        nib.save(mask_img, os.path.join(clone_path, 'clone_V2_' + str(i) + '.' + mask_name))

    print("Subject " + str(subject) + " finished.")
            

In [None]:
def data_to_file(data, path):
    out = open(path, "w")
    for line in data:
        print >> out, line
    out.close()

In [None]:
###############################################################
##### FILES FOR DM_V1 (BASELINE + RANDOM LESION SAMPLING) #####
###############################################################

# copy configFiles from DM_V1 (baseline)
config_path = './ischleseg/deepmedic/versions/DM_V1/configFiles'
new_config_path = './ischleseg/deepmedic/versions/DM_V2/configFiles'
if os.path.exists(new_config_path): shutil.rmtree(new_config_path)
shutil.copytree(config_path, new_config_path)

# get train directories
train_dirs = [x for x in os.listdir(new_config_path)
             if 'train_' in x]

# process each train set
for train_dir in train_dirs:
    train_path = os.path.join(new_config_path, train_dir)
    # read subject codes
    subject_list = [os.path.dirname(line.strip()).split('/')[-1] for line in open(os.path.join(train_path, 'trainChannels_ADC.cfg') , 'r')]
    channels = {}
    # channels - sequences os.path.join('../../../../../../', x) needed for deepmedic
    channels['Channels_ADC'] = [os.path.join('../../../../../../', y)
                                for x in os.walk(root)
                                for y in glob(os.path.join(x[0], '*ADC*.nii.gz'))
                                if os.path.dirname(x[0]).split('/')[-1] in subject_list or
                                os.path.basename(x[0]) in subject_list]
    channels['Channels_MTT'] = [os.path.join('../../../../../../', y)
                                for x in os.walk(root)
                                for y in glob(os.path.join(x[0], '*MTT*.nii.gz'))
                                if os.path.dirname(x[0]).split('/')[-1] in subject_list or
                                os.path.basename(x[0]) in subject_list]
    channels['Channels_rCBF'] = [os.path.join('../../../../../../', y)
                                 for x in os.walk(root)
                                 for y in glob(os.path.join(x[0], '*rCBF*.nii.gz'))
                                 if os.path.dirname(x[0]).split('/')[-1] in subject_list or
                                 os.path.basename(x[0]) in subject_list]
    channels['Channels_rCBV'] = [os.path.join('../../../../../../', y)
                                 for x in os.walk(root)
                                 for y in glob(os.path.join(x[0], '*rCBV*.nii.gz'))
                                 if os.path.dirname(x[0]).split('/')[-1] in subject_list or
                                 os.path.basename(x[0]) in subject_list]
    channels['Channels_Tmax'] = [os.path.join('../../../../../../', y)
                                 for x in os.walk(root)
                                 for y in glob(os.path.join(x[0], '*Tmax*.nii.gz'))
                                 if os.path.dirname(x[0]).split('/')[-1] in subject_list or
                                 os.path.basename(x[0]) in subject_list]
    channels['Channels_TTP'] = [os.path.join('../../../../../../', y)
                                for x in os.walk(root)
                                for y in glob(os.path.join(x[0], '*TTP*.nii.gz'))
                                if os.path.dirname(x[0]).split('/')[-1] in subject_list or
                                os.path.basename(x[0]) in subject_list]
    # labels
    channels['GtLabels'] = [os.path.join('../../../../../../', y)
                            for x in os.walk(root)
                            for y in glob(os.path.join(x[0], '*OT*.nii.gz'))
                            if os.path.dirname(x[0]).split('/')[-1] in subject_list or
                            os.path.basename(x[0]) in subject_list]
    # masks
    channels['RoiMasks'] = [os.path.join('../../../../../../', y)
                            for x in os.walk(root)
                            for y in glob(os.path.join(x[0], '*mask.nii.gz'))
                            if os.path.dirname(x[0]).split('/')[-1] in subject_list or
                            os.path.basename(x[0]) in subject_list]

    for name, files in channels.iteritems():
        # save train channel files
        data_to_file(files, os.path.join(train_path, 'train' + name + '.cfg'))

# modelConfig,cfg, trainConfig.cfg and testConfig.cfg must be added and modified manually.

In [None]:
################################
##### TEST RESULTING FILES #####
################################

# Get training subjects
subject_list = [os.path.split(os.path.dirname(x[0]))[1] for x in subject_list]
subject_list.append('clone_V2_0') #add clone subdir name
#%% Generate files listing all images per channel
os.chdir('/home/uziel/DISS')
root = './data_processed/ISLES2017/training'

channels = {}
# channels - sequences os.path.join('../../../../../../', x) needed for deepmedic
channels['Channels_ADC'] = [y for x in os.walk(root)
                            for y in glob(os.path.join(x[0], '*ADC*.nii.gz'))
                           if os.path.basename(x[0]) in subject_list]
channels['Channels_MTT'] = [y for x in os.walk(root)
                            for y in glob(os.path.join(x[0], '*MTT*.nii.gz'))
                           if os.path.basename(x[0]) in subject_list]
channels['Channels_rCBF'] = [y for x in os.walk(root)
                             for y in glob(os.path.join(x[0], '*rCBF*.nii.gz'))
                           if os.path.basename(x[0]) in subject_list]
channels['Channels_rCBV'] = [y for x in os.walk(root)
                             for y in glob(os.path.join(x[0], '*rCBV*.nii.gz'))
                           if os.path.basename(x[0]) in subject_list]
channels['Channels_Tmax'] = [y for x in os.walk(root)
                             for y in glob(os.path.join(x[0], '*Tmax*.nii.gz'))
                           if os.path.basename(x[0]) in subject_list]
channels['Channels_TTP'] = [y for x in os.walk(root)
                            for y in glob(os.path.join(x[0], '*TTP*.nii.gz'))
                           if os.path.basename(x[0]) in subject_list]
# labels
channels['GtLabels'] = [y for x in os.walk(root)
                        for y in glob(os.path.join(x[0], '*OT*.nii.gz'))
                           if os.path.basename(x[0]) in subject_list]
# masks
channels['RoiMasks'] = [y for x in os.walk(root)
                        for y in glob(os.path.join(x[0], '*mask.nii.gz'))
                           if os.path.basename(x[0]) in subject_list]

# take 5 random subjects and check their images
indices = range(len(channels['Channels_ADC']))
shuffle(indices)
indices = indices[:5]

for i in indices:
    # load a random channel of subject i
    channel = random.choice([x for x in channels.keys() if 'Mask' not in x and "Label" not in x])
    #img = nib.load(channels[channel][i])
    img = nib.load(channels[channel][i])
    mask = nib.load(channels['RoiMasks'][i])
    print('Subject: ' + str(i) + '. Channel: ' + str(channel) + '. Shape: ' + str(img.shape))
    plot_epi(img) # plot_epi(img, cut_coords=(0,0,0)) -> use to see co-registered channels per subject
    plot_roi(mask, img)