# Utils

In [None]:
# | default_exp utils


In [None]:
# | export
from vid_chains.imports import *

In [None]:
# | hide

%load_ext autoreload
%autoreload 2
%reload_ext autoreload

In [None]:
# | export

def load_obj_model(
        yolo_name:str = "yolov8n.pt",
        sam_name:str = "sam_b.pt", 
        task:str = "detect"
    ):
    if task == "detect":
        return YOLO(yolo_name)
    elif task == "segment":
        return SAM(sam_name)
    print(f"Model does not exist for the following task: {task}. Please select one of the following tasks: detect or segment")
    return

def detect_objects(
        model,
        image:Union[np.ndarray, str],
        stream:bool = True,
        task:str = "detect",
        conf:float = 0.25,
        iou:float = 0.7,
        augment:bool = False,
        imgsz:int = 640, 
        names:list = None, 
        exclude:list = None,   
        return_only_boxes:bool = True, 
        points:list = None, 
        labels:list = None, 
        bboxes:list = None
    ):
    results = None
    if task == "detect": # Using the yolo model
        classes_dict = {0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'}
        classes = list(classes_dict.keys())
        if names:
            classes = [list(classes_dict.keys())[list(classes_dict.values()).index(name.lower())] for name in names]
        if exclude:
            exclude = [e.lower() for e in exclude]
            names = [v for v in list(classes_dict.values()) if not v in exclude]
            classes = [list(classes_dict.keys())[list(classes_dict.values()).index(name)] for name in names]
        # print(classes)
        results = model(image, stream=stream, classes=classes, conf=conf, iou=iou, augment=augment, imgsz=imgsz)
        if return_only_boxes:
            return [{"boxes": r.boxes.data.detach().cpu().tolist()} for r in results]
    elif task == "segment": # Using the SAM model..
        results = model(image, stream = stream, bboxes = bboxes, points = points, labels = labels)
    detections = None
    if results:
        for result in results:
            detections = sv.Detections.from_ultralytics(result)
        return detections
    print(f"Can not process the following task: {task}. Please select one of the following tasks: detect or segment")
    return

    
def annotateImage(
        image:np.ndarray, 
        results, 
        draw_bbox:bool = True,
        draw_mask:bool = False, 
        label:bool = True,
        conf_thresh:float = 0.0, 
        names:list = None, 
        exclude:list = None, 
        area_thresh:int = 0
    ):
    labels = None
    detections=results
    if label:
        classes_dict = {0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'}
        classes = list(classes_dict.keys())
        if names:
            classes = [list(classes_dict.keys())[list(classes_dict.values()).index(name.lower())] for name in names]
        if exclude:
            exclude = [e.lower() for e in exclude]
            names = [v for v in list(classes_dict.values()) if not v in exclude]
            classes = [list(classes_dict.keys())[list(classes_dict.values()).index(name)] for name in names]
        detections = detections[detections.confidence > conf_thresh]
        detections = detections[np.isin(detections.class_id, classes)]
        labels = [
            f"{classes_dict[class_id]} {confidence:.2f}"
            for class_id, confidence
            in zip(detections.class_id, detections.confidence)
        ]
    detections = detections[detections.area > area_thresh]
    annotated_image = image.copy()
    if draw_bbox:
        box_annotator = sv.BoxAnnotator()
        annotated_image = box_annotator.annotate(annotated_image, detections, labels=labels)
    if draw_mask:
        mask_annotator = sv.MaskAnnotator()
        annotated_image = mask_annotator.annotate(scene=annotated_image, detections=detections)
    return annotated_image

def get_width(l):
    w = l[2] - l[0]
    return w


def list_widths(obj):
    w = []
    for i in range(0, len(obj.get("boxes"))):
        l = []
        for j in range(0, 6):
            l.append(obj.get("boxes")[i][j])
        if (
            l[5] == 39.0
        ):  # very specific test case for bottles so will ignore other objects, will remove this in the future
            width = get_width(l)
            w.append(width)
    return w


def centroid(l):
    t = []
    cx = (l[0] + l[2]) / 2.0
    cy = (l[1] + l[3]) / 2.0
    t.append(cx)
    t.append(cy)
    return t


def list_centroids(obj):
    c = []
    for i in range(0, len(obj.get("boxes"))):
        l = []
        for j in range(0, 4):
            l.append(obj.get("boxes")[i][j])
        centre = centroid(l)
        c.append(centre)
    return c


def inter_dist(obj):
    c = list_centroids(obj)
    dis = []
    st = []
    for i in range(0, len(c)):
        for j in range(i + 1, len(c)):
            # st.append("Distance b/w object "+str(i)+" and object "+str(j))
            # st.append("D("+str(i)+","+str(j)+")")
            dis.append(math.dist(c[i], c[j]))
    # return st,dis
    return dis


def focal_len_to_px(focal_len, sensor_px):
    return round((focal_len / sensor_px) * 1000)


def camera_to_obj_dist(focal_length_px, obj, real_width):
    widths = list_widths(obj)
    dists = []
    for w in widths:
        distance = (real_width * focal_length_px) / w
        dists.append(distance)

    return dists

In [None]:
# | export


# Extract direction and speed from the selected objects using RAFT (optical flow algorithm)..


def get_points(
    yolo,
    names:list,
    img: Union[str, np.ndarray],
    annotate: bool = False,
    return_img: bool = False,
    stream: bool = True,
):
    detections = detect_objects(model=yolo, image=img, stream=stream, task="detect", names=names, return_only_boxes=False)
    points = []
    labels = []
    boxes = detections.xyxy
    for box in boxes:
        x1, y1, x2, y2 = box
        mid_x = int(x1 + ((x2 - x1) / 2))
        mid_y = int(y1 + ((y2 - y1) / 2))
        points.append([mid_x, mid_y])
        labels.append(1)  #
    if annotate:
        annotated_image = annotateImage(image=img, results=detections, draw_bbox=True, draw_mask=False)
    if return_img:
        return annotated_image, boxes, points, labels
    return boxes, points, labels

def draw_arrow(img: np.ndarray, mean_u, mean_v, points):
    h_rat = 10
    w_rat = 10
    image_arr = cv2.arrowedLine(
        img=img,
        pt1=(points[0], points[1]),
        pt2=(points[0] + int(mean_u) * w_rat, points[1] + int(mean_v) * h_rat),
        color=(0, 0, 255),
        thickness=5,
        line_type=8,
        tipLength=0.5,
    )
    return image_arr


def get_velocity(
    img1: Union[str, np.ndarray],
    img2: Union[str,np.ndarray],
    boxes: list,
    res: np.ndarray,
    model = None,
    save_img: bool = True,
    out_dir: str = "./frames/",
    config_file: str = "raft_8x2_50k_kitti2015_288x960.py",
    checkpoint_file: str = "raft_8x2_50k_kitti2015_288x960.pth",
    device: str = "cpu",
):
    if model == None:
        model = init_model(config_file, checkpoint_file, device=device)
    result = inference_model(model, img1, img2)
    img = res
    vel = []
    flow_map = None
    for box in boxes:
        x1, y1, x2, y2 = box
        mid_x = int(x1 + ((x2 - x1) / 2))
        mid_y = int(y1 + ((y2 - y1) / 2))
        flows_u = result[int(y1) : int(y2), int(x1) : int(x2), 0]
        flows_v = result[int(y1) : int(y2), int(x1) : int(x2), 1]
        mean_u = flows_u.mean()
        mean_v = flows_v.mean()
        img = draw_arrow(img, mean_u, mean_v, (mid_x, mid_y))
        flow_map = visualize_flow(result, save_file=f"{out_dir}/flow_map.png")
        vel.append(math.sqrt(pow(mean_u, 2) + (pow(mean_v, 2))))
    if save_img:
        cv2.imwrite(f"{out_dir}/arrow_and_box.png", img)
    return vel, img, flow_map

def infer_video(
    video_path:str,
    names:list = ["person", "car", "airplane"],
    config_file: str = "raft_8x2_50k_kitti2015_288x960.py",
    checkpoint_file: str = "raft_8x2_50k_kitti2015_288x960.pth",
    device:str="cpu",
):    
    cap = cv2.VideoCapture(video_path)
    ret, frame1 = cap.read()
    if not ret:
        print("ERROR! In Reading the video file..")
        return
    frames = []
    yolo = load_obj_model(task="detect")
    model = init_model(config=config_file, checkpoint=checkpoint_file, device=device)
    speeds = []
    fps = cap.get(cv2.CAP_PROP_FPS)
    while cap.isOpened():
        ret, frame2 = cap.read()
        if ret:
            img, boxes, _, _ = get_points(yolo=yolo, names=names, img=frame2, annotate=True, return_img=True, stream=False)
            speed, img, _ = get_velocity(img1=frame1, img2=frame2, boxes=boxes, res=img, model=model, save_img=False)
            frames.append(img)
            speeds.append(speed)
        else:
            break
        frame1 = frame2
    return frames, speeds, fps

In [None]:
# | export

def generate_video(
        frames:list, 
        fps:int, 
        video_path:str
    ):
    out = cv2.VideoWriter(video_path,cv2.VideoWriter_fourcc(*'mp4v'), fps, (frames[0].shape[1] ,frames[0].shape[0]))
    for i in range(len(frames)):
        out.write(frames[i].astype(np.uint8))
    out.release()

In [None]:
# | hide
# | eval: false

img = cv2.imread("sample_1.jpg")
obj_model = load_obj_model(task="detect")
yolo_results = detect_objects(model=obj_model, task="detect", image=img, return_only_boxes=False)
annotated_image = annotateImage(image=img, results=yolo_results, draw_bbox=True, draw_mask=False, conf_thresh=0.5, area_thresh=200)
cv2.imwrite("yolo_test.jpg", annotated_image)

# img = cv2.imread("sample_1.jpg")
# obj_model = load_obj_model(task="segment")
# sam_results = detect_objects(model=obj_model, task="segment", image=img, return_only_boxes=False)
# annotated_image = annotateImage(image=img, results=sam_results, draw_bbox=False, draw_mask=True, conf_thresh=0.5, area_thresh=200)
# cv2.imwrite("sam_test.jpg", annotated_image)

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79]



0: 480x640 6 persons, 6 cars, 2 handbags, 283.1ms
Speed: 28.8ms preprocess, 283.1ms inference, 1.9ms postprocess per image at shape (1, 3, 480, 640)


True

In [None]:
# | export

def intersection_area(a, b):  # returns None if rectangles don't intersect
    dx = min(a[2], b[2]) - max(a[0], b[0])
    dy = min(a[3], b[3]) - max(a[1], b[1])
    # print(dx, dy)
    if (dx>=0) and (dy>=0):
        return dx*dy


def detect_obstacles(img:np.ndarray,
                     target_box:list, 
                     factor:float = 0.5,
                     conf:float = 0.25, 
                     iou:float = 0.7, 
                     imgsz:int = 640,
                     augment:bool = False,
                     model = None, 
                     objects = ["car"], 
                     alpha:float = 0.4
    ):
    # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    obstacles_list = []
    if model:
        # result = detect_objects(model=model, image=img, stream=True, task="detect", return_only_boxes=False, names=objects)
        result = detect_objects(model=model, task="detect", image=img, return_only_boxes=False)
    else:
        result = detect_objects(model=load_obj_model(), image=img, stream=True, task="detect", return_only_boxes=False, names=objects, conf=conf, iou=iou, augment=augment, imgsz=imgsz)
    
    target_area = (target_box[2]-target_box[0]) * (target_box[3]-target_box[1])
    img = annotateImage(image=img, results=result, label=True)
    i = 0
    res_img = None
    for box in result.xyxy:
        
        inter_area = intersection_area(box, target_box)
        if inter_area:
            ratio = inter_area / target_area
            if ratio >= factor:
                obstacles_list.append(result[i])
                # img = cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), color = (255, 0, 0), thickness=2)
        i+=1
    if len(obstacles_list) == 0:
        res_img = cv2.rectangle(img.copy(), (int(target_box[0]), int(target_box[1])), (int(target_box[2]), int(target_box[3])), color = (0,255,0), thickness=2)
    else:
        res_img = cv2.rectangle(img.copy(), (int(target_box[0]), int(target_box[1])), (int(target_box[2]), int(target_box[3])), (255,0,0), 2)
    # res_img = cv2.addWeighted(res_img, alpha, img, 1 - alpha, 0)
    return obstacles_list, res_img



In [None]:
# | export

def obstacle_avoidance(video_path:str, 
                       target_box:list, 
                       factor:float=0.01, 
                       objects:list = ["airplane", "car", "person", "bus", "truck"]
    ):
    cap = cv2.VideoCapture(video_path)
    yolo = load_obj_model(task="detect")

    frames = []
    fps = cap.get(cv2.CAP_PROP_FPS)
    while cap.isOpened():
        ret, frame = cap.read()
        if ret:
            obstacle_list, res_img = detect_obstacles(img=frame, target_box=target_box, factor=factor, model=yolo, objects=objects)
            # plt.imshow(res_img)
            res_img = cv2.cvtColor(res_img, cv2.COLOR_BGR2RGB)
            frames.append(res_img)
            print("frame No:", len(frames))
        else:
            break
    return frames, fps, obstacle_list

In [None]:
# | hide
# | eval: false

# img = cv2.imread(filename="/media/ali/A4D00431D0040BEC/ALI/Ali/dreamai_pocs/JBT/SAM/sample_1.jpg")
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# # plt.imshow(img)
# model = load_obj_model()
# objs = detect_objects(model=model, image=img, return_only_boxes=False, names=["person"])
# # print(objs[0].xyxy[0])
# obstacles, res_img = detect_obstacles(img=img, target_box=objs[0].xyxy[0], factor=0.1)
# print(obstacles)
# plt.imshow(res_img)


# video_path = "/media/ali/A4D00431D0040BEC/ALI/Ali/dreamai_pocs/JBT/SAM/example_obs_det.mp4"
# target_box = [2000, 1600, 2250, 1700]
# video_path = "/media/ali/A4D00431D0040BEC/ALI/Ali/dreamai_pocs/JBT/SAM/camera-104-2023-07-05_15-36-29.ts"
# target_box = [300, 20, 520, 320]
# video_path = "/media/ali/A4D00431D0040BEC/ALI/Ali/dreamai_pocs/JBT/SAM/jbt_videos/camera-163-2023-07-05_22-12-55.ts"
# target_box = [250, 0, 448,  800]
# (800, 448, 3)
# video_path = "/media/ali/A4D00431D0040BEC/ALI/Ali/dreamai_pocs/JBT/SAM/jbt_videos/camera-159-2023-07-05_23-29-46.ts"
# target_box = [250, 0, 448, 800]

# video_path = "/media/ali/A4D00431D0040BEC/ALI/Ali/dreamai_pocs/JBT/SAM/jbt_videos/camera-107-2023-07-14_01-50-52.ts"
# target_box = [250, 0, 448, 800]
# video_path = "/media/ali/A4D00431D0040BEC/ALI/Ali/dreamai_pocs/JBT/SAM/jbt_videos/camera-107-2023-07-05_22-35-27.ts"
# target_box = [250, 0, 448, 800]

video_path = "/media/ali/A4D00431D0040BEC/ALI/Ali/dreamai_pocs/JBT/SAM/camera-102-2023-07-07_00-11-27.ts"
target_box = [150, 450, 300, 700]
# video_path = "/media/ali/A4D00431D0040BEC/ALI/Ali/dreamai_pocs/JBT/SAM/jbt_videos/camera-32-2023-07-12_11-35-06.ts"
# target_box = [0, 200, 512, 384]
# (384, 512, 3)


frames, fps, obstacle_list = obstacle_avoidance(video_path=video_path, target_box=target_box, factor=0.01)
generate_video(video_path="example_det_obs.mp4", frames=frames, fps=fps)

In [None]:
# | hide
# | eval: false

frames, vels, fps = infer_video(video_path="camera-104-2023-07-05_15-34-29.ts",names=["person", "car", "bus", "truck", "suitcase", "backpack", "handbag", "airplane"])
generate_video(frames=frames, fps=fps, video_path="output_arrow_hahah.mp4")


load checkpoint from local path: raft_8x2_50k_kitti2015_288x960.pth



0: 384x640 2 persons, 96.3ms
Speed: 2.3ms preprocess, 96.3ms inference, 0.8ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 2 persons, 116.2ms
Speed: 4.3ms preprocess, 116.2ms inference, 0.9ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 2 persons, 100.8ms
Speed: 2.1ms preprocess, 100.8ms inference, 1.2ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 2 persons, 113.5ms
Speed: 2.2ms preprocess, 113.5ms inference, 0.9ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 2 persons, 120.9ms
Speed: 2.8ms preprocess, 120.9ms inference, 0.9ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 2 persons, 160.4ms
Speed: 4.5ms preprocess, 160.4ms inference, 3.6ms postprocess per image at shape (1, 3, 384, 640)


In [None]:
# | hide
# | eval: false

# Object Segmentation with SAM...


def get_mask_area(mask: np.ndarray):
    area = mask.sum()  # assumes binary mask (True == 1)
    return area


def calculateIoU(gtMask, predMask):
    # Calculate the true positives,
    # false positives, and false negatives
    tp = 0
    fp = 0
    fn = 0

    for i in range(gtMask.shape[0]):
        for j in range(gtMask.shape[1]):
            if gtMask[i][j] == 1 and predMask[i][j] == 1:
                tp += 1
            elif gtMask[i][j] == 0 and predMask[i][j] == 1:
                fp += 1
            elif gtMask[i][j] == 1 and predMask[i][j] == 0:
                fn += 1
    # Calculate IoU
    iou = tp / (tp + fp + fn)

    return iou


def segment_with_prompts(sam_model: Sam, image: np.ndarray, **kwargs):
    h, w, _ = image.shape
    points = kwargs.get(
        "points", np.array([[w * 0.5, h * 0.5], [0, h], [w, 0], [0, 0], [w, h]])
    )
    labels = kwargs.get("labels", np.array([1, 0, 0, 0, 0]))
    mask = kwargs.get("mask", None)
    if mask != None:
      mask = st.resize(mask, (256, 256), order=0, preserve_range=True, anti_aliasing=False)
      mask = np.stack((mask,) * 1, axis=0)
    predictor = SamPredictor(sam_model)
    predictor.set_image(image)
    masks, scores, logits = predictor.predict(
        point_coords=points, point_labels=labels, mask_input=mask, multimask_output=False
    )
    return masks


def load_sam_model(
    sam_checkpoint: str = "sam_vit_h_4b8939.pth",
    model_type: str = "vit_h",
    device: str = "cuda",
):
    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device=device)
    return sam

