In [None]:
import joblib
import numpy as np
import os
import tensorflow as tf
import math
import csv

from keras.models import Model
from keras.layers import Input, merge, Conv2D, MaxPooling2D, UpSampling2D, concatenate, Dropout
from keras.layers import BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers.schedules import ExponentialDecay
from keras.callbacks import ModelCheckpoint, LearningRateScheduler, TensorBoard, CSVLogger
from keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import plot_model
from keras import backend as K
from PIL import Image
from skimage.metrics import structural_similarity as ssim
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt # to plot images
%matplotlib inline

### Consts

In [None]:
train_and_val_dataset_file = 'datasets/train-and-val.pkl'
test_dataset_file = 'datasets/test.pkl'
saved_model_filename = "datasets/test-4-new-tentative-{epoch:02d}-{val_dice_coef_accur:.4f}.hdf5"
csv_logger_training = "datasets/test-4-new-tentative.csv"

### Load datasets

In [None]:
X_remaining, Y_remaining, remaining_dataset_desc = joblib.load(train_and_val_dataset_file)
Xte, yte, test_dataset_desc = joblib.load(test_dataset_file) # X and y for test
training_set_index = remaining_dataset_desc['training_set_index']
validation_set_index = remaining_dataset_desc['validation_set_index']

Xva, yva = X_remaining[training_set_index:validation_set_index,:], Y_remaining[training_set_index:validation_set_index] # X and y for validation
Xtr, ytr = joblib.load("datasets/train-augmented-11216.pkl")

print(Xtr.shape)
print(Xva.shape)
print(Xte.shape)
print(ytr.shape)
print(yva.shape)
print(yte.shape)

### Pre processing

In [None]:
# Preprocessing in the training set (mean and sd) and apply it to all sets

full_image_mean_value = Xtr.mean() # mean-value for each pixel of all full images
full_image_sd = Xtr.std() # standard deviation for each pixel of all full images

Xtr = (Xtr - full_image_mean_value) / full_image_sd
Xva = (Xva - full_image_mean_value) / full_image_sd
Xte = (Xte - full_image_mean_value) / full_image_sd

### Pre-configurations

In [None]:
K.set_image_data_format('channels_last')  # TF dimension
_, *input_image_shape, _ = Xtr.shape
input_image_shape = tuple(input_image_shape)
print(input_image_shape)

smooth = 1.

use_dropout = True
use_regularizers = True
dropout_rate = 0.5
number_of_epochs = 1000
batch_size = 64
kernel_size = (5, 5)
initial_volume_size = 64

### Define Unet model

In [None]:
# Define loss function
def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = 2 * K.sum(y_true_f * y_pred_f) + smooth
    union = K.sum(y_true_f) + K.sum(y_pred_f) + smooth
    return K.mean(intersection / union)

def dice_coef_per_image_in_batch(y_true, y_pred):
    y_true_f = K.batch_flatten(y_true)
    y_pred_f = K.batch_flatten(y_pred)
    intersection = 2. * K.sum(y_true_f * y_pred_f, axis=1, keepdims=True) + smooth
    union = K.sum(y_true_f, axis=1, keepdims=True) + K.sum(y_pred_f, axis=1, keepdims=True) + smooth
    return K.mean(intersection / union)

def dice_coef_loss(y_true, y_pred):
    return -dice_coef_per_image_in_batch(y_true, y_pred)

def dice_coef_accur(y_true, y_pred):
    return dice_coef_per_image_in_batch(y_true, y_pred)

