In [2]:
import tensorflow.keras.backend as K
import tensorflow as tf
from tensorflow.keras import backend as K
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
from patchify import patchify, unpatchify
from tensorflow.keras import backend as K
from skimage.measure import label

def f1(y_true, y_pred, threshold=0.3):
    y_pred = tf.cast(y_pred > threshold, tf.float32)
    TP = tf.keras.backend.sum(tf.keras.backend.round(tf.keras.backend.clip(y_true * y_pred, 0, 1)))
    Positives = tf.keras.backend.sum(tf.keras.backend.round(tf.keras.backend.clip(y_true, 0, 1)))
    Pred_Positives = tf.keras.backend.sum(tf.keras.backend.round(tf.keras.backend.clip(y_pred, 0, 1)))
    precision = TP / (Pred_Positives + tf.keras.backend.epsilon())
    recall = TP / (Positives + tf.keras.backend.epsilon())
    return 2 * ((precision * recall) / (precision + recall + tf.keras.backend.epsilon()))
    
@tf.keras.utils.register_keras_serializable()
def f1_metric(y_true, y_pred):
    return f1(y_true, y_pred, threshold=0.3)


def weighted_binary_crossentropy(y_true, y_pred):
    """
    Weighted binary cross-entropy to address class imbalance.
    """
    # Define weights for foreground (root, shoot, seed) and background
    weight_foreground = 10.0
    weight_background = 1.0

    # Compute weighted binary cross-entropy
    y_true = K.flatten(y_true)
    y_pred = K.flatten(y_pred)
    weights = tf.where(y_true == 1, weight_foreground, weight_background)
    loss = K.binary_crossentropy(y_true, y_pred)
    weighted_loss = loss * weights
    return K.mean(weighted_loss)
def dice_loss(y_true, y_pred):
    intersection = tf.reduce_sum(y_true * y_pred)
    union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred)
    dice = (2. * intersection + 1e-7) / (union + 1e-7)
    return 1 - dice

@tf.keras.utils.register_keras_serializable()
def combined_loss(y_true, y_pred):
    return 0.5 * dice_loss(y_true, y_pred) + 0.5 * weighted_binary_crossentropy(y_true, y_pred)





In [4]:
import os
import cv2
import numpy as np
import tensorflow as tf
import logging
from pid_class import PIDController
from task10_ot2_gym_wrapper import OT2Env

# Logging configuration
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[logging.FileHandler("integrated_pipeline_log.txt"), logging.StreamHandler()],
)


# Functions
def preprocess_image(image_path):
    image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    if image is None:
        raise ValueError(f"Failed to load image: {image_path}")
    logging.info(f"Loaded image from: {image_path}")
    return image

