# Utils

In [None]:
# | default_exp utils


In [None]:
# | export

from vid_chains.imports import *
import math

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import os
import skimage.transform as st
import sys
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
from segment_anything.modeling import Sam

import gdown
sys.path.insert(3, os.getcwd()+"/Track_Anything")
sys.path.insert(1, os.getcwd()+"/Track_Anything/tracker")
sys.path.insert(2, sys.path[1]+"/model")
from track_anything import TrackingAnything
from track_anything import parse_augment
import requests
import torchvision

In [None]:
#| hide

%load_ext autoreload
%autoreload 2
%reload_ext autoreload

In [None]:
# | export


def load_obj_model(name="yolov8n.pt"):
    return YOLO(name)


def detect_objects(model, img):
    res = model(img, stream=True)
    return [{"boxes": r.boxes.data.detach().cpu().tolist()} for r in res]


def centroid(l):
    t = []
    cx = (l[0] + l[2]) / 2.0
    cy = (l[1] + l[3]) / 2.0
    t.append(cx)
    t.append(cy)
    return t


def list_centroids(objects):
    c = []
    for i in range(0, 11):
        l = []
        for j in range(0, 4):
            l.append(objects[0].get("boxes")[i][j])
        centre = centroid(l)
        c.append(centre)
    return c


def inter_dist(objects):
    c = list_centroids(objects)
    dis = []
    st = []
    for i in range(0, 11):
        for j in range(i + 1, 11):
            # st.append("Distance b/w object "+str(i)+" and object "+str(j))
            # st.append("D("+str(i)+","+str(j)+")")
            dis.append(math.dist(c[i], c[j]))
    # return st,dis
    return dis

In [None]:
# | export

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    # print(mask_image.shape)
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))

def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)


In [None]:
# | export

def get_mask_area(mask:np.ndarray):
  area = mask.sum() # assumes binary mask (True == 1)
  return area

def calculateIoU(gtMask, predMask):
        # Calculate the true positives,
        # false positives, and false negatives
        tp = 0
        fp = 0
        fn = 0

        for i in range(gtMask.shape[0]):
            for j in range(gtMask.shape[1]):
                if gtMask[i][j] == 1 and predMask[i][j] == 1:
                    tp += 1
                elif gtMask[i][j] == 0 and predMask[i][j] == 1:
                    fp += 1
                elif gtMask[i][j] == 1 and predMask[i][j] == 0:
                    fn += 1
        # Calculate IoU
        iou = tp / (tp + fp + fn)

        return iou

def segment_with_prompts(sam_model:Sam, image:np.ndarray, **kwargs):
  h,w,_ = image.shape
  points = kwargs.get('points', np.array([[w*0.5, h*0.5], [0, h], [w, 0], [0,0], [w,h]]))
  labels = kwargs.get('labels', np.array([1, 0, 0, 0, 0]))
  mask = kwargs.get('mask', None)
  mask = st.resize(mask, (256, 256), order=0, preserve_range=True, anti_aliasing=False)
  mask = np.stack((mask,)*1, axis = 0)
  predictor = SamPredictor(sam_model)
  predictor.set_image(image)
  masks, scores, logits = predictor.predict(point_coords=points, point_labels=labels, mask_input=mask, multimask_output=False)
  return masks

def load_sam_model(sam_checkpoint:str = "sam_vit_h_4b8939.pth", model_type:str = "vit_h", device:str = "cuda"):
  sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
  sam.to(device=device)
  return sam

def segment_everything(sam_model:Sam, image:np.ndarray, **kwargs):
  mask = kwargs['mask']
  mask_generator = SamAutomaticMaskGenerator(sam_model)
  masks = mask_generator.generate(image)
  sorted_anns = sorted(masks, key=(lambda x: x['area']), reverse=True)
  best = -1.0
  ind = -100
  area1 = get_mask_area(mask.astype(int))
  for i in range(10):
    val = calculateIoU(mask.astype(int), sorted_anns[i]['segmentation'].astype(int))
    area2 = get_mask_area(sorted_anns[i]['segmentation'].astype(int))
    dif = abs(area2 - area1)
    if val > best and dif < 5000:
      ind = i
      best = val
    elif val > best:
      ind = i
      best = val
  return masks[ind]

