## KITTI Joint Detection and Embedding Finetuning
### This notebook is used to lunch the finetuning of FPN on KITTI joint detection and embedding using the tracking ground truth, the code uses weights of the object detector trained previously (optionally)

In [None]:
import detectron2
from detectron2.utils.logger import setup_logger

setup_logger()

# import some common libraries
import numpy as np
import cv2
import random



# import some common detectron2 utilities

from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
from detectron2.modeling import build_model
from detectron2.evaluation import COCOEvaluator,PascalVOCDetectionEvaluator
import matplotlib.pyplot as plt
import torch.tensor as tensor
from detectron2.data import build_detection_test_loader
from detectron2.evaluation import inference_on_dataset
import torch
from detectron2.structures.instances import Instances
from detectron2.modeling import build_model
%matplotlib inline

## Dataset Building

In [None]:
data_path = "../datasets/KITTI/tracking/"

In [None]:
import os
import numpy as np
import json
from detectron2.structures import BoxMode
def get_kitti_joint(path):
    dict_arr = []
    triplet_boxes = {}
    img_folder = path + 'data_tracking_image_2/training/image_02/'
    label_folder = path + 'data_tracking_label_2/training/label_02/' 
    for seq_name in os.listdir(img_folder):
        with open('%s%s.txt'%(label_folder,seq_name)) as f:
            for line in f:
                parts = line.split(' ')
                frame_number = int(parts[0])

        frames = [None]*(frame_number+1)
        boxes = [None] *(frame_number +1)
        for i in range(len(boxes)):
            boxes[i] = {}
            frames[i] =[]
        with open('%s%s.txt'%(label_folder,seq_name)) as f:
            for line in f:
                parts = line.split(' ')
                frame_number = int(parts[0])
                cat = -1
                if(parts[2] == 'Car' or parts[2] == 'Van'):
                    cat = 0
                
                
                if(cat>-1):
                    
                    frames[frame_number].append(int(parts[1]))
                    xmin = int(float(parts[6]))
                    ymin = int(float(parts[7]))
                    xmax = int(float(parts[8]))
                    ymax = int(float(parts[9]))
                    key = parts[1]
                    if(cat==0):
                        key = '-'+parts[1]
                    
                    boxes[frame_number][parts[1]] =   {'bbox':[xmin,ymin,xmax,ymax
                                                                   ],"img_number":parts[0], "obj_id": parts[1],"category_id": cat,
                                                            "iscrowd": 0,"bbox_mode": BoxMode.XYXY_ABS} 
        
        frame_set = []
        for f in frames:
            if(f is not None):
                frame_set.append(set(f))
        
        import random
        
        frame_order = list(np.arange(0,len(frame_set))) + list(np.arange(0,len(frame_set)))+  list(np.arange(0,len(frame_set)))+list(np.arange(0,len(frame_set)))+list(np.arange(0,len(frame_set)))
        i =0
        useful = 0
        accum_intersect = []
        
        while i < (len(frame_set)*5 -16):
            freqs = {}
            ids_to_keep = random.sample(frame_order[i:i+16],k=8)
            
            for frame_ids in [frame_set[k] for k in ids_to_keep]:
                for t_id in frame_ids:
                    if(t_id not in freqs):
                        freqs[t_id] =0
                    freqs[t_id] +=1
            okay = np.where(np.array([freqs[k] for k in freqs.keys()])>=4)[0].shape[0]
            
            
            base_url = '%s%s/'%(img_folder,seq_name)
            if(okay>=8):
                pairs_used = [k for k in freqs.keys() if freqs[k]>=4]
                
                frame_number = 0
                frame_group = []
                while frame_number <8:
                    obj = {}
                    flatten = lambda l: [item for sublist in l for item in sublist]
                    obj['file_name'] = base_url+'%s.png'%str(ids_to_keep[frame_number]).zfill(6) 
                    
                    temp_img = cv2.imread(obj['file_name'])
                    obj['width'] = temp_img.shape[1]
                    obj['height'] = temp_img.shape[0]
                    name_split = obj['file_name'].split('/')
                    obj["image_id"] = int(name_split[len(name_split)-1].split('.')[0])
                    
                    
                    
                    car_keys = []
                    for k in boxes[ids_to_keep[frame_number]].keys():
                        if(boxes[ids_to_keep[frame_number]][k]["category_id"]==2):
                            car_keys.append(int(k))
                    ordered_keys = []
                    first_keys = []
                    second_keys = []
                    for k in boxes[ids_to_keep[frame_number]].keys():
                        if(int(k) in accum_intersect):
                            first_keys.append(str(k))
                        else:
                            second_keys.append(str(k))
                    ordered_keys = first_keys + second_keys
                    obj["ids"] = [int(k) for k in boxes[ids_to_keep[frame_number]].keys()]
                    obj["classes"] = [int(boxes[ids_to_keep[frame_number]][k]['category_id']) for k in boxes[ids_to_keep[frame_number]].keys()]
                    obj["pairs_used"] = pairs_used
                    obj["annotations"]= ([boxes[ids_to_keep[frame_number]][str(k)] for k in ordered_keys])
                    
                    
                    obj["labels"] = []
                    list_labels = list(accum_intersect.copy()) 
                    for l in list_labels:
                        obj["labels"].append(l)
                 
                    dict_arr.append(obj)
                    
                    frame_number +=1
                    frame_group.append(obj['image_id'])
            
            i=i+1   
        
    return dict_arr


In [None]:
from detectron2.data import DatasetCatalog, MetadataCatalog
for d in ["train"]:
    DatasetCatalog.register("kitti_" + d, lambda d=d: get_kitti_joint(data_path))
    


## Training Parameters

In [None]:
from detectron2.engine import DefaultTrainer
from detectron2.config import get_cfg
import os
cfg = get_cfg()
#cfg.merge_from_file("./detectron2_repo/configs/COCO-Detection/faster_rcnn_R_50_C4_3x.yaml")
cfg.merge_from_file("../configs/COCO-Detection/faster_rcnn_R_50_FPN_3x_Video.yaml")
cfg.DATASETS.TRAIN = ("kitti_train",)
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 2

cfg.MODEL.WEIGHTS="../models/KITTI/KITTI_DET/model_final.pth"

#cfg.MODEL.WEIGHTS = "https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/faster_rcnn_R_50_C4_3x/137849393/model_final_f97cb7.pkl"  # Let training initialize from model zoo
cfg.SOLVER.IMS_PER_BATCH = 8
cfg.SOLVER.BASE_LR = 0.0009  # pick a good LR
cfg.SOLVER.MAX_ITER = 25000  # 300 iterations seems good enough for this toy dataset; you may need to train longer for a practical dataset
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512   # faster, and good enough for this toy dataset (default: 512)
cfg.MODEL.ROI_HEADS.NUM_CLASSES =1
  # only has one class (ballon)
cfg.OUTPUT_DIR='../models/KITTI/KITTI_JOINT'
print(cfg.OUTPUT_DIR)
print(cfg.MODEL.ANCHOR_GENERATOR)
print(cfg.INPUT)

In [None]:
print(cfg.MODEL.ANCHOR_GENERATOR)

In [None]:
#cfg.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS=[[0.25,0.5,1,2]]
#cfg.MODEL.ANCHOR_GENERATOR.SIZES=[[32,64,128,256,512]]
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg,True) 


In [None]:
trainer.resume_or_load(resume=False)
trainer.train()