Skip to content

Commit

Permalink
[Tools] Support replacing the ${key} with the value of cfg.key (open-…
Browse files Browse the repository at this point in the history
…mmlab#7492)

* Support replacing config

* Support replacing config

* Add unit test for replace_cfig

* pre-commit

* fix

* modify the docstring

* rename function

* fix a bug

* fix a bug and simplify the code

* simplify the code

* add replace_cfg_vals for some scripts

* add replace_cfg_vals for some scripts

* add some unit tests
  • Loading branch information
Czm369 authored and ZwwWayne committed Jul 19, 2022
1 parent 7ee1f61 commit 1d7786e
Show file tree
Hide file tree
Showing 9 changed files with 183 additions and 6 deletions.
3 changes: 2 additions & 1 deletion mmdet/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .misc import find_latest_checkpoint, update_data_root
from .parallel import MMDataParallel, MMDistributedDataParallel
from .setup_env import register_all_modules, setup_multi_processes
from .replace_cfg_vals import replace_cfg_vals
from .split_batch import split_batch
from .util_distribution import build_ddp, build_dp, get_device

Expand All @@ -13,5 +14,5 @@
'update_data_root', 'setup_multi_processes', 'get_caller_name',
'log_img_scale', 'compat_cfg', 'split_batch', 'build_ddp', 'build_dp',
'get_device', 'MMDataParallel', 'MMDistributedDataParallel',
'register_all_modules'
'register_all_modules', 'replace_cfg_vals'
]
70 changes: 70 additions & 0 deletions mmdet/utils/replace_cfg_vals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) OpenMMLab. All rights reserved.
import re

from mmcv.utils import Config


def replace_cfg_vals(ori_cfg):
"""Replace the string "${key}" with the corresponding value.
Replace the "${key}" with the value of ori_cfg.key in the config. And
support replacing the chained ${key}. Such as, replace "${key0.key1}"
with the value of cfg.key0.key1. Code is modified from `vars.py
< https://github.com/microsoft/SoftTeacher/blob/main/ssod/utils/vars.py>`_ # noqa: E501
Args:
ori_cfg (mmcv.utils.config.Config):
The origin config with "${key}" generated from a file.
Returns:
updated_cfg [mmcv.utils.config.Config]:
The config with "${key}" replaced by the corresponding value.
"""

def get_value(cfg, key):
for k in key.split('.'):
cfg = cfg[k]
return cfg

def replace_value(cfg):
if isinstance(cfg, dict):
return {key: replace_value(value) for key, value in cfg.items()}
elif isinstance(cfg, list):
return [replace_value(item) for item in cfg]
elif isinstance(cfg, tuple):
return tuple([replace_value(item) for item in cfg])
elif isinstance(cfg, str):
# the format of string cfg may be:
# 1) "${key}", which will be replaced with cfg.key directly
# 2) "xxx${key}xxx" or "xxx${key1}xxx${key2}xxx",
# which will be replaced with the string of the cfg.key
keys = pattern_key.findall(cfg)
values = [get_value(ori_cfg, key[2:-1]) for key in keys]
if len(keys) == 1 and keys[0] == cfg:
# the format of string cfg is "${key}"
cfg = values[0]
else:
for key, value in zip(keys, values):
# the format of string cfg is
# "xxx${key}xxx" or "xxx${key1}xxx${key2}xxx"
assert not isinstance(value, (dict, list, tuple)), \
f'for the format of string cfg is ' \
f"'xxxxx${key}xxxxx' or 'xxx${key}xxx${key}xxx', " \
f"the type of the value of '${key}' " \
f'can not be dict, list, or tuple' \
f'but you input {type(value)} in {cfg}'
cfg = cfg.replace(key, str(value))
return cfg
else:
return cfg

