New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feature] Add an demo to **simply** load pth and an image and get visualize result #3
Comments
Try to add this demo: from inspect import ArgSpec
import os
import cv2
import numpy as np
import torch
from torch.nn.modules.utils import _pair
from easydict import EasyDict
import argparse
from eod.data.datasets.transforms import build_transformer
from eod.data.data_utils import get_image_size
from eod.utils.general.yaml_loader import load_yaml
from eod.utils.general.vis_helper import BaseVisualizer, OpenCVVisualizer
from eod.utils.general.log_helper import default_logger as logger
from eod.utils.general.registry_factory import (
MODEL_HELPER_REGISTRY,
BATCHING_REGISTRY,
IMAGE_READER_REGISTRY,
)
from eod.utils.general.registry_factory import (
INFERENCER_REGISTRY,
SAVER_REGISTRY,
VISUALIZER_REGISTRY,
)
__all__ = ["BaseInference"]
class BaseInference(object):
def __init__(self, config, work_dir="./"):
self.args = config.get("args", {})
self.config = config
# cfg_infer = config['inference']
self.class_names = config.get("class_names", None)
self.work_dir = work_dir
self.ckpt = self.args["ckpt"]
self.vis_dir = self.args["vis_dir"]
self.image_path = self.args["image_path"]
assert self.image_path and os.path.exists(
self.image_path
), "Invalid images path."
# build DataFetch
self.build_data()
logger.info("build data fetcher done")
# build model
self.build_model()
logger.info("build model done")
# build saver
self.build_saver()
logger.info("build saver done")
# build visualizer
self.visualizer = OpenCVVisualizer()
logger.info("build visualizer done")
# resume
self.resume()
logger.info("load weights done")
def tensor2numpy(self, x):
if x is None:
return x
if torch.is_tensor(x):
return x.cpu().numpy()
if isinstance(x, list):
x = [_.cpu().numpy() if torch.is_tensor(_) else _ for _ in x]
return x
def resume(self):
checkpoint = self.saver.load_checkpoint(self.ckpt)
state_dict = checkpoint.get("model", checkpoint.get("state_dict", {}))
self.detector.load(state_dict, strict=False)
def build_saver(self):
cfg_saver = self.config["saver"]
_cfg_saver = cfg_saver
if "kwargs" not in cfg_saver:
_cfg_saver = {"type": "base", "kwargs": {}}
_cfg_saver["kwargs"]["save_cfg"] = cfg_saver
_cfg_saver["kwargs"]["work_dir"] = self.work_dir
self.saver = SAVER_REGISTRY.build(_cfg_saver)
def build_model(self):
model_helper_cfg = self.config["runtime"].get("model_helper", {})
model_helper_cfg["type"] = model_helper_cfg.get("type", "base")
model_helper_cfg["kwargs"] = model_helper_cfg.get(
"kwargs", {"cfg": self.config["net"]}
)
self.detector = MODEL_HELPER_REGISTRY.build(model_helper_cfg).cuda().eval()
def build_data(self):
data_cfg = self.config["dataset"]
assert "test" in data_cfg, "Test dataset config must need !"
dataset_cfg = data_cfg["test"]["dataset"]["kwargs"]
self.color_mode = dataset_cfg["image_reader"]["kwargs"]["color_mode"]
# build image_reader
self.image_reader = IMAGE_READER_REGISTRY.build(dataset_cfg["image_reader"])
self.transformer = build_transformer(dataset_cfg["transformer"])
pad_type = data_cfg["dataloader"]["kwargs"].get("pad_type", "batch_pad")
pad_value = data_cfg["dataloader"]["kwargs"].get("pad_value", 0)
alignment = data_cfg["dataloader"]["kwargs"]["alignment"]
self.batch_pad = BATCHING_REGISTRY.get(pad_type)(alignment, pad_value)
def iterate_image(self, image_dir):
EXTS = ["jpg", "jpeg", "png", "svg", "bmp"]
for root, subdirs, subfiles in os.walk(image_dir):
for filename in subfiles:
ext = filename.rsplit(".", 1)[-1].lower()
filepath = os.path.join(root, filename)
if ext in EXTS:
yield filepath
def map_back(self, output):
"""Map predictions to original image
Args:
- output: dict
Returns:
- output_list: list of dict,
"""
origin_images = output["origin_image"]
image_info = output["image_info"]
bboxes = self.tensor2numpy(output["dt_bboxes"])
batch_size = len(image_info)
output_list = []
for b_ix in range(batch_size):
origin_image = origin_images[b_ix]
if origin_image.ndim == 3:
origin_image_h, origin_image_w, _ = origin_image.shape
else:
origin_image_h, origin_image_w = origin_image.shape
img_info = image_info[b_ix]
unpad_image_h, unpad_image_w = img_info[:2]
scale_h, scale_w = _pair(img_info[2])
keep_ix = np.where(bboxes[:, 0] == b_ix)[0]
# resize bbox
img_bboxes = bboxes[keep_ix]
img_bboxes[:, 1] /= scale_w
img_bboxes[:, 2] /= scale_h
img_bboxes[:, 3] /= scale_w
img_bboxes[:, 4] /= scale_h
img_bboxes = img_bboxes[:, 1:]
img_output = {
"image": origin_image,
"image_info": img_info,
"dt_bboxes": img_bboxes,
}
output_list.append(img_output)
return output_list
def fetch_single(self, filename):
img = self.image_reader.read(filename)
data = EasyDict(
{"filename": filename, "origin_image": img, "image": img, "flipped": False}
)
data = self.transformer(data)
scale_factor = data.get("scale_factor", 1)
image_h, image_w = get_image_size(img)
new_image_h, new_image_w = get_image_size(data.image)
data.image_info = [
1,
new_image_h,
new_image_w,
scale_factor,
image_h,
image_w,
data.flipped,
filename,
]
data.image = data.image.cuda()
return data
def fetch(self, filename_list):
batch = [self.fetch_single(filename) for filename in filename_list]
batch_keys = list(batch[0].keys())
def batch_value(key, default=None):
return [_.get(key, default) for _ in batch]
data = EasyDict({k: batch_value(k) for k in batch_keys})
data = self.batch_pad(data)
return data
def predict(self):
output_list = []
if os.path.isdir(self.image_path):
list_imgs = self.iterate_image(self.image_path)
else:
list_imgs = [self.image_path]
for img_idx, filename in enumerate(list_imgs):
logger.info("predicting {}:{}".format(img_idx, filename))
batch = self.fetch([filename])
with torch.no_grad():
output = self.detector(batch)
output = self.map_back(output)
print(output[0]['dt_bboxes'])
return output_list
def vis(self, outputs):
for img_idx, output in enumerate(outputs):
img = output["image"]
if self.color_mode != "RGB":
cvt_color_vis = getattr(cv2, "COLOR_{}2RGB".format(self.color_mode))
img = cv2.cvtColor(img, cvt_color_vis)
boxes = output["dt_bboxes"]
filename = os.path.basename(output["image_info"][-1])
logger.info("visualizing {}:{}".format(img_idx, filename))
img_h, img_w = img.shape[:2]
classes = boxes[:, -1].astype(np.int32)
boxes = boxes[:, :-1]
output_name = os.path.join(self.vis_dir, filename)
self.visualizer.vis(img, boxes, classes, output_name, absolute_path=True)
def get_parser():
parser = argparse.ArgumentParser(description="EOD demo for builtin configs")
parser.add_argument(
"--config-file",
default="configs/retinanet/retinanet-r18-improve.yaml",
metavar="FILE",
help="path to config file",
)
parser.add_argument(
"--webcam", action="store_true", help="Take inputs from webcam."
)
parser.add_argument("--video-input", help="Path to video file.")
parser.add_argument("--vis_dir", default='results', help="Path to video file.")
parser.add_argument(
"--input",
# nargs="+",
help="A list of space separated input images; "
"or a single glob pattern such as 'directory/*.jpg'",
)
parser.add_argument(
"--ckpt",
help="A checkpoint "
"If not given, will show output in an OpenCV window.",
)
parser.add_argument(
"-c",
"--confidence-threshold",
type=float,
default=0.21,
help="Minimum score for instance predictions to be shown",
)
parser.add_argument(
"-n",
"--nms-threshold",
type=float,
default=0.6,
help="Minimum score for instance predictions to be shown",
)
parser.add_argument(
"--opts",
help="Modify config options using the command-line 'KEY VALUE' pairs",
default=[],
nargs=argparse.REMAINDER,
)
return parser
if __name__ == "__main__":
args = get_parser().parse_args()
cfg = load_yaml(args.config_file)
cfg['args'] = {
'ckpt': args.ckpt,
'image_path': args.input,
'vis_dir': 'results',
'opts': args.opts
}
if not os.path.exists(args.vis_dir):
os.makedirs(args.vis_dir, exist_ok=True)
cfg['runtime'] = cfg.setdefault('runtime', {})
predictor = BaseInference(cfg)
output_list = predictor.predict()
predictor.vis(output_list) BUt reulst is wrong, dont' know why I found your implementation is GOOD but very deeply packaged, it would be better to provide a DEMO file to visualize on single image. |
Thanks for your advice, we will fix it later. @jinfagang |
Hi, We found a little bug in your debug.py, as follows: |
@jinfagang |
@Joker-co No, you will get index out of range error if you don't make the lenght be 8 |
|
@Joker-co the inference.py way is good, using this way to inference:
there is no problem. But this logic is not good:
It always inference all images, and then visualize or save, but question, we want image one by one. And it doesn't support video |
I think you want to inference only one image, not all images in folder? 'this logic is not good.' means wrong det bboxes or wrong confidence? |
@jinfagang |
@Joker-co Hi, currently, with images_folder mode, it will inference all images, then start visualize all images and all result right? this is not good, since we want inference one image, and visualize it, then next one. Otherwise, it really annoying if I want just visualize some images from coco val folder, I have to inference all of them to see the vis result. it doesn't make any sense |
No description provided.
The text was updated successfully, but these errors were encountered: