In [10]:
import os
from ultralytics import SAM
import matplotlib.pyplot as plt
import glob
import cv2
import numpy as np
from PIL import Image

In [57]:
def xywh_to_bbox(x, y, w, h):
    """
    Convert (x, y, w, h) to (x1, y1, x2, y2) format.
    """
    x1 = x
    y1 = y
    x2 = x + w
    y2 = y + h
    return [x1, y1, x2, y2]

def segment_images_from_folder_bbox(folder_path, object_type):
    """
    Segments images in the specified folder using the SAM model with bbox information.
    Assumes folder_path contains two folders: 'images' and 'darknet'.
    Each image in 'images' should have a corresponding label file in 'darknet' with
    bounding box information in the format: x y w h (where x, y are the
    top-left corner coordinates and w, h are the width and height of the bounding box).
    """
    model = SAM("sam2.1_b.pt")
    
    for image_path, bbox_path in zip(sorted(glob.glob(os.path.join(folder_path, 'images', '*.jpg'))), 
                                     sorted(glob.glob(os.path.join(folder_path, 'darknet', '*.txt')))):
        x, y, w, h = 0, 0, 0, 0
        image_dimensions = plt.imread(image_path).shape
        with open(bbox_path, 'r') as f:
            line = f.readline().strip()
            if line:
                x, y, w, h = map(float, line.split()[1:5])  # Skip the class label (first element)
                x, y, w, h = int(x * image_dimensions[1]), int(y * image_dimensions[0]), int(w * image_dimensions[1]), int(h * image_dimensions[0])
                x = x - w // 2
                y = y - h // 2
        os.mkdir(f'data/{object_type}_images_sam2_mask') if not os.path.exists(f'data/{object_type}_images_sam2_mask') else None
        
        # Predict segmentation using the SAM model with bounding box
        results = model(image_path, bboxes=xywh_to_bbox(x, y, w, h))
        for result in results:
            # Display the image with the segmentation mask
            # result.show()

            # Access the masks
            masks = results[0].masks

            # Assuming single class segmentation for simplicity, adjust as needed
            mask = masks[0].data.squeeze().cpu().numpy()  # For multi-class, iterate over masks
            mask = mask.astype(np.uint8) # Convert mask to uint8 if needed)
            mask = cv2.resize(mask, (image_dimensions[1], image_dimensions[0]))
            
            image = cv2.imread(image_path)
            image = cv2.resize(image, (image_dimensions[1], image_dimensions[0]))
            
            # Negate the mask and mask the image
            negative_mask = 1-mask
            print(np.unique(mask))
            negative_image = cv2.bitwise_not(image)
            negative_image = cv2.bitwise_and(negative_image, negative_image, mask=mask)
            masked_image = cv2.bitwise_not(negative_image)

            # negative_mask = negative_mask[y: y + h, x: x + w]  # Crop the mask to the bounding box
            # image = image[y: y + h, x: x + w] # Crop the masked image to the bounding box

            # Save the masked image
            output_dir = f'data/{object_type}_images_sam2_mask'
            if not os.path.exists(output_dir):
                os.makedirs(output_dir, exist_ok=True)
            
            # plt.imshow(negative_mask, cmap='gray')
            # plt.axis('off')
            # plt.show()

            # cv2.imwrite(os.path.join(output_dir, os.path.basename(image_path)), image)
            # cv2.imwrite(os.path.join(output_dir, os.path.basename(image_path).split('.')[0] + '_masked.png'), masked_image)
            cv2.imwrite(os.path.join(output_dir, os.path.basename(image_path).split('.')[0] + '.pbm'), negative_mask*255)
            # img = Image.open(os.path.join(output_dir, os.path.basename(image_path).split('.')[0] + '_mask.jpg'))
            # plt.imshow(img, cmap='gray')
            # plt.axis('off')
            # plt.show()

            # # Plot the masked image
            # plt.imshow(cv2.cvtColor(masked_image, cv2.COLOR_BGR2RGB))
            # plt.axis('off')
            # plt.show()

