In [1]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from skimage.morphology import remove_small_objects, skeletonize
from skimage.measure import label
from skan import Skeleton, summarize
from sim_class import Simulation
import pandas as pd
from stable_baselines3 import PPO
from ot2_gym_wrapper_team import OT2_wrapper
import os
from tensorflow.keras.models import load_model
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K


from patchify import patchify, unpatchify

In [2]:
env = OT2_wrapper(render=False, max_steps=1000, accuracy_threshold=0.001)
obs, info = env.reset()
sim = env.sim # 

In [4]:
image_path = sim.get_plate_image()

In [5]:
# Initialize the environment
env = OT2_wrapper(render=False, max_steps=1000, accuracy_threshold=0.001)
obs, info = env.reset()

# Access the simulation instance
sim = env.sim

# Function to process and display each image
def process_image(image_path, save_dir, index):
    # Load the image in grayscale mode
    image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)

    # Validate if the image is loaded correctly
    if image is None:
        print(f"Failed to load the image from the path: {image_path}")
        return


    # Save the original grayscale image for further processing
    output_path = os.path.join(save_dir, f"image_{index}.png")
    cv2.imwrite(output_path, image)
    print(f"Image {index} saved to {output_path}")

# Directory to save the cropped and padded images
save_directory = "processed_images"
os.makedirs(save_directory, exist_ok=True)

# Loop through all images
print("Processing all plate images...")
for i in range(10):  # Replace 10 with the actual number of images you want to process
    try:
        # Retrieve the image path
        image_path = sim.get_plate_image()
        process_image(image_path, save_directory, i)
    except Exception as e:
        print(f"Error processing image {i}: {e}")

cv2.destroyAllWindows()
print("All images processed!")


Processing all plate images...
Image 0 saved to processed_images\image_0.png
Image 1 saved to processed_images\image_1.png
Image 2 saved to processed_images\image_2.png
Image 3 saved to processed_images\image_3.png
Image 4 saved to processed_images\image_4.png
Image 5 saved to processed_images\image_5.png
Image 6 saved to processed_images\image_6.png
Image 7 saved to processed_images\image_7.png
Image 8 saved to processed_images\image_8.png
Image 9 saved to processed_images\image_9.png
All images processed!


In [6]:
# Initialize the environment
env = OT2_wrapper(render=False, max_steps=1000, accuracy_threshold=0.001)
obs, info = env.reset()

# Access the simulation instance
sim = env.sim

# Directory to save the processed images
save_directory = "processed_images"
os.makedirs(save_directory, exist_ok=True)

# Function to process and save each image
def process_and_save_image(image_path, save_dir, index):
    """
    Load, validate, and save an image from the given path.
    """
    # Load the image in grayscale mode
    image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)

    # Validate if the image is loaded correctly
    if image is None:
        print(f"Failed to load the image from the path: {image_path}")
        return None

    # Save the grayscale image for further processing
    output_path = os.path.join(save_dir, f"image_{index}.png")
    cv2.imwrite(output_path, image)
    print(f"Image {index} saved to {output_path}")

    return image

# Preprocessing helper functions
def detect_edges(image, max_size=2800):
    """
    Detect edges and find the approximate square Petri dish.
    """
    blurred_image = cv2.GaussianBlur(image, (51, 51), 0)
    sobel_x = cv2.Sobel(blurred_image, cv2.CV_64F, 1, 0, ksize=5)
    sobel_y = cv2.Sobel(blurred_image, cv2.CV_64F, 0, 1, ksize=5)
    gradient_magnitude = cv2.magnitude(sobel_x, sobel_y)
    _, edges = cv2.threshold(gradient_magnitude, 50, 255, cv2.THRESH_BINARY)
    edges = edges.astype(np.uint8)

    contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    max_contour = max(contours, key=cv2.contourArea)
    x, y, w, h = cv2.boundingRect(max_contour)

    side_length = min(max(w, h), max_size)
    center_x, center_y = x + w // 2, y + h // 2
    half_side = side_length // 2
    new_x = max(center_x - half_side, 0)
    new_y = max(center_y - half_side, 0)
    new_w = new_h = min(side_length, min(image.shape[1] - new_x, image.shape[0] - new_y))
    return new_x, new_x + new_w, new_y, new_y + new_h

def crop_image(image, edges):
    """
    Crop the image based on detected edges.
    """
    left, right, top, bottom = edges
    return image[top:bottom, left:right]

