#### Install and import Python libraries

In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
import os
import sys

root_dir = ".."
sys.path.append(os.path.join(root_dir, "src"))


In [17]:
import matplotlib.pyplot as plt
import numpy as np
import time
import gif
from tqdm.autonotebook import tqdm

import torch
from torch.utils.data import DataLoader

from tracker.data_track import MOT16Sequences
from tracker.data_obj_detect import MOT16ObjDetect
from tracker.object_detector import FRCNN_FPN
from tracker.tracker import Tracker
from tracker.utils import ( evaluate_mot_accums, get_mot_accum,
                           evaluate_obj_detect, obj_detect_transforms)
from gnn.models import BipartiteNeuralMessagePassingLayer, SimilarityNet
# from gnn.tracker import MPNTracker
from tracker.visualize import plot_sequence, collect_frames_for_gif
from scipy.optimize import linear_sum_assignment as linear_assignment
from market.models import build_model
from tracker.tracker import Tracker, ReIDTracker
import os.path as osp

import motmetrics as mm
mm.lap.default_solver = 'lap'
%matplotlib inline



## Configuration

In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")


# Multi-object tracking

## Configuration

In [5]:
seed = 12345
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True


## load pretrained

In [6]:
# Define our model, and init
refine_gnn_net = SimilarityNet(
    reid_network=None,  # Not needed since we work with precomputed features
    node_dim=32,
    edge_dim=64,
    reid_dim=512,
    edges_in_dim=6,
    num_steps=10,
).cuda()

In [7]:
from gnn.tracker import MPNTracker


In [8]:
best_ckpt = torch.load(osp.join(root_dir, "output", "best_ckpt.pth"))
refine_gnn_net.load_state_dict(best_ckpt)
MAX_PATIENCE = 20

tracker = MPNTracker(
    assign_net=refine_gnn_net.eval(), obj_detect=None, patience=MAX_PATIENCE
)


In [9]:
#from tracker.utils import get_mot_accum,evaluate_mot_accums,plot_sequence

In [10]:
val_sequences = MOT16Sequences(
            "MOT16-val2", osp.join(root_dir, "data/MOT16"), vis_threshold=0.0
        )
train_db = torch.load(
    osp.join(root_dir, "data/preprocessed_data/preprocessed_data_train_2.pth")
)


In [11]:
output_dir=None
#####
db = train_db
time_total = 0
mot_accums = []
results_seq = {}
for seq in val_sequences:
    # break
    tracker.reset()
    now = time.time()

    print(f"Tracking: {seq}")

    # data_loader = DataLoader(seq, batch_size=1, shuffle=False)
    with torch.no_grad():
        # for i, frame in enumerate(tqdm(data_loader)):
        for frame in db[str(seq)]:
            tracker.step(frame)

    results = tracker.get_results()
    results_seq[str(seq)] = results

    if seq.no_gt:
        print("No GT evaluation data available.")
    else:
        mot_accums.append(get_mot_accum(results, seq))

    time_total += time.time() - now

    print(f"Tracks found: {len(results)}")
    print(f"Runtime for {seq}: {time.time() - now:.1f} s.")

    if output_dir is not None:
        os.makedirs(output_dir, exist_ok=True)
        seq.write_results(results, os.path.join(output_dir))

print(f"Runtime for all sequences: {time_total:.1f} s.")
if mot_accums:
    evaluate_mot_accums(
        mot_accums,
        [str(s) for s in val_sequences if not s.no_gt],
        generate_overall=True,
    )


Tracking: MOT16-02
Tracks found: 100
Runtime for MOT16-02: 5.0 s.
Tracking: MOT16-11
Tracks found: 83
Runtime for MOT16-11: 5.1 s.
Runtime for all sequences: 10.0 s.
          IDF1   IDP   IDR  Rcll  Prcn  GT MT PT ML  FP    FN IDs   FM  MOTA  MOTP IDt IDa IDm
MOT16-02 48.4% 68.7% 37.3% 52.2% 96.1%  62 11 38 13 390  8873  98  222 49.6% 0.096  59  47  12
MOT16-11 71.2% 78.5% 65.1% 80.2% 96.6%  75 44 24  7 266  1871  37   90 77.0% 0.083  31  15  14
OVERALL  56.9% 73.0% 46.7% 61.7% 96.3% 137 55 62 20 656 10744 135  312 58.8% 0.090  90  62  26


In [None]:
plot_sequence(
    results_seq["MOT16-02"],
    [s for s in val_sequences if str(s) == "MOT16-02"][0],
    first_n_frames=2,
)


# create gif

In [14]:
seq_name =  "MOT16-02"
sequence = [s for s in val_sequences if str(s) == seq_name ][0]
tracker_seq_res = results_seq[seq_name]

In [15]:
frames = collect_frames_for_gif(sequence, tracker_seq_res)

100%|█████████████████████████████████████████| 600/600 [05:46<00:00,  1.73it/s]


In [18]:
gif.save(frames, 'MOT16-02_result_0001sec.gif', duration=0.001)