if __name__ == "__main__":
    folder_path = 'download/croissant'
    segment_images_from_folder_bbox(folder_path, 'croissant')


image 1/1 /home/decla_5ay7wb/RIPS25-AnalogDevices-ObjectDetection/download/croissant/images/007762c8816ddf93.jpg: 1024x1024 1 0, 115.2ms
Speed: 2.7ms preprocess, 115.2ms inference, 0.7ms postprocess per image at shape (1, 3, 1024, 1024)
[0 1]

image 1/1 /home/decla_5ay7wb/RIPS25-AnalogDevices-ObjectDetection/download/croissant/images/00b6c796919ea125.jpg: 1024x1024 1 0, 100.0ms
Speed: 2.7ms preprocess, 100.0ms inference, 0.7ms postprocess per image at shape (1, 3, 1024, 1024)
[0 1]

image 1/1 /home/decla_5ay7wb/RIPS25-AnalogDevices-ObjectDetection/download/croissant/images/00c3a38d7a8c908b.jpg: 1024x1024 1 0, 99.7ms
Speed: 2.6ms preprocess, 99.7ms inference, 0.7ms postprocess per image at shape (1, 3, 1024, 1024)
[0 1]

image 1/1 /home/decla_5ay7wb/RIPS25-AnalogDevices-ObjectDetection/download/croissant/images/018d3b8f75ec8e49.jpg: 1024x1024 1 0, 99.8ms
Speed: 2.6ms preprocess, 99.8ms inference, 0.7ms postprocess per image at shape (1, 3, 1024, 1024)
[0 1]

image 1/1 /home/decla_5ay7w

In [None]:
for pbm_path in glob.glob('data/croissant_images_sam2_mask/*_mask.jpg'):
    mask = cv2.imread(pbm_path)
    plt.imshow(mask*255, cmap='gray')
    plt.axis('off')
    plt.show()
    # cv2.imwrite(os.path.join('data/cut_and_paste_root/croissant', os.path.basename(pbm_path).split('.')[0]+'.pbm'), mask)

In [61]:
for imgpath in glob.glob('data/croissant_images_sam2_mask/*.pbm'):
    basenames = [os.path.basename(f) for f in glob.glob('data/cut_and_paste_root/croissant/*')]
    if os.path.basename(imgpath).split('.')[0]+'.jpg' in basenames:
        os.system(f'cp {imgpath} data/cut_and_paste_root/croissant/{os.path.basename(imgpath).split(".")[0]}.pbm')

In [62]:
os.system('python Cut-and-Paste/dataset_generator.py --scale --rotation --num 5 data/cut_and_paste_root data/cut_and_paste_croissants')

Number of background images : 8124
List of distractor files collected: []
Working on data/cut_and_paste_croissants/images/1_none.jpg
Working on data/cut_and_paste_croissants/images/2_none.jpg
Working on data/cut_and_paste_croissants/images/3_none.jpg
Working on data/cut_and_paste_croissants/images/5_none.jpg
Working on data/cut_and_paste_croissants/images/6_none.jpg
Working on data/cut_and_paste_croissants/images/4_none.jpg
Working on data/cut_and_paste_croissants/images/7_none.jpg
Working on data/cut_and_paste_croissants/images/8_none.jpg
Working on data/cut_and_paste_croissants/images/9_none.jpg
Working on data/cut_and_paste_croissants/images/11_none.jpg
Working on data/cut_and_paste_croissants/images/12_none.jpg
Working on data/cut_and_paste_croissants/images/13_none.jpg
Working on data/cut_and_paste_croissants/images/10_none.jpg
Working on data/cut_and_paste_croissants/images/14_none.jpg
Working on data/cut_and_paste_croissants/images/15_none.jpg
Working on data/cut_and_paste_crois

0

In [66]:
# Get image size in bytes
print(cv2.imread('data/cut_and_paste_croissants/images/8_poisson.jpg').shape)

(1920, 2560, 3)
