## Libraries

In [22]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import Input, Model, layers
from tensorflow.keras.layers import Lambda, Conv2D, BatchNormalization, MaxPooling2D, Conv2DTranspose, concatenate, Activation, Concatenate
from tensorflow.keras.metrics import IoU, BinaryIoU
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
import keras.backend as K
import cv2 as cv

## Loading Data

In [24]:
home = os.environ['HOME']

In [25]:
path_small_X = os.path.join(home,'raw_data/small_dataset/sample_images')
path_small_y = os.path.join(home,'raw_data/small_dataset/sample_masks')

In [26]:
split_ratio = 0.9

In [28]:
def train_val_split (path_X, path_y, split_ratio):
    X_names = os.listdir(path_X)
    y_names = os.listdir(path_y)
    y_path = [f'{path_y}/{file}' for file in y_names]
    X_path = [f'{path_X}/{file}' for file in X_names]
    train_X, val_X = X_path[:int(len(X_path)*split_ratio)], X_path[int(len(X_path)*split_ratio):]
    train_y, val_y = y_path[:int(len(y_path)*split_ratio)], y_path[int(len(y_path)*split_ratio):]
    return train_X, val_X, train_y, val_y 

In [29]:
train_X, val_X, train_y, val_y = train_val_split (path_small_X, path_small_y, split_ratio) #small dataset

In [30]:
def verify_matching_input_labels(X_names, y_names):
    for x, y in zip(X_names, y_names):
        if os.path.basename(x) != os.path.basename(y):
            raise ValueError(f"X and Y not matching: {x, y}")

In [31]:
verify_matching_input_labels(train_X, train_y)

In [32]:
verify_matching_input_labels(val_X, val_y)

In [33]:
def process_path(image_path, mask_path):
    image = tf.io.read_file(image_path)
    mask = tf.io.read_file(mask_path)
    image = tf.image.decode_png(image, channels = 3)
    mask = tf.image.decode_png(mask, channels = 1) / 255 
    return image, mask

In [38]:
def batch_data (X_path, y_path, batch_size):
    ds_train = tf.data.Dataset.from_tensor_slices((X_path, y_path))
    return ds_train.map(process_path).batch(batch_size)

### Training Dataset

In [40]:
train_dataset = batch_data(train_X, train_y, batch_size=16)

### Validation Dataset

In [41]:
val_dataset = batch_data(val_X, val_y, batch_size=16)

### Test Dataset

In [58]:
path_X_TEST = os.path.join(home,'raw_data/TEST_slices/test_image_slices')
path_y_TEST = os.path.join(home,'raw_data/TEST_slices/test_mask_slices')

In [60]:
def batch_data_test (path_X, path_y, batch_size):
    X_names = os.listdir(path_X)
    X_path = [f'{path_X}/{file}' for file in X_names]
    y_names = os.listdir(path_y)
    y_path = [f'{path_y}/{file}' for file in y_names]
    ds_train = tf.data.Dataset.from_tensor_slices((X_path, y_path))
    return ds_train.map(process_path).batch(batch_size)

In [62]:
TEST_dataset = batch_data_test(path_X_TEST, path_y_TEST, batch_size=16)

## Model Definition

In [42]:
def conv_block(inputs, num_filters):
    x = Conv2D(num_filters, (3,3), padding="same")(inputs)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    x = Conv2D(num_filters, 3, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    return x

In [43]:
def encoder_block(inputs, num_filters):
    x = conv_block(inputs, num_filters) #can be used as skip connection 
    p = MaxPooling2D((2,2))(x)
    return x, p

In [44]:
def decoder_block(inputs, skip_features, num_filters): #skip features are going to be the x returned from the encoder block
    x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(inputs)
    x = Concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    return x

In [45]:
def dice_loss(targets, inputs, smooth=1e-6):
    
    #flatten label and prediction tensors
    inputs = K.flatten(inputs)
    targets = K.flatten(targets)
    
    intersection = K.sum(targets * inputs)
    dice = (2*intersection + smooth) / (K.sum(targets) + K.sum(inputs) + smooth)
    return 1 - dice

In [53]:
def build_unet(img_height, img_width, channels):
    
    inputs = Input((img_height, img_width, channels))
    inputs = Lambda(lambda x: x / 255)(inputs) #Normalize the pixels by dividing by 255

    #Encoder - downscaling (creating features/filter)
    skip1, pool1 = encoder_block(inputs, 16)
    skip2, pool2 = encoder_block(pool1, 32) 
    skip3, pool3 = encoder_block(pool2, 64)
    skip4, pool4 = encoder_block(pool3, 128) 
    
    #Bottleneck or bridge between encoder and decoder
    b1 = conv_block(pool4, 256)
    
    #Decoder - upscaling (reconstructing the image and giving it precise spatial location)
    decoder1 = decoder_block(b1, skip4, 128)
    decoder2 = decoder_block(decoder1, skip3, 64)
    decoder3 = decoder_block(decoder2, skip2, 32)
    decoder4 = decoder_block(decoder3, skip1, 16)
    
    #Output
    outputs = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(decoder4)
    model = Model(inputs, outputs)
    
    iou = BinaryIoU()
    
    model.compile(optimizer='adam', loss=dice_loss, metrics=['accuracy', iou])
    
    #model.summary()
    
    return model


In [54]:
model = build_unet(256, 256, 3)

In [55]:
checkpoint_filepath = '../tmp/checkpoint'
es = EarlyStopping(patience=5, restore_best_weights=True)
checkpoint = ModelCheckpoint(filepath=checkpoint_filepath, save_weights_only=True, monitor='val_loss', restore_best_weights=True)

In [56]:
history = model.fit(train_dataset, validation_data=val_dataset, epochs = 500, callbacks=[es, checkpoint], verbose=1)

Epoch 1/500
Epoch 2/500
Epoch 3/500
Epoch 4/500
Epoch 5/500
Epoch 6/500
Epoch 7/500
Epoch 8/500
Epoch 9/500
Epoch 10/500
Epoch 11/500
Epoch 12/500
Epoch 13/500
Epoch 14/500
Epoch 15/500
Epoch 16/500
Epoch 17/500


In [63]:
model.evaluate(TEST_dataset)



[0.212067112326622, 0.9378632307052612, 0.7928847670555115]