# Import Libraries

In [None]:
import os
import random
from tkinter import Image

import torch
import torchvision

import numpy as np

# Download Pre-trained model

In [None]:
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained='defaults')
_ = model.eval()

# Utilities

In [None]:
def random_color_masks(image, colors):
    """
    Assigns a random color to each predicted object mask.

    Args:
        image: (numpy.ndarray): Binary mask where detected objects are 1.
        colors: (list): List of RGB colors.

    returns:
        numpy.ndarray: Colored mask of the same shape as the input
    """

    # Create empty color channels
    r, g, b = np.zeros_like(image, dtype=np.uint8), np.zeros_like(image, dtype=np.uint8), np.zeros_like(image, dtype=np.uint8)

    # Assign a random color to the mask
    r[image == 1], g[image == 1], b[image == 1] = colors[random.randint(0, len(colors))]

    # Stack the channels to create a colored mask
    return  np.stack([r, g, b], axis=2)

def get_prediction(img_path, threshold):
    """
    Performs object instance segmentation on an image.

    Args:
        img_path: (str): Path to the image.
        threshold: (float): Confidence threshold for filtering predictions.

    Returns:
        tuple: (mask, bounding_boxes, scores, classes)
    """

    img = Image.open(img_path)

    # Convert image to pytorch tensor
    img_tensor = T.ToTensor()(img)

    # Performs inference with the model in no-grad mode to save memory
    with torch.no_grad():
        prediction = model(img_tensor)

    # Extract predictions and convert to NumPy arrays
    pred_data = prediction[0]
    scores = pred_data['scores'].detach().cpu().numpy()
    labels = pred_data['labels'].detach().cpu().numpy()
    boxes = pred_data['boxes'].detach().cpu().numpy()
    masks = (pred_data['masks'] > 0.5).squeeze().detach().cpu().numpy()

    # Get index of last precition above the threshold
    valid_indices = scores > threshold
    if not np.any(valid_indices):
        return [], [], []
    # Select predictions that meet the confidence threshold
    pred_boxes = [[(b[0], b[1]), (b[2], b[3])] for b in boxes[valid_indices]]
    pred_classes = [CATEGORY[i] for i in labels[valid_indices]]
    masks = masks[valid_indices]

    return masks, pred_boxes, pred_classes, scores[valid_indices]

# Perform instance segmentation and visualize the result
def instance_segmentation(img_path, threshold=0.5):
    """
    Performs instance segmentation on an image and displays the results.

    Args:
        img_path: (str): Path to the image.
        threshold: (float): Confidence threshold for filtering predictions.
    """
    # Get segmentation predictions
    masks, boxes, pred_cls = get_prediction(img_path, threshold)

    # Load the image using OpenCV and convert from BGR to RGB
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # Calculate rectangle and text thickness based on image size
    rect_th = max(round(sum(img.shape) / 2 * 0.003), 2)
    text_th = max(rect_th - 1, 1)

    # iterate over detected objects
    for i in range(len(masks)):
        # Get bounding box coordinates
        p1, p2 = (int(boxes[i][0][0]), int(boxes[i][0][1])), (int(boxes[i][1][0]), int(boxes[i][1][1]))

        # Generate a random color mask for the object
        rgb_mask = random_color_masks(masks[i], COLORS)

        # Blend mask with original image using weighted sum
        img = cv2.addWeighted(img, 1, rgb_mask, 0.5, 0)

        # Select bbox color based on object class
        color = COLORS[CATEGORY.index(pred_cls[i])]

        # Draw bbox around the detected object
        cv2.rectangle(img, p1, p2, color, thickness=rect_th)

        # Get text size for label display
        w, h = cv2.getTextSize(pred_cls[i], 0, fontScale=rect_th / 3, thickness=text_th)
        outside = p1[1] - h >= 3 # Determine whether label should be inside or outside
        p2 = (p1[0] + w, p1[1] - h - 3) if outside else (p1[0] + w, p1[1] + h + 3)

        # Draw filled rectangle as background for the class label
        cv2.rectangle(img, p1, p2, color, -1, lineType=cv2.LINE_AA)

        # Overlay the object layer on image
        cv2.putText(
            img,
            pred_cls[i],
            (p1[0], p1[1] - 5 if outside else p1[1] + h + 2),
            cv2.FONT_HERSHEY_SIMPLEX,
            rect_th / 3,
            (255, 255, 255),
            thickness=text_th + 1,
        )

    # Display the final image with masks and bbox
    plt.figure(figsize=(20, 17))
    plt.imshow(img)
    plt.axis("off")
    plt.show()