def segment_everything(sam_model:Sam, image:np.ndarray, **kwargs):
    mask = kwargs["mask"]
    mask_generator = SamAutomaticMaskGenerator(sam_model)
    masks = mask_generator.generate(image)
    if mask == None:
        return masks
    sorted_anns = sorted(masks, key=(lambda x: x['area']), reverse=True)
    best = -1.0
    ind = -100

    area1 = get_mask_area(mask.astype(int))
    for i in range(10):
        val = calculateIoU(mask.astype(int), sorted_anns[i]['segmentation'].astype(int))
        area2 = get_mask_area(sorted_anns[i]['segmentation'].astype(int))
        dif = abs(area2 - area1)
        if val > best and dif < 5000:
            ind = i
            best = val
        elif val > best:
            ind = i
            best = val
    return sorted_anns[ind]

def segment(sam_model:Sam, image:np.ndarray, seg_function=segment_with_prompts, **kwargs):
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  mask_fname = kwargs.get("mask_path", None)
  mask = None
  if mask_fname != None:
    mask = cv2.imread(mask_fname)
    mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
    mask = mask.astype(bool)
  h,w,_ = image.shape
  points = kwargs.get(
        "points", np.array([[w * 0.5, h * 0.5], [0, h], [w, 0], [0, 0], [w, h]])
  )
  labels = kwargs.get("labels", np.array([1, 0, 0, 0, 0]))

  masks = seg_function(sam_model, image, mask=mask, points=points, labels=labels)
  return masks

