In [1]:
import torch
import torchvision
import matplotlib.pyplot as plt
import numpy as np
from skimage import io, segmentation, color
import networkx as nx
import cv2

In [2]:
def load_image(filepath):
    image_path = filepath
    image = io.imread(image_path)

    # If the image has 4 channels (RGBA), remove the alpha channel
    if image.shape[-1] == 4: image = image[..., :3]
        
    return image

In [3]:
def show_images(images):
    n = len(images)

    fig, axes = plt.subplots(1, n, figsize=(5 * n, 5))
    if n == 1: axes = [axes]

    for i, ax in enumerate(axes):
        if len(images[i].shape) == 3: ax.imshow(images[i])
        else: ax.imshow(images[i], cmap='gray')
        
        ax.axis('off')

    plt.show()

In [4]:
import cv2
import numpy as np
import torch

# Load YOLOv5 model (keeping it as an option for detection)
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)

def isolate_objects(image):
    """
    Isolate objects in an image by detecting them using YOLO and removing the background using
    color thresholding and contour-based segmentation.
    
    Args:
    - image: np.array, input image (RGB).
    
    Returns:
    - isolated_image: np.array, image with objects isolated and background removed.
    """
    
    # Convert the image to the HSV color space for better color segmentation
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
    
    # Define color thresholds (adjust these based on actual colors)
    lower_color = np.array([0, 50, 50])
    upper_color = np.array([180, 255, 255])

    # Create a binary mask where the detected colors are in the range and the background is black
    mask = cv2.inRange(hsv, lower_color, upper_color)
    
    # Apply morphological operations to clean up the mask
    kernel = np.ones((5, 5), np.uint8)
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
    mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
    
    # Find contours based on the mask
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    # Create a blank mask to draw the contours
    object_mask = np.zeros_like(mask)

    # Draw filled contours for each detected object on the mask
    cv2.drawContours(object_mask, contours, -1, 255, thickness=cv2.FILLED)
    
    # Create a 3-channel mask from the single-channel mask
    object_mask_3channel = cv2.merge([object_mask, object_mask, object_mask])
    
    # Apply the mask to the image to remove the background
    isolated_image = cv2.bitwise_and(image, object_mask_3channel)

    # Check the isolated image
    cv2.imshow('Isolated Image', isolated_image)
    cv2.waitKey(0)

    return isolated_image


Using cache found in /home/theo/.cache/torch/hub/ultralytics_yolov5_master
YOLOv5 🚀 2024-10-11 Python-3.12.6 torch-2.4.1+cu121 CPU

Fusing layers... 
YOLOv5s summary: 213 layers, 7225885 parameters, 0 gradients, 16.4 GFLOPs
Adding AutoShape... 


In [None]:
# Load the image (replace 'path_to_image.jpg' with your actual image path)
image = cv2.imread('../COMP90086_2024_Project_train/train/173.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # Convert from BGR to RGB

# Isolate objects and remove the background
isolated_image = isolate_objects(image)

show_images([isolated_image])

for i in isolated_image: print(i)

# Display the original and isolated images
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)
plt.title('Original Image')
plt.imshow(image)

plt.subplot(1, 2, 2)
plt.title('Isolated Objects with Black Background')
plt.imshow(isolated_image)

plt.show()



In [8]:
img = load_image('../COMP90086_2024_Project_train/train/173.jpg')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
isolated_img = isolate_objects(img)
print(isolated_img)
show_images([isolated_img])

[[[ 68  52  35]
  [ 68  52  35]
  [ 68  52  35]
  ...
  [ 68  52  35]
  [ 68  52  35]
  [ 68  52  35]]

 [[ 68  52  35]
  [ 68  52  35]
  [ 68  52  35]
  ...
  [ 68  52  35]
  [ 68  52  35]
  [ 68  52  35]]

 [[ 68  52  35]
  [ 68  52  35]
  [ 68  52  35]
  ...
  [ 68  52  35]
  [ 68  52  35]
  [ 68  52  35]]

 ...

 [[176 186 216]
  [ 99 110 140]
  [ 67  78 108]
  ...
  [ 86 101 127]
  [ 61  75 103]
  [142 156 184]]

 [[198 209 241]
  [142 155 187]
  [134 147 179]
  ...
  [103 119 142]
  [136 151 177]
  [152 169 195]]

 [[185 198 230]
  [164 177 209]
  [160 175 208]
  ...
  [ 64  80 103]
  [ 74  91 117]
  [108 125 151]]]
