# Import necessary libraries

In [None]:
import os
import glob
# import keras
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_addons as tfa
from PIL import Image
import numpy as np
from tensorflow.keras import layers
import keras
import wandb
from wandb.integration.keras import WandbMetricsLogger
import datetime
import pandas as pd

In [None]:
# tf.test.is_gpu_available()
tf.config.list_physical_devices('GPU')

In [None]:
wandb.login()

# Retrieve images and masks from file directory

In [None]:
#Get file paths

image_dir = '../data/feb-25/train/images/'
mask_dir = '../data/feb-25/train/masks/'

image_paths = glob.glob(image_dir+"*.jpg")
mask_paths = glob.glob(mask_dir+"*.png")

In [None]:
print(f'Images: {len(image_paths)}')
print(f'Masks: {len(mask_paths)}')

In [None]:
from random import shuffle
from sklearn.model_selection import train_test_split

In [None]:
# # Train test validation split function
# def get_train_and_validation_splits(img_paths, mask_paths):
#     img_paths.sort()
#     mask_paths.sort()
    
#     train_img_paths, temp_img_paths, train_mask_paths, temp_mask_paths = train_test_split(image_paths, mask_paths, 
#                                                                           train_size=0.7, random_state=711)
    
# #     print(len(temp_img_paths))

#     val_img_paths, test_img_paths, val_mask_paths,  test_mask_paths = train_test_split(temp_img_paths, temp_mask_paths,
#                                                                                      test_size=0.4)

#     return train_img_paths, train_mask_paths, val_img_paths, val_mask_paths, test_img_paths, test_mask_paths

In [None]:
# Train test validation split function
def get_train_and_validation_splits(img_paths, mask_paths):
    img_paths.sort()
    mask_paths.sort()
    
    train_img_paths, val_img_paths, train_mask_paths, val_mask_paths = train_test_split(image_paths, mask_paths, 
                                                                          test_size=0.25, random_state=711)
    

    return train_img_paths, train_mask_paths, val_img_paths, val_mask_paths

In [None]:
X_test = sorted(glob.glob('../data/feb-25/test/images/*.jpg'))
Y_test = sorted(glob.glob('../data/feb-25/test/masks/*.jpg'))

In [None]:
X_train, Y_train, X_val, Y_val, = get_train_and_validation_splits(image_paths, mask_paths)

In [None]:
print(f'Train: {len(X_train)} \nTest: {len(X_test)} \nVal: {len(X_val)}')

# Define helper functions for loading, preprocessing and displaying images

In [None]:
#Map class values to class names
classes = {0:'damage', 
           1:'corm', 
           2:'background'}

#Map pixel values to class values
pixel_map = {(0,0,0):[2], 
             (0,255,0):[1], 
             (255,0,0):[0]}

#Map class values to pixel values
class_map = {2:[0,0,0], 
             1:[0,255,0], 
             0:[255,0,0]}

image_shape = (512,512,3)

In [None]:
wandb.log({
    'image_size': str(image_shape[0]),
    'train_samples': len(X_train),
    'val_samples': len(X_val),
    'test_samples': len(X_test),
})

In [None]:
def random_rotation(img, mask):
    '''
    Randomly rotates images 90 degrees anticlockwise
    '''
    if tf.random.uniform(()) > 0.5:
        img = tf.image.rot90(img)
        mask = tf.image.rot90(mask)

    return img, mask

In [None]:
def random_flip(img, mask):
    '''
    Randomly flips images horizontally (left to right)
    '''
    if tf.random.uniform(()) > 0.5:
        img = tf.image.flip_left_right(img)
        mask = tf.image.flip_left_right(mask)

    return img, mask

In [None]:

def gaussian_blur(img):
    '''
    Applies a gaussian blur to images
    '''
    img = tfa.image.gaussian_filter2d(img, padding='CONSTANT')

    return img

In [None]:
def translate(img, mask):
    '''
    Shifts images in random directions by a random factor
    '''
    img_shape = img.shape
    mask_shape = mask.shape

    #Create translation vector
    tx = tf.random.uniform((), minval=-15, maxval=15)
    ty = tf.random.uniform((), minval=-15, maxval=15)

    img = tfa.image.translate(img, [tx,ty])
    mask = tfa.image.translate(mask, [tx,ty], fill_value=2)

    #Enforce shape consistency
    img.set_shape(img_shape)
    mask.set_shape(mask_shape)

    return img, mask

