In [1]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

In [2]:
def dice_loss(y_true, y_pred):
    intersection = K.sum(y_true * y_pred)
    union = K.sum(y_true) + K.sum(y_pred)
    dice = (2. * intersection + K.epsilon()) / (union + K.epsilon())
    return 1 - dice
def f1(y_true, y_pred):
    def recall_m(y_true, y_pred):
        TP = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        Positives = K.sum(K.round(K.clip(y_true, 0, 1)))
        recall = TP / (Positives+K.epsilon())
        return recall
    
    def precision_m(y_true, y_pred):
        TP = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        Pred_Positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
        precision = TP / (Pred_Positives+K.epsilon())
        return precision
    
    precision, recall = precision_m(y_true, y_pred), recall_m(y_true, y_pred)
    
    return 2*((precision*recall)/(precision+recall+K.epsilon()))

In [3]:
import os
import cv2
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
from stable_baselines3 import PPO  # Import RL model
import matplotlib.pyplot as plt
import networkx as nx
from skimage.morphology import skeletonize
from ot2_env_wrapper import OT2Env  # Custom environment wrapper
from clearml import Task  # Import ClearML's Task



def crop_petri_dish(image, patch_size):
    """
    Detect and crop the Petri dish from the image.

    Parameters:
    - image: Input image (numpy array).
    - patch_size: Tuple (height, width) to pad the cropped Petri dish.

    Returns:
    - Cropped image focused on the Petri dish.
    - Bounding box of the Petri dish.
    - Success flag.
    """
    # Threshold the image to create a binary mask
    _, binary = cv2.threshold(image, 100, 255, cv2.THRESH_BINARY)

    # Find contours in the binary mask
    contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Detect the largest contour as the Petri dish
    largest_contour = max(contours, key=cv2.contourArea, default=None)
    if largest_contour is None:
        print("Error: No Petri dish detected.")
        return None, None, False

    # Get the bounding box of the Petri dish
    x, y, w, h = cv2.boundingRect(largest_contour)

    # Crop the image based on the bounding box
    cropped_image = image[y:y + h, x:x + w]

    # Pad the cropped image to ensure it matches the patch size
    padded_image = pad_image(cropped_image, patch_size)

    return padded_image, (x, y, w, h), True


def pad_image(image, patch_size):
    """
    Pad the cropped image to match the required patch size.

    Parameters:
    - image: Input cropped image (numpy array).
    - patch_size: Tuple (height, width) for padding.

    Returns:
    - Padded image.
    """
    height, width = image.shape[:2]
    pad_height = (patch_size[0] - height % patch_size[0]) % patch_size[0]
    pad_width = (patch_size[1] - width % patch_size[1]) % patch_size[1]
    return cv2.copyMakeBorder(image, 0, pad_height, 0, pad_width, cv2.BORDER_CONSTANT, value=0)


def split_image(image, num_parts):
    """
    Split the image into equal vertical parts.

    Parameters:
    - image: Input image (numpy array).
    - num_parts: Number of parts to split the image into.

    Returns:
    - List of image splits.
    """
    height, width = image.shape[:2]
    part_width = width // num_parts
    return [image[:, i * part_width:(i + 1) * part_width] for i in range(num_parts)]


def merge_images(splits, original_shape):
    """
    Merge image splits back into the original shape.

    Parameters:
    - splits: List of image splits.
    - original_shape: Tuple (height, width) of the original image.

    Returns:
    - Merged image.
    """
    merged_image = np.zeros(original_shape, dtype=np.uint8)
    part_width = original_shape[1] // len(splits)
    for i, split in enumerate(splits):
        merged_image[:, i * part_width:(i + 1) * part_width] = split
    return merged_image


