In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import time
from itertools import product
from typing import List, Iterable, Mapping
import os
from pathlib import Path
import datetime
from collections import defaultdict, deque
import statistics
from copy import deepcopy

import numpy as np
import torch
from torch_geometric.data import Data, DataLoader, Batch
from pytorch_lightning.utilities.seed import seed_everything

import dataset_classes.kitti.mot_kitti as mot_kitti
from dataset_classes.nuscenes.dataset import MOTDatasetNuScenes
from dataset_classes.mot_sequence import MOTSequence
from utils import io
from configs.params import TRAIN_SEQ, VAL_SEQ, TRACK_TRAIN_SEQ, TRACK_VAL_SEQ, build_params_dict, KITTI_BEST_PARAMS, NUSCENES_BEST_PARAMS, variant_name_from_params
from configs.local_variables import KITTI_WORK_DIR, SPLIT, NUSCENES_WORK_DIR, MOUNT_PATH
import inputs.utils as input_utils
from dataset_classes.kitti.classes import KITTIClasses
from dataset_classes.nuscenes.classes import NuScenesClasses
from torch_geometric.data import InMemoryDataset
from pytorch_lightning import Trainer

from models.graph_tracker_offline import GraphTrackerOffline
import data.graph_construction as graph_construction
import models.utils as utils_models
import evaluation.helpers as evaluation
import evaluation.offline_utils as offline_utils
import evaluation.offline_processing as processing

In [3]:
mot_dataset = mot_kitti.MOTDatasetKITTI(work_dir=KITTI_WORK_DIR,
                                        det_source=input_utils.POINTGNN_T3,
                                        seg_source=input_utils.TRACKING_BEST,
                                        params=KITTI_BEST_PARAMS)

# mot_dataset = MOTDatasetNuScenes(work_dir=NUSCENES_WORK_DIR,
#                                  det_source=input_utils.CENTER_POINT,
#                                  seg_source=input_utils.MMDETECTION_CASCADE_NUIMAGES,
#                                  params=NUSCENES_BEST_PARAMS,
#                                  version="v1.0-mini")

In [4]:
CLIP_LENGTH = 11

data_full = True
num_augmentations = 0
num_without_augmentations = 1

seg_class_id = KITTIClasses.car  # pedestrian
# seg_class_id = NuScenesClasses.car  # pedestrian
annotated = False
max_edge_length = 5  # -1 to connect all, 5 otherwise
include_dims = False

deltas_only = True
online_only = False

# At 50 km/h, will cover 1.39 m in a frame at 10Hz on KITTI
# 1.4 m/s - average human walking speed, 0.14 per frame
# NuScenes is at 2Hz
max_edge_distances = {  # in meters
    KITTIClasses.car: 3,
    KITTIClasses.pedestrian: 0.5,
    
    NuScenesClasses.car: 10,
    NuScenesClasses.pedestrian: 2,
}

data_params = {
    "seg_class_id": seg_class_id,
    "annotated": annotated,
    "max_edge_length": max_edge_length,

    "include_dims": include_dims,
    "deltas_only": deltas_only,
    "max_edge_distance": max_edge_distances[seg_class_id],
    "online_only": online_only,
    
    "bbox_drop_p": 0.4,   # 0.3, 0.4 Nu
    "frame_drop_p": 0.1,  # 0.05, 0.1 Nu
    # probably should increase for NuScenes - sparser cloud
    "xz_std": 0.5,   # 0.4, 0.5 Nu
    "theta_std": 0.2,  # 0.17, 0.2 Nu
    "lwh_std": 0,  # 0
    
    "bbox_add_p": 0.4,  # 0.3, 0.4 Nu
    "num_bboxes_to_always_add": 3,  # 1, 3 Nu
}

data_params_no_aug = {
    "seg_class_id": seg_class_id,
    "annotated": annotated,
    "max_edge_length": max_edge_length,
    
    "include_dims": include_dims,
    "deltas_only": deltas_only,
    "max_edge_distance": max_edge_distances[seg_class_id],
    "online_only": online_only,
    
    "bbox_drop_p": 0, 
    "frame_drop_p": 0, 
    "xz_std": 0, 
    "theta_std": 0, 
    "lwh_std": 0,
    
    "bbox_add_p": 0,
    "num_bboxes_to_always_add": 0,
}

if not deltas_only:
    print("Requesting relative features not adjusted for time! Distace thresholds will not work as wellS")

seg_class_id

<KITTIClasses.car: 1>

In [5]:
# split = "mini_val"
# seq_names = mot_dataset.sequence_names(split)

split = "training"
seq_names = TRACK_VAL_SEQ

seq_names

