In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import os
import argparse

import cv2
import torch
import numpy as np
from glob import glob

from pysot.core.config import cfg
from pysot.models.model_builder import ModelBuilder
from pysot.tracker.tracker_builder import build_tracker

import time

In [2]:
torch.cuda.is_available()

True

In [None]:
model_path = "./pysot/models/rpn_res50.pth"
### СКАЧАТЬ МОДЕЛЬ
### 'https://drive.google.com/file/d/1-tEtYQdT1G9kn8HsqKNDHVqjE16F8YQH/view?usp=drive_link'
### Поместить и переименовать в соответствии с model_path

if (not os.path.exists(model_path)):
    import requests
    url = 'https://drive.google.com/u/0/uc?id=1-tEtYQdT1G9kn8HsqKNDHVqjE16F8YQH&export=download&confirm=t&uuid=afd42841-2c3c-42a6-80eb-bef8b5157555&at=AB6BwCCrpEW_LEhqItrzgSkQn8S3:1698336708620'
    r = requests.get(url, allow_redirects=True)

    open(model_path, 'wb').write(r.content)

In [4]:
def track_inference(local_config: dict):
    cfg.merge_from_file(local_config["config"])
    cfg.CUDA = torch.cuda.is_available() and cfg.CUDA
    device = torch.device('cuda' if cfg.CUDA else 'cpu')

    model = ModelBuilder()
    model.load_state_dict(torch.load(local_config["model"],
        map_location=lambda storage, loc: storage.cpu()))
    model.eval().to(device)

    tracker = build_tracker(model)

    first_frame = True
    video_name = local_config["video"].split('/')[-1].split('.')[0]
    video_path = local_config["video"]
    cv2.namedWindow(video_name, cv2.WND_PROP_FULLSCREEN)
    
    img_to_skip_for_warm_up = 10
    total_time = 0

    input_cap = cv2.VideoCapture(video_path)
    processed_img_count = 0
    if ("max_img" in local_config):
        max_img = local_config["max_img"]
    else:
        max_img = np.inf

    if ("output_file" in local_config):
        width  = input_cap.get(cv2.CAP_PROP_FRAME_WIDTH)   
        height = input_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)  
        fps = input_cap.get(cv2.CAP_PROP_FPS)
        out_video = cv2.VideoWriter(local_config["output_file"], cv2.VideoWriter_fourcc('F','M','P','4'), fps, (int(width), int(height)))
    else:
        out_video = None


    while input_cap.isOpened() and processed_img_count < max_img:
        success, frame = input_cap.read()
        if success:
            if first_frame:
                if "init_rect" not in local_config:
                    try:
                        init_rect = cv2.selectROI(video_name, frame, False, False)
                        print(init_rect)
                    except:
                        exit()
                else:
                    init_rect = local_config["init_rect"]
                tracker.init(frame, init_rect)
                first_frame = False
                print("Tracker initialized")
            else:
                start_time = time.time()
                outputs = tracker.track(frame)
                end_time = time.time()
                processed_img_count += 1
                if (processed_img_count > img_to_skip_for_warm_up):
                    total_time += end_time - start_time
                
                bbox = list(map(int, outputs['bbox']))
                cv2.rectangle(frame, (bbox[0], bbox[1]),
                            (bbox[0]+bbox[2], bbox[1]+bbox[3]),
                            (0, 255, 0), 3)
                cv2.imshow(video_name, frame)

                if (out_video):
                    out_video.write(frame)
            
            if cv2.waitKey(1) & 0xFF == ord("q"):
                break
        else:
            break

    input_cap.release()
    if (out_video):
        out_video.release()
    cv2.destroyAllWindows()
    return (processed_img_count - img_to_skip_for_warm_up, total_time)

In [5]:
# TRACK_CONFIG = {
#     "config": "experiments/siamrpn_alex_dwxcorr/config.yaml",
#     "model": "models/rpn_alex.pth",
#     "video": "demo/traffic_big.mp4",
#     "output_file": "output/out_alex.mp4",
#     "init_rect": (1147, 429, 51, 44)
# }
# img_count, processed_time = track_inference(TRACK_CONFIG)

# print()
# print(TRACK_CONFIG)
# print(f"FPS: {(img_count / processed_time)}")


Tracker initialized

{'config': 'experiments/siamrpn_alex_dwxcorr/config.yaml', 'model': 'models/rpn_alex.pth', 'video': 'demo/traffic_big.mp4', 'output_file': 'output/out_alex.mp4', 'init_rect': (1147, 429, 51, 44)}
FPS: 68.03375098067357


In [8]:
TRACK_CONFIG = {
    "config": "experiments\siamrpn_r50_l234_dwxcorr\config.yaml",
    "model": "models/rpn_res50.pth",
    "video": "demo/traffic_big.mp4",
    "output_file": "output/out_res_50.mp4",
    "init_rect": (1147, 429, 51, 44)
}
img_count, processed_time = track_inference(TRACK_CONFIG)

print()
print(TRACK_CONFIG)
print(f"FPS: {(img_count / processed_time)}")

Tracker initialized

{'config': 'experiments\\siamrpn_r50_l234_dwxcorr\\config.yaml', 'model': 'models/rpn_res50.pth', 'video': 'demo/traffic_big.mp4', 'output_file': 'output/out_res_50.mp4', 'init_rect': (1147, 429, 51, 44)}
FPS: 33.18419856658559
