### Setting up Basic Stuff
* Import Libraries
* Install Libraries
* Defining Functions

In [None]:
import sys

# sam
!{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'

# ultralytics
!pip install ultralytics

# other installations
!pip install -q roboflow dataclasses-json supervision==0.17

In [None]:
import glob
import os
from IPython import display
display.clear_output()

import ultralytics
from ultralytics import YOLO
from IPython.display import display, Image
import torch

import cv2
import matplotlib.pyplot as plt
import supervision as sv
import numpy as np
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

ultralytics.checks()

In [None]:
# setting up home variable
HOME = os.getcwd()
print("HOME:", HOME)

# setting up global variables
POSSIBLE_IMAGE_EXTENSIONS = ["jpg", "JPG", "png", "PNG", "JPEG", 'jpeg']
RESIZED_IMAGE_SIZE = (640, 490)

In [None]:
# setting up weights directory and downloading a particular sam model
!mkdir {HOME}/weights
%cd {HOME}/weights

!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

In [24]:
def load_image(img_name, img_address, resized_image_size):
  # load single image
  image_bgr = cv2.imread(img_address)
  original_image_size = image_bgr.shape[1], image_bgr.shape[0]

  image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
  resized_image = cv2.resize(image_rgb, resized_image_size) # for inference in YOLO

  return image_bgr, image_rgb, resized_image, original_image_size


def draw_bboxes_xyxyn(bboxes, img):
  colors = [(150, 150, 150)]
  drawn_img = img.copy()
  for i, box in enumerate(bboxes):
    x, y, x1, y1 = box
    x, x1 = x*img.shape[1], x1*img.shape[1]
    y, y1 = y*img.shape[0], y1*img.shape[0]
    cv2.rectangle(drawn_img, (int(x), int(y)), (int(x1), int(y1)), colors[0], 10)
  return drawn_img


def get_sam_masks(yolo_result, sam, resized_image):
  # multiple bounding boxes as input for a single image
  input_boxes = yolo_result.boxes.xyxy
  class_ids = yolo_result.boxes.cls.cpu().numpy()

  mask_predictor = SamPredictor(sam)
  transformed_boxes = mask_predictor.transform.apply_boxes_torch(input_boxes, resized_image.shape[:2])
  mask_predictor.set_image(resized_image)
  masks, iou_predictions, low_res_masks = mask_predictor.predict_torch(
      point_coords=None,
      point_labels=None,
      boxes=transformed_boxes,
      multimask_output=False
  )
  return masks, class_ids

def create_detections(masks, class_ids):
  # creating Detections object for all the masks
  xyxys = np.array([sv.mask_to_xyxy(masks=i.cpu()) for i in masks])
  xyxys = xyxys.squeeze(1)
  numpy_masks = masks.cpu().numpy().squeeze(1)
  detections = sv.Detections(
        class_id = class_ids,
        xyxy=xyxys,
        mask=numpy_masks
  )
  return detections

def draw_masks_image(image_bgr, detections):
  # bounding boxes and segmented areas
  box_annotator = sv.BoxAnnotator(color=sv.Color.red(), thickness=10)
  mask_annotator = sv.MaskAnnotator(color=sv.Color.red())
  source_image = image_bgr.copy()
  segmented_image = image_bgr.copy()

  source_image = box_annotator.annotate(scene=source_image,
                                        detections=detections,
                                        skip_label=False)
  segmented_image = mask_annotator.annotate(scene=segmented_image,
                                            detections=detections)

  # plot_grid = sv.plot_images_grid(
  #       images=[source_image, segmented_image],
  #       grid_size=(1, 2),
  #       titles=['image with SAM BB', 'segmented image'],
  #       size=(20, 20)
  #   )
  return segmented_image

def create_detections(masks, class_ids):
  # creating Detections object for all the masks
  xyxys = np.array([sv.mask_to_xyxy(masks=i.cpu()) for i in masks])
  xyxys = xyxys.squeeze(1)
  numpy_masks = masks.cpu().numpy().squeeze(1)
  detections = sv.Detections(
        class_id = class_ids,
        xyxy=xyxys,
        mask=numpy_masks
  )
  return detections

def remove_small_contours(masks):
    torch_masks = []
    for mask in masks:
        single_mask = np.array(mask[0].cpu()).astype(np.uint8).copy()

        contours, hierarchy = cv2.findContours(
            single_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
        )
        if len(contours)>1:
            cntsSorted = sorted(contours, key=lambda x: cv2.contourArea(x))
            cv2.drawContours(single_mask, cntsSorted[:-1], -1, color=0, thickness=cv2.FILLED)
        torch_mask = torch.tensor(single_mask.astype(bool)).cuda()[None, None, :, :]
        torch_masks.append(torch_mask)
    torch_masks = torch.concat(torch_masks)
    return torch_masks

### Establish Google Drive Connection


In [None]:
from google.colab import drive
drive.mount('/content/drive')

### Custom Variables - I

In [9]:
yolo_model_address = '/content/drive/MyDrive/Projects/Coral Microfragmentation/coral_detection_model/Yolo + SAM/yolo_ar_best.pt' #@param {type:"string"}

### Setting up the Models

In [8]:
# YOLO model
model = YOLO(yolo_model_address)

# SAM model
CHECKPOINT_PATH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth")

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
MODEL_TYPE = "vit_h"

sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)

