In [None]:
from ultralytics import YOLO
import cv2
import numpy as np
import os

# Load the YOLO model with error handling
try:
    model = YOLO('runs/segment/train10/weights/best.pt')
except Exception as e:
    print(f"Error loading YOLO model: {e}")
    exit()

# Define a list of colors for each class
colors = [(0, 0, 255), (0, 255, 0), (255, 0, 0), (255, 255, 0), (0, 255, 255)]  # Example colors, you can define your own

# Iterate through each image in the folder
for filename in os.listdir('test_segment'):
    if filename.endswith('.png') or filename.endswith('.jpg'):
        # Perform object detection on the image
        try:
            image_path = os.path.join('test_segment', filename)
            predict = model.predict(image_path, save=True, save_txt=True)
            masks = [mask.data.cpu().numpy() for mask in predict[0].masks]
        except Exception as e:
            print(f"Error performing object detection on {filename}: {e}")
            continue

        # Load the original image
        original_image = cv2.imread(image_path)

        # Create a white background
        background_color = (255, 255, 255)  # Specify background color (white in BGR format)
        background = np.full_like(original_image, background_color)

        # Iterate through each mask and merge onto the white background with a different color for each class
        for idx, mask in enumerate(masks):
            # Resize the mask to match the shape of the original image
            resized_mask = cv2.resize(mask[0], (original_image.shape[1], original_image.shape[0]))

            # Convert the mask to binary
            binary_mask = resized_mask > 0

            # Apply the mask to the background using a different color for each class
            color = colors[idx % len(colors)]
            background[binary_mask] = color

        # Display the result image
        cv2.imshow(f"Objects with colored background - {filename}", background)

# Wait for a key press or a timeout (e.g., 100 milliseconds)
cv2.waitKey(0)

# Close all OpenCV windows
cv2.destroyAllWindows()