# the pattern of string "${key}"
pattern_key = re.compile(r'\$\{[a-zA-Z\d_.]*\}')
# the type of ori_cfg._cfg_dict is mmcv.utils.config.ConfigDict
updated_cfg = Config(
replace_value(ori_cfg._cfg_dict), filename=ori_cfg.filename)
# replace the model with model_wrapper
if updated_cfg.get('model_wrapper', None) is not None:
updated_cfg.model = updated_cfg.model_wrapper
updated_cfg.pop('model_wrapper')
return updated_cfg
83 changes: 83 additions & 0 deletions tests/test_utils/test_replace_cfg_vals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import os.path as osp
import tempfile
from copy import deepcopy

import pytest
from mmcv.utils import Config

from mmdet.utils import replace_cfg_vals


def test_replace_cfg_vals():
temp_file = tempfile.NamedTemporaryFile()
cfg_path = f'{temp_file.name}.py'
with open(cfg_path, 'w') as f:
f.write('configs')

ori_cfg_dict = dict()
ori_cfg_dict['cfg_name'] = osp.basename(temp_file.name)
ori_cfg_dict['work_dir'] = 'work_dirs/${cfg_name}/${percent}/${fold}'
ori_cfg_dict['percent'] = 5
ori_cfg_dict['fold'] = 1
ori_cfg_dict['model_wrapper'] = dict(
type='SoftTeacher', detector='${model}')
ori_cfg_dict['model'] = dict(
type='FasterRCNN',
backbone=dict(type='ResNet'),
neck=dict(type='FPN'),
rpn_head=dict(type='RPNHead'),
roi_head=dict(type='StandardRoIHead'),
train_cfg=dict(
rpn=dict(
assigner=dict(type='MaxIoUAssigner'),
sampler=dict(type='RandomSampler'),
),
rpn_proposal=dict(nms=dict(type='nms', iou_threshold=0.7)),
rcnn=dict(
assigner=dict(type='MaxIoUAssigner'),
sampler=dict(type='RandomSampler'),
),
),
test_cfg=dict(
rpn=dict(nms=dict(type='nms', iou_threshold=0.7)),
rcnn=dict(nms=dict(type='nms', iou_threshold=0.5)),
),
)
ori_cfg_dict['iou_threshold'] = dict(
rpn_proposal_nms='${model.train_cfg.rpn_proposal.nms.iou_threshold}',
test_rpn_nms='${model.test_cfg.rpn.nms.iou_threshold}',
test_rcnn_nms='${model.test_cfg.rcnn.nms.iou_threshold}',
)

ori_cfg_dict['str'] = 'Hello, world!'
ori_cfg_dict['dict'] = {'Hello': 'world!'}
ori_cfg_dict['list'] = [
'Hello, world!',
]
ori_cfg_dict['tuple'] = ('Hello, world!', )
ori_cfg_dict['test_str'] = 'xxx${str}xxx'

ori_cfg = Config(ori_cfg_dict, filename=cfg_path)
updated_cfg = replace_cfg_vals(deepcopy(ori_cfg))

assert updated_cfg.work_dir \
== f'work_dirs/{osp.basename(temp_file.name)}/5/1'
assert updated_cfg.model.detector == ori_cfg.model
assert updated_cfg.iou_threshold.rpn_proposal_nms \
== ori_cfg.model.train_cfg.rpn_proposal.nms.iou_threshold
assert updated_cfg.test_str == 'xxxHello, world!xxx'
ori_cfg_dict['test_dict'] = 'xxx${dict}xxx'
ori_cfg_dict['test_list'] = 'xxx${list}xxx'
ori_cfg_dict['test_tuple'] = 'xxx${tuple}xxx'
with pytest.raises(AssertionError):
cfg = deepcopy(ori_cfg)
cfg['test_dict'] = 'xxx${dict}xxx'
updated_cfg = replace_cfg_vals(cfg)
with pytest.raises(AssertionError):
cfg = deepcopy(ori_cfg)
cfg['test_list'] = 'xxx${list}xxx'
updated_cfg = replace_cfg_vals(cfg)
with pytest.raises(AssertionError):
cfg = deepcopy(ori_cfg)
cfg['test_tuple'] = 'xxx${tuple}xxx'
updated_cfg = replace_cfg_vals(cfg)
5 changes: 4 additions & 1 deletion tools/analysis_tools/analyze_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from mmdet.core.evaluation import eval_map
from mmdet.core.visualization import imshow_gt_det_bboxes
from mmdet.datasets import build_dataset, get_loading_pipeline
from mmdet.utils import update_data_root
from mmdet.utils import replace_cfg_vals, update_data_root