['0001',
 '0006',
 '0008',
 '0010',
 '0012',
 '0013',
 '0014',
 '0015',
 '0016',
 '0018',
 '0019']

In [6]:
class SequenceDataset(InMemoryDataset):
    def __init__(self, root, sequence_name: str, transform=None, pre_transform=None):
        super(SequenceDataset, self).__init__(root, transform, pre_transform)
        self.sequence_name = sequence_name
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return []

    @property
    def processed_file_names(self):
        return [f"{self.sequence_name}.pt"]

In [75]:
save_folder_name = f"{datetime.datetime.now().strftime('%y-%m-%d_%H:%M')}_dets/processed"
# save_root_dir = Path(f"/storage/slurm/kimal/graphmot_workspace/nuscenes/{split}")
save_root_dir = Path(f"/storage/slurm/kimal/graphmot_workspace/kitti/{split}")
save_dir = save_root_dir / save_folder_name

load_graphs = True
# dir_to_load = save_root_dir / "21-08-18_23:36_dets"
dir_to_load = save_root_dir / "21-08-21_13:39_dets"

In [76]:
data_lists_for_sequences = {}
print(f"{'NOT ' if not load_graphs else ''}loading saved graphs from disk")
for seq_name in seq_names:
    print(f"Processing the graph for {seq_name}")
    if load_graphs:
        graph_data = SequenceDataset(dir_to_load, seq_name)
    else:
        graph_data = graph_construction.from_dataset(mot_dataset, CLIP_LENGTH,
                                                     split, [seq_name], **data_params_no_aug)
        data, slices = InMemoryDataset.collate(graph_data)
        save_dir.mkdir(parents=True, exist_ok=True)
        torch.save((data, slices), save_dir / f"{seq_name}.pt")
    data_lists_for_sequences[seq_name] = graph_data


loading saved graphs from disk
Processing the graph for 0001
Processing the graph for 0006
Processing the graph for 0008
Processing the graph for 0010
Processing the graph for 0012
Processing the graph for 0013
Processing the graph for 0014
Processing the graph for 0015
Processing the graph for 0016
Processing the graph for 0018
Processing the graph for 0019


In [77]:
NUM_WORKERS = 2
batch_size = 32  # 32  # 8 Nuscenes

In [78]:
log_folder_name = Path(mot_dataset.work_dir) / "gnn_training"
log_folder_name

PosixPath('/storage/slurm/kimal/graphmot_workspace/kitti/gnn_training')

### NuScenes mini

In [23]:
# # NuScenes mini check
# ckpt_folder_name = "21-08-17_20:19_aug1_1xpos_slope0.2_len5_edgedist_mini__recurr_edgedim16_steps4_focal_lr0.002_wd0_clip11_batch32_car"
# ckpt_name = "val_loss-epoch=335-val_loss=0.017069"

# ckpt_path = f"{ckpt_folder_name}/version_0/checkpoints/{ckpt_name}.ckpt"
# ckpt_path = str(log_folder_name / ckpt_path)
    
# hparams_file = ckpt_path.split("checkpoints")[0] + "hparams.yaml"
# model_to_test = GraphTrackerOffline.load_from_checkpoint(ckpt_path, hparams_file=hparams_file)

# is_consec = False

### NuScenes Full

In [13]:
ckpt_folder_name = "21-08-18_10:12_aug1_0.5xpos_slope0.2_len5_edgedist_recurr_edgedim16_steps4_focal_lr0.002_wd0.015_clip11_batch32_car"
ckpt_name = "val_loss=0.008631-epoch=233-"

ckpt_path = f"{ckpt_folder_name}/version_0/checkpoints/{ckpt_name}.ckpt"
ckpt_path = str(log_folder_name / ckpt_path)
    
hparams_file = ckpt_path.split("checkpoints")[0] + "hparams.yaml"
model_to_test = GraphTrackerOffline.load_from_checkpoint(ckpt_path, hparams_file=hparams_file)

is_consec = False

In [61]:
# # Trained on Clean data
# ckpt_folder_name = "21-08-17_03:27_clean_1xpos_slope0.2_len5_edgedist_clip11_edgedim16_steps4_focal_lr0.002_wd0_batch32_car"
# ckpt_name = "val_loss-epoch=62-val_loss=0.0007"

# ckpt_path = f"{ckpt_folder_name}/version_0/checkpoints/{ckpt_name}.ckpt"
# ckpt_path = str(log_folder_name / ckpt_path)
    
# hparams_file = ckpt_path.split("checkpoints")[0] + "hparams.yaml"
# model_to_test = GraphTrackerOffline.load_from_checkpoint(ckpt_path, hparams_file=hparams_file)

# is_consec = False

### KITTI