def padder(image, patch_size=256):
    """
    Add padding to make dimensions divisible by the patch size.
    """
    h, w = image.shape[:2]
    height_padding = ((h // patch_size) + 1) * patch_size - h
    width_padding = ((w // patch_size) + 1) * patch_size - w
    top_padding = int(height_padding / 2)
    bottom_padding = height_padding - top_padding
    left_padding = int(width_padding / 2)
    right_padding = width_padding - left_padding
    return cv2.copyMakeBorder(image, top_padding, bottom_padding, left_padding, right_padding, cv2.BORDER_CONSTANT, value=0)

def preprocess_image(image, patch_size=256):
    """
    Preprocess the image: pad, patchify, and normalize.
    """
    padded_image = padder(image, patch_size)
    patches = patchify(padded_image, (patch_size, patch_size), step=patch_size)
    patches_reshaped = patches.reshape(-1, patch_size, patch_size, 1)  # Add channel dimension
    patches_normalized = patches_reshaped / 255.0
    return patches, patches_normalized

# Process all images and apply preprocessing
print("Processing all plate images...")
preprocessed_data = []
patch_size = 256

for i in range(10):  # Adjust range as needed
    try:
        # Retrieve the image path from the simulation
        image_path = sim.get_plate_image()

        # Process and save the image
        image = process_and_save_image(image_path, save_directory, i)

        if image is not None:
            # Apply preprocessing
            edges = detect_edges(image)
            cropped_image = crop_image(image, edges)
            padded_image = padder(cropped_image, patch_size)
            patches, patches_normalized = preprocess_image(padded_image, patch_size)

            preprocessed_data.append((f"image_{i}", patches_normalized))

            # Print progress
            print(f"Processed image_{i}: Original shape {image.shape}, "
                  f"Padded shape {padded_image.shape}, Patches {patches.shape}")
    except Exception as e:
        print(f"Error processing image {i}: {e}")

print(f"Total images processed: {len(preprocessed_data)}")


Processing all plate images...
Image 0 saved to processed_images\image_0.png
Processed image_0: Original shape (3006, 4202), Padded shape (2816, 2816), Patches (12, 12, 256, 256)
Image 1 saved to processed_images\image_1.png
Processed image_1: Original shape (3006, 4202), Padded shape (2816, 2816), Patches (12, 12, 256, 256)
Image 2 saved to processed_images\image_2.png
Processed image_2: Original shape (3006, 4202), Padded shape (2816, 2816), Patches (12, 12, 256, 256)
Image 3 saved to processed_images\image_3.png
Processed image_3: Original shape (3006, 4202), Padded shape (2816, 2816), Patches (12, 12, 256, 256)
Image 4 saved to processed_images\image_4.png
Processed image_4: Original shape (3006, 4202), Padded shape (2816, 2816), Patches (12, 12, 256, 256)
Image 5 saved to processed_images\image_5.png
Processed image_5: Original shape (3006, 4202), Padded shape (2816, 2816), Patches (12, 12, 256, 256)
Image 6 saved to processed_images\image_6.png
Processed image_6: Original shape (

In [7]:
def f1(y_true, y_pred):
    def recall_m(y_true, y_pred):
        TP = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        Positives = K.sum(K.round(K.clip(y_true, 0, 1)))
        recall = TP / (Positives+K.epsilon())
        return recall
    
    def precision_m(y_true, y_pred):
        TP = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        Pred_Positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
        precision = TP / (Pred_Positives+K.epsilon())
        return precision
    
    precision, recall = precision_m(y_true, y_pred), recall_m(y_true, y_pred)
    
    return 2*((precision*recall)/(precision+recall+K.epsilon()))

In [14]:
model_path = "model.h5"
model = load_model(model_path)


ValueError: Unknown metric function: f1. Please ensure this object is passed to the `custom_objects` argument. See https://www.tensorflow.org/guide/keras/save_and_serialize#registering_the_custom_object for details.

In [13]:
# **Run predictions**
predictions = []
patch_size = 256  # Assuming patches are 256x256 as specified in preprocessing

for filename, patches_normalized in preprocessed_data:
    try:
        # Predict on patches
        predicted_patches = model.predict(patches_normalized, batch_size=16)

        # Inspect the shape of predicted patches
        print(f"Predicted patches shape for {filename}: {predicted_patches.shape}")

        # Reshape predictions back to grid dimensions (num_patches_y, num_patches_x, patch_size, patch_size)
        num_patches_y, num_patches_x = 12, 12  # Replace with actual grid dimensions if dynamic
        predicted_patches = predicted_patches.reshape((num_patches_y, num_patches_x, patch_size, patch_size))

        # Drop the last channel dimension (if single channel output)
        predicted_patches = np.squeeze(predicted_patches, axis=-1)

        # Reconstruct the full image using unpatchify
        petri_dish_padded_shape = (num_patches_y * patch_size, num_patches_x * patch_size)
        reconstructed_prediction = unpatchify(predicted_patches, petri_dish_padded_shape)

        # Store the prediction
        predictions.append((filename, reconstructed_prediction))
        print(f"Prediction completed for {filename}")

    except Exception as e:
        print(f"Error processing {filename}: {e}")

# **Visualize predictions**
for filename, predicted_mask in predictions:
    plt.figure(dpi=100)
    plt.imshow(predicted_mask, cmap='gray')
    plt.title(f"Prediction for {filename}")
    plt.axis('off')
    plt.show()


Predicted patches shape for image_0: (144, 256, 256, 1)
Error processing image_0: cannot select an axis to squeeze out which has size not equal to one
Predicted patches shape for image_1: (144, 256, 256, 1)
Error processing image_1: cannot select an axis to squeeze out which has size not equal to one
Predicted patches shape for image_2: (144, 256, 256, 1)
Error processing image_2: cannot select an axis to squeeze out which has size not equal to one
Predicted patches shape for image_3: (144, 256, 256, 1)
Error processing image_3: cannot select an axis to squeeze out which has size not equal to one
Predicted patches shape for image_4: (144, 256, 256, 1)
Error processing image_4: cannot select an axis to squeeze out which has size not equal to one
Predicted patches shape for image_5: (144, 256, 256, 1)
Error processing image_5: cannot select an axis to squeeze out which has size not equal to one
Predicted patches shape for image_6: (144, 256, 256, 1)
Error processing image_6: cannot selec

In [12]:
print(predicted_patches.shape)


(144, 256, 256, 1)
