### Imports:

In [1]:
import pickle

from functools import partial

from tensorflow.keras.layers import Input, LeakyReLU, Add, UpSampling3D, Activation, SpatialDropout3D
from keras.optimizer_v2 import adam
from keras.optimizer_v2 import rmsprop 

import pandas as pd

import numpy as np
from keras import backend as K
from keras import Input
from keras import Model
from keras.layers import Conv3D, MaxPooling3D, UpSampling3D, Activation, BatchNormalization, PReLU, Flatten, Dense, GlobalAveragePooling3D

# K.set_image_dim_ordering('th')
# K.set_image_dim_ordering('tf')
K.set_image_data_format('channels_first')

try:
    from keras.engine import merge
except ImportError:
    from keras.layers.merge import concatenate

from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())
print(len(device_lib.list_local_devices()))

C:\Users\HassanIslam\anaconda3\envs\FYP\lib\site-packages\numpy\.libs\libopenblas.GK7GX5KEQ4F6UYO3P26ULGBQYHGQO7J4.gfortran-win_amd64.dll
C:\Users\HassanIslam\anaconda3\envs\FYP\lib\site-packages\numpy\.libs\libopenblas.WCDJNK7YVMPZQ2ME2ZZHJJRJ3JIKNDB7.gfortran-win_amd64.dll


[name: "/device:CPU:0"
device_type: "CPU"
memory_limit: 268435456
locality {
}
incarnation: 3796633128729504359
, name: "/device:GPU:0"
device_type: "GPU"
memory_limit: 10091102208
locality {
  bus_id: 1
  links {
  }
}
incarnation: 4195951953474789381
physical_device_desc: "device: 0, name: NVIDIA GeForce GTX 1080 Ti, pci bus id: 0000:01:00.0, compute capability: 6.1"
]
2


### Reading in survival data.csv:

In [2]:
survival_data = pd.read_csv('survival_data_filled.csv')
# m = survival_data['Age'].mean()
# survival_data['Age'].fillna(value=m, inplace=True)
# survival_data['Survival'] = survival_data['Survival'].fillna(0)
# survival_data.to_csv('./features/survival_data_filled.csv')

In [3]:
ID = 'BraTS19_CBICA_ABN_1'
survival_data[survival_data.BraTS19ID==ID].Survival.astype(int).values.item(0)

1278

### Make tumor type dictionary:

In [4]:
tumor_type_dict = {}

In [5]:
import os

HGG_dir_list = next(os.walk('./dataCorrected/'))[1]
# print(len(HGG_dir_list))
# LGG_dir_list = next(os.walk('./LGG/'))[1]
# print(len(LGG_dir_list))


for patientID in HGG_dir_list:
#     print(patientID)
    if patientID in HGG_dir_list:
#         tumor_type_dict[patientID] = "HGG"
        tumor_type_dict[patientID] = 0
    elif patientID in LGG_dir_list:
#         tumor_type_dict[patientID] = "LGG"
        tumor_type_dict[patientID] = 1

print(len(tumor_type_dict))
# tumor_type_dict[(HGG_dir_list+LGG_dir_list)[0]]

257


### Calculating metrics:

In [2]:
def dice_coefficient(y_true, y_pred, smooth=1.):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)


def dice_coefficient_loss(y_true, y_pred):
    return -dice_coefficient(y_true, y_pred)


def weighted_dice_coefficient(y_true, y_pred, axis=(-3, -2, -1), smooth=0.00001):
    """
    Weighted dice coefficient. Default axis assumes a "channels first" data structure
    :param smooth:
    :param y_true:
    :param y_pred:
    :param axis:
    :return:
    """
    return K.mean(2. * (K.sum(y_true * y_pred,
                              axis=axis) + smooth/2)/(K.sum(y_true,
                                                            axis=axis) + K.sum(y_pred,
                                                                               axis=axis) + smooth))


def weighted_dice_coefficient_loss(y_true, y_pred):
    return -weighted_dice_coefficient(y_true, y_pred)


def label_wise_dice_coefficient(y_true, y_pred, label_index):
    return dice_coefficient(y_true[:, label_index], y_pred[:, label_index])


def get_label_dice_coefficient_function(label_index):
    f = partial(label_wise_dice_coefficient, label_index=label_index)
    f.__setattr__('__name__', 'label_{0}_dice_coef'.format(label_index))
    return f


dice_coef = dice_coefficient
dice_coef_loss = dice_coefficient_loss

In [3]:
def create_convolution_block(input_layer, n_filters, batch_normalization=False, kernel=(3, 3, 3), activation=None,
                             padding='same', strides=(1, 1, 1), instance_normalization=False):
    """
    :param strides:
    :param input_layer:
    :param n_filters:
    :param batch_normalization:
    :param kernel:
    :param activation: Keras activation layer to use. (default is 'relu')
    :param padding:
    :return:
    """
    layer = Conv3D(n_filters, kernel, padding=padding, strides=strides)(input_layer)
    if batch_normalization:
        layer = BatchNormalization(axis=1)(layer)
    elif instance_normalization:
        try:
            from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
        except ImportError:
            raise ImportError("Install keras_contrib in order to use instance normalization."
                              "\nTry: pip install git+https://www.github.com/farizrahman4u/keras-contrib.git")
        layer = InstanceNormalization(axis=1)(layer)
    if activation is None:
        return Activation('relu')(layer)
    else:
        return activation()(layer)


def compute_level_output_shape(n_filters, depth, pool_size, image_shape):
    """
    Each level has a particular output shape based on the number of filters used in that level and the depth or number 
    of max pooling operations that have been done on the data at that point.
    :param image_shape: shape of the 3d image.
    :param pool_size: the pool_size parameter used in the max pooling operation.
    :param n_filters: Number of filters used by the last node in a given level.
    :param depth: The number of levels down in the U-shaped model a given node is.
    :return: 5D vector of the shape of the output node 
    """
    output_image_shape = np.asarray(np.divide(image_shape, np.power(pool_size, depth)), dtype=np.int32).tolist()
    return tuple([None, n_filters] + output_image_shape)


def get_up_convolution(n_filters, pool_size, kernel_size=(2, 2, 2), strides=(2, 2, 2),
                       deconvolution=False):
    if deconvolution:
        return Deconvolution3D(filters=n_filters, kernel_size=kernel_size,
                               strides=strides)
    else:
        return UpSampling3D(size=pool_size)

def create_localization_module(input_layer, n_filters):
    convolution1 = create_convolution_block(input_layer, n_filters)
    convolution2 = create_convolution_block(convolution1, n_filters, kernel=(1, 1, 1))
    return convolution2


def create_up_sampling_module(input_layer, n_filters, size=(2, 2, 2)):
    up_sample = UpSampling3D(size=size)(input_layer)
    convolution = create_convolution_block(up_sample, n_filters)
    return convolution


