## Importing necessary packages

In [2]:
# Standard library imports
import os

# Third-party libraries for data manipulation, and visualization
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from patchify import patchify, unpatchify

# Deep learning libraries
import tensorflow as tf
import keras.backend as K
from keras.models import load_model

# Scientific image processing and analysis
import cv2
from skimage.morphology import skeletonize
from skan.csr import skeleton_to_csgraph
from skan import summarize, Skeleton
import networkx as nx

# Custom functions I made
from utils.metrics import f1, iou  # Importing the f1 and iou metrics for model compilation
from utils.data_processing import create_prediction, padder, cropping_images_in_folder  # Helper functions for plant_analyser


## Function to generate interactive html image plot

In [10]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots


def interactive_arabidopsis(uncropped_folder_path, root_model, shoot_model, reference_image_path):
    """
    Process plant images to detect landmarks and calculate root lengths using only the folder path of uncropped images.
    
    Parameters:
        uncropped_folder_path (str): Path to folder containing original uncropped images
        root_model: Trained model for root detection
        shoot_model: Trained model for shoot detection
        reference_image_path (str): Path to the reference image for cropping
    """
    
    # Create processed_images directory if it doesn't exist
    os.makedirs('processed_images', exist_ok=True)
    
    # Get cropped images and cropping coordinates
    cropped_images, (x, y, w, h) = cropping_images_in_folder(uncropped_folder_path, reference_image_path)
    
    # Process each image
    for filename, cropped_image in cropped_images.items():
        # Get padded image and padding values
        padded_image, (left_padding, top_padding) = padder(cropped_image)
        
        # Create root and shoot predictions
        root_mask = create_prediction(root_model, cropped_image)
        shoot_mask = create_prediction(shoot_model, cropped_image)
        
        # Set the first 250 rows of the shoot mask to 0
        shoot_mask[:250, :] = 0
        
        # Apply dilation to root mask
        kernel = np.ones((5, 5), dtype="uint8")
        im_blobs_dilation = cv2.dilate(root_mask, kernel, iterations=2)
        
        # Process root mask with connected components
        _, labels, stats, _ = cv2.connectedComponentsWithStats(im_blobs_dilation, connectivity=8)
        areas = stats[:, cv2.CC_STAT_AREA]
        sorted_indices = np.argsort(areas)[::-1]
        filtered_root_mask = np.zeros_like(im_blobs_dilation)
        
        for i in range(min(6, len(sorted_indices))):
            index = sorted_indices[i]
            filtered_root_mask[labels == index] = im_blobs_dilation[labels == index]
        
        # Skeletonize the filtered root image
        skeleton = skeletonize(filtered_root_mask)
        skeleton_branch_data = summarize(Skeleton(skeleton))
        
        # Read uncropped original image for visualization
        uncropped_original_image = cv2.imread(os.path.join(uncropped_folder_path, filename))

        uncropped_original_image_simple = uncropped_original_image.copy()
        
        # Initialize DataFrame for coordinates
        df_coordinates = pd.DataFrame(columns=[
            'skeleton_id', 
            'junction_root_hypthol_V1_x', 
            'junction_root_hypthol_V1_y',
            'primary_root_tip_V2_x', 
            'primary_root_tip_V2_y'
        ])
        
        # Apply landmark detection for each skeleton
        for skeleton_id in skeleton_branch_data['skeleton-id'].unique():
            # Lateral root tips
            lateral_root_tips = skeleton_branch_data[
                (skeleton_branch_data['skeleton-id'] == skeleton_id) & 
                (skeleton_branch_data['branch-type'] == 1)
            ]
            for _, lateral_row in lateral_root_tips.iterrows():
                coord_src_0_lateral = int(lateral_row['coord-dst-0'])
                coord_src_1_lateral = int(lateral_row['coord-dst-1'])
                cv2.circle(uncropped_original_image, 
                          (coord_src_1_lateral + x - left_padding, 
                           coord_src_0_lateral + y - top_padding), 
                          15, (0, 0, 255), 4)
            
            # Junction points
            junction = skeleton_branch_data[
                (skeleton_branch_data['skeleton-id'] == skeleton_id) & 
                ((skeleton_branch_data['branch-type'] == 1) | 
                 (skeleton_branch_data['branch-type'] == 2))
            ]
            for _, junction_row in junction.iterrows():
                coord_src_0_junction = int(junction_row['coord-src-0'])
                coord_src_1_junction = int(junction_row['coord-src-1'])
                cv2.circle(uncropped_original_image, 
                          (coord_src_1_junction + x - left_padding, 
                           coord_src_0_junction + y - top_padding), 
                          15, (255, 0, 0), 4)
            
            # Root tips
            root_tips = skeleton_branch_data[skeleton_branch_data['skeleton-id'] == skeleton_id]
            
            # Top point (junction between primary root and hypocotyl)
            min_row = root_tips.loc[root_tips['coord-src-0'].idxmin()]
            coord_src_0_min = int(min_row['coord-src-0'])
            coord_src_1_min = int(min_row['coord-src-1'])
            cv2.circle(uncropped_original_image, 
                      (coord_src_1_min + x - left_padding, 
                       coord_src_0_min + y - top_padding), 
                      15, (0, 0, 0), 4)
            
            # Bottom point (primary root tip)
            max_row = root_tips.loc[root_tips['coord-dst-0'].idxmax()]
            coord_src_0_max = int(max_row['coord-dst-0'])
            coord_src_1_max = int(max_row['coord-dst-1'])
            cv2.circle(uncropped_original_image, 
                      (coord_src_1_max + x - left_padding, 
                       coord_src_0_max + y - top_padding), 
                      15, (255, 0, 255), 4)
        
        # Create graph from skeleton data
        G = nx.from_pandas_edgelist(skeleton_branch_data, source='node-id-src', target='node-id-dst', edge_attr='branch-distance')
        skeleton_branch_data['mean_x'] = skeleton_branch_data[['coord-src-1', 'coord-dst-1']].mean(axis=1)
        mean_x_per_skeleton = skeleton_branch_data.groupby('skeleton-id')['mean_x'].mean()
        sorted_skeleton_ids = mean_x_per_skeleton.sort_values().index

        # Create a DataFrame to store the root lengths
        root_lengths_data = []
        
        for skeleton_id in sorted_skeleton_ids:
            root_tips = skeleton_branch_data[skeleton_branch_data['skeleton-id'] == skeleton_id]
            total_lateral_root_length = 0  # Initialize total lateral root length for this skeleton

            if not root_tips.empty:
                # Process junction and primary root tip
                min_row = root_tips.loc[root_tips['coord-src-0'].idxmin()]
                max_row = root_tips.loc[root_tips['coord-dst-0'].idxmax()]

                junction_node_id = min_row['node-id-src']
                root_tip_node_id = max_row['node-id-dst']

                # Calculate primary root length
                primary_root_length = nx.dijkstra_path_length(G, junction_node_id, root_tip_node_id, weight='branch-distance')

                # Process lateral roots
                lateral_root_tips = skeleton_branch_data[(skeleton_branch_data['skeleton-id'] == skeleton_id) &
                                                         (skeleton_branch_data['branch-type'] == 1) & 
                                                         (skeleton_branch_data['node-id-dst'] != root_tip_node_id) &
                                                         (skeleton_branch_data['node-id-src'] != junction_node_id)]

                for _, lateral_tip_row in lateral_root_tips.iterrows():
                    tip_node_id = lateral_tip_row['node-id-dst']
                    start_node_row = root_tips[root_tips['node-id-dst'] == tip_node_id]
                    if not start_node_row.empty:
                        start_node_id = start_node_row.iloc[0]['node-id-src']

                        # Calculate lateral root length
                        try:
                            lateral_root_length = nx.dijkstra_path_length(G, start_node_id, tip_node_id, weight='branch-distance')
                            total_lateral_root_length += lateral_root_length
                        except nx.NetworkXNoPath:
                            pass  # Ignore paths where no valid route is found
                
                # Append calculated lengths to the list
                root_lengths_data.append({
                    'skeleton_id': skeleton_id,
                    'primary root length': primary_root_length,
                                        'primary root length': primary_root_length,
                    'total lateral root length': total_lateral_root_length,
                })
        
        # Convert list to DataFrame and save it as CSV
        root_lengths_df = pd.DataFrame(root_lengths_data)
        root_lengths_df.to_csv(f'processed_images/{filename}_root_lengths.csv', index=False)
        
        ### End of Additional Code ###
        
        # Create padded masks at full image size
        root_mask_padded = np.zeros_like(uncropped_original_image[:,:,0])
        shoot_mask_padded = np.zeros_like(uncropped_original_image[:,:,0])
        
        # Place root masks in correct position
        root_mask_region = root_mask_padded[
            y-top_padding:y-top_padding+root_mask.shape[0], 
            x-left_padding:x-left_padding+root_mask.shape[1]
        ]
        shoot_mask_region = shoot_mask_padded[
            y-top_padding:y-top_padding+shoot_mask.shape[0], 
            x-left_padding:x-left_padding+shoot_mask.shape[1]
        ]
        
        root_mask_region[:] = filtered_root_mask
        shoot_mask_region[:] = shoot_mask
        
        # Create colored overlays
        yellow_overlay = np.zeros_like(uncropped_original_image)
        green_overlay = np.zeros_like(uncropped_original_image)
        
        # Assign colors to masks
        yellow_overlay[root_mask_padded == 255] = [0, 255, 255]  # BGR: Yellow for root mask
        green_overlay[shoot_mask_padded == 255] = [0, 255, 0]    # BGR: Green for shoot mask
        
        # Combine overlays with original image
        combined_image = cv2.addWeighted(uncropped_original_image, 1, yellow_overlay, 0.25, 0)
        combined_image = cv2.addWeighted(combined_image, 1, green_overlay, 0.25, 0)

        combined_image_rgb = cv2.cvtColor(combined_image, cv2.COLOR_BGR2RGB)
                
        # Assuming this is already inside your main loop where filename is available:
        # Read the corresponding CSV file
        root_lengths_df = pd.read_csv(f'processed_images/{filename}_root_lengths.csv')

        # Create a figure with two subplots: one for the image and one for the table
        fig = make_subplots(
            rows=2, cols=1,  # Two rows, one column
            row_heights=[0.7, 0.3],  # Allocate more space to the image
            specs=[[{"type": "image"}], [{"type": "table"}]]  # Specify types for each subplot
        )

        # Add a title to the figure
        fig.update_layout(title_text=f'Interactive Arabidopsis Analysis for {filename}')


        # Convert uncropped original image to RGB
        uncropped_original_rgb = cv2.cvtColor(uncropped_original_image_simple, cv2.COLOR_BGR2RGB)

        # Add image to the first subplot
        # Add both images to the first subplot, only one will be visible at a time
        fig.add_trace(go.Image(z=combined_image_rgb, visible=True), row=1, col=1)  # Processed image
        fig.add_trace(go.Image(z=uncropped_original_rgb, visible=False), row=1, col=1)  # Original image


        # Add table to the second subplot
        fig.add_trace(go.Table(
            header=dict(values=['Primary Root Length in Pixels (left to right)', 'Total Lateral Root Length In Pixels (left to right)'],
                        fill_color='paleturquoise',
                        align='left'),
            cells=dict(values=[root_lengths_df['primary root length'], root_lengths_df['total lateral root length']],
                    fill_color='lavender',
                    align='left')
        ), row=2, col=1)

        # Update layout properties for the figure
        # Add a button to toggle between the original and processed image
        fig.update_layout(
            updatemenus=[
                {
                    "buttons": [
                        {
                            "args": [{"visible": [True, False]}],  # Show processed image, hide original
                            "label": "Processed Image",
                            "method": "update"
                        },
                        {
                            "args": [{"visible": [False, True]}],  # Show original image, hide processed
                            "label": "Original Image",
                            "method": "update"
                        }
                    ],
                    "direction": "left",
                    "pad": {"r": 10, "t": 10},
                    "showactive": True,
                    "type": "buttons",
                    "x": 0.5,
                    "xanchor": "center",
                    "y": 1.15,
                    "yanchor": "top"
                }
            ]
        )


        # Save as HTML file with filename but without the extension of the original file in the name

        html_filename = f'{filename}_interactive.html'
        fig.write_html(html_filename)
        print(f'Plot saved as {html_filename}')

## Example Usage

In [11]:
# Path to the folder containing the images we want to process
uncropped_folder_path = 'measurement_set'

# Load the root model
root_model = load_model('utils/root_1.h5', custom_objects={'f1': f1, 'iou': iou})

# Load the shoot model
shoot_model = load_model('utils/shoot_1.h5', custom_objects={'f1': f1, 'iou': iou})

# Reference image for cropping
reference_image_path = 'utils/reference_image.png'


interactive_arabidopsis(uncropped_folder_path, root_model, shoot_model, reference_image_path)

Plot saved as measurement_image_1.tif_interactive.html
Plot saved as measurement_image_2.tif_interactive.html