def bbox_map_eval(det_result, annotation):
Expand Down Expand Up @@ -188,6 +188,9 @@ def main():

cfg = Config.fromfile(args.config)

# replace the ${key} with the value of cfg.key
cfg = replace_cfg_vals(cfg)

# update data root according to MMDET_DATASETS
update_data_root(cfg)

Expand Down
5 changes: 4 additions & 1 deletion tools/analysis_tools/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps
from mmdet.datasets import build_dataset
from mmdet.utils import update_data_root
from mmdet.utils import replace_cfg_vals, update_data_root


def parse_args():
Expand Down Expand Up @@ -232,6 +232,9 @@ def main():

cfg = Config.fromfile(args.config)

# replace the ${key} with the value of cfg.key
cfg = replace_cfg_vals(cfg)

# update data root according to MMDET_DATASETS
update_data_root(cfg)

Expand Down
5 changes: 4 additions & 1 deletion tools/analysis_tools/eval_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from mmcv import Config, DictAction

from mmdet.datasets import build_dataset
from mmdet.utils import update_data_root
from mmdet.utils import replace_cfg_vals, update_data_root


def parse_args():
Expand Down Expand Up @@ -50,6 +50,9 @@ def main():

cfg = Config.fromfile(args.config)

# replace the ${key} with the value of cfg.key
cfg = replace_cfg_vals(cfg)

# update data root according to MMDET_DATASETS
update_data_root(cfg)

Expand Down
5 changes: 4 additions & 1 deletion tools/analysis_tools/optimize_anchors.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from mmdet.core import bbox_cxcywh_to_xyxy, bbox_overlaps, bbox_xyxy_to_cxcywh
from mmdet.datasets import build_dataset
from mmdet.utils import get_root_logger, update_data_root
from mmdet.utils import get_root_logger, replace_cfg_vals, update_data_root


def parse_args():
Expand Down Expand Up @@ -325,6 +325,9 @@ def main():
cfg = args.config
cfg = Config.fromfile(cfg)

# replace the ${key} with the value of cfg.key
cfg = replace_cfg_vals(cfg)

# update data root according to MMDET_DATASETS
update_data_root(cfg)

Expand Down
5 changes: 4 additions & 1 deletion tools/misc/print_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from mmcv import Config, DictAction

from mmdet.utils import update_data_root
from mmdet.utils import replace_cfg_vals, update_data_root


def parse_args():
Expand Down Expand Up @@ -45,6 +45,9 @@ def main():

cfg = Config.fromfile(args.config)

# replace the ${key} with the value of cfg.key
cfg = replace_cfg_vals(cfg)

# update data root according to MMDET_DATASETS
update_data_root(cfg)

Expand Down
8 changes: 8 additions & 0 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,17 @@ def trigger_visualization_hook(cfg, args):
def main():
args = parse_args()

<<<<<<< HEAD
# register all modules in mmdet into the registries
# do not init the default scope here because it will be init in the runner
register_all_modules(init_default_scope=False)
=======
# replace the ${key} with the value of cfg.key
cfg = replace_cfg_vals(cfg)

# update data root according to MMDET_DATASETS
update_data_root(cfg)
>>>>>>> 0db1b9b3 ([Tools] Support replacing the ${key} with the value of cfg.key (#7492))

# load config
cfg = Config.fromfile(args.config)
Expand Down

0 comments on commit 1d7786e

Please sign in to comment.