<a href="https://colab.research.google.com/github/LHBuilder/SA-Segment-Anything/blob/main/integ_yolo_sam.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Integrate YOLO-NAS and Meta SAM

Setup Environment

In [None]:
# py -3.10 -m venv myvenv
# myvenv\Scripts\activate

# !pip install super-gradients==3.1.0
# !pip install imutils
# !pip install pytube --upgrade

# !pip install git+https://github.com/facebookresearch/segment-anything.git

In [None]:
import torch
torch.__version__
torch.cuda.get_device_name(0)

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

In [None]:
def show_mask(mask, ax, random_color=False):
  if random_color:
    color = np.concatenate([np.random.random(3), np.array([0.6])], aixs=0)
  else:
    color = np.array([30/255, 144/255, 255/255, 0.6])
  h, w = mask.shape[-2:]
  mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
  ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
  pos_points = coords[labels==1]
  neg_points = coords[labels==0]
  ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=market_size, edgecolor='white', linewidth=1.25)
  ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=market_size, edgecolor='white', linewidth=1.25)

def show_box(box, ax):
  x0, y0 = box[0], box[1]
  w, h = box[2] - box[0], box[3] - box[1]
  ax.add_patch(plt.Rectangle(x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))

def show_anns(anns):
  if len(anns) == 0:
    return
  sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
  ax = plt.gca()
  ax.set_autoscale_on(False)
  polygons = []
  color = []
  for ann in sorted_anns:
    m = ann['segmentation']
    img = np.ones(m.shape[0], m.shape[1], 3))
    color_mask = np.random.random((1, 3)).tolist()[0]
    for i in range(3):
      img[:,:,i] = color_mask[i]
    ax.imshow(np.dstack((img, m*0.35)))

In [None]:
image = cv2.imread('images/person.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

plt.figure(figsize=(10,10))
plt.imshow(image)
plt.axis('off')
plt.show

**YOLO-NAS Detects Objects**



In [None]:
from super_gradients.training import yolon
import cv2
import matplotlib.pyplot as plt
import pickle

image_path = 'images/person.jpg'

model = yolon.get('yolo_nas_l', pretrained_wrights='coco') # yolo_nas_l is the yolo_nas large model
model.predict(image, conf=0.25).show()

conf_threshold = 0.25
detection_pred = model.predict(image_path, conf=conf_threshold)
detections = detection_pred.save('output_folder') # save the output with detected bounding box

**SAM Selects Objects**

First, load the SAM model and predictor. Change the path below to point to the SAM checkpoint. And then run on CUDA and use the default model for best results

In [None]:
from segment_anything import sam_model_registry, SamPredictor
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
import matplotlib.pyplot as plt

image = cv2.imread('images/person.jpg')

# SAM model for masking
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

#predictor = SamPredicator(sam)

mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(image)

plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show()

Now all the objects were masked but we only want to put mask on Person
Steps:
1. Object detection using YOLO-NAS
2. Provide bounding box coordinates to SAM
3. SAM will provide the mask on Person

YOLO-NAS inference: Extract confidence, labels, and bounding boxes
Access this information via the _images_prediction_lst attribute of the prediction objects.

In [None]:
from super_gradients.training import yolon
import cv2

image_path = 'images/person.jpg'
model = yolon.get("yolo_nas_l", pretrained_weights="coco")
conf_threshold = 0.25

detection_pred = model.predict(image_path, conf = conf_threshold)._images_prediction_lst

Extract only the desired information

In [None]:
# Extract desired outputs
bboxes_xyxy = detection_pred[0].prediction.bboxes_xyxy.tolist()
confidence = detection_pred[0].prediction.confidence.tolist()
labels = detection_pred[0].prediction.labels.tolist()

print("Bounding Boxes (xyxy):", bboxes_xyxy)
print("Confidence:", confidence)
print("Labels:", labels)

bboxes_xyxy,confidence,labels

In [None]:
import sys
sys.path.append("..")
from super_gradients.training import yolon
from segment_anything import sam_model_registry, SamPredictor
import cv2
import matplotlib.pyplot as plt
import numpy as np

image_path = 'images/person.jpg'
image = cv2.imread('images/person.jpg')

# SAM model for masking
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuba"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

predictor = SamPredictor(sam)

predictor.set_image(image)

image = image.transpose((2, 0, 1)) # Tranpose to match SAM input format
image = image / 255.0 # Normalize image values to [0, 1]
image = np.expand_dims(image, axis=0) # Add batch dimension

input_box = np.array(bboxes_xyxy[0])

if labels[0] == 0:
  # predict masks using SAM
  masks, _, _ = predictor.predict(
      point_coords=None,
      point_labels=None,
      box=input_box[Noe, :],
      multimask_output=False,
  )

# Display the image with masks and bounding box
plt.figure(figsize=(10, 10))
plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))

if labels[0] == 0:
  plt.imshow(masks[0], alpha=0.5)
plt.gca().add_patch(plt.Rectangle((input_box[0], input_box[1]), input_box[2] - input_box[0], input_box[3] - input_box[1],
                                   linewidth=2, edgecolor='r', facecolor='none'))
plt.text(input_box[0], input_box[1] - 5, 'Person', fontsize=12, color='r', backgroundcolor='w')
plt.axis('off')
plt.savefig('output_folder/output3.png')
plt.show