Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] Add an demo to **simply** load pth and an image and get visualize result #3

Closed
lucasjinreal opened this issue Nov 1, 2021 · 10 comments

Comments

@lucasjinreal
Copy link

No description provided.

@lucasjinreal
Copy link
Author

Try to add this demo:

from inspect import ArgSpec
import os
import cv2
import numpy as np
import torch
from torch.nn.modules.utils import _pair
from easydict import EasyDict
import argparse

from eod.data.datasets.transforms import build_transformer
from eod.data.data_utils import get_image_size
from eod.utils.general.yaml_loader import load_yaml
from eod.utils.general.vis_helper import BaseVisualizer, OpenCVVisualizer
from eod.utils.general.log_helper import default_logger as logger
from eod.utils.general.registry_factory import (
    MODEL_HELPER_REGISTRY,
    BATCHING_REGISTRY,
    IMAGE_READER_REGISTRY,
)
from eod.utils.general.registry_factory import (
    INFERENCER_REGISTRY,
    SAVER_REGISTRY,
    VISUALIZER_REGISTRY,
)

__all__ = ["BaseInference"]


class BaseInference(object):
    def __init__(self, config, work_dir="./"):
        self.args = config.get("args", {})
        self.config = config
        # cfg_infer = config['inference']
        self.class_names = config.get("class_names", None)
        self.work_dir = work_dir
        self.ckpt = self.args["ckpt"]
        self.vis_dir = self.args["vis_dir"]
        self.image_path = self.args["image_path"]

        assert self.image_path and os.path.exists(
            self.image_path
        ), "Invalid images path."

        # build DataFetch
        self.build_data()
        logger.info("build data fetcher done")
        # build model
        self.build_model()
        logger.info("build model done")
        # build saver
        self.build_saver()
        logger.info("build saver done")
        # build visualizer
        self.visualizer = OpenCVVisualizer()
        logger.info("build visualizer done")
        # resume
        self.resume()
        logger.info("load weights done")

    def tensor2numpy(self, x):
        if x is None:
            return x
        if torch.is_tensor(x):
            return x.cpu().numpy()
        if isinstance(x, list):
            x = [_.cpu().numpy() if torch.is_tensor(_) else _ for _ in x]
        return x

    def resume(self):
        checkpoint = self.saver.load_checkpoint(self.ckpt)
        state_dict = checkpoint.get("model", checkpoint.get("state_dict", {}))
        self.detector.load(state_dict, strict=False)

    def build_saver(self):
        cfg_saver = self.config["saver"]
        _cfg_saver = cfg_saver
        if "kwargs" not in cfg_saver:
            _cfg_saver = {"type": "base", "kwargs": {}}
            _cfg_saver["kwargs"]["save_cfg"] = cfg_saver
            _cfg_saver["kwargs"]["work_dir"] = self.work_dir
        self.saver = SAVER_REGISTRY.build(_cfg_saver)

    def build_model(self):
        model_helper_cfg = self.config["runtime"].get("model_helper", {})
        model_helper_cfg["type"] = model_helper_cfg.get("type", "base")
        model_helper_cfg["kwargs"] = model_helper_cfg.get(
            "kwargs", {"cfg": self.config["net"]}
        )
        self.detector = MODEL_HELPER_REGISTRY.build(model_helper_cfg).cuda().eval()

    def build_data(self):
        data_cfg = self.config["dataset"]
        assert "test" in data_cfg, "Test dataset config must need !"
        dataset_cfg = data_cfg["test"]["dataset"]["kwargs"]
        self.color_mode = dataset_cfg["image_reader"]["kwargs"]["color_mode"]
        # build image_reader
        self.image_reader = IMAGE_READER_REGISTRY.build(dataset_cfg["image_reader"])

        self.transformer = build_transformer(dataset_cfg["transformer"])
        pad_type = data_cfg["dataloader"]["kwargs"].get("pad_type", "batch_pad")
        pad_value = data_cfg["dataloader"]["kwargs"].get("pad_value", 0)
        alignment = data_cfg["dataloader"]["kwargs"]["alignment"]
        self.batch_pad = BATCHING_REGISTRY.get(pad_type)(alignment, pad_value)

    def iterate_image(self, image_dir):
        EXTS = ["jpg", "jpeg", "png", "svg", "bmp"]

        for root, subdirs, subfiles in os.walk(image_dir):
            for filename in subfiles:
                ext = filename.rsplit(".", 1)[-1].lower()
                filepath = os.path.join(root, filename)
                if ext in EXTS:
                    yield filepath

    def map_back(self, output):
        """Map predictions to original image
        Args:
           - output: dict
        Returns:
           - output_list: list of dict,
        """
        origin_images = output["origin_image"]
        image_info = output["image_info"]
        bboxes = self.tensor2numpy(output["dt_bboxes"])
        batch_size = len(image_info)

        output_list = []
        for b_ix in range(batch_size):

            origin_image = origin_images[b_ix]
            if origin_image.ndim == 3:
                origin_image_h, origin_image_w, _ = origin_image.shape
            else:
                origin_image_h, origin_image_w = origin_image.shape

            img_info = image_info[b_ix]
            unpad_image_h, unpad_image_w = img_info[:2]
            scale_h, scale_w = _pair(img_info[2])
            keep_ix = np.where(bboxes[:, 0] == b_ix)[0]

            # resize bbox
            img_bboxes = bboxes[keep_ix]
            img_bboxes[:, 1] /= scale_w
            img_bboxes[:, 2] /= scale_h
            img_bboxes[:, 3] /= scale_w
            img_bboxes[:, 4] /= scale_h
            img_bboxes = img_bboxes[:, 1:]

            img_output = {
                "image": origin_image,
                "image_info": img_info,
                "dt_bboxes": img_bboxes,
            }
            output_list.append(img_output)

        return output_list

    def fetch_single(self, filename):
        img = self.image_reader.read(filename)
        data = EasyDict(
            {"filename": filename, "origin_image": img, "image": img, "flipped": False}
        )
        data = self.transformer(data)
        scale_factor = data.get("scale_factor", 1)

        image_h, image_w = get_image_size(img)
        new_image_h, new_image_w = get_image_size(data.image)
        data.image_info = [
            1,
            new_image_h,
            new_image_w,
            scale_factor,
            image_h,
            image_w,
            data.flipped,
            filename,
        ]
        data.image = data.image.cuda()
        return data

    def fetch(self, filename_list):
        batch = [self.fetch_single(filename) for filename in filename_list]

        batch_keys = list(batch[0].keys())

        def batch_value(key, default=None):
            return [_.get(key, default) for _ in batch]

        data = EasyDict({k: batch_value(k) for k in batch_keys})
        data = self.batch_pad(data)

        return data

    def predict(self):
        output_list = []
        if os.path.isdir(self.image_path):
            list_imgs = self.iterate_image(self.image_path)
        else:
            list_imgs = [self.image_path]
        for img_idx, filename in enumerate(list_imgs):
            logger.info("predicting {}:{}".format(img_idx, filename))
            batch = self.fetch([filename])
            with torch.no_grad():
                output = self.detector(batch)
            output = self.map_back(output)
            print(output[0]['dt_bboxes'])
        return output_list

    def vis(self, outputs):
        for img_idx, output in enumerate(outputs):
            img = output["image"]
            if self.color_mode != "RGB":
                cvt_color_vis = getattr(cv2, "COLOR_{}2RGB".format(self.color_mode))
                img = cv2.cvtColor(img, cvt_color_vis)
            boxes = output["dt_bboxes"]
            filename = os.path.basename(output["image_info"][-1])
            logger.info("visualizing {}:{}".format(img_idx, filename))

            img_h, img_w = img.shape[:2]
            classes = boxes[:, -1].astype(np.int32)
            boxes = boxes[:, :-1]
            output_name = os.path.join(self.vis_dir, filename)

            self.visualizer.vis(img, boxes, classes, output_name, absolute_path=True)


