In [None]:
import tensorflow as tf
# print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
print("GPU available:", tf.test.is_gpu_available())


from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
import SimpleITK as sitk
import os
import datetime


def display(display_list, cmap='bone'):
    plt.figure(figsize=(9, 9))

    title = ['Input Image', 'True Mask', 'Predicted Mask']

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i + 1)
        plt.title(title[i])
        plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]), cmap=cmap)
        plt.axis('off')
    plt.show()

def create_mask(pred_mask):
    # pred_mask = tf.argmax(pred_mask, axis=-1)
    # pred_mask = pred_mask[..., tf.newaxis]
    pred_mask[pred_mask > 0] = 1
    pred_mask[pred_mask < 0] = 0
    return pred_mask[0]


def show_predictions(dataset=None, num=1):
    if dataset:
        for image, mask in dataset.take(num):
            pred_mask = model.predict(image)
            display([image[0], mask[0], create_mask(pred_mask)])
    else:
        display([sample_image, sample_mask, create_mask(model.predict(sample_image[tf.newaxis, ...]))])



class DisplayCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        show_predictions()
        print(f'\nSample Prediction after epoch {epoch + 1}\n')

In [None]:
def downsample_block(inputs, filters, size):
    x = layers.Conv2D(filters, size, padding='same')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Conv2D(filters, size, padding='same')(x)
    x = layers.BatchNormalization()(x)
    y = layers.ReLU()(x)

    x = layers.MaxPool2D()(y)

    return x, y