def IOU_calc(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)

    return 2*(intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def IOU_calc_loss(y_true, y_pred):
    return -IOU_calc(y_true, y_pred)

def setup_regularizers(conv_layer):
    return BatchNormalization()(conv_layer) if use_regularizers else conv_layer

def setup_dropout(conv_layer):
    return Dropout(dropout_rate)(conv_layer) if use_dropout else conv_layer

# Define model
inputs = Input((*input_image_shape, 1))
conv1 = Conv2D(initial_volume_size, kernel_size, activation='relu', padding='same')(inputs)
conv1 = Conv2D(initial_volume_size, kernel_size, activation='relu', padding='same')(conv1)
conv1 = setup_regularizers(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

conv2 = Conv2D(initial_volume_size*2, kernel_size, activation='relu', padding='same')(pool1)
conv2 = Conv2D(initial_volume_size*2, kernel_size, activation='relu', padding='same')(conv2)
conv2 = setup_regularizers(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

conv3 = Conv2D(initial_volume_size*4, kernel_size, activation='relu', padding='same')(pool2)
conv3 = Conv2D(initial_volume_size*4, kernel_size, activation='relu', padding='same')(conv3)
conv3 = setup_regularizers(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

conv4 = Conv2D(initial_volume_size*8, kernel_size, activation='relu', padding='same')(pool3)
conv4 = Conv2D(initial_volume_size*8, kernel_size, activation='relu', padding='same')(conv4)
conv4 = setup_regularizers(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)tf.keras.optimizers.schedules.ExponentialDecay

conv5 = Conv2D(initial_volume_size*16, kernel_size, activation='relu', padding='same')(pool4)
conv5 = Conv2D(initial_volume_size*16, kernel_size, activation='relu', padding='same')(conv5)
conv5 = setup_regularizers(conv5)

up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4], axis=3)
up6 = setup_dropout(up6)
conv6 = Conv2D(initial_volume_size*8, kernel_size, activation='relu', padding='same')(up6)
conv6 = Conv2D(initial_volume_size*8, kernel_size, activation='relu', padding='same')(conv6)

up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3], axis=3)
up7 = setup_dropout(up7)
conv7 = Conv2D(initial_volume_size*4, kernel_size, activation='relu', padding='same')(up7)
conv7 = Conv2D(initial_volume_size*4, kernel_size, activation='relu', padding='same')(conv7)

up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2], axis=3)
up8 = setup_dropout(up8)
conv8 = Conv2D(initial_volume_size*2, kernel_size, activation='relu', padding='same')(up8)
conv8 = Conv2D(initial_volume_size*2, kernel_size, activation='relu', padding='same')(conv8)

up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1], axis=3)
up9 = setup_dropout(up9)
conv9 = Conv2D(initial_volume_size, kernel_size, activation='relu', padding='same')(up9)
conv9 = Conv2D(initial_volume_size, kernel_size, activation='relu', padding='same')(conv9)

conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9)

model = Model(inputs=[inputs], outputs=[conv10])


initial_learning_rate = 1e-5
lr_schedule = ExponentialDecay(
    initial_learning_rate,
    decay_steps=2000,
    decay_rate=0.96,
    staircase=True)
model.compile(optimizer=Adam(lr=lr_schedule), loss=dice_coef_loss, metrics=[dice_coef_accur])
print("Size of the CNN: %s" % model.count_params())


In [None]:
print(model.summary())

### Train model

In [None]:
# Define callbacks
model_checkpoint = ModelCheckpoint(saved_model_filename, monitor='val_dice_coef_accur', save_best_only=True, verbose=1)
csv_logger = CSVLogger(csv_logger_training, append=True, separator=';')

# Train
history = model.fit(Xtr, ytr, batch_size=batch_size, epochs=number_of_epochs, verbose=1, shuffle=True,
             callbacks=[model_checkpoint, csv_logger], validation_data=(Xva, yva))

### Show training metrics / loss

In [None]:
csv_history_file = "datasets/test-4-new-tentative.csv"
data = {}

with open(csv_history_file, "r") as f:
    reader = csv.reader(f, delimiter=";")
    for i, line in enumerate(reader):
        if i > 0:
            dice_coef_accur_in_csv, loss_in_csv, val_dice_coef_accur_in_csv, val_loss_in_csv = float(line[1]), float(line[2]), float(line[3]), float(line[4])

            dice_coef_accur_list = data.get('dice_coef_accur', None)
            if dice_coef_accur_list is None:
                dice_coef_accur_list = [dice_coef_accur_in_csv]
                data['dice_coef_accur'] = dice_coef_accur_list
            else:
                data['dice_coef_accur'].append(dice_coef_accur_in_csv)

            loss_list = data.get('loss', None)
            if loss_list is None:
                loss_list = [loss_in_csv]
                data['loss'] = loss_list
            else:
                data['loss'].append(loss_in_csv)

            val_dice_coef_accur_list = data.get('val_dice_coef_accur', None)
            if val_dice_coef_accur_list is None:
                val_dice_coef_accur_list = [val_dice_coef_accur_in_csv]
                data['val_dice_coef_accur'] = val_dice_coef_accur_list
            else:
                data['val_dice_coef_accur'].append(val_dice_coef_accur_in_csv)

            val_loss_list = data.get('val_loss', None)
            if val_loss_list is None:
                val_loss_list = [val_loss_in_csv]
                data['val_loss'] = val_loss_list
            else:
                data['val_loss'].append(val_loss_in_csv)

x = data['dice_coef_accur']
y = data['val_dice_coef_accur']
plt.plot(x, label='train')
plt.plot(y, label = 'val')
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
plt.show()

x = data['loss']
y = data['val_loss']
plt.plot(x[:300], label='train')
plt.plot(y[:300], label='val')
plt.title("Training and validation loss over epochs")
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
plt.savefig("Training_loss.jpg", dpi=1200,  bbox_inches='tight')
plt.show()

### Evaluate the model

### Predict masks using the trained model

In [None]:
model.load_weights("datasets/test-4-new-tentative-821-0.9578.hdf5")
imgs_mask_test = model.predict(Xte, verbose=1)

In [None]:
print(np.median(imgs_mask_test[0]))
acc_metric = tf.keras.metrics.BinaryAccuracy()
acc_metric.update_state(imgs_mask_test, yte)
print("Accuracy: ", acc_metric.result().numpy())

meaniou_metric = tf.keras.metrics.MeanIoU(2)
meaniou_metric.update_state(imgs_mask_test, yte)
print("MeanIOU: ", meaniou_metric.result().numpy())

recall_metric = tf.keras.metrics.Recall()
recall_metric.update_state(imgs_mask_test, yte)
print("Recall / Sensitivity: ", recall_metric.result().numpy())

precision_metric = tf.keras.metrics.Precision()
precision_metric.update_state(imgs_mask_test, yte)
print("Precision: ", precision_metric.result().numpy())

tp_metric = tf.keras.metrics.TruePositives()
tp_metric.update_state(imgs_mask_test, yte)
tp = tp_metric.result().numpy()
print("TP: ", tp)

tn_metric = tf.keras.metrics.TrueNegatives()
tn_metric.update_state(imgs_mask_test, yte)
tn = tn_metric.result().numpy()
print("TN: ", tn)

fp_metric = tf.keras.metrics.FalsePositives()
fp_metric.update_state(imgs_mask_test, yte)
fp = fp_metric.result().numpy()
print("FP: ", fp)

fn_metric = tf.keras.metrics.FalseNegatives()
fn_metric.update_state(imgs_mask_test, yte)
fn = fn_metric.result().numpy()
print("FN: ", fn)

print("Specificity: ", tn/(fp+tn))

### Show results

In [None]:
ncols = 3 # number of columns in final grid of images
nrows = 8 # looking at all images takes some time
_, axes = plt.subplots(nrows, ncols, figsize=(17, 17*nrows/ncols))
for axis in axes.flatten():
    axis.set_axis_off()
    axis.set_aspect('equal')

for k in range(0, nrows):
    im_test_original = Xte[k].reshape(*input_image_shape)
    im_result = imgs_mask_test[k].reshape(*input_image_shape)
    im_ground_truth = yte[k].reshape(*input_image_shape)

    axes[k, 0].set_title("Original Test Image")
    axes[k, 0].imshow(im_test_original, cmap='gray')

    axes[k, 1].set_title("Ground Truth")
    axes[k, 1].imshow(im_ground_truth, cmap='gray')

    axes[k, 2].set_title("Predicted")
    axes[k, 2].imshow(im_result, cmap='gray')
plt.savefig("Examples.jpg", dpi=500)

# Show best and worst test example by some metric
# best_index = 0
# worst_index = 0
# best_m = 0
# worst_m = 1
# for i in range(len(yte)):
#     acc_metric.reset_state()
#     acc_metric.update_state(imgs_mask_test[i], yte[i])
#     m = acc_metric.result().numpy()

#     if m < worst_m:
#         best_m = m
#         worst_index = i

#     if m > best_acc:
#         best_acc = m
#         best_index = i

# print(best_index, best_m)
# print(worst_index, worst_m)
# idxs = [best_index, worst_index]

# _, axes = plt.subplots(2, ncols, figsize=(17, 17*2/ncols))
# for axis in axes.flatten():
#     axis.set_axis_off()
#     axis.set_aspect('equal')

# for k in range(0, 2):
#     im_test_original = Xte[idxs[k]].reshape(*input_image_shape)
#     im_result = imgs_mask_test[idxs[k]].reshape(*input_image_shape)
#     im_ground_truth = yte[idxs[k]].reshape(*input_image_shape)

#     axes[k, 0].set_title("Original Test Image")
#     axes[k, 0].imshow(im_test_original, cmap='gray')

#     axes[k, 1].set_title("Ground Truth")
#     axes[k, 1].imshow(im_ground_truth, cmap='gray')

#     axes[k, 2].set_title("Predicted")
#     axes[k, 2].imshow(im_result, cmap='gray')