In [None]:
def noise_injection(img, im_shape=image_shape):
    '''
    Adds noise to images by a random factor
    '''
    #Get random noise factor
    noise_factor = tf.random.uniform((), maxval=0.1)

    #Generate noise from a gaussian distribution by a random factor
    noise = noise_factor * tf.random.normal(shape=im_shape)

    #Create noisy image
    noisy_img = noise + img

    #Enforce pixel range consistency
    noisy_img = tf.clip_by_value(noisy_img, 0.0, 1.0)

    return noisy_img

In [None]:
def color_transformations(img, mask):
    '''
    Randomly adjusts color aspects of an image
    '''
    img, mask = random_rotation(img, mask)
    img, mask = random_flip(img, mask)

    img = tf.image.random_brightness(img, 0.4)
    img = tf.image.random_contrast(img, 0.5, 2.0)
    img = tf.image.random_saturation(img, 0.75, 1.25)
    img = tf.image.random_hue(img, 0.1)

    return img, mask

In [None]:
def geometric_transformations(img, mask):
    '''
    Randomly adjusts positioning and orientation of an image
    '''
    img, mask = translate(img, mask)
    img, mask = random_rotation(img, mask)
    img, mask = random_flip(img, mask)

    return img, mask

In [None]:
def noise_transformations(img, mask):
    '''
    Randomly adds or reduces image noise
    '''
    img, mask = random_rotation(img, mask)
    img, mask = random_flip(img, mask)

    if tf.random.uniform(()) > 0.3:
        img = gaussian_blur(img)
    else:
        img = noise_injection(img)

    return img, mask

In [None]:
def reshape_mask(mask, im_shape=image_shape):
    '''
    Assigns each pixel a new value with respect to it's associated class i.e.
    reshapes masks from (w,h,3) to (w,h,1)
      '''
    #Get mask array
    img_array = np.array(mask)

    #Generate list of pixel sequences from mask array
    pixels = list(Image.fromarray(img_array).getdata())

    #Map pixels to classes and create new mask array
    mask = np.array([pixel_map[px] for px in pixels])

    #Reshape mask array
    mask = np.reshape(mask, (im_shape[0],im_shape[1],1))

    #Create image(mask) tensor
    mask_tensor = tf.convert_to_tensor(mask, dtype=tf.uint8)

    return mask_tensor

In [None]:
def tf_reshape_mask(img, mask):
    '''
    Wrapper function to enable applying arbitrary python logic
    '''
    mask_shape = mask.shape

    [mask,] = tf.py_function(reshape_mask, [mask], [tf.uint8])

    #Enforce shape consistency
    mask.set_shape(mask_shape)

    return img, mask

In [None]:
def revert_mask(mask, im_shape=image_shape):
    '''
    Reverts masks to original shape (128,128,3)
    '''
    #Generate list of pixel sequences from mask array
    pixels = np.reshape(mask, (im_shape[0]*im_shape[1],1)).tolist()

    #Map pixels to classes and create new mask array (with original shape)
    mask = np.reshape(np.array([class_map[px[0]] for px in pixels], dtype='uint8'), im_shape)

    #Create image(mask) tensor
    mask_tensor = tf.convert_to_tensor(mask, dtype=tf.uint8)

    return mask_tensor

In [None]:
def load_raw_images_and_masks(img_path, mask_path):
    '''
    Loads raw images from file paths
    '''
    img_raw = tf.io.read_file(img_path)
    mask_raw = tf.io.read_file(mask_path)
    img = tf.image.decode_jpeg(img_raw)
    mask = tf.image.decode_png(mask_raw)

    return img, mask

In [None]:
def resize_image(img, mask, size=(image_shape[0], image_shape[1])):
    img = tf.image.resize(img, size, method='nearest')
    mask = tf.image.resize(mask, size, method='nearest')
    return img, mask

In [None]:
def normalize_image(img):
    '''
    Normalizes images
    '''
    img = tf.cast(img, tf.float32)/255.0

    return img

In [None]:
@tf.function
def load_train_images_and_masks(img_path, mask_path):
    '''
    Loads train images and masks, and performs partial preprocessing
    '''
    img, mask = load_raw_images_and_masks(img_path, mask_path)
    img, mask = resize_image(img, mask)
    img, mask = random_rotation(img, mask)
    img, mask = random_flip(img, mask)
    img = normalize_image(img)

    return img, mask


In [None]:
def load_validation_images_and_masks(img_path, mask_path):
    '''
    Loads validation images and masks, and performs partial preprocessing
    '''
    img, mask = load_raw_images_and_masks(img_path, mask_path)
    img, mask = resize_image(img, mask)
    img = normalize_image(img)

    return img, mask