def create_context_module(input_layer, n_level_filters, dropout_rate=0.3, data_format="channels_first"):
    convolution1 = create_convolution_block(input_layer=input_layer, n_filters=n_level_filters)
    dropout = SpatialDropout3D(rate=dropout_rate, data_format=data_format)(convolution1)
    convolution2 = create_convolution_block(input_layer=dropout, n_filters=n_level_filters)
    return convolution2


create_convolution_block = partial(create_convolution_block, activation=LeakyReLU, instance_normalization=True)

### Make the labels and test train dictionaries:

In [8]:
# from glob import glob
# paths = glob('/Users/etheredgej/Desktop/MICCAI_BraTS17_Data_Training/train/HGG/*/')
# print(paths)

import os
HGG_dir_list = next(os.walk('./dataCorrected/'))[1]
print(len(HGG_dir_list))
# LGG_dir_list = next(os.walk('./LGG/'))[1]
# print(len(LGG_dir_list))

257


### Dictionary for all samples:

In [9]:
completelist = HGG_dir_list


In [10]:
# completelist = HGG_dir_list + LGG_dir_list

# completelist = list(survival_data.Brats17ID.copy())

# print(completelist[0:4])
np.random.shuffle(completelist) # shuffles in place
# print(completelist[0:4])

partition={}

holdout_percentage=0.15
partition['holdout']=completelist[0:int(len(completelist)*holdout_percentage)]
trainlist=completelist[int(len(completelist)*holdout_percentage):len(completelist)]

train_percentage=0.7
partition['train']=trainlist[0:int(len(trainlist)*train_percentage)]
partition['test']=trainlist[int(len(trainlist)*train_percentage):len(trainlist)]


labels={}
# HGG=0
# LGG=1
for directory in HGG_dir_list:
    labels[directory]=0
# for directory in LGG_dir_list:
#     labels[directory]=1
    
print(len(partition['holdout']))
print(len(partition['train']))
print(len(partition['test']))

38
153
66


In [11]:
partition['holdout']

['BraTS19_TCIA01_150_1',
 'BraTS19_TCIA08_280_1',
 'BraTS19_CBICA_AOZ_1',
 'BraTS19_TCIA01_186_1',
 'BraTS19_TCIA02_377_1',
 'BraTS19_TCIA02_368_1',
 'BraTS19_TCIA02_151_1',
 'BraTS19_CBICA_AXM_1',
 'BraTS19_TCIA02_394_1',
 'BraTS19_2013_23_1',
 'BraTS19_CBICA_AUX_1',
 'BraTS19_TCIA06_603_1',
 'BraTS19_CBICA_AQV_1',
 'BraTS19_CBICA_AUQ_1',
 'BraTS19_TCIA08_234_1',
 'BraTS19_TCIA02_135_1',
 'BraTS19_TCIA06_372_1',
 'BraTS19_CBICA_ANG_1',
 'BraTS19_TCIA02_171_1',
 'BraTS19_CBICA_ASR_1',
 'BraTS19_TCIA01_390_1',
 'BraTS19_CBICA_BNR_1',
 'BraTS19_CBICA_ABN_1',
 'BraTS19_CBICA_BGO_1',
 'BraTS19_CBICA_AAL_1',
 'BraTS19_CBICA_ATX_1',
 'BraTS19_CBICA_AYA_1',
 'BraTS19_2013_13_1',
 'BraTS19_CBICA_AVV_1',
 'BraTS19_TCIA08_242_1',
 'BraTS19_CBICA_AQR_1',
 'BraTS19_TCIA02_607_1',
 'BraTS19_CBICA_AXO_1',
 'BraTS19_TCIA02_118_1',
 'BraTS19_2013_7_1',
 'BraTS19_CBICA_AOH_1',
 'BraTS19_TCIA01_401_1',
 'BraTS19_CBICA_BIC_1']

### Dictionary for the samples with survival data:

In [12]:
completelist = HGG_dir_list

completelist = list(survival_data.BraTS19ID.copy())

# print(completelist[0:4])
np.random.shuffle(completelist) # shuffles in place
# print(completelist[0:4])

subpartition={}

holdout_percentage=0.15
subpartition['holdout']=completelist[0:int(len(completelist)*holdout_percentage)]
trainlist=completelist[int(len(completelist)*holdout_percentage):len(completelist)]

train_percentage=0.7
subpartition['train']=trainlist[0:int(len(trainlist)*train_percentage)]
subpartition['test']=trainlist[int(len(trainlist)*train_percentage):len(trainlist)]


labels={}
# HGG=0
# LGG=1
for directory in HGG_dir_list:
    labels[directory]=0
# for directory in LGG_dir_list:
#     labels[directory]=1
    
print(len(subpartition['holdout']))
print(len(subpartition['train']))
print(len(subpartition['test']))

38
153
66


In [13]:
subpartition['holdout']
# len(survival_data.Brats17ID)
# len(set.intersection(set(completelist),set(survival_data.Brats17ID)))

['BraTS19_CBICA_BFP_1',
 'BraTS19_TCIA01_401_1',
 'BraTS19_TCIA03_375_1',
 'BraTS19_CBICA_AWH_1',
 'BraTS19_CBICA_ATF_1',
 'BraTS19_CBICA_ASE_1',
 'BraTS19_TCIA02_473_1',
 'BraTS19_CBICA_ASF_1',
 'BraTS19_CBICA_AQR_1',
 'BraTS19_2013_22_1',
 'BraTS19_CBICA_BGE_1',
 'BraTS19_CBICA_AQT_1',
 'BraTS19_CBICA_ARW_1',
 'BraTS19_TCIA01_180_1',
 'BraTS19_TCIA01_390_1',
 'BraTS19_TCIA01_335_1',
 'BraTS19_CBICA_BGN_1',
 'BraTS19_TCIA01_412_1',
 'BraTS19_CBICA_ASW_1',
 'BraTS19_TCIA01_131_1',
 'BraTS19_2013_4_1',
 'BraTS19_TMC_12866_1',
 'BraTS19_CBICA_AUA_1',
 'BraTS19_TCIA02_208_1',
 'BraTS19_CBICA_BEM_1',
 'BraTS19_CBICA_APZ_1',
 'BraTS19_2013_17_1',
 'BraTS19_CBICA_AVG_1',
 'BraTS19_CBICA_AZD_1',
 'BraTS19_2013_25_1',
 'BraTS19_TCIA08_105_1',
 'BraTS19_TCIA08_205_1',
 'BraTS19_CBICA_AUN_1',
 'BraTS19_CBICA_ALX_1',
 'BraTS19_TCIA08_469_1',
 'BraTS19_2013_26_1',
 'BraTS19_TCIA01_147_1',
 'BraTS19_CBICA_BJY_1']

