In [6]:
%matplotlib inline
from matplotlib import pyplot as plt
plt.rcParams["figure.figsize"] = (10, 7) # (w, h)
import numpy as np
import cv2

In [7]:
vidcap = cv2.VideoCapture('cat.mp4')

def getFrame(sec, state):
    vidcap.set(cv2.CAP_PROP_POS_MSEC,sec*1000)
    hasFrames,image = vidcap.read()
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    if hasFrames:
        state.append(image)
    return hasFrames

sec = 3.5
frameRate = 0.3
state = []
success = getFrame(sec, state)
# used frame till 20 second, becuase after 20th second occurs background clutter -> the model is not able to detect the cat
while sec < 20:
    sec = sec + frameRate
    sec = round(sec, 2)
    success = getFrame(sec,state)

### I use model from https://github.com/alisonswu/cat-detector just a little bit changed detect function

In [9]:
import torch
import os
from PIL import Image

from model import catdetector
from utils import detect
from tqdm import tqdm

if not os.path.exists('catdetector.pkl'):
    os.system("wget https://s3.amazonaws.com/cat-detector/catdetector.pkl")
    

save_model = torch.load('catdetector.pkl', map_location='cpu')
model_dict =  catdetector.state_dict()
state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()}
model_dict.update(state_dict)
catdetector.load_state_dict(model_dict)

_ = catdetector.eval()

bboxes = []
for i in tqdm(range(len(state)), desc="Finding object"):
    bboxes.append(detect(state[i], catdetector))

Finding object:   0%|          | 0/56 [00:00<?, ?it/s]


AttributeError: module 'scipy.misc' has no attribute 'imresize'

In [6]:
def find_bb_center(bbox):
    if bbox is None:
        return None

    half_w = int(abs(bbox[2] - bbox[0]) / 2)
    half_h = int(abs(bbox[3] - bbox[1]) / 2)

    return (bbox[0] + half_w, bbox[1] + half_h)

def most_frequent(List): 
    return max(set(List), key = List.count) 

def off_px(p, offset):
    ''' offset point by X axis '''
    return (p[0] + offset, p[1])

centers = []
for bbox in bboxes:
    centers.append(find_bb_center(bbox))

centers_to_check = 6 # should be even
direction_right = 1
direction_left = -1

for_video = []
for i in tqdm(range(len(state)), desc="Drawing graphics"):
    cv2img = state[i].copy()
    ori_H, ori_W, _ = cv2img.shape
    
    bbox = bboxes[i]
    if bbox is None:
        for_video.append(cv2img)
        continue
    
    # draw bounding box
    cv2.rectangle(cv2img,(bbox[0], bbox[1]), (bbox[2], bbox[3]), (255,0,0), thickness=int(ori_H/150))

    last_centers = [i for i in centers[:i+1] if i] # remove any None centers
    if len(last_centers) >= centers_to_check:
        current_p = last_centers[-1] # find current center

        last_centers = last_centers[-centers_to_check:] # leave only last `centers_to_check` centers
        last_centers.reverse() # make them in order current -> old -> older

        directions = [] # list of last `centers_to_check`-1 direction changes
        for ci in range(len(last_centers) - 1):
            cc = last_centers[ci]
            pc = last_centers[ci + 1]

            # detecting direction change
            if cc[0] < pc[1]:
                directions.append(direction_right) # moving right
            else:
                directions.append(direction_left) # moving left
        direction = most_frequent(directions) # find most common direction change in the last row

        # we want fancy line
        bb_width = abs(bbox[2] - bbox[0])
        l_offset = int(bb_width * 0.2)
        arr_width = int(bb_width * 0.5) + l_offset
        
        # draw arrow from current center with detected direction
        cv2.arrowedLine(cv2img, off_px(current_p, direction * l_offset), off_px(current_p, direction * arr_width), (0, 255, 0), 2)
        #cv2.circle(cv2img, current_p, 3, (0, 255, 0), -1)

    for_video.append(cv2img)

Drawing graphics: 100%|██████████████████████████████████████████████████████████████| 56/56 [00:00<00:00, 1749.97it/s]


In [7]:
height, width, layers = for_video[0].shape
size = (width,height)    

out = cv2.VideoWriter('cat_tracking.avi',cv2.VideoWriter_fourcc(*'DIVX'), 15, size)

for img in for_video:
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    out.write(img)
out.release()