In [None]:
# def get_dataset(image_paths, mask_paths):
#     '''
#     Generates the training dataset
#     '''
#     train_dataset = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths))

#     #Apply preprocessing across the dataset
#     train_dataset = train_dataset.map(load_train_images_and_masks, 
#                                         num_parallel_calls=tf.data.experimental.AUTOTUNE)

#     return train_dataset

In [None]:
# To be replaced
def get_train_dataset(image_paths, mask_paths):
    '''
    Generates the training dataset
    '''
    train_dataset = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths))

    #Apply preprocessing across the dataset
    train_dataset = train_dataset.map(load_train_images_and_masks, 
                                        num_parallel_calls=tf.data.experimental.AUTOTUNE)

    return train_dataset

In [None]:
# To be replaced
def get_validation_dataset(image_paths, mask_paths):
    '''
    Generates the validation dataset
    '''
    validation_dataset = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths))

    #Apply preprocessing across the dataset
    validation_dataset = validation_dataset.map(load_validation_images_and_masks)

    return validation_dataset


In [None]:
def display(plots, titles, cmap=None, n_examples=1, metrics=None):
    
    display_strings=[]

    if metrics:
        for metric in metrics:
            display_string=''
            for key, value in metric.items():
                display_string += key+': '+str(value)+'\n'
            display_strings.append(display_string)

    labels = [1,2,4,5,7,8]
  
    plt.figure(figsize=(15,15))

    for i in range(len(plots)):
        ax = plt.subplot(n_examples, 3, i+1)
        plt.title(titles[i])
        if i in labels and metrics!=None:
            plt.xlabel(display_strings[labels.index(i)], fontsize=12)
        plt.xticks([])
        plt.yticks([])
        plt.imshow(plots[i], cmap=cmap)

    plt.subplots_adjust(hspace=0.8)

In [None]:
def display_(plots, titles, cmap=None, n_examples=1, metrics=None, centroids=None):
    '''
    Displays images and associated metrics (if given)
    '''
    if metrics:
        display_string='[Lesion]: [Area, Coverage(%), Distance from centroid(+)]\n\n'
        for key, value in metrics.items():
            display_string += 'Lesion'+str(key)+':  '+str(value[0])+',  '+str(value[1])+',  '+str(value[2])+'\n'

    plt.figure(figsize=(15,15))

    for i in range(len(plots)):
        ax = plt.subplot(n_examples, 3, i+1)
        plt.title(titles[i])
        if i == 0 and metrics and centroids:
            for c in range(len(centroids[0])):
                plt.annotate(c, centroids[0][c], color='white')
            plt.annotate('+', centroids[1], color='white')
            plt.xlabel(display_string, fontsize=12)
        plt.xticks([])
        plt.yticks([])
        plt.imshow(plots[i], cmap=cmap)

    plt.subplots_adjust(hspace=0.8)

In [None]:
# def plot_metric(name, title):
#     '''
#     Plots model metrics
#     '''
#     plt.plot(model_history.history[name], color='blue', label=name)
#     plt.plot(model_history.history['val_'+name], color='green', label='val_'+name)
#     plt.xlabel('epochs')
#     plt.ylabel(name)
#     plt.ylim(top=1)
#     plt.title(title)
#     plt.legend()
#     plt.show()

In [None]:
def plot_metric(name, title, save_path):
    '''
    Plots model metrics
    '''
    plt.plot(model_history.history[name], color='blue', label=name)
    plt.plot(model_history.history['val_'+name], color='green', label='val_'+name)
    plt.xlabel('epochs')
    plt.ylabel(name)
    plt.ylim(top=1)
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.savefig(save_path)
    plt.show()
    

# Create datasets and perform image augmentation

Note: This process increases the size of the training set by a factor of 4 (important to remember this when determining training steps).

In [None]:
train = get_train_dataset(X_train, Y_train)
val  = get_validation_dataset(X_val, Y_val)

train = train.map(tf_reshape_mask)
val = val.map(tf_reshape_mask)

augmented = train.map(color_transformations)
augmented_ = train.map(geometric_transformations)
augmented__ = train.map(noise_transformations)

train = train.concatenate(augmented)
train = train.concatenate(augmented_)
train = train.concatenate(augmented__)

In [None]:
#Ensure shape consistency
print(train.element_spec)
print(val.element_spec)