def extract_petri_dish(image):
    _, thresholded = cv2.threshold(image, 57, 255, cv2.THRESH_BINARY)
    contours, _ = cv2.findContours(thresholded, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if not contours:
        logging.warning("No contours detected for Petri dish.")
        return image, None
    largest_contour = max(contours, key=cv2.contourArea)
    x, y, w, h = cv2.boundingRect(largest_contour)
    cropped_image = image[y:y+h, x:x+w]
    return cropped_image, None

def predict_root_mask(image, model, patch_size=128, stride=64):
    h, w = image.shape
    patches = []
    positions = []
    for y in range(0, h - patch_size + 1, stride):
        for x in range(0, w - patch_size + 1, stride):
            patch = image[y:y+patch_size, x:x+patch_size]
            patches.append(np.stack([patch]*3, axis=-1))
            positions.append((y, x))
    patches = np.array(patches) / 255.0
    predictions = model.predict(patches, verbose=0)
    reconstructed = np.zeros((h, w), dtype=np.float32)
    counts = np.zeros((h, w), dtype=np.float32)
    for pred, (y, x) in zip(predictions, positions):
        reconstructed[y:y+patch_size, x:x+patch_size] += pred[..., 0]
        counts[y:y+patch_size, x:x+patch_size] += 1
    mask = (reconstructed / np.maximum(counts, 1) > 0.5).astype(np.uint8)
    logging.info("Root mask predicted.")
    return mask

def convert_to_mm(pixel_coords):
    """Convert pixel coordinates to mm-space."""
    return np.array(pixel_coords) * CONVERSION_FACTOR

def convert_to_robot_space(mm_coords):
    """Convert mm coordinates to robot space."""
    return mm_coords + PLATE_POSITION_ROBOT

def find_lowest_point(mask, bounding_box):
    x, y, w, h = bounding_box
    roi = mask[y:y+h, x:x+w]
    coordinates = np.column_stack(np.where(roi > 0))
    if len(coordinates) == 0:
        return None
    lowest_local = coordinates[np.argmax(coordinates[:, 0])]
    return lowest_local + [y, x]

def measure_bounding_boxes(filtered_mask):
    """Measure roots with bounding boxes."""
    contours, _ = cv2.findContours(filtered_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    bounding_boxes = [cv2.boundingRect(contour) for contour in contours]
    bounding_boxes = sorted(bounding_boxes, key=lambda b: b[0])  # Sort left-to-right
    logging.info(f"Detected {len(bounding_boxes)} bounding boxes.")
    return bounding_boxes

def visualize_results(image, mask, bounding_boxes, robot_coords):
    """Visualize the predictions and results."""
    boxed_image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
    for (x, y, w, h), robot_coord in zip(bounding_boxes, robot_coords):
        cv2.rectangle(boxed_image, (x, y), (x + w, y + h), (0, 255, 0), 2)
        cv2.putText(
            boxed_image,
            f"({robot_coord[0]:.3f}, {robot_coord[1]:.3f}, {robot_coord[2]:.3f})",
            (x, y - 10),
            cv2.FONT_HERSHEY_SIMPLEX,
            0.4,
            (0, 255, 0),
            1,
        )
    plt.imshow(boxed_image)
    plt.title("Detected Roots with Robot Coordinates")
    plt.axis("off")
    plt.show()

def full_pipeline_with_pid(sim, model_path, env):
    """Full pipeline with integrated PID control."""
    model = tf.keras.models.load_model(model_path)
    image_path = sim.get_plate_image()  # Get image path from simulation
    image = preprocess_image(image_path)
    petri_dish, _ = extract_petri_dish(image)

        # Constants
    PLATE_SIZE_MM = 150  # Plate size in mm
    PLATE_POSITION_ROBOT = np.array([0.10775, 0.088 - 0.026, 0.057])  # Adjusted plate position in robot space
    DEFAULT_Z = 0.057  # Default z-coordinate
    CONVERSION_FACTOR = PLATE_SIZE_MM / image.shape[0]  # Conversion factor (assume plate size in pixels is 1024)

    # Predictions
    predicted_mask = predict_root_mask(petri_dish, model)
    connected_mask = connect_roots(predicted_mask)

    # Bounding Box Detection
    bounding_boxes = measure_bounding_boxes(connected_mask)
    robot_coords = []

    for box in bounding_boxes:
        point = find_lowest_point(connected_mask, box)
        if point is not None:
            mm_coords = convert_to_mm(point)
            robot_coord = convert_to_robot_space(mm_coords)
            robot_coords.append(robot_coord)
            logging.info(f"Root tip robot coordinate: {robot_coord}")
        else:
            logging.warning("No valid lowest point found for bounding box.")

    # Visualize results
    visualize_results(petri_dish, connected_mask, bounding_boxes, robot_coords)

    # PID Control Loop
    for idx, coord in enumerate(robot_coords):
        logging.info(f"Navigating to root tip {idx + 1}: {coord}")
        observation, info = env.reset()
        pid_x, pid_y, pid_z = PIDController(), PIDController(), PIDController()
        pid_x.setpoint, pid_y.setpoint, pid_z.setpoint = coord

        for _ in range(200):
            current_position = observation[:3]
            error = np.linalg.norm(coord - current_position)
            if error < 0.01:  # Threshold for reaching the target
                logging.info(f"Reached root tip {idx + 1} with error {error:.4f}")
                break
            action = np.array([
                pid_x.compute(current_position[0]),
                pid_y.compute(current_position[1]),
                pid_z.compute(current_position[2])
            ])
            observation, _, terminated, _, _ = env.step(action)
            if terminated:
                logging.warning("Simulation terminated unexpectedly.")
                break

# Initialize simulation and environment
env = OT2Env()
sim = OT2Env()  # Replace with your simulator instance
full_pipeline_with_pid(sim, "232430_unet_model_128px_v7md.keras", env)


AttributeError: 'OT2Env' object has no attribute 'get_plate_image'