In [1]:
import os
import torch
import numpy as np
from PIL import Image
import cv2
import matplotlib.pyplot as plt
from pathlib import Path
from ultralytics import SAM

In [None]:
base_dir = "../../data/raw/openimages-download-v0"

In [3]:
# Set up SAM2 model
def setup_sam2_model():
    """Initialize and return SAM2 model"""
    # Initialize SAM model (will automatically download if needed)
    model = SAM('sam2_b.pt')  # You can also use 'sam2_l.pt' for larger model
    return model

model = setup_sam2_model()
print("SAM2 model loaded successfully!")

SAM2 model loaded successfully!


In [4]:
def read_darknet_bboxes(bbox_path, image_width, image_height):
	"""Read bounding boxes from darknet format file and convert to pixel coordinates"""
	bboxes = []
	
	with open(bbox_path, 'r') as f:
		for line in f:
			parts = line.strip().split()
			assert len(parts) == 5, f"Invalid bbox line: {line.strip()}"
			
			# Darknet format: class_id x_center y_center width height (normalized)
			class_id = int(parts[0])
			x_center = float(parts[1])
			y_center = float(parts[2])
			width = float(parts[3])
			height = float(parts[4])
			
			# Convert from normalized coordinates to pixel coordinates
			x_center_px = x_center * image_width
			y_center_px = y_center * image_height
			width_px = width * image_width
			height_px = height * image_height
			
			# Convert to x1, y1, x2, y2 format
			x1 = int(x_center_px - width_px / 2)
			y1 = int(y_center_px - height_px / 2)
			x2 = int(x_center_px + width_px / 2)
			y2 = int(y_center_px + height_px / 2)
			
			# Ensure coordinates are within image bounds
			x1 = max(0, min(x1, image_width - 1))
			y1 = max(0, min(y1, image_height - 1))
			x2 = max(0, min(x2, image_width - 1))
			y2 = max(0, min(y2, image_height - 1))
			
			bboxes.append([x1, y1, x2, y2])

	return bboxes

def generate_automatic_masks(image_path, label_path, model):
	"""Generate automatic segmentation masks for an image using bounding boxes if available"""

	# Load original image to get dimensions
	image = cv2.imread(str(image_path))
	height, width = image.shape[:2]
	
	# Look for corresponding bboxes from label file
	bboxes = read_darknet_bboxes(label_path, width, height)
	assert bboxes, "No bounding boxes found in the file: {}".format(label_path)
	
	results = model(str(image_path), bboxes=bboxes)

	assert (len(results) != 0) and (results[0].masks is not None), f"No masks generated for {image_path}"

	# Get all masks from the result
	masks = results[0].masks.data.cpu().numpy()

	# Combine all masks
	combined_mask = np.zeros((height, width), dtype=np.uint8)
	for mask in masks:
		if mask.shape != (height, width):
			mask = cv2.resize(mask.astype(np.uint8), (width, height))
		combined_mask = np.logical_or(combined_mask, mask > 0.5).astype(np.uint8)
	
	return combined_mask * 255  # Convert to 0-255 range

In [None]:
"""Process all images in the directory structure and generate masks"""
base_path = Path(base_dir)

assert base_path.exists(),"Directory {base_dir} does not exist!"

processed_count = 0
error_count = 0

# Walk through all class directories
for subdir in base_path.iterdir():
	if not subdir.is_dir():
		continue

	images_dir = subdir / "images"
	if not images_dir.exists():
		continue

	print(f"Processing subfolder: {subdir.name}")

	# Process all images in the class directory
	image_files = list(images_dir.glob("*.jpg")) + list(images_dir.glob("*.jpeg")) + list(images_dir.glob("*.png"))
	
	for image_path in image_files:
		# Skip if this is already a mask file
		if "_mask" in image_path.stem:
			continue
		
		print(f"  Processing: {image_path.name}")
		
		# Generate mask
		label_path = str(image_path).replace('/images/', '/labels/').replace('jpg','txt')
		mask = generate_automatic_masks(image_path, label_path, model)

		if mask is not None:
			# Create mask filename
			mask_filename = f"{image_path.stem}_mask{image_path.suffix}"
			mask_path = images_dir / mask_filename
			
			# Save mask
			cv2.imwrite(str(mask_path), mask)
			processed_count += 1
			print(f"    Saved mask: {mask_filename}")
		else:
			print(f"    Failed to generate mask for {image_path.name}")
			error_count += 1

