In [3]:
import torch
from torchvision import transforms
import torchvision
from PIL import Image
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os

In [5]:
yolov5_model = torch.hub.load('ultralytics/yolov5', 'yolov5s')

mask_rcnn_model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
mask_rcnn_model.eval()

def load_image(image_path):
    image = cv2.imread(image_path)
    if image is None:
        raise FileNotFoundError(f"Image file not found at {image_path}")
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    return image, image_rgb

def detect_objects(image):
    results = yolov5_model(image)
    return results.xyxy[0].cpu().numpy()

def segment_objects(image, boxes):
    image_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
    masks = []
    for i, box in enumerate(boxes):
        x1, y1, x2, y2 = map(int, box[:4])
        cropped_image = image_tensor[:, y1:y2, x1:x2].unsqueeze(0)
        with torch.no_grad():
            output = mask_rcnn_model(cropped_image)
        mask = output[0]['masks'][0, 0].mul(255).byte().cpu().numpy()
        masks.append((x1, y1, x2, y2, mask))
    return masks

def draw_boxes_and_masks(image, boxes, masks):
    fig, ax = plt.subplots(1, figsize=(12, 8))
    ax.imshow(image)
    for i, (box, mask_info) in enumerate(zip(boxes, masks)):
        x1, y1, x2, y2, mask = mask_info
        label = f'{yolov5_model.names[int(box[5])]} {box[4]:.2f}'
        rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, color='red', linewidth=2)
        ax.add_patch(rect)
        ax.text(x1, y1 - 10, label, color='red', fontsize=12, backgroundcolor="none")

        resized_mask = cv2.resize(mask, (x2 - x1, y2 - y1))
        colored_mask = np.zeros_like(image, dtype=np.uint8)
        colored_mask[y1:y2, x1:x2][resized_mask > 127] = (0, 255, 0)

        image = cv2.addWeighted(image, 1, colored_mask, 0.5, 0)

    ax.imshow(image)
    plt.axis('off')
    plt.show()
    return image

def main(image_path):
    image, image_rgb = load_image(image_path)
    boxes = detect_objects(image_rgb)
    vehicle_boxes = [box for box in boxes if int(box[5]) in [2, 5, 7]]
    masks = segment_objects(image_rgb, vehicle_boxes)
    image_with_masks = draw_boxes_and_masks(image_rgb, vehicle_boxes, masks)
    output_image = cv2.cvtColor(image_with_masks, cv2.COLOR_RGB2BGR)
    cv2.imwrite('output_image_with_masks.jpg', output_image)

    fig, ax = plt.subplots(1, 2, figsize=(20, 10))

    return [fig, ax, image_with_masks]

Using cache found in /home/vscode/.cache/torch/hub/ultralytics_yolov5_master


Defaulting to user installation because normal site-packages is not writeable
Collecting ultralytics
  Downloading ultralytics-8.2.17-py3-none-any.whl.metadata (40 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.7/40.7 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
Collecting scipy>=1.4.1 (from ultralytics)
  Downloading scipy-1.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.6/60.6 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
Collecting py-cpuinfo (from ultralytics)
  Downloading py_cpuinfo-9.0.0-py3-none-any.whl.metadata (794 bytes)
Collecting thop>=0.1.1 (from ultralytics)
  Downloading thop-0.1.1.post2209072238-py3-none-any.whl.metadata (2.7 kB)
Collecting seaborn>=0.11.0 (from ultralytics)
  Downloading seaborn-0.13.2-py3-none-any.whl.metadata (5.4 kB)
Downloading ultralytics-8.2.17-py3-none-any.whl (757 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

YOLOv5 🚀 2024-5-18 Python-3.11.9 torch-2.2.1+cu121 CPU

Downloading https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5s.pt to yolov5s.pt...
100%|██████████| 14.1M/14.1M [00:01<00:00, 11.8MB/s]

Fusing layers... 
YOLOv5s summary: 213 layers, 7225885 parameters, 0 gradients, 16.4 GFLOPs
Adding AutoShape... 
Downloading: "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth" to /home/vscode/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth
100%|██████████| 170M/170M [00:15<00:00, 11.8MB/s] 


In [6]:
image_path = '../data/example/hello.jpg'
fig, ax, image_with_masks = main(image_path)

In [7]:
ax[0].imshow(cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB))
ax[0].set_title("Original Image")
ax[0].axis('off')

ax[1].imshow(image_with_masks)
ax[1].set_title("Image with Masks")
ax[1].axis('off')

plt.show()