# Prepare datasets

In [None]:
batch_size = 2
buffer_size = 100

In [None]:
wandb.log({
    'batch_size': batch_size
})

In [None]:
#Shuffle and group into batches
train_dataset = train.shuffle(buffer_size)
train_dataset = train_dataset.batch(batch_size).repeat()

#Prefetch to optimize processing
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

#Group into batches
validation_dataset = val.batch(batch_size).repeat()

In [None]:
#Get samples for exploration
samples = [(image, mask) for image, mask in train.take(1)]

In [None]:
sample_image = samples[0][0]
sample_mask = samples[0][1]

In [None]:
#Ensure shape consistency
print(sample_image.shape)
print(sample_mask.shape)

In [None]:
#Visualize samples
display([sample_image, sample_mask[:,:,0], revert_mask(sample_mask)], ['image','reshaped mask', 'true mask'])

# Define UNet model

In [None]:
#Import necessary tensorflow modules
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, MaxPooling2D
from tensorflow.keras.layers import Dropout, Input, Activation, concatenate, BatchNormalization
from tensorflow.keras.models import Model

In [None]:
def double_conv_block(x, n_filters):

    # Conv2D then ReLU activation
    x = layers.Conv2D(n_filters, 3, padding = "same", activation = "relu", kernel_initializer = "he_normal")(x)
    # Conv2D then ReLU activation
    x = layers.Conv2D(n_filters, 3, padding = "same", activation = "relu", kernel_initializer = "he_normal")(x)

    return x

In [None]:
def downsample_block(x, n_filters):
    f = double_conv_block(x, n_filters)
    p = layers.MaxPool2D(2)(f)
    p = layers.Dropout(0.3)(p)

    return f, p

In [None]:
def upsample_block(x, conv_features, n_filters):
    # upsample
    x = layers.Conv2DTranspose(n_filters, 3, 2, padding="same")(x)
    # concatenate 
    x = layers.concatenate([x, conv_features])
    # dropout
    x = layers.Dropout(0.3)(x)
    # Conv2D twice with ReLU activation
    x = double_conv_block(x, n_filters)

    return x

In [None]:
def build_unet_model(in_size=image_shape):

    # inputs
    inputs = layers.Input(shape=in_size)

    # encoder: contracting path - downsample
    # 1 - downsample
    f1, p1 = downsample_block(inputs, 64)
    # 2 - downsample
    f2, p2 = downsample_block(p1, 128)
    # 3 - downsample
    f3, p3 = downsample_block(p2, 256)
    # 4 - downsample
    f4, p4 = downsample_block(p3, 512)

    # 5 - bottleneck
    bottleneck = double_conv_block(p4, 1024)

    # decoder: expanding path - upsample
    # 6 - upsample
    u6 = upsample_block(bottleneck, f4, 512)
    # 7 - upsample
    u7 = upsample_block(u6, f3, 256)
    # 8 - upsample
    u8 = upsample_block(u7, f2, 128)
    # 9 - upsample
    u9 = upsample_block(u8, f1, 64)

    # outputs
    outputs = layers.Conv2D(3, 1, padding="same", activation = "softmax")(u9)

    # unet model with Keras Functional API
    unet_model = tf.keras.Model(inputs, outputs, name="U-Net")

    return unet_model

In [None]:
#Create model instance
model = build_unet_model()

In [None]:
#Check model summary
model.summary()
# keras.utils.plot_model(model, show_shapes=True)

# Compile and build model instance

In [None]:
#Define some training parameters

#Define training epochs
epochs = 300

#Define number of training examples
n_train_examples = len(X_train)          #factor of 4 due to augmentations

#Define number of validation examples
n_val_examples = len(X_val)

#Define training steps
steps_per_epoch = n_train_examples//batch_size

#Define validation steps
val_steps = n_val_examples//batch_size

wandb.log({
    'epochs': epochs,
    'steps_per_epoch': steps_per_epoch
})

print(f'Train exmaples: {n_train_examples}')
print(f'Eval exmaples: {n_val_examples}')
print(f'Steps per Epoch: {steps_per_epoch}')
print(f'Validation steps: {val_steps}')

In [None]:
#Define custom training metric

