forked from open-mmlab/mmdetection
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Tools] Support replacing the ${key} with the value of cfg.key (open-…
…mmlab#7492) * Support replacing config * Support replacing config * Add unit test for replace_cfig * pre-commit * fix * modify the docstring * rename function * fix a bug * fix a bug and simplify the code * simplify the code * add replace_cfg_vals for some scripts * add replace_cfg_vals for some scripts * add some unit tests
- Loading branch information
Showing
9 changed files
with
183 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import re | ||
|
||
from mmcv.utils import Config | ||
|
||
|
||
def replace_cfg_vals(ori_cfg): | ||
"""Replace the string "${key}" with the corresponding value. | ||
Replace the "${key}" with the value of ori_cfg.key in the config. And | ||
support replacing the chained ${key}. Such as, replace "${key0.key1}" | ||
with the value of cfg.key0.key1. Code is modified from `vars.py | ||
< https://github.com/microsoft/SoftTeacher/blob/main/ssod/utils/vars.py>`_ # noqa: E501 | ||
Args: | ||
ori_cfg (mmcv.utils.config.Config): | ||
The origin config with "${key}" generated from a file. | ||
Returns: | ||
updated_cfg [mmcv.utils.config.Config]: | ||
The config with "${key}" replaced by the corresponding value. | ||
""" | ||
|
||
def get_value(cfg, key): | ||
for k in key.split('.'): | ||
cfg = cfg[k] | ||
return cfg | ||
|
||
def replace_value(cfg): | ||
if isinstance(cfg, dict): | ||
return {key: replace_value(value) for key, value in cfg.items()} | ||
elif isinstance(cfg, list): | ||
return [replace_value(item) for item in cfg] | ||
elif isinstance(cfg, tuple): | ||
return tuple([replace_value(item) for item in cfg]) | ||
elif isinstance(cfg, str): | ||
# the format of string cfg may be: | ||
# 1) "${key}", which will be replaced with cfg.key directly | ||
# 2) "xxx${key}xxx" or "xxx${key1}xxx${key2}xxx", | ||
# which will be replaced with the string of the cfg.key | ||
keys = pattern_key.findall(cfg) | ||
values = [get_value(ori_cfg, key[2:-1]) for key in keys] | ||
if len(keys) == 1 and keys[0] == cfg: | ||
# the format of string cfg is "${key}" | ||
cfg = values[0] | ||
else: | ||
for key, value in zip(keys, values): | ||
# the format of string cfg is | ||
# "xxx${key}xxx" or "xxx${key1}xxx${key2}xxx" | ||
assert not isinstance(value, (dict, list, tuple)), \ | ||
f'for the format of string cfg is ' \ | ||
f"'xxxxx${key}xxxxx' or 'xxx${key}xxx${key}xxx', " \ | ||
f"the type of the value of '${key}' " \ | ||
f'can not be dict, list, or tuple' \ | ||
f'but you input {type(value)} in {cfg}' | ||
cfg = cfg.replace(key, str(value)) | ||
return cfg | ||
else: | ||
return cfg | ||
|
||
# the pattern of string "${key}" | ||
pattern_key = re.compile(r'\$\{[a-zA-Z\d_.]*\}') | ||
# the type of ori_cfg._cfg_dict is mmcv.utils.config.ConfigDict | ||
updated_cfg = Config( | ||
replace_value(ori_cfg._cfg_dict), filename=ori_cfg.filename) | ||
# replace the model with model_wrapper | ||
if updated_cfg.get('model_wrapper', None) is not None: | ||
updated_cfg.model = updated_cfg.model_wrapper | ||
updated_cfg.pop('model_wrapper') | ||
return updated_cfg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import os.path as osp | ||
import tempfile | ||
from copy import deepcopy | ||
|
||
import pytest | ||
from mmcv.utils import Config | ||
|
||
from mmdet.utils import replace_cfg_vals | ||
|
||
|
||
def test_replace_cfg_vals(): | ||
temp_file = tempfile.NamedTemporaryFile() | ||
cfg_path = f'{temp_file.name}.py' | ||
with open(cfg_path, 'w') as f: | ||
f.write('configs') | ||
|
||
ori_cfg_dict = dict() | ||
ori_cfg_dict['cfg_name'] = osp.basename(temp_file.name) | ||
ori_cfg_dict['work_dir'] = 'work_dirs/${cfg_name}/${percent}/${fold}' | ||
ori_cfg_dict['percent'] = 5 | ||
ori_cfg_dict['fold'] = 1 | ||
ori_cfg_dict['model_wrapper'] = dict( | ||
type='SoftTeacher', detector='${model}') | ||
ori_cfg_dict['model'] = dict( | ||
type='FasterRCNN', | ||
backbone=dict(type='ResNet'), | ||
neck=dict(type='FPN'), | ||
rpn_head=dict(type='RPNHead'), | ||
roi_head=dict(type='StandardRoIHead'), | ||
train_cfg=dict( | ||
rpn=dict( | ||
assigner=dict(type='MaxIoUAssigner'), | ||
sampler=dict(type='RandomSampler'), | ||
), | ||
rpn_proposal=dict(nms=dict(type='nms', iou_threshold=0.7)), | ||
rcnn=dict( | ||
assigner=dict(type='MaxIoUAssigner'), | ||
sampler=dict(type='RandomSampler'), | ||
), | ||
), | ||
test_cfg=dict( | ||
rpn=dict(nms=dict(type='nms', iou_threshold=0.7)), | ||
rcnn=dict(nms=dict(type='nms', iou_threshold=0.5)), | ||
), | ||
) | ||
ori_cfg_dict['iou_threshold'] = dict( | ||
rpn_proposal_nms='${model.train_cfg.rpn_proposal.nms.iou_threshold}', | ||
test_rpn_nms='${model.test_cfg.rpn.nms.iou_threshold}', | ||
test_rcnn_nms='${model.test_cfg.rcnn.nms.iou_threshold}', | ||
) | ||
|
||
ori_cfg_dict['str'] = 'Hello, world!' | ||
ori_cfg_dict['dict'] = {'Hello': 'world!'} | ||
ori_cfg_dict['list'] = [ | ||
'Hello, world!', | ||
] | ||
ori_cfg_dict['tuple'] = ('Hello, world!', ) | ||
ori_cfg_dict['test_str'] = 'xxx${str}xxx' | ||
|
||
ori_cfg = Config(ori_cfg_dict, filename=cfg_path) | ||
updated_cfg = replace_cfg_vals(deepcopy(ori_cfg)) | ||
|
||
assert updated_cfg.work_dir \ | ||
== f'work_dirs/{osp.basename(temp_file.name)}/5/1' | ||
assert updated_cfg.model.detector == ori_cfg.model | ||
assert updated_cfg.iou_threshold.rpn_proposal_nms \ | ||
== ori_cfg.model.train_cfg.rpn_proposal.nms.iou_threshold | ||
assert updated_cfg.test_str == 'xxxHello, world!xxx' | ||
ori_cfg_dict['test_dict'] = 'xxx${dict}xxx' | ||
ori_cfg_dict['test_list'] = 'xxx${list}xxx' | ||
ori_cfg_dict['test_tuple'] = 'xxx${tuple}xxx' | ||
with pytest.raises(AssertionError): | ||
cfg = deepcopy(ori_cfg) | ||
cfg['test_dict'] = 'xxx${dict}xxx' | ||
updated_cfg = replace_cfg_vals(cfg) | ||
with pytest.raises(AssertionError): | ||
cfg = deepcopy(ori_cfg) | ||
cfg['test_list'] = 'xxx${list}xxx' | ||
updated_cfg = replace_cfg_vals(cfg) | ||
with pytest.raises(AssertionError): | ||
cfg = deepcopy(ori_cfg) | ||
cfg['test_tuple'] = 'xxx${tuple}xxx' | ||
updated_cfg = replace_cfg_vals(cfg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters