## Libraries

In [1]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Lambda, Conv2D, BatchNormalization, MaxPooling2D, Conv2DTranspose, Activation, Concatenate
from tensorflow.keras.metrics import BinaryIoU
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
import keras.backend as K
import cv2 as cv
import matplotlib.pyplot as plt

## Loading Data

In [2]:
home = os.environ['HOME']

In [3]:
path_X = os.path.join(home,'raw_data/image_slices')
path_y = os.path.join(home,'raw_data/mask_slices')

In [4]:
split_ratio = 0.9

In [5]:
def train_val_split (path_X, path_y, split_ratio):
    X_names = os.listdir(path_X)
    y_names = os.listdir(path_y)
    y_path = [f'{path_y}/{file}' for file in y_names]
    X_path = [f'{path_X}/{file}' for file in X_names]
    train_X, val_X = X_path[:int(len(X_path)*split_ratio)], X_path[int(len(X_path)*split_ratio):]
    train_y, val_y = y_path[:int(len(y_path)*split_ratio)], y_path[int(len(y_path)*split_ratio):]
    return train_X, val_X, train_y, val_y 

In [6]:
train_X, val_X, train_y, val_y = train_val_split (path_X, path_y, split_ratio)

In [7]:
def verify_matching_input_labels(X_names, y_names):
    for x, y in zip(X_names, y_names):
        if os.path.basename(x) != os.path.basename(y):
            raise ValueError(f"X and Y not matching: {x, y}")

In [8]:
verify_matching_input_labels(train_X, train_y)

In [9]:
verify_matching_input_labels(val_X, val_y)

In [10]:
def process_path(image_path, mask_path):
    image = tf.io.read_file(image_path)
    mask = tf.io.read_file(mask_path)
    image = tf.image.decode_png(image, channels = 3)
    mask = tf.image.decode_png(mask, channels = 1) / 255 
    return image, mask

In [11]:
def batch_data (X_path, y_path, batch_size):
    ds_train = tf.data.Dataset.from_tensor_slices((X_path, y_path))
    return ds_train.shuffle(buffer_size = len(X_path), seed = 10).map(process_path).batch(batch_size)

### Training Dataset

In [12]:
train_dataset = batch_data(train_X, train_y, batch_size=16)

2022-12-02 10:21:17.427015: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-12-02 10:21:17.573086: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-12-02 10:21:17.574919: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-12-02 10:21:17.579678: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compil

### Validation Dataset

In [13]:
val_dataset = batch_data(val_X, val_y, batch_size=16)

### Test Dataset

In [14]:
path_X_TEST = os.path.join(home,'raw_data/TEST_slices/test_image_slices')
path_y_TEST = os.path.join(home,'raw_data/TEST_slices/test_mask_slices')

In [15]:
def batch_data_test (path_X, path_y, batch_size):
    X_names = os.listdir(path_X)
    X_path = [f'{path_X}/{file}' for file in X_names]
    y_names = os.listdir(path_y)
    y_path = [f'{path_y}/{file}' for file in y_names]
    ds_train = tf.data.Dataset.from_tensor_slices((X_path, y_path))
    return ds_train.map(process_path).batch(batch_size)

In [16]:
TEST_dataset = batch_data_test(path_X_TEST, path_y_TEST, batch_size=16)

## Model Definition