def dice_coeff(true_mask, pred_mask):
    '''
    Defines the training metric i.e. 
    calculates the dice coefficient for the necrosis class
    '''
    #Enforce shape consistency
    pred_mask = tf.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]

    #Avoid zero division
    smoothing_factor = 0.000001

    intersection = tf.reduce_sum(tf.cast((pred_mask == 0), tf.float32) * tf.cast((true_mask == 0), tf.float32))
    pred_area = tf.reduce_sum(tf.cast((pred_mask == 0), tf.float32))
    true_area = tf.reduce_sum(tf.cast((true_mask == 0), tf.float32))
    combined_area = pred_area + true_area

    score = 2 * ((intersection + smoothing_factor) / (combined_area + smoothing_factor))

    return score

In [None]:
#Compile model
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), 
              loss='sparse_categorical_crossentropy', 
              metrics=[dice_coeff])

In [None]:
# create results directory
training_dir = f'unet-training/{datetime.datetime.now().strftime("%Y-%m-%d_%H_%M")}/'
os.makedirs(training_dir, exist_ok=True)

In [None]:
f'unet-training/{datetime.datetime.now().strftime("%Y-%m-%d_%H_%M")}/'

In [None]:
proj_name = 'corm_experiments'
wandb.init(project=proj_name)

In [None]:
#Build model
model_history = model.fit(train_dataset, 
                          batch_size=batch_size, 
                          epochs=epochs, 
                          validation_data=validation_dataset, 
                          steps_per_epoch=steps_per_epoch, 
                          validation_steps=val_steps, 
                          callbacks=[tf.keras.callbacks.EarlyStopping(monitor='val_dice_coeff', 
                                                                      mode='max', 
                                                                      patience=100),
                                     tf.keras.callbacks.ModelCheckpoint(f'{training_dir}model_feb25_epoch-{epochs:02d}_{datetime.datetime.now().strftime("%Y-%m-%d_%H_%M")}.h5', 
                                                                        save_best_only=True, 
                                                                        save_weights_only=True),
                                    WandbMetricsLogger()]
                          )

In [None]:
# wandb.finish()

In [None]:
# Save history
hist_df = pd.DataFrame(model_history.history) 
hist_df.to_csv(f'{training_dir}history.csv', index=False)

In [None]:
#Plot model metrics
plot_metric('loss', 'Training loss vs. Validation loss', f'{training_dir}loss.png')
plot_metric('dice_coeff', 'Training dice_coeff vs. Validation dice_coeff', f'{training_dir}Dice_coeff.png') 

# Evaluate model

In [None]:
#Evaluation utility functions

def unbatch_validation_images_and_masks():
    '''
    Gets validation images and masks from batches
    '''
    #Unbatch validation dataset
    val_dataset = validation_dataset.unbatch()
    val_dataset = val_dataset.batch(batch_size=n_val_examples)

    val_images = []
    val_masks = []

    #Get validation images and masks
    for image, mask in val_dataset.take(1):
        val_images = image.numpy()
        val_masks = mask.numpy()

    #Avoid 'batch size remainder' trap
    val_images = val_images[:(n_val_examples - (n_val_examples % batch_size))]
    val_masks = val_masks[:(n_val_examples - (n_val_examples % batch_size))]

    return val_images, val_masks

In [None]:
def compute_class_metrics(true_mask, pred_mask):
    '''
    Calculates class-wise metrics i.e.
    intersection over union (IoU) and dice coefficient
    '''
    class_ious = {}
    class_dice_scores = {}

    #Avoid zero division
    smoothing_factor = 0.000001

    #Loop through classes
    for i in range(3):
        intersection = np.sum((pred_mask == i) * (true_mask == i))
        pred_area = np.sum((pred_mask == i))
        true_area = np.sum((true_mask == i))
        combined_area = pred_area + true_area

        #Calculate class IoU
        class_ious[i] = (intersection + smoothing_factor) / (combined_area - intersection + smoothing_factor)

        #Calculate class dice score
        class_dice_scores[i] = 2 * ((intersection + smoothing_factor) / (combined_area + smoothing_factor))

    return class_ious, class_dice_scores

In [None]:
#Get validation images and masks
val_images, val_masks = unbatch_validation_images_and_masks()

In [None]:
#Ensure shape consistency
print(val_images.shape)
print(val_masks.shape)

In [None]:
# model.load_weights('/content/drive/MyDrive/unet weights/unet_model_v3.h5')

In [None]:
#Get predictions
predictions = model.predict(validation_dataset, steps=val_steps)

In [None]:
predictions.shape

In [None]:
#Get overall predicted mask
results_ = np.argmax(predictions, axis=3)
results = results_[..., tf.newaxis]

In [None]:
print(results_.shape)
print(results.shape)

