In [1]:
import os
from ultralytics import SAM
import matplotlib.pyplot as plt
import glob
import cv2
import numpy as np
from PIL import Image

In [2]:
# Class folders should be organized as follows:
#
# raw/
# └── imgs_to_mask/
#     ├── darknet/
#     │   ├── 1.txt
#     │   └── ...
#     └── images/
#         ├── 1.jpg
#         └── ...

# Darknet files should have the following format:
# class_name x_center y_center width height
#
# e.g.
# Screwdriver 0.5 0.5 0.2 0.2

name = 'screwdriver'
class_folder = f'../data/raw/{name}/'
output_dir = f'../data/raw/{name}_masked'


In [3]:
def xywh_to_bbox(x, y, w, h):
    """
    Convert (x, y, w, h) to (x1, y1, x2, y2) format.
    """
    x1 = x
    y1 = y
    x2 = x + w
    y2 = y + h
    return [x1, y1, x2, y2]

def segment_images_from_folder_bbox(folder_path):
    """
    Segments images in the specified folder using the SAM model with bbox information.
    Assumes folder_path contains two folders: 'images' and 'darknet'.
    Each image in 'images' should have a corresponding label file in 'darknet' with
    bounding box information in the format: x y w h (where x, y are the
    top-left corner coordinates and w, h are the width and height of the bounding box).
    """
    model = SAM("sam2.1_b.pt")
    
    for image_path, bbox_path in zip(sorted(glob.glob(os.path.join(folder_path, 'images', '*.jpg'))), 
                                     sorted(glob.glob(os.path.join(folder_path, 'darknet', '*.txt')))):
        x, y, w, h = 0, 0, 0, 0
        image_dimensions = plt.imread(image_path).shape
        if sum([1 for line in open(bbox_path) if line.strip()]) != 1:
            print(f"There should be exactly one line in {bbox_path}.")
            continue

        with open(bbox_path, 'r') as f:
            line = f.readline().strip()
            if line:
                label = float(line.split()[0])  # Read the class label (first element)
                x, y, w, h = map(float, line.split()[1:5]) # Bounding box coordinates
                x, y, w, h = int(x * image_dimensions[1]), int(y * image_dimensions[0]), int(w * image_dimensions[1]), int(h * image_dimensions[0])
                x = x - w // 2
                y = y - h // 2

        # Predict segmentation using the SAM model with bounding box
        results = model(image_path, bboxes=xywh_to_bbox(x, y, w, h))
        for result in results:
            # Display the image with the segmentation mask
            # result.show()

            # Access the masks
            masks = results[0].masks

            # Assuming single class segmentation for simplicity, adjust as needed
            mask = masks[0].data.squeeze().cpu().numpy()  # For multi-class, iterate over masks
            mask = mask.astype(np.uint8) # Convert mask to uint8 if needed)
            mask = cv2.resize(mask, (image_dimensions[1], image_dimensions[0]))
            
            image = cv2.imread(image_path)
            image = cv2.resize(image, (image_dimensions[1], image_dimensions[0]))
            
            # Negate the mask and mask the image
            negative_mask = 1-mask
            negative_image = cv2.bitwise_not(image)
            negative_image = cv2.bitwise_and(negative_image, negative_image, mask=mask)
            masked_image = cv2.bitwise_not(negative_image)
            
            # Uncomment to see the negated masks
            # plt.imshow(negative_mask, cmap='gray')
            # plt.axis('off')
            # plt.show()

            # Uncomment to plot the masked images
            # plt.imshow(cv2.cvtColor(masked_image, cv2.COLOR_BGR2RGB))
            # plt.axis('off')
            # plt.show()

            os.mkdir(output_dir) if not os.path.exists(output_dir) else None
            cv2.imwrite(os.path.join(output_dir, os.path.basename(image_path)), image)
            # cv2.imwrite(os.path.join(output_dir, os.path.basename(image_path).split('.')[0] + '_masked.jpg'), masked_image)
            cv2.imwrite(os.path.join(output_dir, os.path.basename(image_path).split('.')[0] + '_mask.png'), negative_mask*255)

