In [2]:
import tensorflow.keras.backend as K
import tensorflow as tf
from tensorflow.keras import backend as K
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
from patchify import patchify, unpatchify
from tensorflow.keras import backend as K
from skimage.measure import label

def f1(y_true, y_pred, threshold=0.3):
    y_pred = tf.cast(y_pred > threshold, tf.float32)
    TP = tf.keras.backend.sum(tf.keras.backend.round(tf.keras.backend.clip(y_true * y_pred, 0, 1)))
    Positives = tf.keras.backend.sum(tf.keras.backend.round(tf.keras.backend.clip(y_true, 0, 1)))
    Pred_Positives = tf.keras.backend.sum(tf.keras.backend.round(tf.keras.backend.clip(y_pred, 0, 1)))
    precision = TP / (Pred_Positives + tf.keras.backend.epsilon())
    recall = TP / (Positives + tf.keras.backend.epsilon())
    return 2 * ((precision * recall) / (precision + recall + tf.keras.backend.epsilon()))
    
@tf.keras.utils.register_keras_serializable()
def f1_metric(y_true, y_pred):
    return f1(y_true, y_pred, threshold=0.3)


def weighted_binary_crossentropy(y_true, y_pred):
    """
    Weighted binary cross-entropy to address class imbalance.
    """
    # Define weights for foreground (root, shoot, seed) and background
    weight_foreground = 10.0
    weight_background = 1.0

    # Compute weighted binary cross-entropy
    y_true = K.flatten(y_true)
    y_pred = K.flatten(y_pred)
    weights = tf.where(y_true == 1, weight_foreground, weight_background)
    loss = K.binary_crossentropy(y_true, y_pred)
    weighted_loss = loss * weights
    return K.mean(weighted_loss)
def dice_loss(y_true, y_pred):
    intersection = tf.reduce_sum(y_true * y_pred)
    union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred)
    dice = (2. * intersection + 1e-7) / (union + 1e-7)
    return 1 - dice

@tf.keras.utils.register_keras_serializable()
def combined_loss(y_true, y_pred):
    return 0.5 * dice_loss(y_true, y_pred) + 0.5 * weighted_binary_crossentropy(y_true, y_pred)





In [5]:
import os
import cv2
import numpy as np
import tensorflow as tf
import logging
import json

# Logging configuration
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[logging.FileHandler("pipeline_log.txt"), logging.StreamHandler()],
)

def preprocess_image(image_path):
    """Load and preprocess the input image."""
    image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    if image is None:
        raise ValueError(f"Failed to load image: {image_path}")
    return image

def extract_petri_dish(image):
    """Extract the largest contour assumed to be the Petri dish."""
    _, thresholded = cv2.threshold(image, 57, 255, cv2.THRESH_BINARY)
    contours, _ = cv2.findContours(thresholded, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if not contours:
        logging.warning("No contours detected for Petri dish.")
        return image, np.ones_like(image)

    largest_contour = max(contours, key=cv2.contourArea)
    x, y, w, h = cv2.boundingRect(largest_contour)
    cropped_image = image[y:y+h, x:x+w]
    return cropped_image

def predict_root_mask(image, model, patch_size=128, stride=64, batch_size=16):
    """Predict root mask using a patch-based model with smaller batches."""
    h, w = image.shape
    patches = []
    positions = []
    for y in range(0, h - patch_size + 1, stride):
        for x in range(0, w - patch_size + 1, stride):
            patch = image[y:y+patch_size, x:x+patch_size]
            patch_rgb = np.stack([patch] * 3, axis=-1)
            patches.append(patch_rgb)
            positions.append((y, x))

    patches = np.array(patches) / 255.0
    predictions = []
    for i in range(0, len(patches), batch_size):
        batch = patches[i:i + batch_size]
        batch_predictions = model.predict(batch, verbose=0)
        predictions.extend(batch_predictions)

    reconstructed = np.zeros((h, w), dtype=np.float32)
    counts = np.zeros((h, w), dtype=np.float32)
    for pred, (y, x) in zip(predictions, positions):
        pred = pred[..., 0]
        reconstructed[y:y+patch_size, x:x+patch_size] += pred
        counts[y:y+patch_size, x:x+patch_size] += 1
    return (reconstructed / np.maximum(counts, 1) > 0.3).astype(np.uint8)

def connect_roots(mask):
    """Dilate the mask to connect fragmented roots."""
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2, 2))
    return cv2.dilate(mask, kernel, iterations=27)

def filter_and_select_largest_objects(mask, min_area=500, max_objects=5):
    """Filter and retain the largest root components."""
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
    valid_objects = [(i, stats[i, cv2.CC_STAT_AREA]) for i in range(1, num_labels) if stats[i, cv2.CC_STAT_AREA] >= min_area]
    largest_objects = sorted(valid_objects, key=lambda x: x[1], reverse=True)[:max_objects]

    filtered_mask = np.zeros_like(mask, dtype=np.uint8)
    for obj_id, _ in largest_objects:
        filtered_mask[labels == obj_id] = 255
    return filtered_mask, largest_objects