In [None]:
#Get class scores
ious, dice_scores = compute_class_metrics(val_masks, results)

In [None]:
print(classes)
print(ious)
print(dice_scores)

In [None]:
wandb.log({
    'eval' : {
        'classes': classes,
        'iou': ious,
        'dice_scores': dice_scores
    }
})

In [None]:
wandb.finish()

In [None]:
p = predictions[0]

In [None]:
predictions[0].shape

In [None]:
np.max(p)

In [None]:
data = p.astype(np.float64) / np.max(p) # normalize the data to 0 - 1
data = 255 * data # Now scale by 255
img = data.astype(np.uint8)


In [None]:
plt.imshow(img)

In [None]:
(predictions[0]*255.0)[0][0]

In [None]:
(predictions[0]*255.0).astype(np.int8)

In [None]:
predictions[0].dtype

In [None]:
plt.imshow(predictions[0])

# Post-detection analysis

In [None]:
import cv2 as cv

In [None]:
def get_damage_and_corm_masks(mask):
  '''
  Extracts/seperates annotations for the corm and damage lesions
  from the segmentation mask
  '''
  damage_stack = []
  corm_stack = []

  #Unstack channels

  #Retrieve red channel (necrosis annotation) from original mask and fill all
  #other channels with zeros
  necrosis_stack.append(revert_mask(mask)[:,:,0])
  necrosis_stack.append(tf.cast(tf.fill((128,128), 0), tf.uint8))
  necrosis_stack.append(tf.cast(tf.fill((128,128), 0), tf.uint8))

  #Retrieve green channel (root annotation) from original mask and fill all
  #other channels with zeros
  root_stack.append(tf.cast(tf.fill((128,128), 0), tf.uint8))
  root_stack.append(revert_mask(mask)[:,:,1])
  root_stack.append(tf.cast(tf.fill((128,128), 0), tf.uint8))

  #Restack channels and create new annotations
  necrosis_mask = tf.stack(necrosis_stack, axis=2)
  root_mask = tf.stack(root_stack, axis=2)

  return necrosis_mask, root_mask

In [None]:
#Post-detection utility functions

def get_necrosis_and_root_masks(mask):
  '''
  Extracts/seperates annotations for the root and necrosis lesions
  from the segmentation mask
  '''
  
  necrosis_stack = []
  root_stack = []

  #Unstack channels

  #Retrieve red channel (necrosis annotation) from original mask and fill all
  #other channels with zeros
  necrosis_stack.append(revert_mask(mask)[:,:,0])
  necrosis_stack.append(tf.cast(tf.fill((128,128), 0), tf.uint8))
  necrosis_stack.append(tf.cast(tf.fill((128,128), 0), tf.uint8))

  #Retrieve green channel (root annotation) from original mask and fill all
  #other channels with zeros
  root_stack.append(tf.cast(tf.fill((128,128), 0), tf.uint8))
  root_stack.append(revert_mask(mask)[:,:,1])
  root_stack.append(tf.cast(tf.fill((128,128), 0), tf.uint8))

  #Restack channels and create new annotations
  necrosis_mask = tf.stack(necrosis_stack, axis=2)
  root_mask = tf.stack(root_stack, axis=2)

  return necrosis_mask, root_mask

def annotate_mask(mask):
  '''
  Identifies contours and annotates masks with identified the contours
  '''
  mask_array = mask.numpy().astype(np.uint8)
  mask_gray = cv.cvtColor(mask_array, cv.COLOR_BGR2GRAY)
  _, mask_threshold = cv.threshold(mask_gray, 64, 128, cv.THRESH_BINARY + cv.THRESH_OTSU)

  contours, hierarchies = cv.findContours(mask_threshold, cv.RETR_TREE, cv.CHAIN_APPROX_NONE)

  #Account for nested contours
  if isinstance(hierarchies, np.ndarray):
    parent_contours = [contour for contour, hierarchy in zip(contours, hierarchies[0].tolist()) if hierarchy[-1] == -1]
    child_contours = [contour for contour, hierarchy in zip(contours, hierarchies[0].tolist()) if hierarchy[-1] != -1]
    contours = [parent_contours, child_contours]
  else:
    contours = [contours, []]

  annotated_mask = mask_array.copy()
  annotated_mask = cv.drawContours(annotated_mask, contours[0], -1, (255,255,255), 1)

  return contours, annotated_mask

