In [2]:
import os
import yaml
import time
from copy import deepcopy
import numpy as np
import torch

os.chdir('../')

import mmengine
from mot_3d.data_protos import BBox
from mot_3d.mot import MOTModel
from mot_3d.frame_data import FrameData
from mot_3d.update_info_data import UpdateInfoData
import mot_3d.tracklet as tracklet
from mot_3d.association import associate_dets_to_tracks, associate_unmatched_trks

In [None]:
def load_frame(data, detection_result, i, track_labels, point_cloud_range):
    assert data['sample_idx'] == detection_result['sample_idx'][i]
    # points = np.fromfile('./data/CODA/' + data['lidar_path'], dtype=np.float32).reshape([-1, 4])
    # mask = (points[:, 0] > point_cloud_range[0]) & (points[:, 0] < point_cloud_range[3]) & \
    #     (points[:, 1] > point_cloud_range[1]) & (points[:, 1] < point_cloud_range[4]) & \
    #     (points[:, 2] > point_cloud_range[2]) & (points[:, 2] < point_cloud_range[5])
    # points = points[mask]
    ego2global = data['ego2global']
    pre_bboxes = detection_result['pre_bboxes'][i]
    pre_labels = detection_result['pre_labels'][i]
    track_mask = np.isin(pre_labels, np.array(track_labels))
    pre_bboxes = pre_bboxes[track_mask].tolist()
    pre_labels = pre_labels[track_mask].tolist()
    frame_data = FrameData(dets=pre_bboxes, ego=ego2global, pc=None, det_types=pre_labels, time_stamp=data['time_stamp'])
    frame_data.dets = [BBox.bbox2world(ego2global, det) for det in frame_data.dets]
    
    gt_bboxes = [BBox.bbox2world(ego2global, BBox.array2bbox(instance['bbox_3d'])) for instance in data['instances']]
    gt_labels = [instance['bbox_label_3d'] for instance in data['instances']]
    gt_ids = [str(instance['instance_id'].split(':')[-1]) for instance in data['instances']]
    
    return frame_data, gt_bboxes, gt_labels, gt_ids

In [None]:
data_info = mmengine.load('./data/CODA/coda_infos_val.pkl')
detection_results = mmengine.load('./data/CODA/detection_results.pkl')
config_path = 'configs/nus_configs/diou.yaml'
configs = yaml.load(open(config_path, 'r'), Loader=yaml.Loader)
point_cloud_range = [-21.0, -21.0, -2.0, 21.0, 21.0, 6.0]

In [None]:
scene = -1
frame = -1
all_results = []
sequence_results = None
track_labels = [1, 2]

print('Processing %d frames' % len(data_info['data_list']))
for i, data in enumerate(data_info['data_list']):
    if data['scene'] != scene['scene_id'] or data['frame'] - 1 != frame:
        if sequence_results is not None:
            all_results.append(sequence_results)
            print('Processing scene %d end at frame %d (%d / %d)' % (scene, frame, i, len(data_info['data_list'])))
        scene = data['scene']
        frame = data['frame']
        print('Processing scene %d start from frame %d' % (scene, frame))
        tracker = MOTModel(configs)
        sequence_results = {'token': data['token'], 'data_list': []}
    
    frame_data, gt_bboxes, gt_labels, gt_ids = load_frame(data, detection_results, i, track_labels, point_cloud_range)
    results = tracker.frame_mot(frame_data)
    result_pred_bboxes = [trk[0] for trk in results]
    result_pred_ids = [trk[1] for trk in results]
    result_pred_states = [trk[2] for trk in results]
    result_labels = [trk[3] for trk in results]

    frame_result = {}
    frame_result['track_ids'] = result_pred_ids
    frame_result['track_bboxes'] = [BBox.bbox2array(bbox) for bbox in result_pred_bboxes]
    frame_result['track_states'] = result_pred_states
    frame_result['track_labels'] = result_labels
    frame_result['gt_bboxes'] = [BBox.bbox2array(bbox) for bbox in gt_bboxes]
    frame_result['gt_labels'] = gt_labels
    frame_result['gt_ids'] = gt_ids
    sequence_results['data_list'].append(frame_result)

In [None]:
mmengine.dump(all_results, './data/CODA/code_track_result.pkl')