### crop_img function:

In [None]:
# import numpy as np
# from nilearn.image.image import check_niimg
# from nilearn.image.image import _crop_img_to as crop_img_to


# def crop_img(img, rtol=1e-8, copy=True, return_slices=False):
#     """Crops img as much as possible
#     Will crop img, removing as many zero entries as possible
#     without touching non-zero entries. Will leave one voxel of
#     zero padding around the obtained non-zero area in order to
#     avoid sampling issues later on.
#     Parameters
#     ----------
#     img: Niimg-like object
#         See http://nilearn.github.io/manipulating_images/input_output.html
#         img to be cropped.
#     rtol: float
#         relative tolerance (with respect to maximal absolute
#         value of the image), under which values are considered
#         negligeable and thus croppable.
#     copy: boolean
#         Specifies whether cropped data is copied or not.
#     return_slices: boolean
#         If True, the slices that define the cropped image will be returned.
#     Returns
#     -------
#     cropped_img: image
#         Cropped version of the input image
#     """

#     img = check_niimg(img)
#     data = img.get_data()
#     infinity_norm = max(-data.min(), data.max())
#     passes_threshold = np.logical_or(data < -rtol * infinity_norm,
#                                      data > rtol * infinity_norm)

#     if data.ndim == 4:
#         passes_threshold = np.any(passes_threshold, axis=-1)
#     coords = np.array(np.where(passes_threshold))
#     start = coords.min(axis=1)
#     end = coords.max(axis=1) + 1

#     # pad with one voxel to avoid resampling problems
#     start = np.maximum(start - 1, 0)
#     end = np.minimum(end + 1, data.shape[:3])

#     slices = [slice(s, e) for s, e in zip(start, end)]

#     if return_slices:
#         return slices

#     return crop_img_to(img, slices, copy=copy)

### Save all the cropped images as pickled numpy arrays:

In [None]:
# import pickle
# import numpy as np
# import nibabel as nib

# for i, ID in enumerate(completelist):
#     print("Reading",completelist[i])

#     img1 = './dataCorrected/' + ID + '/' + ID + '_flair.nii.gz'
#     img2 = './dataCorrected/' + ID +  '/' + ID + '_t1.nii.gz'
#     img3 = './dataCorrected/' + ID +  '/' + ID + '_t1ce.nii.gz'
#     img4 = './dataCorrected/' + ID +  '/' + ID + '_t2.nii.gz'
#     img5 = './data/' + ID +  '/' + ID + '_seg.nii.gz'

#     newimage = nib.concat_images([img1, img2, img3, img4, img5])
#     cropped = crop_img(newimage)         
#     img_array = np.array(cropped.dataobj)
#     z = np.rollaxis(img_array, 3, 0)

#     padded_image = np.zeros((5,160,192,160))
#     padded_image[:z.shape[0],:z.shape[1],:z.shape[2],:z.shape[3]] = z

#     a,b,c,d,seg_mask = np.split(padded_image, 5, axis=0)

#     images = np.concatenate([a,b,c,d], axis=0)

#     # print("images shape:", images.shape, "images values:", np.unique(images.astype(int)))

#     # split the channels:
#     # seg_mask_1 = copy.deepcopy(seg_mask.astype(int))
#     seg_mask_1 = np.zeros((1,160,192,160))
#     seg_mask_1[seg_mask.astype(int) == 1] = 1
#     seg_mask_2 = np.zeros((1,160,192,160))
#     seg_mask_2[seg_mask.astype(int) == 2] = 1
#     seg_mask_3 = np.zeros((1,160,192,160))
#     seg_mask_3[seg_mask.astype(int) == 4] = 1
#     seg_mask_3ch = np.concatenate([seg_mask_1,seg_mask_2,seg_mask_3], axis=0).astype(int)

#     # 1) the "enhancing tumor" (ET), 2) the "tumor core" (TC), and 3) the "whole tumor" (WT) 
#     # The ET is described by areas that show hyper-intensity in T1Gd when compared to T1, but also when compared to “healthy” white matter in T1Gd. The TC describes the bulk of the tumor, which is what is typically resected. The TC entails the ET, as well as the necrotic (fluid-filled) and the non-enhancing (solid) parts of the tumor. The appearance of the necrotic (NCR) and the non-enhancing (NET) tumor core is typically hypo-intense in T1-Gd when compared to T1. The WT describes the complete extent of the disease, as it entails the TC and the peritumoral edema (ED), which is typically depicted by hyper-intense signal in FLAIR.
#     # The labels in the provided data are: 
#     # 1 for NCR & NET (necrotic (NCR) and the non-enhancing (NET) tumor core) = TC ("tumor core")
#     # 2 for ED ("peritumoral edema")
#     # 4 for ET ("enhancing tumor")
#     # 0 for everything else

# #     X[i,] = images
# #     y1[i,] = seg_mask_3ch
#     pickle.dump( images, open( "D:\croppedFiles/%s_images.pkl"%(ID), "wb" ) )
#     pickle.dump( seg_mask_3ch, open( "D:\croppedFiles/%s_seg_mask_3ch.pkl"%(ID), "wb" ) )
#     print("Saving", i+1,completelist[i], "of", len(completelist))


### Data generator all samples (1 predictions):

### Data generator all samples (2 predictions):

### Data generator for samples with survival data (3 predictions, subset of images):

In [14]:
# https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly.html

import numpy as np
import keras
import nibabel as nib

class SubDataGenerator(keras.utils.data_utils.Sequence):
    'Generates data for Keras'
    def __init__(self, list_IDs, batch_size=1, dim=(240,240,155), n_channels=4,
                 n_classes=3, shuffle=True):
        'Initialization'
        self.dim = dim
        self.batch_size = batch_size
        self.labels = labels
        self.list_IDs = list_IDs
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]

        # Generate data
        X, y1, y2 = self.__data_generation(list_IDs_temp)

        return X, [y1, y2]

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __data_generation(self, list_IDs_temp):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
        X = np.empty((self.batch_size, self.n_channels, *self.dim))
        y1 = np.empty((self.batch_size, 3, *self.dim))
        y2 = np.empty(self.batch_size)

        # Generate data
        # Decode and load the data
        for i, ID in enumerate(list_IDs_temp):

            # 1) the "enhancing tumor" (ET), 2) the "tumor core" (TC), and 3) the "whole tumor" (WT) 
            # The ET is described by areas that show hyper-intensity in T1Gd when compared to T1, but also when compared to “healthy” white matter in T1Gd. The TC describes the bulk of the tumor, which is what is typically resected. The TC entails the ET, as well as the necrotic (fluid-filled) and the non-enhancing (solid) parts of the tumor. The appearance of the necrotic (NCR) and the non-enhancing (NET) tumor core is typically hypo-intense in T1-Gd when compared to T1. The WT describes the complete extent of the disease, as it entails the TC and the peritumoral edema (ED), which is typically depicted by hyper-intense signal in FLAIR.
            # The labels in the provided data are: 
            # 1 for NCR & NET (necrotic (NCR) and the non-enhancing (NET) tumor core) = TC ("tumor core")
            # 2 for ED ("peritumoral edema")
            # 4 for ET ("enhancing tumor")
            # 0 for everything else

            X[i,] = pickle.load( open( "D:\croppedFiles/%s_images.pkl"%(ID), "rb" ) )
            y1[i,] = pickle.load( open( "D:\croppedFiles/%s_seg_mask_3ch.pkl"%(ID), "rb" ) )            
            y2[i,] = survival_data[survival_data.BraTS19ID==ID].Survival.astype(int).values.item(0)

        return X, y1, y2