def cbsd_scoring(necrosis_percentage):
  '''
  Determines Cassava Brown Streak Disease (CBSD) score
  with respected to the percentage of root affected by necrosis
  '''
  cbsd_score = 0

  if necrosis_percentage <= 2:
    cbsd_score = 1
  elif necrosis_percentage <= 5:
    cbsd_score = 2
  elif necrosis_percentage <= 10:
    cbsd_score = 3
  elif necrosis_percentage <= 25:
    cbsd_score = 4
  else:
    cbsd_score = 5

  return cbsd_score

def post_detection_analysis(mask):
  '''
  Performs post-detection analysis i.e.
   - Identifying, counting and annotating necrosis lesions
   - Calculating percentange of root affecting by necrosis
   - Determining cbsd score
   and returns results
  '''
  necrosis_mask, root_mask = get_necrosis_and_root_masks(mask)

  necrosis_contours, necrosis_annotated_mask = annotate_mask(necrosis_mask)
  root_contours, root_annotated_mask = annotate_mask(root_mask)

  #Determine size/area of necrosis lesions (taking into account ring shaped lesions)
  if len(necrosis_contours[1]):
    necrosis_contour_areas = [cv.contourArea(contour) for contour in necrosis_contours[0]]
    child_necrosis_contour_areas = [cv.contourArea(contour) for contour in necrosis_contours[1]]
    total_necrosis_area = sum(necrosis_contour_areas) - sum(child_necrosis_contour_areas)
  else:
    necrosis_contour_areas = [cv.contourArea(contour) for contour in necrosis_contours[0]]
    total_necrosis_area = sum(necrosis_contour_areas)

  #Determine size/area of cassava root
  root_contour_areas = [cv.contourArea(contour) for contour in root_contours[0]]
  total_root_area = sum(root_contour_areas)

  root_area_idx = np.argmax(root_contour_areas)
  root_hull = cv.convexHull(root_contours[0][root_area_idx])
  root_area_hull = cv.contourArea(root_hull)

  #Account for boundary lesions/convexity defects
  if total_necrosis_area > root_area_hull:
    root_area_hull = total_necrosis_area + total_root_area

  #Determine number of lesions
  #and calculate percentage of root affected by necrosis
  n_lesions = len(necrosis_contour_areas)
  necrosis_percentage = total_necrosis_area/root_area_hull * 100
  verdict = None

  #Determine Cassava Brown Streak Disease (CBSD) score
  cbsd_score = cbsd_scoring(necrosis_percentage)

  #Determine verdict based on CBSD score
  if cbsd_score > 1:
    verdict = '**necrotic**'
  else:
    verdict = '**no necrosis**'

  results = {'Number of lesions':n_lesions, 
             'Area of root':root_area_hull, 
             'Area of lesions (sum)':total_necrosis_area, 
             'Necrosis percentage':round(necrosis_percentage, 2), 
             'CBSD score':cbsd_score, 
             'Verdict':verdict}

  annotated_mask = necrosis_annotated_mask + root_annotated_mask

  return results, annotated_mask

In [None]:
#Lesion analysis utility functions

def get_thresh(mask):
  '''
  Performs image thresholding on masks
  '''
  mask_array = mask.numpy().astype(np.uint8)
  mask_gray = cv.cvtColor(mask_array, cv.COLOR_BGR2GRAY)
  _, mask_threshold = cv.threshold(mask_gray, 64, 128, cv.THRESH_BINARY + cv.THRESH_OTSU)
  return mask_threshold

def get_contours(mask):
  '''
  Finds and returns parent contours from mask images
  '''
  mask_threshold = get_thresh(mask)
  contours, hierarchies = cv.findContours(mask_threshold, cv.RETR_TREE, cv.CHAIN_APPROX_NONE)

  #Account for nested contours
  if isinstance(hierarchies, np.ndarray):
    contours = [contour for contour, hierarchy in zip(contours, hierarchies[0].tolist()) if hierarchy[-1] == -1]

  return contours

def get_centroids(contours):
  '''
  Finds and returns centroid coordinates for
  contours identified from mask images
  '''
  #Avoid zero-division error
  smoothing_factor = 1e-10

  moments = [cv.moments(cnt) for cnt in contours]

  x_coords = [moment['m10'] for moment in moments]
  y_coords = [moment['m01'] for moment in moments]

  areas = [moment['m00']+smoothing_factor for moment in moments]

  #Determine centroid coordinates
  cx_coords = [x/A for x,A in zip(x_coords,areas)]
  cy_coords = [y/A for y,A in zip(y_coords,areas)]

  centroids = list(zip(cx_coords,cy_coords))

  return centroids

