<a href="https://colab.research.google.com/github/BerensRWU/Complex_YOLO/blob/main/main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [11]:
import numpy as np
import os
import argparse
import cv2
import torch
import torch.utils.data as torch_data

from models import Darknet
from detector import detector, setup_detector
from visualize import visualize_func
from evaluation import get_batch_statistics_rotated_bbox

from utils.astyx_yolo_dataset import AstyxYOLODataset
import utils.config as cnf

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# this class replace the argparse argument class
class arguments:
  def __init__(self, model_def, weights_path, conf_thres, nms_thres, iou_thres,
              split, radar, estimate_bb, evaluate, visualize):
    self.model_def = model_def
    self.weights_path = weights_path
    self.conf_thres = conf_thres
    self.nms_thres = nms_thres
    self.iou_thres = iou_thres
    self.split = split
    self.radar = radar
    self.estimate_bb = estimate_bb
    self.evaluate = evaluate
    self.visualize = visualize

In [15]:
opt = arguments(model_def = "network/yolov3-custom.cfg",
                weights_path = "checkpoints",
                conf_thres = 0.5,
                nms_thres = 0.2,
                iou_thres = 0.5,
                split = "valid",
                radar = False,
                estimate_bb = True,
                evaluate = False,
                visualize = True)

cnf.root_dir = "drive/My Drive/dataset"

In [4]:
if not os.path.exists("output"):
  os.makedirs("output")

In [14]:
if opt.estimate_bb:
  # if we want to detect objects we have to setup the model for our purpose
  model = setup_detector(opt)
  if opt.evaluate:
    ngt = 0 # number of all targets
    true_positives = []
    pred_scores = []

In [9]:
# Load the Astyx dataset
dataset = AstyxYOLODataset(cnf.root_dir, split=opt.split, mode="EVAL", radar=opt.radar)
data_loader = torch_data.DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=dataset.collate_fn)


Load EVAL samples from drive/My Drive/dataset/dataset_astyx_hires2019
Done: total EVAL samples 107


In [10]:
# loop over all frames from the split file
for index, (sample_id, bev_maps, targets) in enumerate(data_loader):
    # Stores detections for each image index
    img_detections = []
    
    # Targets position and dimension values are between 0 - 1, so that they
    # have to be transformed to pixel coordinates
    targets[:, 2:] *= cnf.BEV_WIDTH
    
    if opt.estimate_bb:
        # detects objects
        predictions = detector(model, bev_maps, opt)
        img_detections.extend(predictions)
        # Calculate if the prediction is a true detection
        if opt.evaluate:
            ngt += len(targets)
            true_positive, pred_score = get_batch_statistics_rotated_bbox(predictions, targets, opt.iou_thres)
            """
            Concatenate all true_positives and pred_scores to two long true_positives and pred_scores lists.
            """
            
    # Visualization of the ground truth and if estimated the predicted boxes
    if opt.visualize:
        visualize_func(bev_maps[0], targets, img_detections, sample_id, opt.estimate_bb)

In [16]:
if opt.estimate_bb and opt.evaluate:
    AP = calculate_ap(true_positives, pred_scores, ngt)