#### Install and import Python libraries

In [None]:
%load_ext autoreload
%autoreload 2


In [None]:
import os
import sys
import matplotlib.pyplot as plt
import numpy as np
import time
from tqdm.autonotebook import tqdm
import torch
from torch.utils.data import DataLoader
from scipy.optimize import linear_sum_assignment as linear_assignment
import os.path as osp
import motmetrics as mm
mm.lap.default_solver = 'lap'
%matplotlib inline


## import local modules

In [None]:
root_dir = ".."
sys.path.append(os.path.join(root_dir, "src"))


In [None]:
from mot.data.data_track import MOT16Sequences
from mot.data.data_obj_detect import MOT16ObjDetect
from mot.models.object_detector import FRCNN_FPN
from mot.tracker.base import Tracker
from mot.visualize import plot_sequence
from mot.transforms import obj_detect_transforms
from mot.eval import evaluate_mot_accums, get_mot_accum, evaluate_obj_detect
from mot.models.gnn import BipartiteNeuralMessagePassingLayer, SimilarityNet
from mot.tracker.advanced import MPNTracker
from market.models import build_model




## Configuration

In [None]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "2"


In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")


# Multi-object tracking

## Configuration

In [None]:
seed = 12345
_UNMATCHED_COST = 255.0

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True


## load pretrained

In [None]:
model_path = root_dir + "/models/resnet50_reid.pth"

reid_model = build_model("resnet34", 751, loss="softmax", pretrained=True)
reid_ckpt = torch.load(model_path, map_location=lambda storage, loc: storage)
reid_model.load_state_dict(reid_ckpt)
reid_model = reid_model.to(device)


In [None]:
obj_detect_model_file = os.path.join(root_dir, "models/faster_rcnn_fpn.model")
obj_detect_nms_thresh = 0.3

# object detector
obj_detect = FRCNN_FPN(num_classes=2, nms_thresh=obj_detect_nms_thresh)
obj_detect_state_dict = torch.load(
    obj_detect_model_file, map_location=lambda storage, loc: storage
)
obj_detect.load_state_dict(obj_detect_state_dict)
obj_detect.eval()
obj_detect = obj_detect.to(device)


In [None]:
similarity_net = SimilarityNet(
    reid_network=None,  # Not needed since we work with precomputed features
    node_dim=32,
    edge_dim=64,
    reid_dim=512,
    edges_in_dim=6,
    num_steps=10,
)

best_ckpt = torch.load(
    osp.join(root_dir, "models", "best_ckpt.pth"),
    map_location=lambda storage, loc: storage,
)
similarity_net.load_state_dict(best_ckpt)
similarity_net = similarity_net.to(device)


## Infer without predefined features, from raw sequences

In [None]:
# tracker = ReIDHungarianTracker(obj_detect)
# tracker = ReIDHungarianTracker2(obj_detect)

tracker = MPNTracker(
    similarity_net=similarity_net.eval(), reid_model=reid_model, obj_detect=obj_detect
)


In [None]:
val_sequences = MOT16Sequences(
    "MOT16-val2", osp.join(root_dir, "data/MOT16"), vis_threshold=0.0
)
print("seqs", [str(s) for s in val_sequences if not s.no_gt])


In [None]:
time_total = 0
mot_accums = []
results_seq = {}
for seq in val_sequences:
    tracker.reset()
    now = time.time()

    print(f"Tracking: {seq}")

    data_loader = DataLoader(seq, batch_size=1, shuffle=False)
    for frame in tqdm(data_loader):
        tracker.step(frame)

    results = tracker.get_results()
    results_seq[str(seq)] = results

    if seq.no_gt:
        print(f"No GT evaluation data available.")
    else:
        mot_accums.append(get_mot_accum(results, seq))

    time_total += time.time() - now

    print(f"Tracks found: {len(results)}")
    print(f"Runtime for {seq}: {time.time() - now:.1f} s.")

    # seq.write_results(results, os.path.join(output_dir))