def distance_transform(mask_threshold):
  '''
  Applies euclidean distance transform on lesions
  with respect to the entire root (and background)
  '''
  dist_trans = tfa.image.euclidean_dist_transform(mask_threshold)
  return dist_trans

def centroid_distance_transform(root, root_centroid, necrosis_thresh):
  '''
  Applies euclidean distance transform on lesions
  with respect to the centre of the root
  '''
  #Create transform representation
  dist_trans = np.zeros((128,128), dtype=np.uint8)

  #Faintly annotate representation with root
  cv.drawContours(dist_trans, root, -1, (4), 1)

  #Get lesion pixels
  lesion_pixels = cv.findNonZero(necrosis_thresh)[:,0].tolist()

  #Calculate euclidean distance from centre of root
  #for each pixel
  euclidean_dist = [np.linalg.norm(px - np.array(root_centroid)) for px in lesion_pixels]

  #Update representation with euclidean distances for lesion pixels
  for px, dist in zip(lesion_pixels, euclidean_dist):
    dist_trans[px[1], px[0]] = dist

  return dist_trans

def lesion_analysis(lesions, root, root_, lesion_centroids, root_centroid):
  '''
  Analyzes each lesion and returns associated metrics i.e.
   - Pixel area
   - Coverage (%)
   - Average distance from centre of root

  *Average distance from centre of root* is defined as
  the distance between the lesion centroid and the root centroid
  '''
  root_area = sum([cv.contourArea(contour) for contour in root])
  root_area_hull = cv.contourArea(root_)
  lesion_areas = [cv.contourArea(lesion) for lesion in lesions]

  #Account for edge cases/convexity defects
  if sum(lesion_areas) > root_area_hull:
    root_area_hull = root_area + sum(lesion_areas)

  #Calculate coverage of each lesion
  lesion_coverage_percentages = [round((lesion_area/root_area_hull)*100, 2) for lesion_area in lesion_areas]
  
  #Calculate average distance from centre of root
  #for each lesion
  dist_between_centroids = [round(np.linalg.norm(np.array(root_centroid) - np.array(lesion_centroid)), 2) for lesion_centroid in lesion_centroids]
  
  metrics_ = list(zip(lesion_areas, lesion_coverage_percentages, dist_between_centroids))

  #Create lesion --> metrics dictionary mapping
  metrics = {lesion_idx:metrics for lesion_idx, metrics in enumerate(metrics_)}

  return metrics

In [None]:
#Get random prediction to analyze
idx = np.random.randint(low=0, high=n_val_examples, size=1)[0]

plots = []
metrics = []
titles = ['Image', 'True Mask', 'Predicted Mask']

plots.append(val_images[idx])

#Get true mask analysis results
analysis, annotated_mask = post_detection_analysis(val_masks[idx])
metrics.append(analysis)
plots.append(annotated_mask)

#Get predicted mask analysis results
analysis, annotated_mask = post_detection_analysis(results[idx])
metrics.append(analysis)
plots.append(annotated_mask)

In [None]:
#Get masks
necrosis_mask, root_mask = get_necrosis_and_root_masks(results[idx])
combined_mask = necrosis_mask + root_mask

#Get mask thresholds
necrosis_thresh = get_thresh(necrosis_mask)
root_thresh = get_thresh(root_mask)

#Get distance transforms
#with respect to entire root
necrosis_trans = distance_transform(necrosis_thresh)
root_trans = distance_transform(root_thresh) 
combined_trans = necrosis_trans + root_trans

#Get lesions and root
lesions = get_contours(necrosis_mask)
root = get_contours(root_mask)
root_idx = np.argmax([cv.contourArea(contour) for contour in root])
root_ = cv.convexHull(root[root_idx])

#Get centroids
lesion_centroids = get_centroids(lesions)
root_centroid = get_centroids([root_])[0]

#Get distance transform
#with respect to the centre of the root
centroid_trans = centroid_distance_transform(root, root_centroid, necrosis_thresh)

plots_ = [combined_mask, centroid_trans, combined_trans]
titles_ = ['Predicted Mask', 'Transform (w.r.t root centroid)', 'Transform (w.r.t root & background)']
metrics_ = lesion_analysis(lesions, root, root_, lesion_centroids, root_centroid)

In [None]:
#Visualize results
display(plots, titles, metrics=metrics)
display_(plots_, titles_, cmap='gray', metrics=metrics_, centroids=[lesion_centroids, root_centroid])