In [None]:
%pip install autodistill-grounded-sam autodistill-yolov8 roboflow

In [27]:
from autodistill_grounded_sam import GroundedSAM
from autodistill.detection import CaptionOntology
from autodistill.utils import plot
import cv2

from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm

# define an ontology to map class names to our GroundedSAM prompt
# the ontology dictionary has the format {caption: class}
# where caption is the prompt sent to the base model, and class is the label that will
# be saved for that caption in the generated annotations
# then, load the model

In [14]:
np.random.seed(3)

def show_mask(mask, ax, random_color=False, borders = True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask = mask.astype(np.uint8)
    mask_image =  mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    if borders:
        import cv2
        contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
        # Try to smooth contours
        contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
        mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))

def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        show_mask(mask, plt.gca(), borders=borders)
        if point_coords is not None:
            assert input_labels is not None
            show_points(point_coords, input_labels, plt.gca())
        if box_coords is not None:
            # boxes
            show_box(box_coords, plt.gca())
        if len(scores) > 1:
            plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
        plt.axis('off')
        plt.show()

In [None]:
classes = {
    "man": "man",
    "id": "id",
}

base_model = GroundedSAM(
    ontology=CaptionOntology(
        classes
    )
)

In [None]:
#predict on single image
results = base_model.predict("logistics.jpeg")

In [None]:
#one mask - one photo
if plot := False:
  show_masks(Image.open("/content/паспорт+лицо.jpg"), masks=results.mask, scores=results.confidence)

In [None]:
#all masks on same photo
if plot := False:
  plot(
      image=cv2.imread("/content/паспорт+лицо.jpg"),
      classes=base_model.ontology.classes(),
      detections=results
  )

In [32]:
def inference_seg_on_fold(base_model, fold_path):
  preds = base_model.label(fold_path, extension=".jpg")
  res_dict = {}
  for img_name, detections in tqdm(preds.annotations.items()):
    res_dict[img_name] = {
        'masks' : detections.mask,
        'conf' : detections.confidence,
        'class_ids' : detections.class_id
    }
  return res_dict

In [33]:
resutls = inference_seg_on_fold(
    base_model=base_model,
    fold_path='/content/context_images'
)

Labeling /content/context_images/паспорт+лицо.jpg:   0%|          | 0/1 [00:00<?, ?it/s]The `device` argument is deprecated and will be removed in v5 of Transformers.
torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.
None of the inputs have requires_grad=True. Gradients will be None
`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
Labeling /content/context_images/паспорт+лицо.jpg: 100%|██████████| 1/1 [00:02<00:00,  2.88s/it]
Passing a `Dict[str, np.ndarray]` into `DetectionDataset` is deprecated and will be removed in `supervision-0.26.0`. Use a list of paths `List[str]` instead.


Found /content/context_images_labeled/valid/images/паспорт+лицо.jpg as already present, not moving anything to /content/context_images_labeled/valid/images
Found /content/context_images_labeled/valid/labels/паспорт+лицо.txt as already present, not moving anything to /content/context_images_labeled/valid/labels
Labeled dataset created - ready for distillation.


100%|██████████| 1/1 [00:00<00:00, 1301.37it/s]


In [34]:
resutls

{'паспорт+лицо.jpg': {'masks': array([[[False, False, False, ..., False, False, False],
          [False, False, False, ..., False, False, False],
          [False, False, False, ..., False, False, False],
          ...,
          [False, False, False, ..., False, False, False],
          [False, False, False, ..., False, False, False],
          [False, False, False, ..., False, False, False]],
  
         [[False, False, False, ..., False, False, False],
          [False, False, False, ..., False, False, False],
          [False, False, False, ..., False, False, False],
          ...,
          [False, False, False, ..., False, False, False],
          [False, False, False, ..., False, False, False],
          [False, False, False, ..., False, False, False]],
  
         [[False, False, False, ..., False, False, False],
          [False, False, False, ..., False, False, False],
          [False, False, False, ..., False, False, False],
          ...,
          [False, False, False, .