diff --git a/common_utility/configLoader.py b/common_utility/configLoader.py index b505e02..5d202b7 100644 --- a/common_utility/configLoader.py +++ b/common_utility/configLoader.py @@ -2,57 +2,108 @@ # SPDX-FileCopyrightText: 2024 Attila Gombos # SPDX-License-Identifier: MIT -import os +import sys +from argparse import ArgumentParser, Action, Namespace from configparser import ConfigParser from pathlib import Path -from typing import Any +from typing import Any, cast from context_logger import get_logger -from common_utility import copy_file - -log = get_logger('ConfigLoader') - class IConfigLoader(object): - def load(self, arguments: dict[str, Any]) -> dict[str, Any]: + def load(self, argument_parser: ArgumentParser) -> Namespace: raise NotImplementedError() class ConfigLoader(IConfigLoader): - def __init__(self, default_config_file: Path, config_file_argument: str = 'config_file') -> None: + def __init__(self, default_config_file: Path) -> None: + self._config_parser = ConfigParser(interpolation=None) self._default_config_file = default_config_file - self._config_file_argument = config_file_argument + self.log = get_logger(type(self).__name__) + + def load(self, argument_parser: ArgumentParser) -> Namespace: + arguments = argument_parser.parse_known_args()[0] + + self.log.info('Loading default configuration', config_file=str(self._default_config_file)) + loaded = self._config_parser.read(self._default_config_file) + + if str(self._default_config_file) not in loaded: + self.log.warn('Default configuration could not be loaded', config_file=str(self._default_config_file)) + + if custom_config_file := arguments.config: + custom_config_file = Path(custom_config_file) + + self.log.info('Loading custom configuration', config_file=str(custom_config_file)) + loaded = self._config_parser.read(custom_config_file) + + if str(custom_config_file) not in loaded: + self.log.warn('Custom configuration could not be loaded', config_file=str(custom_config_file)) + + configuration = dict(vars(arguments)) + + for section in self._config_parser.sections(): + configuration.update(dict(self._config_parser[section])) + + self.log.info('Loading command line arguments', arguments=vars(arguments)) + cli_overrides = self._get_cli_overrides(argument_parser, arguments) + configuration.update(cli_overrides) + + self._sanitize_config(argument_parser, configuration) + + self.log.info('Configuration loaded', configuration=configuration) + + return Namespace(**configuration) + + def _get_cli_overrides(self, parser: ArgumentParser, arguments: Namespace) -> dict[str, Any]: + cli_overrides: dict[str, Any] = {} + argv_tokens = set(sys.argv[1:]) + + if not argv_tokens: + return cli_overrides + + for action in parser._actions: + if action.dest == 'help': + continue - def load(self, arguments: dict[str, Any]) -> dict[str, Any]: - parser = ConfigParser(interpolation=None) + if action.option_strings: # has --flag or -f + if any(opt in argv_tokens for opt in action.option_strings): + cli_overrides[action.dest] = getattr(arguments, action.dest) - log.info('Loading default configuration', config_file=str(self._default_config_file)) - parser.read(self._default_config_file) + return cli_overrides - if config_file := arguments.get(self._config_file_argument): - custom_config_file = Path(config_file) + def _sanitize_config(self, parser: ArgumentParser, config: dict[str, Any]) -> None: + for config_key in config: + action = self._find_action(parser, config_key) + if action is None or action.default is None: + continue - if os.path.exists(custom_config_file): - log.info('Loading custom configuration', config_file=str(custom_config_file)) - parser.read(custom_config_file) - else: - try: - log.info('Creating custom configuration using default', config_file=str(custom_config_file)) - copy_file(self._default_config_file, custom_config_file) - except Exception as exception: - log.warn('Failed to create custom configuration file', error=str(exception)) + if isinstance(action.default, bool): + self._convert_bool(config, config_key) + elif isinstance(action.default, int): + self._convert_int(config, config_key) + elif isinstance(action.default, float): + self._convert_float(config, config_key) - configuration = {} + for config_key in config: + self.log.debug('Config', key=config_key, value=config[config_key], type=type(config[config_key])) - for section in parser.sections(): - configuration.update(dict(parser[section])) + def _find_action(self, parser: ArgumentParser, config_key: str) -> Action: + return next((a for a in parser._actions if a.dest == config_key), cast(Action, cast(object, None))) - log.info('Loading command line arguments', arguments=arguments) - configuration.update(arguments) + def _convert_bool(self, config: dict[str, Any], config_key: str) -> None: + config[config_key] = str(config[config_key]).lower() in ('true', '1', 'yes') - log.info('Configuration loaded', configuration=configuration) + def _convert_int(self, config: dict[str, Any], config_key: str) -> None: + try: + config[config_key] = int(config[config_key]) + except (TypeError, ValueError): + pass - return configuration + def _convert_float(self, config: dict[str, Any], config_key: str) -> None: + try: + config[config_key] = float(config[config_key]) + except (TypeError, ValueError): + pass diff --git a/tests/config/example.default.conf b/tests/config/example.conf.default similarity index 100% rename from tests/config/example.default.conf rename to tests/config/example.conf.default diff --git a/tests/configLoaderTest.py b/tests/configLoaderTest.py index b825991..4add8fa 100644 --- a/tests/configLoaderTest.py +++ b/tests/configLoaderTest.py @@ -1,77 +1,215 @@ -import os.path +import sys import unittest +from argparse import ArgumentParser from pathlib import Path from unittest import TestCase +from unittest.mock import patch from context_logger import setup_logging from common_utility import delete_file, ConfigLoader, copy_file from tests import TEST_RESOURCE_ROOT, TEST_FILE_SYSTEM_ROOT +DEFAULT_CONFIG_FILE = f'{TEST_FILE_SYSTEM_ROOT}/etc/example.conf.default' + class ConfigLoaderTest(TestCase): @classmethod def setUpClass(cls): setup_logging('python-common-utility', 'DEBUG', warn_on_overwrite=False) + copy_file(f'{TEST_RESOURCE_ROOT}/config/example.conf.default', DEFAULT_CONFIG_FILE) def setUp(self): print() delete_file(f'{TEST_FILE_SYSTEM_ROOT}/etc/example.conf') + delete_file(f'{TEST_FILE_SYSTEM_ROOT}/etc/example.types.conf') + delete_file(f'{TEST_FILE_SYSTEM_ROOT}/etc/example.types.invalid.conf') + + def _create_argument_parser(self) -> ArgumentParser: + argument_parser = ArgumentParser() + argument_parser.add_argument('--config', default=None) + argument_parser.add_argument('--config-key1', default=None) + argument_parser.add_argument('--config-key2', default=None) + argument_parser.add_argument('--example-key1', default=None) + argument_parser.add_argument('--example-key2', default=None) + return argument_parser + + def test_load_config_when_default_config_file_could_not_be_loaded(self): + # Given + config_loader = ConfigLoader(Path('invalid/path/example.conf.default')) + argument_parser = self._create_argument_parser() + + # When + with patch.object(sys, 'argv', ['test', '--config-key1', 'new_value1']): + result = config_loader.load(argument_parser) + + # Then + self.assertEqual('new_value1', result.config_key1) + + def test_load_config_when_no_custom_config_file_specified(self): + # Given + config_loader = ConfigLoader(Path(DEFAULT_CONFIG_FILE)) + argument_parser = self._create_argument_parser() + + # When + with patch.object(sys, 'argv', ['test', '--config-key1', 'new_value1']): + result = config_loader.load(argument_parser) + + # Then + self.assertEqual('new_value1', result.config_key1) + self.assertEqual('value2', result.config_key2) + self.assertEqual('example1', result.example_key1) + self.assertEqual('example2', result.example_key2) + + def test_load_config_when_custom_config_file_specified(self): + # Given + config_loader = ConfigLoader(Path(DEFAULT_CONFIG_FILE)) + argument_parser = self._create_argument_parser() + config_file = f'{TEST_FILE_SYSTEM_ROOT}/etc/example.conf' + + copy_file(f'{TEST_RESOURCE_ROOT}/config/example.conf', config_file) + + # When + with patch.object(sys, 'argv', ['test', '--config', config_file, '--example-key1', 'new_example1']): + result = config_loader.load(argument_parser) - def test_load_config_when_custom_configuration_not_exists(self): + # Then + self.assertEqual('value1', result.config_key1) + self.assertEqual('value3', result.config_key2) + self.assertEqual('new_example1', result.example_key1) + self.assertEqual('example4', result.example_key2) + + def test_load_config_when_custom_config_file_could_not_be_loaded(self): # Given - config_loader = ConfigLoader(Path(TEST_RESOURCE_ROOT) / 'config' / 'example.default.conf') - arguments = { - 'config_file': f'{TEST_FILE_SYSTEM_ROOT}/etc/example.conf', - 'config_key1': 'new_value1', - } + config_loader = ConfigLoader(Path(DEFAULT_CONFIG_FILE)) + argument_parser = self._create_argument_parser() + # When + with patch.object(sys, 'argv', + ['test', '--config', 'invalid/path/example.conf', '--example-key1', 'new_example1']): + result = config_loader.load(argument_parser) + + # Then + self.assertEqual('value2', result.config_key2) + self.assertEqual('new_example1', result.example_key1) + self.assertEqual('example2', result.example_key2) + + def test_load_config_when_parser_default_values_defined_but_not_passed(self): + # Given + config_loader = ConfigLoader(Path(TEST_RESOURCE_ROOT) / 'config' / 'example.conf.default') + argument_parser = ArgumentParser() + argument_parser.add_argument('--config', default=None) + argument_parser.add_argument('--config-key1', default='cli_default_value1') + argument_parser.add_argument('--example-key1', default='cli_default_example1') + config_file = f'{TEST_FILE_SYSTEM_ROOT}/etc/example.conf' + + copy_file(f'{TEST_RESOURCE_ROOT}/config/example.conf', config_file) # When - result = config_loader.load(arguments) + with patch.object(sys, 'argv', ['test', '--config', config_file]): + result = config_loader.load(argument_parser) # Then - self.assertTrue(os.path.exists(f'{TEST_FILE_SYSTEM_ROOT}/etc/example.conf')) - self.assertEqual('new_value1', result['config_key1']) - self.assertEqual('value2', result['config_key2']) - self.assertEqual('example1', result['example_key1']) - self.assertEqual('example2', result['example_key2']) + self.assertEqual('value1', result.config_key1) + self.assertEqual('example3', result.example_key1) - def test_load_config_when_custom_configuration_exists(self): + def test_load_config_when_short_option_cli_override_is_passed(self): # Given - config_loader = ConfigLoader(Path(TEST_RESOURCE_ROOT) / 'config' / 'example.default.conf') - arguments = { - 'config_file': f'{TEST_FILE_SYSTEM_ROOT}/etc/example.conf', - 'example_key1': 'new_example1', - } + config_loader = ConfigLoader(Path(TEST_RESOURCE_ROOT) / 'config' / 'example.conf.default') + argument_parser = ArgumentParser() + argument_parser.add_argument('--config', default=None) + argument_parser.add_argument('--example-key2', '-e2', default=None) + config_file = f'{TEST_FILE_SYSTEM_ROOT}/etc/example.conf' - copy_file(f'{TEST_RESOURCE_ROOT}/config/example.conf', f'{TEST_FILE_SYSTEM_ROOT}/etc/example.conf') + copy_file(f'{TEST_RESOURCE_ROOT}/config/example.conf', config_file) # When - result = config_loader.load(arguments) + with patch.object(sys, 'argv', ['test', '--config', config_file, '-e2', 'new_example2']): + result = config_loader.load(argument_parser) # Then - self.assertEqual('value1', result['config_key1']) - self.assertEqual('value3', result['config_key2']) - self.assertEqual('new_example1', result['example_key1']) - self.assertEqual('example4', result['example_key2']) + self.assertEqual('new_example2', result.example_key2) - def test_load_config_when_fail_to_create_custom_configuration(self): + def test_load_config_when_long_option_cli_override_passed_using_equal_sign(self): # Given - config_loader = ConfigLoader(Path(TEST_RESOURCE_ROOT) / 'config' / 'example.default.conf') - arguments = { - 'config_file': '/invalid/path/to/example.conf', - 'config_key1': 'new_value1', - } + config_loader = ConfigLoader(Path(TEST_RESOURCE_ROOT) / 'config' / 'example.conf.default') + argument_parser = self._create_argument_parser() + + # When + with patch.object(sys, 'argv', ['test', '--config-key1=new_value1']): + result = config_loader.load(argument_parser) + + # Then + self.assertEqual('value1', result.config_key1) + + def test_get_cli_overrides_when_no_cli_arguments(self): + # Given + config_loader = ConfigLoader(Path(DEFAULT_CONFIG_FILE)) + argument_parser = self._create_argument_parser() + arguments = argument_parser.parse_args([]) + + # When + with patch.object(sys, 'argv', ['test']): + result = config_loader._get_cli_overrides(argument_parser, arguments) + + # Then + self.assertEqual({}, result) + + def test_get_cli_overrides_when_argument_has_no_option_string(self): + # Given + config_loader = ConfigLoader(Path(DEFAULT_CONFIG_FILE)) + argument_parser = ArgumentParser() + argument_parser.add_argument('--config', default=None) + argument_parser.add_argument('input_file') + arguments = argument_parser.parse_args(['input.txt']) + + # When + with patch.object(sys, 'argv', ['test', 'input.txt']): + result = config_loader._get_cli_overrides(argument_parser, arguments) + + # Then + self.assertEqual({}, result) + + def test_load_config_when_type_values_present_then_sanitize(self): + # Given + config_file = f'{TEST_FILE_SYSTEM_ROOT}/etc/example.types.conf' + Path(config_file).write_text('[types]\nfeature_enabled = true\nretry_count = 7\ntimeout = 1.5\n', + encoding='utf-8') + + config_loader = ConfigLoader(Path(DEFAULT_CONFIG_FILE)) + argument_parser = ArgumentParser() + argument_parser.add_argument('--config', default=None) + argument_parser.add_argument('--feature-enabled', default=False) + argument_parser.add_argument('--retry-count', default=0) + argument_parser.add_argument('--timeout', default=0.0) + + # When + with patch.object(sys, 'argv', ['test', '--config', config_file]): + result = config_loader.load(argument_parser) + + # Then + self.assertTrue(result.feature_enabled) + self.assertEqual(7, result.retry_count) + self.assertEqual(1.5, result.timeout) + + def test_load_config_when_invalid_numeric_values_present_then_keep_original(self): + # Given + config_file = f'{TEST_FILE_SYSTEM_ROOT}/etc/example.types.invalid.conf' + Path(config_file).write_text('[types]\nretry_count = invalid\ntimeout = invalid\n', encoding='utf-8') + + config_loader = ConfigLoader(Path(DEFAULT_CONFIG_FILE)) + argument_parser = ArgumentParser() + argument_parser.add_argument('--config', default=None) + argument_parser.add_argument('--retry-count', default=0) + argument_parser.add_argument('--timeout', default=0.0) # When - result = config_loader.load(arguments) + with patch.object(sys, 'argv', ['test', '--config', config_file]): + result = config_loader.load(argument_parser) # Then - self.assertEqual('new_value1', result['config_key1']) - self.assertEqual('value2', result['config_key2']) - self.assertEqual('example1', result['example_key1']) - self.assertEqual('example2', result['example_key2']) + self.assertEqual('invalid', result.retry_count) + self.assertEqual('invalid', result.timeout) if __name__ == '__main__':