def generate_mask(image, model, patch_size):
    """
    Generate a binary mask from an input image using the trained model.

    Parameters:
    - image: Input image (numpy array).
    - model: Trained Keras model.
    - patch_size: Tuple (height, width) for patching.

    Returns:
    - mask: Predicted binary mask (numpy array).
    """
    height, width = image.shape[:2]
    # Pad the image
    pad_height = (patch_size[0] - height % patch_size[0]) % patch_size[0]
    pad_width = (patch_size[1] - width % patch_size[1]) % patch_size[1]
    padded_image = cv2.copyMakeBorder(image, 0, pad_height, 0, pad_width, cv2.BORDER_CONSTANT, value=0)

    # Patch the image
    patches = []
    for y in range(0, padded_image.shape[0], patch_size[0]):
        for x in range(0, padded_image.shape[1], patch_size[1]):
            patch = padded_image[y:y + patch_size[0], x:x + patch_size[1]]
            patches.append(patch / 255.0)  # Normalize

    patches = np.array(patches)[..., np.newaxis]  # Add channel dimension

    # Predict patches
    predicted_patches = model.predict(patches)
    predicted_patches = (predicted_patches > 0.5).astype(np.uint8) * 255

    # Reconstruct the mask
    reconstructed_mask = np.zeros_like(padded_image, dtype=np.uint8)
    idx = 0
    for y in range(0, padded_image.shape[0], patch_size[0]):
        for x in range(0, padded_image.shape[1], patch_size[1]):
            reconstructed_mask[y:y + patch_size[0], x:x + patch_size[1]] = predicted_patches[idx].squeeze()
            idx += 1

    # Crop back to the original size
    return reconstructed_mask[:height, :width]


def detect_root_tip_with_skeletonization(mask, kernel_size=10, closing_iterations=3, min_area=400):
    """
    Detect the root tip using the improved skeletonization process.

    Parameters:
    - mask: Binary mask (numpy array).
    - kernel_size: Kernel size for morphological operations.
    - closing_iterations: Number of iterations for morphological closing.
    - min_area: Minimum area for connected components.

    Returns:
    - root_tip: Tuple (y_pixel, x_pixel) of the root tip.
    """
    # Improve connectivity
    improved_mask = improve_connectivity(mask, kernel_size, closing_iterations)

    # Extract large connected components
    large_components = extract_large_components(improved_mask, min_area=min_area)

    # Find the longest path in the largest component
    longest_path = None
    max_length = 0

    for label, component_mask, area, stats in large_components:
        path, length, skeleton = find_longest_path_in_component(component_mask)
        if length > max_length:
            max_length = length
            longest_path = path

    if not longest_path:
        raise ValueError("No valid root tip detected in the skeletonized mask.")

    # The root tip is the endpoint of the longest path (lowest pixel)
    root_tip = longest_path[-1]
    return root_tip

def improve_connectivity(mask, kernel_size=5, closing_iterations=3):
    """
    Improves connectivity in the binary mask by applying morphological closing.
    """
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
    improved_mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=closing_iterations)
    return improved_mask


def skeleton_to_graph(skeleton):
    """Converts a skeletonized image to a graph using networkx."""
    G = nx.Graph()
    skeleton_pixels = np.argwhere(skeleton > 0)
    for y, x in skeleton_pixels:
        G.add_node((y, x))
        for dy, dx in [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)]:
            neighbor = (y + dy, x + dx)
            if (neighbor[0] >= 0 and neighbor[1] >= 0 and neighbor in G.nodes):
                G.add_edge((y, x), neighbor)
    return G


