In [1]:
# Script to segments HADES root images using Breda's student model

# load libraries
from collections import defaultdict
import re
import cv2
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from patchify import patchify, unpatchify
import csv
import imagecodecs

import tensorflow as tf
import keras.backend as K
from keras.models import load_model

import skimage
from skimage.morphology import skeletonize
from skan import Skeleton, summarize
from skan.csr import skeleton_to_csgraph
from skan import draw
import networkx as nx

# load functions
def list_files_recursive(directory):
    file_list = []
    for root, dirs, files in os.walk(directory):
        for file in files:
            file_list.append(os.path.join(root, file))
    return file_list

def extract_individual_mask(coords, img_label):
    individual_mask = np.zeros_like(img_label)
    for r, c in coords:
        individual_mask[r, c] = 1
    return individual_mask

def grep(lst, substrings):
    return [item for item in lst if any(substring in item for substring in substrings)]

def f1(y_true, y_pred):
    def recall_m(y_true, y_pred):
        TP = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        Positives = K.sum(K.round(K.clip(y_true, 0, 1)))
        recall = TP / (Positives + K.epsilon())
        return recall

    def precision_m(y_true, y_pred):
        TP = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        Pred_Positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
        precision = TP / (Pred_Positives + K.epsilon())
        return precision

    precision, recall = precision_m(y_true, y_pred), recall_m(y_true, y_pred)

    return 2 * ((precision * recall) / (precision + recall + K.epsilon()))

def iou(y_true, y_pred):
    def f(y_true, y_pred):
        threshold = 0.5
        y_pred_binary = K.round(y_pred + 0.5 - threshold)

        intersection = K.sum(K.abs(y_true * y_pred_binary), axis=[1, 2, 3])
        total = K.sum(K.square(y_true), [1, 2, 3]) + K.sum(K.square(y_pred_binary), [1, 2, 3])
        union = total - intersection
        return (intersection + K.epsilon()) / (union + K.epsilon())

    return K.mean(f(y_true, y_pred), axis=-1)