In [None]:
# | hide
# | eval: false

def get_points(
    yolo, 
    img: Union[str, np.ndarray],
    draw_bbox: bool = False,
    return_img: bool = False,
    stream: bool = True,
):
    result = detect_objects(model=yolo, img=img, stream=stream, draw_bbox=draw_bbox)
    points = []
    labels = []
    for box in result[0]["boxes"]:
        x1, y1, x2, y2 = box[:4]
        mid_x = int(x1 + ((x2 - x1) / 2))
        mid_y = int(y1 + ((y2 - y1) / 2))
        points.append([mid_x, mid_y])
        labels.append(1)  #
    if return_img:
        return result[0]["img"], result[0]["boxes"], points, labels
    return result[0]["boxes"], points, labels


# Extract direction and speed from the selected objects using RAFT (optical flow algorithm)..


def display_direction(result: np.ndarray, mean_u, mean_v, points):
    h_rat = 10
    w_rat = 10
    image_arr = cv2.arrowedLine(
        img=result,
        pt1=(points[0], points[1]),
        pt2=(points[0] + int(mean_u) * w_rat, points[1] + int(mean_v) * h_rat),
        color=(0, 0, 255),
        thickness=5,
        line_type=8,
        tipLength=0.5,
    )
    return image_arr


def get_velocity(
    img1: Union[str, np.ndarray],
    img2: Union[str,np.ndarray],
    boxes: list,
    res: np.ndarray,
    model=None,
    save_img: bool = True,
    out_dir: str = "./frames/",
    config_file: str = "raft_8x2_50k_kitti2015_288x960.py",
    checkpoint_file: str = "raft_8x2_50k_kitti2015_288x960.pth",
    device: str = "cpu",
):
    if model == None:
        model = init_model(config_file, checkpoint_file, device=device)
    result = inference_model(model, img1, img2)
    img = res
    vel = 0
    flow_map = None
    # print(boxes)
    for box in boxes:
        x1, y1, x2, y2 = box[:4]
        mid_x = int(x1 + ((x2 - x1) / 2))
        mid_y = int(y1 + ((y2 - y1) / 2))
        flows_u = result[int(y1) : int(y2), int(x1) : int(x2), 0]
        flows_v = result[int(y1) : int(y2), int(x1) : int(x2), 1]
        mean_u = flows_u.mean()
        mean_v = flows_v.mean()
        img = display_direction(img, mean_u, mean_v, (mid_x, mid_y))
        flow_map = visualize_flow(result, save_file="flow_map.png")
        vel = math.sqrt(pow(mean_u, 2) + (pow(mean_v, 2)))
    if save_img:
        cv2.imwrite("arrow_and_box.png", img)
    return vel, img, flow_map

