In [None]:
import os
HOME = os.getcwd()
print("HOME:", HOME)

#Install + Import

In [None]:
!pip install transformers==4.49.0

In [None]:
!git clone https://github.com/zdata-inc/sam2_realtime
%cd {HOME}/sam2_realtime
!pip install -e . -q

from sam2.build_sam import build_sam2_object_tracker

%cd checkpoints
!sh download_ckpts.sh
%cd ..
%cd ..


In [None]:
!pip install autodistill-grounded-sam-2

In [None]:
!pip install -q supervision jupyter_bbox_widget


In [None]:
import os
import time
import urllib

import cv2
import numpy as np
import torch
from IPython.display import clear_output, display
from PIL import Image
import supervision as sv

In [None]:
class Visualizer:
    def __init__(self,
                 video_width,
                 video_height,
                 ):

        self.video_width = video_width
        self.video_height = video_height

    def resize_mask(self, mask):
        mask = torch.tensor(mask, device='cpu')
        mask = torch.nn.functional.interpolate(mask,
                                               size=(self.video_height, self.video_width),
                                               mode="bilinear",
                                               align_corners=False,
                                               )

        return mask

    def add_frame(self, frame, mask):
        frame = frame.copy()
        frame = cv2.resize(frame, (self.video_width, self.video_height))

        mask = self.resize_mask(mask=mask)
        mask = (mask > 0.0).numpy()

        colors = [[255, 20, 147], [255, 99, 71], [1, 255, 20], [255, 215, 0], [0, 226, 255], [255, 0, 43], [128, 128, 128], [0, 10, 64], [92, 247, 107], [221, 194, 255]]

        for i in range(mask.shape[0]):
            obj_mask = mask[i, 0, :, :]
            frame[obj_mask] = colors[i]

        rgb_frame = Image.fromarray(frame)
        clear_output(wait=True)
        display(rgb_frame)
        img = np.array(rgb_frame)
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        #print(rgb_frame)
        return img

In [None]:
# Set SAM2 Configuration
VIDEO_STREAM = f"{HOME}/vid_dr03.mp4"
YOLO_CHECKPOINT_FILEPATH = "yolov8x-seg.pt"
SAM_CHECKPOINT_FILEPATH = "sam2_realtime/checkpoints/sam2.1_hiera_base_plus.pt"
SAM_CONFIG_FILEPATH = "./configs/samurai/sam2.1_hiera_b+.yaml"
OUTPUT_PATH = VIDEO_STREAM + "_segmented.mp4"
DEVICE = 'cuda'

In [None]:
# Open Video Stream
video_stream = cv2.VideoCapture(VIDEO_STREAM)

video_height = int(video_stream.get(cv2.CAP_PROP_FRAME_HEIGHT))
video_width = int(video_stream.get(cv2.CAP_PROP_FRAME_WIDTH))

# For real-time visualization
visualizer = Visualizer(video_width=video_width,
                        video_height=video_height
                        )

In [None]:
from autodistill_grounded_sam_2 import GroundedSAM2
from autodistill.detection import CaptionOntology

def tratar_first_frame(first_frame):

  sky = f'sky'
  sea = f'sea'
  mountain = f'mountain'

  base_model = GroundedSAM2(
	ontology=CaptionOntology(
    	{
        	"sky": f"{sky}",
          "sea": f"{sea}",
          "mountain":f"{mountain}"

    	}
)
)

  sky_mask = base_model.predict(first_frame)

  NUM_OBJECTS = len(sky_mask.mask)

  return sky_mask, NUM_OBJECTS

In [None]:
available_slots = np.inf

video_info = sv.VideoInfo.from_video_path(f"{HOME}/vid_dr03.mp4")

first_frame = True

with sv.VideoSink(OUTPUT_PATH, video_info=video_info) as sink:
  with torch.inference_mode(), torch.autocast('cuda', dtype=torch.bfloat16):
      while video_stream.isOpened():

          ret, frame = video_stream.read()


          if not ret:
              break

          img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

          if first_frame:

              ground_masks, n_obj = tratar_first_frame(img)

              save_centers = []
              save_temps = []

              for mask_idx, mask_result in enumerate(ground_masks.mask):

                x_min, y_min, x_max, y_max = ground_masks.xyxy[mask_idx]

                temp = [[x_min, y_min], [x_max, y_max]]
                save_temps.append(temp)

                cx = int((x_min + x_max)/2)
                cy = int((y_max + y_min)/2)
                org = (cx, cy)
                save_centers.append(org)

              bbox = np.array(save_temps)

              sam = build_sam2_object_tracker(num_objects=n_obj,
                                config_file=SAM_CONFIG_FILEPATH,
                                ckpt_path=SAM_CHECKPOINT_FILEPATH,
                                device=DEVICE,
                                verbose=False
              )

              sam_out = sam.track_new_object(img=img,
                                            mask=ground_masks.mask
                                            )

              first_frame = False

          else:

              sam_out = sam.track_all_objects(img=img)


          final = visualizer.add_frame(frame=frame, mask=sam_out['pred_masks'])

          final = cv2.putText(final, 'sky', save_centers[0], cv2.FONT_HERSHEY_SIMPLEX,
                   1, (255, 0, 0), 2, cv2.LINE_AA)
          final = cv2.putText(final, 'sea', save_centers[1], cv2.FONT_HERSHEY_SIMPLEX,
                   1, (255, 0, 0), 2, cv2.LINE_AA)
          final = cv2.putText(final, 'mountain', save_centers[2], cv2.FONT_HERSHEY_SIMPLEX,
                   1, (255, 0, 0), 2, cv2.LINE_AA)
          sink.write_frame(final)

video_stream.release()