In [17]:
def conv_block(inputs, num_filters):
    x = Conv2D(num_filters, (3,3), padding="same")(inputs)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    x = Conv2D(num_filters, 3, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    return x

In [18]:
def encoder_block(inputs, num_filters):
    x = conv_block(inputs, num_filters) #can be used as skip connection 
    p = MaxPooling2D((2,2))(x)
    return x, p

In [19]:
def decoder_block(inputs, skip_features, num_filters): #skip features are going to be the x returned from the encoder block
    x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(inputs)
    x = Concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    return x

In [20]:
def dice_loss(targets, inputs, smooth=1e-6):
    
    #flatten label and prediction tensors
    inputs = K.flatten(inputs)
    targets = K.flatten(targets)
    
    intersection = K.sum(targets * inputs)
    dice = (2*intersection + smooth) / (K.sum(targets) + K.sum(inputs) + smooth)
    return 1 - dice

In [21]:
def loss_sum(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)
    o = tf.keras.losses.BinaryCrossentropy()(y_true, y_pred) + dice_loss(y_true, y_pred)
    return tf.reduce_mean(o)

In [38]:
def build_unet(img_height, img_width, channels):
    
    inputs = Input((img_height, img_width, channels))
    inputs = Lambda(lambda x: x / 255)(inputs) #Normalize the pixels by dividing by 255

    #Encoder - downscaling (creating features/filter)
    skip1, pool1 = encoder_block(inputs, 16)
    skip2, pool2 = encoder_block(pool1, 32) 
    skip3, pool3 = encoder_block(pool2, 64)
    skip4, pool4 = encoder_block(pool3, 128) 
    
    #Bottleneck or bridge between encoder and decoder
    b1 = conv_block(pool4, 256)
    
    #Decoder - upscaling (reconstructing the image and giving it precise spatial location)
    decoder1 = decoder_block(b1, skip4, 128)
    decoder2 = decoder_block(decoder1, skip3, 64)
    decoder3 = decoder_block(decoder2, skip2, 32)
    decoder4 = decoder_block(decoder3, skip1, 16)
    
    #Output
    outputs = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(decoder4)
    model = Model(inputs, outputs)
    
    iou = BinaryIoU()
    
    model.compile(optimizer='adam', loss=loss_sum, metrics=['accuracy', iou])
    
    model.summary()
    
    return model


In [23]:
model = build_unet(256, 256, 3)

In [24]:
checkpoint_filepath = '../tmp/simple_unet/loss_sum_trainingset'
es = EarlyStopping(patience=5, restore_best_weights=True)
checkpoint = ModelCheckpoint(filepath=checkpoint_filepath, save_weights_only=True, monitor='val_loss', restore_best_weights=True)

In [25]:
#history = model.fit(train_dataset, validation_data=val_dataset, epochs = 500, callbacks=[es, checkpoint], verbose=1)

In [26]:
# plt.plot(history.history['val_binary_io_u'], label='validation set')
# plt.plot(history.history['binary_io_u'], label='training set')
# plt.legend()
# plt.show()

In [27]:
loaded_model = build_unet(256, 256, 3)

In [28]:
loaded_model.load_weights(checkpoint_filepath)

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f45db607d10>

In [36]:
loaded_model.evaluate(TEST_dataset)

2022-12-02 10:24:08.489435: I tensorflow/stream_executor/cuda/cuda_dnn.cc:368] Loaded cuDNN version 8200




[0.27950355410575867, 0.9610777497291565, 0.8586697578430176]

In [29]:
for x, y in TEST_dataset:
    print(x, y)
    break

tf.Tensor(
[[[[ 40  64  57]
   [ 37  61  54]
   [ 39  63  56]
   ...
   [ 40  59  57]
   [ 45  64  63]
   [ 40  62  60]]

  [[ 43  65  57]
   [ 44  67  59]
   [ 47  69  62]
   ...
   [ 40  60  57]
   [ 49  68  66]
   [ 45  64  62]]

  [[ 40  63  55]
   [ 47  70  62]
   [ 52  73  67]
   ...
   [ 52  70  66]
   [ 52  71  68]
   [ 48  67  64]]

  ...

  [[ 77  95  77]
   [ 86  99  83]
   [ 88  99  87]
   ...
   [ 38  61  55]
   [ 43  69  61]
   [ 43  68  59]]

  [[ 81 101  87]
   [ 95 113  98]
   [ 74  93  80]
   ...
   [ 41  62  58]
   [ 37  64  56]
   [ 42  67  58]]

  [[ 64  88  74]
   [ 76  99  85]
   [ 60  83  69]
   ...
   [ 45  65  62]
   [ 40  63  56]
   [ 41  65  56]]]


 [[[195 199 202]
   [193 198 201]
   [193 198 201]
   ...
   [127 155 132]
   [126 154 131]
   [138 166 143]]

  [[187 192 195]
   [187 192 195]
   [188 193 196]
   ...
   [132 160 137]
   [128 156 133]
   [140 168 145]]

  [[185 190 193]
   [185 190 193]
   [186 191 194]
   ...
   [135 163 138]
   [128 156 131]


In [30]:
#loaded_model.predict(x)