Skip to content

Commit

Permalink
update config and config factory (#1893)
Browse files Browse the repository at this point in the history
* update config and config factory

* update config and config factory

* address comments
  • Loading branch information
chesterxgchen committed Aug 2, 2023
1 parent fd724ab commit 23d5c72
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 5 deletions.
12 changes: 11 additions & 1 deletion nvflare/fuel/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from abc import ABC, abstractmethod
from collections import OrderedDict
from enum import Enum
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional


class ConfigFormat(Enum):
Expand All @@ -36,6 +36,13 @@ def config_ext_formats(cls):
}
)

@classmethod
def extensions(cls, target_fmt=None) -> List[str]:
if target_fmt is None:
return [ext for ext, fmt in cls.config_ext_formats().items()]
else:
return [ext for ext, fmt in cls.config_ext_formats().items() if fmt == target_fmt]


class Config(ABC):
def __init__(self, conf: Any, fmt: ConfigFormat, file_path: Optional[str] = None):
Expand All @@ -50,6 +57,9 @@ def get_format(self) -> ConfigFormat:
"""
return self.format

def get_exts(self) -> List[str]:
return ConfigFormat.extensions(self.format)

def get_native_conf(self):
"""Return the original underline config object representation if you prefer to use it directly
Pyhocon → ConfigTree
Expand Down
17 changes: 13 additions & 4 deletions nvflare/fuel/utils/config_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class ConfigFactory:

@staticmethod
def search_config_format(
init_file_path, search_dirs: Optional[List[str]] = None
init_file_path: str, search_dirs: Optional[List[str]] = None, target_fmt: Optional[ConfigFormat] = None
) -> Tuple[Optional[ConfigFormat], Optional[str]]:

"""find the configuration format and the location (file_path) for given initial init_file_path and search directories.
Expand All @@ -55,6 +55,7 @@ def search_config_format(
Args:
init_file_path: initial file_path for the configuration
search_dirs: search directory. If none, the parent directory of init_file_path will be used as search dir
target_fmt: (ConfigFormat) if specified, only this format searched, ignore all other formats.
Returns:
Tuple of None,None or ConfigFormat and real configuration file path
Expand All @@ -65,12 +66,17 @@ def search_config_format(
parent_dir = pathlib.Path(init_file_path).parent
search_dirs = [str(parent_dir)]

target_exts = None
if target_fmt:
target_exts = ConfigFormat.extensions(target_fmt)

# we ignore the original extension
file_basename = ConfigFactory.get_file_basename(init_file_path)
ext2fmt_map = ConfigFormat.config_ext_formats()
extensions = target_exts if target_fmt else ext2fmt_map.keys()
for search_dir in search_dirs:
logger.debug(f"search file basename:'{file_basename}', search dirs = {search_dirs}")
for ext in ext2fmt_map:
for ext in extensions:
fmt = ext2fmt_map[ext]
filename = f"{file_basename}{ext}"
for root, dirs, files in os.walk(search_dir):
Expand All @@ -87,7 +93,9 @@ def get_file_basename(init_file_path):
return file_basename

@staticmethod
def load_config(file_path: str, search_dirs: Optional[List[str]] = None) -> Optional[Config]:
def load_config(
file_path: str, search_dirs: Optional[List[str]] = None, target_fmt: Optional[ConfigFormat] = None
) -> Optional[Config]:

"""Find the configuration for given initial init_file_path and search directories.
for example, the initial config file path given is config_client.json
Expand All @@ -97,12 +105,13 @@ def load_config(file_path: str, search_dirs: Optional[List[str]] = None) -> Opti
Args:
file_path: initial file path
search_dirs: search directory. If none, the parent directory of init_file_path will be used as search dir
target_fmt: (ConfigFormat) if specified, only this format searched, ignore all other formats.
Returns:
None if not found, or Config
"""
config_format, real_config_file_path = ConfigFactory.search_config_format(file_path, search_dirs)
config_format, real_config_file_path = ConfigFactory.search_config_format(file_path, search_dirs, target_fmt)
if config_format is not None and real_config_file_path is not None:
config_loader = ConfigFactory.get_config_loader(config_format)
if config_loader:
Expand Down
12 changes: 12 additions & 0 deletions tests/unit_test/fuel/utils/config_format_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,15 @@ def test_config_exts(self):
def test_config_exts2(self):
exts2fmt_map = ConfigFormat.config_ext_formats()
assert "|".join(exts2fmt_map.keys()) == ".json|.conf|.yml|.json.default|.conf.default|.yml.default"

def test_config_exts3(self):
exts = ConfigFormat.extensions()
assert "|".join(exts) == ".json|.conf|.yml|.json.default|.conf.default|.yml.default"

def test_config_exts4(self):
exts = ConfigFormat.extensions(target_fmt=ConfigFormat.JSON)
assert "|".join(exts) == ".json|.json.default"
exts = ConfigFormat.extensions(target_fmt=ConfigFormat.OMEGACONF)
assert "|".join(exts) == ".yml|.yml.default"
exts = ConfigFormat.extensions(target_fmt=ConfigFormat.PYHOCON)
assert "|".join(exts) == ".conf|.conf.default"

0 comments on commit 23d5c72

Please sign in to comment.