print(f"\nProcessing complete!")
print(f"Successfully processed: {processed_count} images")
print(f"Errors encountered: {error_count} images")

Processing subfolder: screwdriver
  Processing: 8abe387edad1ab14.jpg

image 1/1 /home/vikhyat/RIPS25-AnalogDevices-ObjectDetection/src/Cut-and-Paste/../../data/raw/openimages-download-v0/screwdriver/images/8abe387edad1ab14.jpg: 1024x1024 1 0, 566.7ms
Speed: 19.1ms preprocess, 566.7ms inference, 35.1ms postprocess per image at shape (1, 3, 1024, 1024)
    Saved mask: 8abe387edad1ab14_mask.jpg
  Processing: 1f10b3e18c251566.jpg

image 1/1 /home/vikhyat/RIPS25-AnalogDevices-ObjectDetection/src/Cut-and-Paste/../../data/raw/openimages-download-v0/screwdriver/images/1f10b3e18c251566.jpg: 1024x1024 1 0, 1 1, 1 2, 122.1ms
Speed: 5.8ms preprocess, 122.1ms inference, 0.7ms postprocess per image at shape (1, 3, 1024, 1024)
    Saved mask: 1f10b3e18c251566_mask.jpg
  Processing: 38cf55592ebb90d7.jpg

image 1/1 /home/vikhyat/RIPS25-AnalogDevices-ObjectDetection/src/Cut-and-Paste/../../data/raw/openimages-download-v0/screwdriver/images/38cf55592ebb90d7.jpg: 1024x1024 1 0, 99.2ms
Speed: 4.1ms preproc

In [6]:
# Optional: Visualize some results
def visualize_segmentation_results(base_dir, class_name, num_samples=3):
    """Visualize original images and their generated masks"""
    images_dir = Path(base_dir) / class_name / "images"
    
    if not images_dir.exists():
        print(f"Directory {images_dir} does not exist!")
        return
    
    # Get some sample images (non-mask files)
    image_files = [f for f in images_dir.glob("*.jpg") if "_mask" not in f.stem][:num_samples]
    
    fig, axes = plt.subplots(num_samples, 2, figsize=(12, 4*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    for i, image_path in enumerate(image_files):
        # Load original image
        original = cv2.imread(str(image_path))
        original_rgb = cv2.cvtColor(original, cv2.COLOR_BGR2RGB)
        
        # Load corresponding mask
        mask_path = images_dir / f"{image_path.stem}_mask{image_path.suffix}"
        
        if mask_path.exists():
            mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
            
            # Display original image
            axes[i, 0].imshow(original_rgb)
            axes[i, 0].set_title(f"Original: {image_path.name}")
            axes[i, 0].axis('off')
            
            # Display mask
            axes[i, 1].imshow(mask, cmap='gray')
            axes[i, 1].set_title(f"Mask: {mask_path.name}")
            axes[i, 1].axis('off')
        else:
            print(f"Mask not found for {image_path.name}")
    
    plt.tight_layout()
    plt.show()

# Example: Visualize results for a specific class (replace with actual class name)
# visualize_segmentation_results("../../data/raw/openimages-download-v0", "your_class_name", 3)

In [7]:
# Optional: Check processing status
"""Check how many images have been processed vs total images"""
base_path = Path(base_dir)

total_images = 0
total_masks = 0

for class_dir in base_path.iterdir():
	if not class_dir.is_dir():
		continue
		
	images_dir = class_dir / "images"
	if not images_dir.exists():
		continue
		
	# Count original images (non-mask files)
	image_files = [f for f in images_dir.glob("*.jpg") if "_mask" not in f.stem]
	mask_files = list(images_dir.glob("*_mask.jpg"))
	
	class_images = len(image_files)
	class_masks = len(mask_files)
	
	total_images += class_images
	total_masks += class_masks
	
	print(f"{class_dir.name}: {class_masks}/{class_images} masks generated")

print(f"\nOverall: {total_masks}/{total_images} masks generated ({total_masks/total_images*100:.1f}%)")

screwdriver: 37/37 masks generated
wrench: 31/31 masks generated
power plugs and sockets: 50/50 masks generated
door handle: 50/50 masks generated
hammer: 50/50 masks generated

Overall: 218/218 masks generated (100.0%)
