## MOT Joint Detection and Embedding Finetuning
### This notebook is used to lunch the finetuning of FPN on MOT joint detection and embedding using the tracking ground truth, the code uses weights of the object detector trained previously (optionally)

In [None]:
#select the mot challenge to work with (two are available)
mot_datasets = ["MOT17","MOT20"]
dataset = mot_datasets[0]

In [None]:
data_path ='../datasets/MOT/%s/train/'%dataset

In [None]:
import detectron2
from detectron2.utils.logger import setup_logger

setup_logger()

# import some common libraries
import numpy as np
import cv2
import random



# import some common detectron2 utilities

from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
from detectron2.modeling import build_model
from detectron2.evaluation import COCOEvaluator,PascalVOCDetectionEvaluator
import matplotlib.pyplot as plt
import torch.tensor as tensor
from detectron2.data import build_detection_test_loader
from detectron2.evaluation import inference_on_dataset
import torch
from detectron2.structures.instances import Instances
from detectron2.modeling import build_model
%matplotlib inline

In [None]:
import os
import numpy as np
import json
from detectron2.structures import BoxMode
def get_mot_joint(path,P=8,K=4):
    dict_arr = []
    

    for seq_name in os.listdir(path) :
        frame_number = len(os.listdir(path+seq_name+'/img1'))

        frames = [None]*(frame_number+1)
        boxes = [None] *(frame_number +1)
        for i in range(len(boxes)):
            boxes[i] = {}
            frames[i] =[]
        with open(path + seq_name+ '/gt/gt.txt') as f:
            for line in f:
                parts = line.split(',')
                frame_number = int(parts[0])
                
                if int(parts[6]) == 1 and int(parts[7]) == 1 and float(parts[8]) >= 0.25:
                    
                    
                    frames[frame_number].append(int(parts[1]))
                  
                    xmin = int(round(float(parts[2])))
                    ymin = int(round(float(parts[3])))
                    box_width = int(round(float(parts[4])))
                    box_height = int(round(float(parts[5])))
                    xmax = xmin+ box_width
                    ymax = ymin + box_height
                    key = parts[1]
                    
                    
                    boxes[frame_number][parts[1]] =   {'bbox':[xmin,ymin,xmax,ymax
                                                                   ],"img_number":parts[0], "obj_id": parts[1],"category_id": 0,
                                                            "iscrowd": 0,"bbox_mode": BoxMode.XYXY_ABS} 
        
        frame_set = []
        for f in frames:
            if(f is not None):
                frame_set.append(set(f))
        
        import random
        frame_order1 = np.arange(0,len(frame_set))
        random.shuffle(frame_order1)
        frame_order2 = np.arange(0,len(frame_set))
        random.shuffle(frame_order2)
        frame_order3 = np.arange(0,len(frame_set))
        random.shuffle(frame_order3)
        frame_order4 = np.arange(0,len(frame_set))
        random.shuffle(frame_order4)
        
        frame_order = list(frame_order1)+list(frame_order2)+list(frame_order3)+list(frame_order4)
        frame_order1 = np.arange(1,len(frame_set))
        random.shuffle(frame_order1)
        frame_order2 = np.arange(1,len(frame_set))
        random.shuffle(frame_order2)
        frame_order3 = np.arange(1,len(frame_set))
        random.shuffle(frame_order3)
        frame_order4 = np.arange(1,len(frame_set))
        random.shuffle(frame_order4)
        frame_order = list(np.arange(1,len(frame_set)))# + list(np.arange(1,len(frame_set)))+ list(np.arange(1,len(frame_set)))
        i =0
        useful = 0
        accum_intersect = []
        
        while i < (len(frame_set)*1 -16):
            freqs = {}
            ids_to_keep = random.sample(frame_order[i:i+16],k=P)
            
            for frame_ids in [frame_set[k] for k in ids_to_keep]:
                for t_id in frame_ids:
                    if(t_id not in freqs):
                        freqs[t_id] =0
                    freqs[t_id] +=1
            okay = np.where(np.array([freqs[k] for k in freqs.keys()])>=K)[0].shape[0]
            
            base_url = path + seq_name + '/img1/'
            if(okay>=P):
                pairs_used = [k for k in freqs.keys() if freqs[k]>=K][0:P]
                
                frame_number = 0
                frame_group = []
                while frame_number <P:
                    obj = {}
                    flatten = lambda l: [item for sublist in l for item in sublist]
                    obj['file_name'] = base_url+'%s.jpg'%str(ids_to_keep[frame_number]).zfill(6) 
                    
                    temp_img = cv2.imread(obj['file_name'])
                    if(temp_img is None):
                        print(obj['file_name'])
                    obj['width'] = temp_img.shape[1]
                    obj['height'] = temp_img.shape[0]
                    obj["image_id"] = [i+frame_number]
                    obj["group_id"] = i
                    
                    
                    car_keys = []
                    
                   
                    for k in boxes[ids_to_keep[frame_number]].keys():
                        if(boxes[ids_to_keep[frame_number]][k]["category_id"]==0):
                            car_keys.append(int(k))
                    ordered_keys = []
                    first_keys = []
                    second_keys = []
                    for k in boxes[ids_to_keep[frame_number]].keys():
                        if(int(k) in accum_intersect):
                            first_keys.append(str(k))
                        else:
                            second_keys.append(str(k))
                    

                    ordered_keys = first_keys + second_keys
                    obj["ids"] = [int(k) for k in boxes[ids_to_keep[frame_number]].keys()]
                    obj["classes"] = [int(boxes[ids_to_keep[frame_number]][k]['category_id']) for k in boxes[ids_to_keep[frame_number]].keys()]
                    obj["pairs_used"] = pairs_used
                    obj["annotations"]= ([boxes[ids_to_keep[frame_number]][str(k)] for k in ordered_keys])
                    
                   
                    
                    obj["labels"] = []
                    list_labels = list(accum_intersect.copy()) 
                    for l in list_labels:
                        obj["labels"].append(l)
                        
                    dict_arr.append(obj)
                    frame_number +=1
                    frame_group.append(obj['image_id'])
                
            i=i+1
            
        break   
    
    return dict_arr
   

