In [1]:
%%capture

!pip install datasets
!pip install ultralytics=8.0.227

In [12]:
import re
import cv2
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Polygon
from pathlib import Path
import torch

from datasets import load_dataset

import ultralytics
from ultralytics import SAM, YOLO

print(ultralytics.__version__)

8.0.227


In [3]:
def custom_auto_annotate(data, bbx_file, sam_model='sam_b.pt', device='', output_dir='./'):
    """
    Automatically annotates images using a YOLO object detection model and a SAM segmentation model.

    Args:
        data (str): Path to a folder containing images to be annotated.
        bbx_file (str): Path to a file with bounding_boxes annotated to images.
        sam_model (str, optional): Pre-trained SAM segmentation model. Defaults to 'sam_b.pt'.
        device (str, optional): Device to run the models on. Defaults to an empty string (CPU or GPU, if available).
        output_dir (str, optional): Directory to save the annotated results.
    """
    # Load SAM model
    sam_model = SAM(sam_model)

    # Create directory if doesn't exist
    path = Path(output_dir)
    path.mkdir(parents=True, exist_ok=True)

    with open(bbx_file, 'r') as bbx_f:
      lines = bbx_f.readlines()

    image_path = None
    bounding_boxes = []
    for line in lines:
      line = line.strip()
      if not line:
          continue # skip empty lines

      image_path_match = re.match(r'^([^/]+/[^ ]+\.jpg)$', line) # Regex for image path
      bounding_box_count_match = re.fullmatch(r'\d+$', line) # Regex for number of bounding boxes
      bounding_box_match = re.fullmatch(r'(\d+) (\d+) (\d+) (\d+) (\d+) (\d+) (\d+) (\d+) (\d+) (\d+)$', line) # Regex for bounding boxes

      if image_path_match: # New file so we process the previous data
        print(f"processing image: {image_path_match.group(1)}")

        if bounding_boxes == [] or len(bounding_boxes) > 50: # We don't process images with no bounding boxes or more than 50
          image_path = image_path_match.group(1)
          continue

        # Compute mask with SAM
        full_image_path = data / Path(image_path)
        sam_results = sam_model(full_image_path, bboxes=np.array(bounding_boxes), verbose=False, save=False, device=device)
        segments = sam_results[0].masks.xyn

        # Save mask
        with open(f'{str(Path(output_dir) / Path(image_path).stem)}.txt', 'w') as f:
          for i in range(len(segments)):
            s = segments[i]
            if len(s) == 0:
              continue
            segment = map(str, segments[i].reshape(-1).tolist())
            f.write(f'{0} ' + ' '.join(segment) + '\n') # We put 0 because we only consider face

        # Next image
        image_path = image_path_match.group(1)
        bounding_boxes.clear()

      elif bounding_box_count_match: # We don't need this information
        continue 
      elif bounding_box_match: # Get current bounding box
          box_info = [int(val) for val in bounding_box_match.groups()]
          x, y, width, height = 0, 1, 2, 3
          box = np.array([
                box_info[x],
                box_info[y],
                box_info[x] + box_info[width],
                box_info[y] + box_info[height]
            ])
          bounding_boxes.append(box)

In [15]:
data = "../data/images"
bbx_file = '../data/bounding_boxes/wider_faceseglite_bbx.txt'
sam_model = 'sam_b.pt'
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
output_directory = f"../data/masks_{datetime.now()}"

In [None]:
custom_auto_annotate(data=data, bbx_file=bbx_file, sam_model=sam_model, device=device, output_dir=output_directory)