In [1]:
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
import numpy as np
from PIL import Image
import os
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd


In [4]:
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

--2025-08-07 18:50:10--  https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 2600:9000:2105:9e00:13:6e38:acc0:93a1, 2600:9000:2105:ae00:13:6e38:acc0:93a1, 2600:9000:2105:c000:13:6e38:acc0:93a1, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|2600:9000:2105:9e00:13:6e38:acc0:93a1|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2564550879 (2.4G) [binary/octet-stream]
Saving to: ‘sam_vit_h_4b8939.pth’


2025-08-07 18:51:21 (34.5 MB/s) - ‘sam_vit_h_4b8939.pth’ saved [2564550879/2564550879]



In [2]:
IMAGES_DIR = "dataset"     # folder with your images
CROPS_DIR = "crops_for_labeling"     # where crops will be saved
YOLO_LABELS_DIR = "yolo_labels"      # YOLO label txt files
SAM_CHECKPOINT = "sam_vit_h_4b8939.pth"  # path to SAM weights

In [3]:
os.makedirs(CROPS_DIR, exist_ok=True)
os.makedirs(YOLO_LABELS_DIR, exist_ok=True)

In [5]:
device = "cpu"  # or "cuda" if on GPU

sam = sam_model_registry["vit_h"](checkpoint=SAM_CHECKPOINT)
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)

In [None]:
crops_data = []

for img_file in tqdm(os.listdir(IMAGES_DIR)):
    if not img_file.lower().endswith((".jpg", ".jpeg", ".png")):
        continue
    img_path = os.path.join(IMAGES_DIR, img_file)
    img = cv2.imread(img_path)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    masks = mask_generator.generate(img_rgb)

    img_pil = Image.fromarray(img_rgb)
    H, W = img_pil.size

    for i, mask_dict in enumerate(masks):
        mask = mask_dict['segmentation']
        y_idx, x_idx = np.where(mask)
        if len(x_idx) == 0 or len(y_idx) == 0:
            continue
        x_min, x_max = x_idx.min(), x_idx.max()
        y_min, y_max = y_idx.min(), y_idx.max()
        bbox = [x_min, y_min, x_max, y_max]

        # Save crop for labeling
        crop = img_pil.crop(bbox)
        crop_filename = f"{os.path.splitext(img_file)[0]}_obj{i+1}.jpg"
        crop_path = os.path.join(CROPS_DIR, crop_filename)
        crop.save(crop_path)

        # Record for labeling
        crops_data.append({
            "image_file": img_file,
            "crop_file": crop_filename,
            "bbox_x1": x_min, "bbox_y1": y_min, "bbox_x2": x_max, "bbox_y2": y_max,
            "img_w": W, "img_h": H,
            "label": "",      # <--- Fill this after labeling (manual or with LLM)
        })

# Save all crops to a CSV for fast labeling
df = pd.DataFrame(crops_data)
df.to_csv("rnd/crops_to_label.csv", index=False)
print("Crops ready for labeling! Fill in the 'label' column in crops_to_label.csv.")


100%|██████████| 15/15 [34:35<00:00, 138.34s/it]

Crops ready for labeling! Fill in the 'label' column in crops_to_label.csv.



