In [None]:
import os
import cv2
import json
import torch
import numpy as np
from tqdm import tqdm
from datetime import datetime
from segment_anything import sam_model_registry, SamPredictor
from pycocotools import mask as mask_utils

In [None]:
print(torch.cuda.is_available())

In [None]:
image_folder = r"Dataset_resized"
checkpoint_path = r"C:\Users\poten\Downloads\sam_vit_h_4b8939.pth"
model_type = "vit_h"
output_root = "segmentation_outputs"
os.makedirs(output_root, exist_ok=True)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
sam.to(device)
predictor = SamPredictor(sam)

In [None]:
# Track processed images
processed_file = os.path.join(output_root, "processed_images.json")
if os.path.exists(processed_file):
    with open(processed_file, "r") as f:
        processed_images = set(json.load(f))
else:
    processed_images = set()

In [None]:
def binary_mask_to_rle(mask):
    rle = mask_utils.encode(np.asfortranarray(mask.astype(np.uint8)))
    rle["counts"] = rle["counts"].decode("utf-8")  # bytes to str for JSON
    return rle


annotation_id = 1
image_id = 1

In [None]:
# i = 0
for fname in tqdm(os.listdir(image_folder)):
    if not fname.lower().endswith((".png", ".jpg", ".jpeg")):
        continue

    if fname in processed_images:
        continue

    # i += 1
    # if i < 100:
    #     continue


    while True:

        image_path = os.path.join(image_folder, fname)

        image_bgr = cv2.imread(image_path)

        image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)

        h, w = image_rgb.shape[:2]


        predictor.set_image(image_rgb)


        click_points = []

        click_labels = []


        def click_event(event, x, y, flags, param):

            if event == cv2.EVENT_LBUTTONDOWN:

                click_points.append([x, y])

                click_labels.append(1)

                cv2.circle(image_bgr, (x, y), 5, (0, 255, 0), -1)

                cv2.imshow("Click points, press 's' to segment or 'k' to skip", image_bgr)


        print(f"\nProcessing image: {fname}")

        cv2.imshow("Click points, press 's' to segment or 'k' to skip", image_bgr)

        cv2.setMouseCallback("Click points, press 's' to segment or 'k' to skip", click_event)


        key = None

        # Wait for 's' to segment or 'k' to skip or 'ESC' to exit

        while True:

            key = cv2.waitKey(1)

            if key == ord("s") and click_points:

                break

            if key == ord("k"):

                print("⏭️ Skipping this image...")

                break

            elif key == 27:

                cv2.destroyAllWindows()

                exit()


        if key == ord("k"):

            break  # Move to next image


        input_points = np.array(click_points)

        input_labels = np.array(click_labels)

        masks, scores, logits = predictor.predict(

            point_coords=input_points,

            point_labels=input_labels,

            multimask_output=False,
        )

        mask = masks[0]


        result_overlay = image_bgr.copy()

        result_overlay[mask] = [0, 255, 0]


        # Show result and ask user to retry or accept

        cv2.imshow("Result - press 'r' to retry or any key to accept", result_overlay)

        key = cv2.waitKey(0)

        cv2.destroyAllWindows()


        if key == ord("r"):

            print("🔁 Retrying segmentation for this image...")

            continue  # Re-do the same image


        # Save segmentation result

        base_name = os.path.splitext(fname)[0]

        overlay_path = os.path.join(output_root, f"{base_name}_segmented.png")

        cv2.imwrite(overlay_path, result_overlay)


        coco_data = {

            "info": {

                "description": "Manual SAM Segmentation",

                "date_created": datetime.now().isoformat(),

            },

            "images": [{"id": image_id, "file_name": fname, "width": w, "height": h}],

            "annotations": [

                {

                    "id": annotation_id,

                    "image_id": image_id,

                    "category_id": 1,

                    "segmentation": binary_mask_to_rle(mask),

                    "area": int(mask.sum()),

                    "bbox": list(cv2.boundingRect(mask.astype(np.uint8))),

                    "iscrowd": 0,

                }

            ],

            "categories": [{"id": 1, "name": "object"}],

        }


        json_path = os.path.join(output_root, f"{base_name}.json")

        with open(json_path, "w") as f:

            json.dump(coco_data, f)


        # ✅ Save processed image info

        processed_images.add(fname)

        with open(processed_file, "w") as f:

            json.dump(list(processed_images), f)


        annotation_id += 1

        image_id += 1

        break  # Move to next image


print("\n✅ All segmentations complete. JSON and segmented images saved.")