In [54]:
import os
from patchify import unpatchify
import cv2
import numpy as np
import pandas as pd
from tensorflow.keras.models import load_model
import glob

In [55]:
patch_size=256
from custom_functions import predict_patches, cropping
from custom_metrics import f1,iou

### Model

In [56]:
# Define custom_objects dictionary with your metric functions
custom_objects = {'f1': f1, 'iou': iou}

# Load the model using the custom_objects parameter
root_model = load_model('root_model.h5', custom_objects=custom_objects)

In [57]:
for image_path in sorted(glob.glob('Kaggle_dataset/*')):
    cropping(image_path)

### Instance segmentation and saving instances


In [58]:
def instance_segmentation(image_path, min_area_threshold, output_dir, filename, num_largest=5, threshold=0.5):
    # Predict roots
    image, preds = predict_patches(image_path, root_model, patch_size)
    predicted_mask = unpatchify(preds, (image.shape[0], image.shape[1]))
    # Threshold predictions
    binary_mask = (predicted_mask > threshold).astype(np.uint8)
    # Connected components analysis
    _, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask)

    # Filter connected components based on area threshold
    mask = stats[:, 4] >= min_area_threshold
    labels = np.where(mask[labels], labels, 0)
    stats, centroids = stats[mask], centroids[mask]
    # Create a new image with filtered connected components
    filtered_image = np.where(labels > 0, binary_mask, 0)
    top_cutoff, bottom_cutoff = 400, 2100
    filtered_image[:top_cutoff, :] = 0
    filtered_image[bottom_cutoff:, :] = 0

    # Connected components analysis on the filtered image
    _, labels, stats, centroids = cv2.connectedComponentsWithStats(filtered_image)

    # Find indices of the largest components
    largest_label_indices = np.argsort(stats[1:, cv2.CC_STAT_AREA])[-num_largest:] + 1
    # Create a mask for the largest components
    largest_components_mask = np.where(np.isin(labels, largest_label_indices), 255, 0).astype(np.uint8)

    # Divide the image width into equal parts
    width = largest_components_mask.shape[1] // num_largest
    for i in range(num_largest):
        # Crop the corresponding part of the predicted_mask
        part = i * width
        cropped_image = largest_components_mask[:, part:part + width]
        # Save the cropped part as a separate image
        part_filename = f"{filename}_plant_{i + 1}.png"
        cv2.imwrite(os.path.join(output_dir, part_filename), cropped_image)

    return output_dir

In [82]:
# Cropped images for Kaggle dataset
input_dir = 'Cropped'
# Output directory for plant instances
output_dir = 'Instance_segmentation_dataset'

# Loop through all the files in the input directory
for filename in os.listdir(input_dir):
    if filename.endswith(".tif"):
        # Construct the full path of the image file
        image_path = os.path.join(input_dir, filename)
        # Construct the output filename based on the input filename
        output_filename = os.path.splitext(filename)[0]
        instance_segmentation(image_path, 400, output_dir, output_filename)

### Primary roots measures

In [83]:
import networkx as nx
from skimage.morphology import skeletonize
from skan import Skeleton, summarize

def measure_roots(image_path):
    # Load the image
    plant = cv2.imread(image_path)
    # Check if the image is empty 
    if np.all(plant == 0):
        # If the image is empty set a root length to 0.0 
        return 0.0

    _, binary = cv2.threshold(plant, 0.5, 1, cv2.THRESH_BINARY)
    binary = binary.astype('uint8')
    # Create skeleton
    skeleton = skeletonize(binary)
    # Skeleton summary
    graph = summarize(Skeleton(skeleton))
    # Create a graph using NetworkX
    G = nx.from_pandas_edgelist(graph, source='node-id-src', target='node-id-dst', edge_attr='branch-distance')
    # Calculate primary root length
    start_point = graph['node-id-src'].min()
    end_point = graph['node-id-dst'].max()
    primary_root_len = nx.dijkstra_path_length(G, start_point, end_point, weight='branch-distance')

    return primary_root_len


In [107]:
# Create an empty list to store DataFrames
results = []

# Loop through images in the directory and measure all roots
for filename in os.listdir(output_dir):
    if filename.endswith(".png"): 
        image_path = os.path.join(output_dir, filename)
        primary_root_len = measure_roots(image_path)
        measurements = pd.DataFrame({'Plant ID': [filename[:-4]], 'Length (px)': [round(primary_root_len, 1)]})
        results.append(measurements)

# Concatenate all DataFrames into one
final_results_df = pd.concat(results, ignore_index=True)

### Changing order of the results

In [108]:
# Function to change order of results to be from image 1 to image 11
def change_order(s):
    current_number = 0
    result = []
    for char in s:
        if char.isdigit():
            current_number = current_number * 10 + int(char)
        elif current_number != 0:
            result.append(current_number)
            current_number = 0
    if current_number != 0:
        result.append(current_number)

    return tuple(result)

# Apply the custom sorting function 
sorted_data = final_results_df.sort_values(by='Plant ID', key=lambda x: x.map(change_order))
sorted_data.to_csv('Results/primary_root_lengths.csv', index=False)

In [109]:
sorted_data

Unnamed: 0,Plant ID,Length (px)
10,test_image_1_plant_1,591.1
11,test_image_1_plant_2,529.7
12,test_image_1_plant_3,684.7
13,test_image_1_plant_4,448.2
14,test_image_1_plant_5,654.6
15,test_image_2_plant_1,1136.0
16,test_image_2_plant_2,1607.5
17,test_image_2_plant_3,1271.2
18,test_image_2_plant_4,1292.7
19,test_image_2_plant_5,1121.1