def padder(image, patch_size=256):
    h = image.shape[0]
    w = image.shape[1]
    height_padding = ((h // patch_size) + 1) * patch_size - h
    width_padding = ((w // patch_size) + 1) * patch_size - w

    top_padding = int(height_padding/2)
    bottom_padding = height_padding - top_padding

    left_padding = int(width_padding/2)
    right_padding = width_padding - left_padding

    padded_image = cv2.copyMakeBorder(image, top_padding, bottom_padding, left_padding, right_padding, cv2.BORDER_CONSTANT, value=[0, 0, 0])

    return padded_image, {'top_padding': top_padding, 'bottom_padding': bottom_padding, 'left_padding': left_padding, 'right_padding': right_padding}

def unpadder(padded_image, padding_info):
    top_padding = padding_info['top_padding']
    bottom_padding = padding_info['bottom_padding']
    left_padding = padding_info['left_padding']
    right_padding = padding_info['right_padding']

    unpadded_image = padded_image[top_padding:padded_image.shape[0] - bottom_padding, left_padding:padded_image.shape[1] - right_padding]

    return unpadded_image

def shortest_path(start,end,binary_image):
    cost_matrix = np.where(binary_image,1,1000)
    path, cost = skimage.graph.route_through_array(
        cost_matrix, start=start, end=end, fully_connected=True)
    return path

def find_tip(skeleton_binary, mean_filter):

    tip_mask = mean_filter.copy()
    tip_mask[skeleton_binary == 0] = 0
    tip_mask[tip_mask == 56] = 255
    tip_mask[tip_mask != 255] = 0

    return tip_mask

def skeleton_pruning(skeleton_image, length):

    # Convert the skeleton to a binary image (0 or 255)
    skeleton_binary = np.where(skeleton_image, 255, 0)
    skeleton_binary = skimage.util.img_as_ubyte(skeleton_binary)

    # annotation of the skeleton base on neighbors connections
    mean_filter = skimage.filters.rank.mean(skeleton_binary, footprint=skimage.morphology.square(3))

    # Detect tips based on the mean filter
    tip_mask = find_tip(skeleton_binary, mean_filter)

    # remove node to creat segment
    skeleton_nonode = mean_filter.copy()
    skeleton_nonode[skeleton_binary == 0] = 0
    skeleton_nonode[skeleton_nonode > 112] = 0
    skeleton_nonode[skeleton_nonode > 0] = 255

    # find segement of the skeleton
    segments = skimage.measure.label(skeleton_nonode)
    segments_list = skimage.measure.regionprops(segments)

    # Iterate through segments and prune small segments connected to tips
    skeleton_pruned = skeleton_binary.copy()
    for segment in segments_list:
        condition = any(tip_mask[pixel[0], pixel[1]] == 255 for pixel in segment.coords)
        if segment.area < length and condition:
            skeleton_pruned[segment.coords[:, 0], segment.coords[:, 1]] = 0

    # Recreate the tip mask for the pruned skeleton
    new_mean_filter = skimage.filters.rank.mean(skeleton_pruned, footprint=skimage.morphology.square(3))
    new_tip_mask = find_tip(skeleton_pruned, new_mean_filter)

    return skeleton_pruned, new_tip_mask

def skeleton(individual_mask, skeleton_prunning_length, minimum_area_lateral):

    # extract skeleton from root mask
    skeleton_image = skimage.morphology.medial_axis(individual_mask, return_distance=False)

    # pruned the skeleton according to the pruning value
    pruned_skeleton, tip_mask = skeleton_pruning(skeleton_image, length=skeleton_prunning_length)

    # find top and bottom tip of the skeleton
    top_tip = tuple(np.argwhere(pruned_skeleton == 255))
    if len(top_tip) != 0 :
        top_tip = top_tip[0]

    bottom_tip = tuple(np.argwhere(pruned_skeleton == 255))
    if len(bottom_tip) != 0 :
        bottom_tip = bottom_tip[-1]

    # find the shortest path to identify main root
    if len(top_tip) != 0 and len(bottom_tip) != 0:
        keep_path = shortest_path(start=top_tip, end=bottom_tip, binary_image=pruned_skeleton)
    else:
        keep_path = []

    # recreate main root path according to the shortest path
    main_root_skeleton = np.zeros_like(pruned_skeleton)
    for r, c in keep_path:
        main_root_skeleton[r, c] = 1

    # remove the main root from the skeleton to get only the lateral root skeleton
    lateral_root_skeleton = pruned_skeleton.copy()
    lateral_root_skeleton[main_root_skeleton == 1] = 0

    # pruned some lateral roots based on area
    lateral_root_skeleton_pruned = skimage.morphology.remove_small_objects(skimage.util.img_as_bool(lateral_root_skeleton),
                                                                           minimum_area_lateral, connectivity=2)

    # to find origin of lateral roots. look for tips after removing main root
    lateral_root_skeleton_pruned = np.where(lateral_root_skeleton_pruned, 255, 0)
    lateral_root_skeleton_pruned = skimage.util.img_as_ubyte(lateral_root_skeleton_pruned)
    mean_filter = skimage.filters.rank.mean(lateral_root_skeleton_pruned, footprint=skimage.morphology.square(3))
    node_mask = find_tip(lateral_root_skeleton_pruned, mean_filter)
    node_mask[tip_mask == 255] = 0

    #put back the main root skeleton to the pruned lateral root to find the tips of the skeleton
    skeleton_pruned_laterals = lateral_root_skeleton_pruned.copy()
    skeleton_pruned_laterals[main_root_skeleton == 1] = 255
    mean_filter = skimage.filters.rank.mean(skeleton_pruned_laterals, footprint=skimage.morphology.square(3))
    tip_mask = find_tip(skeleton_pruned_laterals, mean_filter)

    return pruned_skeleton, main_root_skeleton, lateral_root_skeleton_pruned, node_mask, tip_mask

def root_model_predictions(reference_image_path, input_folder, list_images, output_folder, model):

    # Create output directory if it doesn't exist
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    # Read the reference image in color mode
    reference_im = cv2.imread(reference_image_path, cv2.IMREAD_COLOR)

    # Convert the reference image to grayscale for thresholding
    reference_im_gray = cv2.cvtColor(reference_im, cv2.COLOR_BGR2GRAY)

    # Apply Otsu's method to the grayscale reference image to make it binary
    _, reference_im_binary = cv2.threshold(reference_im_gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)

    # Find the contours in the binary reference image
    contours, hierarchy = cv2.findContours(reference_im_binary, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

    # Find the largest contour, which should be the petri dish in the reference image
    largest_contour = max(contours, key=cv2.contourArea)

    # Find the bounding rectangle of the largest contour in the reference image
    x, y, w, h = cv2.boundingRect(largest_contour)

    # Iterate over the images and masks in the specified folders
    for image_filename in list_images:

        # Construct the paths for the current image and mask
        image_path = os.path.join(input_folder, image_filename)

        # Read the image and mask in color mode
        im = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)

        # reshape the imput image to the shape of the reference image if necessary
        im = cv2.resize(im,(reference_im.shape[1], reference_im.shape[0]))

        # Crop the original color image
        cropped_mask = im.copy()
        cropped_mask[y:y + h, x:x + w] = 0
        im[cropped_mask != 0] = 0

        # define the patch size
        patch_size = 256

        # Pad the image
        padded_image = padder(im, patch_size)

        # Patch Extraction
        patches = patchify(padded_image[0], (patch_size, patch_size), step=patch_size)
        i, j = patches.shape[0], patches.shape[1]
        patches = patches.reshape(-1, patch_size, patch_size, 1)

        # Reshape Patches for model
        patches = patches.reshape(-1, patch_size, patch_size, 1)

        # Apply Model Predictions
        preds = model.predict(patches / 255)

        # Reshape Predictions
        preds = preds.reshape(i, j, patch_size, patch_size)

        # Unpatchify
        predicted_mask = unpatchify(preds, (padded_image[0].shape[0], padded_image[0].shape[1]))

        # Thresholding
        predicted_mask = (predicted_mask > 0.25).astype(np.uint8)

        # Data Type Conversion
        predicted_mask = predicted_mask.astype(np.uint8)
        predicted_mask = predicted_mask * 255  # Scale values to 0-255 range
        predicted_mask = predicted_mask.astype(np.uint8)

        # Unpad the predicted mask to get the original image shape
        predicted_mask = unpadder(predicted_mask, padded_image[1])

        # create individual output folder before saving
        output_link = re.sub(".png", "", '_'.join(re.split("_", image_filename)[2::]))
        output_link = os.path.join(output_folder, output_link)
        if not os.path.exists(output_link):
            os.makedirs(output_link)

        # Save Predictions
        output_path = os.path.join(output_link, f"{image_filename.split('.')[0]}_predicted.tif")
        cv2.imwrite(output_path, predicted_mask)

        # save mask overlapping on original image for inspection
        mask_rgb = cv2.imread(image_path, cv2.IMREAD_COLOR)
        mask_rgb = skimage.img_as_ubyte(mask_rgb)
        mask_rgb[predicted_mask == 255] =  [255,0,0]
        output_path = re.sub("_predicted", "_predicted_rgb.tif", output_path)
        cv2.imwrite(output_path, mask_rgb)

def root_traits(image_path, mask_temp, mask_rgb, skeleton_prunning_length, minimum_area_lateral):

    # add the option to handle more than one object. Select the main object for analysis of morphology
    # the other objects will be part of the main root length but not part of network analysis
    mask_label = skimage.measure.label(mask_temp)
    label_list = skimage.measure.regionprops(mask_label)

    if len(label_list) <= 1:
        pruned_skeleton, main_root_skeleton, lateral_root_skeleton, node_mask, tip_mask = skeleton(mask_temp, skeleton_prunning_length, minimum_area_lateral)
    if len(label_list) > 1:
        max_area_label = max(label_list, key=lambda region: region.area).label
        max_area_label_coords = label_list[max_area_label - 1].coords
        mask_first = extract_individual_mask(max_area_label_coords, mask_temp)
        mask_second = mask_temp.copy()
        mask_second[mask_first == 1] = 0

        pruned_skeleton, main_root_skeleton, lateral_root_skeleton, node_mask, tip_mask = skeleton(mask_first, skeleton_prunning_length, minimum_area_lateral)
        skeleton_second = skimage.morphology.medial_axis(mask_second, return_distance=False)
        lateral_root_skeleton[skeleton_second == 1] = 1

    mask_rgb[skimage.morphology.dilation(lateral_root_skeleton, skimage.morphology.disk(1)) == 255] = [255, 125, 50]
    mask_rgb[skimage.morphology.dilation(main_root_skeleton, skimage.morphology.disk(1)) == 1] = [125, 125, 255]
    mask_rgb[skimage.morphology.dilation(node_mask, skimage.morphology.disk(1)) == 255] = [255, 0, 255]
    mask_rgb[skimage.morphology.dilation(tip_mask, skimage.morphology.disk(1)) == 255] = [0, 255, 0]

    #leaf_area = len(Leaf_mask[Leaf_mask == 1])
    root_area = len(mask_temp[mask_temp == 255])
    main_root_length = len(main_root_skeleton[main_root_skeleton == 1])
    lateral_roots_length = len(lateral_root_skeleton[lateral_root_skeleton == 255])
    total_root_length = main_root_length + lateral_roots_length
    number_lateral_roots = len(node_mask[node_mask == 255])
    if len(np.where(node_mask == 255)[0]) != 0 and len(np.where(pruned_skeleton == 255)[0]) !=0:
        depth_origin_lateral_roots = abs(np.where(node_mask == 255)[0] - np.where(pruned_skeleton == 255)[0][0])
    else:
        depth_origin_lateral_roots = []

    mydict = defaultdict(int)
    mydict["image_path"] = image_path
    mydict["root_area"] = root_area
    mydict['total_root_length'] = total_root_length
    mydict["main_root_length"] = main_root_length
    mydict["lateral_roots_length"] = lateral_roots_length
    mydict["origin_lateral_roots"] = number_lateral_roots
    mydict["depth_origin_lateral_roots"] = depth_origin_lateral_roots

    return mydict

def individual_processing(rgb_folder, output_folder ,mask_seeding_position, list_images):

    for image_filename in list_images:

        # create individual output folder
        output_link = re.sub("_predicted.tif", "", '_'.join(re.split("_", os.path.basename(image_filename))[2::]))
        output_link = os.path.join(output_folder, output_link)
        if not os.path.exists(output_link):
            os.makedirs(output_link)

        # load the predicted mode
        predicted_mask = cv2.imread(image_filename)
        predicted_mask = predicted_mask[:,:,0]
        predicted_mask = skimage.util.img_as_ubyte(predicted_mask)

        # load the seeding mask
        mask_seeding_position = cv2.resize(mask_seeding_position, (predicted_mask.shape[1], predicted_mask.shape[0]))
        mask_seeding_position[mask_seeding_position < 200] = 0
        mask_seeding_position[mask_seeding_position != 0] = 255

        mask_seeding_position_label = skimage.measure.label(mask_seeding_position)
        seeding_position_List = skimage.measure.regionprops(mask_seeding_position_label)

        # dialted the predicted mask to reconnect some dicsonnect roots
        predicted_mask = skimage.morphology.dilation(predicted_mask)
        predicted_mask = skimage.morphology.erosion(predicted_mask)

        # label the predicted mask
        predicted_mask_label = skimage.measure.label(predicted_mask)
        objlist = skimage.measure.regionprops(predicted_mask_label)

        link_mask_rgb = re.sub("_predicted.tif", ".png", os.path.basename(image_filename))
        link_mask_rgb = os.path.join(rgb_folder, link_mask_rgb)
        mask_rgb = cv2.imread(link_mask_rgb)
        mask_rgb = skimage.img_as_ubyte(mask_rgb)

        # initialize the output mask and select the  objects according to seed position
        mask = np.zeros_like(predicted_mask)
        i = 0
        subname = ("_1", "_2", "_3", "_4", "_5")
        for position in seeding_position_List:

            mask_temp = np.zeros_like(predicted_mask)
            individual_mask_position = extract_individual_mask(position.coords, mask_seeding_position_label)

            #max_overlap = 0
            #best_obj = None
            for obj in objlist:
                individual_mask = extract_individual_mask(obj.coords, predicted_mask)
                overlapping_pixels = np.logical_and(individual_mask_position, individual_mask)
                num_overlapping_pixels = np.count_nonzero(overlapping_pixels)

                #if num_overlapping_pixels > max_overlap:
                #    max_overlap = num_overlapping_pixels
                #    best_obj = obj

                #if best_obj:
                #    for r, c in best_obj.coords:
                #        mask_temp[r, c] = 255

                if num_overlapping_pixels > 0:
                    for r, c in obj.coords:
                        mask_temp[r, c] = 255

            # export the individual seeedling mask
            output_path = os.path.join(output_link, re.sub(".tif", (subname[i] + ".tif"), os.path.basename(image_filename)))
            cv2.imwrite(output_path, mask_temp)

            # extract traits for individual seedling mask
            output_path = os.path.join(output_link, re.sub(".tif", (subname[i] + "_metrics.csv"), os.path.basename(image_filename)))

            # get the root traits and store them in dictionnary and overlay on the rgb image
            results = root_traits(image_filename, mask_temp, mask_rgb, skeleton_prunning_length = 10, minimum_area_lateral = 10)

            with open(output_path, 'w', newline="") as csv_file:
                writer = csv.writer(csv_file)
                for key, value in results.items():
                    writer.writerow([key, value])


            # add the individual obj for each position on progressively for all the seed position
            mask[mask_temp == 255] = 255
            i = i + 1

        # save the mask at the end of the iterations
        output_path = (output_link + "/" + os.path.basename(image_filename))
        cv2.imwrite(output_path, mask)

        # save the mask overlaping the rgb for inspection
        output_path = os.path.join(output_link, re.sub("_predicted.tif", "_predicted_rgb.tif", os.path.basename(image_filename)))
        cv2.imwrite(output_path, mask_rgb)


####################

# Define the working directory, input and output folders
wd = "/Users/viniciuslube/Desktop/Pipeline/BUas-NPEC/Valerian"
model = load_model(wd + '/processed_images/model_reference/root_4.h5', custom_objects={'f1': f1, 'iou': iou})

# subest the list images from the input folder
list_images = os.listdir(wd+"/rootcam1")
list_images = grep(list_images, '.png')

# subest the list to the correct rounds
list_images = grep(list_images, [
"_1_", "_2_",  "_3_", "_4_", "_5_",
"_6_", "_7_",  "_8_", "_9_", "_10_",
"_11_", "_12_", "_13_", "_14_", "_15_",
"_16_", "_17_", "_18_", "_19_", "_20_"])

# run the model for the root segmentation
root_model_predictions(reference_image_path=wd + '/processed_images/model_reference/014_43-18-ROOT1-2023-08-08_pvd_OD001_f6h1_03-Fish Eye Corrected.png',
input_folder=wd+"/rootcam1",
list_images=list_images,
output_folder=wd + "/processed_images/rootcam1/mask_model",
model=model)

# load the seed template for postprocessing
mask_seeding_position = skimage.io.imread(wd + "/processed_images/seeding_mask_2.tif")
mask_seeding_position = mask_seeding_position[:, :, 0]

list_images = list_files_recursive(wd+"/processed_images/rootcam1/mask_model")
list_images_rgb = grep(list_images, ["rgb"])
list_images = [item for item in list_images if item not in list_images_rgb]
list_images = grep(list_images, [
"_1_", "_2_",  "_3_", "_4_", "_5_",
"_6_", "_7_",  "_8_", "_9_", "_10_",
"_11_", "_12_", "_13_", "_14_", "_15_",
"_16_", "_17_", "_18_", "_19_", "_20_"])

# segment the individual seedlings
individual_processing(rgb_folder = "/Users/viniciuslube/Desktop/Pipeline/BUas-NPEC/Valerian/rootcam1",
                      output_folder=wd + "/processed_images/rootcam1/ind_mask/",
                      mask_seeding_position=mask_seeding_position,
                      list_images=list_images)





  return _convert(image, np.uint8, force_copy)
  return _convert(image, np.uint8, force_copy)
