In [None]:
import os
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from skimage.morphology import skeletonize, remove_small_objects
import tensorflow.keras.backend as K
from tensorflow.keras.models import load_model
from ot2_gym_wrapper import OT2Env  # Assuming this is where the OT2Env is defined
from scipy.spatial.distance import euclidean
from skimage.graph import route_through_array
import pandas as pd
import re
from stable_baselines3 import PPO

# ------------------------------------
# OLD PIPELINE FUNCTIONS
# ------------------------------------

def f1_score(y_true, y_pred):
    def recall_m(y_true, y_pred):
        TP = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        Positives = K.sum(K.round(K.clip(y_true, 0, 1)))
        recall = TP / (Positives + K.epsilon())
        return recall
    
    def precision_m(y_true, y_pred):
        TP = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        Pred_Positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
        precision = TP / (Pred_Positives + K.epsilon())
        return precision
    
    precision_val, recall_val = precision_m(y_true, y_pred), recall_m(y_true, y_pred)
    
    return 2 * ((precision_val * recall_val) / (precision_val + recall_val + K.epsilon()))


def padder(image, divisor, padding_value=(0, 0, 0)):
    """
    Applies balanced padding to an image to make its dimensions divisible by a given divisor.
    """
    original_height, original_width = image.shape[:2]
    pad_height = (divisor - (original_height % divisor)) % divisor
    pad_width = (divisor - (original_width % divisor)) % divisor

    top_pad = pad_height // 2
    bottom_pad = pad_height - top_pad
    left_pad = pad_width // 2
    right_pad = pad_width - left_pad

    padded_image = cv2.copyMakeBorder(
        image,
        top_pad, bottom_pad,
        left_pad, right_pad,
        cv2.BORDER_CONSTANT,
        value=padding_value
    )

    return padded_image


def reduce_noise(image):
    """
    Reduce noise in the input grayscale image using Gaussian blur and morphological operations.
    """
    blurred_image = cv2.GaussianBlur(image, (3, 3), 0)
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
    opened_image = cv2.morphologyEx(blurred_image, cv2.MORPH_OPEN, kernel)
    closed_image = cv2.morphologyEx(opened_image, cv2.MORPH_CLOSE, kernel)
    return closed_image