def infer_video(video_path:str,
    names:list = ["person", "car", "airplane"],
    config_file: str = "raft_8x2_50k_kitti2015_288x960.py",
    checkpoint_file: str = "raft_8x2_50k_kitti2015_288x960.pth",
    device:str="cpu",
):    
    print(names)
    cap = cv2.VideoCapture(video_path)
    ret, frame1 = cap.read()
    if not ret:
        print("ERROR! In Reading the video file..")
        return
    frames = []
    yolo = load_obj_model()
    model = init_model(config=config_file, checkpoint=checkpoint_file, device=device)
    speeds = []
    fps = cap.get(cv2.CAP_PROP_FPS)
    i = 0
    while cap.isOpened():
        ret, frame2 = cap.read()
        if ret:
            # frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
            img, boxes, _, _ = get_points(yolo=yolo, names=names,img=frame2, draw_bbox=True, return_img=True)
            speed, img, _ = get_velocity(img1=frame1, img2=frame2, boxes=boxes, res=img, model=model, save_img=False)
            frames.append(img)
            speeds.append(speed)
        else:
            break
        i+=1
        frame1 = frame2
    return frames, speeds, fps





In [None]:
# | hide
# | eval: false

frames, vels, fps = infer_video(video_path="/home/trillo3/projects/vid_chains/static_files/104/archive/2023/07/05/camera-104-2023-07-05_15-34-29.ts",names=["person", "car", "bus", "truck", "suitcase", "backpack", "handbag", "airplane"])


