This is a dataset annotator using SAM. For every image in the datset, it allows the user to mannually draw bounding boxes on images, generate masks, ans convert them into YOLO format for training purposes.

In [None]:
import cv2
import numpy as np
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
import torch
import random
from matplotlib import pyplot as plt
import os

In [None]:
dataset_name = 'my_custom_dataset'
os.makedirs(dataset_name, exist_ok=True)
images_folder_path = "test_images"
train_split_path = dataset_name + "/train"
train_images_folder_path = train_split_path + "/images"
train_labels_folder_path = train_split_path + "/labels"
os.makedirs(train_images_folder_path, exist_ok=True)
os.makedirs(train_labels_folder_path, exist_ok=True)

val_split_path = dataset_name + "/valid"
test_split_path = dataset_name + "/test"



In [None]:
image_files = [f for f in os.listdir(images_folder_path) if os.path.isfile(os.path.join(images_folder_path, f))]
image_number = 0
for image_file in image_files:
    image_number += 1
    image_path = os.path.join(images_folder_path, image_file)
    img = cv2.imread(image_path)
    cv2.imwrite(train_split_path + "/images/" + str(image_number) + ".png", img)
    r = {}
    while(True):
        label = input("Enter the label for the mask (exit to move to net image): ")
        if label == 'exit':
            break
        if label not in r:
            r[label] = {"bounding_box": [], "masks": []}
        while True:
            bbox = cv2.selectROI("Interactive Menu", img)
        
            convert = [bbox[0], bbox[1], bbox[0]+bbox[2], bbox[1]+bbox[3]]
            r[label]["bounding_box"].append(np.asarray(convert))
            cv2.destroyAllWindows()

            cont = input("Draw another bounding box for the same label? (y/n): ")
            if cont.lower() != 'y':
                break  # Move to the next label


    print("\nSegmenting the objects, this might take a while")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
    # sam.to(device)
    mask_predictor = SamPredictor(sam)
    img2 = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    mask_predictor.set_image(img2)

    for label in r:
        for bbox in r[label]["bounding_box"]:
            mask, _, _ = mask_predictor.predict(box=bbox, multimask_output=False)
            r[label]["masks"].append(mask[0])

    # mask_visualization = np.zeros_like(img)
    # for label in r:
    #     color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
    #     for mask in r[label]["masks"]:
    #         mask_visualization[mask] = color

    # plt.imshow(mask_visualization)
    # plt.show()

    color_mask = np.zeros_like(img)
    with open(train_split_path + "/labels/" + str(image_number) + ".txt", "w") as f:
        for label in r:
            for mask in r[label]["masks"]:
                mask[mask > 0] = 255
                mask = mask.astype(np.uint8)
                contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
                approx_contours = [cv2.approxPolyDP(contour, 0.001 * cv2.arcLength(contour, True), True) for contour in contours]
                color = [int(c) for c in np.random.choice(range(256), size=3)]
                for i, contour in enumerate(approx_contours):
                    f.write(str(label) + " ")
                    for point in contour:
                        f.write(str(point[0][0]/img.shape[0]) + " " + str(point[0][1]/img.shape[1]) + " ")
                    f.write("\n")
                    cv2.drawContours(color_mask, [contour], -1, color, 2)
    # for label in r:
    #     for mask in r[label]["masks"]:
    #         mask[mask > 0] = 255
    #         mask = mask.astype(np.uint8)
    #         contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    #         approx_contours = [cv2.approxPolyDP(contour, 0.001 * cv2.arcLength(contour, True), True) for contour in contours]
    #         color = [int(c) for c in np.random.choice(range(256), size=3)]  # Random color for each instance
    #         for i, contour in enumerate(approx_contours):
    #             cv2.drawContours(color_mask, [contour], -1, color, 2)  # Fill the contour with the color
    #             for point in contour:
    #                 print(point[0][0], point[0][1])

    plt.imshow(color_mask)
    plt.show()

