**Задача 1**

Дополните код алгоритма SORT так, чтоб можно было задать область интереса roi и считать сколько объектов какого класса находится внутри нее.

**Задача 2**

Дополните код алгоритма SORT так, чтоб можно было задать линию и считать сколько объектов ее пересекло.

In [1]:
!pip install filterpy >> None

In [2]:
import numpy as np
import cv2
import torch
import torchvision
from filterpy.kalman import KalmanFilter
from collections import defaultdict
from scipy.optimize import linear_sum_assignment
from tqdm import tqdm
from IPython.display import Video, display

In [3]:
class KalmanBoxTracker(object):
    count = 0
    def __init__(self, bbox, label=None):
        self.kf = KalmanFilter(dim_x=7, dim_z=4)
        self.kf.F = np.array([
            [1,0,0,0,1,0,0],
            [0,1,0,0,0,1,0],
            [0,0,1,0,0,0,1],
            [0,0,0,1,0,0,0],
            [0,0,0,0,1,0,0],
            [0,0,0,0,0,1,0],
            [0,0,0,0,0,0,1]
        ])
        self.kf.H = np.array([
            [1,0,0,0,0,0,0],
            [0,1,0,0,0,0,0],
            [0,0,1,0,0,0,0],
            [0,0,0,1,0,0,0]
        ])
        self.kf.R[2:,2:] *= 10.
        self.kf.P[4:,4:] *= 1000.
        self.kf.P *= 10.
        self.kf.Q[-1,-1] *= 0.01
        self.kf.Q[4:,4:] *= 0.01
        self.kf.x[:4] = self.convert_bbox_to_z(bbox)
        self.id = KalmanBoxTracker.count
        KalmanBoxTracker.count += 1
        self.history = []
        self.hits = 0
        self.hit_streak = 0
        self.age = 0
        self.time_since_update = 0
        self.label = label
    def update(self, bbox, label=None):
        self.time_since_update = 0
        self.history = []
        self.hits += 1
        self.hit_streak += 1
        self.kf.update(self.convert_bbox_to_z(bbox))
        if label is not None:
            self.label = label
    def predict(self):
        if (self.kf.x[6] + self.kf.x[2]) <= 0:
            self.kf.x[6] *= 0.0
        self.kf.predict()
        self.age += 1
        if self.time_since_update > 0:
            self.hit_streak = 0
        self.time_since_update += 1
        self.history.append(self.convert_x_to_bbox(self.kf.x))
        return self.history[-1]
    def get_state(self):
        return self.convert_x_to_bbox(self.kf.x)
    @staticmethod
    def convert_bbox_to_z(bbox):
        w = bbox[2] - bbox[0]
        h = bbox[3] - bbox[1]
        x = bbox[0] + w/2.
        y = bbox[1] + h/2.
        s = w * h
        r = w / float(h)
        return np.array([x, y, s, r]).reshape((4, 1))
    @staticmethod
    def convert_x_to_bbox(x, score=None):
        w = np.sqrt(x[2] * x[3])
        h = x[2] / w
        if score is None:
            return np.array([x[0]-w/2., x[1]-h/2., x[0]+w/2., x[1]+h/2.]).reshape((1,4))
        else:
            return np.array([x[0]-w/2., x[1]-h/2., x[0]+w/2., x[1]+h/2., score]).reshape((1,5))

In [4]:
def iou(box1, box2):
    x1 = max(box1[0], box2[0])
    y1 = max(box1[1], box2[1])
    x2 = min(box1[2], box2[2])
    y2 = min(box1[3], box2[3])
    inter = max(0, x2 - x1 + 1) * max(0, y2 - y1 + 1)
    a1 = (box1[2] - box1[0] + 1) * (box1[3] - box1[1] + 1)
    a2 = (box2[2] - box2[0] + 1) * (box2[3] - box2[1] + 1)
    return inter / float(a1 + a2 - inter)

In [5]:
def associate_detections_to_trackers(detections, trackers, iou_threshold=0.3):
    if len(trackers) == 0:
        return np.empty((0,2), int), np.arange(len(detections)), np.empty((0,5), int)
    iou_matrix = np.zeros((len(detections), len(trackers)), dtype=np.float32)
    for d, det in enumerate(detections):
        for t, trk in enumerate(trackers):
            iou_matrix[d, t] = iou(det, trk)
    if min(iou_matrix.shape) > 0:
        a = (iou_matrix > iou_threshold).astype(int)
        if a.sum(1).max() == 1 and a.sum(0).max() == 1:
            matched = np.stack(np.where(a), axis=1)
        else:
            r, c = linear_sum_assignment(-iou_matrix)
            matched = np.array([[row, col] for row, col in zip(r, c)])
    else:
        matched = np.empty((0,2), int)
    unmatched_d = [d for d in range(len(detections)) if d not in matched[:,0]]
    unmatched_t = [t for t in range(len(trackers)) if t not in matched[:,1]]
    matches = []
    for m in matched:
        if iou_matrix[m[0], m[1]] < iou_threshold:
            unmatched_d.append(m[0])
            unmatched_t.append(m[1])
        else:
            matches.append(m.reshape(1,2))
    matches = np.concatenate(matches, axis=0) if matches else np.empty((0,2), int)
    return matches, np.array(unmatched_d), np.array(unmatched_t)