In [None]:
# | hide
# | eval: false

def load_obj_model(name="yolov8n.pt"):
    return YOLO(name)


def detect_objects(model, img, names=None, exclude=None,  stream=True, draw_bbox=False, return_only_boxes=True):
    dict = {0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'}
    classes = list(dict.keys())
    if names:
        classes = [list(dict.keys())[list(dict.values()).index(name.lower())] for name in names]
    if exclude:
        exclude = [e.lower() for e in exclude]
        names = [v for v in list(dict.values()) if not v in exclude]
        classes = [list(dict.keys())[list(dict.values()).index(name)] for name in names]
    res = model(img, stream=stream, classes=classes)
    # if draw_bbox:
        # return [{"boxes": r.boxes.data.detach().cpu().tolist(),"img": r.plot()} for r in res]
    if return_only_boxes:
        return [{"boxes": r.boxes.data.detach().cpu().tolist()} for r in res]
    return res



In [None]:
# | hide
# | eval: false

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels == 1]
    neg_points = coords[labels == 0]
    ax.scatter(
        pos_points[:, 0],
        pos_points[:, 1],
        color="green",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )
    ax.scatter(
        neg_points[:, 0],
        neg_points[:, 1],
        color="red",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )


def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(
        plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2)
    )


def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones(
        (sorted_anns[0]["segmentation"].shape[0], sorted_anns[0]["segmentation"].shape[1], 4)
    )
    img[:, :, 3] = 0
    for ann in sorted_anns:
        m = ann["segmentation"]
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)

