In [None]:
import argparse
import os
import time

import cv2
import torch
import numpy as np

from nanodet.data.batch_process import stack_batch_img
from nanodet.data.collate import naive_collate
from nanodet.data.transform import Pipeline
from nanodet.model.arch import build_model
from nanodet.util import Logger, cfg, load_config, load_model_weight
from nanodet.util.path import mkdir

image_ext = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]
video_ext = ["mp4", "mov", "avi", "mkv"]

class Predictor(object):
    def __init__(self, cfg, model_path, logger, device="cuda:0"):
        self.cfg = cfg
        self.device = device
        model = build_model(cfg.model)
        ckpt = torch.load(model_path, map_location=lambda storage, loc: storage)
        load_model_weight(model, ckpt, logger)
        if cfg.model.arch.backbone.name == "RepVGG":
            deploy_config = cfg.model
            deploy_config.arch.backbone.update({"deploy": True})
            deploy_model = build_model(deploy_config)
            from nanodet.model.backbone.repvgg import repvgg_det_model_convert

            model = repvgg_det_model_convert(model, deploy_model)
        self.model = model.to(device).eval()
        self.pipeline = Pipeline(cfg.data.val.pipeline, cfg.data.val.keep_ratio)

    def inference(self, img):
        img_info = {"id": 0}
        if isinstance(img, str):
            img_info["file_name"] = os.path.basename(img)
            img = cv2.imread(img)
        else:
            img_info["file_name"] = None

        height, width = img.shape[:2]
        img_info["height"] = height
        img_info["width"] = width
        meta = dict(img_info=img_info, raw_img=img, np_img=img, img=img)
        meta = self.pipeline(None, meta, self.cfg.data.val.input_size)
        meta["np_img"] = meta["img"]
        meta["img"] = torch.from_numpy(meta["img"].transpose(2, 0, 1)).to(self.device)
        meta = naive_collate([meta])
        meta["img"] = stack_batch_img(meta["img"], divisible=32)
        with torch.no_grad():
            # results = self.model.inference(meta)
            results = None
        return meta, results

    def post_process(self, img, preds):
        img_info = {"id": 0}
        if isinstance(img, str):
            img_info["file_name"] = os.path.basename(img)
            img = cv2.imread(img)
        else:
            img_info["file_name"] = None

        height, width = img.shape[:2]
        img_info["height"] = height
        img_info["width"] = width
        meta = dict(img_info=img_info, raw_img=img, img=img)
        meta = self.pipeline(None, meta, self.cfg.data.val.input_size)
        meta["img"] = torch.from_numpy(meta["img"].transpose(2, 0, 1)).to(self.device)
        meta = naive_collate([meta])
        meta["img"] = stack_batch_img(meta["img"], divisible=32)
        with torch.no_grad():
            results = self.model.head.post_process(preds, meta)
        return meta, results

    def visualize(self, dets, meta, class_names, score_thres, wait=0):
        time1 = time.time()
        result_img = self.model.head.show_result(
            meta["raw_img"][0], dets, class_names, score_thres=score_thres, show=False
        )
        print("viz time: {:.3f}s".format(time.time() - time1))
        return result_img


def get_image_list(path):
    image_names = []
    for maindir, subdir, file_name_list in os.walk(path):
        for filename in file_name_list:
            apath = os.path.join(maindir, filename)
            ext = os.path.splitext(apath)[1]
            if ext in image_ext:
                image_names.append(apath)
    return image_names

config = "/home/tao/Github/nanodet_custom/results/sim_data_mono_pretrain_0.5x/nanodet_plus_custom.yml"
model = "/home/tao/Github/nanodet_custom/results/sim_data_mono_pretrain_0.5x/model_best/nanodet_model_best.pth"
path = "/home/tao/Pictures/real_data_mono_l_ann/data"
save_result = False

local_rank = 0
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

load_config(cfg, config)
logger = Logger(local_rank, use_tensorboard=False)
# predictor = Predictor(cfg, args.model, logger, device="cuda:0")
predictor = Predictor(cfg, model, logger, device="cpu")
logger.log('Press "Esc", "q" or "Q" to exit.')
current_time = time.localtime()

if os.path.isdir(path):
    files = get_image_list(path)
else:
    files = [path]
files.sort()
for image_name in files:
    meta, res = predictor.inference(image_name)
    # result_image = predictor.visualize(res[0], meta, cfg.class_names, 0.35)
    print(meta.keys())
    print(meta["img_info"])
    print(meta["raw_img"][0].shape)
    print(meta["np_img"][0].shape)
    print(meta["np_img"][0].dtype)
    print(meta["img"][0].shape)
    break
    if save_result:
        save_folder = os.path.join(
            cfg.save_dir, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
        )
        mkdir(local_rank, save_folder)
        save_file_name = os.path.join(save_folder, os.path.basename(image_name))
        cv2.imwrite(save_file_name, result_image)