In [6]:
class Sort(object):
    def __init__(self, max_age=1, min_hits=3, iou_threshold=0.3):
        self.max_age = max_age
        self.min_hits = min_hits
        self.iou_threshold = iou_threshold
        self.trackers = []
        self.track_labels = {}
        self.frame_count = 0
    def update(self, dets=np.empty((0,5)), labels=None):
        self.frame_count += 1
        trks = np.zeros((len(self.trackers), 5))
        to_del = []
        ret = []
        for t, _ in enumerate(trks):
            pos = self.trackers[t].predict()[0]
            trks[t] = [pos[0], pos[1], pos[2], pos[3], 0]
            if np.any(np.isnan(pos)):
                to_del.append(t)
        trks = np.ma.compress_rows(np.ma.masked_invalid(trks))
        for t in reversed(to_del):
            self.trackers.pop(t)
        matched, un_d, un_t = associate_detections_to_trackers(dets, trks, self.iou_threshold)
        for m in matched:
            ti = m[1]; di = m[0]
            self.trackers[ti].update(dets[di], labels[di] if labels is not None else None)
            self.track_labels[self.trackers[ti].id] = self.trackers[ti].label
        for i in un_d:
            trk = KalmanBoxTracker(dets[i], labels[i] if labels is not None else None)
            self.trackers.append(trk)
            self.track_labels[trk.id] = trk.label
        i = len(self.trackers)
        for trk in reversed(self.trackers):
            d = trk.get_state()[0]
            if (trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits):
                ret.append(np.concatenate((d, [trk.id+1])).reshape(1,-1))
            i -= 1
            if trk.time_since_update > self.max_age:
                self.trackers.pop(i)
        return np.concatenate(ret) if ret else np.empty((0,5))

In [7]:
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()
COCO_INSTANCE_CATEGORY_NAMES = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana',
    'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut',
    'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', 'N/A',
    'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]

Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" to /root/.cache/torch/hub/checkpoints/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth
100%|██████████| 160M/160M [00:00<00:00, 171MB/s]


In [8]:
def detect_objects(img, model, threshold=0.5):
    img_tensor = torchvision.transforms.functional.to_tensor(img).unsqueeze(0)
    with torch.no_grad():
        preds = model(img_tensor)
    boxes = preds[0]['boxes'].cpu().numpy()
    scores = preds[0]['scores'].cpu().numpy()
    labels = preds[0]['labels'].cpu().numpy()
    mask = scores >= threshold
    return np.hstack((boxes[mask], scores[mask,None])), labels[mask]

In [9]:
def ccw(A,B,C):
    return (C[1]-A[1])*(B[0]-A[0]) > (B[1]-A[1])*(C[0]-A[0])

def line_intersect(A,B,C,D):
    return ccw(A,C,D) != ccw(B,C,D) and ccw(A,B,C) != ccw(A,B,D)

In [10]:
def track_objects(video_path, output_path='output.mp4', roi=None, line=None):
    mot = Sort(max_age=5, min_hits=3, iou_threshold=0.3)
    cap = cv2.VideoCapture(video_path)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w,h))
    last_pos = {}
    crossed = set()
    for _ in tqdm(range(total_frames), desc="Processing frames"):
        ret, frame = cap.read()
        if not ret:
            break
        rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        dets, labels = detect_objects(rgb, model)
        tracks = mot.update(dets, labels)
        counts = defaultdict(int)
        if roi:
            cv2.rectangle(frame, (roi[0],roi[1]), (roi[2],roi[3]), (0,255,0), 2)
        if line:
            cv2.line(frame, line[0], line[1], (0,0,255), 2)
        for tr in tracks:
            x1,y1,x2,y2,tid = tr
            tid = int(tid)
            cx,cy = int((x1+x2)/2), int((y1+y2)/2)
            lbl = COCO_INSTANCE_CATEGORY_NAMES[mot.track_labels[tid-1]]
            if roi and roi[0] <= cx <= roi[2] and roi[1] <= cy <= roi[3]:
                counts[lbl] += 1
            if line:
                if tid in last_pos:
                    if line_intersect(last_pos[tid], (cx,cy), line[0], line[1]) and tid not in crossed:
                        crossed.add(tid)
                last_pos[tid] = (cx,cy)
        for i,(cls,cnt) in enumerate(counts.items()):
            cv2.putText(frame, f'{cls}: {cnt}', (roi[0], roi[1]-10-15*i),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,255,0), 2)
        if line:
            cv2.putText(frame, f'Crossed: {len(crossed)}', (10,30),
                        cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,255), 2)
        for tr in tracks:
            x1,y1,x2,y2,tid = map(int, tr)
            col = (int(255*(tid%3)/3), int(255*(tid%6)/6), int(255*(tid%9)/9))
            cv2.rectangle(frame, (x1,y1), (x2,y2), col, 2)
            cv2.putText(frame, f'ID:{tid}', (x1,y1-10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, col, 2)
        out.write(frame)
    cap.release()
    out.release()
    cv2.destroyAllWindows()

In [11]:
track_objects('pedestrian.mp4', 'output.mp4', roi=(100,100,600,600),
              line=((200,0),(200,480)))

Processing frames: 100%|██████████| 149/149 [19:37<00:00,  7.90s/it]


In [12]:
!ffmpeg -i output.mp4 -c:v libx264 -c:a aac output_.mp4

ffmpeg version 4.4.2-0ubuntu0.22.04.1 Copyright (c) 2000-2021 the FFmpeg developers
  built with gcc 11 (Ubuntu 11.2.0-19ubuntu1)
  configuration: --prefix=/usr --extra-version=0ubuntu0.22.04.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libdav1d --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librabbitmq --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libsrt --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enab

In [13]:
video_path = 'output_.mp4'
display(Video(video_path, embed=True, width=640, height=480))