def find_longest_path_in_component(component_mask):
    """
    Finds the longest path in a single connected component using its skeleton.
    """
    skeleton = skeletonize(component_mask // 255)
    G = skeleton_to_graph(skeleton)
    topmost_pixel = tuple(np.argwhere(skeleton > 0).min(axis=0))

    if topmost_pixel not in G.nodes:
        skeleton_pixels = np.array(list(G.nodes))
        distances = np.sum(np.abs(skeleton_pixels - np.array(topmost_pixel)), axis=1)
        closest_node = tuple(skeleton_pixels[np.argmin(distances)])
        topmost_pixel = closest_node

    lengths = nx.single_source_dijkstra_path_length(G, source=topmost_pixel)
    bottommost_pixel = max(lengths, key=lengths.get)
    longest_path = nx.shortest_path(G, source=topmost_pixel, target=bottommost_pixel)
    return longest_path, lengths[bottommost_pixel], skeleton


def extract_large_components(mask, min_area=500):
    """
    Extracts connected components larger than a specified area.
    """
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
    large_components = []
    for label in range(1, num_labels):  # Skip background
        area = stats[label, cv2.CC_STAT_AREA]
        if area >= min_area:
            component_mask = np.zeros_like(mask, dtype=np.uint8)
            component_mask[labels == label] = 255
            large_components.append((label, component_mask, area, stats[label]))
    return large_components

def convert_pixel_to_mm(root_tip_pixel, image_height, plate_height_mm):
    """
    Convert pixel coordinates to millimeters.

    Parameters:
    - root_tip_pixel: Root tip coordinates in pixels (y_pixel, x_pixel).
    - image_height: Height of the original image in pixels.
    - plate_height_mm: Real-world height of the plate in millimeters.

    Returns:
    - root_tip_mm: Root tip coordinates in millimeters (x_mm, y_mm, z_mm).
    """
    scale = plate_height_mm / image_height  # mm per pixel
    y_mm = root_tip_pixel[0] * scale
    x_mm = root_tip_pixel[1] * scale
    return (x_mm, y_mm, 0)  


def convert_to_robot_coordinates(root_tip_mm, plate_position_robot):
    """
    Convert root tip positions in mm (relative to the plate) to robot space.

    Parameters:
    - root_tip_mm: Root tip position in mm (x_mm, y_mm, z_mm).
    - plate_position_robot: Position of the top-left corner of the plate in robot space [x, y, z].

    Returns:
    - root_tip_robot: Root tip position in robot space (x_robot, y_robot, z_robot).
    """
    # Convert from mm to meters
    root_tip_m = [
        root_tip_mm[0] / 1000,
        root_tip_mm[1] / 1000,
        root_tip_mm[2] / 1000,
    ]

    # Add the plate position to get the robot space coordinates
    return [
        root_tip_m[0] + plate_position_robot[0],
        root_tip_m[1] + plate_position_robot[1],
        root_tip_m[2] + plate_position_robot[2],
    ]

def inoculate_with_rl(env, root_tips_robot, model_path_LR):
    """
    Perform root tip inoculation using an RL model, ensuring X and Y accuracy.

    Parameters:
    - env: The simulation environment.
    - root_tips_robot: List of root tip coordinates in robot space [(x, y, z), ...].
    - model_path_LR: Path to the trained RL model.
    """
    # Load the trained RL model
    rl_model = PPO.load(model_path_LR)

    for idx, root_tip in enumerate(root_tips_robot):
        print(f"Starting inoculation for root tip {idx + 1} at {root_tip}")
        is_done = False
        is_truncated = False
        obs, _ = env.reset()
        env.goal_position = np.array(root_tip)  # Set the goal position for the environment

        while not is_done and not is_truncated:
            try:
                # Predict the action using the RL model
                action, _ = rl_model.predict(obs)
                obs, reward, is_done, is_truncated, info = env.step(action)

                # Validate step outputs
                assert isinstance(is_done, bool), "`done` should be a bool."
                assert isinstance(is_truncated, bool), "`truncated` should be a bool."
                assert isinstance(obs, np.ndarray), "`obs` should be a numpy array."

                # Log the current pipette position and error
                current_position = obs[:3]
                xy_error = np.linalg.norm(env.goal_position[:2] - current_position[:2])
                print(f"Step: Current XY: {current_position[:2]}, Goal XY: {env.goal_position[:2]}, Error XY: {xy_error}")

                if is_done:
                    print(f"Inoculating at position {current_position[:2]} (XY Accuracy Met)")
                    print("Simulating inoculum drop...")
                    break

            except Exception as e:
                print(f"Error during RL step execution: {e}")
                break

        print(f"Finished processing root tip {idx + 1}\n")


# Main Workflow for RL Model with Mask Segmentation
if __name__ == "__main__":
    # Initialize environment and parameters
    env = OT2Env(render=True)
    
    # Parameters
    mask_model_path = r"C:\Users\Edopi\Desktop\2024-25b-fai2-adsai-EdoardoPierezza231412\datalab_tasks\Task8\Edoardo_231412_undet_model256px_data augmentation_1.h5"  # Path to segmentation (mask) model
    model_path_LR = r"C:\Users\Edopi\Downloads\model (7).zip" # Path to RL model
    patch_size = (256, 256)
    plate_position_robot = [0.10775, 0.088 - 0.026, 0.057]  # Adjusted plate position
    image_height = 2816  # Original image height in pixels
    plate_height_mm = 150  # Plate height in millimeters

    try:
        # Capture the image from the environment
        print("Capturing image from the environment...")
        image_path = env.image()
        image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)

        if image is None:
            raise ValueError("Failed to load the image. Check the image path or capture process.")

        # Crop the Petri dish from the captured image
        print("Cropping the Petri dish from the image...")
        cropped_image, bbox, success = crop_petri_dish(image, patch_size)
        if not success:
            raise RuntimeError("Failed to detect and crop the Petri dish.")

        # Split the cropped image into parts for processing
        print("Splitting the cropped image...")
        splits = split_image(cropped_image, num_parts=5)

        # Load the mask model
        print("Loading the mask segmentation model...")
        mask_model = load_model(mask_model_path, custom_objects={"f1": f1, "dice_loss": dice_loss})

        # Initialize variables to store results
        root_tips_mm = []
        root_tips_robot = []

        # Process each split image
        for idx, split in enumerate(splits):
            print(f"Processing split {idx + 1} of {len(splits)}...")

            try:
                # Generate the mask for the current split
                mask = generate_mask(split, mask_model, patch_size)

                # Detect the root tip using skeletonization
                root_tip_pixel = detect_root_tip_with_skeletonization(
                    mask, kernel_size=10, closing_iterations=3, min_area=400
                )

                # Convert the root tip's pixel coordinates to millimeters
                root_tip_mm = convert_pixel_to_mm(
                    root_tip_pixel, image_height // len(splits), plate_height_mm // len(splits)
                )
                root_tips_mm.append(root_tip_mm)

                # Convert the millimeter coordinates to robot space
                root_tip_robot = convert_to_robot_coordinates(root_tip_mm, plate_position_robot)
                root_tips_robot.append(root_tip_robot)

                print(f"Split {idx + 1} root tip in robot coordinates: {root_tip_robot}")

            except Exception as e:
                print(f"Warning: No root detected in split {idx + 1}. Skipping to the next split. Error: {e}")

        # Perform inoculation using RL
        print("Starting inoculation process with RL model...")
        inoculate_with_rl(env, root_tips_robot, model_path_LR)

        print("Inoculation process with RL model completed.")

    except Exception as e:
        print(f"An error occurred during execution: {e}")
    
    finally:
        # Close the environment after execution
        print("Closing the environment...")
        env.close()


   


Capturing image from the environment...
Image captured and saved at: textures/_plates/034_43-13-ROOT1-2023-08-08_control_pH7_-Fe+B_col0_02-Fish Eye Corrected.png
Cropping the Petri dish from the image...
Splitting the cropped image...
Loading the mask segmentation model...
Processing split 1 of 5...
Split 1 root tip in robot coordinates: [0.10775, 0.14054351687388988, 0.057]
Processing split 2 of 5...
Split 2 root tip in robot coordinates: [0.1319418294849023, 0.11970870337477796, 0.057]
Processing split 3 of 5...
Split 3 root tip in robot coordinates: [0.12474822380106572, 0.12359857904085257, 0.057]
Processing split 4 of 5...
Split 4 root tip in robot coordinates: [0.12128463587921848, 0.10841207815275311, 0.057]
Processing split 5 of 5...
Split 5 root tip in robot coordinates: [0.12005905861456483, 0.11549911190053286, 0.057]
Starting inoculation process with RL model...
Starting inoculation for root tip 1 at [0.10775, 0.14054351687388988, 0.057]
Reset: Pipette Position [0.073  0.08