In [15]:
segment_images_from_folder_bbox(class_folder)


image 1/1 /home/decla_5ay7wb/RIPS25-AnalogDevices-ObjectDetection/src/../data/raw/screwdriver/images/000872c905a57b65.jpg: 1024x1024 1 0, 116.5ms
Speed: 2.5ms preprocess, 116.5ms inference, 0.6ms postprocess per image at shape (1, 3, 1024, 1024)

image 1/1 /home/decla_5ay7wb/RIPS25-AnalogDevices-ObjectDetection/src/../data/raw/screwdriver/images/00d9a083e1cba929.jpg: 1024x1024 1 0, 99.3ms
Speed: 2.5ms preprocess, 99.3ms inference, 0.7ms postprocess per image at shape (1, 3, 1024, 1024)

image 1/1 /home/decla_5ay7wb/RIPS25-AnalogDevices-ObjectDetection/src/../data/raw/screwdriver/images/01e80d7cf80f3921.jpg: 1024x1024 1 0, 98.3ms
Speed: 2.6ms preprocess, 98.3ms inference, 0.7ms postprocess per image at shape (1, 3, 1024, 1024)

image 1/1 /home/decla_5ay7wb/RIPS25-AnalogDevices-ObjectDetection/src/../data/raw/screwdriver/images/02a90407e6bf3ef8.jpg: 1024x1024 1 0, 99.7ms
Speed: 2.5ms preprocess, 99.7ms inference, 0.6ms postprocess per image at shape (1, 3, 1024, 1024)

image 1/1 /home/d

In [None]:
if not os.path.exists(f"../data/cut_and_paste_root/{name}/"):
    os.makedirs(f"../data/cut_and_paste_root/{name}")
os.system(f"cp {os.path.join(output_dir, '*')} ../data/cut_and_paste_root/{name}/")

0

In [6]:
os.system(f'python Cut-and-Paste/dataset_generator.py --scale --rotation --num 1 ../data/cut_and_paste_root ../data/processed/{name}')

Number of background images : 8124
List of distractor files collected: []
Working on ../data/processed/screwdriver/images/1_none.jpg
Working on ../data/processed/screwdriver/images/2_none.jpg
Working on ../data/processed/screwdriver/images/3_none.jpg
Working on ../data/processed/screwdriver/images/4_none.jpg


  backgrounds[i] = Image.fromarray(background_array, 'RGB')
  backgrounds[i] = Image.fromarray(background_array, 'RGB')
  backgrounds[i] = Image.fromarray(background_array, 'RGB')
  blurred_img = Image.fromarray(blurred_img, 'RGB')


Working on ../data/processed/screwdriver/images/5_none.jpg


  blurred_img = Image.fromarray(blurred_img, 'RGB')


Working on ../data/processed/screwdriver/images/6_none.jpg


  blurred_img = Image.fromarray(blurred_img, 'RGB')
  backgrounds[i] = Image.fromarray(background_array, 'RGB')


Working on ../data/processed/screwdriver/images/7_none.jpg


  blurred_img = Image.fromarray(blurred_img, 'RGB')


Working on ../data/processed/screwdriver/images/8_none.jpg
Working on ../data/processed/screwdriver/images/9_none.jpg
Working on ../data/processed/screwdriver/images/10_none.jpg
Working on ../data/processed/screwdriver/images/11_none.jpg
Working on ../data/processed/screwdriver/images/12_none.jpg
Working on ../data/processed/screwdriver/images/13_none.jpg
Working on ../data/processed/screwdriver/images/14_none.jpg


0

In [66]:
# Get image size in bytes
print(cv2.imread('data/cut_and_paste_croissants/images/8_poisson.jpg').shape)

(1920, 2560, 3)
