In [1]:
#imports
import os
from timeit import time
import warnings
import sys
import cv2
import numpy as np
from PIL import Image
#imports from file yolo.py
from yolo import YOLO
#imports from folder deep_sort
#Simple Online and Realtime Tracking with a Deep Association Metric (Deep SORT)
from deep_sort import preprocessing
from deep_sort import nn_matching
from deep_sort.detection import Detection
from deep_sort.tracker import Tracker
#imports from filder tools
from tools import generate_detections as gdet
#imports from file detection.py in folder deep_sort
from deep_sort.detection import Detection as ddet

Using TensorFlow backend.


In [2]:
def main(yolo):

   # Definition of the parameters
    max_cosine_distance = 0.3
    nn_budget = None
    nms_max_overlap = 1.0
    
   # deep_sort 
    model_filename = 'model_data/mars-small128.pb'
    encoder = gdet.create_box_encoder(model_filename,batch_size=1) #generate detections
    
    metric = nn_matching.NearestNeighborDistanceMetric("cosine", max_cosine_distance, nn_budget)
    tracker = Tracker(metric)

    writeVideo_flag = True #write file to disk
    
    video_capture = cv2.VideoCapture("videos/traffic.mp4") # open video file from disk

    if writeVideo_flag: #for video write
    # Define the codec and create VideoWriter object
        w = int(video_capture.get(3))
        h = int(video_capture.get(4))
        fourcc = cv2.VideoWriter_fourcc(*'MJPG')
        out = cv2.VideoWriter('output.avi', fourcc, 15, (w, h)) # write to video file on disk 
        list_file = open('detection.txt', 'w') # txt file to write to
        frame_index = 0
        
    fps = 0.0
    while True:
        ret, frame = video_capture.read()  # capture one frame
        if ret != True: # return value from video capture
            break;
        t1 = time.time()

        image = Image.fromarray(frame) #create image from frame
        boxs, scores = yolo.detect_image(image) #boxs and probabilities        
        features = encoder(frame,boxs)
        
        detections = [Detection(bbox, score, feature) for bbox, score, feature in zip(boxs, scores, features)]
        
        # Run non-maxima suppression.
        boxes = np.array([d.tlwh for d in detections])
        scores = np.array([d.confidence for d in detections])
        indices = preprocessing.non_max_suppression(boxes, nms_max_overlap, scores)
        detections = [detections[i] for i in indices]
        
        # Call the tracker
        tracker.predict()
        tracker.update(detections)    
        
        for track in tracker.tracks:            
            if track.is_confirmed() and track.time_since_update >1 :
                continue 
            bbox = track.to_tlbr()
            cv2.rectangle(frame, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])),(255,255,255), 2) #white
            cv2.putText(frame, str(track.track_id),(int(bbox[2]), int(bbox[3])),0, 5e-3 * 200, (0,255,0),2)
            list_file.write(str(frame_index)+' - ') # write frame index
            list_file.write(str(track.track_id)+': '+str(int(bbox[0])) + ' '+str(int(bbox[1])) + ' '
                                    +str(int(bbox[2])) + ' '+str(int(bbox[3])) + ' ')
            list_file.write('\n')

        for det in detections:
            bbox = det.to_tlbr()
            cv2.rectangle(frame,(int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])),(255,0,0), 2) #blue
            cv2.putText(frame, "{:.2f}".format(det.confidence),(int(bbox[2]), int(bbox[3])+20),0, 5e-3 * 200, (0,0,255),2)
            
        cv2.imshow('', frame) # show on screent
        
        if writeVideo_flag:
            # save a frame (to file on disk)
            out.write(frame)
            frame_index = frame_index + 1

            
        fps  = ( fps + (1./(time.time()-t1)) ) / 2
        #print("fps= %f"%(fps)) #print fps
        
        # Press Q to stop!
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    video_capture.release()
    if writeVideo_flag:
        out.release()
        list_file.close()
    cv2.destroyAllWindows()
    
main(YOLO())

model_data/yolo.h5 model, anchors, and classes loaded.