def measure_bounding_boxes(filtered_mask):
    """Measure roots with bounding boxes."""
    contours, _ = cv2.findContours(filtered_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    bounding_boxes = [cv2.boundingRect(contour) for contour in contours]
    bounding_boxes = sorted(bounding_boxes, key=lambda b: b[0])
    return bounding_boxes

def find_lowest_point(mask, bounding_box):
    """Find the lowest point in a bounding box."""
    x, y, w, h = bounding_box
    roi = mask[y:y+h, x:x+w]
    coordinates = np.column_stack(np.where(roi > 0))
    if len(coordinates) == 0:
        return None
    lowest_local = coordinates[np.argmax(coordinates[:, 0])]
    return lowest_local + [y, x]

def convert_to_mm(pixel_coords, image_shape, plate_size_mm=150):
    """Convert pixel coordinates to mm-space."""
    h, w = image_shape
    conversion_factor_x = plate_size_mm / w
    conversion_factor_y = plate_size_mm / h
    return np.array([pixel_coords[1] * conversion_factor_x, pixel_coords[0] * conversion_factor_y])

# Directory containing images
input_dir = "kaggle_test"
output_file = "root_tip_coordinates.json"
model_path = "232430_unet_model_128px_v8md.keras"
plate_position_robot = np.array([0.10775, 0.088 - 0.026, 0.057])  # Adjust based on your setup

# Register custom objects for model deserialization
custom_objects = {
    "combined_loss": combined_loss,
    "f1_metric": f1_metric
}

# Limit GPU memory growth
physical_devices = tf.config.list_physical_devices('GPU')
if physical_devices:
    try:
        for device in physical_devices:
            tf.config.experimental.set_memory_growth(device, True)
        logging.info("Enabled memory growth for GPU.")
    except Exception as e:
        logging.error(f"Error enabling memory growth for GPU: {e}")

# Load the model with custom objects
model = tf.keras.models.load_model(model_path, custom_objects=custom_objects)

# Process all images in the directory
all_root_tip_coords = {}

for image_name in sorted(os.listdir(input_dir)):
    if not image_name.endswith((".png", ".jpg", ".jpeg")):
        continue

    image_path = os.path.join(input_dir, image_name)
    logging.info(f"Processing image: {image_name}")

    try:
        # Process the image and extract root tip coordinates
        image = preprocess_image(image_path)
        petri_dish = extract_petri_dish(image)
        predicted_mask = predict_root_mask(petri_dish, model)
        connected_mask = connect_roots(predicted_mask)
        filtered_mask, _ = filter_and_select_largest_objects(connected_mask, min_area=500, max_objects=5)
        bounding_boxes = measure_bounding_boxes(filtered_mask)

        root_tip_coords = []
        for box in bounding_boxes:
            point = find_lowest_point(filtered_mask, box)
            if point is not None:
                mm_coords = convert_to_mm(point, petri_dish.shape)
                robot_coords = np.append(mm_coords, 0) + plate_position_robot
                root_tip_coords.append(robot_coords.tolist())  # Convert to list for JSON serialization

        all_root_tip_coords[image_name] = root_tip_coords
        logging.info(f"Extracted Root Tip Coordinates for {image_name}: {root_tip_coords}")

    except Exception as e:
        logging.error(f"Error processing image {image_name}: {e}")

# Save the root tip coordinates to a JSON file
with open(output_file, "w") as f:
    json.dump(all_root_tip_coords, f, indent=4)
    logging.info(f"Saved root tip coordinates to {output_file}")


2025-01-12 17:11:12,666 - INFO - Enabled memory growth for GPU.
2025-01-12 17:11:13,170 - INFO - Processing image: test_image_1.png
2025-01-12 17:11:20,499 - INFO - Extracted Root Tip Coordinates for test_image_1.png: [[6.8935945822692455, 43.06671185212033, 0.057], [23.201511400948558, 54.59262703878216, 0.057], [28.18144573148486, 31.975736861181588, 0.057], [76.1201542320321, 39.47845523740485, 0.057], [142.5010371214885, 26.1040442189199, 0.057]]
2025-01-12 17:11:20,500 - INFO - Processing image: test_image_10.png
2025-01-12 17:11:28,595 - INFO - Extracted Root Tip Coordinates for test_image_10.png: [[14.689594174410292, 90.18015252416755, 0.057], [32.21997301644031, 115.79776799140708, 0.057], [71.46228895639742, 107.09744575725027, 0.057], [107.7560702287348, 96.73224704618688, 0.057], [135.57951554681915, 88.78380451127819, 0.057]]
2025-01-12 17:11:28,596 - INFO - Processing image: test_image_11.png
2025-01-12 17:11:36,677 - INFO - Extracted Root Tip Coordinates for test_image_1