In [None]:
def timing(f):
    def wrap(*args, **kwargs):
        time1 = time.time()
        ret = f(*args, **kwargs)
        time2 = time.time()
        print('{:s} function took {:.3f} ms'.format(f.__name__, (time2-time1)*1000.0))

        return ret
    return wrap

In [None]:
from imageai import Detection
import matplotlib.pyplot as plt
from scipy.spatial import KDTree
import cv2
import numpy as np
import time

In [None]:
@timing
def get_yolo():
    yolo='yolo.h5'
#     yolo='yolo-tiny.h5'
    detector=Detection.ObjectDetection()
    detector.setModelTypeAsYOLOv3()
#     detector.setModelTypeAsTinyYOLOv3()
    detector.setModelPath(yolo)
    detector.loadModel()
    return detector

In [None]:
def show_frame(frame,cap):
    cv2.imshow('frame',frame)
    if(cv2.waitKey(25)&0XFF==ord('q')):
        cv2.destroyAllWindows()
        cap.release()

In [None]:
def get_bboxes(frame,detector):
    custom_class=detector.CustomObjects(person=True)
    img,preds=detector.detectCustomObjectsFromImage(input_image=frame, 
                      custom_objects=custom_class, input_type="array",
                      output_type="array",
                      minimum_percentage_probability=50,
                      display_percentage_probability=False,
                      display_object_name=False)
    return img,preds

In [None]:
def get_center_pts(preds,frame):
    centers=[]
    for pred in preds:
        x1,y1,x2,y2=pred.get('box_points')
        ct=(int((x1+x2)/2),int((y1+y2)/2))
        centers.append(ct)
    return centers

In [None]:
def find_close_points(frame,centers,cut_off): 
    X=np.array(centers)
    tree=KDTree(X)
    for pivot_xy in X:
        violated_pts=tree.query_ball_point(pivot_xy,r=cut_off)
        if(len(violated_pts)>1):
            cv2.line(frame,(X[violated_pts][0][0], X[violated_pts][0][1]), (X[violated_pts][1][0], X[violated_pts][1][1]), (0, 255, 0), thickness=2)
            cv2.circle(frame,tuple(pivot_xy),radius=4,color=(0,0,255),thickness=-1)
        else:
            cv2.circle(frame,tuple(pivot_xy),radius=4,color=(0,255,0),thickness=-1)
    return frame    

In [None]:
@timing
def process_frame(video_path):
    
    cap=cv2.VideoCapture(video_path)    
    width=int(cap.get(3))
    height=int(cap.get(4))
    
    writer=cv2.VideoWriter('output_1.avi',cv2.VideoWriter_fourcc(*'MJPG'),24.0,(width,height))
    
    detector=get_yolo()
    
    if(cap.isOpened()==False):
        print('No video file or incorrect path')
    
    while(cap.isOpened()):        
        ret,frame=cap.read()
        if(ret==True):
            frame_count=frame_count+1          
            box_img,preds=get_bboxes(frame,detector)            
            centers=get_center_pts(preds,box_img)
            out_img=find_close_points(frame,centers,cut_off=70)
            writer.write(out_img)
        else:
            break
    cap.release()
    writer.release()

In [None]:
process_frame(video_path='input.mp4')