mask_generator = SamAutomaticMaskGenerator(sam)

### Custom Variables - II

In [35]:
address_of_image_dir = "/content/drive/MyDrive/Projects/Coral Microfragmentation/coral_detection_model/Coral Detection Test Images/23-24 Coral Table Test Images" #@param {type:"string"}

### Getting the Predictions

In [57]:
# Reading in multiple images
img_addresses = []
for extension in POSSIBLE_IMAGE_EXTENSIONS:
  address_of_all_images = os.path.join(address_of_image_dir, f"*.{extension}")
  current_img_addresses = [i for i in glob.iglob(address_of_all_images)]
  img_addresses.extend(current_img_addresses)
img_names = [i.split("/")[-1] for i in img_addresses]

image_bgrs = []
image_rgbs = []
resized_images = []
original_image_sizes = []

for j, i_img_address in enumerate(img_addresses):
  img_name = img_names[j]
  image_bgr, image_rgb, resized_image, original_image_size = load_image(img_name, i_img_address, RESIZED_IMAGE_SIZE)
  image_bgrs.append(image_bgr)
  image_rgbs.append(image_rgb)
  resized_images.append(resized_image)
  original_image_sizes.append(original_image_size)

In [None]:
# Getting YOLO predictions for multiple images
results = model.predict(resized_images, conf=0.1)
images_with_bboxes = [draw_bboxes_xyxyn(result.boxes.xyxyn, image_rgb) for result in results]

In [59]:
# Getting SAM predictions for multiple images
all_masks = []
all_class_ids = []
for j, result in enumerate(results):
  if result.boxes.shape[0] != 0:
    masks, class_ids = get_sam_masks(result, sam, resized_images[j])
    ### TODO: Find better way to deal with multiple small contours
    masks = remove_small_contours(masks)
  else:
    class_ids = np.array([], dtype=float)
    masks = torch.zeros((0, 1, RESIZED_IMAGE_SIZE[1], RESIZED_IMAGE_SIZE[0]), dtype=bool)
  all_masks.append(masks)
  all_class_ids.append(class_ids)

# Resizing segmentation maps for multiple images
all_big_masks = []
for j, masks in enumerate(all_masks):
  big_masks = [torch.nn.functional.interpolate(i.to(torch.float32).unsqueeze(0),
                                              size=(original_image_sizes[j][1], original_image_sizes[j][0])).to(bool)
                                              for i in masks]
  if big_masks:
      big_masks = torch.stack(big_masks).squeeze(1)
  all_big_masks.append(big_masks)

# Creating detection objects for multiple images
all_detections = []
for j, big_masks in enumerate(all_big_masks):
  all_detections.append(create_detections(big_masks, all_class_ids[j]))

### Saving the results

In [60]:
inference_dataset = sv.DetectionDataset(["coral", 'ref'],
                                        {f"{img_name}": image_rgb for img_name, image_rgb in zip(img_names, image_rgbs)},
                                        {f"{img_name}": detections for img_name, detections in zip(img_names, all_detections)})

In [61]:
save_coco_file_address = f"{address_of_image_dir}/Output/coco_from_AI.json"
inference_dataset.as_coco(annotations_path=save_coco_file_address,
                          approximation_percentage=0)