In [None]:
import os
import argparse

import cv2

In [None]:

import torch
import torch.nn as nn
from torch.utils.data import DataLoader as DataLoader
from loader.kitti import collate_fn

from config import cfg
from utils.utils import *
from model.model import RPN3D
from dataloader.kitti import KITTI_Loader as Dataset

In [None]:
#args here

In [None]:
def main(args):
    
    val_dataset = Dataset(os.path.join(cfg.DATA_DIR, 'validation'), shuffle = False, aug = False, is_testset = False)
    val_dataloader = DataLoader(val_dataset, batch_size = args.batch_size, shuffle = False, collate_fn = collate_fn,
                                num_workers = args.workers, pin_memory = False)
    
    #device - set device
    model = RPN3D(cfg.DETECT_OBJ)
    weights = torch.load(args.pre_trained)
    model.load_state_dict(weights['state_dict'])
    model.eval()
    
    with torch.no_grad():
        for (i, val_data) in enumerate(val_dataloader):

            # Forward pass for validation and prediction
            probs, deltas, val_loss, val_cls_loss, val_reg_loss, cls_pos_loss_rec, cls_neg_loss_rec = model(val_data)

            front_images, bird_views, heatmaps = None, None, None
            
            if args.vis:
                tags, ret_box3d_scores, front_images, bird_views, heatmaps = \model.module.predict(val_data, probs, deltas, summary = False, vis = True)
            else:
                tags, ret_box3d_scores = model.module.predict(val_data, probs, deltas, summary = False, vis = False)

            # tags: (N)
            # ret_box3d_scores: (N, N'); (class, x, y, z, h, w, l, rz, score)
            for tag, score in zip(tags, ret_box3d_scores):
                output_path = os.path.join(args.output_path, 'data', tag + '.txt')
                with open(output_path, 'w+') as f:
                    labels = box3d_to_label([score[:, 1:8]], [score[:, 0]], [score[:, -1]], coordinate = 'lidar')[0]
                    for line in labels:
                        f.write(line)
                    print('Write out {} objects to {}'.format(len(labels), tag))

            # Dump visualizations
            if args.vis:
                for tag, front_image, bird_view, heatmap in zip(tags, front_images, bird_views, heatmaps):
                    front_img_path = os.path.join(args.output_path, 'vis', tag + '_front.jpg')
                    bird_view_path = os.path.join(args.output_path, 'vis', tag + '_bv.jpg')
                    heatmap_path = os.path.join(args.output_path, 'vis', tag + '_heatmap.jpg')
                    cv2.imwrite(front_img_path, front_image)
                    cv2.imwrite(bird_view_path, bird_view)
                    cv2.imwrite(heatmap_path, heatmap)
    

In [None]:
#call main here