def segment(sam_model:Sam, image:np.ndarray, seg_function=segment_with_prompts, **kwargs):
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  mask_fname = kwargs.get('mask_path', 'mask.png')
  mask = cv2.imread(mask_fname)
  mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
  mask = mask.astype(bool)
  h,w,_ = image.shape
  points = kwargs.get('points', np.array([[w*0.5, h*0.5], [0, h], [w, 0], [0,0], [w,h]]))
  labels = kwargs.get('labels', np.array([1, 0, 0, 0, 0]))
  masks = seg_function(sam_model, image, mask=mask, points=points, labels=labels)
  return masks

In [None]:
# | expprt

# download checkpoints
def download_checkpoint(url, folder, filename):
    os.makedirs(folder, exist_ok=True)
    filepath = os.path.join(folder, filename)

    if not os.path.exists(filepath):
        print("download checkpoints ......")
        response = requests.get(url, stream=True)
        with open(filepath, "wb") as f:
            for chunk in response.iter_content(chunk_size=8192):
                if chunk:
                    f.write(chunk)

        print("download successfully!")

    return filepath

def download_checkpoint_from_google_drive(file_id, folder, filename):
    os.makedirs(folder, exist_ok=True)
    filepath = os.path.join(folder, filename)

    if not os.path.exists(filepath):
        print("Downloading checkpoints from Google Drive... tips: If you cannot see the progress bar, please try to download it manuall \
              and put it in the checkpointes directory. E2FGVI-HQ-CVPR22.pth: https://github.com/MCG-NKU/E2FGVI(E2FGVI-HQ model)")
        url = f"https://drive.google.com/uc?id={file_id}"
        gdown.download(url, filepath, quiet=False)
        print("Downloaded successfully!")

    return filepath

# generate video after vos inference
def generate_video_from_frames(frames:list, output_path:str, fps:int=30):
    """
    Generates a video from a list of frames.

    Args:
        frames (list of numpy arrays): The frames to include in the video.
        output_path (str): The path to save the generated video.
        fps (int, optional): The frame rate of the output video. Defaults to 30.
    """
    # height, width, layers = frames[0].shape
    # fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    # video = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
    # print(output_path)
    # for frame in frames:
    #     video.write(frame)

    # video.release()
    frames = torch.from_numpy(np.asarray(frames))
    if not os.path.exists(os.path.dirname(output_path)):
        os.makedirs(os.path.dirname(output_path))
    torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
    return output_path

def generate_frames_from_video(video_path:str, start_time:int):
  frames = []
  try:
      cap = cv2.VideoCapture(video_path)
      cap.set(cv2.CAP_PROP_POS_MSEC, start_time*1000)
      fps = cap.get(cv2.CAP_PROP_FPS)
      while cap.isOpened():
          ret, frame = cap.read()
          if ret == True:
              frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
          else:
              break
  except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
      print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
  return frames

def track_object(images:list, points:np.ndarray, labels:np.ndarray, e2fgvi_checkpoint:str, sam_checkpoint:str, xmem_checkpoint:str, **kwargs):
  sys.argv = ["cuda:0"]
  args = parse_augment()
  multimask = kwargs.get('multimask', True)
  track_model = TrackingAnything(sam_checkpoint, xmem_checkpoint, e2fgvi_checkpoint, args)
  track_model.samcontroler.sam_controler.reset_image()
  track_model.samcontroler.sam_controler.set_image(images[0])
  mask,_,_ = track_model.first_frame_click(image = images[0], points = points, labels = labels, multimask = multimask)
  masks, logits ,painted_images= track_model.generator(images, mask)
  return masks, logits, painted_images