In [None]:
from detectron2.data import DatasetCatalog, MetadataCatalog
for d in ["train"]:
    DatasetCatalog.register("%s_"%dataset + d,
                            lambda d=d: 
                            get_mot_joint
                            (data_path))
    
balloon_metadata = MetadataCatalog.get("mot20_train")

In [None]:
from detectron2.engine import DefaultTrainer
from detectron2.config import get_cfg
import os
cfg = get_cfg()
cfg.merge_from_file("../configs/COCO-Detection/faster_rcnn_R_50_FPN_3x_Video.yaml")
cfg.DATASETS.TRAIN = ("%s_train"%dataset,)
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 1

cfg.MODEL.WEIGHTS="../models/MOT/%s_DET/model_final.pth"%dataset
cfg.MODEL.ROI_HEADS.NUM_CLASSES=1
cfg.SOLVER.IMS_PER_BATCH = 8
cfg.SOLVER.BASE_LR = 0.00005  # pick a good LR
cfg.SOLVER.MAX_ITER = 34000  # 300 iterations seems good enough for this toy dataset; you may need to train longer for a practical dataset
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512   # faster, and good enough for this toy dataset (default: 512)

cfg.OUTPUT_DIR='../models/MOT/%s_JOINT'%dataset
print(cfg.OUTPUT_DIR)
print(cfg.MODEL.ANCHOR_GENERATOR)

### Trainer initialization and dataset loading

In [None]:

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg,True) 


### Training

In [None]:
trainer.resume_or_load(resume=False)
trainer.train()