def upsample_block(inputs, inputs_skip, filters, size):
    x = layers.Conv2DTranspose(filters, size, strides=2, padding='same')(inputs)
    x = layers.Concatenate()([x, inputs_skip])
    x = layers.Conv2D(filters, size, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Conv2D(filters, size, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    return x


def bottom_block(inputs, filters, size):
    x = layers.Conv2D(filters, size, padding='same')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Dropout(0.5)(x)
    x = layers.Conv2D(filters, size, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Dropout(0.5)(x)

    return x


def unet_model(input_shape):
    inputs = tf.keras.Input(shape=input_shape)

    x, skip1 = downsample_block(inputs, 64, 3)
    x, skip2 = downsample_block(x, 128, 3)
    x, skip3 = downsample_block(x, 256, 3)
    x, skip4 = downsample_block(x, 512, 3)

    x = bottom_block(x, 1024, 3)

    x = upsample_block(x, skip4, 512, 3)
    x = upsample_block(x, skip3, 256, 3)
    x = upsample_block(x, skip2, 128, 3)
    x = upsample_block(x, skip1, 64, 3)

    output = layers.Conv2D(1, 3, padding='same')(x)

    out_model = tf.keras.Model(inputs, output, name='Unet')

    return out_model


In [None]:
def label_map_3(seg_2d, img_2d):
    # works for 3d input image
    # 0 is white matter
    mask_image[mask_image == 1035] = 0 # set the insula as background
    mask_image[mask_image == 2035] = 0
    mask_image[mask_image != 0] = 1  # gray matter
    mask_image[mask_image == 0] = 2  # background
    return mask_image

def label_map_dkt31_6(seg_2d):
    # works for 3d input image
    # 0: white matter & background
    # 2,3: TEMPORAL_LOBE_MEDIAL
    TEMPORAL_LOBE_MEDIAL = {6, 16, 7}
    # 4,5: TEMPORAL_LOBE_LATERAL
    TEMPORAL_LOBE_LATERAL = {30, 15, 9, 34}
    # 6,7: FRONTAL_LOBE
    FRONTAL_LOBE = {28, 12, 14, 24, 17, 3, 18, 19, 20, 27}
    # 8,9: PARIETAL_LOBE
    PARIETAL_LOBE = {22, 31, 29, 8, 25}
    # 10, 11: OCCIPITAL_LOBE
    OCCIPITAL_LOBE = {13, 21, 5, 11}
    # 12, 13: CINGULATE_CORTEX
    CINGULATE_CORTEX = {10, 23, 26, 2}
    # removed #35
    label_cluster = [TEMPORAL_LOBE_MEDIAL, TEMPORAL_LOBE_LATERAL, FRONTAL_LOBE, PARIETAL_LOBE, OCCIPITAL_LOBE, CINGULATE_CORTEX]
    def get_new_label(old_label, left_right):
        for idx, lobe in enumerate(label_cluster):
            if old_label in lobe:
                if left_right == "left":
                    return 2+2*idx
                elif left_right == "right":
                    return 3+2*idx
                else:
                    raise ValueError("Wrong left_right value!")
        return 0

    seg_shape = seg_2d.shape
    seg_1d = seg_2d.flatten()
    for idx, val in enumerate(seg_1d):
        val = int(val)
        if val == 0:
            continue
        elif 1002 <= val <= 1035:
            seg_1d[idx] = get_new_label(val-1000, "left")
        elif 2002 <= val <= 2035:
            seg_1d[idx] = get_new_label(val-2000, "right")
        else:
            seg_1d[idx] = 0
            
            # raise ValueError("Could not convert label number {} for brain segmentation".format(val))
    seg_2d = seg_1d.reshape(seg_shape)
    return seg_2d

extra_mask = None
temp_arr_inp = []
temp_arr_label = []

def load_whole_brain_seg_data(set_start: int, set_end: bool, structures: bool, base_raw: bool,
                              mni152: bool, structures_cortex: bool,
                              resize: bool, resize_shape: [int, int, int],
                              augment: bool, augment_only_images: bool):
    datadirpath = 'dataverse_files'
    datasetpath_dic = {
        'NKI-RS': 'NKI-RS-22_volumes',
        'OASIS-TRT': 'OASIS-TRT-20_volumes',
        'MMRR': 'MMRR-21_volumes',
        'NKI-TRT': 'NKI-TRT-20_volumes',
        'Extra': 'Extra-18_volumes'
    }

    if set_start == 0 and set_end == 0:
        dataset_list = list(datasetpath_dic.values())
    else:
        dataset_list = list(datasetpath_dic.values())[set_start:set_end]

    print('dataset_list',dataset_list)

    scan_file_name = 't1weighted'
    if not base_raw:
        scan_file_name = scan_file_name + '_brain'

    if structures:
        mask_file_name = 'labels.DKT31.manual+aseg'
    else:
        mask_file_name = 't1weighted_brain'

    if mni152:
        scan_file_name = scan_file_name + '.MNI152'
        mask_file_name = mask_file_name + '.MNI152'

    scan_file_name = scan_file_name + '.nii.gz'
    mask_file_name = mask_file_name + '.nii.gz'

    for dataset in dataset_list:
        dataset_path = os.path.join(datadirpath, dataset)
        dir_list = [x for x in os.listdir(dataset_path) if '-' in x]
#         print('dir_list',dir_list)
        i=0
        for directory in dir_list:
            raw_scan_file = os.path.join(dataset_path, directory, scan_file_name)
            brain_scan_file = os.path.join(dataset_path, directory, mask_file_name)
            print("raw_scan_file",raw_scan_file)
            print("brain_scan_file",brain_scan_file)

            scan_image = sitk.GetArrayFromImage(sitk.ReadImage(raw_scan_file, sitk.sitkFloat32))#.nii.gz to npy
            scan_image /= np.amax(scan_image) #normalization

            mask_image = sitk.GetArrayFromImage(sitk.ReadImage(brain_scan_file, sitk.sitkFloat32))
            
            # make mask label to 13 classes.
            extra_mask = label_map_dkt31_6(mask_image)
            for i in range(1,3):
                temp_arr_inp.append(scan_image)
                temp_arr_label.append(mask_image)
            
#             if i==0:
#                 print("::::::::::::::::::")
#                 print(mask_image.min())
#                 print(mask_image.max())
#                 print(np.unique(mask_image))
                
                
#                 print("extra mask")
#                 print(extra_mask.min())
#                 print(extra_mask.max())
#                 print(np.unique(extra_mask))
                
#                 print("::::::::::::::::::")
            
            
            
            
#             i+=1


            # make mask label binary (0 and 1)
            if not structures:
                mask_image = np.where(0 == mask_image, mask_image, 1)
            
            # multilabel ???
            
            """
            
            """
            
            if structures and not structures_cortex:
                mask_image = np.where(2000.0 > mask_image, mask_image, mask_image - 1000.0)
                mask_image = np.where(1000.0 > mask_image, mask_image, mask_image - 1000.0)

            if resize:
#                 print('scan_image',scan_image.shape) #[182,218,182]
#                 print('mask_image',mask_image.shape)
                scan_image = resize_256(scan_image, resize_shape)
                mask_image = resize_256(mask_image, resize_shape)
#                 print('scan_image',scan_image.shape)
#                 print('mask_image',mask_image.shape) #[182,256,256] after reshape

            if augment:
                scan_image, mask_image = image_augmentation(scan_image, mask_image, augment_only_images)

            scan_image = np.expand_dims(scan_image, axis=-1)
            mask_image = np.expand_dims(mask_image, axis=-1)

            for slices in range(len(scan_image)):
                yield scan_image[slices], mask_image[slices]
        
        print('Loaded ' + dataset)
#         print('temp_arr_inp',temp_arr_inp)


def resize_256(array, resize_shape: [int, int, int]) -> np.array:
    for dim in range(len(resize_shape)):
        if resize_shape[dim] == 0:
            resize_shape[dim] = array.shape[dim]
    shape_dif = np.subtract(resize_shape, array.shape) #[0,38,74]

    # Padding
    pad_list = []
    for dif in shape_dif:
        if dif <= 0:
            pad_list.append([0, 0])
        elif dif % 2 != 0:
            pad = int((dif - 1) / 2)
            pad_list.append([pad, pad + 1])
        else:
            pad = int(dif / 2)
            pad_list.append([pad, pad])

    output_array = np.pad(array, pad_list, mode='constant', constant_values=0.0)

    # Cropping
    cl = []  # crop_list
    for dif in shape_dif:
        if dif >= 0:
            cl.append([0, 256])
        elif dif % 2 != 0:
            crop = abs(int((dif + 1) / 2))
            cl.append([crop, crop + 256])
        else:
            crop = abs(int(dif / 2))
            cl.append([crop, crop + 256])

    output_array = output_array[cl[0][0]:cl[0][1], cl[1][0]:cl[1][1], cl[2][0]:cl[2][1]]

    return output_array


def image_augmentation(scan_image, mask_image, augment_only_images=True):
    rng = np.random.default_rng()

    if augment_only_images:
        scan = scan_image
        mask = mask_image
        for index in range(len(scan_image)):
            if bool(rng.integers(2, size=1)[0]):
                scan[index] = np.transpose(scan[index], [1, 0])
                mask[index] = np.transpose(mask[index], [1, 0])
    else:
        tranpose_array = [0, 1, 2]
        rng.shuffle(tranpose_array)

        scan = np.transpose(scan_image, tranpose_array)
        mask = np.transpose(mask_image, tranpose_array)

    for index in range(len(scan_image)):
        if bool(rng.integers(2, size=1)[0]):
            scan[index] = np.flip(scan[index], axis=0)
            mask[index] = np.flip(mask[index], axis=0)

        if bool(rng.integers(2, size=1)[0]):
            scan[index] = np.flip(scan[index], axis=1)
            mask[index] = np.flip(mask[index], axis=1)

    return scan, mask


dim1 = 256
dim2 = 256

structures = False
base_raw = True
mni152 = True
structures_cortex = False
resize = True
resize_shape = [0, dim1, dim2]
augment = True
augment_only_images = True

# args = [set_start, set_end, structures, base_raw, mni152, structures_cortex, resize, resize_shape, augment, augment_only_images]

#tf.data.Dataset.from_generator for data pipeline.helps to create source datset from input data,apply transformations and preprocess 
#to iterate over dataset

train_ds = tf.data.Dataset.from_generator(load_whole_brain_seg_data,
                                    args=[0, 3, structures, base_raw, mni152, structures_cortex, resize, resize_shape, augment, augment_only_images],
                                    output_signature=(
                                        tf.TensorSpec(shape=(dim1, dim2, 1), dtype=tf.float32),
                                        tf.TensorSpec(shape=(dim1, dim2, 1), dtype=tf.float32)))

val_ds = tf.data.Dataset.from_generator(load_whole_brain_seg_data,
                                    args=[3, 4, structures, base_raw, mni152, structures_cortex, resize, resize_shape, augment, augment_only_images],
                                    output_signature=(
                                        tf.TensorSpec(shape=(dim1, dim2, 1), dtype=tf.float32),
                                        tf.TensorSpec(shape=(dim1, dim2, 1), dtype=tf.float32)))


test_ds = tf.data.Dataset.from_generator(load_whole_brain_seg_data,
                                    args=[4, 5, structures, base_raw, mni152, structures_cortex, resize, resize_shape, augment, augment_only_images],
                                    output_signature=(
                                        tf.TensorSpec(shape=(dim1, dim2, 1), dtype=tf.float32),
                                        tf.TensorSpec(shape=(dim1, dim2, 1), dtype=tf.float32)))


In [None]:
AUTOTUNE = tf.data.AUTOTUNE


def performance(ds):
    ds = ds.cache()
    ds = ds.batch(batch_size)
    ds = ds.prefetch(buffer_size=AUTOTUNE)
    return ds


batch_size = 16

train_ds = train_ds.shuffle(10000)
val_ds = val_ds.shuffle(10000)

train_ds = performance(train_ds)
train_ds = train_ds.repeat()
val_ds = performance(val_ds)
test_ds = performance(test_ds)

In [None]:
for image, mask in train_ds.take(1):
    display([image[0], mask[0]])
    sample_image, sample_mask = image[0], mask[0]

In [None]:
# print(train_ds)
# print(val_ds)
# print(train_ds.cardinality())
# print(val_ds.cardinality())

# for image, mask in test_ds.take(1):
#     sample_image, sample_mask = image[0], mask[0]
#     tranpose_array = [0, 1, 2]
# #     rng.shuffle(tranpose_array)
#     print(image.shape)
#     display(sample_image[np.newaxis,...])
#     scan = np.transpose(image[0], tranpose_array)
#     print("scan",scan.shape)
#     display(scan[np.newaxis,...])

In [None]:
for image, mask in train_ds.take(1):
    display([image[0], mask[0]])
    sample_image, sample_mask = image[0], mask[0]

In [None]:
model = unet_model((dim1, dim2, 1))

optimizer = tf.keras.optimizers.Adam(learning_rate = 0.0001)
model.compile(optimizer=optimizer, loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), metrics=['accuracy'])

epochs = 20
steps_per_epoch = 256

log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
earlystop_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3)

checkpoint_filepath = 'epochmodels/{epoch:02d}-' + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
model_save_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_filepath, save_freq=steps_per_epoch * 4)

model_history = model.fit(train_ds, epochs=epochs,
                          steps_per_epoch=steps_per_epoch,
                          validation_data=val_ds,
                          callbacks=[DisplayCallback(), tensorboard_callback, earlystop_callback, model_save_callback])

In [None]:
acc = model_history.history['accuracy']
val_acc = model_history.history['val_accuracy']
loss = model_history.history['loss']
val_loss = model_history.history['val_loss']

epochs_range = range(0, epochs)

plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='lower left')
plt.title('Loss')

plt.suptitle(f'brainseg_whole_210303_6\nBatch size: {batch_size}, Epochs: {epochs}')
plt.show()

In [None]:
show_predictions(test_ds, num=10)