Skip to content

Commit

Permalink
add ppyoloe-seg
Browse files Browse the repository at this point in the history
  • Loading branch information
MINGtoMING committed Oct 2, 2023
1 parent 7d6dc40 commit 80fd251
Show file tree
Hide file tree
Showing 11 changed files with 905 additions and 8 deletions.
35 changes: 35 additions & 0 deletions configs/ppyoloe_seg/README_cn.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# PP-YOLOE-Seg


## 简介
PP-YOLOE-Seg是基于PP-YOLOE并结合YOLACT的实时实例分割模型。

### 训练

请执行以下指令训练

```bash
python -m paddle.distributed.launch --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/ppyoloe/ppyoloe_plus_seg_crn_l_80e_coco.yml --eval --amp
```
### 评估

执行以下命令在单个GPU上评估COCO val2017数据集

```bash
CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c configs/ppyoloe/ppyoloe_plus_seg_crn_l_80e_coco.yml -o weights=${model_weights}
```

在coco test-dev2017上评估,请先从[COCO数据集下载](https://cocodataset.org/#download)下载COCO test-dev2017数据集,然后解压到COCO数据集文件夹并像`configs/ppyolo/ppyolo_test.yml`一样配置`EvalDataset`

### 推理

使用以下命令在单张GPU上预测图片,使用`--infer_img`推理单张图片以及使用`--infer_dir`推理文件中的所有图片。


```bash
# 推理单张图片
CUDA_VISIBLE_DEVICES=0 python tools/infer.py -c configs/ppyoloe/ppyoloe_plus_seg_crn_l_80e_coco.yml -o weights=${model_weights} --infer_img=demo/000000014439_640x640.jpg

# 推理文件中的所有图片
CUDA_VISIBLE_DEVICES=0 python tools/infer.py -c configs/ppyoloe/ppyoloe_plus_crn_l_80e_coco.yml -o weights=${model_weights} --infer_dir=demo
```
18 changes: 18 additions & 0 deletions configs/ppyoloe_seg/_base_/optimizer_80e.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
epoch: 80

LearningRate:
base_lr: 0.001
schedulers:
- name: CosineDecay
max_epochs: 96
- name: LinearWarmup
start_factor: 0.
epochs: 5

OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
50 changes: 50 additions & 0 deletions configs/ppyoloe_seg/_base_/ppyoloe_plus_seg_crn.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
architecture: PPYOLOE
with_mask: True
norm_type: sync_bn
use_ema: true
ema_decay: 0.9998
ema_black_list: ['proj_conv.weight']
custom_black_list: ['reduce_mean']

PPYOLOE:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOESegHead
post_process: ~

CSPResNet:
layers: [3, 6, 6, 3]
channels: [64, 128, 256, 512, 1024]
return_idx: [1, 2, 3]
use_large_stem: True
use_alpha: True

CustomCSPPAN:
out_channels: [768, 384, 192]
stage_num: 1
block_num: 3
act: 'swish'
spp: true

PPYOLOESegHead:
fpn_strides: [32, 16, 8]
grid_cell_scale: 5.0
grid_cell_offset: 0.5
static_assigner_epoch: 30
use_varifocal_loss: True
loss_weight: {class: 1.0, iou: 2.5, dfl: 0.5, mask: 2.5, dice: 2.5}
static_assigner:
name: ATSSAssigner
topk: 9
assigner:
name: TaskAlignedAssigner
topk: 13
alpha: 1.0
beta: 6.0
nms:
name: MultiClassNMS
return_index: True
nms_top_k: 1000
keep_top_k: 300
score_threshold: 0.01
nms_threshold: 0.7
41 changes: 41 additions & 0 deletions configs/ppyoloe_seg/_base_/ppyoloe_plus_seg_reader.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
worker_num: 4
eval_height: &eval_height 640
eval_width: &eval_width 640
eval_size: &eval_size [*eval_height, *eval_width]

TrainReader:
sample_transforms:
- Decode: {}
- Poly2Mask: {del_poly: True}
- RandomDistort: {}
- RandomExpand: {fill_value: [123.675, 116.28, 103.53]}
- RandomCrop: {}
- RandomFlip: {}
batch_transforms:
- BatchRandomResize: {target_size: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608, 640, 672, 704, 736, 768], random_size: True, random_interp: True, keep_ratio: False}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {}
- PadGT: {}
batch_size: 8
shuffle: true
drop_last: true
use_shared_memory: true
collate_batch: true

EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {}
batch_size: 2

TestReader:
inputs_def:
image_shape: [3, *eval_height, *eval_width]
sample_transforms:
- Decode: {}
- Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {}
batch_size: 1
16 changes: 16 additions & 0 deletions configs/ppyoloe_seg/ppyoloe_plus_seg_crn_l_80e_coco.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
_BASE_: [
'../datasets/coco_instance.yml',
'../runtime.yml',
'./_base_/optimizer_80e.yml',
'./_base_/ppyoloe_plus_seg_crn.yml',
'./_base_/ppyoloe_plus_seg_reader.yml',
]

log_iter: 100
snapshot_epoch: 5
weights: output/ppyoloe_plus_seg_crn_l_80e_coco/model_final


pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_l_80e_coco.pdparams
depth_mult: 1.0
width_mult: 1.0
15 changes: 15 additions & 0 deletions configs/ppyoloe_seg/ppyoloe_plus_seg_crn_s_80e_coco.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
_BASE_: [
'../datasets/coco_instance.yml',
'../runtime.yml',
'./_base_/optimizer_80e.yml',
'./_base_/ppyoloe_plus_seg_crn.yml',
'./_base_/ppyoloe_plus_seg_reader.yml',
]

log_iter: 100
snapshot_epoch: 5
weights: output/ppyoloe_plus_seg_crn_s_80e_coco/model_final

pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/ppyoloe_crn_s_obj365_pretrained.pdparams
depth_mult: 0.33
width_mult: 0.50
7 changes: 7 additions & 0 deletions ppdet/data/transform/batch_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,13 @@ def __call__(self, samples, context=None):
if num_gt > 0:
pad_gt_areas[:num_gt, 0] = sample['gt_areas']
sample['gt_areas'] = pad_gt_areas
if 'gt_segm' in sample:
pad_gt_segm = np.zeros(
(num_max_boxes, *sample['gt_segm'].shape[-2:]),
dtype=np.uint8)
if num_gt > 0:
pad_gt_segm[:num_gt] = sample['gt_segm']
sample['gt_segm'] = pad_gt_segm
return samples


Expand Down
14 changes: 14 additions & 0 deletions ppdet/data/transform/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2352,6 +2352,20 @@ def apply(self, sample, context=None):
sample['gt_keypoint'] = self.apply_keypoint(sample['gt_keypoint'],
offsets)

if 'gt_segm' in sample and len(sample['gt_segm']) > 0:
masks = [
cv2.copyMakeBorder(
gt_segm,
offset_y,
h - (offset_y + im_h),
offset_x,
w - (offset_x + im_w),
borderType=cv2.BORDER_CONSTANT,
value=[0, ])
for gt_segm in sample['gt_segm']
]
sample['gt_segm'] = np.asarray(masks).astype(np.uint8)

return sample


Expand Down
28 changes: 21 additions & 7 deletions ppdet/modeling/architectures/ppyoloe.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class PPYOLOE(BaseArch):
"""

__category__ = 'architecture'
__shared__ = ['for_distill']
__shared__ = ['for_distill', 'with_mask']
__inject__ = ['post_process', 'ssod_loss']

def __init__(self,
Expand All @@ -55,13 +55,15 @@ def __init__(self,
ssod_loss='SSODPPYOLOELoss',
for_distill=False,
feat_distill_place='neck_feats',
for_mot=False):
for_mot=False,
with_mask=False):
super(PPYOLOE, self).__init__()
self.backbone = backbone
self.neck = neck
self.yolo_head = yolo_head
self.post_process = post_process
self.for_mot = for_mot
self.with_mask = with_mask

# for ssod, semi-det
self.is_teacher = False
Expand Down Expand Up @@ -110,13 +112,22 @@ def _forward(self):
yolo_head_outs = self.yolo_head(neck_feats)

if self.post_process is not None:
bbox, bbox_num, nms_keep_idx = self.post_process(
yolo_head_outs, self.yolo_head.mask_anchors,
self.inputs['im_shape'], self.inputs['scale_factor'])
if not self.with_mask:
bbox, bbox_num, nms_keep_idx = self.post_process(
yolo_head_outs, self.yolo_head.mask_anchors,
self.inputs['im_shape'], self.inputs['scale_factor'])
else:
bbox, bbox_num, mask, nms_keep_idx = self.post_process(
yolo_head_outs, self.yolo_head.mask_anchors,
self.inputs['im_shape'], self.inputs['scale_factor'])

else:
bbox, bbox_num, nms_keep_idx = self.yolo_head.post_process(
yolo_head_outs, self.inputs['scale_factor'])
if not self.with_mask:
bbox, bbox_num, nms_keep_idx = self.yolo_head.post_process(
yolo_head_outs, self.inputs['scale_factor'])
else:
bbox, bbox_num, mask, nms_keep_idx = self.yolo_head.post_process(
yolo_head_outs, self.inputs['scale_factor'])

if self.use_extra_data:
extra_data = {} # record the bbox output before nms, such like scores and nms_keep_idx
Expand All @@ -131,6 +142,9 @@ def _forward(self):
else:
output = {'bbox': bbox, 'bbox_num': bbox_num}

if self.with_mask:
output['mask'] = mask

return output

def get_loss(self):
Expand Down
4 changes: 3 additions & 1 deletion ppdet/modeling/heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from . import sparse_roi_head
from . import vitpose_head
from . import clrnet_head
from . import ppyoloe_seg_head

from .bbox_head import *
from .mask_head import *
Expand Down Expand Up @@ -71,4 +72,5 @@
from .sparse_roi_head import *
from .petr_head import *
from .vitpose_head import *
from .clrnet_head import *
from .clrnet_head import *
from .ppyoloe_seg_head import *

0 comments on commit 80fd251

Please sign in to comment.