def get_parser():
    parser = argparse.ArgumentParser(description="EOD demo for builtin configs")
    parser.add_argument(
        "--config-file",
        default="configs/retinanet/retinanet-r18-improve.yaml",
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument(
        "--webcam", action="store_true", help="Take inputs from webcam."
    )
    parser.add_argument("--video-input", help="Path to video file.")
    parser.add_argument("--vis_dir", default='results', help="Path to video file.")
    parser.add_argument(
        "--input",
        # nargs="+",
        help="A list of space separated input images; "
        "or a single glob pattern such as 'directory/*.jpg'",
    )
    parser.add_argument(
        "--ckpt",
        help="A checkpoint "
        "If not given, will show output in an OpenCV window.",
    )

    parser.add_argument(
        "-c",
        "--confidence-threshold",
        type=float,
        default=0.21,
        help="Minimum score for instance predictions to be shown",
    )
    parser.add_argument(
        "-n",
        "--nms-threshold",
        type=float,
        default=0.6,
        help="Minimum score for instance predictions to be shown",
    )
    parser.add_argument(
        "--opts",
        help="Modify config options using the command-line 'KEY VALUE' pairs",
        default=[],
        nargs=argparse.REMAINDER,
    )
    return parser


if __name__ == "__main__":
    args = get_parser().parse_args()

    cfg = load_yaml(args.config_file)
    cfg['args'] = {
        'ckpt': args.ckpt,
        'image_path': args.input,
        'vis_dir': 'results',
        'opts': args.opts
    }
    if not os.path.exists(args.vis_dir):
        os.makedirs(args.vis_dir, exist_ok=True)
    cfg['runtime'] = cfg.setdefault('runtime', {})

    predictor = BaseInference(cfg)
    output_list = predictor.predict()
    predictor.vis(output_list)

BUt reulst is wrong, dont' know why

I found your implementation is GOOD but very deeply packaged, it would be better to provide a DEMO file to visualize on single image.

@yqyao
Copy link

yqyao commented Nov 1, 2021

Thanks for your advice, we will fix it later. @jinfagang

@LitPrice
Copy link
Contributor

LitPrice commented Nov 2, 2021

Hi, We found a little bug in your debug.py, as follows:
data.image_info = [ 1, new_image_h, new_image_w, scale_factor, image_h, image_w, data.flipped, filename, ]
the image_info should organized by following mode:
data.image_info = [ new_image_h, new_image_w, scale_factor, image_h, image_w, data.flipped, filename, ]

@LitPrice
Copy link
Contributor

LitPrice commented Nov 2, 2021

@jinfagang

@lucasjinreal
Copy link
Author

@Joker-co No, you will get index out of range error if you don't make the lenght be 8

@LitPrice
Copy link
Contributor

LitPrice commented Nov 3, 2021

img_info = image_info[b_ix] unpad_image_h, unpad_image_w = img_info[:2] scale_h, scale_w = _pair(img_info[2]) keep_ix = np.where(bboxes[:, 0] == b_ix)[0]
Are scale_h and scale_w normal?

@lucasjinreal
Copy link
Author

@Joker-co the inference.py way is good, using this way to inference:

python -m eod inference --config configs/retinanet/retinanet-r18-improve.yaml --ckpt weights/retinanet_r18.pth -i images -v results

there is no problem.

But this logic is not good:

output_list = inferencer.predict()
    inferencer.vis(output_list)

It always inference all images, and then visualize or save, but

question, we want image one by one.

And it doesn't support video

@LitPrice
Copy link
Contributor

LitPrice commented Nov 3, 2021

I think you want to inference only one image, not all images in folder?
This question, you can -i with only one image path.

'this logic is not good.' means wrong det bboxes or wrong confidence?
Could you provide more detailed error information?

@LitPrice
Copy link
Contributor

LitPrice commented Nov 3, 2021

@jinfagang

@lucasjinreal
Copy link
Author

@Joker-co Hi, currently, with images_folder mode, it will inference all images, then start visualize all images and all result right?

this is not good, since we want inference one image, and visualize it, then next one.

Otherwise, it really annoying if I want just visualize some images from coco val folder, I have to inference all of them to see the vis result.

it doesn't make any sense

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants