Import weights

In [1]:
import os

CHECKPOINT_PATH = os.path.join("weights", "sam_vit_h_4b8939.pth")
print(CHECKPOINT_PATH, ", exist: ", os.path.isfile(CHECKPOINT_PATH))

weights/sam_vit_h_4b8939.pth , exist:  True


Load the SAM model

In [2]:
import torch

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
MODEL_TYPE = "vit_h"

In [3]:
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)

  state_dict = torch.load(f)


In [4]:
sam_predictor = SamPredictor(sam)

Load the YOLO model

In [5]:
from ultralytics import YOLO

In [6]:
model = YOLO("yolov8n.pt")
model.to(DEVICE)

YOLO(
  (model): DetectionModel(
    (model): Sequential(
      (0): Conv(
        (conv): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        (act): SiLU(inplace=True)
      )
      (1): Conv(
        (conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        (act): SiLU(inplace=True)
      )
      (2): C2f(
        (cv1): Conv(
          (conv): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(32, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
          (act): SiLU(inplace=True)
        )
        (cv2): Conv(
          (conv): Conv2d(48, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(32, eps=0.001, momentum=0.03, affine=True, track_running_s

Import other library

In [7]:
import numpy as np
import matplotlib.pyplot as plt
import cv2

Helping function

In [8]:
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))

Start inferencing the video

In [27]:
video_path = "fight_2.mp4"
cap = cv2.VideoCapture(video_path)

output_path = "result_2.avi"
fourcc = cv2.VideoWriter_fourcc(*'XVID')
fps = int(cap.get(cv2.CAP_PROP_FPS))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    composite_mask = np.zeros_like(frame, dtype=np.uint8)

    objects = model(frame, classes=[0], conf=0.6)

    for obj in objects:
        boxes = obj.boxes
        cls = boxes.cls

        classes = ['person']

        for i in range(len(boxes)):
            class_name = classes[int(cls[i])]
            print(class_name)

            #Extort the box coordination
            xyxy = boxes.xyxy[i]
            x1, y1, x2, y2 = xyxy.cpu().numpy()
            
            #Plot the rectangle on to the frame
            cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)

            #Add text to the rectangle
            text = class_name
            font = cv2.FONT_HERSHEY_SIMPLEX
            font_scale = 0.5
            thickness = 2
            text_size, _ = cv2.getTextSize(text, font, font_scale, thickness)
            text_x = int(x1 + 5)
            text_y = int(y1 + text_size[1] + 5)
            cv2.putText(frame, text, (text_x, text_y), font, font_scale, (0, 0, 255), thickness)

            """
            Use the box and inference it to the SAM model
            """

            sam_predictor.set_image(frame)
            input_box = np.array(xyxy.cpu().numpy())

            masks, _, _ = sam_predictor.predict(
                point_coords=None,
                point_labels=None,
                box=input_box[None, :],
                multimask_output=False,
            )

            for mask in masks:
                colored_mask = np.zeros_like(frame)
                colored_mask[mask > 0] = [0, 255, 0]
                composite_mask = cv2.add(composite_mask, colored_mask)
    
    alpha = 0.5
    frame = cv2.addWeighted(frame, 1 - alpha, composite_mask, alpha, 0)

    out.write(frame)
    
    cv2.imshow("Frame with person segmentation", frame)

    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()


0: 384x640 4 persons, 68.3ms
Speed: 23.3ms preprocess, 68.3ms inference, 19.4ms postprocess per image at shape (1, 3, 384, 640)
person
person
person
person

0: 384x640 4 persons, 15.2ms
Speed: 7.3ms preprocess, 15.2ms inference, 2.2ms postprocess per image at shape (1, 3, 384, 640)
person
person
person
person

0: 384x640 4 persons, 8.9ms
Speed: 9.1ms preprocess, 8.9ms inference, 6.5ms postprocess per image at shape (1, 3, 384, 640)
person
person
person
person

0: 384x640 4 persons, 10.0ms
Speed: 2.3ms preprocess, 10.0ms inference, 2.0ms postprocess per image at shape (1, 3, 384, 640)
person
person
person
person


Inferencing a picture

In [33]:
image = cv2.imread("punch3.png")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

objects = model(image, classes=[0], conf=0.5)

for obj in objects:
    boxes = obj.boxes
    cls = boxes.cls

    classes = ['person']
    _len = len(boxes)
    composite_mask = np.zeros_like(image, dtype=np.uint8)

    for i in range(_len):
        class_name = classes[int(cls[i])]

        #Extort the box coordination
        xyxy = boxes.xyxy[i]
        x1, y1, x2, y2 = xyxy.cpu().numpy()
        
        """
        #Plot the rectangle on to the frame
        cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)

        #Add text to the rectangle
        text = class_name
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 0.5
        thickness = 2
        text_size, _ = cv2.getTextSize(text, font, font_scale, thickness)
        text_x = int(x1 + 5)
        text_y = int(y1 + text_size[1] + 5)
        cv2.putText(image, text, (text_x, text_y), font, font_scale, (0, 0, 255), thickness)

        Use the box and inference it to the SAM model
        """

        sam_predictor.set_image(image)
        input_box = np.array(xyxy.cpu().numpy())

        masks, _, _ = sam_predictor.predict(
            point_coords=None,
            point_labels=None,
            box=input_box[None, :],
            multimask_output=False,
        )

        for mask in masks:
            print("chk")
            colored_mask = np.zeros_like(image)
            colored_mask[mask > 0] = [0, 255, 0]
            composite_mask = cv2.add(composite_mask, colored_mask)

alpha = 1
image = cv2.addWeighted(image, 1 - alpha, composite_mask, alpha, 0)

cv2.imshow("Final image", image)
cv2.waitKey(0)
cap.release()
cv2.destroyAllWindows()



0: 448x640 2 persons, 53.0ms
Speed: 4.3ms preprocess, 53.0ms inference, 4.1ms postprocess per image at shape (1, 3, 448, 640)
chk
chk
