## Importing necessary packages

In [1]:
# Standard library imports
import os

# Third-party libraries for data manipulation, and visualization
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from patchify import patchify, unpatchify

# Deep learning libraries
import tensorflow as tf
import keras.backend as K
from keras.models import load_model

# Scientific image processing and analysis
import cv2
from skimage.morphology import skeletonize
from skan.csr import skeleton_to_csgraph
from skan import summarize, Skeleton
import networkx as nx

# Custom functions I made
from utils.metrics import f1, iou  # Importing the f1 and iou metrics for model compilation
from utils.data_processing import create_prediction, padder, cropping_images_in_folder  # Helper functions for plant_analyser


## The computer vision pipeline: **arabidopsis_analyser**

In [2]:
def arabidopsis_analyser(uncropped_folder_path, root_model, shoot_model, reference_image_path):
    """
    Analyzes plant images to detect root and shoot regions, extract landmarks, and calculate root lengths.
    
    This function processes plant images by detecting root and shoot regions, applying landmark detection, and 
    calculating the lengths of primary and lateral roots. The detected regions and landmarks are overlaid on the 
    original images, and several results are saved, including processed masks, root lengths, and visualizations. 
    
    Parameters:
    ----------
    uncropped_folder_path : str
        Path to the folder containing original uncropped plant images.
    root_model : object
        Pre-trained model for root segmentation and detection.
    shoot_model : object
        Pre-trained model for shoot segmentation and detection.
    reference_image_path : str
        Path to a reference image that is used to crop all images uniformly.

    Returns:
    -------
    None
    
    Side Effects:
    -------------
    The function saves the following files in a `processed_images` directory:
    - Cropped and padded images with landmarks overlaid.
    - Root and shoot masks at full image size.
    - Combined images showing both root and shoot masks overlaid on the original image.
    - CSV files containing coordinates of detected primary root tips and junction points.
    - CSV files containing calculated root lengths (primary and lateral root lengths).
    - Final images with landmarks and masks overlaid, complete with a legend.

    Detailed Steps:
    ---------------
    1. The function starts by cropping the input images using a reference image for consistent cropping across the dataset.
    2. Each cropped image is padded for further processing.
    3. Root and shoot masks are created using pre-trained models.
    4. The shoot mask is post-processed to exclude the top 250 rows, representing the hypocotyl region.
    5. Root masks undergo dilation, and connected components analysis is applied to identify the largest root structures.
    6. The skeletonization process is applied to extract the root skeleton, followed by landmark detection (such as root tips and junctions).
    7. Landmarks, including junction points and root tips, are marked on the uncropped original image using colored circles.
    8. The function calculates primary and lateral root lengths using the skeletonized data and saves the results in CSV format.
    9. Root and shoot masks are padded to the size of the original uncropped images and combined into overlays (yellow for root, green for shoot).
    10. These overlays are saved separately and combined with the original images both with and without landmarks.
    11. Finally, the combined images with masks and landmarks are saved, and a visual legend is added to clarify the markings.
    """
    
    # Create processed_images directory if it doesn't exist
    os.makedirs('processed_images', exist_ok=True)
    
    # Get cropped images and cropping coordinates
    cropped_images, (x, y, w, h) = cropping_images_in_folder(uncropped_folder_path, reference_image_path)
    
    # Process each image
    for filename, cropped_image in cropped_images.items():
        # Create a directory for each processed image
        image_name = os.path.splitext(filename)[0]  # Get the image name without the extension
        image_output_dir = os.path.join('processed_images', image_name)
        os.makedirs(image_output_dir, exist_ok=True)
        
        # Get padded image and padding values
        padded_image, (left_padding, top_padding) = padder(cropped_image)
        
        # Create root and shoot predictions
        root_mask = create_prediction(root_model, cropped_image)
        shoot_mask = create_prediction(shoot_model, cropped_image)
        
        # Set the first 250 rows of the shoot mask to 0 (to remove noise)
        shoot_mask[:250, :] = 0
        
        # Apply dilation to root mask
        kernel = np.ones((5, 5), dtype="uint8")
        im_blobs_dilation = cv2.dilate(root_mask, kernel, iterations=2)
        
        # Process root mask with connected components
        _, labels, stats, _ = cv2.connectedComponentsWithStats(im_blobs_dilation, connectivity=8)
        areas = stats[:, cv2.CC_STAT_AREA]
        sorted_indices = np.argsort(areas)[::-1]
        filtered_root_mask = np.zeros_like(im_blobs_dilation)
        
        for i in range(min(6, len(sorted_indices))):
            index = sorted_indices[i]
            filtered_root_mask[labels == index] = im_blobs_dilation[labels == index]
        
        # Skeletonize the filtered root image
        skeleton = skeletonize(filtered_root_mask)
        skeleton_branch_data = summarize(Skeleton(skeleton))
        
        # Read uncropped original image for visualization
        uncropped_original_image = cv2.imread(os.path.join(uncropped_folder_path, filename))
        
        # Initialize DataFrame for coordinates
        df_coordinates = pd.DataFrame(columns=[
            'skeleton_id', 
            'junction_root_hypthol_V1_x', 
            'junction_root_hypthol_V1_y',
            'primary_root_tip_V2_x', 
            'primary_root_tip_V2_y'
        ])
        
        # Create a copy of the original image for overlay-only image (before adding landmark location circles)
        overlay_image_copy = uncropped_original_image.copy()

        # Apply landmark detection for each skeleton
        for skeleton_id in skeleton_branch_data['skeleton-id'].unique():
            # Lateral root tips
            lateral_root_tips = skeleton_branch_data[
                (skeleton_branch_data['skeleton-id'] == skeleton_id) & 
                (skeleton_branch_data['branch-type'] == 1)
            ]
            for _, lateral_row in lateral_root_tips.iterrows():
                coord_src_0_lateral = int(lateral_row['coord-dst-0'])
                coord_src_1_lateral = int(lateral_row['coord-dst-1'])
                cv2.circle(uncropped_original_image, 
                          (coord_src_1_lateral + x - left_padding, 
                           coord_src_0_lateral + y - top_padding), 
                          15, (0, 0, 255), 4)
            
            # Junction points
            junction = skeleton_branch_data[
                (skeleton_branch_data['skeleton-id'] == skeleton_id) & 
                ((skeleton_branch_data['branch-type'] == 1) | 
                 (skeleton_branch_data['branch-type'] == 2))
            ]
            for _, junction_row in junction.iterrows():
                coord_src_0_junction = int(junction_row['coord-src-0'])
                coord_src_1_junction = int(junction_row['coord-src-1'])
                cv2.circle(uncropped_original_image, 
                          (coord_src_1_junction + x - left_padding, 
                           coord_src_0_junction + y - top_padding), 
                          15, (255, 0, 0), 4)
            
            # Root tips
            root_tips = skeleton_branch_data[skeleton_branch_data['skeleton-id'] == skeleton_id]
            
            # Top point (junction between primary root and hypocotyl)
            min_row = root_tips.loc[root_tips['coord-src-0'].idxmin()]
            coord_src_0_min = int(min_row['coord-src-0'])
            coord_src_1_min = int(min_row['coord-src-1'])
            cv2.circle(uncropped_original_image, 
                      (coord_src_1_min + x - left_padding, 
                       coord_src_0_min + y - top_padding), 
                      15, (0, 0, 0), 4)
            
            # Bottom point (primary root tip)
            max_row = root_tips.loc[root_tips['coord-dst-0'].idxmax()]
            coord_src_0_max = int(max_row['coord-dst-0'])
            coord_src_1_max = int(max_row['coord-dst-1'])
            cv2.circle(uncropped_original_image, 
                      (coord_src_1_max + x - left_padding, 
                       coord_src_0_max + y - top_padding), 
                      15, (255, 0, 255), 4)
            
            # Save the original image with circles
            cv2.imwrite(os.path.join(image_output_dir, 'original_image_with_landmarks.png'), uncropped_original_image)
            
            # Add coordinates to DataFrame
            new_row = pd.DataFrame({
                'skeleton_id': [skeleton_id],
                'junction_root_hypthol_V1_x': [coord_src_1_min + x - left_padding],
                'junction_root_hypthol_V1_y': [coord_src_0_min + y - top_padding],
                'primary_root_tip_V2_x': [coord_src_1_max + x - left_padding],
                'primary_root_tip_V2_y': [coord_src_0_max + y - top_padding]
            })
            df_coordinates = pd.concat([df_coordinates, new_row], ignore_index=True)
        
        # Save coordinates
        df_coordinates.to_csv(os.path.join(image_output_dir, 'primary_root_tips_coordinates.csv'), index=False)
        
        # Create graph from skeleton data
        G = nx.from_pandas_edgelist(skeleton_branch_data, source='node-id-src', target='node-id-dst', edge_attr='branch-distance')
        root_lengths_data = []

        for skeleton_id in skeleton_branch_data['skeleton-id'].unique():
            # Extract root tips for the current skeleton
            root_tips = skeleton_branch_data[skeleton_branch_data['skeleton-id'] == skeleton_id]
            
            if not root_tips.empty:
                # Process junction and primary root tip
                min_row = root_tips.loc[root_tips['coord-src-0'].idxmin()]
                max_row = root_tips.loc[root_tips['coord-dst-0'].idxmax()]

                junction_node_id = min_row['node-id-src']
                root_tip_node_id = max_row['node-id-dst']

                # Calculate primary root length
                primary_root_length = nx.dijkstra_path_length(G, junction_node_id, root_tip_node_id, weight='branch-distance')

                # Initialize total lateral root length
                total_lateral_root_length = 0

                # Process lateral roots
                lateral_root_tips = skeleton_branch_data[
                    (skeleton_branch_data['skeleton-id'] == skeleton_id) & 
                    (skeleton_branch_data['branch-type'] == 1) & 
                    (skeleton_branch_data['node-id-dst'] != root_tip_node_id) &
                    (skeleton_branch_data['node-id-src'] != junction_node_id)
                ]

                for _, lateral_tip_row in lateral_root_tips.iterrows():
                    tip_node_id = lateral_tip_row['node-id-dst']
                    start_node_row = root_tips[root_tips['node-id-dst'] == tip_node_id]
                    
                    if not start_node_row.empty:
                        start_node_id = start_node_row.iloc[0]['node-id-src']

                        # Calculate lateral root length
                        try:
                            lateral_root_length = nx.dijkstra_path_length(G, start_node_id, tip_node_id, weight='branch-distance')
                            total_lateral_root_length += lateral_root_length
                        except nx.NetworkXNoPath:
                            pass  # Ignore paths where no valid route is found

                # Append calculated lengths to the list
                root_lengths_data.append({
                    'skeleton_id': skeleton_id,
                    'primary root length': primary_root_length,
                    'total lateral root length': total_lateral_root_length,
                })

        # Convert list to DataFrame and save as CSV
        root_lengths_df = pd.DataFrame(root_lengths_data)
        root_lengths_df.to_csv(os.path.join(image_output_dir, 'root_lengths.csv'), index=False)


        # Create padded masks at full image size (matching uncropped_original_image size)
        root_mask_padded = np.zeros_like(uncropped_original_image[:,:,0])
        shoot_mask_padded = np.zeros_like(uncropped_original_image[:,:,0])

        # Place root masks in correct position
        root_mask_region = root_mask_padded[
            y-top_padding:y-top_padding+root_mask.shape[0], 
            x-left_padding:x-left_padding+root_mask.shape[1]
        ]
        shoot_mask_region = shoot_mask_padded[
            y-top_padding:y-top_padding+shoot_mask.shape[0], 
            x-left_padding:x-left_padding+shoot_mask.shape[1]
        ]

        root_mask_region[:] = filtered_root_mask
        shoot_mask_region[:] = shoot_mask

        # Create colored overlays
        yellow_overlay = np.zeros_like(uncropped_original_image)
        green_overlay = np.zeros_like(uncropped_original_image)

        # Assign colors to masks
        yellow_overlay[root_mask_padded == 255] = [0, 255, 255]  # BGR: Yellow for root mask
        green_overlay[shoot_mask_padded == 255] = [0, 255, 0]    # BGR: Green for shoot mask

        full_root_mask = np.zeros_like(uncropped_original_image)
        full_shoot_mask = np.zeros_like(uncropped_original_image)

        full_root_mask[root_mask_padded == 255] = [255, 255, 255] 
        full_shoot_mask[shoot_mask_padded == 255] = [255, 255, 255]   

        # Save the full root and shoot masks
        cv2.imwrite(os.path.join(image_output_dir, 'full_root_mask.png'), full_root_mask)
        cv2.imwrite(os.path.join(image_output_dir, 'full_shoot_mask.png'), full_shoot_mask)

        # Combine overlays with the overlay_image_copy (image before landmarks)
        combined_overlay_image = cv2.addWeighted(overlay_image_copy, 1, yellow_overlay, 0.25, 0)
        combined_overlay_image = cv2.addWeighted(combined_overlay_image, 1, green_overlay, 0.25, 0)

        # Save the overlay_only image (without landmark circles)
        cv2.imwrite(os.path.join(image_output_dir, 'overlay_only.png'), combined_overlay_image)

        # Now combine the overlays with the uncropped_original_image (which has landmarks)
        combined_image = cv2.addWeighted(uncropped_original_image, 1, yellow_overlay, 0.25, 0)
        combined_image = cv2.addWeighted(combined_image, 1, green_overlay, 0.25, 0)

        # Convert to RGB for saving
        combined_image_rgb = cv2.cvtColor(combined_image, cv2.COLOR_BGR2RGB)

        # Save the image with landmarks and overlays
        plt.figure(figsize=(12, 8))
        plt.imshow(combined_image_rgb)
        plt.title(f'Original Image with Landmarks and Masks: {filename}')

        # Create legend
        legend_patches = [
            mpatches.Patch(color='red', label='Lateral root tip'),
            mpatches.Patch(color='blue', label='Point where lateral roots branch out from the primary root'),
            mpatches.Patch(color='black', label='Junction between the primary root and the hypocotyl'),
            mpatches.Patch(color='#FF00FF', label='Primary root tip'),
            mpatches.Patch(color='yellow', alpha=0.3, label='Root mask'),
            mpatches.Patch(color='green', alpha=0.3, label='Shoot mask')
        ]

        plt.legend(handles=legend_patches, bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.tight_layout()
        plt.savefig(os.path.join(image_output_dir, 'final_image_with_legend.png'))
        plt.close()


## Example Usage

In [3]:
# Path to the folder containing the images we want to process
uncropped_folder_path = 'measurement_set'

# Load the root model
root_model = load_model('utils/root_1.h5', custom_objects={'f1': f1, 'iou': iou})

# Load the shoot model
shoot_model = load_model('utils/shoot_1.h5', custom_objects={'f1': f1, 'iou': iou})

# Reference image for cropping
reference_image_path = 'utils/reference_image.png'


arabidopsis_analyser(uncropped_folder_path, root_model, shoot_model, reference_image_path)