diff --git a/README.md b/README.md
index b512d984..54f238df 100644
--- a/README.md
+++ b/README.md
@@ -72,7 +72,7 @@ export PYTHONPATH=/path/to/models:$PYTHONPATH
### 目标检测
-目标检测同样是计算机视觉中的常见任务,我们提供了两个经典的目标检测模型[Retinanet](./official/vision/detection/models/retinanet.py)和[Faster R-CNN](./official/vision/detection/models/faster_rcnn.py),这两个模型在**COCO验证集**上的测试结果如下:
+目标检测同样是计算机视觉中的常见任务,我们提供了多个经典的目标检测模型,这些模型在COCO2017验证集上的测试结果如下:
| 模型 | mAP
@5-95 |
| --- | :---: |
@@ -81,21 +81,27 @@ export PYTHONPATH=/path/to/models:$PYTHONPATH
| retinanet-resx101-coco-2x-800size | 42.7 |
| faster-rcnn-res50-coco-1x-800size | 38.0 |
| faster-rcnn-res101-coco-2x-800size | 42.5 |
-| faster-rcnn-resx101-coco-2x-800size | 44.7 * |
+| faster-rcnn-resx101-coco-2x-800size | 43.6 |
| fcos-res50-coco-1x-800size | 39.7 |
| fcos-res101-coco-2x-800size | 44.1 |
-| fcos-resx101-coco-2x-800size | 39.7 * |
+| fcos-resx101-coco-2x-800size | 44.9 |
| atss-res50-coco-1x-800size | 40.1 |
| atss-res101-coco-2x-800size | 44.5 |
| atss-resx101-coco-2x-800size | 45.9 |
### 图像分割
-我们也提供了经典的语义分割模型--[Deeplabv3plus](./official/vision/segmentation/),这个模型在**PASCAL VOC验证集**上的测试结果如下:
+我们也提供了经典的语义分割模型--[DeepLabV3+](./official/vision/segmentation/),这个模型在Pascal VOC2012验证集上的测试结果如下:
- | 模型 | Backbone | mIoU_single | mIoU_multi |
- | -- | :--: | :--: | :--: |
- | Deeplabv3plus | Resnet101 | 79.0 | 79.8 |
+| 模型 | mIoU |
+| --- | :--: |
+| deeplabv3plus-res101-voc-512size | 79.5 |
+
+在Cityscapes验证集上的测试结果如下:
+
+| 模型 | mIoU |
+| --- | :--: |
+| deeplabv3plus-res101-cityscapes-768size | 78.5 |
### 人体关节点检测
diff --git a/hubconf.py b/hubconf.py
index a932bef4..33db85d3 100644
--- a/hubconf.py
+++ b/hubconf.py
@@ -29,15 +29,19 @@
)
from official.vision.detection.configs import (
faster_rcnn_res50_coco_1x_800size,
- faster_rcnn_res50_coco_1x_800size_syncbn,
faster_rcnn_res101_coco_2x_800size,
faster_rcnn_resx101_coco_2x_800size,
retinanet_res50_coco_1x_800size,
- retinanet_res50_coco_1x_800size_syncbn,
retinanet_res101_coco_2x_800size,
retinanet_resx101_coco_2x_800size,
+ fcos_res50_coco_1x_800size,
+ fcos_res101_coco_2x_800size,
+ fcos_resx101_coco_2x_800size,
+ atss_res50_coco_1x_800size,
+ atss_res101_coco_2x_800size,
+ atss_resx101_coco_2x_800size,
)
-from official.vision.detection.models import FasterRCNN, RetinaNet
+from official.vision.detection.models import FasterRCNN, RetinaNet, FCOS, ATSS
from official.vision.detection.tools.utils import DetEvaluator
from official.vision.keypoints.inference import KeypointEvaluator
from official.vision.keypoints.models import (
@@ -46,7 +50,8 @@
simplebaseline_res101,
simplebaseline_res152,
)
-from official.vision.segmentation.deeplabv3plus import (
- DeepLabV3Plus,
- deeplabv3plus_res101,
+from official.vision.segmentation.configs import (
+ deeplabv3plus_res101_cityscapes_768size,
+ deeplabv3plus_res101_voc_512size,
)
+from official.vision.segmentation.models import DeepLabV3Plus
diff --git a/official/assets/cat_seg_out.jpg b/official/assets/cat_seg_out.jpg
index be5a2c7e..0be80c88 100644
Binary files a/official/assets/cat_seg_out.jpg and b/official/assets/cat_seg_out.jpg differ
diff --git a/official/vision/detection/README.md b/official/vision/detection/README.md
index 99670ba4..5920f6e2 100644
--- a/official/vision/detection/README.md
+++ b/official/vision/detection/README.md
@@ -2,7 +2,12 @@
## 介绍
-本目录包含了采用MegEngine实现的经典网络结构,包括[RetinaNet](https://arxiv.org/pdf/1708.02002>)、[Faster R-CNN](https://arxiv.org/pdf/1612.03144.pdf)等,同时提供了在COCO2017数据集上的完整训练和测试代码。
+本目录包含了采用MegEngine实现的如下经典网络结构,并提供了在COCO2017数据集上的完整训练和测试代码:
+
+- [RetinaNet](https://arxiv.org/abs/1708.02002)
+- [Faster R-CNN](https://arxiv.org/abs/1612.03144)
+- [FCOS](https://arxiv.org/abs/1904.01355)
+- [ATSS](https://arxiv.org/abs/1912.02424)
网络在COCO2017验证集上的性能和结果如下:
@@ -13,10 +18,10 @@
| retinanet-resx101-coco-2x-800size | 42.7 | 2 |
| faster-rcnn-res50-coco-1x-800size | 38.0 | 2 |
| faster-rcnn-res101-coco-2x-800size | 42.5 | 2 |
-| faster-rcnn-resx101-coco-2x-800size | 44.7 * | 2 |
+| faster-rcnn-resx101-coco-2x-800size | 43.6 | 2 |
| fcos-res50-coco-1x-800size | 39.7 | 2 |
| fcos-res101-coco-2x-800size | 44.1 | 2 |
-| fcos-resx101-coco-2x-800size | 39.7 * | 2 |
+| fcos-resx101-coco-2x-800size | 44.9 | 2 |
| atss-res50-coco-1x-800size | 40.1 | 2 |
| atss-res101-coco-2x-800size | 44.5 | 2 |
| atss-resx101-coco-2x-800size | 45.9 | 2 |
@@ -119,7 +124,7 @@ python3 tools/test.py -f configs/retinanet_res50_coco_1x_800size.py -n 8 \
## 参考文献
-- [Focal Loss for Dense Object Detection](https://arxiv.org/pdf/1708.02002) Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He, Piotr Dollár. Proceedings of the IEEE international conference on computer vision. 2017: 2980-2988.
-- [Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks](https://arxiv.org/pdf/1506.01497.pdf) S. Ren, K. He, R. Girshick, and J. Sun. In: Neural Information Processing Systems(NIPS)(2015).
-- [Feature Pyramid Networks for Object Detection](https://arxiv.org/pdf/1612.03144.pdf) T. Lin, P. Dollár, R. Girshick, K. He, B. Hariharan and S. Belongie. 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Honolulu, HI, 2017, pp. 936-944, doi: 10.1109/CVPR.2017.106.
-- [Microsoft COCO: Common Objects in Context](https://arxiv.org/pdf/1405.0312.pdf) Lin, Tsung-Yi and Maire, Michael and Belongie, Serge and Hays, James and Perona, Pietro and Ramanan, Deva and Dollár, Piotr and Zitnick, C Lawrence, Lin T Y, Maire M, Belongie S, et al. European conference on computer vision. Springer, Cham, 2014: 740-755.
+- [Microsoft COCO: Common Objects in Context](https://arxiv.org/abs/1405.0312) Lin, Tsung-Yi and Maire, Michael and Belongie, Serge and Hays, James and Perona, Pietro and Ramanan, Deva and Dollár, Piotr and Zitnick, C Lawrence, Lin T Y, Maire M, Belongie S, et al. European conference on computer vision. Springer, Cham, 2014: 740-755.
+- [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He, Piotr Dollár. Proceedings of the IEEE international conference on computer vision. 2017: 2980-2988.
+- [Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks](https://arxiv.org/abs/1506.01497) S. Ren, K. He, R. Girshick, and J. Sun. In: Neural Information Processing Systems(NIPS)(2015).
+- [Feature Pyramid Networks for Object Detection](https://arxiv.org/abs/1612.03144) T. Lin, P. Dollár, R. Girshick, K. He, B. Hariharan and S. Belongie. 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Honolulu, HI, 2017, pp. 936-944, doi: 10.1109/CVPR.2017.106.
diff --git a/official/vision/detection/configs/__init__.py b/official/vision/detection/configs/__init__.py
index 2abd428a..023fa844 100644
--- a/official/vision/detection/configs/__init__.py
+++ b/official/vision/detection/configs/__init__.py
@@ -1,11 +1,15 @@
from .faster_rcnn_res50_coco_1x_800size import faster_rcnn_res50_coco_1x_800size
-from .faster_rcnn_res50_coco_1x_800size_syncbn import faster_rcnn_res50_coco_1x_800size_syncbn
from .faster_rcnn_res101_coco_2x_800size import faster_rcnn_res101_coco_2x_800size
from .faster_rcnn_resx101_coco_2x_800size import faster_rcnn_resx101_coco_2x_800size
from .retinanet_res50_coco_1x_800size import retinanet_res50_coco_1x_800size
-from .retinanet_res50_coco_1x_800size_syncbn import retinanet_res50_coco_1x_800size_syncbn
from .retinanet_res101_coco_2x_800size import retinanet_res101_coco_2x_800size
from .retinanet_resx101_coco_2x_800size import retinanet_resx101_coco_2x_800size
+from .fcos_res50_coco_1x_800size import fcos_res50_coco_1x_800size
+from .fcos_res101_coco_2x_800size import fcos_res101_coco_2x_800size
+from .fcos_resx101_coco_2x_800size import fcos_resx101_coco_2x_800size
+from .atss_res50_coco_1x_800size import atss_res50_coco_1x_800size
+from .atss_res101_coco_2x_800size import atss_res101_coco_2x_800size
+from .atss_resx101_coco_2x_800size import atss_resx101_coco_2x_800size
_EXCLUDE = {}
__all__ = [k for k in globals().keys() if k not in _EXCLUDE and not k.startswith("_")]
diff --git a/official/vision/detection/configs/faster_rcnn_res50_objects365_1x_800size.py b/official/vision/detection/configs/faster_rcnn_res50_objects365_1x_800size.py
deleted file mode 100644
index 32094117..00000000
--- a/official/vision/detection/configs/faster_rcnn_res50_objects365_1x_800size.py
+++ /dev/null
@@ -1,47 +0,0 @@
-# -*- coding: utf-8 -*-
-# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
-#
-# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-from official.vision.detection import models
-
-
-class CustomFasterRCNNConfig(models.FasterRCNNConfig):
- def __init__(self):
- super().__init__()
-
- # ------------------------ data cfg -------------------------- #
- self.train_dataset = dict(
- name="objects365",
- root="train",
- ann_file="annotations/objects365_train_20190423.json",
- remove_images_without_annotations=True,
- )
- self.test_dataset = dict(
- name="objects365",
- root="val",
- ann_file="annotations/objects365_val_20190423.json",
- remove_images_without_annotations=False,
- )
- self.num_classes = 365
-
- # ------------------------ training cfg ---------------------- #
- self.nr_images_epoch = 400000
-
-
-def faster_rcnn_res50_objects365_1x_800size(**kwargs):
- r"""
- Faster-RCNN FPN trained from Objects365 dataset.
- `"Faster-RCNN" `_
- `"FPN" `_
- """
- cfg = CustomFasterRCNNConfig()
- cfg.backbone_pretrained = False
- return models.FasterRCNN(cfg, **kwargs)
-
-
-Net = models.FasterRCNN
-Cfg = CustomFasterRCNNConfig
diff --git a/official/vision/detection/configs/faster_rcnn_resx101_coco_2x_800size.py b/official/vision/detection/configs/faster_rcnn_resx101_coco_2x_800size.py
index c358fa60..75a9b028 100644
--- a/official/vision/detection/configs/faster_rcnn_resx101_coco_2x_800size.py
+++ b/official/vision/detection/configs/faster_rcnn_resx101_coco_2x_800size.py
@@ -24,7 +24,7 @@ def __init__(self):
@hub.pretrained(
"https://data.megengine.org.cn/models/weights/"
- "faster_rcnn_resx101_coco_2x_800size_44dot7_d03b05b2.pkl"
+ "faster_rcnn_resx101_coco_2x_800size_43dot6_79fb71a7.pkl"
)
def faster_rcnn_resx101_coco_2x_800size(**kwargs):
r"""
diff --git a/official/vision/detection/configs/fcos_resx101_coco_2x_800size.py b/official/vision/detection/configs/fcos_resx101_coco_2x_800size.py
index ec6573da..66d56f1f 100644
--- a/official/vision/detection/configs/fcos_resx101_coco_2x_800size.py
+++ b/official/vision/detection/configs/fcos_resx101_coco_2x_800size.py
@@ -24,7 +24,7 @@ def __init__(self):
@hub.pretrained(
"https://data.megengine.org.cn/models/weights/"
- "fcos_resx101_coco_2x_800size_39dot7_313ef718.pkl"
+ "fcos_resx101_coco_2x_800size_44dot9_37e7b921.pkl"
)
def fcos_resx101_coco_2x_800size(**kwargs):
r"""
diff --git a/official/vision/detection/configs/retinanet_res50_objects365_1x_800size.py b/official/vision/detection/configs/retinanet_res50_objects365_1x_800size.py
deleted file mode 100644
index 5233cdfc..00000000
--- a/official/vision/detection/configs/retinanet_res50_objects365_1x_800size.py
+++ /dev/null
@@ -1,47 +0,0 @@
-# -*- coding: utf-8 -*-
-# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
-#
-# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-from official.vision.detection import models
-
-
-class CustomRetinaNetConfig(models.RetinaNetConfig):
- def __init__(self):
- super().__init__()
-
- # ------------------------ data cfg -------------------------- #
- self.train_dataset = dict(
- name="objects365",
- root="train",
- ann_file="annotations/objects365_train_20190423.json",
- remove_images_without_annotations=True,
- )
- self.test_dataset = dict(
- name="objects365",
- root="val",
- ann_file="annotations/objects365_val_20190423.json",
- remove_images_without_annotations=False,
- )
- self.num_classes = 365
-
- # ------------------------ training cfg ---------------------- #
- self.nr_images_epoch = 400000
-
-
-def retinanet_res50_objects365_1x_800size(**kwargs):
- r"""
- RetinaNet trained from Objects365 dataset.
- `"RetinaNet" `_
- `"FPN" `_
- """
- cfg = CustomRetinaNetConfig()
- cfg.backbone_pretrained = False
- return models.RetinaNet(cfg, **kwargs)
-
-
-Net = models.RetinaNet
-Cfg = CustomRetinaNetConfig
diff --git a/official/vision/detection/tools/inference.py b/official/vision/detection/tools/inference.py
index 3a23fbe2..0bab6690 100644
--- a/official/vision/detection/tools/inference.py
+++ b/official/vision/detection/tools/inference.py
@@ -40,6 +40,7 @@ def main():
cfg.backbone_pretrained = False
model = current_network.Net(cfg)
model.eval()
+
state_dict = mge.load(args.weight_file)
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
diff --git a/official/vision/detection/tools/test.py b/official/vision/detection/tools/test.py
index d725cae3..4c53b1ac 100644
--- a/official/vision/detection/tools/test.py
+++ b/official/vision/detection/tools/test.py
@@ -105,13 +105,12 @@ def main():
result_list.append(result_queue.get())
for p in procs:
p.join()
-
else:
result_list = []
worker(
current_network, weight_file, args.dataset_dir,
- None, None, args.ngpus, 0, result_list
+ None, None, 1, 0, result_list
)
all_results = DetEvaluator.format(result_list, cfg)
diff --git a/official/vision/detection/tools/test_random.py b/official/vision/detection/tools/test_random.py
index 5f33047a..d105e036 100644
--- a/official/vision/detection/tools/test_random.py
+++ b/official/vision/detection/tools/test_random.py
@@ -66,10 +66,6 @@ def main():
args.end_epoch = args.start_epoch
assert 0 <= args.start_epoch <= args.end_epoch < cfg.max_epoch
- master_ip = "localhost"
- port = dist.get_free_ports(1)[0]
- dist.Server(port)
-
for epoch_num in range(args.start_epoch, args.end_epoch + 1):
if args.weight_file:
weight_file = args.weight_file
@@ -78,32 +74,44 @@ def main():
os.path.basename(args.file).split(".")[0], epoch_num
)
- result_list = []
- result_queue = Queue(2000)
- procs = []
- for i in range(args.ngpus):
- proc = Process(
- target=worker,
- args=(
- current_network,
- weight_file,
- args.dataset_dir,
- master_ip,
- port,
- args.ngpus,
- i,
- result_queue,
- ),
- )
- proc.start()
- procs.append(proc)
-
- num_imgs = dict(coco=5000, objects365=30000)
+ if args.ngpus > 1:
+ master_ip = "localhost"
+ port = dist.get_free_ports(1)[0]
+ dist.Server(port)
+
+ result_list = []
+ result_queue = Queue(2000)
+ procs = []
+ for i in range(args.ngpus):
+ proc = Process(
+ target=worker,
+ args=(
+ current_network,
+ weight_file,
+ args.dataset_dir,
+ master_ip,
+ port,
+ args.ngpus,
+ i,
+ result_queue,
+ ),
+ )
+ proc.start()
+ procs.append(proc)
+
+ num_imgs = dict(coco=5000, objects365=30000)
+
+ for _ in tqdm(range(num_imgs[cfg.test_dataset["name"]])):
+ result_list.append(result_queue.get())
+ for p in procs:
+ p.join()
+ else:
+ result_list = []
- for _ in tqdm(range(num_imgs[cfg.test_dataset["name"]])):
- result_list.append(result_queue.get())
- for p in procs:
- p.join()
+ worker(
+ current_network, weight_file, args.dataset_dir,
+ None, None, 1, 0, result_list
+ )
all_results = DetEvaluator.format(result_list, cfg)
json_path = "log-of-{}/epoch_{}.json".format(
@@ -146,15 +154,18 @@ def main():
def worker(
- current_network, weight_file, dataset_dir, master_ip, port, world_size, rank, result_queue
+ current_network, weight_file, dataset_dir,
+ master_ip, port, world_size, rank, result_list
):
- dist.init_process_group(
- master_ip=master_ip,
- port=port,
- world_size=world_size,
- rank=rank,
- device=rank,
- )
+ if world_size > 1:
+ dist.init_process_group(
+ master_ip=master_ip,
+ port=port,
+ world_size=world_size,
+ rank=rank,
+ device=rank,
+ )
+
mge.device.set_default_device("gpu{}".format(rank))
cfg = current_network.Cfg()
@@ -170,6 +181,9 @@ def worker(
evaluator = DetEvaluator(model)
test_loader = build_dataloader(rank, world_size, dataset_dir, model.cfg)
+ if world_size == 1:
+ test_loader = tqdm(test_loader)
+
for data in test_loader:
image, im_info = DetEvaluator.process_inputs(
data[0][0],
@@ -180,10 +194,14 @@ def worker(
image=mge.tensor(image),
im_info=mge.tensor(im_info)
)
- result_queue.put_nowait({
+ result = {
"det_res": pred_res,
"image_id": int(data[1][2][0].split(".")[0].split("_")[-1]),
- })
+ }
+ if world_size > 1:
+ result_list.put_nowait(result)
+ else:
+ result_list.append(result)
def build_dataloader(rank, world_size, dataset_dir, cfg):
diff --git a/official/vision/segmentation/README.md b/official/vision/segmentation/README.md
index cffaf6cd..c80bba0b 100644
--- a/official/vision/segmentation/README.md
+++ b/official/vision/segmentation/README.md
@@ -1,26 +1,54 @@
-# Semantic Segmentation
+# Megengine Semantic Segmentation Models
-本目录包含了采用MegEngine实现的经典[Deeplabv3plus](https://arxiv.org/abs/1802.02611.pdf)网络结构,同时提供了在PASCAL VOC和Cityscapes数据集上的完整训练和测试代码。
+## 介绍
-网络在PASCAL VOC2012验证集的性能和结果如下:
+本目录包含了采用MegEngine实现的经典[Deeplabv3plus](https://arxiv.org/abs/1802.02611.pdf)网络结构,同时提供了在Pascal VOC2012和Cityscapes数据集上的完整训练和测试代码。
- Methods | Backbone | TrainSet | EvalSet | mIoU_single | mIoU_multi |
- :--: |:--: |:--: |:--: |:--: |:--: |
- Deeplabv3plus | Resnet101 | train_aug | val | 79.0 | 79.8 |
+网络在Pascal VOC2012验证集上的性能和结果如下:
+| 模型 | mIoU |
+| --- | :--: |
+| deeplabv3plus-res101-voc-512size | 79.5 |
+
+网络在Cityscapes验证集上的性能和结果如下:
+
+| 模型 | mIoU |
+| --- | :--: |
+| deeplabv3plus-res101-cityscapes-768size | 78.5 |
## 安装和环境配置
-在开始运行本目录下的代码之前,请确保按照[README](../../../../README.md)进行了正确的环境配置。
+本目录下代码基于MegEngine v1.0,在开始运行本目录下的代码之前,请确保按照[README](../../../README.md)进行了正确的环境配置。
+
+## 如何使用
+
+以DeepLabV3+为例,模型训练好之后,可以通过如下命令测试单张图片:
+
+```bash
+python3 tools/inference.py -f configs/deeplabv3plus_res101_voc_512size.py \
+ -w /path/to/model_weights.pkl \
+ -i ../../assets/cat.jpg
+```
+
+`tools/inference.py`的命令行选项如下:
+- `-f`, 测试的网络结构描述文件。
+- `-w`, 需要测试的模型权重。
+- `-i`, 需要测试的样例图片。
+
+使用默认图片和默认模型测试的结果见下图:
+
+
## 如何训练
-1、在开始训练前,请下载[VOC2012官方数据集](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/#data),并解压到合适的目录下。为保证一样的训练环境,还需要下载[SegmentationClassAug](https://www.dropbox.com/s/oeu149j8qtbs1x0/SegmentationClassAug.zip?dl=0&file_subpath=%2FSegmentationClassAug)。具体可以参照这个[流程](https://www.sun11.me/blog/2018/how-to-use-10582-trainaug-images-on-DeeplabV3-code/)。
+以DeepLabV3+在Pascal VOC2012数据集上训练为例。
+
+1. 在开始训练前,请下载[Pascal VOC2012数据集](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/#data),并解压到合适的目录下。为保证一样的训练环境,还需要下载[SegmentationClassAug](https://www.dropbox.com/s/oeu149j8qtbs1x0/SegmentationClassAug.zip?dl=0&file_subpath=%2FSegmentationClassAug)。具体可以参照这个[流程](https://www.sun11.me/blog/2018/how-to-use-10582-trainaug-images-on-DeeplabV3-code/)。
准备好的 VOC 数据目录结构如下:
-```bash
+```
/path/to/
|->VOC2012
| |Annotations
@@ -29,67 +57,48 @@
| |SegmentationClass
| |SegmentationClass_aug
```
+
其中,ImageSets/Segmentation中包含了[trainaug.txt](https://gist.githubusercontent.com/sun11/2dbda6b31acc7c6292d14a872d0c90b7/raw/5f5a5270089239ef2f6b65b1cc55208355b5acca/trainaug.txt)。
注意:SegmentationClass_aug和SegmentationClass中的数据格式不同。
-2、准备好预训练好的backbone权重,可以直接下载megengine官方提供的在ImageNet上预训练的resnet101模型。
+2. 准备预训练的`backbone`网络权重:可使用 megengine.hub 下载`megengine`官方提供的在ImageNet上训练的模型, 并存放在 `/path/to/pretrain.pkl`。
-3、开始训练:
-
-`train.py`的命令行参数如下:
-- `--config`,训练时采用的配置文件,VOC和Cityscapes各一份默认配置;
-- `--dataset_dir`,训练时采用的训练集存放的目录;
-- `--weight_file`,训练时采用的预训练权重;
-- `--ngpus`, 训练时采用的gpu数量,默认8; 当设置为1时,表示单卡训练
-- `--resume`, 是否从已训好的模型继续训练,默认`None`;
-
-```bash
-python3 train.py --config cfg_voc.py \
- --dataset_dir /path/to/VOC2012 \
- --weight_file /path/to/weights.pkl \
- --ngpus 8
-```
+3. 开始训练:
-或在Cityscapes数据集上进行训练:
```bash
-python3 train.py --config cfg_cityscapes.py \
- --dataset_dir /path/to/Cityscapes \
- --weight_file /path/to/weights.pkl \
- --ngpus 8
+python3 tools/train.py -f configs/deeplabv3plus_res101_voc_512size.py -n 8 \
+ -d /path/to/VOC2012
```
-## 如何测试
+`tools/train.py`的命令行选项如下:
-模型训练好之后,可以通过如下命令测试模型在VOC2012验证集的性能:
+- `-f`, 所需要训练的网络结构描述文件。
+- `-n`, 用于训练的devices(gpu)数量。
+- `-w`, 预训练的backbone网络权重。
+- `-d`, 数据集的上级目录,默认`/data/datasets`。
+- `-r`, 是否从已训好的模型继续训练,默认`None`。
-```bash
-python3 test.py --config cfg_voc.py \
- --dataset_dir /path/to/VOC2012 \
- --model_path /path/to/model.pkl
-```
+默认情况下模型会存在 `log-of-模型名`目录下。
-`test.py`的命令行参数如下:
-- `--config`,训练时采用的配置文件,VOC和Cityscapes各一份默认配置;
-- `--dataset_dir`,验证时采用的验证集目录;
-- `--model_path`,载入训练好的模型;
+## 如何测试
-## 如何使用
+以DeepLabV3+在Pascal VOC2012数据集上测试为例。
-模型训练好之后,可以通过如下命令测试单张图片,得到分割结果:
+在得到训练完保存的模型之后,可以通过tools下的test.py文件测试模型在验证集上的性能。
```bash
-python3 inference.py --model_path /path/to/model \
- --image_path /path/to/image.jpg
+python3 tools/test.py -f configs/deeplabv3plus_res101_voc_512size.py -n 8 \
+ -w /path/to/model_weights.pkl \
+ -d /path/to/VOC2012
```
-`inference.py`的命令行参数如下:
-- `--model_path`,载入训练好的模型;
-- `--image_path`,载入待测试的图像
+`tools/test.py`的命令行选项如下:
-
-


-
+- `-f`, 所需要测试的网络结构描述文件。
+- `-n`, 用于测试的devices(gpu)数量。
+- `-w`, 需要测试的模型权重。
+- `-d`,数据集的上级目录,默认`/data/datasets`。
## 参考文献
diff --git a/official/vision/segmentation/cfg_cityscapes.py b/official/vision/segmentation/cfg_cityscapes.py
deleted file mode 100644
index 7b54c434..00000000
--- a/official/vision/segmentation/cfg_cityscapes.py
+++ /dev/null
@@ -1,42 +0,0 @@
-# -*- coding: utf-8 -*-
-# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
-#
-# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-import os
-
-
-class Config:
- DATASET = "Cityscapes"
-
- BATCH_SIZE = 4
- LEARNING_RATE = 0.0065
- EPOCHS = 200
-
- ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname("__file__")))
- MODEL_SAVE_DIR = os.path.join(ROOT_DIR, "log")
- LOG_DIR = MODEL_SAVE_DIR
- if not os.path.isdir(MODEL_SAVE_DIR):
- os.makedirs(MODEL_SAVE_DIR)
-
- DATA_WORKERS = 4
-
- IGNORE_INDEX = 255
- NUM_CLASSES = 19
- IMG_HEIGHT = 800
- IMG_WIDTH = 800
- IMG_MEAN = [103.530, 116.280, 123.675]
- IMG_STD = [57.375, 57.120, 58.395]
-
- VAL_HEIGHT = 800
- VAL_WIDTH = 800
- VAL_BATCHES = 1
- VAL_MULTISCALE = [1.0] # [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
- VAL_FLIP = False
- VAL_SLIP = True
- VAL_SAVE = None
-
-
-cfg = Config()
diff --git a/official/vision/segmentation/cfg_voc.py b/official/vision/segmentation/cfg_voc.py
deleted file mode 100644
index 0c010da1..00000000
--- a/official/vision/segmentation/cfg_voc.py
+++ /dev/null
@@ -1,43 +0,0 @@
-# -*- coding: utf-8 -*-
-# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
-#
-# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-import os
-
-
-class Config:
- DATASET = "VOC2012"
-
- BATCH_SIZE = 8
- LEARNING_RATE = 0.002
- EPOCHS = 100
-
- ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname("__file__")))
- MODEL_SAVE_DIR = os.path.join(ROOT_DIR, "log")
- LOG_DIR = MODEL_SAVE_DIR
- if not os.path.isdir(MODEL_SAVE_DIR):
- os.makedirs(MODEL_SAVE_DIR)
-
- DATA_WORKERS = 4
- DATA_TYPE = "trainaug"
-
- IGNORE_INDEX = 255
- NUM_CLASSES = 21
- IMG_HEIGHT = 512
- IMG_WIDTH = 512
- IMG_MEAN = [103.530, 116.280, 123.675]
- IMG_STD = [57.375, 57.120, 58.395]
-
- VAL_HEIGHT = 512
- VAL_WIDTH = 512
- VAL_BATCHES = 1
- VAL_MULTISCALE = [1.0] # [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
- VAL_FLIP = False
- VAL_SLIP = False
- VAL_SAVE = None
-
-
-cfg = Config()
diff --git a/official/vision/segmentation/configs/__init__.py b/official/vision/segmentation/configs/__init__.py
new file mode 100644
index 00000000..3298f637
--- /dev/null
+++ b/official/vision/segmentation/configs/__init__.py
@@ -0,0 +1,5 @@
+from .deeplabv3plus_res101_cityscapes_768size import deeplabv3plus_res101_cityscapes_768size
+from .deeplabv3plus_res101_voc_512size import deeplabv3plus_res101_voc_512size
+
+_EXCLUDE = {}
+__all__ = [k for k in globals().keys() if k not in _EXCLUDE and not k.startswith("_")]
diff --git a/official/vision/segmentation/configs/deeplabv3plus_res101_cityscapes_768size.py b/official/vision/segmentation/configs/deeplabv3plus_res101_cityscapes_768size.py
new file mode 100644
index 00000000..8f559060
--- /dev/null
+++ b/official/vision/segmentation/configs/deeplabv3plus_res101_cityscapes_768size.py
@@ -0,0 +1,56 @@
+# -*- coding: utf-8 -*-
+# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
+#
+# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+from megengine import hub
+
+from official.vision.segmentation import models
+
+
+class CityscapesConfig:
+ def __init__(self):
+ self.dataset = "Cityscapes"
+
+ self.backbone = "resnet101"
+ self.backbone_pretrained = True
+
+ self.batch_size = 4
+ self.learning_rate = 0.01
+ self.momentum = 0.9
+ self.weight_decay = 0.0001
+ self.max_epoch = 40
+ self.nr_images_epoch = 32000
+
+ self.ignore_label = 255
+ self.num_classes = 19
+ self.img_height = 768
+ self.img_width = 768
+ self.img_mean = [103.530, 116.280, 123.675] # BGR
+ self.img_std = [57.375, 57.120, 58.395]
+
+ self.val_height = 1024
+ self.val_width = 2048
+ self.val_multiscale = [1.0] # [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
+ self.val_flip = False
+ self.val_slip = False
+ self.val_save_path = None
+
+ self.log_interval = 20
+
+
+@hub.pretrained(
+ "https://data.megengine.org.cn/models/weights/"
+ "deeplabv3plus_res101_cityscapes_768size_78dot5_c45e0cb9.pkl"
+)
+def deeplabv3plus_res101_cityscapes_768size(**kwargs):
+ r"""DeepLab v3+ model from
+ `"Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation" `_
+ """
+ return models.DeepLabV3Plus(**kwargs)
+
+
+Net = models.DeepLabV3Plus
+Cfg = CityscapesConfig
diff --git a/official/vision/segmentation/configs/deeplabv3plus_res101_voc_512size.py b/official/vision/segmentation/configs/deeplabv3plus_res101_voc_512size.py
new file mode 100644
index 00000000..f54c9136
--- /dev/null
+++ b/official/vision/segmentation/configs/deeplabv3plus_res101_voc_512size.py
@@ -0,0 +1,57 @@
+# -*- coding: utf-8 -*-
+# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
+#
+# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+from megengine import hub
+
+from official.vision.segmentation import models
+
+
+class VOCConfig:
+ def __init__(self):
+ self.dataset = "VOC2012"
+ self.data_type = "trainaug"
+
+ self.backbone = "resnet101"
+ self.backbone_pretrained = True
+
+ self.batch_size = 8
+ self.learning_rate = 0.02
+ self.momentum = 0.9
+ self.weight_decay = 0.0001
+ self.max_epoch = 40
+ self.nr_images_epoch = 64000
+
+ self.ignore_label = 255
+ self.num_classes = 21
+ self.img_height = 512
+ self.img_width = 512
+ self.img_mean = [103.530, 116.280, 123.675] # BGR
+ self.img_std = [57.375, 57.120, 58.395]
+
+ self.val_height = 512
+ self.val_width = 512
+ self.val_multiscale = [1.0] # [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
+ self.val_flip = False
+ self.val_slip = False
+ self.val_save_path = None
+
+ self.log_interval = 20
+
+
+@hub.pretrained(
+ "https://data.megengine.org.cn/models/weights/"
+ "deeplabv3plus_res101_voc_512size_79dot5_7856dc84.pkl"
+)
+def deeplabv3plus_res101_voc_512size(**kwargs):
+ r"""DeepLab v3+ model from
+ `"Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation" `_
+ """
+ return models.DeepLabV3Plus(**kwargs)
+
+
+Net = models.DeepLabV3Plus
+Cfg = VOCConfig
diff --git a/official/vision/segmentation/inference.py b/official/vision/segmentation/inference.py
deleted file mode 100644
index 1e1c9830..00000000
--- a/official/vision/segmentation/inference.py
+++ /dev/null
@@ -1,107 +0,0 @@
-# -*- coding: utf-8 -*-
-# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
-#
-# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-import argparse
-
-import cv2
-import megengine as mge
-import megengine.data.dataset as dataset
-import megengine.jit as jit
-import numpy as np
-
-from megengine.utils.http_download import download_from_url
-from official.vision.segmentation.deeplabv3plus import DeepLabV3Plus
-
-
-class Config:
- NUM_CLASSES = 21
- IMG_SIZE = 512
- IMG_MEAN = [103.530, 116.280, 123.675]
- IMG_STD = [57.375, 57.120, 58.395]
-
-
-cfg = Config()
-
-# pre-defined colors for at most 20 categories
-class_colors = [
- [0, 0, 0], # background
- [0, 0, 128],
- [0, 128, 0],
- [0, 128, 128],
- [128, 0, 0],
- [128, 0, 128],
- [128, 128, 0],
- [128, 128, 128],
- [0, 0, 64],
- [0, 0, 192],
- [0, 128, 64],
- [0, 128, 192],
- [128, 0, 64],
- [128, 0, 192],
- [128, 128, 64],
- [128, 128, 192],
- [0, 64, 0],
- [0, 64, 128],
- [0, 192, 0],
- [0, 192, 128],
- [128, 64, 0],
-]
-
-def main():
- parser = argparse.ArgumentParser()
- parser.add_argument("-i", "--image_path", type=str, default=None, help="inference image")
- parser.add_argument("-m", "--model_path", type=str, default=None, help="inference model")
- args = parser.parse_args()
-
- net = load_model(args.model_path)
- if args.image_path is None:
- download_from_url("https://data.megengine.org.cn/images/cat.jpg", "test.jpg")
- img = cv2.imread("test.jpg")
- else:
- img = cv2.imread(args.image_path)
- pred = inference(img, net)
- cv2.imwrite("out.jpg", pred)
-
-def load_model(model_path):
- model_dict = mge.load(model_path)
- net = DeepLabV3Plus(class_num=cfg.NUM_CLASSES)
- net.load_state_dict(model_dict["state_dict"])
- print("load model %s" % (model_path))
- net.eval()
- return net
-
-
-def inference(img, net):
- @jit.trace(symbolic=True, opt_level=2)
- def pred_fun(data, net=None):
- net.eval()
- pred = net(data)
- return pred
-
- img = (img.astype("float32") - np.array(cfg.IMG_MEAN)) / np.array(cfg.IMG_STD)
- orih, oriw = img.shape[:2]
- img = cv2.resize(img, (cfg.IMG_SIZE, cfg.IMG_SIZE))
- img = img.transpose(2, 0, 1)[np.newaxis]
-
- data = mge.tensor()
- data.set_value(img)
- pred = pred_fun(data, net=net)
- pred = pred.numpy().squeeze().argmax(0)
- pred = cv2.resize(
- pred.astype("uint8"), (oriw, orih), interpolation=cv2.INTER_NEAREST
- )
-
- out = np.zeros((orih, oriw, 3))
- nids = np.unique(pred)
- for t in nids:
- out[pred == t] = class_colors[t]
- return out
-
-
-if __name__ == "__main__":
- main()
diff --git a/official/vision/segmentation/models/__init__.py b/official/vision/segmentation/models/__init__.py
new file mode 100644
index 00000000..a834581f
--- /dev/null
+++ b/official/vision/segmentation/models/__init__.py
@@ -0,0 +1,12 @@
+# -*- coding: utf-8 -*-
+# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
+#
+# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+from .deeplabv3plus import *
+
+_EXCLUDE = {}
+__all__ = [k for k in globals().keys() if k not in _EXCLUDE and not k.startswith("_")]
diff --git a/official/vision/segmentation/deeplabv3plus.py b/official/vision/segmentation/models/deeplabv3plus.py
similarity index 51%
rename from official/vision/segmentation/deeplabv3plus.py
rename to official/vision/segmentation/models/deeplabv3plus.py
index 02fd940a..3df5c983 100644
--- a/official/vision/segmentation/deeplabv3plus.py
+++ b/official/vision/segmentation/models/deeplabv3plus.py
@@ -6,48 +6,10 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-import megengine as mge
import megengine.functional as F
-import megengine.hub as hub
import megengine.module as M
-from official.vision.classification.resnet.model import Bottleneck, ResNet
-
-
-class ModifiedResNet(ResNet):
- def _make_layer(
- self, block, channels, blocks, stride=1, dilate=False, norm=M.BatchNorm2d
- ):
- if dilate:
- self.dilation *= stride
- stride = 1
-
- layers = []
- layers.append(
- block(
- self.in_channels,
- channels,
- stride,
- groups=self.groups,
- base_width=self.base_width,
- dilation=self.dilation,
- norm=norm,
- )
- )
- self.in_channels = channels * block.expansion
- for _ in range(1, blocks):
- layers.append(
- block(
- self.in_channels,
- channels,
- groups=self.groups,
- base_width=self.base_width,
- dilation=self.dilation,
- norm=norm,
- )
- )
-
- return M.Sequential(*layers)
+import official.vision.classification.resnet.model as resnet
class ASPP(M.Module):
@@ -56,7 +18,7 @@ def __init__(self, in_channels, out_channels, dr=1):
self.conv1 = M.Sequential(
M.Conv2d(
- in_channels, out_channels, 1, 1, padding=0, dilation=dr, bias=True
+ in_channels, out_channels, 1, 1, padding=0, dilation=dr, bias=False
),
M.BatchNorm2d(out_channels),
M.ReLU(),
@@ -69,7 +31,7 @@ def __init__(self, in_channels, out_channels, dr=1):
1,
padding=6 * dr,
dilation=6 * dr,
- bias=True,
+ bias=False,
),
M.BatchNorm2d(out_channels),
M.ReLU(),
@@ -82,7 +44,7 @@ def __init__(self, in_channels, out_channels, dr=1):
1,
padding=12 * dr,
dilation=12 * dr,
- bias=True,
+ bias=False,
),
M.BatchNorm2d(out_channels),
M.ReLU(),
@@ -95,18 +57,18 @@ def __init__(self, in_channels, out_channels, dr=1):
1,
padding=18 * dr,
dilation=18 * dr,
- bias=True,
+ bias=False,
),
M.BatchNorm2d(out_channels),
M.ReLU(),
)
- self.convgp = M.Sequential(
- M.Conv2d(in_channels, out_channels, 1, 1, 0, bias=True),
+ self.conv_gp = M.Sequential(
+ M.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False),
M.BatchNorm2d(out_channels),
M.ReLU(),
)
- self.convout = M.Sequential(
- M.Conv2d(out_channels * 5, out_channels, 1, 1, padding=0, bias=True),
+ self.conv_out = M.Sequential(
+ M.Conv2d(out_channels * 5, out_channels, 1, 1, padding=0, bias=False),
M.BatchNorm2d(out_channels),
M.ReLU(),
)
@@ -117,23 +79,23 @@ def forward(self, x):
conv32 = self.conv3(x)
conv33 = self.conv4(x)
- gp = F.mean(x, 2, True)
- gp = F.mean(gp, 3, True)
- gp = self.convgp(gp)
- gp = F.interpolate(gp, (x.shapeof(2), x.shapeof(3)))
+ gp = F.mean(x, [2, 3], True)
+ gp = self.conv_gp(gp)
+ gp = F.nn.interpolate(gp, (x.shape[2], x.shape[3]))
out = F.concat([conv1, conv31, conv32, conv33, gp], axis=1)
- out = self.convout(out)
+ out = self.conv_out(out)
return out
class DeepLabV3Plus(M.Module):
- def __init__(self, class_num=21, pretrained=None):
+ def __init__(self, cfg):
super().__init__()
+ self.cfg = cfg
self.output_stride = 16
self.sub_output_stride = self.output_stride // 4
- self.class_num = class_num
+ self.num_classes = cfg.num_classes
self.aspp = ASPP(
in_channels=2048, out_channels=256, dr=16 // self.output_stride
@@ -141,22 +103,22 @@ def __init__(self, class_num=21, pretrained=None):
self.dropout = M.Dropout(0.5)
self.upstage1 = M.Sequential(
- M.Conv2d(256, 48, 1, 1, padding=1 // 2, bias=True),
+ M.Conv2d(256, 48, 1, 1, padding=1 // 2, bias=False),
M.BatchNorm2d(48),
M.ReLU(),
)
self.upstage2 = M.Sequential(
- M.Conv2d(256 + 48, 256, 3, 1, padding=1, bias=True),
+ M.Conv2d(256 + 48, 256, 3, 1, padding=1, bias=False),
M.BatchNorm2d(256),
M.ReLU(),
M.Dropout(0.5),
- M.Conv2d(256, 256, 3, 1, padding=1, bias=True),
+ M.Conv2d(256, 256, 3, 1, padding=1, bias=False),
M.BatchNorm2d(256),
M.ReLU(),
M.Dropout(0.1),
)
- self.convout = M.Conv2d(256, self.class_num, 1, 1, padding=0)
+ self.conv_out = M.Conv2d(256, self.num_classes, 1, 1, padding=0)
for m in self.modules():
if isinstance(m, M.Conv2d):
@@ -165,49 +127,24 @@ def __init__(self, class_num=21, pretrained=None):
M.init.ones_(m.weight)
M.init.zeros_(m.bias)
- self.backbone = ModifiedResNet(
- Bottleneck, [3, 4, 23, 3], replace_stride_with_dilation=[False, False, True]
+ self.backbone = getattr(resnet, cfg.backbone)(
+ replace_stride_with_dilation=[False, False, True],
+ pretrained=cfg.backbone_pretrained,
)
- if pretrained is not None:
- model_dict = mge.load(pretrained)
- self.backbone.load_state_dict(model_dict)
+ del self.backbone.fc
def forward(self, x):
layers = self.backbone.extract_features(x)
up0 = self.aspp(layers["res5"])
up0 = self.dropout(up0)
- up0 = F.interpolate(up0, scale_factor=self.sub_output_stride)
+ up0 = F.nn.interpolate(up0, scale_factor=self.sub_output_stride)
up1 = self.upstage1(layers["res2"])
up1 = F.concat([up0, up1], 1)
up2 = self.upstage2(up1)
- out = self.convout(up2)
- out = F.interpolate(out, scale_factor=4)
+ out = self.conv_out(up2)
+ out = F.nn.interpolate(out, scale_factor=4)
return out
-
-
-def softmax_cross_entropy(pred, label, axis=1, ignore_index=255):
- offset = F.zero_grad(pred.max(axis=axis, keepdims=True))
- pred = pred - offset
- log_prob = pred - F.log(F.exp(pred).sum(axis=axis, keepdims=True))
-
- mask = 1 - F.equal(label, ignore_index)
- vlabel = label * mask
- loss = -(F.indexing_one_hot(log_prob, vlabel, axis) * mask).sum() / F.maximum(
- mask.sum(), 1
- )
- return loss
-
-
-@hub.pretrained(
- "https://data.megengine.org.cn/models/weights/"
- "sematicseg_0f8e02aa_deeplabv3plus.pkl"
-)
-def deeplabv3plus_res101(**kwargs):
- r"""DeepLab v3+ model from
- `"Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation" `_
- """
- return DeepLabV3Plus(**kwargs)
diff --git a/official/vision/segmentation/test.py b/official/vision/segmentation/test.py
deleted file mode 100644
index 0fa8bed0..00000000
--- a/official/vision/segmentation/test.py
+++ /dev/null
@@ -1,263 +0,0 @@
-# -*- coding: utf-8 -*-
-# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
-#
-# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-import argparse
-import multiprocessing as mp
-import os
-
-import cv2
-import megengine as mge
-import megengine.data as data
-import megengine.data.dataset as dataset
-import megengine.data.transform as T
-import megengine.jit as jit
-import numpy as np
-from tqdm import tqdm
-
-from official.vision.segmentation.deeplabv3plus import DeepLabV3Plus
-from official.vision.segmentation.utils import import_config_from_file
-
-
-def main():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "-c", "--config", type=str, required=True, help="configuration file"
- )
- parser.add_argument(
- "-d", "--dataset_dir", type=str, default="/data/datasets/VOC2012",
- )
- parser.add_argument(
- "-m", "--model_path", type=str, default=None, help="eval model file"
- )
- args = parser.parse_args()
-
- cfg = import_config_from_file(args.config)
-
- test_loader, test_size = build_dataloader(args.dataset_dir, cfg)
- print("number of test images: %d" % (test_size))
- net = DeepLabV3Plus(class_num=cfg.NUM_CLASSES)
- model_dict = mge.load(args.model_path)
-
- net.load_state_dict(model_dict["state_dict"])
- print("load model %s" % (args.model_path))
- net.eval()
-
- result_list = []
- for sample_batched in tqdm(test_loader):
- img = sample_batched[0].squeeze()
- label = sample_batched[1].squeeze()
- im_info = sample_batched[2]
- pred = evaluate(net, img, cfg)
- result_list.append({"pred": pred, "gt": label, "name":im_info[2]})
- if cfg.VAL_SAVE:
- save_results(result_list, cfg.VAL_SAVE, cfg)
- compute_metric(result_list, cfg)
-
-
-## inference one image
-def pad_image_to_shape(img, shape, border_mode, value):
- margin = np.zeros(4, np.uint32)
- pad_height = shape[0] - img.shape[0] if shape[0] - img.shape[0] > 0 else 0
- pad_width = shape[1] - img.shape[1] if shape[1] - img.shape[1] > 0 else 0
- margin[0] = pad_height // 2
- margin[1] = pad_height // 2 + pad_height % 2
- margin[2] = pad_width // 2
- margin[3] = pad_width // 2 + pad_width % 2
- img = cv2.copyMakeBorder(
- img, margin[0], margin[1], margin[2], margin[3], border_mode, value=value
- )
- return img, margin
-
-
-def eval_single(net, img, is_flip):
- @jit.trace(symbolic=True, opt_level=2)
- def pred_fun(data, net=None):
- net.eval()
- pred = net(data)
- return pred
-
- data = mge.tensor()
- data.set_value(img.transpose(2, 0, 1)[np.newaxis])
- pred = pred_fun(data, net=net)
- if is_flip:
- img_flip = img[:, ::-1, :]
- data.set_value(img_flip.transpose(2, 0, 1)[np.newaxis])
- pred_flip = pred_fun(data, net=net)
- pred = (pred + pred_flip[:, :, :, ::-1]) / 2.0
- del pred_flip
- pred = pred.numpy().squeeze().transpose(1, 2, 0)
- del data
- return pred
-
-
-def evaluate(net, img, cfg):
- ori_h, ori_w, _ = img.shape
- pred_all = np.zeros((ori_h, ori_w, cfg.NUM_CLASSES))
- for rate in cfg.VAL_MULTISCALE:
- if cfg.VAL_SLIP:
- new_h, new_w = int(ori_h*rate), int(ori_w*rate)
- val_size = (cfg.VAL_HEIGHT, cfg.VAL_WIDTH)
- else:
- new_h, new_w = int(cfg.VAL_HEIGHT*rate), int(cfg.VAL_WIDTH*rate)
- val_size = (new_h, new_w)
- img_scale = cv2.resize(
- img, (new_w, new_h), interpolation=cv2.INTER_LINEAR
- )
-
- if (new_h <= val_size[0]) and (new_h <= val_size[1]):
- img_pad, margin = pad_image_to_shape(
- img_scale, val_size, cv2.BORDER_CONSTANT, value=0
- )
- pred = eval_single(net, img_pad, cfg.VAL_FLIP)
- pred = pred[
- margin[0] : (pred.shape[0] - margin[1]),
- margin[2] : (pred.shape[1] - margin[3]),
- :,
- ]
- else:
- stride_rate = 2 / 3
- stride = [int(np.ceil(i * stride_rate)) for i in val_size]
- img_pad, margin = pad_image_to_shape(
- img_scale, val_size, cv2.BORDER_CONSTANT, value=0
- )
- pad_h, pad_w = img_pad.shape[:2]
- r_grid, c_grid = [
- int(np.ceil((ps - cs) / stride)) + 1
- for ps, cs, stride in zip(img_pad.shape, val_size, stride)
- ]
-
- pred_scale = np.zeros((pad_h, pad_w, cfg.NUM_CLASSES))
- count_scale = np.zeros((pad_h, pad_w, cfg.NUM_CLASSES))
- for grid_yidx in range(r_grid):
- for grid_xidx in range(c_grid):
- s_x = grid_xidx * stride[1]
- s_y = grid_yidx * stride[0]
- e_x = min(s_x + val_size[1], pad_w)
- e_y = min(s_y + val_size[0], pad_h)
- s_x = e_x - val_size[1]
- s_y = e_y - val_size[0]
- img_sub = img_pad[s_y:e_y, s_x:e_x, :]
- tpred = eval_single(net, img_sub, cfg.VAL_FLIP)
- count_scale[s_y:e_y, s_x:e_x, :] += 1
- pred_scale[s_y:e_y, s_x:e_x, :] += tpred
- #pred_scale = pred_scale / count_scale
- pred = pred_scale[
- margin[0] : (pred_scale.shape[0] - margin[1]),
- margin[2] : (pred_scale.shape[1] - margin[3]),
- :,
- ]
-
- pred = cv2.resize(pred, (ori_w, ori_h), interpolation=cv2.INTER_LINEAR)
- pred_all = pred_all + pred
-
- #pred_all = pred_all / len(cfg.VAL_MULTISCALE)
- result = np.argmax(pred_all, axis=2).astype(np.uint8)
- return result
-
-
-def save_results(result_list, save_dir, cfg):
- if not os.path.exists(save_dir):
- os.makedirs(save_dir)
- for idx, sample in enumerate(result_list):
- if cfg.DATASET == "Cityscapes":
- name = sample["name"].split('/')[-1][:-4]
- else:
- name = sample["name"]
- file_path = os.path.join(save_dir, "%s.png"%name)
- cv2.imwrite(file_path, sample["pred"])
- file_path = os.path.join(save_dir, "%s.gt.png"%name)
- cv2.imwrite(file_path, sample["gt"])
-
-# voc cityscapes metric
-def compute_metric(result_list, cfg):
- class_num = cfg.NUM_CLASSES
- hist = np.zeros((class_num, class_num))
- correct = 0
- labeled = 0
- count = 0
- for idx in range(len(result_list)):
- pred = result_list[idx]['pred']
- gt = result_list[idx]['gt']
- assert(pred.shape == gt.shape)
- k = (gt>=0) & (gt 0] * freq[freq >0]).sum()
- mean_pixel_acc = correct / labeled
-
- if cfg.DATASET == "VOC2012":
- class_names = ("background", ) + dataset.PascalVOC.class_names
- elif cfg.DATASET == "Cityscapes":
- class_names = dataset.Cityscapes.class_names
- else:
- raise ValueError("Unsupported dataset {}".format(cfg.DATASET))
-
- n = iu.size
- lines = []
- for i in range(n):
- if class_names is None:
- cls = 'Class %d:' % (i+1)
- else:
- cls = '%d %s' % (i+1, class_names[i])
- lines.append('%-8s\t%.3f%%' % (cls, iu[i] * 100))
- lines.append('---------------------------- %-8s\t%.3f%%\t%-8s\t%.3f%%' % ('mean_IU', mean_IU * 100,'mean_pixel_ACC',mean_pixel_acc*100))
- line = "\n".join(lines)
- print(line)
- return mean_IU
-
-
-class EvalPascalVOC(dataset.PascalVOC):
- def _trans_mask(self, mask):
- label = np.ones(mask.shape[:2]) * 255
- class_colors = self.class_colors.copy()
- class_colors.insert(0, [0,0,0])
- for i in range(len(class_colors)):
- b, g, r = class_colors[i]
- label[
- (mask[:, :, 0] == b) & (mask[:, :, 1] == g) & (mask[:, :, 2] == r)
- ] = i
- return label.astype(np.uint8)
-
-def build_dataloader(dataset_dir, cfg):
- if cfg.DATASET == "VOC2012":
- val_dataset = EvalPascalVOC(
- dataset_dir,
- "val",
- order=["image", "mask", "info"]
- )
- elif cfg.DATASET == "Cityscapes":
- val_dataset = dataset.Cityscapes(
- dataset_dir,
- "val",
- mode='gtFine',
- order=["image", "mask", "info"]
- )
- else:
- raise ValueError("Unsupported dataset {}".format(cfg.DATASET))
-
- val_sampler = data.SequentialSampler(val_dataset, cfg.VAL_BATCHES)
- val_dataloader = data.DataLoader(
- val_dataset,
- sampler=val_sampler,
- transform=T.Normalize(
- mean=cfg.IMG_MEAN, std=cfg.IMG_STD, order=["image", "mask"]
- ),
- num_workers=cfg.DATA_WORKERS,
- )
- return val_dataloader, val_dataset.__len__()
-
-
-if __name__ == "__main__":
- main()
diff --git a/official/vision/segmentation/tools/inference.py b/official/vision/segmentation/tools/inference.py
new file mode 100644
index 00000000..08c0a2d8
--- /dev/null
+++ b/official/vision/segmentation/tools/inference.py
@@ -0,0 +1,75 @@
+# -*- coding: utf-8 -*-
+# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
+#
+# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+import argparse
+
+import cv2
+import numpy as np
+
+import megengine as mge
+
+from official.vision.segmentation.tools.utils import class_colors, import_from_file
+
+logger = mge.get_logger(__name__)
+logger.setLevel("INFO")
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-f", "--file", default="net.py", type=str, help="net description file"
+ )
+ parser.add_argument(
+ "-w", "--weight_file", default=None, type=str, help="weights file",
+ )
+ parser.add_argument("-i", "--image", type=str)
+ args = parser.parse_args()
+
+ current_network = import_from_file(args.file)
+ cfg = current_network.Cfg()
+ cfg.backbone_pretrained = False
+ model = current_network.Net(cfg)
+ model.eval()
+
+ state_dict = mge.load(args.weight_file)
+ if "state_dict" in state_dict:
+ state_dict = state_dict["state_dict"]
+ model.load_state_dict(state_dict)
+
+ img = cv2.imread(args.image)
+ pred = inference(img, model)
+ cv2.imwrite("results.jpg", pred)
+
+
+def inference(img, model):
+ def pred_func(data):
+ pred = model(data)
+ return pred
+
+ img = (
+ img.astype("float32") - np.array(model.cfg.img_mean)
+ ) / np.array(model.cfg.img_std)
+ ori_h, ori_w = img.shape[:2]
+ img = cv2.resize(img, (model.cfg.val_height, model.cfg.val_width))
+ img = img.transpose(2, 0, 1)[np.newaxis]
+
+ pred = pred_func(mge.tensor(img))
+ pred = pred.numpy().squeeze().argmax(0)
+ pred = cv2.resize(
+ pred.astype("uint8"), (ori_w, ori_h), interpolation=cv2.INTER_NEAREST
+ )
+
+ out = np.zeros((ori_h, ori_w, 3))
+ nids = np.unique(pred)
+ for t in nids:
+ out[pred == t] = class_colors[t]
+ return out
+
+
+if __name__ == "__main__":
+ main()
diff --git a/official/vision/segmentation/tools/test.py b/official/vision/segmentation/tools/test.py
new file mode 100644
index 00000000..4c3aa245
--- /dev/null
+++ b/official/vision/segmentation/tools/test.py
@@ -0,0 +1,335 @@
+# -*- coding: utf-8 -*-
+# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
+#
+# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+import argparse
+import os
+from multiprocessing import Process, Queue
+from tqdm import tqdm
+
+import cv2
+import numpy as np
+
+import megengine as mge
+import megengine.distributed as dist
+from megengine.data import DataLoader, dataset
+from megengine.data import transform as T
+# from megengine.jit import trace
+
+from official.vision.segmentation.tools.utils import (
+ InferenceSampler,
+ class_colors,
+ import_from_file
+)
+
+logger = mge.get_logger(__name__)
+logger.setLevel("INFO")
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-f", "--file", default="net.py", type=str, help="net description file"
+ )
+ parser.add_argument(
+ "-w", "--weight_file", default=None, type=str, help="weights file",
+ )
+ parser.add_argument(
+ "-n", "--ngpus", default=1, type=int, help="total number of gpus for testing",
+ )
+ parser.add_argument(
+ "-d", "--dataset_dir", type=str, default="/data/datasets",
+ )
+ args = parser.parse_args()
+
+ current_network = import_from_file(args.file)
+ cfg = current_network.Cfg()
+
+ if args.ngpus > 1:
+ master_ip = "localhost"
+ port = dist.get_free_ports(1)[0]
+ dist.Server(port)
+
+ result_list = []
+ result_queue = Queue(500)
+ procs = []
+ for i in range(args.ngpus):
+ proc = Process(
+ target=worker,
+ args=(
+ current_network,
+ args.weight_file,
+ args.dataset_dir,
+ master_ip,
+ port,
+ args.ngpus,
+ i,
+ result_queue,
+ ),
+ )
+ proc.start()
+ procs.append(proc)
+
+ num_imgs = dict(VOC2012=1449, Cityscapes=500)
+
+ for _ in tqdm(range(num_imgs[cfg.dataset])):
+ result_list.append(result_queue.get())
+ for p in procs:
+ p.join()
+ else:
+ result_list = []
+
+ worker(
+ current_network, args.weight_file, args.dataset_dir,
+ None, None, 1, 0, result_list
+ )
+
+
+ if cfg.val_save_path is not None:
+ save_results(result_list, cfg.val_save_path, cfg)
+ logger.info("Start evaluation!")
+ compute_metric(result_list, cfg)
+
+
+def worker(
+ current_network, weight_file, dataset_dir,
+ master_ip, port, world_size, rank, result_list
+):
+ if world_size > 1:
+ dist.init_process_group(
+ master_ip=master_ip,
+ port=port,
+ world_size=world_size,
+ rank=rank,
+ device=rank,
+ )
+
+ mge.device.set_default_device("gpu{}".format(rank))
+
+ cfg = current_network.Cfg()
+ cfg.backbone_pretrained = False
+ model = current_network.Net(cfg)
+ model.eval()
+
+ state_dict = mge.load(weight_file)
+ if "state_dict" in state_dict:
+ state_dict = state_dict["state_dict"]
+ model.load_state_dict(state_dict)
+
+ # @trace(symbolic=True)
+ def pred_func(data):
+ pred = model(data)
+ return pred
+
+ test_loader = build_dataloader(rank, world_size, dataset_dir, model.cfg)
+ if world_size == 1:
+ test_loader = tqdm(test_loader)
+
+ for data in test_loader:
+ img = data[0].squeeze()
+ label = data[1].squeeze()
+ im_info = data[2]
+ pred = evaluate(pred_func, img, model.cfg)
+ result = {"pred": pred, "gt": label, "name": im_info[2]}
+ if world_size > 1:
+ result_list.put_nowait(result)
+ else:
+ result_list.append(result)
+
+
+## inference one image
+def pad_image_to_shape(img, shape, border_mode, value):
+ margin = np.zeros(4, np.uint32)
+ pad_height = shape[0] - img.shape[0] if shape[0] - img.shape[0] > 0 else 0
+ pad_width = shape[1] - img.shape[1] if shape[1] - img.shape[1] > 0 else 0
+ margin[0] = pad_height // 2
+ margin[1] = pad_height // 2 + pad_height % 2
+ margin[2] = pad_width // 2
+ margin[3] = pad_width // 2 + pad_width % 2
+ img = cv2.copyMakeBorder(
+ img, margin[0], margin[1], margin[2], margin[3], border_mode, value=value
+ )
+ return img, margin
+
+
+def eval_single(pred_func, img, is_flip):
+ pred = pred_func(mge.tensor(img.transpose(2, 0, 1)[np.newaxis]))
+ if is_flip:
+ pred_flip = pred_func(mge.tensor(img[:, ::-1].transpose(2, 0, 1)[np.newaxis]))
+ pred = (pred + pred_flip[:, :, :, ::-1]) / 2.0
+ del pred_flip
+ pred = pred.numpy().squeeze().transpose(1, 2, 0)
+ return pred
+
+
+def evaluate(pred_func, img, cfg):
+ ori_h, ori_w, _ = img.shape
+ pred_all = np.zeros((ori_h, ori_w, cfg.num_classes))
+ for rate in cfg.val_multiscale:
+ if cfg.val_slip:
+ new_h, new_w = int(ori_h * rate), int(ori_w * rate)
+ val_size = (cfg.val_height, cfg.val_width)
+ else:
+ new_h, new_w = int(cfg.val_height * rate), int(cfg.val_width * rate)
+ val_size = (new_h, new_w)
+ img_scale = cv2.resize(
+ img, (new_w, new_h), interpolation=cv2.INTER_LINEAR
+ )
+
+ if (new_h <= val_size[0]) and (new_h <= val_size[1]):
+ img_pad, margin = pad_image_to_shape(
+ img_scale, val_size, cv2.BORDER_CONSTANT, value=0
+ )
+ pred = eval_single(pred_func, img_pad, cfg.val_flip)
+ pred = pred[
+ margin[0]:(pred.shape[0] - margin[1]),
+ margin[2]:(pred.shape[1] - margin[3]),
+ ]
+ else:
+ stride_rate = 2 / 3
+ stride = [int(np.ceil(i * stride_rate)) for i in val_size]
+ img_pad, margin = pad_image_to_shape(
+ img_scale, val_size, cv2.BORDER_CONSTANT, value=0
+ )
+ pad_h, pad_w = img_pad.shape[:2]
+ r_grid, c_grid = [
+ int(np.ceil((ps - cs) / stride)) + 1
+ for ps, cs, stride in zip(img_pad.shape, val_size, stride)
+ ]
+
+ pred_scale = np.zeros((pad_h, pad_w, cfg.num_classes))
+ count_scale = np.zeros((pad_h, pad_w, cfg.num_classes))
+ for grid_yidx in range(r_grid):
+ for grid_xidx in range(c_grid):
+ s_x = grid_xidx * stride[1]
+ s_y = grid_yidx * stride[0]
+ e_x = min(s_x + val_size[1], pad_w)
+ e_y = min(s_y + val_size[0], pad_h)
+ s_x = e_x - val_size[1]
+ s_y = e_y - val_size[0]
+ img_sub = img_pad[s_y:e_y, s_x:e_x]
+ tpred = eval_single(pred_func, img_sub, cfg.val_flip)
+ count_scale[s_y:e_y, s_x:e_x] += 1
+ pred_scale[s_y:e_y, s_x:e_x] += tpred
+ # pred_scale = pred_scale / count_scale
+ pred = pred_scale[
+ margin[0]:(pred_scale.shape[0] - margin[1]),
+ margin[2]:(pred_scale.shape[1] - margin[3]),
+ ]
+
+ pred_all += cv2.resize(pred, (ori_w, ori_h), interpolation=cv2.INTER_LINEAR)
+
+ # pred_all = pred_all / len(cfg.val_multiscale)
+ result = np.argmax(pred_all, axis=2).astype(np.uint8)
+ return result
+
+
+def save_results(result_list, save_dir, cfg):
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+ for sample in result_list:
+ if cfg.dataset == "Cityscapes":
+ name = sample["name"].split("/")[-1][:-4]
+ else:
+ name = sample["name"]
+ file_path = os.path.join(save_dir, "%s.png" % name)
+ cv2.imwrite(file_path, sample["pred"])
+ file_path = os.path.join(save_dir, "%s.gt.png" % name)
+ cv2.imwrite(file_path, sample["gt"])
+
+
+# voc cityscapes metric
+def compute_metric(result_list, cfg):
+ num_classes = cfg.num_classes
+ hist = np.zeros((num_classes, num_classes))
+ correct = 0
+ labeled = 0
+ count = 0
+ for result in result_list:
+ pred = result["pred"]
+ gt = result["gt"]
+ assert pred.shape == gt.shape
+ k = (gt >= 0) & (gt < num_classes)
+ labeled += np.sum(k)
+ correct += np.sum((pred[k] == gt[k]))
+ hist += np.bincount(
+ num_classes * gt[k].astype(int) + pred[k].astype(int),
+ minlength=num_classes ** 2
+ ).reshape(num_classes, num_classes)
+ count += 1
+
+ iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
+ mean_IU = np.nanmean(iu)
+ # mean_IU_no_back = np.nanmean(iu[1:])
+ # freq = hist.sum(1) / hist.sum()
+ # freq_IU = (iu[freq > 0] * freq[freq > 0]).sum()
+ mean_pixel_acc = correct / labeled
+
+ if cfg.dataset == "VOC2012":
+ class_names = ("background", ) + dataset.PascalVOC.class_names
+ elif cfg.dataset == "Cityscapes":
+ class_names = dataset.Cityscapes.class_names
+ else:
+ raise ValueError("Unsupported dataset {}".format(cfg.dataset))
+
+ n = iu.size
+ lines = []
+ for i in range(n):
+ if class_names is None:
+ cls = "Class %d:" % (i + 1)
+ else:
+ cls = "%d %s" % (i + 1, class_names[i])
+ lines.append("%-8s\t%.3f%%" % (cls, iu[i] * 100))
+ lines.append(
+ "---------------------------- %-8s\t%.3f%%\t%-8s\t%.3f%%" % (
+ "mean_IU", mean_IU * 100, "mean_pixel_ACC", mean_pixel_acc * 100
+ )
+ )
+ line = "\n".join(lines)
+ logger.info(line)
+
+
+class EvalPascalVOC(dataset.PascalVOC):
+ def _trans_mask(self, mask):
+ label = np.ones(mask.shape[:2]) * 255
+ for i, (b, g, r) in enumerate(class_colors):
+ label[
+ (mask[:, :, 0] == b) & (mask[:, :, 1] == g) & (mask[:, :, 2] == r)
+ ] = i
+ return label.astype(np.uint8)
+
+def build_dataloader(rank, world_size, dataset_dir, cfg):
+ if cfg.dataset == "VOC2012":
+ val_dataset = EvalPascalVOC(
+ dataset_dir,
+ "val",
+ order=["image", "mask", "info"]
+ )
+ elif cfg.dataset == "Cityscapes":
+ val_dataset = dataset.Cityscapes(
+ dataset_dir,
+ "val",
+ mode="gtFine",
+ order=["image", "mask", "info"]
+ )
+ else:
+ raise ValueError("Unsupported dataset {}".format(cfg.dataset))
+
+ val_sampler = InferenceSampler(val_dataset, 1, world_size=world_size, rank=rank)
+ val_dataloader = DataLoader(
+ val_dataset,
+ sampler=val_sampler,
+ transform=T.Normalize(
+ mean=cfg.img_mean, std=cfg.img_std, order=["image", "mask"]
+ ),
+ num_workers=2,
+ )
+ return val_dataloader
+
+
+if __name__ == "__main__":
+ main()
diff --git a/official/vision/segmentation/tools/train.py b/official/vision/segmentation/tools/train.py
new file mode 100644
index 00000000..f2611ef0
--- /dev/null
+++ b/official/vision/segmentation/tools/train.py
@@ -0,0 +1,259 @@
+# -*- coding: utf-8 -*-
+# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
+#
+# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+import argparse
+import multiprocessing as mp
+import os
+import time
+
+import numpy as np
+
+import megengine as mge
+import megengine.distributed as dist
+import megengine.functional as F
+from megengine.autodiff import GradManager
+from megengine.data import DataLoader, Infinite, RandomSampler, dataset
+from megengine.data import transform as T
+# from megengine.jit import trace
+from megengine.optimizer import SGD
+
+from official.vision.segmentation.tools.utils import AverageMeter, get_config_info, import_from_file
+
+logger = mge.get_logger(__name__)
+logger.setLevel("INFO")
+mge.device.set_prealloc_config(1024, 1024, 512 * 1024 * 1024, 2.0)
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-f", "--file", default="net.py", type=str, help="net description file"
+ )
+ parser.add_argument(
+ "-n", "--ngpus", type=int, default=8, help="batchsize for training"
+ )
+ parser.add_argument(
+ "-d", "--dataset_dir", type=str, default="/data/datasets",
+ )
+ parser.add_argument(
+ "-r", "--resume", type=str, default=None, help="resume model file"
+ )
+ args = parser.parse_args()
+
+ # ------------------------ begin training -------------------------- #
+ logger.info("Device Count = %d", args.ngpus)
+
+ log_dir = "log-of-{}".format(os.path.basename(args.file).split(".")[0])
+ if not os.path.isdir(log_dir):
+ os.makedirs(log_dir)
+
+ if args.ngpus > 1:
+ master_ip = "localhost"
+ port = dist.get_free_ports(1)[0]
+ dist.Server(port)
+ processes = list()
+ for rank in range(args.ngpus):
+ process = mp.Process(
+ target=worker, args=(master_ip, port, args.ngpus, rank, args)
+ )
+ process.start()
+ processes.append(process)
+
+ for p in processes:
+ p.join()
+ else:
+ worker(None, None, 1, 0, args)
+
+
+def worker(master_ip, port, world_size, rank, args):
+ if world_size > 1:
+ dist.init_process_group(
+ master_ip=master_ip,
+ port=port,
+ world_size=world_size,
+ rank=rank,
+ device=rank,
+ )
+ logger.info("Init process group for gpu{} done".format(rank))
+
+ current_network = import_from_file(args.file)
+
+ model = current_network.Net(current_network.Cfg())
+ model.train()
+
+ if dist.get_rank() == 0:
+ logger.info(get_config_info(model.cfg))
+ logger.info(repr(model))
+
+ backbone_params = []
+ head_params = []
+ for name, param in model.named_parameters():
+ if "backbone" in name:
+ backbone_params.append(param)
+ else:
+ head_params.append(param)
+
+ opt = SGD(
+ [
+ {"params": backbone_params, "lr": model.cfg.learning_rate * 0.1},
+ {"params": head_params},
+ ],
+ lr=model.cfg.learning_rate,
+ momentum=model.cfg.momentum,
+ weight_decay=model.cfg.weight_decay * dist.get_world_size(),
+ )
+
+ gm = GradManager()
+ if dist.get_world_size() > 1:
+ gm.attach(
+ model.parameters(),
+ callbacks=[dist.make_allreduce_cb("SUM", dist.WORLD)]
+ )
+ else:
+ gm.attach(model.parameters())
+
+ cur_epoch = 0
+ if args.resume is not None:
+ pretrained = mge.load(args.resume)
+ cur_epoch = pretrained["epoch"] + 1
+ model.load_state_dict(pretrained["state_dict"])
+ opt.load_state_dict(pretrained["opt"])
+ if dist.get_rank() == 0:
+ logger.info("load success: epoch %d", cur_epoch)
+
+ if dist.get_world_size() > 1:
+ dist.bcast_list_(model.parameters(), dist.WORLD) # sync parameters
+
+ if dist.get_rank() == 0:
+ logger.info("Prepare dataset")
+ train_loader = iter(
+ build_dataloader(model.cfg.batch_size, args.dataset_dir, model.cfg)
+ )
+
+ for epoch in range(cur_epoch, model.cfg.max_epoch):
+ train_one_epoch(model, train_loader, opt, gm, epoch)
+ if dist.get_rank() == 0:
+ save_path = "log-of-{}/epoch_{}.pkl".format(
+ os.path.basename(args.file).split(".")[0], epoch
+ )
+ mge.save({
+ "epoch": epoch,
+ "state_dict": model.state_dict(),
+ "opt": opt.state_dict()
+ }, save_path)
+ logger.info("dump weights to %s", save_path)
+
+
+def train_one_epoch(model, data_queue, opt, gm, epoch):
+ # @trace(symbolic=True)
+ def train_func(data, label):
+ with gm:
+ pred = model(data)
+ loss = cross_entropy(
+ pred, label, ignore_label=model.cfg.ignore_label
+ )
+ gm.backward(loss)
+ opt.step().clear_grad()
+ return loss
+
+ meter = AverageMeter(record_len=1)
+ time_meter = AverageMeter(record_len=2)
+ log_interval = model.cfg.log_interval
+ tot_step = model.cfg.nr_images_epoch // (
+ model.cfg.batch_size * dist.get_world_size()
+ )
+ for step in range(tot_step):
+ adjust_learning_rate(opt, epoch, step, tot_step, model.cfg)
+
+ data_tik = time.time()
+ inputs, labels = next(data_queue)
+ labels = np.squeeze(labels, axis=1).astype(np.int32)
+ data_tok = time.time()
+
+ tik = time.time()
+ loss = train_func(mge.tensor(inputs), mge.tensor(labels))
+ tok = time.time()
+
+ time_meter.update([tok - tik, data_tok - data_tik])
+
+ if dist.get_rank() == 0:
+ info_str = "e%d, %d/%d, lr:%f, "
+ loss_str = ", ".join(["{}:%f".format(loss) for loss in ["loss"]])
+ time_str = ", train_time:%.3fs, data_time:%.3fs"
+ log_info_str = info_str + loss_str + time_str
+ meter.update([loss.numpy() for loss in [loss]])
+ if step % log_interval == 0:
+ logger.info(
+ log_info_str,
+ epoch,
+ step,
+ tot_step,
+ opt.param_groups[1]["lr"],
+ *meter.average(),
+ *time_meter.average()
+ )
+ meter.reset()
+ time_meter.reset()
+
+
+def adjust_learning_rate(optimizer, epoch, step, tot_step, cfg):
+ max_iter = cfg.max_epoch * tot_step
+ cur_iter = epoch * tot_step + step
+ cur_lr = cfg.learning_rate * (1 - cur_iter / (max_iter + 1)) ** 0.9
+ optimizer.param_groups[0]["lr"] = cur_lr * 0.1
+ optimizer.param_groups[1]["lr"] = cur_lr
+
+
+def cross_entropy(pred, label, axis=1, ignore_label=255):
+ mask = label != ignore_label
+ pred = pred.transpose(0, 2, 3, 1)
+ return F.loss.cross_entropy(pred[mask], label[mask], axis)
+
+
+def build_dataloader(batch_size, dataset_dir, cfg):
+ if cfg.dataset == "VOC2012":
+ train_dataset = dataset.PascalVOC(
+ dataset_dir,
+ cfg.data_type,
+ order=["image", "mask"]
+ )
+ elif cfg.dataset == "Cityscapes":
+ train_dataset = dataset.Cityscapes(
+ dataset_dir,
+ "train",
+ mode='gtFine',
+ order=["image", "mask"]
+ )
+ else:
+ raise ValueError("Unsupported dataset {}".format(cfg.dataset))
+
+ train_sampler = Infinite(RandomSampler(train_dataset, batch_size, drop_last=True))
+ train_dataloader = DataLoader(
+ train_dataset,
+ sampler=train_sampler,
+ transform=T.Compose(
+ transforms=[
+ T.RandomHorizontalFlip(0.5),
+ T.RandomResize(scale_range=(0.5, 2)),
+ T.RandomCrop(
+ output_size=(cfg.img_height, cfg.img_width),
+ padding_value=[0, 0, 0],
+ padding_maskvalue=255,
+ ),
+ T.Normalize(mean=cfg.img_mean, std=cfg.img_std),
+ T.ToMode(),
+ ],
+ order=["image", "mask"],
+ ),
+ num_workers=2,
+ )
+ return train_dataloader
+
+
+if __name__ == "__main__":
+ main()
diff --git a/official/vision/segmentation/tools/utils.py b/official/vision/segmentation/tools/utils.py
new file mode 100644
index 00000000..eef4c474
--- /dev/null
+++ b/official/vision/segmentation/tools/utils.py
@@ -0,0 +1,99 @@
+# -*- coding: utf-8 -*-
+# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
+#
+# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+import functools
+import importlib
+import math
+from tabulate import tabulate
+
+import numpy as np
+
+from megengine.data import Sampler
+
+
+class AverageMeter:
+ """Computes and stores the average and current value"""
+
+ def __init__(self, record_len=1):
+ self.record_len = record_len
+ self.reset()
+
+ def reset(self):
+ self.sum = [0 for i in range(self.record_len)]
+ self.cnt = 0
+
+ def update(self, val):
+ self.sum = [s + v for s, v in zip(self.sum, val)]
+ self.cnt += 1
+
+ def average(self):
+ return [s / self.cnt for s in self.sum]
+
+
+def import_from_file(cfg_file):
+ spec = importlib.util.spec_from_file_location("config", cfg_file)
+ cfg_module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(cfg_module)
+ return cfg_module
+
+
+def get_config_info(config):
+ config_table = []
+ for c, v in config.__dict__.items():
+ if not isinstance(v, (int, float, str, list, tuple, dict, np.ndarray)):
+ if hasattr(v, "__name__"):
+ v = v.__name__
+ elif hasattr(v, "__class__"):
+ v = v.__class__
+ elif isinstance(v, functools.partial):
+ v = v.func.__name__
+ config_table.append((str(c), str(v)))
+ config_table = tabulate(config_table)
+ return config_table
+
+
+class InferenceSampler(Sampler):
+ def __init__(self, dataset, batch_size=1, world_size=None, rank=None):
+ super().__init__(dataset, batch_size, False, None, world_size, rank)
+ begin = self.num_samples * self.rank
+ end = min(self.num_samples * (self.rank + 1), len(self.dataset))
+ self.indices = list(range(begin, end))
+
+ def batch(self):
+ step, length = self.batch_size, len(self.indices)
+ batch_index = [self.indices[i : i + step] for i in range(0, length, step)]
+ return iter(batch_index)
+
+ def __len__(self):
+ return int(math.ceil(len(self.indices) / self.batch_size))
+
+
+# pre-defined colors for at most 20 categories
+class_colors = [
+ [0, 0, 0], # background
+ [0, 0, 128],
+ [0, 128, 0],
+ [0, 128, 128],
+ [128, 0, 0],
+ [128, 0, 128],
+ [128, 128, 0],
+ [128, 128, 128],
+ [0, 0, 64],
+ [0, 0, 192],
+ [0, 128, 64],
+ [0, 128, 192],
+ [128, 0, 64],
+ [128, 0, 192],
+ [128, 128, 64],
+ [128, 128, 192],
+ [0, 64, 0],
+ [0, 64, 128],
+ [0, 192, 0],
+ [0, 192, 128],
+ [128, 64, 0],
+]
diff --git a/official/vision/segmentation/train.py b/official/vision/segmentation/train.py
deleted file mode 100644
index d4dff1fa..00000000
--- a/official/vision/segmentation/train.py
+++ /dev/null
@@ -1,200 +0,0 @@
-# -*- coding: utf-8 -*-
-# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
-#
-# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-import argparse
-import multiprocessing as mp
-import os
-
-import megengine as mge
-import megengine.data as data
-import megengine.data.dataset as dataset
-import megengine.data.transform as T
-import megengine.distributed as dist
-import megengine.jit as jit
-import megengine.optimizer as optim
-import numpy as np
-
-from official.vision.segmentation.deeplabv3plus import (
- DeepLabV3Plus,
- softmax_cross_entropy,
-)
-from official.vision.segmentation.utils import import_config_from_file
-
-logger = mge.get_logger(__name__)
-
-
-def main():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "-c", "--config", type=str, required=True, help="configuration file"
- )
- parser.add_argument(
- "-d", "--dataset_dir", type=str, default="/data/datasets/VOC2012",
- )
- parser.add_argument(
- "-w", "--weight_file", type=str, default=None, help="pre-train weights file",
- )
- parser.add_argument(
- "-n", "--ngpus", type=int, default=8, help="batchsize for training"
- )
- parser.add_argument(
- "-r", "--resume", type=str, default=None, help="resume model file"
- )
- args = parser.parse_args()
-
- world_size = args.ngpus
- logger.info("Device Count = %d", world_size)
- if world_size > 1:
- mp.set_start_method("spawn")
- processes = []
- for rank in range(world_size):
- p = mp.Process(target=worker, args=(rank, world_size, args))
- p.start()
- processes.append(p)
- for p in processes:
- p.join()
- else:
- worker(0, 1, args)
-
-
-def worker(rank, world_size, args):
- cfg = import_config_from_file(args.config)
-
- if world_size > 1:
- dist.init_process_group(
- master_ip="localhost",
- master_port=23456,
- world_size=world_size,
- rank=rank,
- dev=rank,
- )
- logger.info("Init process group done")
-
- logger.info("Prepare dataset")
- train_loader, epoch_size = build_dataloader(cfg.BATCH_SIZE, args.dataset_dir, cfg)
- batch_iter = epoch_size // (cfg.BATCH_SIZE * world_size)
-
- net = DeepLabV3Plus(class_num=cfg.NUM_CLASSES, pretrained=args.weight_file)
- base_lr = cfg.LEARNING_RATE * world_size
- optimizer = optim.SGD(
- net.parameters(requires_grad=True),
- lr=base_lr,
- momentum=0.9,
- weight_decay=0.00004,
- )
-
- @jit.trace(symbolic=True, opt_level=2)
- def train_func(data, label, net=None, optimizer=None):
- net.train()
- pred = net(data)
- loss = softmax_cross_entropy(pred, label, ignore_index=cfg.IGNORE_INDEX)
- optimizer.backward(loss)
- return pred, loss
-
- begin_epoch = 0
- end_epoch = cfg.EPOCHS
- if args.resume is not None:
- pretrained = mge.load(args.resume)
- begin_epoch = pretrained["epoch"] + 1
- net.load_state_dict(pretrained["state_dict"])
- logger.info("load success: epoch %d", begin_epoch)
-
- itr = begin_epoch * batch_iter
- max_itr = end_epoch * batch_iter
-
- image = mge.tensor(
- np.zeros([cfg.BATCH_SIZE, 3, cfg.IMG_HEIGHT, cfg.IMG_WIDTH]).astype(np.float32),
- dtype="float32",
- )
- label = mge.tensor(
- np.zeros([cfg.BATCH_SIZE, cfg.IMG_HEIGHT, cfg.IMG_WIDTH]).astype(np.int32),
- dtype="int32",
- )
- exp_name = os.path.abspath(os.path.dirname(__file__)).split("/")[-1]
-
- for epoch in range(begin_epoch, end_epoch):
- for i_batch, sample_batched in enumerate(train_loader):
-
- def adjust_lr(optimizer, itr, max_itr):
- now_lr = base_lr * (1 - itr / (max_itr + 1)) ** 0.9
- for param_group in optimizer.param_groups:
- param_group["lr"] = now_lr
- return now_lr
-
- now_lr = adjust_lr(optimizer, itr, max_itr)
- inputs_batched, labels_batched = sample_batched
- labels_batched = np.squeeze(labels_batched, axis=1).astype(np.int32)
- image.set_value(inputs_batched)
- label.set_value(labels_batched)
-
- optimizer.zero_grad()
- _, loss = train_func(image, label, net=net, optimizer=optimizer)
- optimizer.step()
- running_loss = loss.numpy()[0]
-
- if rank == 0:
- logger.info(
- "%s epoch:%d/%d\tbatch:%d/%d\titr:%d\tlr:%g\tloss:%g",
- exp_name,
- epoch,
- end_epoch,
- i_batch,
- batch_iter,
- itr + 1,
- now_lr,
- running_loss,
- )
- itr += 1
-
- if rank == 0:
- save_path = os.path.join(cfg.MODEL_SAVE_DIR, "epoch%d.pkl" % (epoch))
- mge.save({"epoch": epoch, "state_dict": net.state_dict()}, save_path)
- logger.info("save epoch%d", epoch)
-
-
-def build_dataloader(batch_size, dataset_dir, cfg):
- if cfg.DATASET == "VOC2012":
- train_dataset = dataset.PascalVOC(
- dataset_dir,
- cfg.DATA_TYPE,
- order=["image", "mask"]
- )
- elif cfg.DATASET == "Cityscapes":
- train_dataset = dataset.Cityscapes(
- dataset_dir,
- "train",
- mode='gtFine',
- order=["image", "mask"]
- )
- else:
- raise ValueError("Unsupported dataset {}".format(cfg.DATASET))
- train_sampler = data.RandomSampler(train_dataset, batch_size, drop_last=True)
- train_dataloader = data.DataLoader(
- train_dataset,
- sampler=train_sampler,
- transform=T.Compose(
- transforms=[
- T.RandomHorizontalFlip(0.5),
- T.RandomResize(scale_range=(0.5, 2)),
- T.RandomCrop(
- output_size=(cfg.IMG_HEIGHT, cfg.IMG_WIDTH),
- padding_value=[0, 0, 0],
- padding_maskvalue=255,
- ),
- T.Normalize(mean=cfg.IMG_MEAN, std=cfg.IMG_STD),
- T.ToMode(),
- ],
- order=["image", "mask"],
- ),
- num_workers=0,
- )
- return train_dataloader, train_dataset.__len__()
-
-
-if __name__ == "__main__":
- main()
diff --git a/official/vision/segmentation/utils.py b/official/vision/segmentation/utils.py
deleted file mode 100644
index 8da4fc0e..00000000
--- a/official/vision/segmentation/utils.py
+++ /dev/null
@@ -1,10 +0,0 @@
-import importlib.util
-import os
-
-
-def import_config_from_file(cfg_file):
- assert os.path.exists(cfg_file), "config file {} not exists".format(cfg_file)
- spec = importlib.util.spec_from_file_location("config", cfg_file)
- cfg_module = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(cfg_module)
- return cfg_module.cfg