In [11]:
# ### KITTI world
ckpt_folder_name = "21-08-21_02:32_aug2_0.5xpos_slope0.2_world_offline_recurr_edgedim16_steps4_focal_lr0.004_wd0.005_clip11_batch64_car"
ckpt_name = "val_loss=0.021501-epoch=149"

ckpt_path = f"{ckpt_folder_name}/version_0/checkpoints/{ckpt_name}.ckpt"
ckpt_path = str(log_folder_name / ckpt_path)
    
hparams_file = ckpt_path.split("checkpoints")[0] + "hparams.yaml"
model_to_test = GraphTrackerOffline.load_from_checkpoint(ckpt_path, hparams_file=hparams_file)

is_consec = False

In [24]:
ckpt_folder_name = "21-08-16_16:30_aug2_1xpos_0.3bboxdrop_slope0.2_len5_edgedist_2lnode_recurr_edgedim16_steps4_focal_lr0.004_wd0.005_clip11_batch64_car"
ckpt_name = "val_loss-epoch=111-val_loss=0.0886"

ckpt_path = f"{ckpt_folder_name}/version_0/checkpoints/{ckpt_name}.ckpt"
ckpt_path = str(log_folder_name / ckpt_path)
    
hparams_file = ckpt_path.split("checkpoints")[0] + "hparams.yaml"
model_to_test = GraphTrackerOffline.load_from_checkpoint(ckpt_path, hparams_file=hparams_file)

is_consec = False

In [None]:
ckpt_folder_name = "21-08-11_03:59_augmented_full_2_0.5xpos_2layers_0.3bboxdrop_clip11_edgedim16_steps5_focal_lr0.003_wd0.001_batch32_car"
ckpt_name = "val_loss-epoch=195-val_loss=0.0555"

ckpt_path = f"{ckpt_folder_name}/version_0/checkpoints/{ckpt_name}.ckpt"
ckpt_path = str(log_folder_name / ckpt_path)
    
hparams_file = ckpt_path.split("checkpoints")[0] + "hparams.yaml"
model_to_test = GraphTrackerOffline.load_from_checkpoint(ckpt_path, hparams_file=hparams_file)

is_consec = False

In [44]:
ckpt_folder_name = "21-08-11_22:41_aug2_0.5xpos_2layers_0.3bboxdrop_consec_length5_clip11_edgedim16_steps5_focal_lr0.003_wd0.005_batch32_car"
ckpt_name = "val_loss-epoch=193-val_loss=0.0427"

ckpt_path = f"{ckpt_folder_name}/version_0/checkpoints/{ckpt_name}.ckpt"
ckpt_path = str(log_folder_name / ckpt_path)
    
hparams_file = ckpt_path.split("checkpoints")[0] + "hparams.yaml"
model_to_test = GraphTrackerOffline.load_from_checkpoint(ckpt_path, hparams_file=hparams_file)

is_consec = True

In [14]:
model_to_test