In [None]:
# import numpy as np

# from keras.models import Sequential
# from my_classes import DataGenerator

In [None]:
ID = LGG_dir_list[0]
type(tumor_type_dict[ID])

### Single prediction compilation:

In [None]:
# # change the number of labels?
# # loss_function={'activation_block': weighted_dice_coefficient_loss, 'survival_block': 'mean_squared_error'}
# # selected_optimizer = RMSprop
# # selected_initial_learning_rate = 5e-4

# model = isensee2017_model(input_shape=(4, 160, 192, 160), n_base_filters=12, depth=5, dropout_rate=0.3,
#                       n_segmentation_levels=3, n_labels=3, activation_name="sigmoid")

# model.compile(optimizer=RMSprop(lr=5e-4), 
#               loss={'activation_block': weighted_dice_coefficient_loss}, 
#               loss_weights={'activation_block': 1.},
#              metrics={'activation_block': ['accuracy',weighted_dice_coefficient, dice_coefficient]})

# model.summary(line_length=150) # add the parameter that allows me to show everything instead of cutting it off

In [4]:
# change the number of labels?
# loss_function={'activation_block': weighted_dice_coefficient_loss, 'survival_block': 'mean_squared_error'}
# selected_optimizer = RMSprop
# selected_initial_learning_rate = 5e-4



input_shape=(4, 160, 192, 160)
n_base_filters=8
depth=5
dropout_rate=0.3
n_segmentation_levels=3
n_labels=3
activation_name="sigmoid"

"""
This function builds a model proposed by Isensee et al. for the BRATS 2017 competition:
https://www.cbica.upenn.edu/sbia/Spyridon.Bakas/MICCAI_BraTS/MICCAI_BraTS_2017_proceedings_shortPapers.pdf
This network is highly similar to the model proposed by Kayalibay et al. "CNN-based Segmentation of Medical
Imaging Data", 2017: https://arxiv.org/pdf/1701.03056.pdf
:param input_shape:
:param n_base_filters:
:param depth:
:param dropout_rate:
:param n_segmentation_levels:
:param n_labels:
:param optimizer:
:param initial_learning_rate:
:param loss_function:
:param activation_name:
:return:
"""
inputs = Input(input_shape)

current_layer = inputs
level_output_layers = list()
level_filters = list()
for level_number in range(depth):
    n_level_filters = (2**level_number) * n_base_filters
    level_filters.append(n_level_filters)

    if current_layer is inputs:
        in_conv = create_convolution_block(current_layer, n_level_filters)
    else:
        in_conv = create_convolution_block(current_layer, n_level_filters, strides=(2, 2, 2))

    context_output_layer = create_context_module(in_conv, n_level_filters, dropout_rate=dropout_rate)

    summation_layer = Add()([in_conv, context_output_layer])
    level_output_layers.append(summation_layer)
    current_layer = summation_layer

segmentation_layers = list()
for level_number in range(depth - 2, -1, -1):
    up_sampling = create_up_sampling_module(current_layer, level_filters[level_number])
    concatenation_layer = concatenate([level_output_layers[level_number], up_sampling], axis=1)
    localization_output = create_localization_module(concatenation_layer, level_filters[level_number])
    current_layer = localization_output
    if level_number < n_segmentation_levels:
        segmentation_layers.insert(0, create_convolution_block(current_layer, n_filters=n_labels, kernel=(1, 1, 1)))

for l in segmentation_layers:
    print(l.shape)
        
output_layer = None
for level_number in reversed(range(n_segmentation_levels)):
    segmentation_layer = segmentation_layers[level_number]
    if output_layer is None:
        output_layer = segmentation_layer
    else:
        print(level_number)
        output_layer = Add()([output_layer, segmentation_layer])

    if level_number > 0:
        output_layer = UpSampling3D(size=(2, 2, 2))(output_layer)

activation_block = Activation(activation = activation_name, name='activation_block')(output_layer)
#     survival_block = Activation("linear")(summation_layer)
#     activation_block = Dense(1, activation=activation_name, name='activation_block')(output_layer)
#     flatten = Flatten(name='flatten')(summation_layer)
#     survival_block = Dense(1, activation='linear', name='survival_block')(flatten)

survival_conv_1 = Conv3D(filters=n_level_filters, kernel_size=(3, 3, 3), padding='same', strides=(1, 1, 1), name='survival_conv_1')(summation_layer)
survival_conv_2 = Conv3D(filters=n_level_filters, kernel_size=(3, 3, 3), padding='same', strides=(1, 1, 1), name='survival_conv_2')(survival_conv_1)
dropout = SpatialDropout3D(rate=dropout_rate, data_format='channels_first', name='dropout')(survival_conv_2)
survival_conv_3 = Conv3D(filters=n_level_filters, kernel_size=(3, 3, 3), padding='same', strides=(1, 1, 1), name='survival_conv_3')(dropout)
survival_GAP = GlobalAveragePooling3D(name='survival_GAP')(survival_conv_3)
#     flatten = Flatten(name='flatten')(survival_GAP)
#     survival_block = Activation("linear", name='survival_block')(flatten)
survival_block = Dense(1, activation='linear', name='survival_block')(survival_GAP)

tumortype_conv_1 = Conv3D(filters=n_level_filters, kernel_size=(3, 3, 3), padding='same', strides=(1, 1, 1), name='tumortype_conv_1')(summation_layer)
tumortype_conv_2 = Conv3D(filters=n_level_filters, kernel_size=(3, 3, 3), padding='same', strides=(1, 1, 1), name='tumortype_conv_2')(tumortype_conv_1)
tumortype_dropout = SpatialDropout3D(rate=dropout_rate, data_format='channels_first', name='tumortype_dropout')(tumortype_conv_2)
tumortype_conv_3 = Conv3D(filters=n_level_filters, kernel_size=(3, 3, 3), padding='same', strides=(1, 1, 1), name='tumortype_conv_3')(tumortype_dropout)
tumortype_GAP = GlobalAveragePooling3D(name='tumortype_GAP')(tumortype_conv_3)
#     flatten = Flatten(name='flatten')(tumortype_GAP)
#     tumortype_block = Activation("linear", name='tumortype_block')(flatten)
tumortype_block = Dense(1, activation='sigmoid', name='tumortype_block')(tumortype_GAP)  