print(f"Runtime for all sequences: {time_total:.1f} s.")
if mot_accums:
    evaluate_mot_accums(
        mot_accums,
        [str(s) for s in val_sequences if not s.no_gt],
        generate_overall=True,
    )


## Baseline Tracker Results

            IDF1   IDP   IDR   Rcll  Prcn  GT  MT  PT ML  FP   FN IDs   FM  MOTA  MOTP
    MOT16-02 32.2% 49.8% 23.8% 30.8% 64.4%  62  5  22 35 3170 12858  52   93 13.5% 0.086
    MOT16-05 47.7% 53.9% 42.8% 57.8% 72.7% 133 39  64 30 1502  2917  87  103 34.9% 0.144
    MOT16-09 43.0% 48.8% 38.4% 51.9% 66.1%  26  7  14  5 1420  2559  39   66 24.5% 0.107
    MOT16-11 49.0% 54.1% 44.8% 55.8% 67.5%  75 15  32 28 2542  4166  20   39 28.7% 0.080
    OVERALL  41.0% 51.8% 33.9% 44.1% 67.3% 296 66 132 98 8634 22500 198  301 22.2% 0.101

## Hungarian Tracker Results

             IDF1    IDP  IDR  Rcll  Prcn   GT  MT  PT ML   FP    FN IDs   FM  MOTA  MOTP 
    MOT16-02 39.1% 55.5% 30.2% 52.3% 96.2%  62  11  38 13  383  8870 246  215 48.9% 0.096 
    MOT16-05 55.1% 65.2% 47.7% 68.8% 94.2% 133  55  66 12  295  2158 199  155 61.7% 0.143  
    MOT16-09 50.2% 62.0% 42.1% 66.4% 97.8%  26  13  12  1   80  1789  76   78 63.5% 0.083  
    MOT16-11 60.4% 66.6% 55.3% 80.2% 96.6%  75  42  26  7  266  1868  99   86 76.3% 0.083  
    OVERALL  49.0% 61.5% 40.6% 63.5% 96.2% 296 121 142 33 1024 14685 620  534 59.4% 0.099 

# ReidHungarianIoU Tracker results

              IDF1   IDP   IDR  Rcll  Prcn  GT  MT  PT ML   FP    FN IDs   FM  MOTA  MOTP IDt IDa IDm
    MOT16-02 41.0% 58.2% 31.6% 52.3% 96.2%  62  11  38 13  383  8870 334  221 48.4% 0.095 173 154  10
    MOT16-05 55.8% 66.1% 48.3% 68.8% 94.2% 133  56  65 12  295  2158 218  150 61.4% 0.142  96 144  25
    MOT16-09 52.4% 64.7% 44.0% 66.4% 97.8%  26  12  13  1   80  1789  80   79 63.4% 0.083  27  58   5
    MOT16-11 62.2% 68.6% 56.9% 80.2% 96.6%  75  42  26  7  266  1868 112   86 76.2% 0.083  41  77  11
    OVERALL  50.6% 63.7% 42.1% 63.5% 96.2% 296 121 142 33 1024 14685 744  536 59.1% 0.099 337 433  51

# ReidHungarianIoU Tracker results

              IDF1   IDP   IDR  Rcll  Prcn  GT MT PT ML  FP    FN IDs   FM  MOTA  MOTP IDt IDa IDm
    MOT16-02 42.7% 60.7% 33.0% 52.3% 96.2%  62 11 38 13 383  8870 191  214 49.2% 0.095  24 167   4
    MOT16-11 62.6% 69.0% 57.3% 80.2% 96.6%  75 42 26  7 266  1868  83   86 76.5% 0.083  13  71   3
    OVERALL  50.2% 64.3% 41.1% 61.7% 96.4% 137 53 64 20 649 10738 274  300 58.4% 0.090  37 238   7

## Visualize tracking results

### new

In [None]:
plot_sequence(
    results_seq["MOT16-02"],
    [s for s in val_sequences if str(s) == "MOT16-02"][0],
    first_n_frames=3,
)