def combine_mask(image:np.ndarray, mask:np.ndarray, color:tuple=None):
    if color == None:
        color = (30, 144, 255)
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h,w)
    image[mask_image] = color
    return image

In [None]:
# | hide
# | eval: false

# with segment_everything..

# frame_dir = 'frames'
# frames = os.listdir(frame_dir)
# # print(frames)
# for frame in frames[:5]:
#   image = cv2.imread(f'{frame_dir}/{frame}')
#   sam = load_sam_model()
#   mask_3 = segment(sam_model=sam, image=image, seg_function = segment_everything)
#   plt.figure(figsize=(10,10))
#   plt.imshow(image)
#   show_mask(mask_3, plt.gca())
#   plt.show()

In [None]:
# | hide
# | eval: false

# Clone the repository:
# !git clone https://github.com/gaomingqi/Track-Anything.git
# %cd /content/Track-Anything

# Install dependencies:
# !pip install -r requirements.txt
# new libraries: progressbar2 gdown gitpython openmim av hickle tqdm psutil gradio

In [None]:
# | hide
# | eval: false

# Object tracking with TAN..

# download checkpoints
def download_checkpoint(url, folder, filename):
    os.makedirs(folder, exist_ok=True)
    filepath = os.path.join(folder, filename)

    if not os.path.exists(filepath):
        print("download checkpoints ......")
        response = requests.get(url, stream=True)
        with open(filepath, "wb") as f:
            for chunk in response.iter_content(chunk_size=8192):
                if chunk:
                    f.write(chunk)

        print("download successfully!")

    return filepath

