## KITTI Object Detection finetuning
### This notebook is used to lunch the finetuning of FPN on KITTI object detection benchmark, the code fetches COCO weights for weight initialization

In [None]:
data_path = "../datasets/KITTI/data_object_image_2/training"

In [None]:
import detectron2
from detectron2.utils.logger import setup_logger

setup_logger()

import numpy as np
import cv2
import random

from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
from detectron2.modeling import build_model
from detectron2.evaluation import COCOEvaluator,PascalVOCDetectionEvaluator
import matplotlib.pyplot as plt
import torch.tensor as tensor
from detectron2.data import build_detection_test_loader
from detectron2.evaluation import inference_on_dataset
import torch
from detectron2.structures.instances import Instances
from detectron2.modeling import build_model
%matplotlib inline

## Dataset Parsing

In [None]:
import os
import numpy as np
import json
from detectron2.structures import BoxMode

def get_kitti_dicts(img_dir):
    
    dataset_dicts = []
    with open('../datasets/KITTI/kitti_train.txt') as f:
        
        for line in f:
          record = {}
          image_path = os.path.join(img_dir, 'image_2/%s.png'%line.replace('\n',''))
          height, width = cv2.imread(image_path).shape[:2]
          record["file_name"] = image_path
          record["image_id"] = int(line)
          record["height"] = height
          record["width"] = width
          objs = []
          ann_path = os.path.join(img_dir,'label_2/%s.txt'%line.replace('\n',''))
          with open(ann_path) as ann_file:
            for ann_line in ann_file:
              line_items = ann_line.split(' ')
              
              if(line_items[0]=='Car'):
                  class_id=2
                  
              elif(line_items[0]=='Pedestrian'):
                  class_id=0
                  
              elif(line_items[0]=='Cyclist'):
                  class_id=1
                  
                  
              else:
                continue
              obj = {'bbox':[np.round(float(line_items[4])),np.round(float(line_items[5])),
                             np.round(float(line_items[6])),np.round(float(line_items[7]))],"category_id": class_id,"iscrowd": 0,"bbox_mode": BoxMode.XYXY_ABS}      
              
              objs.append(obj)
          record["annotations"] = objs
          dataset_dicts.append(record)
          items+=1
          
   
    return dataset_dicts
def get_kitti_val(img_dir):
  
    dataset_dicts = []
    items = 0
    with open('kitti_val.txt') as f:
        
        for line in f:
          record = {}
          image_path = os.path.join(img_dir, 'image_2/%s.png'%line.replace('\n','').zfill(6))
          
          height, width = cv2.imread(image_path).shape[:2]
          record["file_name"] = image_path
          record["image_id"] = int(line)
          record["height"] = height
          record["width"] = width
          objs = []
          ann_path = os.path.join(img_dir,'label_2/%s.txt'%line.replace('\n','').zfill(6))
          with open(ann_path) as ann_file:
            for ann_line in ann_file:
              line_items = ann_line.split(' ')
              if(line_items[0]=='Car'):
                  class_id=2
                  
              elif(line_items[0]=='Pedestrian'):
                  class_id=0
              elif(line_items[0]=='Cyclist'):
                  class_id=1
              else:
                continue
              obj = {'bbox':[np.round(float(line_items[4])),np.round(float(line_items[5])),
                             np.round(float(line_items[6])),np.round(float(line_items[7]))],"category_id": class_id,"iscrowd": 0,"bbox_mode": BoxMode.XYXY_ABS}      
              objs.append(obj)
          record["annotations"] = objs
          dataset_dicts.append(record)
          items+=1
          
    return dataset_dicts


In [None]:
from detectron2.data import DatasetCatalog, MetadataCatalog
for d in ["train", "val"]:
    DatasetCatalog.register("kitti_" + d, lambda d=d: get_kitti_dicts(data_path))
    


## Training Parameters

In [None]:
from detectron2.engine import DefaultTrainer
from detectron2.config import get_cfg
import os
cfg = get_cfg())
cfg.merge_from_file("../configs/COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")
cfg.DATASETS.TRAIN = ("kitti_train",)
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 2
#load coco weights
cfg.MODEL.WEIGHTS="https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/faster_rcnn_R_50_FPN_3x/137849458/model_final_280758.pkl"
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.0025  # pick a good LR
cfg.SOLVER.MAX_ITER = 20000 
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512   #(default: 512)

cfg.OUTPUT_DIR='../models/KITTI/KITTI_DET'


### Initialize the trainer and load the dataset

In [None]:

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg,False) 


### Begin Training

In [None]:
trainer.resume_or_load(resume=False)
trainer.train()