Skip to content

Commit

Permalink
feat(layers): support jit op (#1241)
Browse files Browse the repository at this point in the history
  • Loading branch information
FateScript committed Apr 15, 2022
1 parent 7ba9fd2 commit 6513f76
Show file tree
Hide file tree
Showing 14 changed files with 220 additions and 68 deletions.
2 changes: 2 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
include requirements.txt
recursive-include yolox *.cpp *.h *.cu *.cuh *.cc
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ This repo is an implementation of PyTorch version YOLOX, there is also a [MegEng
<img src="assets/git_fig.png" width="1000" >

## Updates!!
* 【2022/04/14】 We suport jit compile op.
* 【2021/08/19】 We optimize the training process with **2x** faster training and **~1%** higher performance! See [notes](docs/updates_note.md) for more details.
* 【2021/08/05】 We release [MegEngine version YOLOX](https://github.com/MegEngine/YOLOX).
* 【2021/07/28】 We fix the fatal error of [memory leak](https://github.com/Megvii-BaseDetection/YOLOX/issues/103)
Expand Down
63 changes: 29 additions & 34 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,14 @@

import re
import setuptools
import glob
from os import path
import torch
from torch.utils.cpp_extension import CppExtension
import sys


def get_extensions():
this_dir = path.dirname(path.abspath(__file__))
extensions_dir = path.join(this_dir, "yolox", "layers", "csrc")

main_source = path.join(extensions_dir, "vision.cpp")
sources = glob.glob(path.join(extensions_dir, "**", "*.cpp"))

sources = [main_source] + sources
extension = CppExtension

extra_compile_args = {"cxx": ["-O3"]}
define_macros = []

include_dirs = [extensions_dir]

ext_modules = [
extension(
"yolox._C",
sources,
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
]

return ext_modules
TORCH_AVAILABLE = True
try:
import torch
except ImportError:
TORCH_AVAILABLE = False
print("[WARNING] Unable to import torch, pre-compiling ops will be disabled.")


def get_package_dir():
Expand Down Expand Up @@ -67,23 +43,42 @@ def get_long_description():
return long_description


def get_ext_modules():
ext_module = []
if sys.platform != "win32": # pre-compile ops on linux
assert TORCH_AVAILABLE, "torch is required for pre-compiling ops, please install it first."
# if any other op is added, please also add it here
from yolox.layers import FastCOCOEvalOp
ext_module.append(FastCOCOEvalOp().build_op())
return ext_module


def get_cmd_class():
cmdclass = {}
if TORCH_AVAILABLE:
cmdclass["build_ext"] = torch.utils.cpp_extension.BuildExtension
return cmdclass


setuptools.setup(
name="yolox",
version=get_yolox_version(),
author="megvii basedet team",
url="https://github.com/Megvii-BaseDetection/YOLOX",
package_dir=get_package_dir(),
packages=setuptools.find_packages(exclude=("tests", "tools")) + list(get_package_dir().keys()),
python_requires=">=3.6",
install_requires=get_install_requirements(),
setup_requires=["wheel"], # avoid building error when pip is not updated
long_description=get_long_description(),
long_description_content_type="text/markdown",
ext_modules=get_extensions(),
include_package_data=True, # include files in MANIFEST.in
ext_modules=get_ext_modules(),
cmdclass=get_cmd_class(),
classifiers=[
"Programming Language :: Python :: 3", "Operating System :: OS Independent",
"License :: OSI Approved :: Apache Software License",
],
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
packages=setuptools.find_packages(),
project_urls={
"Documentation": "https://yolox.readthedocs.io",
"Source": "https://github.com/Megvii-BaseDetection/YOLOX",
Expand Down
10 changes: 9 additions & 1 deletion tools/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,14 @@

from yolox.core import launch
from yolox.exp import get_exp
from yolox.utils import configure_nccl, fuse_model, get_local_rank, get_model_info, setup_logger
from yolox.utils import (
configure_module,
configure_nccl,
fuse_model,
get_local_rank,
get_model_info,
setup_logger
)


def make_parser():
Expand Down Expand Up @@ -190,6 +197,7 @@ def main(exp, args, num_gpu):


if __name__ == "__main__":
configure_module()
args = make_parser().parse_args()
exp = get_exp(args.exp_file, args.name)
exp.merge(args.opts)
Expand Down
3 changes: 2 additions & 1 deletion tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from yolox.core import Trainer, launch
from yolox.exp import get_exp
from yolox.utils import configure_nccl, configure_omp, get_num_devices
from yolox.utils import configure_module, configure_nccl, configure_omp, get_num_devices


def make_parser():
Expand Down Expand Up @@ -118,6 +118,7 @@ def main(exp, args):


if __name__ == "__main__":
configure_module()
args = make_parser().parse_args()
exp = get_exp(args.exp_file, args.name)
exp.merge(args.opts)
Expand Down
4 changes: 0 additions & 4 deletions yolox/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-

from .utils import configure_module

configure_module()

__version__ = "0.2.0"
9 changes: 2 additions & 7 deletions yolox/exp/yolox_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,9 @@ def get_data_loader(
MosaicDetection,
worker_init_reset_seed,
)
from yolox.utils import (
wait_for_the_master,
get_local_rank,
)

local_rank = get_local_rank()
from yolox.utils import wait_for_the_master

with wait_for_the_master(local_rank):
with wait_for_the_master():
dataset = COCODataset(
data_dir=self.data_dir,
json_file=self.train_ann,
Expand Down
10 changes: 9 additions & 1 deletion yolox/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,12 @@
# -*- coding:utf-8 -*-
# Copyright (c) Megvii Inc. All rights reserved.

from .fast_coco_eval_api import COCOeval_opt
# import torch first to make jit op work without `ImportError of libc10.so`
import torch # noqa

from .jit_ops import FastCOCOEvalOp, JitOp

try:
from .fast_coco_eval_api import COCOeval_opt
except ImportError: # exception will be raised when users build yolox from source
pass
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,16 @@ py::dict Accumulate(
const std::vector<ImageEvaluation>& evalutations);

} // namespace COCOeval

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("COCOevalAccumulate", &COCOeval::Accumulate, "COCOeval::Accumulate");
m.def(
"COCOevalEvaluateImages",
&COCOeval::EvaluateImages,
"COCOeval::EvaluateImages");
pybind11::class_<COCOeval::InstanceAnnotation>(m, "InstanceAnnotation")
.def(pybind11::init<uint64_t, double, double, bool, bool>());
pybind11::class_<COCOeval::ImageEvaluation>(m, "ImageEvaluation")
.def(pybind11::init<>());
}
13 changes: 0 additions & 13 deletions yolox/layers/csrc/vision.cpp

This file was deleted.

13 changes: 7 additions & 6 deletions yolox/layers/fast_coco_eval_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,17 @@
import numpy as np
from pycocotools.cocoeval import COCOeval

# import torch first to make yolox._C work without ImportError of libc10.so
# in YOLOX, env is already set in __init__.py.
from yolox import _C
from .jit_ops import FastCOCOEvalOp


class COCOeval_opt(COCOeval):
"""
This is a slightly modified version of the original COCO API, where the functions evaluateImg()
and accumulate() are implemented in C++ to speedup evaluation
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.module = FastCOCOEvalOp().load()

def evaluate(self):
"""
Expand Down Expand Up @@ -72,7 +73,7 @@ def convert_instances_to_cpp(instances, is_det=False):
# to access in C++
instances_cpp = []
for instance in instances:
instance_cpp = _C.InstanceAnnotation(
instance_cpp = self.module.InstanceAnnotation(
int(instance["id"]),
instance["score"] if is_det else instance.get("score", 0.0),
instance["area"],
Expand Down Expand Up @@ -106,7 +107,7 @@ def convert_instances_to_cpp(instances, is_det=False):
]

# Call C++ implementation of self.evaluateImgs()
self._evalImgs_cpp = _C.COCOevalEvaluateImages(
self._evalImgs_cpp = self.module.COCOevalEvaluateImages(
p.areaRng,
maxDet,
p.iouThrs,
Expand All @@ -131,7 +132,7 @@ def accumulate(self):
if not hasattr(self, "_evalImgs_cpp"):
print("Please run evaluate() first")

self.eval = _C.COCOevalAccumulate(self._paramsEval, self._evalImgs_cpp)
self.eval = self.module.COCOevalAccumulate(self._paramsEval, self._evalImgs_cpp)

# recall is num_iou_thresholds X num_categories X num_area_ranges X num_max_detections
self.eval["recall"] = np.array(self.eval["recall"]).reshape(
Expand Down

0 comments on commit 6513f76

Please sign in to comment.