def download_checkpoint_from_google_drive(file_id, folder, filename):
    os.makedirs(folder, exist_ok=True)
    filepath = os.path.join(folder, filename)

    if not os.path.exists(filepath):
        print("Downloading checkpoints from Google Drive... tips: If you cannot see the progress bar, please try to download it manuall \
              and put it in the checkpointes directory. E2FGVI-HQ-CVPR22.pth: https://github.com/MCG-NKU/E2FGVI(E2FGVI-HQ model)")
        url = f"https://drive.google.com/uc?id={file_id}"
        gdown.download(url, filepath, quiet=False)
        print("Downloaded successfully!")

    return filepath

# generate video after vos inference
def generate_video_from_frames(frames:list, output_path:str, fps:int=30):
    """
    Generates a video from a list of frames.

    Args:
        frames (list of numpy arrays): The frames to include in the video.
        output_path (str): The path to save the generated video.
        fps (int, optional): The frame rate of the output video. Defaults to 30.
    """
    # height, width, layers = frames[0].shape
    # fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    # video = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
    # print(output_path)
    # for frame in frames:
    #     video.write(frame)

    # video.release()
    frames = torch.from_numpy(np.asarray(frames))
    if not os.path.exists(os.path.dirname(output_path)):
        os.makedirs(os.path.dirname(output_path))
    torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
    return output_path