def morphological_petri_dish_crop(image):
    """
    Detect and crop the Petri dish using adaptive thresholding and morphological operations.
    Dynamically adjusts the crop size based on the detected dish dimensions.
    """
    thresh = cv2.adaptiveThreshold(
        image, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2
    )
    
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
    closed = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)
    
    contours, _ = cv2.findContours(closed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    largest_contour = max(contours, key=cv2.contourArea)
    
    x, y, w, h = cv2.boundingRect(largest_contour)
    
    size = int(max(w, h) * 0.98)
    cx, cy = x + w // 2, y + h // 2
    
    x1 = max(0, cx - size // 2)
    y1 = max(0, cy - size // 2)
    x2 = min(image.shape[1], cx + size // 2)
    y2 = min(image.shape[0], cy + size // 2)
    
    cropped_img = image[y1:y2, x1:x2].astype(np.float32) / 255.0
    cropped_img = cv2.resize(cropped_img, (size, size), interpolation=cv2.INTER_AREA)
    
    bbox = (x1, y1, x2, y2)
    return cropped_img, bbox


def padder_with_overlap(image, divisor, padding_value=(0, 0, 0)):
    """
    Applies overlapping padding to an image to reduce its dimensions such that
    the dimensions become divisible by the given divisor.
    """
    original_height, original_width = image.shape[:2]
    new_height = original_height - (original_height % divisor)
    new_width = original_width - (original_width % divisor)

    top_overlap = (original_height - new_height) // 2
    bottom_overlap = original_height - new_height - top_overlap
    left_overlap = (original_width - new_width) // 2
    right_overlap = original_width - new_width - left_overlap

    cropped_image = image[top_overlap:original_height-bottom_overlap, left_overlap:original_width-right_overlap]

    padded_image = cv2.copyMakeBorder(
        cropped_image,
        top_overlap, bottom_overlap,
        left_overlap, right_overlap,
        cv2.BORDER_CONSTANT,
        value=padding_value
    )

    return padded_image, (new_height, new_width)


def patch_image(image, patch_size=256, stride=128):
    """Divides an image into overlapping patches."""
    patches = []
    positions = []
    for i in range(0, image.shape[0] - patch_size + 1, stride):
        for j in range(0, image.shape[1] - patch_size + 1, stride):
            patch = image[i:i + patch_size, j:j + patch_size]
            patches.append(patch)
            positions.append((i, j))
    return np.array(patches), positions


def predict_patches(patches, model, batch_size=32):
    """
    Predict a batch of patches using the model for faster inference.
    """
    patches = np.array(patches)
    if patches.ndim == 2:  # Ensure the patches have a channel dimension
        patches = patches[..., np.newaxis]
    predictions = model.predict(patches, batch_size=batch_size)
    return predictions


def unpatch_image(patches, positions, image_shape, patch_size=256):
    """Reconstructs the full image from patches."""
    reconstructed = np.zeros((*image_shape, 1), dtype=np.float32)
    patch_count = np.zeros((*image_shape, 1), dtype=np.float32)

    for patch, (i, j) in zip(patches, positions):
        if patch.ndim == 2:
            patch = patch[..., np.newaxis]
        
        reconstructed[i:i + patch_size, j:j + patch_size, :] += patch
        patch_count[i:i + patch_size, j:j + patch_size, :] += 1

    reconstructed /= np.maximum(patch_count, 1)
    return np.squeeze(reconstructed)


def reverse_padding_and_cropping(reconstructed, original_shape, bbox):
    """
    Reverse the effects of cropping and padding on the mask and restore to original dimensions.
    """
    # Initialize final_mask as a 2D array (grayscale)
    final_mask = np.zeros(original_shape, dtype=reconstructed.dtype)

    x1, y1, x2, y2 = bbox
    final_mask[y1:y2, x1:x2] = reconstructed[:y2 - y1, :x2 - x1]

    return final_mask


# --------------------------------
# POST-PROCESSING FUNCTIONS
# --------------------------------

def process_root_mask(mask, kernel_size=1, iterations=1400, min_area=150):
    """
    Processes the mask to highlight roots, connect disconnected parts, and remove noise.
    """
    mask_normalized = cv2.normalize(mask, None, 0, 255, cv2.NORM_MINMAX)
    mask_8bit = np.uint8(mask_normalized)

    # Apply Otsu's thresholding to create a binary mask
    _, binary_mask = cv2.threshold(mask_8bit, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)

    # Morphological kernel
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size,kernel_size))

    # Opening -> Dilation -> Closing
    opened_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel)
    dilated_mask = cv2.dilate(opened_mask, kernel, iterations=iterations)
    closed_mask = cv2.morphologyEx(dilated_mask, cv2.MORPH_CLOSE, kernel)

    # Connected component filtering
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(closed_mask)
    filtered_mask = np.zeros_like(closed_mask)
    for i in range(1, num_labels):  # skip background
        if stats[i, cv2.CC_STAT_AREA] >= min_area:
            filtered_mask[labels == i] = 255

    return filtered_mask


def skeletonize_mask_skimage(processed_mask, min_size):
    """
    Skeletonizes the given binary mask and removes small objects.
    """
    binary_mask = np.array(processed_mask > 0, dtype=bool)
    cleaned_mask = remove_small_objects(binary_mask, min_size=min_size)
    skeleton = skeletonize(cleaned_mask)
    skeleton_uint8 = np.uint8(skeleton) * 255
    return skeleton_uint8


def create_overlay(image, mask, alpha=0.5):
    """
    Creates an overlay of the predicted mask on the original image.
    """
    image = cv2.normalize(image, None, 0, 1, cv2.NORM_MINMAX, dtype=cv2.CV_32F)
    mask = cv2.normalize(mask, None, 0, 1, cv2.NORM_MINMAX, dtype=cv2.CV_32F)
    
    image_rgb = np.stack([image]*3, axis=-1)  # Convert grayscale to RGB
    mask_rgb = np.zeros_like(image_rgb)
    mask_rgb[..., 0] = mask
    
    overlay_image = cv2.addWeighted(image_rgb, 1 - alpha, mask_rgb, alpha, 0)
    return overlay_image


def find_endpoints(skeleton):
    """
    Detect endpoints in a skeletonized image.
    Endpoints are pixels with only one neighbor.
    """
    skeleton_coords = np.column_stack(np.where(skeleton > 0))
    endpoints = []
    for coord in skeleton_coords:
        x, y = coord
        neighborhood = skeleton[max(0, x - 1):x + 2, max(0, y - 1):y + 2]
        if np.sum(neighborhood) == 2:
            endpoints.append((x, y))
    return endpoints


def measure_root_from_component(component_image, label_id):
    """
    Measure the root length and locate its endpoints for a single segmented component.
    Ensures the 'start_point' is the topmost endpoint, and the 'tip' is the farthest from it
    (i.e. the bottom one for vertical roots).
    """
    root_mask = (component_image == label_id).astype(np.uint8)
    skeleton = skeletonize(root_mask > 0)

    endpoints = find_endpoints(skeleton)
    if len(endpoints) < 2:
        raise ValueError(f"Not enough endpoints detected for label {label_id}.")

    # Pick topmost endpoint as 'start_point'
    start_point = min(endpoints, key=lambda p: p[0])  # row=0 => "topmost"
    # Pick the farthest endpoint from 'start_point' as 'tip'
    tip = max(endpoints, key=lambda p: euclidean(start_point, p))

    # Measure root length (Euclidean distance from 'start_point' to all skeleton coords)
    skeleton_coords = np.column_stack(np.where(skeleton > 0))
    distances = [euclidean(start_point, coord) for coord in skeleton_coords]
    length = max(distances)

    return length, start_point, tip, skeleton



def is_moderately_vertical(skeleton_coords, max_horizontal_to_vertical_ratio=0.5):
    """
    Allow roots that are mostly vertical but exclude highly diagonal roots.
    """
    if len(skeleton_coords) < 2:
        return False

    y_coords = skeleton_coords[:, 0]
    x_coords = skeleton_coords[:, 1]

    total_vertical_change = np.ptp(y_coords)
    total_horizontal_change = np.ptp(x_coords)

    ratio = total_horizontal_change / (total_vertical_change + 1e-6)
    return ratio <= max_horizontal_to_vertical_ratio


def isolate_and_measure_roots_by_plant(
    image_path, 
    min_area=80, 
    max_horizontal_to_vertical_ratio=0.5, 
    min_length=10, 
    dish_bbox=None
):
    """
    Isolate and measure the primary roots of 5 plants aligned horizontally within a Petri dish.
    Assigns roots into 5 vertical bins and assigns empty bins a measurement of 0.
    """

    if dish_bbox is None or len(dish_bbox) != 4:
        raise ValueError("A valid dish_bbox (dish_x1, dish_y1, dish_x2, dish_y2) must be provided.")

    gray = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    if gray is None:
        raise FileNotFoundError(f"Could not read the image: {image_path}")

    retval, labels, stats, centroids = cv2.connectedComponentsWithStats(gray)

    results_by_plant = []
    for plant_label in range(1, retval):  # Ignore background
        plant_area = stats[plant_label, cv2.CC_STAT_AREA]
        if plant_area >= min_area:
            x, y, w, h = (
                stats[plant_label, cv2.CC_STAT_LEFT],
                stats[plant_label, cv2.CC_STAT_TOP],
                stats[plant_label, cv2.CC_STAT_WIDTH],
                stats[plant_label, cv2.CC_STAT_HEIGHT]
            )

            plant_mask = (labels == plant_label).astype(np.uint8)
            plant_image = plant_mask[y:y+h, x:x+w]

            retval_roots, root_labels, root_stats, _ = cv2.connectedComponentsWithStats(plant_image)

            plant_roots = []
            for root_label in range(1, retval_roots):  # Ignore background
                root_area = root_stats[root_label, cv2.CC_STAT_AREA]
                if root_area >= min_area:
                    try:
                        length, start, tip, skeleton = measure_root_from_component(root_labels, root_label)
                        skeleton_coords = np.column_stack(np.where(skeleton > 0))
                        if length >= min_length and is_moderately_vertical(skeleton_coords, max_horizontal_to_vertical_ratio):
                            # bounding box in global coords
                            root_x_rel = root_stats[root_label, cv2.CC_STAT_LEFT]
                            root_y_rel = root_stats[root_label, cv2.CC_STAT_TOP]
                            root_w = root_stats[root_label, cv2.CC_STAT_WIDTH]
                            root_h = root_stats[root_label, cv2.CC_STAT_HEIGHT]

                            root_x = x + root_x_rel
                            root_y = y + root_y_rel

                            # convert local coords to global
                            global_start = (start[0] + y, start[1] + x)
                            global_tip   = (tip[0] + y,   tip[1] + x)

                            plant_roots.append({
                                "root_label": root_label,
                                "length": length,
                                "start": global_start,  # (row, col)
                                "tip": global_tip,      # (row, col)
                                "skeleton": skeleton,
                                "bounding_box": (root_x, root_y, root_w, root_h)
                            })
                    except ValueError:
                        # Not enough endpoints, skip
                        pass

            results_by_plant.append({
                "plant_label": plant_label,
                "plant_area": plant_area,
                "roots": plant_roots,
                "bounding_box": (x, y, w, h)
            })

    # Flatten all roots
    all_roots = [root for plant in results_by_plant for root in plant["roots"]]

    # 5 vertical bins across dish
    dish_x1, dish_y1, dish_x2, dish_y2 = dish_bbox
    dish_width = dish_x2 - dish_x1
    segment_width = dish_width / 5.0
    plant_bins = []
    for i in range(5):
        left_bound = dish_x1 + int(round(i * segment_width))
        right_bound = dish_x1 + int(round((i + 1) * segment_width))
        plant_bins.append((left_bound, right_bound))

    # Sort by x-coord of start point (global_start is (row, col), so col=1)
    all_roots = sorted(all_roots, key=lambda root: root["start"][1])

    # Assign roots to bins
    final_results = [{"plant_id": i+1, "length": 0.0, "roots": []} for i in range(5)]
    for root in all_roots:
        start_x = root["start"][1]
        for i, (left_bound, right_bound) in enumerate(plant_bins):
            if left_bound <= start_x < right_bound:
                final_results[i]["roots"].append(root)
                if root["length"] > final_results[i]["length"]:
                    final_results[i]["length"] = root["length"]
                break

    # Rename roots within each bin
    for plant_result in final_results:
        sorted_roots = sorted(plant_result["roots"], key=lambda r: r["length"], reverse=True)
        for idx, root in enumerate(sorted_roots):
            root["root_id"] = f"Root {plant_result['plant_id']}-{idx+1}"

    # Empty bins => length=0.0
    for i, plant_result in enumerate(final_results):
        if not plant_result["roots"]:
            plant_result["length"] = 0.0

    return final_results


def display_and_save_roots_by_plant(results_by_plant, output_directory, image_basename):
    """
    Display and save skeleton overlays for each plant and its roots.
    The output images are named using the original image base name followed by _root_<index>.png
    """
    os.makedirs(output_directory, exist_ok=True)

    for plant in results_by_plant:
        for root in plant["roots"]:
            root_id = root.get("root_id", "unknown")
            root_filename = f"{image_basename}_{root_id}.png"

            x, y, w, h = root["bounding_box"]
            skeleton_mask = root['skeleton']

            # Create a blank RGB canvas
            blank_canvas = np.zeros((h, w, 3), dtype=np.uint8)
            # Overlay the skeleton in red channel
            # note that 'skeleton_mask' is local: top-left corner = (0,0) in that bounding box
            skeleton_overlay = (skeleton_mask * 255).astype(np.uint8)
            blank_canvas[:skeleton_mask.shape[0], :skeleton_mask.shape[1], 2] = skeleton_overlay

            output_path = os.path.join(output_directory, root_filename)
            cv2.imwrite(output_path, blank_canvas)

def extract_root_coordinates(final_results):
    """
    Converts 'results_by_plant' into a dictionary of root coordinates
    that can be used by an RL pipeline or elsewhere.

    final_results is a list of 5 dicts (bins), each with:
       {
         "plant_id": int,
         "length": float,         # longest root in that bin
         "roots": [
             {
               "root_id": str,    # e.g., "Root 1-1"
               "length": float,   # length of this root
               "start": (int, int),  # (row, col)
               "tip": (int, int),    # (row, col)
               ...
             },
             ...
         ]
       }
    Returns a dictionary, e.g.:
       {
         "plant_1": [
            {
              "root_id": "Root 1-1",
              "start": (row, col),
              "tip": (row, col),
              "length": float
            },
            ...
         ],
         "plant_2": [...],
         ...
       }
    """
    coords_dict = {}

    for plant_data in final_results:
        plant_id = plant_data["plant_id"]  # e.g. 1, 2, 3...
        plant_key = f"plant_{plant_id}"
        coords_dict[plant_key] = []
        
        for root in plant_data["roots"]:
            # Grab relevant info
            root_id = root.get("root_id", "unknown_root")
            start_pt = root["start"]  # (row, col)
            tip_pt   = root["tip"]    # (row, col)
            length   = root["length"]

            coords_dict[plant_key].append({
                "root_id": root_id,
                "start": start_pt,
                "tip": tip_pt,
                "length": length
            })

    return coords_dict
# ----------------------------------------------------------
# NEW PIPELINE REQUIREMENT: PLOTTING TIPS ON ORIGINAL IMAGE
# ----------------------------------------------------------
def plot_tips_on_original_image(original_image_path, final_results, output_path):
    """
    Plots the tip of each root onto the original Petri dish image so we can see if the coordinates
    (in global space) are accurate.
    - final_results: the data structure returned by isolate_and_measure_roots_by_plant().
      final_results is a list of 5 dicts, each with "roots" containing global tip coords.
    """
    # Read original image in color
    original_img = cv2.imread(original_image_path, cv2.IMREAD_COLOR)
    if original_img is None:
        raise FileNotFoundError(f"Could not load original image: {original_image_path}")

    # Loop over each bin (plant_id) and each root
    for plant_data in final_results:
        for root in plant_data["roots"]:
            tip = root["tip"]  # (row, col)
            tip_x = tip[1]
            tip_y = tip[0]

            # Draw a small circle at the tip
            cv2.circle(original_img, (tip_x, tip_y), radius=6, color=(0, 0, 255), thickness=-1)

            # Optionally, label it with the root_id
            root_id = root.get("root_id", "UnknownRoot")
            cv2.putText(original_img, root_id, (tip_x+8, tip_y),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)

    # Save the result
    cv2.imwrite(output_path, original_img)


# ---------------------------------------
# MAIN EXECUTION WITH BOTH STEPS
# ---------------------------------------
def main_pipeline_example():
    # 1) Load model
    model_path = "final_modified_model.h5"
    model = load_model(model_path, custom_objects={"f1_score": f1_score})

    # 2) Initialize the OT2 environment
    env = OT2Env(render=True)
    image_path = env.get_plate_image()
    if not image_path:
        raise ValueError("Failed to retrieve plate image path from the environment.")

    # 3) Load the grayscale image
    original_image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)

    # 4) Morphological crop
    cropped_img, petri_dish_bbox = morphological_petri_dish_crop(original_image)

    # 5) Pad
    patching_size = 256
    padded_image, _ = padder_with_overlap(cropped_img, patching_size)

    # 6) Patch
    stride = 256
    patched, positions = patch_image(padded_image, patch_size=patching_size, stride=stride)

    # 7) Predict
    predicted = predict_patches(patched, model)

    # 8) Unpatch
    unpatched = unpatch_image(predicted, positions, padded_image.shape, patch_size=patching_size)

    # 9) Reverse crop/padding
    final_mask = reverse_padding_and_cropping(unpatched, original_image.shape, petri_dish_bbox)

    # 10) Post-process
    processed_mask = process_root_mask(final_mask)
    skeletonized_mask = skeletonize_mask_skimage(processed_mask, min_size=150)

    # 11) Create overlay for quick visualization
    overlay = create_overlay(original_image, skeletonized_mask, alpha=0.5)

    # 12) Save final mask and overlay
    output_dir = "Final_Masks_Iteration_10"
    os.makedirs(output_dir, exist_ok=True)
    base_name = "plate_image"

    final_mask_path = os.path.join(output_dir, f"{base_name}_mask.png")
    overlay_path = os.path.join(output_dir, f"{base_name}_overlay.png")

    cv2.imwrite(final_mask_path, (final_mask * 255).astype(np.uint8))
    overlay_bgr = cv2.cvtColor((overlay * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)
    cv2.imwrite(overlay_path, overlay_bgr)

    print(f"Saved final mask to {final_mask_path}")
    print(f"Saved overlay to {overlay_path}")

    # 13) Run isolate & measure on that final mask
    results_by_plant = isolate_and_measure_roots_by_plant(
        final_mask_path,   # path to the binary mask
        min_area=500,
        max_horizontal_to_vertical_ratio=0.5,
        min_length=10,
        dish_bbox=petri_dish_bbox
    )

    # 14) Save skeleton overlays (debug)
    skeleton_output_dir = "Skeleton_Iteration_10"
    display_and_save_roots_by_plant(results_by_plant, skeleton_output_dir, base_name)

    # 15) Plot tips on the *original* image for a final coordinate check
    final_tips_plot_path = os.path.join(skeleton_output_dir, f"{base_name}_tips_plotted.png")
    plot_tips_on_original_image(
        original_image_path=image_path,
        final_results=results_by_plant,
        output_path=final_tips_plot_path
    )
    print(f"Plotted root endpoints on original image: {final_tips_plot_path}")

    # 16) Convert the measurement results into a dictionary for RL pipeline
    coords_dict = extract_root_coordinates(results_by_plant)
    print("\n--- Coordinates Dictionary ---")
    for plant_key, roots_list in coords_dict.items():
        print(f"{plant_key}:")
        for root_info in roots_list:
            print(f"  {root_info}")
    
    # Now you can pass 'coords_dict' to your RL pipeline.
    # e.g. env.step(coords_dict) or however you incorporate it.
    return coords_dict


if __name__ == "__main__":
    coordinates_dict = main_pipeline_example()
    print(coordinates_dict)


In [1]:
import os
import cv2
import numpy as np
from skimage.measure import label, regionprops
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
import tensorflow.keras.backend as K
import tensorflow as tf
from skimage.morphology import skeletonize
import pandas as pd
from skimage.morphology import skeletonize
from skimage.measure import regionprops
import numpy as np

# ======================================
#           Custom F1 Score Metric
# ======================================
def f1_score(y_true, y_pred):
    y_pred_bin = tf.round(tf.clip_by_value(y_pred, 0, 1))
    tp = K.sum(tf.round(y_true * y_pred_bin))
    fp = K.sum(tf.round((1 - y_true) * y_pred_bin))
    fn = K.sum(tf.round(y_true * (1 - y_pred_bin)))
    precision = tp / (tp + fp + K.epsilon())
    recall = tp / (tp + fn + K.epsilon())
    f1 = 2 * precision * recall / (precision + recall + K.epsilon())
    return f1


# ======================================
#       Image Preprocessing
# ======================================
def combine_edges_and_threshold(gray_img):
    """
    Combine Canny edges and a binary threshold for better contour detection.
    """
    gray_img_uint8 = (gray_img * 255).astype(np.uint8)
    detected_edges = cv2.Canny(gray_img_uint8, 50, 150)
    _, thresholded_img = cv2.threshold(gray_img_uint8, 100, 255, cv2.THRESH_BINARY)
    overlay = cv2.addWeighted(detected_edges, 0.6, thresholded_img, 0.4, 0)
    return overlay


def locate_main_contour(img_shape, edge_img):
    """
    Find the bounding box of the largest external contour.
    """
    contours, _ = cv2.findContours(edge_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if not contours:
        return 0, 0, img_shape[1], img_shape[0]
    big_contour = max(contours, key=cv2.contourArea)
    x, y, w, h = cv2.boundingRect(big_contour)
    return x, y, w, h


def square_crop(gray_img, bounding_box, buffer_percent=0.04):
    """
    Crop around the main contour to produce a roughly square region.
    """
    x, y, width, height = bounding_box
    side_len = max(width, height)
    center_x = x + width // 2
    center_y = y + height // 2
    pad = int(side_len * buffer_percent)
    top_edge = max(0, center_y - side_len // 2 - pad)
    bottom_edge = center_y + side_len // 2 + pad
    left_edge = max(0, center_x - side_len // 2 - pad)
    right_edge = center_x + side_len // 2 + pad
    return gray_img[top_edge:bottom_edge, left_edge:right_edge]



# ======================================
#       Patch-based Segmentation
# ======================================
def segment_image_in_patches(segm_model, gray_img, patch_size=256):
    """
    Perform patch-based segmentation and stitch the result into a full mask.
    """
    border_size = 1
    # Pad the image to handle edges
    padded = cv2.copyMakeBorder(
        gray_img, border_size, border_size, border_size, border_size,
        cv2.BORDER_CONSTANT, value=0
    )

    patch_list = []
    coords_list = []
    h, w = padded.shape

    # Extract patches
    for row_start in range(0, h, patch_size):
        for col_start in range(0, w, patch_size):
            patch = padded[row_start:row_start + patch_size, col_start:col_start + patch_size]
            if patch.shape[0] < patch_size or patch.shape[1] < patch_size:
                patch_full = np.zeros((patch_size, patch_size), dtype=np.float32)
                patch_full[:patch.shape[0], :patch.shape[1]] = patch
                patch = patch_full

            patch_list.append(patch)
            coords_list.append((row_start, col_start))

    # Model inference
    batch_input = np.array(patch_list)[..., np.newaxis]
    batch_output = segm_model.predict(batch_input)

    # Reassemble
    complete_mask = np.zeros_like(padded, dtype=np.float32)
    for idx, (r, c) in enumerate(coords_list):
        piece_height, piece_width = padded[r:r + patch_size, c:c + patch_size].shape
        full_prediction = batch_output[idx].squeeze()[:piece_height, :piece_width]
        complete_mask[r:r + piece_height, c:c + piece_width] = full_prediction

    # Remove padding
    return complete_mask[border_size:-border_size, border_size:-border_size]


# ====================================================
#   Post-processing of Predictions-Regions of Interest
# ====================================================
def refine_mask_for_shoots(predicted_mask, min_pixels=1):
    """
    Refine the mask for shoot predictions using morphological closing and connected component filtering.
    """
    bin_mask = (predicted_mask > 0.5).astype(np.uint8)
    shape_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (20, 20))
    closed_mask = cv2.morphologyEx(bin_mask, cv2.MORPH_CLOSE, shape_kernel)

    label_img = label(closed_mask)
    refined = np.zeros_like(closed_mask, dtype=np.uint8)
    for props in regionprops(label_img):
        if props.area >= min_pixels:
            refined[label_img == props.label] = 1
    return refined


def refine_mask_for_roots(predicted_mask, min_pixels=1):
    """
    Refine the mask for root predictions using additional morphological operations to improve connectivity.
    """
    # Binarize the mask
    bin_mask = (predicted_mask > 0.5).astype(np.uint8)

    # Step 1: Dilate to connect nearby segments
    dilation_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (14, 14)) 
    dilated_mask = cv2.dilate(bin_mask, dilation_kernel, iterations=1)

    # Step 2: Morphological closing to fill small gaps
    shape_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
    closed_mask = cv2.morphologyEx(dilated_mask, cv2.MORPH_CLOSE, shape_kernel)

    # Step 3: Retain only connected components with a minimum pixel size
    label_img = label(closed_mask)
    refined = np.zeros_like(closed_mask, dtype=np.uint8)
    for props in regionprops(label_img):
        if props.area >= min_pixels:
            refined[label_img == props.label] = 1

    return refined


def extract_bounding_boxes_for_shoots(mask, min_area=1):
    """
    Extract bounding boxes for shoot regions in a binary mask.
    """
    labeled_mask = label(mask)
    bboxes = []
    for region in regionprops(labeled_mask):
        if region.area >= min_area:
            minr, minc, maxr, maxc = region.bbox
            bboxes.append((minc, minr, maxc - minc, maxr - minr))  # (x, y, w, h)
    return bboxes


def extract_bounding_boxes_for_roots(mask, min_area=1):
    """
    Extract bounding boxes for root regions in a binary mask.
    """
    labeled_mask = label(mask)
    bboxes = []
    for region in regionprops(labeled_mask):
        if region.area >= min_area:
            minr, minc, maxr, maxc = region.bbox
            bboxes.append((minc, minr, maxc - minc, maxr - minr))  # (x, y, w, h)
    return bboxes


def normalize_image(image):
    """
    Normalize the image to the range [0, 1].
    """
    return image / 255.0


def apply_shoot_roi_mask(mask, roi):
    """
    Apply a mask to keep only the region of interest (ROI) for shoot predictions.
    """
    x, y, w, h = roi
    roi_mask = np.zeros_like(mask, dtype=np.uint8)
    roi_mask[y:y + h, x:x + w] = 1
    return mask * roi_mask


def calculate_roi_with_margin(image, roi_height_fraction=0.25, top_margin_fraction=0.1, 
                              inner_margin_fraction=0.09, outer_margin_fraction=0.32):
    img_height, img_width = image.shape

    # Detect transitions (boundaries) dynamically
    left_boundary = next(
        (i for i in range(img_width // 2) if image[:, i].mean() < 200), 10
    )
    right_boundary = next(
        (i for i in range(img_width - 1, img_width // 2, -1) if image[:, i].mean() < 200),
        img_width - 10,
    )

    # Expand the left boundary dynamically
    inner_margin = int(img_width * inner_margin_fraction)
    x = max(left_boundary - 20 + inner_margin, 0)  # Adjust left boundary
    
    # Adjust ROI width
    outer_margin = int(img_width * outer_margin_fraction)
    width = max(right_boundary - x - outer_margin, 0)

    # Define top and height of the ROI
    y = int(img_height * top_margin_fraction)
    height = int(img_height * roi_height_fraction)

    return (x, y, width, height)



# ======================================
#    Zone-Based Filtering & RSA
# ======================================
def get_zone_from_bbox(bbox, img_width):
    """
    Determine which zone a bounding box belongs to (0 to 4).
    """
    x, y, w, h = bbox
    zone_width = img_width // 5  # 5 zones
    center_x = x + w / 2.0

    for zone in range(5):
        if center_x < (zone + 1) * zone_width:
            return zone
    return 4


def filter_root_bboxes_in_zone(root_bboxes, zone, img_width):
    """
    Filter root bounding boxes that belong to a specific zone (by center_x).
    """
    zone_bboxes = []
    zone_w = img_width // 5
    zone_start_x = zone * zone_w
    zone_end_x = zone_start_x + zone_w

    for bbox in root_bboxes:
        x, y, w, h = bbox
        center_x = x + w / 2.0
        if zone_start_x <= center_x < zone_end_x:
            zone_bboxes.append(bbox)
    return zone_bboxes

def filter_primary_root_in_each_zone(root_bboxes, img_width):
    primary_roots_by_zone = {}
    for zone in range(5):
        zone_bboxes = filter_root_bboxes_in_zone(root_bboxes, zone, img_width)
        if zone_bboxes:
            # Add a tolerance to include partially overlapping boxes
            largest_bbox = max(zone_bboxes, key=lambda bbox: bbox[2] * bbox[3])
            if zone == 0:
                # Include small left boxes
                largest_bbox = max(zone_bboxes, key=lambda bbox: bbox[2] * bbox[3] + bbox[0])
            primary_roots_by_zone[zone] = largest_bbox
    return primary_roots_by_zone

def divide_bboxes_into_zones(bboxes, img_width):
    zones = {i: [] for i in range(5)}
    zone_width = img_width / 5
    zone_boundaries = [(i * zone_width, (i + 1) * zone_width) for i in range(5)]

    for bbox in bboxes:
        x, y, w, h = bbox
        center_x = x + w / 2
        for i, (zone_start, zone_end) in enumerate(zone_boundaries):
            if zone_start - 10 <= center_x < zone_end + 10:  # Add tolerance
                zones[i].append(bbox)
                break

    return zones

def filter_largest_shoot_bbox(bboxes, shoot_roi):
    """
    Retain only the largest shoot bounding box within the shoot ROI.
    """
    roi_x, roi_y, roi_w, roi_h = shoot_roi
    filtered_bboxes = []

    for bbox in bboxes:
        x, y, w, h = bbox
        # Check intersection with the ROI
        if x + w > roi_x and x < roi_x + roi_w and y + h > roi_y and y < roi_y + roi_h:
            filtered_bboxes.append(bbox)

    if not filtered_bboxes:
        return None

    largest_bbox = max(filtered_bboxes, key=lambda bx: bx[2] * bx[3])
    return largest_bbox


def filter_root_bboxes_in_zones_with_priority(bboxes, zones, min_area_ratio=0.1):
    """
    Retain the largest root bounding box in each zone and include smaller ones if above area threshold.
    """
    filtered_bboxes = []

    for zone, zone_bboxes in zones.items():
        if zone_bboxes:
            largest_bbox = max(zone_bboxes, key=lambda bbox: bbox[2] * bbox[3])
            largest_area = largest_bbox[2] * largest_bbox[3]
            filtered_bboxes.append(largest_bbox)
            for bbox in zone_bboxes:
                if bbox != largest_bbox:
                    area = bbox[2] * bbox[3]
                    if area >= largest_area * min_area_ratio:
                        filtered_bboxes.append(bbox)

    return filtered_bboxes


def filter_primary_roots_by_region(filtered_root_bboxes, root_zones):
    """
    Pick the largest bounding box in each zone as the 'primary root'.
    """
    primary_roots_by_region = {}
    for zone in range(5):
        if zone in root_zones and root_zones[zone]:
            largest_bbox = max(root_zones[zone], key=lambda bbox: bbox[2] * bbox[3])
            primary_roots_by_region[zone] = largest_bbox
            print(f"Zone {zone}: Selected Primary Root: {largest_bbox}")
        else:
            primary_roots_by_region[zone] = None
            print(f"Zone {zone}: No roots found.")
    return primary_roots_by_region


def calculate_endpoint_distance_with_tip(skeleton):
    """
    Calculate the distance between the first and last endpoints of a skeleton
    and return the coordinates of these endpoints.
    """
    # Find endpoints: pixels with exactly one neighbor
    endpoint_coords = []
    for y in range(1, skeleton.shape[0] - 1):
        for x in range(1, skeleton.shape[1] - 1):
            if skeleton[y, x]:  # Check if the pixel is part of the skeleton
                neighbors = np.sum(skeleton[y - 1:y + 2, x - 1:x + 2]) - 1
                if neighbors == 1:  # Endpoint has only one neighbor
                    endpoint_coords.append((y, x))

    # If there are at least two endpoints, calculate distance
    if len(endpoint_coords) >= 2:
        p1, p2 = endpoint_coords[0], endpoint_coords[-1]
        distance = np.sqrt((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)
        return int(distance), p2  # Return distance and the tip coordinate (last endpoint)

    # If no valid endpoints, return 0 and None
    return 0, None

def extract_region_masks_with_tips(primary_roots_by_region, root_mask, cropped_img):
    """
    Extract region masks for the primary roots, calculate their skeletonized lengths,
    and return their tip coordinates.
    """
    skeletonized_roots_by_region = {}
    root_lengths_by_zone = {}
    root_tips_by_zone = {}

    for zone in range(5):
        if zone in primary_roots_by_region and primary_roots_by_region[zone] is not None:
            x, y, w, h = primary_roots_by_region[zone]
            region_mask = root_mask[y:y + h, x:x + w]
            region_cropped_img = cropped_img[y:y + h, x:x + w]

            # Skeletonize
            skeleton = skeletonize(region_mask)
            skeletonized_roots_by_region[zone] = (skeleton, region_cropped_img)

            # Calculate the endpoint distance and the tip coordinate
            endpoint_distance, tip_coord = calculate_endpoint_distance_with_tip(skeleton)
            root_lengths_by_zone[zone] = endpoint_distance

            # Convert tip coordinates to global image coordinates
            if tip_coord:
                global_tip_coord = (tip_coord[1] + x, tip_coord[0] + y)  # (x, y) in global coordinates
                root_tips_by_zone[zone] = global_tip_coord

        else:
            skeletonized_roots_by_region[zone] = None
            root_lengths_by_zone[zone] = 0
            root_tips_by_zone[zone] = None

    return skeletonized_roots_by_region, root_lengths_by_zone, root_tips_by_zone

def visualize_predictions_with_tips(image, shoots, roots, zones, roi, root_tips, title):
    """
    Visualize bounding boxes for shoots/roots plus vertical zone divisions, ROI rectangle,
    and plot primary root tips.
    """
    image_uint8 = (image * 255).astype(np.uint8)
    overlay = cv2.cvtColor(image_uint8, cv2.COLOR_GRAY2BGR)

    img_height, img_width = image.shape
    zone_width = img_width // 5

    # Draw zone lines
    for i in range(1, 5):
        cv2.line(overlay, (i * zone_width, 0), (i * zone_width, img_height), (255, 255, 0), 2)

    # Shoots in green
    for bbox in shoots:
        x, y, w, h = bbox
        cv2.rectangle(overlay, (x, y), (x + w, y + h), (0, 255, 0), 2)

    # Roots in red
    for bbox in roots:
        x, y, w, h = bbox
        cv2.rectangle(overlay, (x, y), (x + w, y + h), (0, 0, 255), 2)

    # ROI in blue
    roi_x, roi_y, roi_w, roi_h = roi
    cv2.rectangle(overlay, (roi_x, roi_y), (roi_x + roi_w, roi_y + roi_h), (255, 0, 0), 2)

    # Annotate zones
    for zone in range(5):
        zone_bboxes = zones.get(zone, [])
        for bbox in zone_bboxes:
            x, y, w, h = bbox
            cx = x + w // 2
            cy = y + h // 2
            cv2.putText(overlay, f"Z{zone}", (int(cx), int(cy)),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 1)

    # Plot root tips in magenta
    for zone, tip in root_tips.items():
        if tip:
            cv2.circle(overlay, tip, 5, (255, 0, 255), -1)

    plt.figure(figsize=(12, 12))
    plt.imshow(overlay[..., ::-1])  # Convert BGR to RGB
    plt.title(title)
    plt.axis("off")
    plt.show()

# ======================================
#   Visualization Functions
# ======================================
def save_and_visualize_predicted_mask(mask, output_path=None, title="Predicted Mask"):
    """
    Save and visualize the predicted mask.
    Args:
        mask (numpy array): The predicted mask to be visualized.
        output_path (str, optional): Path to save the mask. If None, it won't save.
        title (str): Title for the visualization.
    """
    plt.figure(figsize=(8, 8))
    plt.imshow(mask, cmap="gray")
    plt.title(title)
    plt.axis("off")
    plt.show()

def visualize_predictions_with_roi(image, shoots, roots, roi, title):
    """
    Visualize predictions with bounding boxes and an ROI rectangle.
    Shoot boxes in green, root boxes in red, and ROI in blue.
    """
    # Convert the input image to uint8
    image_uint8 = (image * 255).astype(np.uint8)

    # Convert grayscale to BGR
    overlay = cv2.cvtColor(image_uint8, cv2.COLOR_GRAY2BGR)

    # Draw shoot bounding boxes in green
    for bbox in shoots:
        x, y, w, h = bbox
        cv2.rectangle(overlay, (x, y), (x + w, y + h), (0, 255, 0), 2)

    # Draw root bounding boxes in red
    for bbox in roots:
        x, y, w, h = bbox
        cv2.rectangle(overlay, (x, y), (x + w, y + h), (0, 0, 255), 2)

    # Draw the ROI rectangle in blue
    roi_x, roi_y, roi_w, roi_h = roi
    cv2.rectangle(overlay, (roi_x, roi_y), (roi_x + roi_w, roi_y + roi_h), (255, 0, 0), 2)

    # Display the image
    plt.figure(figsize=(10, 10))
    plt.imshow(overlay[..., ::-1])  # Convert BGR to RGB for display
    plt.title(title)
    plt.axis('off')
    plt.show()


def visualize_skeletonized_roots_by_zone(skeletonized_roots_by_zone, title):
    """
    Visualize the skeletonized primary roots for each zone (0 to 4) in subplots.
    """
    fig, axes = plt.subplots(1, 5, figsize=(20, 5))
    fig.suptitle(title, fontsize=16)

    for zone in range(5):
        ax = axes[zone]
        if zone in skeletonized_roots_by_zone:
            skeleton, region_cropped_img = skeletonized_roots_by_zone[zone]
            overlay = cv2.cvtColor((region_cropped_img * 255).astype(np.uint8), cv2.COLOR_GRAY2BGR)
            # Overlay the skeleton in red
            overlay[skeleton > 0] = [255, 0, 0]
            ax.imshow(overlay)
            ax.set_title(f"Zone {zone}")
        else:
            ax.imshow(np.zeros((100, 100)), cmap="gray")
            ax.set_title(f"Zone {zone} (No root)")
        ax.axis("off")

    plt.tight_layout()
    plt.show()

def visualize_predictions_with_zones(image, shoots, roots, zones, roi, title):
    """
    Visualize bounding boxes for shoots/roots plus vertical zone divisions and ROI rectangle.
    """
    image_uint8 = (image * 255).astype(np.uint8)
    overlay = cv2.cvtColor(image_uint8, cv2.COLOR_GRAY2BGR)

    img_height, img_width = image.shape
    zone_width = img_width // 5

    # Draw zone lines
    for i in range(1, 5):
        cv2.line(overlay, (i * zone_width, 0), (i * zone_width, img_height), (255, 255, 0), 2)

    # Shoots in green
    for bbox in shoots:
        x, y, w, h = bbox
        cv2.rectangle(overlay, (x, y), (x + w, y + h), (0, 255, 0), 2)

    # Roots in red
    for bbox in roots:
        x, y, w, h = bbox
        cv2.rectangle(overlay, (x, y), (x + w, y + h), (0, 0, 255), 2)

    # ROI in blue
    roi_x, roi_y, roi_w, roi_h = roi
    cv2.rectangle(overlay, (roi_x, roi_y), (roi_x + roi_w, roi_y + roi_h), (255, 0, 0), 2)

    # Annotate zones
    for zone in range(5):
        zone_bboxes = zones.get(zone, [])
        for bbox in zone_bboxes:
            x, y, w, h = bbox
            cx = x + w // 2
            cy = y + h // 2
            cv2.putText(overlay, f"Z{zone}", (int(cx), int(cy)),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 1)

    for zone, bxs in zones.items():
        print(f"Zone {zone}: {len(bxs)} bounding boxes")  # Debug

    plt.figure(figsize=(12, 12))
    plt.imshow(overlay[..., ::-1])
    plt.title(title)
    plt.axis("off")
    plt.show()
def display_skeletonized_root_masks(cropped_img, root_mask, title="Skeletonized Root Mask"):
    """
    Display the skeleton overlay (red) in place of the predicted root mask.
    Left subplot: original cropped image (grayscale).
    Right subplot: overlay of the skeleton on grayscale.
    """
    # Skeletonize the binarized root_mask
    skeleton = skeletonize(root_mask > 0)

    # Prepare overlay
    cropped_img_normalized = (cropped_img * 255).astype(np.uint8)
    overlay = cv2.cvtColor(cropped_img_normalized, cv2.COLOR_GRAY2BGR)
    overlay[skeleton > 0] = [0, 0, 255]  # Red color for skeleton

    # Plot side by side
    plt.figure(figsize=(12, 6))

    plt.subplot(1, 2, 1)
    plt.imshow(cropped_img, cmap="gray")
    plt.title("Cropped Grayscale Image")
    plt.axis("off")

    plt.subplot(1, 2, 1)
    plt.imshow(root_mask, cmap="gray")
    plt.title("Cropped Grayscale Image")
    plt.axis("off")
    
    plt.subplot(1, 2, 2)
    plt.imshow(overlay[..., ::-1])  # BGR -> RGB
    plt.title(title)
    plt.axis("off")

    plt.tight_layout()
    plt.show()


def visualize_skeletonized_roots_with_lengths(skeletonized_roots_by_region, root_lengths_by_zone, title):
    """
    Display each zone's cropped region with skeleton overlay, also showing length.
    """
    fig, axes = plt.subplots(1, 5, figsize=(20, 5))
    fig.suptitle(title, fontsize=16)

    for zone in range(5):
        ax = axes[zone]
        if zone in skeletonized_roots_by_region and skeletonized_roots_by_region[zone] is not None:
            skeleton, region_cropped_img = skeletonized_roots_by_region[zone]
            overlay = cv2.cvtColor((region_cropped_img * 255).astype(np.uint8), cv2.COLOR_GRAY2BGR)
            overlay[skeleton > 0] = [255, 0, 0]  # Red for skeleton

            ax.imshow(overlay)
            root_length = root_lengths_by_zone.get(zone, 0)
            ax.set_title(f"Zone {zone}\nLength: {root_length} pixels")
        else:
            ax.imshow(np.zeros((100, 100)), cmap="gray")
            ax.set_title(f"Zone {zone}\nLength: 0 px")

        ax.axis("off")

    plt.tight_layout()
    plt.show()




# ----------------------Advanced ROI-------------------------
def calculate_zone_boundaries_within_roi(shoot_roi, num_zones, img_width):
    """
    Calculate zone boundaries within the shoot ROI (x splits into num_zones).
    """
    roi_x, roi_y, roi_w, roi_h = shoot_roi
    zone_width = roi_w / num_zones
    zones = [
        (roi_x + i * zone_width, roi_x + (i + 1) * zone_width, roi_y, roi_y + roi_h)
        for i in range(num_zones)
    ]
    return zones


def assign_roots_to_zones_within_roi(root_bboxes, zones):
    assigned_roots = {i: [] for i in range(len(zones))}
    for i, (zone_start_x, zone_end_x, zone_start_y, zone_end_y) in enumerate(zones):
        for bbox in root_bboxes:
            x, y, w, h = bbox
            cx = x + w / 2.0
            if (zone_start_x <= cx < zone_end_x) or (i == 0 and x < zone_start_x + 10):
                assigned_roots[i].append(bbox)
    return assigned_roots



def select_primary_root_with_extensions(root_bboxes, shoot_roi, zones, img_width):
    """
    1. Filter bounding boxes to those that start in the shoot ROI
       (top edge within ROI's vertical extent + some horizontal overlap).
    2. In each zone, pick the bounding box with the largest area.
    3. Expand that bounding box by merging any other bounding boxes 
       that overlap it (based on bounding-box overlap, not center).
    4. Return the final expanded bounding box for each zone.
    """

    def boxes_overlap(boxA, boxB):
        """
        Check if two bounding boxes overlap.
        boxA, boxB: [xmin, ymin, xmax, ymax]
        Returns True if they overlap.
        """
        return not (
            boxB[0] > boxA[2] + 10 or  # boxB is entirely to the right of boxA
            boxB[2] < boxA[0] - 10 or  # boxB is entirely to the left of boxA
            boxB[1] > boxA[3] + 10 or  # boxB is entirely below boxA
            boxB[3] < boxA[1] - 10     # boxB is entirely above boxA
        )

    primary_roots_by_zone = {i: None for i in range(len(zones))}
    roi_x, roi_y, roi_w, roi_h = shoot_roi

    # Step 1: Pick a "primary" root bbox in each zone (i.e., the largest one that starts in the ROI).
    for bbox in root_bboxes:
        x, y, w, h = bbox
        area = w * h

        # Check if the bbox starts in the shoot ROI (top edge within ROI vertical extent)
        within_top = (y >= roi_y) and (y < roi_y + roi_h)
        overlaps_horizontally = ((x + w) > roi_x) and (x < roi_x + roi_w)
        if not (within_top and overlaps_horizontally):
            continue

        # Assign this bbox to the zone based on its center_x
        cx = x + w / 2.0
        for zone_idx, (zsx, zex, zsy, zey) in enumerate(zones):
            if (zsx <= cx < zex) or (zone_idx == 0 and x < zsx + 10):  # Allow for overlap tolerance
                if (
                    primary_roots_by_zone[zone_idx] is None
                    or area > primary_roots_by_zone[zone_idx]['area']
                ):
                    primary_roots_by_zone[zone_idx] = {
                        'bbox': bbox,
                        'area': area
                    }
                break

    # Step 2: Expand each zone's bounding box by merging any overlapping bounding boxes.
    for zone_idx, root_info in primary_roots_by_zone.items():
        if root_info is None:
            continue
        x, y, w, h = root_info['bbox']
        total_bbox = [x, y, x + w, y + h]  # [xmin, ymin, xmax, ymax]

        # Merge all other bounding boxes that overlap the current bounding box
        for bbox in root_bboxes:
            ox, oy, ow, oh = bbox
            other_box = [ox, oy, ox + ow, oy + oh]

            # If they overlap, expand the total_bbox to include the other box
            if boxes_overlap(total_bbox, other_box):
                total_bbox[0] = min(total_bbox[0], other_box[0])
                total_bbox[1] = min(total_bbox[1], other_box[1])
                total_bbox[2] = max(total_bbox[2], other_box[2])
                total_bbox[3] = max(total_bbox[3], other_box[3])

        # Convert [xmin, ymin, xmax, ymax] back to (x, y, w, h)
        final_xmin, final_ymin, final_xmax, final_ymax = total_bbox
        final_w = final_xmax - final_xmin
        final_h = final_ymax - final_ymin
        final_area = final_w * final_h

        # Update the zone's bounding-box info
        root_info['extended_bbox'] = (final_xmin, final_ymin, final_w, final_h)
        root_info['total_area'] = final_area

    return primary_roots_by_zone

# ======================================
#   Main Pipeline with Visualization
# ======================================

def main_pipeline_with_dual_shoot_models_and_tips(root_model1_path, root_model2_path, shoot_model1_path, shoot_model2_path, data_dir, visualize=False):
    # Load models
    root_model1 = load_model(root_model1_path, compile=False)
    root_model2 = load_model(root_model2_path, compile=False)
    shoot_model1 = load_model(shoot_model1_path, compile=False)
    shoot_model2 = load_model(shoot_model2_path, compile=False)

    # Load image (grayscale)
    original_image = cv2.imread(data_dir, cv2.IMREAD_GRAYSCALE)
    image_height, image_width = original_image.shape

    normalized_image = original_image / 255.0

    # Preprocess, find bounding box, and crop
    edges = combine_edges_and_threshold(normalized_image)
    bbox = locate_main_contour(normalized_image.shape, edges)
    cropped_img = square_crop(normalized_image, bbox)

    # Calculate shoot ROI
    shoot_roi = calculate_roi_with_margin(
        image=original_image,
        roi_height_fraction=0.15,
        top_margin_fraction=0.1,
        inner_margin_fraction=0.09,
        outer_margin_fraction=0.35
    )

    # Segment shoots using both models
    shoot_mask1 = segment_image_in_patches(shoot_model1, cropped_img)
    shoot_mask2 = segment_image_in_patches(shoot_model2, cropped_img)

    # Combine shoot predictions 
    combined_shoot_mask = (shoot_mask1 + shoot_mask2)

    # Refine the combined shoot mask
    combined_shoot_mask = apply_shoot_roi_mask(combined_shoot_mask, shoot_roi)
    combined_shoot_mask = refine_mask_for_shoots(combined_shoot_mask, min_pixels=2)
    shoot_bboxes = extract_bounding_boxes_for_shoots(combined_shoot_mask, min_area=2)

    # Segment roots using both models
    root_mask1 = segment_image_in_patches(root_model1, cropped_img)
    root_mask2 = segment_image_in_patches(root_model2, cropped_img)

    # Combine root predictions 
    combined_root_mask = (root_mask1 + root_mask2)

    # Refine the combined root mask
    combined_root_mask = refine_mask_for_roots(combined_root_mask, min_pixels=1)

    # Extract root bounding boxes
    root_bboxes = extract_bounding_boxes_for_roots(combined_root_mask, min_area=1)

    # Filter bboxes to those below the shoot ROI
    filtered_root_bboxes_in_shoot_roi = [
        bbox for bbox in root_bboxes
        if bbox[1] >= shoot_roi[1] and bbox[1] < (shoot_roi[1] + shoot_roi[3])
    ]

    # Divide the region into zones (5) inside the shoot ROI
    zones_within_roi = calculate_zone_boundaries_within_roi(shoot_roi, num_zones=5, img_width=image_width)
    roots_by_zone = assign_roots_to_zones_within_roi(filtered_root_bboxes_in_shoot_roi, zones_within_roi)

    # Pick primary root in each zone with bounding box extension
    extended_roots_by_zone = select_primary_root_with_extensions(
        filtered_root_bboxes_in_shoot_roi, shoot_roi, zones_within_roi, image_width
    )

    # Extract skeleton masks and measure length per zone
    primary_roots_dict = {
        zone: root_info['bbox']
        for zone, root_info in extended_roots_by_zone.items()
        if root_info is not None
    }

    skeletonized_roots_by_zone, root_lengths_by_zone, root_tips_by_zone = extract_region_masks_with_tips(
        primary_roots_dict,
        combined_root_mask,
        cropped_img
    )

    # Optional visualizations
    if visualize:
        visualize_predictions_with_tips(
            cropped_img,
            shoot_bboxes,
            filtered_root_bboxes_in_shoot_roi,
            roots_by_zone,
            shoot_roi,
            root_tips_by_zone,
            title="Filtered Predictions with Tips"
        )

    return {"Root Lengths by Zone": root_lengths_by_zone, "Root Tips by Zone": root_tips_by_zone}

# ======================================
#   Batch Execution Across Directory
# ======================================
def run_analysis_across_directory(root_model1, root_model2, shoot_model1, shoot_model2, folder_path, visualize=False):
    """
    Process all images in the directory and return the analysis results.
    Optionally, visualize the results for each image.
    """
    # Validate folder_path is a directory
    if not os.path.isdir(folder_path):
        raise NotADirectoryError(f"Provided path '{folder_path}' is not a directory.")
    
    gathered_data = []

    for file_name in os.listdir(folder_path):
        if file_name.lower().endswith(('.png', '.jpg', '.jpeg')):  # Accept common image formats
            image_path = os.path.join(folder_path, file_name)
            try:
                # Load the image
                original_image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
                if original_image is None:
                    print(f"Warning: Skipping file {file_name} (unable to load image).")
                    continue
                
                # Run the main pipeline with dual shoot models
                results = main_pipeline_with_dual_shoot_models_and_tips(
                    root_model1_path=root_model1,
                    root_model2_path=root_model2,
                    shoot_model1_path=shoot_model1,
                    shoot_model2_path=shoot_model2,
                    data_dir=image_path,
                    visualize=visualize
                )
                # Format results
                for zone, length in results["Root Lengths by Zone"].items():
                    tip = results["Root Tips by Zone"].get(zone, None)
                    plant_id = f"{os.path.splitext(file_name)[0]}_plant_{zone + 1}"
                    gathered_data.append([plant_id, length, tip])
            except Exception as err:
                print(f"Failed to process {file_name}: {err}")

    # Return results as a list
    return gathered_data


# ======================================
#  Running the analysis
# ======================================

if __name__ == "__main__":
    root_model1_path = 'converted_models/converted_AlexiKehayias_232230_unet_model_Iteration_1_Root_Final_256px.h5'
    root_model2_path = 'converted_models/converted_AlexiKehayias_232230_unet_model_Iteration_10_Test_256px.h5'
    shoot_model1_path = 'converted_models/converted_AlexiKehayias_232230_unet_model_Iteration_1_Shoot_Final_256px.h5'
    shoot_model2_path = 'converted_models/converted_AlexiKehayias_232230_unet_model_Iteration_1_Shoot_256px.h5'
    data_dir = 'C:/Users/User/Desktop/2024-25b-fai2-adsai-AlexiKehayias232230/datalab_tasks/task8/Test_Data'

    results = run_analysis_across_directory(
        root_model1=root_model1_path,
        root_model2=root_model2_path,
        shoot_model1=shoot_model1_path,
        shoot_model2=shoot_model2_path,
        folder_path=data_dir,
        visualize=True
    )

    for result in results:
        print(result)


Failed to process test_image_1.png: [Errno 2] Unable to synchronously open file (unable to open file: name = 'converted_models/converted_AlexiKehayias_232230_unet_model_Iteration_1_Root_Final_256px.h5', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0)
Failed to process test_image_10.png: [Errno 2] Unable to synchronously open file (unable to open file: name = 'converted_models/converted_AlexiKehayias_232230_unet_model_Iteration_1_Root_Final_256px.h5', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0)
Failed to process test_image_11.png: [Errno 2] Unable to synchronously open file (unable to open file: name = 'converted_models/converted_AlexiKehayias_232230_unet_model_Iteration_1_Root_Final_256px.h5', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0)
Failed to process test_image_12.png: [Errno 2] Unable to synchronously open file (unable to open file: name = 'converted_models/converted_AlexiKehayi