<a href="https://colab.research.google.com/github/BerensRWU/Complex_YOLO/blob/main/visualize.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import os
import argparse
import cv2
import torch

import utils.utils as utils
from models import *
import torch.utils.data as torch_data

import utils.astyx_bev_utils as bev_utils
from utils.astyx_yolo_dataset import AstyxYOLODataset
import utils.config as cnf

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [3]:
model_def = "network/yolov3-custom.cfg"
weights_path = "checkpoints"
conf_thres = 0.5
nms_thres = 0.2
split = "valid"
radar = True
estimate_bb = False

In [11]:
if not os.path.exists("output"):
  os.makedirs("output")

In [5]:
if estimate_bb:
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  weights_path = os.path.join(weights_path, "weights_RADAR.pth" if radar else "weights_LIDAR.pth")
  # Set up model
  model = Darknet(model_def, img_size=cnf.BEV_WIDTH).to(device)
  # Load checkpoint weights
  model.load_state_dict(torch.load(weights_path, map_location = device))
  # Eval mode
  model.eval()

In [None]:
dataset = AstyxYOLODataset(cnf.root_dir, split=split, mode="EVAL", radar=radar)
data_loader = torch_data.DataLoader(dataset, batch_size=1, shuffle=False)

In [7]:
Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor

In [None]:
for index, (sample_id, bev_maps, targets) in enumerate(data_loader):
  if estimate_bb:
      # Configure bev image
      input_imgs = Variable(bev_maps.type(Tensor))
      # Get detections 
      with torch.no_grad():
          detections = model(input_imgs)
          detections = utils.non_max_suppression_rotated_bbox(detections, conf_thres, nms_thres) 
      img_detections = []  # Stores detections for each image index
      img_detections.extend(detections)

  bev_maps = torch.squeeze(bev_maps).numpy()

  RGB_Map = np.zeros((cnf.BEV_WIDTH, cnf.BEV_WIDTH, 3))
  # Because cv2 saves BGR instead of RGB
  RGB_Map[:, :, 2] = bev_maps[0, :, :]  # height -> r_map 
  RGB_Map[:, :, 1] = bev_maps[1, :, :]  # density -> g_map
  RGB_Map[:, :, 0] = bev_maps[2, :, :]  # intensity/velocity -> b_map
  
  RGB_Map *= 255
  RGB_Map = RGB_Map.astype(np.uint8)
  
  targets = targets[0]
  targets[:, 2:] *= cnf.BEV_WIDTH
  for _,cls,x,y,w,l,im,re in targets:
      yaw = np.arctan2(im,re)
      bev_utils.drawRotatedBox(RGB_Map, x, y, w, l, yaw, [0, 255, 0])
      
  if estimate_bb:    
      for detections in img_detections:
          if detections is None:
              continue
          # Rescale boxes to original image
          detections = utils.rescale_boxes(detections, cnf.BEV_WIDTH, RGB_Map.shape[:2])
          
          for x, y, w, l, im, re, conf, cls_conf, cls_pred in detections:
              yaw = np.arctan2(im, re)
              # Draw rotated box
              bev_utils.drawRotatedBox(RGB_Map, x, y, w, l, yaw, [0, 0, 255])
  
  cv2.imwrite("output/%06d.png" % sample_id, RGB_Map) # note cv2 RGB->BGR