def generate_frames_from_video(video_path:str, start_time:int):
  frames = []
  try:
      cap = cv2.VideoCapture(video_path)
      cap.set(cv2.CAP_PROP_POS_MSEC, start_time*1000)
      fps = cap.get(cv2.CAP_PROP_FPS)
      while cap.isOpened():
          ret, frame = cap.read()
          if ret == True:
              frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
          else:
              break
  except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
      print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
  return frames

def track_object(images:list, points:np.ndarray, labels:np.ndarray, e2fgvi_checkpoint:str, sam_checkpoint:str, xmem_checkpoint:str, **kwargs):
  sys.argv = ["cuda:0"]
  args = parse_augment()
  multimask = kwargs.get('multimask', True)
  track_model = TrackingAnything(sam_checkpoint, xmem_checkpoint, e2fgvi_checkpoint, args)
  track_model.samcontroler.sam_controler.reset_image()
  track_model.samcontroler.sam_controler.set_image(images[0])
  mask,_,_ = track_model.first_frame_click(image = images[0], points = points, labels = labels, multimask = multimask)
  masks, logits ,painted_images= track_model.generator(images, mask)
  return masks, logits, painted_images

In [None]:
# | hide
# | eval: false

# check and download checkpoints if needed
SAM_checkpoint_dict = {
    'vit_h': "sam_vit_h_4b8939.pth",
    'vit_l': "sam_vit_l_0b3195.pth",
    "vit_b": "sam_vit_b_01ec64.pth"
}
SAM_checkpoint_url_dict = {
    'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
    'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
    'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
}
sam_checkpoint = SAM_checkpoint_dict['vit_h']
sam_checkpoint_url = SAM_checkpoint_url_dict['vit_h']
xmem_checkpoint = "XMem-s012.pth"
xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
e2fgvi_checkpoint = "E2FGVI-HQ-CVPR22.pth"
e2fgvi_checkpoint_id = "10wGdKSUOie0XmCr8SQ2A2FeDe-mfn5w3"

folder = "./checkpoints"
sam_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)



In [None]:
# | hide
# | eval: false

# extract frames from the video...
frames = generate_frames_from_video('vid_shorts.mp4', start_time=1)

In [None]:
# | hide
# | eval: false

h, w, _ = frames[0].shape
points=np.array([[int(w*0.5), int(h*0.5)], [0, h-10], [w-10, 0], [0,0], [w-10,h-10]])
labels = np.array([1, 0, 0, 0, 0])
# Track the masked object using point prompt..
masks, logits, painted_images = track_object(frames, points = points, labels = labels, e2fgvi_checkpoint = e2fgvi_checkpoint, sam_checkpoint = sam_checkpoint, xmem_checkpoint = xmem_checkpoint)



In [None]:
# | hide
# | eval: false

# Save the return frames in the form of a video..
output_path = 'output.mp4'
output_path = generate_video_from_frames(frames=frames, output_path=output_path)

In [None]:
# | hide
# | eval: false

img_path = "../imgs/"

model = load_obj_model()
objects = detect_objects(model, img_path)


In [None]:
# | hide
# | eval: false


objects[0]


In [None]:
# | hide
import nbdev

nbdev.nbdev_export()