In [None]:
infer_time1 = 0
infer_time2 = 0
infer_time3 = 0
infer_time4 = 0
infer_time5 = 0
for scene in track_info:
    tracker = MOTModel(configs)
    scene_results = []
    for frame_id, data in enumerate(scene['data_list']):
        time0 = time.time()
        input_data, gt_bboxes = load_frame(data, point_cloud_range)

        time1 = time.time()
        tracker.frame_count += 1
        # initialize the time stamp on frame 0
        if tracker.time_stamp is None:
            tracker.time_stamp = input_data.time_stamp
    
        # filter out low-score detections
        dets = input_data.dets
        det_indexes = [i for i, det in enumerate(dets) if det.s >= tracker.score_threshold]
        dets = [dets[i] for i in det_indexes]

        # prediction and association
        trk_preds = list()
        for trk in tracker.trackers:
            trk_preds.append(trk.predict(input_data.time_stamp))
        matched, unmatched_dets, unmatched_trks = associate_dets_to_tracks(dets, trk_preds, 
            tracker.match_type, tracker.asso, tracker.asso_thres)
        for k in range(len(matched)):
            matched[k][0] = det_indexes[matched[k][0]]
        for k in range(len(unmatched_dets)):
            unmatched_dets[k] = det_indexes[unmatched_dets[k]]
        
        time2 = time.time()
        # association in second stage
        dets = input_data.dets
        det_indexes = [i for i, det in enumerate(dets) if det.s >= tracker.score_threshold_second_stage]
        dets = [dets[i] for i in det_indexes]
        unmatched_trk_preds = [tracker.trackers[t].get_state() for t in unmatched_trks]
        update_modes = associate_unmatched_trks(dets, unmatched_trk_preds, tracker.asso, tracker.asso_thres_second_stage)

        time3 = time.time()
        # update the matched tracks
        for i in range(len(matched)):
            d = matched[i][0]
            trk = tracker.trackers[matched[i][1]]
            update_info = UpdateInfoData(mode=1, bbox=input_data.dets[d], ego=input_data.ego, 
                        frame_index=tracker.frame_count, pc=input_data.pc, dets=input_data.dets)
            trk.update(update_info)
        for i in range(len(unmatched_trks)):
            trk = tracker.trackers[unmatched_trks[i]]
            update_info = UpdateInfoData(mode=update_modes[i], bbox=unmatched_trk_preds[i], ego=input_data.ego, 
                        frame_index=tracker.frame_count, pc=input_data.pc, dets=input_data.dets)
            trk.update(update_info)

        time4 = time.time()
        # create new tracks for unmatched detections
        for index in unmatched_dets:
            track = tracklet.Tracklet(tracker.configs, tracker.count, input_data.dets[index], input_data.det_types[index], 
                tracker.frame_count, time_stamp=input_data.time_stamp)
            tracker.trackers.append(track)
            tracker.count += 1
        
        # remove dead tracks
        track_num = len(tracker.trackers)
        for index, trk in enumerate(reversed(tracker.trackers)):
            if trk.death(tracker.frame_count):
                tracker.trackers.pop(track_num - 1 - index)
        
        # output the results
        results = list()
        for trk in tracker.trackers:
            state_string = trk.state_string(tracker.frame_count)
            results.append((trk.get_state(), trk.id, state_string, trk.det_type))
        
        # wrap up and update the information about the mot trackers
        tracker.time_stamp = input_data.time_stamp
        for trk in tracker.trackers:
            trk.sync_time_stamp(tracker.time_stamp)

        result_pred_bboxes = [trk[0] for trk in results]
        result_pred_ids = [trk[1] for trk in results]
        result_pred_states = [trk[2] for trk in results]
        result_types = [trk[3] for trk in results]

        frame_result = {}
        frame_result['track_ids'] = result_pred_ids
        frame_result['track_bboxes'] = [BBox.bbox2array(bbox) for bbox in result_pred_bboxes]
        frame_result['track_states'] = result_pred_states
        scene_results.append(frame_result)

        time5 = time.time()
        infer_time1 += time1 - time0
        infer_time2 += time2 - time1
        infer_time3 += time3 - time2
        infer_time4 += time4 - time3
        infer_time5 += time5 - time4

In [None]:
infer_time1, infer_time2, infer_time3, infer_time4, infer_time5