In [None]:
# Import required libraries
import os
import sys
import numpy as np
import skimage.io
import matplotlib.pyplot as plt

# Set the path for Mask R-CNN
ROOT_DIR = os.path.abspath("Mask_RCNN")  # Adjust if necessary
sys.path.append(ROOT_DIR)

# Import Mask R-CNN modules
from mrcnn import utils
import mrcnn.model as modellib
from mrcnn import visualize
sys.path.append(os.path.join(ROOT_DIR, "samples/coco/"))
import coco  # COCO dataset configurations

# Model configuration for inference
class InferenceConfig(coco.CocoConfig):
    GPU_COUNT = 1  # Use only 1 GPU
    IMAGES_PER_GPU = 1  # Process one image at a time

config = InferenceConfig()

# Paths for model files
MODEL_DIR = os.path.join(ROOT_DIR, "logs")  # Directory for logs
COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")  # Pre-trained model

# Download COCO trained weights if not already present
if not os.path.exists(COCO_MODEL_PATH):
    utils.download_trained_weights(COCO_MODEL_PATH)

# Load Mask R-CNN model in inference mode
model = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR, config=config)
model.load_weights(COCO_MODEL_PATH, by_name=True)

# COCO dataset class names
class_names = ['BG', 'person', 'bicycle', 'car', 'motorcycle', 'airplane',
               'bus', 'train', 'truck', 'boat', 'traffic light',
               'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird',
               'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear',
               'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
               'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
               'kite', 'baseball bat', 'baseball glove', 'skateboard',
               'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
               'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
               'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
               'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
               'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
               'keyboard', 'cell phone', 'microwave', 'oven', 'toaster',
               'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
               'teddy bear', 'hair drier', 'toothbrush']

# ✅ Get image path from user input
image_path = input("Enter the full path of your image: ")

# ✅ Check if the file exists
if not os.path.exists(image_path):
    print("Error: The file does not exist! Please check the path.")
else:
    # ✅ Load the input image
    image = skimage.io.imread(image_path)
    print("Image successfully loaded!")

    # ✅ Perform object detection
    results = model.detect([image], verbose=1)
    r = results[0]

    # ✅ Display detected objects with masks and bounding boxes
    visualize.display_instances(image, r['rois'], r['masks'], r['class_ids'],
                                class_names, r['scores'])

    # ✅ Function to extract and display only the segmented object
    def segment(image, r):
        idx = r['scores'].argmax()  # Get the most confident detected object
        mask = r['masks'][:, :, idx]  # Extract its mask
        mask = np.stack((mask,) * 3, axis=-1)  # Convert to 3-channel (RGB)
        mask = mask.astype('uint8')  # Convert to integer format
        bg = 255 - mask * 255  # Create a white background
        mask_img = image * mask  # Apply the mask to the original image
        result = mask_img + bg  # Combine segmented object with background
        return result

    # ✅ Perform segmentation on the detected object
    segmented_image = segment(image, r)

    # ✅ Display the original and segmented images side by side
    plt.figure(figsize=(16, 8))

    # Show original image
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title("Original Image")
    plt.axis('off')

    # Show segmented image
    plt.subplot(1, 2, 2)
    plt.imshow(segmented_image)
    plt.title("Segmented Image")
    plt.axis('off')

    plt.show()