GraphTrackerOffline: 
initial_edge_model=MLP(
  (nonlinearity): LeakyReLU(negative_slope=0.2, inplace=True)
  (fc_layers): Sequential(
    (0): Linear(in_features=4, out_features=16, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Linear(in_features=16, out_features=16, bias=True)
    (3): LeakyReLU(negative_slope=0.2, inplace=True)
  )
), 
initial_node_model=InitialNodeModel(
  (node_mlp): MLP(
    (nonlinearity): LeakyReLU(negative_slope=0.2, inplace=True)
    (fc_layers): Sequential(
      (0): Linear(in_features=32, out_features=32, bias=True)
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
      (2): Linear(in_features=32, out_features=64, bias=True)
      (3): LeakyReLU(negative_slope=0.2, inplace=True)
      (4): Linear(in_features=64, out_features=32, bias=True)
      (5): LeakyReLU(negative_slope=0.2, inplace=True)
    )
  )
), 
mpn_model=MessagePassingNetworkRecurrent(
  (edge_model): BasicEdgeModel(
    (edge_mlp): MLP(
      (nonlinearity): L

In [59]:
data_list = graph_construction.from_dataset(mot_dataset, CLIP_LENGTH, split, ["0001"], **data_params_no_aug, starting_frame=0)

Processing sequence 0001
Processing frame 359
5 5 5 4 4 4 4 4 4 4 4 
Processing frame 360
5 5 4 4 4 4 4 4 4 4 4 
Building data objects for training from 1 sequences for length 11    : took 0.01 minutes
2 data objects in total
time_annotations      0.000 minutes
time_building_clip      0.001 minutes
time_attr_compute      0.001 minutes
time_drop_bboxes_frames 0.000 minutes
time_jitter_bboxes      0.000 minutes
time_add_bboxes         0.000 minutes


In [66]:
data_list

[Data(edge_attr=[244, 4], edge_distances=[244], edge_index=[2, 244], edge_polar_angles=[244], start_frame_i=358),
 Data(edge_attr=[242, 4], edge_distances=[242], edge_index=[2, 242], edge_polar_angles=[242], start_frame_i=359),
 Data(edge_attr=[241, 4], edge_distances=[241], edge_index=[2, 241], edge_polar_angles=[241], start_frame_i=360),
 Data(edge_attr=[237, 4], edge_distances=[237], edge_index=[2, 237], edge_polar_angles=[237], start_frame_i=361),
 Data(edge_attr=[233, 4], edge_distances=[233], edge_index=[2, 233], edge_polar_angles=[233], start_frame_i=362),
 Data(edge_attr=[231, 4], edge_distances=[231], edge_index=[2, 231], edge_polar_angles=[231], start_frame_i=363),
 Data(edge_attr=[218, 4], edge_distances=[218], edge_index=[2, 218], edge_polar_angles=[218], start_frame_i=364),
 Data(edge_attr=[206, 4], edge_distances=[206], edge_index=[2, 206], edge_polar_angles=[206], start_frame_i=365),
 Data(edge_attr=[204, 4], edge_distances=[204], edge_index=[2, 204], edge_polar_angles=[

In [89]:
target_class = seg_class_id.value

score_reduce_list = ["mean"]
pred_score_threshold = 0
reduced_score_threshold_list = [0.5]
max_time_diff_list = [5]  # 1-2-3-4 don't make a diff

combos = list(product(score_reduce_list, reduced_score_threshold_list, max_time_diff_list))

for (score_reduce, reduced_score_threshold, max_time_diff) in combos:

    mot_dataset.reset(only_submission=True)
    for seq_name, data_list in data_lists_for_sequences.items():
        print(f"Processing sequence {seq_name}")
        dataloader = DataLoader(data_list, shuffle=False, drop_last=False,
                                    batch_size=batch_size, num_workers=NUM_WORKERS)

        batched_preds = Trainer(gpus=1).predict(model_to_test, dataloaders=dataloader)
        batched_preds, _ = zip(*batched_preds)

        sequence = mot_dataset.get_sequence(split, seq_name)

        start_time = time.time()
        (num_nodes_before_frame, instance_id_to_node_id, 
             node_id_to_frame_i, num_nodes_processed) = processing.prep_sequence(sequence, data_list, 
                                                                                 CLIP_LENGTH, target_class)
        print(f"prep_sequence {time.time() - start_time:.2f}")

        start_time = time.time()
        node_matches = processing.map_predictions_to_detections(sequence, data_list, 
                                                                dataloader, batch_size, batched_preds,
                                                                num_nodes_before_frame, num_nodes_processed,
                                                                max_time_diff, pred_score_threshold)
        print(f"map_predictions_to_detections {time.time() - start_time:.2f}")

        start_time = time.time()
        node_match_stats = offline_utils.reduce_match_scores(node_matches, score_reduce)
        node_matches_sorted = sorted(node_match_stats, key=lambda triplet: triplet[-1], reverse=True)
        print(f"reduce_match_scores + sort {time.time() - start_time:.2f}")

        start_time = time.time()
        node_id_to_track_id, same_tracks_map, \
            track_id_to_frame_indices = processing.map_nodes_to_tracks(node_matches_sorted, 
                                                            node_id_to_frame_i, 
                                                            reduced_score_threshold)

        print(f"map_nodes_to_tracks {time.time() - start_time:.2f}")

        direct_mapping = offline_utils.streamline_mapping(same_tracks_map)

        start_time = time.time()
        node_id_to_final_track_id = {}
        track_ids_unique = set()
        for node_id, track_id in node_id_to_track_id.items():
            track_id = direct_mapping.get(track_id, track_id)
            assert track_id not in direct_mapping
            node_id_to_final_track_id[node_id] = track_id
            track_ids_unique.add(track_id)
        print(f"map to final tracks {time.time() - start_time:.2f}")

        print(f"{len(track_ids_unique)} unique tracks in total")

        suffix = f"sorted_when_prep_only_maxtime{max_time_diff}_{'consec_' if is_consec else ''}"
        suffix += f"pred{pred_score_threshold}_{reduced_score_threshold}{score_reduce}_{ckpt_folder_name[:14]}"
        run_info = sequence.report_offline_tracking(target_class, instance_id_to_node_id, node_id_to_final_track_id, suffix)
        break
    mot_dataset.save_all_mot_results(run_info["mot_3d_file"])
print("done")

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Processing sequence 0001


Predicting: 0it [00:00, ?it/s]

prep_sequence 0.64
map_predictions_to_detections 11.15
reduce_match_scores + sort 1.88
Assigned 437 tracks, combined 260 of them
map_nodes_to_tracks 0.01
map to final tracks 0.00
177 unique tracks in total
Processing frame 000000
Processing frame 000100
Processing frame 000200
Processing frame 000300
Processing frame 000400
done


In [83]:
len(node_id_to_track_id)

2176

In [85]:
len(instance_id_to_node_id)

2913

In [87]:
instance_id_to_node_id_unsorted = instance_id_to_node_id
instance_id_to_node_id_unsorted

{0: 0,
 1: 1,
 2: 2,
 3: 3,
 4: 4,
 5: 5,
 1000: 6,
 1001: 7,
 1002: 8,
 1003: 9,
 1004: 10,
 2000: 11,
 2001: 12,
 2002: 13,
 2003: 14,
 2004: 15,
 2005: 16,
 3000: 17,
 3001: 18,
 3002: 19,
 3003: 20,
 3004: 21,
 3005: 22,
 3006: 23,
 3007: 24,
 4000: 25,
 4001: 26,
 4002: 27,
 4003: 28,
 4004: 29,
 4005: 30,
 5000: 31,
 5001: 32,
 5002: 33,
 5003: 34,
 5004: 35,
 5005: 36,
 5006: 37,
 5007: 38,
 6000: 39,
 6001: 40,
 6002: 41,
 6003: 42,
 6004: 43,
 6005: 44,
 6006: 45,
 7000: 46,
 7001: 47,
 7002: 48,
 7003: 49,
 7004: 50,
 7005: 51,
 7006: 52,
 7007: 53,
 8000: 54,
 8001: 55,
 8002: 56,
 8003: 57,
 8004: 58,
 8005: 59,
 8006: 60,
 8007: 61,
 9000: 62,
 9001: 63,
 9002: 64,
 9003: 65,
 9004: 66,
 9005: 67,
 9006: 68,
 9007: 69,
 9008: 70,
 10000: 71,
 10001: 72,
 10002: 73,
 10003: 74,
 10004: 75,
 10005: 76,
 10006: 77,
 10007: 78,
 11000: 79,
 11001: 80,
 11002: 81,
 11003: 82,
 11004: 83,
 11005: 84,
 11006: 85,
 11007: 86,
 11008: 87,
 11009: 88,
 12000: 89,
 12001: 90,
 12002:

In [74]:
for i, scores in enumerate(node_matches.values()):
    if len(scores) > 1:
        print(len(scores))
    if i == 200:
        break

2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3


7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
1

6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
1

8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6

6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
8
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9


KeyboardInterrupt: 

In [99]:
node_id_to_track_id = {}
same_tracks_map = {}
# record which frames each track has already covered - to avoid overlap when joining tracks
track_id_to_frame_indices = defaultdict(set)

start_nodes_done = set()
end_nodes_done = set()

track_id_latest = 0

# links with higher scores will be popped first
next_nodes_stack = deque(start_node for start_node, end_node, score in reversed(node_matches_sorted))

In [100]:
while len(next_nodes_stack):
    start_node = next_nodes_stack.pop()
    if start_node in start_nodes_done:
        print(f"{start_node} - as start done")
        continue

    targets_with_scores_queue = matches_dict_sorted.get(start_node, None)  # if it has no connections
    if not targets_with_scores_queue:
        print(f"{start_node} - no targets from it")
        continue

    while len(targets_with_scores_queue):
        end_node, score = targets_with_scores_queue.popleft()
        if end_node not in end_nodes_done:
            print(f" -> {end_node} - chosen as target from {start_node}")
            break
    else:  
        assert not targets_with_scores_queue
        assert not matches_dict_sorted[start_node]
        del matches_dict_sorted[start_node]
        # exhausted all targets
        print(f"{start_node} - no more free targets from it")
        continue

    if score < reduced_score_threshold:
        print(f"Assigned {track_id_latest} tracks, combined {len(same_tracks_map)} of them")
        continue

    assert end_node not in end_nodes_done

    start_node_frame_i = node_id_to_frame_i[start_node]
    end_node_frame_i = node_id_to_frame_i[end_node]
    assert start_node_frame_i != end_node_frame_i, f"{start_node}, {end_node}, {score} same frame"

    if start_node not in node_id_to_track_id and end_node not in node_id_to_track_id:
        # both are new - start a new track
        track_to_assign = track_id_latest
        track_id_latest += 1
    elif start_node in node_id_to_track_id and end_node not in node_id_to_track_id:
        # only one is already assigned
        track_to_assign = node_id_to_track_id[start_node]
        while track_to_assign in same_tracks_map:
            track_to_assign = same_tracks_map[track_to_assign]

        if end_node_frame_i in track_id_to_frame_indices[track_to_assign]:
            # this track already has a detection in that frame
            continue
    elif end_node in node_id_to_track_id and start_node not in node_id_to_track_id:
        track_to_assign = node_id_to_track_id[end_node]
        while track_to_assign in same_tracks_map:
            track_to_assign = same_tracks_map[track_to_assign]

        if start_node_frame_i in track_id_to_frame_indices[track_to_assign]:
            # this track already has a detection in that frame
            continue
    else:
        start_node_track = node_id_to_track_id[start_node]
        end_node_track = node_id_to_track_id[end_node]

        while start_node_track in same_tracks_map:
            start_node_track = same_tracks_map[start_node_track]
        while end_node_track in same_tracks_map:
            end_node_track = same_tracks_map[end_node_track]

        if start_node_track == end_node_track:
            continue

        if any(x in track_id_to_frame_indices[start_node_track] 
               for x in track_id_to_frame_indices[end_node_track]):
            # these tracks cover at least one same frame - overlapping tracks remain independent
            continue

        if start_node_track != end_node_track:
            earlier_track = min(start_node_track, end_node_track)
            later_track = max(start_node_track, end_node_track)
            same_tracks_map[earlier_track] = later_track
            track_id_to_frame_indices[later_track].update(track_id_to_frame_indices[earlier_track])
            track_to_assign = later_track

    node_id_to_track_id[start_node] = track_to_assign
    node_id_to_track_id[end_node] = track_to_assign

    track_id_to_frame_indices[track_to_assign].add(start_node_frame_i)
    track_id_to_frame_indices[track_to_assign].add(end_node_frame_i)

    start_nodes_done.add(start_node)
    end_nodes_done.add(end_node)

    next_nodes_stack.append(end_node)

 -> 567 - chosen as target from 537
 -> 569 - chosen as target from 567
Assigned 1 tracks, combined 0 of them
 -> 132 - chosen as target from 126
 -> 146 - chosen as target from 132
 -> 154 - chosen as target from 146
 -> 161 - chosen as target from 154
 -> 170 - chosen as target from 161
 -> 178 - chosen as target from 170
 -> 182 - chosen as target from 178
 -> 193 - chosen as target from 182
 -> 198 - chosen as target from 193
 -> 208 - chosen as target from 198
 -> 217 - chosen as target from 208
 -> 227 - chosen as target from 217
 -> 247 - chosen as target from 227
 -> 264 - chosen as target from 247
 -> 282 - chosen as target from 264
 -> 316 - chosen as target from 282
 -> 323 - chosen as target from 316
 -> 346 - chosen as target from 323
 -> 354 - chosen as target from 346
 -> 362 - chosen as target from 354
 -> 369 - chosen as target from 362
 -> 378 - chosen as target from 369
 -> 386 - chosen as target from 378
 -> 396 - chosen as target from 386
 -> 412 - chosen as target

AssertionError: 593, 597, 0.9488247036933899 same frame

In [119]:
node_match_stats = reduce_match_scores(node_matches, score_reduce)

In [129]:
node_match_stats

[(0, 6, 1, 0.8023928999900818),
 (0, 7, 1, 0.2695392668247223),
 (0, 8, 1, 0.183773010969162),
 (0, 9, 1, 0.15450535714626312),
 (0, 10, 1, 0.15554189682006836),
 (0, 11, 2, 0.05288366973400116),
 (0, 12, 2, 0.7419271469116211),
 (0, 13, 2, 0.010699267499148846),
 (0, 14, 2, 0.02443588525056839),
 (0, 15, 2, 0.024670930579304695),
 (0, 16, 2, 0.25350552797317505),
 (0, 17, 3, 0.0637764111161232),
 (0, 18, 3, 0.6304059028625488),
 (0, 19, 3, 0.008165400475263596),
 (0, 20, 3, 0.013989799655973911),
 (0, 21, 3, 0.028728950768709183),
 (0, 22, 3, 0.16156984865665436),
 (0, 23, 3, 0.2480320930480957),
 (0, 24, 3, 0.2654702663421631),
 (0, 25, 4, 0.11307983845472336),
 (0, 26, 4, 0.6999945640563965),
 (0, 27, 4, 0.014082674868404865),
 (0, 28, 4, 0.014057829976081848),
 (0, 29, 4, 0.023379676043987274),
 (0, 30, 4, 0.10300933569669724),
 (0, 31, 5, 0.14182287454605103),
 (0, 32, 5, 0.8192587494850159),
 (0, 33, 5, 0.01059158518910408),
 (0, 34, 5, 0.0071989428251981735),
 (0, 35, 5, 0.02480

In [130]:
node_matches_sorted = sorted(node_match_stats, key=lambda triplet: triplet[-1], reverse=True)

In [178]:
node_matches_sorted

[(2901, 2903, 2, 1.0),
 (2902, 2903, 1, 1.0),
 (2898, 2904, 6, 1.0),
 (2900, 2904, 4, 1.0),
 (2901, 2904, 3, 1.0),
 (2902, 2904, 2, 1.0),
 (2903, 2904, 1, 1.0),
 (2897, 2905, 8, 1.0),
 (2898, 2905, 7, 1.0),
 (2899, 2905, 6, 1.0),
 (2900, 2905, 5, 1.0),
 (2901, 2905, 4, 1.0),
 (2902, 2905, 3, 1.0),
 (2903, 2905, 2, 1.0),
 (2904, 2905, 1, 1.0),
 (2898, 2906, 8, 1.0),
 (2899, 2906, 7, 1.0),
 (2900, 2906, 6, 1.0),
 (2901, 2906, 5, 1.0),
 (2902, 2906, 4, 1.0),
 (2903, 2906, 3, 1.0),
 (2904, 2906, 2, 1.0),
 (2905, 2906, 1, 1.0),
 (2898, 2907, 9, 1.0),
 (2899, 2907, 8, 1.0),
 (2900, 2907, 7, 1.0),
 (2901, 2907, 6, 1.0),
 (2902, 2907, 5, 1.0),
 (2903, 2907, 4, 1.0),
 (2904, 2907, 3, 1.0),
 (2905, 2907, 2, 1.0),
 (2906, 2907, 1, 1.0),
 (2898, 2908, 10, 1.0),
 (2899, 2908, 9, 1.0),
 (2900, 2908, 8, 1.0),
 (2901, 2908, 7, 1.0),
 (2902, 2908, 6, 1.0),
 (2903, 2908, 5, 1.0),
 (2904, 2908, 4, 1.0),
 (2905, 2908, 3, 1.0),
 (2906, 2908, 2, 1.0),
 (2907, 2908, 1, 1.0),
 (2899, 2909, 10, 1.0),
 (2900, 2

In [191]:
node_id_to_track_id, same_tracks_map = map_nodes_to_tracks(node_matches_sorted, score_threshold)

Assigned 582 tracks, combined somem of them probably


In [45]:
sequence = mot_dataset.get_sequence(SPLIT, "0001")

In [21]:
batched_preds_0001 = pl.Trainer(gpus=1).predict(model_to_test, dataloaders=dataloader_val)
batched_preds_0001, _ = zip(*batched_preds_0001)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

NameError: name 'batched_preds' is not defined

In [33]:
for data_batch, batched_preds_0001_clip in zip(dataloader_val, batched_preds_0001):
    print(data_batch)
    print(batched_preds_0001_clip.shape)
    break

Batch(batch=[2982], edge_attr=[126607, 7], edge_index=[2, 126607], ptr=[33], start_frame_i=[32])
torch.Size([126607, 1])


In [35]:
for data_0001_single in data_batch.to_data_list():
    print(data_0001_single)
    break

Data(edge_attr=[2829, 7], edge_index=[2, 2829], start_frame_i=[1])


In [41]:
data_0001_single.edge_index[:, :10]

tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])

In [43]:
batched_preds_0001_clip[:10].reshape(1, -1)

tensor([[0.8024, 0.2695, 0.1838, 0.1545, 0.1555, 0.0529, 0.7419, 0.0107, 0.0244,
         0.0247]])

In [53]:
len(batched_preds_0001_clip)

126607

In [55]:
data_0001_preds = batched_preds_0001_clip[:data_0001_single.num_edges]

In [60]:
data_0001_preds = torch.hstack((data_0001_single.edge_index.T, data_0001_preds))

In [65]:
data_0001_preds[:10]

tensor([[0.0000e+00, 6.0000e+00, 8.0239e-01],
        [0.0000e+00, 7.0000e+00, 2.6954e-01],
        [0.0000e+00, 8.0000e+00, 1.8377e-01],
        [0.0000e+00, 9.0000e+00, 1.5451e-01],
        [0.0000e+00, 1.0000e+01, 1.5554e-01],
        [0.0000e+00, 1.1000e+01, 5.2884e-02],
        [0.0000e+00, 1.2000e+01, 7.4193e-01],
        [0.0000e+00, 1.3000e+01, 1.0699e-02],
        [0.0000e+00, 1.4000e+01, 2.4436e-02],
        [0.0000e+00, 1.5000e+01, 2.4671e-02]])

In [78]:
for i in range(20):
    edges = data_0001_single.edge_index[:, i]
    score = batched_preds_0001_clip[i]
    print(f"{edges[0].item():>2}->{edges[1].item():<2}: {score.item():.2f}")
#     print(data_0001_single.edge_index[0, i].item(), data_0001_single.edge_index[1, i].item(), " - ", batched_preds_0001_clip[i].item())

 0->6 : 0.80
 0->7 : 0.27
 0->8 : 0.18
 0->9 : 0.15
 0->10: 0.16
 0->11: 0.05
 0->12: 0.74
 0->13: 0.01
 0->14: 0.02
 0->15: 0.02
 0->16: 0.25
 0->17: 0.06
 0->18: 0.63
 0->19: 0.01
 0->20: 0.01
 0->21: 0.03
 0->22: 0.16
 0->23: 0.25
 0->24: 0.27
 0->25: 0.11


In [None]:
0 6 11 17 25

In [None]:
# discrad scores above threshold if already matched to something in that frame
# can be done early in the process

In [56]:
data_0001_preds.shape

torch.Size([2829, 1])

In [47]:
dets_per_frame = sequence.dets_3d_per_frame

In [77]:
print(len([det for det in dets_per_frame["000000"] if det.seg_class_id == seg_class_id.value]))
print(len([det for det in dets_per_frame["000001"] if det.seg_class_id == seg_class_id.value]))
print(len([det for det in dets_per_frame["000002"] if det.seg_class_id == seg_class_id.value]))
print(len([det for det in dets_per_frame["000003"] if det.seg_class_id == seg_class_id.value]))


6
5
6
8


In [29]:
len(batched_preds_0001)

14

In [200]:
data_list_0001 = data_lists_for_sequences["0001"]

In [202]:
len(data_list_0001)

420

In [203]:
data_0001_000000 = data_list_0001[0]

In [206]:
data_0001_000000.edge_index[:, :10]

tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])

In [149]:
folder_name = "21-08-10_13:53_augmented_full_2_0.5xpos_2layers_clip11_edgedim16_steps5_focal_batch32_car"
ckpt_name = "val_loss-epoch=149-val_loss=0.0493"

ckpt_path = log_folder_name / f"{folder_name}/version_0/checkpoints/{ckpt_name}.ckpt"
print(f"Loading {ckpt_path}")
model_to_test = GraphTrackerOffline.load_from_checkpoint(ckpt_path)
model_to_test.eval()

Loading /storage/slurm/kimal/graphmot_workspace/kitti/gnn_training/21-08-10_13:53_augmented_full_2_0.5xpos_2layers_clip11_edgedim16_steps5_focal_batch32_car/version_0/checkpoints/val_loss-epoch=149-val_loss=0.0493.ckpt


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

Inference took 6.0 sec


[(tensor([[0.8989],
          [0.7105],
          [0.8802],
          [0.8649],
          [0.1972],
          [0.8562],
          [0.1696],
          [0.5702],
          [0.1606],
          [0.4787],
          [0.2001],
          [0.2059],
          [0.2045],
          [0.1447],
          [0.7910],
          [0.9521],
          [0.9381],
          [0.0690],
          [0.9163],
          [0.0747],
          [0.5803],
          [0.0718],
          [0.5075],
          [0.0872],
          [0.0934],
          [0.1045],
          [0.0326],
          [0.8782],
          [0.8506],
          [0.2264],
          [0.8335],
          [0.1933],
          [0.5753],
          [0.1989],
          [0.5174],
          [0.2404],
          [0.2407],
          [0.2488],
          [0.1780],
          [0.9484],
          [0.0558],
          [0.9339],
          [0.0608],
          [0.5709],
          [0.0647],
          [0.5199],
          [0.0870],
          [0.0893],
          [0.1078],
          [0.0271],


In [13]:
output = model_to_test(batch)
preds = torch.hstack([model_to_test._predictions_from_forward([out]) for out in output])
print(preds.shape)

torch.Size([18687, 5])


In [14]:
for i, batch in enumerate(dataloader_val):
    output = model_to_test(batch)
    preds = torch.hstack([model_to_test._predictions_from_forward([out]) for out in output])
    for col in preds.T[:-1]:
        fraction_diff = 1 - ((col == preds[:, -1]).sum() / preds.shape[0])
        print(f"{fraction_diff * 100:>5.2f}% -", end=" ")
    print()
    if i == 5:
        break


37.70% - 23.19% - 11.76% -  7.77% - 
19.63% - 17.38% - 13.88% -  7.22% - 
11.59% - 10.85% -  8.22% -  5.26% - 
 8.54% -  7.06% -  5.87% -  4.00% - 
18.75% - 13.61% - 12.50% -  7.99% - 
23.69% - 15.51% -  6.82% -  5.97% - 