model = Model(inputs=inputs, outputs=[activation_block])
#     model.compile(optimizer=optimizer(lr=initial_learning_rate), loss=loss_function)
#     loss={'activation_block': 'binary_crossentropy', 'survival_block': 'mean_squared_error'}
# assign weights and loss as dictionaries
# functional-api-guide
# loss_weights define the ratio of how much I care about optimizing each one

# model.load_weights("./weights/1pred_weights.25--0.08.hdf5", by_name=True) # the by_name=True allows you to use a different architecture and bring in the weights from the matching layers 

model.compile(optimizer="adam", 
              loss={'activation_block': weighted_dice_coefficient_loss}, 
              loss_weights={'activation_block': 1.},
             metrics={'activation_block': ['accuracy',weighted_dice_coefficient, dice_coefficient]})

model.summary() # add the parameter that allows me to show everything instead of cutting it off

(None, 3, 160, 192, 160)
(None, 3, 80, 96, 80)
(None, 3, 40, 48, 40)
1
0
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 4, 160, 192, 0                                            
__________________________________________________________________________________________________
conv3d (Conv3D)                 (None, 8, 160, 192,  872         input_1[0][0]                    
__________________________________________________________________________________________________
instance_normalization (Instanc (None, 8, 160, 192,  16          conv3d[0][0]                     
__________________________________________________________________________________________________
tf.nn.leaky_relu (TFOpLambda)   (None, 8, 160, 192,  0           instance_normalization[0][0]     
_____________________

In [None]:
tf.config.list_physical_devices('GPU')

### Save training history and predictions for 1 prediction:

### 2 predictions compilation (all data):

In [5]:
model = Model(inputs=inputs, outputs=[activation_block,tumortype_block])
model.load_weights("./weights/model_1_weights.h5", by_name=True) # the by_name=True allows you to use a different architecture and bring in the weights from the matching layers 

model.compile(optimizer="adam", 
              loss={'activation_block': weighted_dice_coefficient_loss, 'tumortype_block': 'binary_crossentropy'}, 
              loss_weights={'activation_block': 1., 'tumortype_block': 0.2},
             metrics={'activation_block': ['accuracy',weighted_dice_coefficient, dice_coefficient], 'tumortype_block': ['accuracy']})


model.summary(line_length=150) # add the parameter that allows me to show everything instead of cutting it off

Model: "model_1"
______________________________________________________________________________________________________________________________________________________
Layer (type)                                     Output Shape                     Param #           Connected to                                      
input_1 (InputLayer)                             [(None, 4, 160, 192, 160)]       0                                                                   
______________________________________________________________________________________________________________________________________________________
conv3d (Conv3D)                                  (None, 8, 160, 192, 160)         872               input_1[0][0]                                     
______________________________________________________________________________________________________________________________________________________
instance_normalization (InstanceNormalization)   (None, 8, 160, 192, 160)    

### Train the 2 prediction full data net:

In [5]:
model.load_weights('./weights/model_3_weights_new.h5')
model.save('model3_new.h5')

ValueError: You are trying to load a weight file containing 64 layers into a model with 60 layers.

In [None]:
params = {'dim': (160,192,160),
          'batch_size': 1,
          'n_classes': 3,
          'n_channels': 4,
          'shuffle': True}

# Generators
training_generator = DataGenerator(partition['train'], **params)
validation_generator = DataGenerator(partition['test'], **params)

# cb_1=keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=2, verbose=0, mode='auto')
cb_2=keras.callbacks.ModelCheckpoint(filepath="./weights/2pred_weights.{epoch:02d}-{val_loss:.2f}.hdf5", monitor='val_loss', verbose=0, save_best_only=True, save_weights_only=False, mode='auto', period=1)

results = model.fit_generator(generator=training_generator,
                    validation_data=validation_generator,
                   epochs=100,
                   callbacks=[cb_2])

model.save_weights("./weights/model_2_weights.h5")
print("Saved model to disk")

### Save training history and predictions for 2 predictions (all data):

In [None]:
history_2_pred = results.history
pickle.dump( history_2_pred, open( "./weights/history_2_pred.pkl", "wb" ) )

params = {'dim': (160,192,160),
          'batch_size': 1,
          'n_classes': 3,
          'n_channels': 4,
          'shuffle': False}

# Turned shuffle off so that we can match the values in the dictionary to the predictions. 
# This way we can compare the predictions side-by-side with the ground truth.

validation_generator = DataGenerator(partition['holdout'], **params)

predictions_2_pred = model.predict_generator(generator=validation_generator)

pickle.dump( predictions_2_pred, open( "./weights/predictions_2_pred.pkl", "wb" ) )

### 3 predictions compilation (all data):

In [6]:
model = Model(inputs=inputs, outputs=[activation_block,survival_block])
model.load_weights("./weights/model_1_weights.h5", by_name=True)

model.compile(optimizer="adam", 
              loss={'activation_block': weighted_dice_coefficient_loss, 'survival_block': 'mean_squared_error'}, 
              loss_weights={'activation_block': 1., 'survival_block': 0.2,},
             metrics={'activation_block': ['accuracy',weighted_dice_coefficient, dice_coefficient], 'survival_block': ['accuracy', 'mae']})

model.summary(line_length=150) # add the parameter that allows me to show everything instead of cutting it off

Model: "model_1"
______________________________________________________________________________________________________________________________________________________
Layer (type)                                     Output Shape                     Param #           Connected to                                      
input_1 (InputLayer)                             [(None, 4, 160, 192, 160)]       0                                                                   
______________________________________________________________________________________________________________________________________________________
conv3d (Conv3D)                                  (None, 8, 160, 192, 160)         872               input_1[0][0]                                     
______________________________________________________________________________________________________________________________________________________
instance_normalization (InstanceNormalization)   (None, 8, 160, 192, 160)    

### Train the 3 prediction subset data net:

In [7]:
model.load_weights('./weights/model_3_weights_new.h5')
model.save('model3_new.h5')

In [19]:
model.load_weights('./weights/model_3_weights.h5')
params = {'dim': (160,192,160),
          'batch_size': 1,
          'n_classes': 3,
          'n_channels': 4,
          'shuffle': False}

# Turned shuffle off so that we can match the values in the dictionary to the predictions. 
# This way we can compare the predictions side-by-side with the ground truth.

validation_generator = SubDataGenerator(["BraTS19_TMC_11964_1"], **params)

predictions_1_pred = model.predict_generator(generator=validation_generator)



In [23]:
print(predictions_1_pred[1])

[[172.54855]]


In [20]:
params = {'dim': (160,192,160),
          'batch_size': 1,
          'n_classes': 3,
          'n_channels': 4,
          'shuffle': True}

# Generators
training_generator = SubDataGenerator(subpartition['train'], **params)
validation_generator = SubDataGenerator(subpartition['test'], **params)

# cb_1=keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=2, verbose=0, mode='auto')
cb_2=keras.callbacks.ModelCheckpoint(filepath="./weights/3pred_weights.{epoch:02d}-{val_loss:.2f}.hdf5", monitor='val_loss', verbose=0, save_best_only=True, save_weights_only=False, mode='auto', period=1)

results = model.fit_generator(generator=training_generator,
                    validation_data=validation_generator,
                   epochs=100,
                   callbacks=[cb_2])

model.save_weights("./weights/model_3_weights_new.h5")
print("Saved model to disk")





Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100


Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100


Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100


Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78/100
Epoch 79/100
Epoch 80/100
Epoch 81/100
Epoch 82/100
Epoch 83/100
Epoch 84/100
Epoch 85/100
Epoch 86/100
Epoch 87/100
Epoch 88/100
Epoch 89/100
Epoch 90/100
Epoch 91/100
Epoch 92/100
Epoch 93/100
Epoch 94/100
Epoch 95/100
Epoch 96/100
Epoch 97/100


Epoch 98/100
Epoch 99/100
Epoch 100/100
Saved model to disk


### Save training history and predictions for 3 predictions (subset of data with survival predictions):

In [None]:
history_3_pred = results.history
pickle.dump( history_3_pred, open( "./weights/history_3_pred.pkl", "wb" ) )

params = {'dim': (160,192,160),
          'batch_size': 1,
          'n_classes': 3,
          'n_channels': 4,
          'shuffle': False}

# Turned shuffle off so that we can match the values in the dictionary to the predictions. 
# This way we can compare the predictions side-by-side with the ground truth.

validation_generator = SubDataGenerator(subpartition['holdout'], **params)

predictions_3_pred = model.predict_generator(generator=validation_generator)

pickle.dump( predictions_3_pred, open( "./weights/predictions_3_pred.pkl", "wb" ) )

### Validation set predictions for 2 predictions (all data):

In [None]:
# score = model.evaluate(x_test, y_test, verbose=0)

params = {'dim': (160,192,160),
          'batch_size': 1,
          'n_classes': 3,
          'n_channels': 4,
          'shuffle': False}

# Turned shuffle off so that we can match the values in the dictionary to the predictions. 
# This way we can compare the predictions side-by-side with the ground truth.

validation_generator = DataGenerator(partition['holdout'], **params)

prediction = model.predict_generator(generator=validation_generator)
# print('Test loss:', score[0])
# print('Test accuracy:', score[1])
prediction

### Validation set predictions for 3 predictions (subset of data with survival predictions):

In [None]:
# score = model.evaluate(x_test, y_test, verbose=0)

params = {'dim': (160,192,160),
          'batch_size': 1,
          'n_classes': 3,
          'n_channels': 4,
          'shuffle': False}

# Turned shuffle off so that we can match the values in the dictionary to the predictions. 
# This way we can compare the predictions side-by-side with the ground truth.

validation_generator = SubDataGenerator(subpartition['holdout'], **params)

prediction = model.predict_generator(generator=validation_generator)
# print('Test loss:', score[0])
# print('Test accuracy:', score[1])
prediction

In [None]:
# sanity check on the predictions:
len(prediction)
prediction[0].shape # segmentation mask
prediction[1].shape # survival
prediction[2].shape # tumor type
# np.unique(prediction)

In [None]:
# len(completelist)
len(set.intersection(set(HGG_dir_list), set(completelist)))
len(set.intersection(set(LGG_dir_list), set(completelist)))

In [None]:
# for ID in partition['holdout']:
#     print(tumor_type_dict[ID])

for ID in partition['holdout']:
    print(tumor_type_dict[ID])

In [None]:
# ! mkdir to_categorical_try
# ! mkdir channel_split

In [None]:
# import pickle

# pickle.dump( partition, open( "./channel_split/partition.pkl", "wb" ) ) # this has the test/train ID matches

# # # access the test list:
# # testIDlist = partition['test']
# # testIDlist

In [None]:
# for i in range(len(prediction)):
#     pickle.dump( prediction[i], open( "./channel_split/prediction_"+str(i)+".pkl", "wb" ) )

In [None]:
# # https://keras.io/getting-started/faq/#how-can-i-save-a-keras-model
# # pickle.dump( model, open( "model.pkl", "wb" ) )
# model.save_weights('./channel_split/my_model_weights.h5')

In [None]:
import pickle

partition = pickle.load(open( "./channel_split/partition.pkl", "rb" ) ) # this has the test/train ID matches

# access the test list:
testIDlist = partition['test']
# testIDlist

### Write images to .tif:

In [None]:
from tifffile import imsave
for i in range(len(prediction)):
    imarray = pickle.load(open( "./channel_split/prediction_"+str(i)+".pkl", "rb" ) )

    imarray[]
    imarray *= 255.0/imarray.max()
    print(np.unique())
    
    imsave("./channel_split/"+testIDlist[i]+"prediction.tif", imarray)

    # make ground truth 
    ID = testIDlist[i]
    img1 = './data/' + ID + '_flair.nii.gz'
    img2 = './data/' + ID + '_t1.nii.gz'
    img3 = './data/' + ID + '_t1ce.nii.gz'
    img4 = './data/' + ID + '_t2.nii.gz'
    img5 = './data/' + ID + '_seg.nii.gz'

    newimage = nib.concat_images([img1, img2, img3, img4, img5])
    cropped = crop_img(newimage)         
    img_array = np.array(cropped.dataobj)
    z = np.rollaxis(img_array, 3, 0)

    padded_image = np.zeros((5,160,192,160))
    padded_image[:z.shape[0],:z.shape[1],:z.shape[2],:z.shape[3]] = z

    a,b,c,d,seg_mask = np.split(padded_image, 5, axis=0)

    images = np.concatenate([a,b,c,d], axis=0)
    imsave("./channel_split/"+testIDlist[i]+"ground_truth.tif", images)
    

### Testing:

In [None]:
import pickle
import numpy as np
import copy
import nibabel as nib

i = 0

imarray = pickle.load(open( "./channel_split/prediction_"+str(i)+".pkl", "rb" ) )
# threshold the channels (for prediction):
prediction_thresh = copy.deepcopy(imarray)
prediction_thresh[prediction_thresh < 0.5] = 0.
prediction_thresh[prediction_thresh >= 0.5] = 1.
prediction_thresh = prediction_thresh
print(np.unique(prediction_thresh))
prediction_thresh *= 255.0/prediction_thresh.max() # convert to 8-bit pixel values
prediction_thresh = prediction_thresh.astype(int)
print(np.unique(prediction_thresh))
print(prediction_thresh.shape)

ID = testIDlist[i]
img1 = './data/' + ID + '_flair.nii.gz'
flairimg = nib.load(img1)
flairimg = np.array(flairimg.dataobj)
flairimg = np.expand_dims(flairimg, axis=0)
flairimg = np.rollaxis(flairimg, 3, 0)
print(np.unique(flairimg))
flairimg = flairimg.astype(float)
flairimg *= 255.0/flairimg.max() # convert to 8-bit pixel values
flairimg = flairimg.astype(int)
print(np.unique(flairimg))
print(flairimg.shape)

### Making ground truth .tiff files: 
- Testing:

In [None]:
import numpy as np
import copy
import nibabel as nib

from tifffile import imsave
from libtiff import TIFF

from skimage.io._plugins import freeimage_plugin as fi

# import javabridge
# import bioformats
# javabridge.start_vm(class_path=bioformats.JARS)

# your program goes here


# ID = testIDlist[i]
# for i in range(len(testIDlist)):
for i in range(2):

    print("current image:", i)

    ID = testIDlist[i]
    img1 = './data/' + ID + '_flair.nii.gz'
    img2 = './data/' + ID + '_t1.nii.gz'
    img3 = './data/' + ID + '_t1ce.nii.gz'
    img4 = './data/' + ID + '_t2.nii.gz'
    img5 = './data/' + ID + '_seg.nii.gz'

    newimage = nib.concat_images([img1, img2, img3, img4, img5])
    cropped = crop_img(newimage)
    img_array = np.array(cropped.dataobj)
    z = np.rollaxis(img_array, 3, 0)

    padded_image = np.zeros((5, 160, 192, 160))
    padded_image[:z.shape[0], :z.shape[1], :z.shape[2], :z.shape[3]] = z

    a, b, c, d, seg_mask = np.split(padded_image, 5, axis=0)

    images = np.concatenate([a, b, c, d], axis=0)

    # print("images shape:", images.shape, "images values:", np.unique(images.astype(int)))

    # split the channels:
    # seg_mask_1 = copy.deepcopy(seg_mask.astype(int))
    seg_mask_1 = np.zeros((1, 160, 192, 160))
    seg_mask_1[seg_mask.astype(int) > 0] = 1
    seg_mask_2 = np.zeros((1, 160, 192, 160))
    seg_mask_2[seg_mask.astype(int) > 1] = 1
    seg_mask_3 = np.zeros((1, 160, 192, 160))
    seg_mask_3[seg_mask.astype(int) > 2] = 1
    seg_mask_3ch = np.concatenate(
        [seg_mask_1, seg_mask_2, seg_mask_3], axis=0).astype(int)

    # def scale_image(image_array):
    #     image_array = image_array.astype(float)
    #     image_array *= 255.0/image_array.max() # convert to 8-bit pixel values
    #     image_array = image_array.astype(int)
    #     return image_array

    # img_array_list = [a,seg_mask_1,seg_mask_2,seg_mask_3]
    # for img_array in img_array_list:
    #     img_array = scale_image(img_array)

    a = a.astype(float)
    a *= 255.0/a.max()  # convert to 8-bit pixel values
    a = np.rollaxis(a, 0, 2)
    a = a.astype('uint8')
#     print("unique flair values:", np.unique(a))

    seg_mask_1 = seg_mask_1.astype(float)
    seg_mask_1 *= 255.0/seg_mask_1.max()  # convert to 8-bit pixel values
    seg_mask_1 = np.rollaxis(seg_mask_1, 0, 2)
    seg_mask_1 = seg_mask_1.astype('uint8')
#     print("unique segment mask values:", np.unique(seg_mask_1))

    seg_mask_2 = seg_mask_2.astype(float)
    seg_mask_2 *= 255.0/seg_mask_2.max()  # convert to 8-bit pixel values
    seg_mask_2 = np.rollaxis(seg_mask_2, 0, 2)
    seg_mask_2 = seg_mask_2.astype('uint8')

    seg_mask_3 = seg_mask_3.astype(float)
    seg_mask_3 *= 255.0/seg_mask_3.max()  # convert to 8-bit pixel values
    seg_mask_3 = np.rollaxis(seg_mask_3, 0, 2)
    seg_mask_3 = seg_mask_3.astype('uint8')

#     ground_truth = np.concatenate(
#         [a, seg_mask_1, seg_mask_2, seg_mask_3], axis=0).astype('uint8')

#     print("unique flair + segment mask values:", np.unique(ground_truth))
    # shape.ground_truth
    # flairimg = flairimg.astype(float)
    # flairimg *= 255.0/flairimg.max() # convert to 8-bit pixel values
    # flairimg = flairimg.astype(int)
    # print(np.unique(flairimg))
#     print("final image shape:", ground_truth.shape)
#     imsave("./channel_split/"+testIDlist[i]+"ground_truth.tif", ground_truth, 'imagej')
    imsave("./channel_split/"+testIDlist[i]+"flair.tif", a, 'imagej')
    imsave("./channel_split/"+testIDlist[i]+"ground_truth_1.tif", seg_mask_1, 'imagej')
    imsave("./channel_split/"+testIDlist[i]+"ground_truth_2.tif", seg_mask_2, 'imagej')
    imsave("./channel_split/"+testIDlist[i]+"ground_truth_3.tif", seg_mask_3, 'imagej')

#     tiff = TIFF.open("./channel_split/"+testIDlist[i]+"ground_truth.tif", mode='w')
#     tiff.write_image(ground_truth)
#     tiff.close()
#     fi.write_multipage(ground_truth, "./channel_split/"+testIDlist[i]+"ground_truth.tif")
#     bioformats.write_image(pathname="./channel_split/"+testIDlist[i]+"ground_truth.tif", 
#                            pixels=ground_truth, 
#                            pixel_type=u'uint8',
#                            size_c=4, size_z=160, size_t=1,
#                            channel_names=None)

# javabridge.kill_vm()

### Adding predictions:

In [None]:
import numpy as np
import copy
import nibabel as nib

from tifffile import imsave
from libtiff import TIFF

from skimage.io._plugins import freeimage_plugin as fi

# import javabridge
# import bioformats
# javabridge.start_vm(class_path=bioformats.JARS)

# your program goes here


# ID = testIDlist[i]
# for i in range(len(testIDlist)):
for i in range(2):

    print("current image:", i)

    ID = testIDlist[i]
    img1 = './data/' + ID + '_flair.nii.gz'
    img2 = './data/' + ID + '_t1.nii.gz'
    img3 = './data/' + ID + '_t1ce.nii.gz'
    img4 = './data/' + ID + '_t2.nii.gz'
    img5 = './data/' + ID + '_seg.nii.gz'

    newimage = nib.concat_images([img1, img2, img3, img4, img5])
    cropped = crop_img(newimage)
    img_array = np.array(cropped.dataobj)
    z = np.rollaxis(img_array, 3, 0)

    padded_image = np.zeros((5, 160, 192, 160))
    padded_image[:z.shape[0], :z.shape[1], :z.shape[2], :z.shape[3]] = z

    a, b, c, d, seg_mask = np.split(padded_image, 5, axis=0)

    images = np.concatenate([a, b, c, d], axis=0)

    # print("images shape:", images.shape, "images values:", np.unique(images.astype(int)))

    # split the channels:
    # seg_mask_1 = copy.deepcopy(seg_mask.astype(int))
    seg_mask_1 = np.zeros((1, 160, 192, 160))
    seg_mask_1[seg_mask.astype(int) > 0] = 1
    seg_mask_2 = np.zeros((1, 160, 192, 160))
    seg_mask_2[seg_mask.astype(int) > 1] = 1
    seg_mask_3 = np.zeros((1, 160, 192, 160))
    seg_mask_3[seg_mask.astype(int) > 2] = 1
    seg_mask_3ch = np.concatenate(
        [seg_mask_1, seg_mask_2, seg_mask_3], axis=0).astype(int)

    # def scale_image(image_array):
    #     image_array = image_array.astype(float)
    #     image_array *= 255.0/image_array.max() # convert to 8-bit pixel values
    #     image_array = image_array.astype(int)
    #     return image_array

    # img_array_list = [a,seg_mask_1,seg_mask_2,seg_mask_3]
    # for img_array in img_array_list:
    #     img_array = scale_image(img_array)

    a = a.astype(float)
    a *= 255.0/a.max()  # convert to 8-bit pixel values
    a = np.rollaxis(a, 0, 2) # cxyz -> xycz for imagej
    a = np.rollaxis(a, 0, 3) # switching x and z
    a = a.astype('uint8')
#     print("unique flair values:", np.unique(a))

    seg_mask_1 = seg_mask_1.astype(float)
    seg_mask_1 *= 255.0/seg_mask_1.max()  # convert to 8-bit pixel values
    seg_mask_1 = np.rollaxis(seg_mask_1, 0, 2) # cxyz -> xycz for imagej
    seg_mask_1 = np.rollaxis(seg_mask_1, 0, 3) # switching x and z
    seg_mask_1 = seg_mask_1.astype('uint8')
#     print("unique segment mask values:", np.unique(seg_mask_1))

    seg_mask_2 = seg_mask_2.astype(float)
    seg_mask_2 *= 255.0/seg_mask_2.max()  # convert to 8-bit pixel values
    seg_mask_2 = np.rollaxis(seg_mask_2, 0, 2) # cxyz -> xycz for imagej
    seg_mask_2 = np.rollaxis(seg_mask_2, 0, 3) # switching x and z
    seg_mask_2 = seg_mask_2.astype('uint8')

    seg_mask_3 = seg_mask_3.astype(float)
    seg_mask_3 *= 255.0/seg_mask_3.max()  # convert to 8-bit pixel values
    seg_mask_3 = np.rollaxis(seg_mask_3, 0, 2) # cxyz -> xycz for imagej
    seg_mask_3 = np.rollaxis(seg_mask_3, 0, 3) # switching x and z
    seg_mask_3 = seg_mask_3.astype('uint8')

#     ground_truth = np.concatenate(
#         [a, seg_mask_1, seg_mask_2, seg_mask_3], axis=0).astype('uint8')

#     print("unique flair + segment mask values:", np.unique(ground_truth))
    # shape.ground_truth
    # flairimg = flairimg.astype(float)
    # flairimg *= 255.0/flairimg.max() # convert to 8-bit pixel values
    # flairimg = flairimg.astype(int)
    # print(np.unique(flairimg))
#     print("final image shape:", ground_truth.shape)
#     imsave("./channel_split/"+testIDlist[i]+"ground_truth.tif", ground_truth, 'imagej')
    imsave("./channel_split/"+testIDlist[i]+"_flair.tif", a, 'imagej')
    imsave("./channel_split/"+testIDlist[i]+"_ground_truth_1.tif", seg_mask_1, 'imagej')
    imsave("./channel_split/"+testIDlist[i]+"_ground_truth_2.tif", seg_mask_2, 'imagej')
    imsave("./channel_split/"+testIDlist[i]+"_ground_truth_3.tif", seg_mask_3, 'imagej')

    imarray = pickle.load(open( "./channel_split/prediction_"+str(i)+".pkl", "rb" ) )

    prediction_thresh = copy.deepcopy(imarray)
    prediction_thresh[prediction_thresh < 0.5] = 0.
    prediction_thresh[prediction_thresh >= 0.5] = 1.
    prediction_thresh = prediction_thresh
    print(np.unique(prediction_thresh))
    prediction_thresh *= 255.0/prediction_thresh.max() # convert to 8-bit pixel values
    prediction_thresh = prediction_thresh.astype('uint8')
    prediction_thresh = np.rollaxis(prediction_thresh, 1, 3) # switching x and z; c will be taken care of in split
    print(np.unique(prediction_thresh))
    print(prediction_thresh.shape)

    pred1, pred2, pred3 = np.split(prediction_thresh, 3, axis=0)

    imsave("./channel_split/"+testIDlist[i]+"_predicted_1.tif", pred1, 'imagej')
    imsave("./channel_split/"+testIDlist[i]+"_predicted_2.tif", pred2, 'imagej')
    imsave("./channel_split/"+testIDlist[i]+"_predicted_3.tif", pred3, 'imagej')

    # print("images shape:", images.shape, "images values:", np.unique(images.astype(int)))

    # split the channels:
    # seg_mask_1 = copy.deepcopy(seg_mask.astype(int))


    #     seg_mask_3ch = np.concatenate(
#         [seg_mask_1, seg_mask_2, seg_mask_3], axis=0).astype(int)

#     imarray *= 255.0/imarray.max()

#     imsave("./channel_split/"+testIDlist[i]+"ground_truth_1.tif", seg_mask_1, 'imagej')
#     imsave("./channel_split/"+testIDlist[i]+"ground_truth_2.tif", seg_mask_2, 'imagej')
#     imsave("./channel_split/"+testIDlist[i]+"ground_truth_3.tif", seg_mask_3, 'imagej')