In [1]:
import cv2
import random
import glob
import matplotlib.pyplot as plt
import numpy as np

np.random.seed(42)

In [2]:
# Define class names and colors
class_names = ['tumor']
colors = np.random.uniform(0, 255, size=(len(class_names), 3))

# Function to plot segmentation masks on the image
def plot_segmentation(image, polygons, labels):
    h, w, _ = image.shape

    for polygon_num, polygon in enumerate(polygons):
        class_name = class_names[int(labels[polygon_num])]
        color = colors[class_names.index(class_name)]

        # Denormalize the polygon points
        points = []
        for i in range(0, len(polygon), 2):
            x = int(float(polygon[i]) * w)
            y = int(float(polygon[i + 1]) * h)
            points.append([x, y])

        # Convert points to a NumPy array for OpenCV functions
        points = np.array(points, np.int32).reshape((-1, 1, 2))

        # Draw the segmentation mask
        cv2.polylines(image, [points], isClosed=True, color=color, thickness=2)
        cv2.fillPoly(image, [points], color=color)

        # Draw the class label
        centroid_x = int(np.mean(points[:, 0, 0]))
        centroid_y = int(np.mean(points[:, 0, 1]))
        font_scale = 0.5
        font_thickness = 1
        cv2.putText(image, class_name, (centroid_x, centroid_y - 10), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), font_thickness)

    return image

In [3]:
# Function to plot images with segmentation masks
def plot(image_paths, label_paths, num_samples):
    all_images = sorted(glob.glob(image_paths))
    all_labels = sorted(glob.glob(label_paths))

    if not all_images or not all_labels:
        print("Error: No images or labels found. Check the paths.")
        return

    num_images = len(all_images)
    plt.figure(figsize=(15, 12))

    for i in range(num_samples):
        idx = random.randint(0, num_images - 1)
        image = cv2.imread(all_images[idx])

        if image is None:
            print(f"Error: Could not read image {all_images[idx]}")
            continue

        polygons = []
        labels = []

        with open(all_labels[idx], 'r') as f:
            for line in f.readlines():
                elements = line.split()
                label = int(elements[0])
                polygon_points = elements[1:]  # Extract the polygon points
                polygons.append(polygon_points)
                labels.append(label)

        result_image = plot_segmentation(image, polygons, labels)
        plt.subplot(2, 2, i + 1)
        plt.imshow(result_image[:, :, ::-1])
        plt.axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
plot(
    image_paths='/home/users/xt37/zz/data/train/images/*',
    label_paths='/home/users/xt37/zz/data/train/labels/*',
    num_samples=4
)