In [None]:
# | hide
# | eval: false


def get_points(img:np.ndarray):
    yolo = load_obj_model()
    img = cv2.imread()
    boxes = detect_objects(img, yolo)
    points = []
    labels = []
    for box in boxes['boxes']:
        x1, y1, x2, y2 = box
        mid_x = int(x1+((x2-x1)/2))
        mid_y = int(y1+((y2-y1)/2))
        points.append(mid_x, mid_y)
        labels.append(1)
    return points, labels

In [None]:
# | hide
# | eval: false

# with segment_with_prompts..


frame_dir = 'frames'
frames = os.listdir(frame_dir)
for frame in frames[:5]:
  image = cv2.imread(f'{frame_dir}/{frame}')
  sam = load_sam_model()
  mask_2 = segment(sam_model=sam, image=image, seg_function = segment_with_prompts)
  plt.figure(figsize=(10,10))
  plt.imshow(image)
  show_mask(mask_2, plt.gca())
  plt.show()

In [None]:
# | hide
# | eval: false

# with segment_everything..

frame_dir = 'frames'
frames = os.listdir(frame_dir)
# print(frames)
for frame in frames[:5]:
  image = cv2.imread(f'{frame_dir}/{frame}')
  sam = load_sam_model()
  mask_3 = segment(sam_model=sam, image=image, seg_function = segment_everything)
  plt.figure(figsize=(10,10))
  plt.imshow(image)
  show_mask(mask_3, plt.gca())
  plt.show()

In [None]:
# | hide
# | eval: false

# Clone the repository:
# !git clone https://github.com/gaomingqi/Track-Anything.git
# %cd /content/Track-Anything

# Install dependencies:
# !pip install -r requirements.txt
# new libraries: progressbar2 gdown gitpython openmim av hickle tqdm psutil gradio

In [None]:
# | hide
# | eval: false

# check and download checkpoints if needed
SAM_checkpoint_dict = {
    'vit_h': "sam_vit_h_4b8939.pth",
    'vit_l': "sam_vit_l_0b3195.pth",
    "vit_b": "sam_vit_b_01ec64.pth"
}
SAM_checkpoint_url_dict = {
    'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
    'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
    'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
}
sam_checkpoint = SAM_checkpoint_dict['vit_h']
sam_checkpoint_url = SAM_checkpoint_url_dict['vit_h']
xmem_checkpoint = "XMem-s012.pth"
xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
e2fgvi_checkpoint = "E2FGVI-HQ-CVPR22.pth"
e2fgvi_checkpoint_id = "10wGdKSUOie0XmCr8SQ2A2FeDe-mfn5w3"

folder = "./checkpoints"
sam_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)



In [None]:
# | hide
# | eval: false

# extract frames from the video...
frames = generate_frames_from_video('vid_shorts.mp4', start_time=1)

In [None]:
# | hide
# | eval: false

h, w, _ = frames[0].shape
points=np.array([[int(w*0.5), int(h*0.5)], [0, h-10], [w-10, 0], [0,0], [w-10,h-10]])
labels = np.array([1, 0, 0, 0, 0])
# Track the masked object using point prompt..
masks, logits, painted_images = track_object(frames, points = points, labels = labels, e2fgvi_checkpoint = e2fgvi_checkpoint, sam_checkpoint = sam_checkpoint, xmem_checkpoint = xmem_checkpoint)



In [None]:
# | hide
# | eval: false

# Save the return frames in the form of a video..
output_path = 'output.mp4'
output_path = generate_video_from_frames(frames=frames, output_path=output_path)

In [None]:
# | hide
# | eval: false

img_path = "../imgs/"

model = load_obj_model()
objects = detect_objects(model, img_path)


In [None]:
# | hide
# | eval: false


objects[0]


In [None]:
# | hide
import nbdev

nbdev.nbdev_export()
