In [1]:
!pip install git+https://github.com/facebookresearch/segment-anything.git
!pip install opencv-python matplotlib torch torchvision pycocotools
!pip install torch --index-url https://download.pytorch.org/whl/cu118


Collecting git+https://github.com/facebookresearch/segment-anything.git
  Cloning https://github.com/facebookresearch/segment-anything.git to /tmp/pip-req-build-3ayvb8bh
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.git /tmp/pip-req-build-3ayvb8bh
  Resolved https://github.com/facebookresearch/segment-anything.git to commit dca509fe793f601edb92606367a655c15ac00fdf
  Preparing metadata (setup.py) ... [?25l[?25hdone
Looking in indexes: https://download.pytorch.org/whl/cu118


In [1]:
from google.colab import drive
drive.mount('/content/drive')
source_dir = "/content/drive/MyDrive/merged_dataset"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import zipfile
import os

zip_path = "/content/drive/MyDrive/merged_dataset.zip"
extract_to = "/content/drive/MyDrive/merged_dataset"

# Create target directory if it doesn't exist
os.makedirs(extract_to, exist_ok=True)

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_to)

print("Extraction complete!")

In [None]:
import os, gc, torch

gc.collect()
torch.cuda.empty_cache()

os.kill(os.getpid(), 9)


In [2]:
import os, sys, json, cv2, torch, numpy as np, urllib.request
from tqdm import tqdm
from collections import defaultdict
from segment_anything import SamPredictor, sam_model_registry

# === CONFIGURATION ===
source_dir = "/content/drive/MyDrive/merged_dataset"
coco_file = "_annotations_all.coco.json"
images_dir = os.path.join(source_dir, "images")
masks_dir = os.path.join(source_dir, "auto_masks")
os.makedirs(masks_dir, exist_ok=True)
output_json = "_annotations_masks_auto.coco.json"

# === MODEL SELECTION ===
model_choice = input("Select SAM model (h=vit_h, l=vit_l, b=vit_b) [default h]: ").lower()
model_type = {"h": "vit_h", "l": "vit_l", "b": "vit_b"}.get(model_choice, "vit_h")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
sam_urls = {
    "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
    "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
    "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
}

sam_checkpoint = f"sam_{model_type}.pth"

# Download SAM checkpoint
if not os.path.exists(sam_checkpoint):
    print(f"Downloading {model_type} model...")
    urllib.request.urlretrieve(sam_urls[model_type], sam_checkpoint)
    print("Done!")

# Load SAM
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device)
predictor = SamPredictor(sam)

# Load COCO JSON
with open(os.path.join(source_dir, coco_file)) as f:
    coco = json.load(f)

images = {img["id"]: img for img in coco["images"]}
anns = coco["annotations"]

# Group annotations per image
anns_by_image = defaultdict(list)
for ann in anns:
    anns_by_image[ann["image_id"]].append(ann)

# Segmentation loop
print(f"Segmenting {len(anns)} bboxes across {len(anns_by_image)} images...")

for image_id, ann_list in tqdm(anns_by_image.items()):
    img_info = images[image_id]
    img_path = os.path.join(images_dir, img_info["file_name"])

    if not os.path.exists(img_path):
        print(f"Missing: {img_path}")
        continue

    image = cv2.imread(img_path)
    if image is None:
        continue

    predictor.set_image(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))

    # Process all bboxes in this image
    for ann in ann_list:
        x, y, w, h = map(int, ann["bbox"])
        x1, y1 = max(0, x), max(0, y)
        x2 = min(x + w, image.shape[1] - 1)
        y2 = min(y + h, image.shape[0] - 1)

        masks, scores, _ = predictor.predict(
            box=np.array([x1, y1, x2, y2]),
            multimask_output=True
        )
        best_mask = masks[np.argmax(scores)].astype(np.uint8)

        # Save mask as PNG
        mask_filename = f"{os.path.splitext(img_info['file_name'])[0]}_{ann['id']}.png"
        mask_path = os.path.join(masks_dir, mask_filename)
        cv2.imwrite(mask_path, best_mask * 255)

        # Convert mask to polygons
        contours, _ = cv2.findContours(best_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        polygons = []
        for c in contours:
            c = c.flatten().tolist()
            if len(c) >= 6:
                polygons.append(c)

        # Update annotation
        ann["segmentation"] = polygons
        ann["segmentation_mask"] = mask_filename
        ann["iscrowd"] = 0

print("All masks saved.")

# Save updated COCO JSON
with open(os.path.join(source_dir, output_json), "w") as f:
    json.dump(coco, f, indent=2)

print("COCO file updated.")


Select SAM model (h=vit_h, l=vit_l, b=vit_b) [default h]: h
cuda
Segmenting 9831 bboxes across 425 images...


100%|██████████| 425/425 [22:05<00:00,  3.12s/it]


All masks saved.
COCO file updated.
