# Utils

In [None]:
# | default_exp utils


In [None]:
# | export

from vid_chains.imports import *
import math

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import os
import skimage.transform as st
import sys
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
from segment_anything.modeling import Sam

In [None]:
#| hide

%load_ext autoreload
%autoreload 2
%reload_ext autoreload

In [None]:
# | export


def load_obj_model(name="yolov8n.pt"):
    return YOLO(name)


def detect_objects(model, img):
    res = model(img, stream=True)
    return [{"boxes": r.boxes.data.detach().cpu().tolist()} for r in res]


def centroid(l):
    t = []
    cx = (l[0] + l[2]) / 2.0
    cy = (l[1] + l[3]) / 2.0
    t.append(cx)
    t.append(cy)
    return t


def list_centroids(objects):
    c = []
    for i in range(0, 11):
        l = []
        for j in range(0, 4):
            l.append(objects[0].get("boxes")[i][j])
        centre = centroid(l)
        c.append(centre)
    return c


def inter_dist(objects):
    c = list_centroids(objects)
    dis = []
    st = []
    for i in range(0, 11):
        for j in range(i + 1, 11):
            # st.append("Distance b/w object "+str(i)+" and object "+str(j))
            # st.append("D("+str(i)+","+str(j)+")")
            dis.append(math.dist(c[i], c[j]))
    # return st,dis
    return dis

In [None]:
# | export

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    # print(mask_image.shape)
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))

def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)


In [None]:
# | export

def get_mask_area(mask:np.ndarray):
  area = mask.sum() # assumes binary mask (True == 1)
  return area

def calculateIoU(gtMask, predMask):
        # Calculate the true positives,
        # false positives, and false negatives
        tp = 0
        fp = 0
        fn = 0

        for i in range(gtMask.shape[0]):
            for j in range(gtMask.shape[1]):
                if gtMask[i][j] == 1 and predMask[i][j] == 1:
                    tp += 1
                elif gtMask[i][j] == 0 and predMask[i][j] == 1:
                    fp += 1
                elif gtMask[i][j] == 1 and predMask[i][j] == 0:
                    fn += 1
        # Calculate IoU
        iou = tp / (tp + fp + fn)

        return iou

def segment_with_prompts(sam_model:Sam, image:np.ndarray, **kwargs):
  h,w,_ = image.shape
  points=np.array([[w*0.5, h*0.5], [0, h], [w, 0], [0,0], [w,h]])
  labels = np.array([1, 0, 0, 0, 0])
  mask = kwargs.get('mask', None)
  mask = st.resize(mask, (256, 256), order=0, preserve_range=True, anti_aliasing=False)
  mask = np.stack((mask,)*1, axis = 0)
  predictor = SamPredictor(sam_model)
  predictor.set_image(image)
  masks, scores, logits = predictor.predict(point_coords=points, point_labels=labels, mask_input=mask, multimask_output=False)
  return masks

def load_sam_model(sam_checkpoint:str = "sam_vit_h_4b8939.pth", model_type:str = "vit_h", device:str = "cuda"):
  sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
  sam.to(device=device)
  return sam

def segment_everything(sam_model:Sam, image:np.ndarray, **kwargs):
  mask = kwargs['mask']
  mask_generator = SamAutomaticMaskGenerator(sam_model)
  masks = mask_generator.generate(image)
  sorted_anns = sorted(masks, key=(lambda x: x['area']), reverse=True)
  best = -1.0
  ind = -100
  area1 = get_mask_area(mask.astype(int))
  for i in range(10):
    val = calculateIoU(mask.astype(int), sorted_anns[i]['segmentation'].astype(int))
    area2 = get_mask_area(sorted_anns[i]['segmentation'].astype(int))
    dif = abs(area2 - area1)
    if val > best and dif < 5000:
      ind = i
      best = val
    elif val > best:
      ind = i
      best = val
  return masks[ind]

def segment(sam_model:Sam, image:np.ndarray, seg_function=segment_with_prompts, **kwargs):
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  mask_fname = kwargs.get('mask_path', 'mask.png')
  mask = cv2.imread(mask_fname)
  mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
  mask = mask.astype(bool)
  masks = seg_function(sam_model, image, mask=mask)
  return masks

In [None]:
# | hide
# | eval: false

# with segment_with_prompts..


frame_dir = 'frames'
frames = os.listdir(frame_dir)
for frame in frames[:5]:
  image = cv2.imread(f'{frame_dir}/{frame}')
  sam = load_sam_model()
  mask_2 = segment(sam_model=sam, image=image, seg_function = segment_with_prompts)
  plt.figure(figsize=(10,10))
  plt.imshow(image)
  show_mask(mask_2, plt.gca())
  plt.show()

In [None]:
# | hide
# | eval: false

# with segment_everything..

frame_dir = 'frames'
frames = os.listdir(frame_dir)
# print(frames)
for frame in frames[:5]:
  image = cv2.imread(f'{frame_dir}/{frame}')
  sam = load_sam_model()
  mask_3 = segment(sam_model=sam, image=image, seg_function = segment_everything)
  plt.figure(figsize=(10,10))
  plt.imshow(image)
  show_mask(mask_3, plt.gca())
  plt.show()

In [None]:
# | hide
# | eval: false

img_path = "../imgs/"

model = load_obj_model()
objects = detect_objects(model, img_path)



image 1/2 /home/hamza/dev/HF/vid_chains/nbs/../imgs/1.jpg: 384x640 6 persons, 1 bicycle, 2 handbags, 2 suitcases, 7.1ms
image 2/2 /home/hamza/dev/HF/vid_chains/nbs/../imgs/2.jpg: 480x640 2 cars, 7 airplanes, 1 truck, 6.3ms
Speed: 2.0ms preprocess, 6.7ms inference, 2.8ms postprocess per image at shape (1, 3, 480, 640)


In [None]:
# | hide
# | eval: false


objects[0]


{'boxes': [[80.27601623535156,
   46.615535736083984,
   192.8482666015625,
   470.9010925292969,
   0.8922310471534729,
   0.0],
  [676.040283203125,
   48.51327133178711,
   795.2723388671875,
   403.0267639160156,
   0.8862534761428833,
   0.0],
  [186.3711700439453,
   146.55426025390625,
   482.23590087890625,
   483.7049255371094,
   0.8681676387786865,
   0.0],
  [420.06683349609375,
   103.76722717285156,
   485.9291687011719,
   326.57281494140625,
   0.8210902810096741,
   0.0],
  [314.42913818359375,
   160.44873046875,
   344.228759765625,
   231.4899444580078,
   0.5700167417526245,
   0.0],
  [601.1113891601562,
   272.0596618652344,
   692.3943481445312,
   372.8212585449219,
   0.5182335376739502,
   28.0],
  [159.09445190429688,
   254.00254821777344,
   204.3885040283203,
   320.135009765625,
   0.32958218455314636,
   26.0],
  [143.58888244628906,
   285.8494567871094,
   394.2103576660156,
   515.9329833984375,
   0.3147803246974945,
   1.0],
  [275.177490234375,
  

In [None]:
# | hide
import nbdev

nbdev.nbdev_export()
